Browse Source

修改获取属性方法,增加测试类,使用try-with-resources释放资源.

聂秋秋 6 years ago
parent
commit
6ac1f19266

+ 14 - 13
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/ReflectionKit.java

@@ -196,19 +196,20 @@ public class ReflectionKit {
      * @param clazz 反射类
      */
     public static List<Field> doGetFieldList(Class<?> clazz) {
-        List<Field> fieldList = Stream.of(clazz.getDeclaredFields())
-            /* 过滤静态属性 */
-            .filter(field -> !Modifier.isStatic(field.getModifiers()))
-            /* 过滤 transient关键字修饰的属性 */
-            .filter(field -> !Modifier.isTransient(field.getModifiers()))
-            .collect(toCollection(LinkedList::new));
-        /* 处理父类字段 */
-        Class<?> superClass = clazz.getSuperclass();
-        if (superClass == null) {
-            return fieldList;
+        if (clazz.getSuperclass() != null) {
+            List<Field> fieldList = Stream.of(clazz.getDeclaredFields())
+                /* 过滤静态属性 */
+                .filter(field -> !Modifier.isStatic(field.getModifiers()))
+                /* 过滤 transient关键字修饰的属性 */
+                .filter(field -> !Modifier.isTransient(field.getModifiers()))
+                .collect(toCollection(LinkedList::new));
+            /* 处理父类字段 */
+            Class<?> superClass = clazz.getSuperclass();
+            /* 排除重载属性 */
+            return excludeOverrideSuperField(fieldList, getFieldList(superClass));
+        } else {
+            return Collections.emptyList();
         }
-        /* 排除重载属性 */
-        return excludeOverrideSuperField(fieldList, getFieldList(superClass));
     }
 
     /**
@@ -222,7 +223,7 @@ public class ReflectionKit {
     public static List<Field> excludeOverrideSuperField(List<Field> fieldList, List<Field> superFieldList) {
         // 子类属性
         Map<String, Field> fieldMap = fieldList.stream().collect(toMap(Field::getName, identity()));
-        superFieldList.stream().filter(field -> fieldMap.get(field.getName()) == null).forEach(fieldList::add);
+        superFieldList.stream().filter(field -> !fieldMap.containsKey(field.getName())).forEach(fieldList::add);
         return fieldList;
     }
 }

+ 84 - 0
mybatis-plus-core/src/test/java/com/baomidou/mybatisplus/core/toolkit/ReflectionKitTest.java

@@ -0,0 +1,84 @@
+package com.baomidou.mybatisplus.core.toolkit;
+
+import lombok.Data;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.lang.reflect.Field;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * 反射工具类测试
+ *
+ * @author nieqiuqiu 2019/1/16.
+ */
+public class ReflectionKitTest {
+    
+    @Data
+    private static class A {
+        
+        private transient String test;
+        
+        private static String testStatic;
+        
+        private String name;
+        
+        private Boolean testWrap;
+        
+        private boolean testSimple;
+        
+    }
+    
+    @Data
+    private static class B extends A {
+        
+        private Integer age;
+        
+    }
+    
+    @Data
+    private static class C extends B {
+        
+        private String sex;
+        
+    }
+    
+    @Test
+    public void testGetFieldList() {
+        List<Field> fieldList = ReflectionKit.getFieldList(C.class);
+        Assert.assertEquals(5, fieldList.size());
+    }
+    
+    @Test
+    public void testGetFieldMap() throws NoSuchFieldException {
+        Map<String, Field> fieldMap = ReflectionKit.getFieldMap(C.class);
+        Assert.assertEquals(5, fieldMap.size());
+        Assert.assertEquals(fieldMap.get("sex"), C.class.getDeclaredField("sex"));
+        Assert.assertEquals(fieldMap.get("age"), B.class.getDeclaredField("age"));
+        Assert.assertEquals(fieldMap.get("name"), A.class.getDeclaredField("name"));
+    }
+    
+    @Test
+    public void testGetMethodCapitalize() throws NoSuchFieldException {
+        Field field = C.class.getDeclaredField("sex");
+        String getMethod = ReflectionKit.getMethodCapitalize(field, "sex");
+        Assert.assertEquals("getSex", getMethod);
+        field = A.class.getDeclaredField("testWrap");
+        getMethod = ReflectionKit.getMethodCapitalize(field, "testWrap");
+        Assert.assertEquals("getTestWrap", getMethod);
+        field = A.class.getDeclaredField("testSimple");
+        getMethod = ReflectionKit.getMethodCapitalize(field, "testSimple");
+        Assert.assertEquals("isTestSimple", getMethod);
+    }
+    
+    @Test
+    public void testGetMethodValue() {
+        C c = new C();
+        c.setSex("女");
+        c.setName("妹纸");
+        c.setAge(18);
+        Assert.assertEquals(c.getSex(), ReflectionKit.getMethodValue(c.getClass(), c, "sex"));
+        Assert.assertEquals(c.getAge(), ReflectionKit.getMethodValue(c, "age"));
+    }
+}

+ 92 - 107
mybatis-plus-generator/src/main/java/com/baomidou/mybatisplus/generator/config/builder/ConfigBuilder.java

@@ -432,7 +432,6 @@ public class ConfigBuilder {
 
         //不存在的表名
         Set<String> notExistTables = new HashSet<>();
-        PreparedStatement preparedStatement = null;
         try {
             String tablesSql = dbQuery.tablesSql();
             if (DbType.POSTGRE_SQL == dbQuery.dbType()) {
@@ -467,45 +466,45 @@ public class ConfigBuilder {
                     tablesSql = sb.toString();
                 }
             }
-            preparedStatement = connection.prepareStatement(tablesSql);
-            ResultSet results = preparedStatement.executeQuery();
             TableInfo tableInfo;
-            while (results.next()) {
-                String tableName = results.getString(dbQuery.tableName());
-                if (StringUtils.isNotEmpty(tableName)) {
-                    String tableComment = results.getString(dbQuery.tableComment());
-                    if (config.isSkipView() && "VIEW".equals(tableComment)) {
-                        // 跳过视图
-                        continue;
-                    }
-                    tableInfo = new TableInfo();
-                    tableInfo.setName(tableName);
-                    tableInfo.setComment(tableComment);
-                    if (isInclude) {
-                        for (String includeTable : config.getInclude()) {
-                            // 忽略大小写等于 或 正则 true
-                            if (tableNameMatches(includeTable, tableName)) {
-                                includeTableList.add(tableInfo);
-                            } else {
-                                notExistTables.add(includeTable);
-                            }
+            try (PreparedStatement preparedStatement = connection.prepareStatement(tablesSql);
+                 ResultSet results = preparedStatement.executeQuery()) {
+                while (results.next()) {
+                    String tableName = results.getString(dbQuery.tableName());
+                    if (StringUtils.isNotEmpty(tableName)) {
+                        String tableComment = results.getString(dbQuery.tableComment());
+                        if (config.isSkipView() && "VIEW".equals(tableComment)) {
+                            // 跳过视图
+                            continue;
                         }
-                    } else if (isExclude) {
-                        for (String excludeTable : config.getExclude()) {
-                            // 忽略大小写等于 或 正则 true
-                            if (tableNameMatches(excludeTable, tableName)) {
-                                excludeTableList.add(tableInfo);
-                            } else {
-                                notExistTables.add(excludeTable);
+                        tableInfo = new TableInfo();
+                        tableInfo.setName(tableName);
+                        tableInfo.setComment(tableComment);
+                        if (isInclude) {
+                            for (String includeTable : config.getInclude()) {
+                                // 忽略大小写等于 或 正则 true
+                                if (tableNameMatches(includeTable, tableName)) {
+                                    includeTableList.add(tableInfo);
+                                } else {
+                                    notExistTables.add(includeTable);
+                                }
+                            }
+                        } else if (isExclude) {
+                            for (String excludeTable : config.getExclude()) {
+                                // 忽略大小写等于 或 正则 true
+                                if (tableNameMatches(excludeTable, tableName)) {
+                                    excludeTableList.add(tableInfo);
+                                } else {
+                                    notExistTables.add(excludeTable);
+                                }
                             }
                         }
+                        tableList.add(tableInfo);
+                    } else {
+                        System.err.println("当前数据库为空!!!");
                     }
-                    tableList.add(tableInfo);
-                } else {
-                    System.err.println("当前数据库为空!!!");
                 }
             }
-
             // 将已经存在的表移除,获取配置中数据库不存在的表
             for (TableInfo tabInfo : tableList) {
                 notExistTables.remove(tabInfo.getName());
@@ -528,18 +527,6 @@ public class ConfigBuilder {
             includeTableList.forEach(ti -> convertTableFields(ti, config.getColumnNaming()));
         } catch (SQLException e) {
             e.printStackTrace();
-        } finally {
-            // 释放资源
-            try {
-                if (preparedStatement != null) {
-                    preparedStatement.close();
-                }
-                if (connection != null) {
-                    connection.close();
-                }
-            } catch (SQLException e) {
-                e.printStackTrace();
-            }
         }
         return processTable(includeTableList, config.getNaming(), config);
     }
@@ -558,7 +545,6 @@ public class ConfigBuilder {
         return setTableName.equals(dbTableName)
             || StringUtils.matches(setTableName, dbTableName);
     }
-
     /**
      * <p>
      * 将字段信息与表信息关联
@@ -584,79 +570,78 @@ public class ConfigBuilder {
                 tableFieldsSql = String.format(tableFieldsSql.replace("#schema", dataSourceConfig.getSchemaName()), tableName);
             } else if (DbType.H2 == dbType) {
                 tableName = tableName.toUpperCase();
-                PreparedStatement pkQueryStmt = connection.prepareStatement(String.format(H2Query.PK_QUERY_SQL, tableName));
-                ResultSet pkResults = pkQueryStmt.executeQuery();
-                while (pkResults.next()) {
-                    String primaryKey = pkResults.getString(dbQuery.fieldKey());
-                    if ("TRUE".equalsIgnoreCase(primaryKey)) {
-                        h2PkColumns.add(pkResults.getString(dbQuery.fieldName()));
+                try (PreparedStatement pkQueryStmt = connection.prepareStatement(String.format(H2Query.PK_QUERY_SQL, tableName));
+                     ResultSet pkResults = pkQueryStmt.executeQuery()) {
+                    while (pkResults.next()) {
+                        String primaryKey = pkResults.getString(dbQuery.fieldKey());
+                        if (Boolean.valueOf(primaryKey)) {
+                            h2PkColumns.add(pkResults.getString(dbQuery.fieldName()));
+                        }
                     }
                 }
-                pkResults.close();
-                pkQueryStmt.close();
                 tableFieldsSql = String.format(tableFieldsSql, tableName);
             } else {
                 tableFieldsSql = String.format(tableFieldsSql, tableName);
             }
-            PreparedStatement preparedStatement = connection.prepareStatement(tableFieldsSql);
-            ResultSet results = preparedStatement.executeQuery();
-            while (results.next()) {
-                TableField field = new TableField();
-                String columnName = results.getString(dbQuery.fieldName());
-                // 避免多重主键设置,目前只取第一个找到ID,并放到list中的索引为0的位置
-                boolean isId;
-                if(DbType.H2 == dbType){
-                    isId = h2PkColumns.contains(columnName);
-                }else{
-                    String key = results.getString(dbQuery.fieldKey());
-                    if (DbType.DB2 == dbType) {
-                        isId = StringUtils.isNotEmpty(key) && "1".equals(key);
+            try (
+                PreparedStatement preparedStatement = connection.prepareStatement(tableFieldsSql);
+                ResultSet results = preparedStatement.executeQuery()) {
+                while (results.next()) {
+                    TableField field = new TableField();
+                    String columnName = results.getString(dbQuery.fieldName());
+                    // 避免多重主键设置,目前只取第一个找到ID,并放到list中的索引为0的位置
+                    boolean isId;
+                    if (DbType.H2 == dbType) {
+                        isId = h2PkColumns.contains(columnName);
                     } else {
-                        isId = StringUtils.isNotEmpty(key) && "PRI".equals(key.toUpperCase());
+                        String key = results.getString(dbQuery.fieldKey());
+                        if (DbType.DB2 == dbType) {
+                            isId = StringUtils.isNotEmpty(key) && "1".equals(key);
+                        } else {
+                            isId = StringUtils.isNotEmpty(key) && "PRI".equals(key.toUpperCase());
+                        }
                     }
-                }
-
-                // 处理ID
-                if (isId && !haveId) {
-                    field.setKeyFlag(true);
-                    if (DbType.H2 == dbType || dbQuery.isKeyIdentity(results)) {
-                        field.setKeyIdentityFlag(true);
+            
+                    // 处理ID
+                    if (isId && !haveId) {
+                        field.setKeyFlag(true);
+                        if (DbType.H2 == dbType || dbQuery.isKeyIdentity(results)) {
+                            field.setKeyIdentityFlag(true);
+                        }
+                        haveId = true;
+                    } else {
+                        field.setKeyFlag(false);
                     }
-                    haveId = true;
-                } else {
-                    field.setKeyFlag(false);
-                }
-                // 自定义字段查询
-                String[] fcs = dbQuery.fieldCustom();
-                if (null != fcs) {
-                    Map<String, Object> customMap = new HashMap<>();
-                    for (String fc : fcs) {
-                        customMap.put(fc, results.getObject(fc));
+                    // 自定义字段查询
+                    String[] fcs = dbQuery.fieldCustom();
+                    if (null != fcs) {
+                        Map<String, Object> customMap = new HashMap<>();
+                        for (String fc : fcs) {
+                            customMap.put(fc, results.getObject(fc));
+                        }
+                        field.setCustomMap(customMap);
                     }
-                    field.setCustomMap(customMap);
-                }
-                // 处理其它信息
-                field.setName(columnName);
-                field.setType(results.getString(dbQuery.fieldType()));
-                field.setPropertyName(strategyConfig, processName(field.getName(), strategy));
-                field.setColumnType(dataSourceConfig.getTypeConvert().processTypeConvert(globalConfig, field.getType()));
-                field.setComment(results.getString(dbQuery.fieldComment()));
-                if (strategyConfig.includeSuperEntityColumns(field.getName())) {
-                    // 跳过公共字段
-                    commonFieldList.add(field);
-                    continue;
-                }
-                // 填充逻辑判断
-                List<TableFill> tableFillList = getStrategyConfig().getTableFillList();
-                if (null != tableFillList) {
-                    // 忽略大写字段问题
-                    tableFillList.stream().filter(tf -> tf.getFieldName().equalsIgnoreCase(field.getName()))
-                        .findFirst().ifPresent(tf -> field.setFill(tf.getFieldFill().name()));
+                    // 处理其它信息
+                    field.setName(columnName);
+                    field.setType(results.getString(dbQuery.fieldType()));
+                    field.setPropertyName(strategyConfig, processName(field.getName(), strategy));
+                    field.setColumnType(dataSourceConfig.getTypeConvert().processTypeConvert(globalConfig, field.getType()));
+                    field.setComment(results.getString(dbQuery.fieldComment()));
+                    if (strategyConfig.includeSuperEntityColumns(field.getName())) {
+                        // 跳过公共字段
+                        commonFieldList.add(field);
+                        continue;
+                    }
+                    // 填充逻辑判断
+                    List<TableFill> tableFillList = getStrategyConfig().getTableFillList();
+                    if (null != tableFillList) {
+                        // 忽略大写字段问题
+                        tableFillList.stream().filter(tf -> tf.getFieldName().equalsIgnoreCase(field.getName()))
+                            .findFirst().ifPresent(tf -> field.setFill(tf.getFieldFill().name()));
+                    }
+                    fieldList.add(field);
                 }
-                fieldList.add(field);
             }
-            results.close();
-            preparedStatement.close();
         } catch (SQLException e) {
             System.err.println("SQL Exception:" + e.getMessage());
         }