Explorar o código

优化枚举处理,修改测试用例.

聂秋秋 %!s(int64=6) %!d(string=hai) anos
pai
achega
fa010e4995

+ 14 - 21
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/handlers/EnumAnnotationTypeHandler.java

@@ -46,28 +46,29 @@ import java.util.concurrent.ConcurrentHashMap;
  */
 @Deprecated
 public class EnumAnnotationTypeHandler<E extends Enum<E>> extends BaseTypeHandler<E> {
-
+    
+    private static final Log LOGGER = LogFactory.getLog(EnumAnnotationTypeHandler.class);
+    
+    private static final Map<Class<?>, Method> TABLE_METHOD_OF_ENUM_TYPES = new ConcurrentHashMap<>();
+    
     private final Class<E> type;
+    
+    private final Method method;
 
     public EnumAnnotationTypeHandler(Class<E> type) {
         if (type == null) {
             throw new IllegalArgumentException("Type argument cannot be null");
         }
         this.type = type;
+        this.method = TABLE_METHOD_OF_ENUM_TYPES.computeIfAbsent(type, k-> {
+            Field field = dealEnumType(this.type).orElseThrow(() -> new IllegalArgumentException(String.format("Could not find @EnumValue in Class: %s.", type.getName())));
+            return ReflectionKit.getMethod(this.type, field);
+        });
     }
 
-    private static final Log LOGGER = LogFactory.getLog(EnumAnnotationTypeHandler.class);
-
-    private static final Map<Class<?>, Method> TABLE_FIELD_OF_ENUM_TYPES = new ConcurrentHashMap<>();
-
-    public static void addEnumType(Class<?> clazz, Method method) {
-        TABLE_FIELD_OF_ENUM_TYPES.put(clazz, method);
-    }
-
-
+    @SuppressWarnings("Duplicates")
     @Override
     public void setNonNullParameter(PreparedStatement ps, int i, Enum parameter, JdbcType jdbcType) throws SQLException {
-        Method method = getMethod(type);
         try {
             method.setAccessible(true);
             if (jdbcType == null) {
@@ -92,7 +93,7 @@ public class EnumAnnotationTypeHandler<E extends Enum<E>> extends BaseTypeHandle
         if (s == null) {
             return null;
         }
-        return EnumUtils.valueOf(type, s, getMethod(type));
+        return EnumUtils.valueOf(type, s, method);
     }
 
     @Override
@@ -108,13 +109,5 @@ public class EnumAnnotationTypeHandler<E extends Enum<E>> extends BaseTypeHandle
     public static Optional<Field> dealEnumType(Class<?> clazz) {
         return clazz.isEnum() ? Arrays.stream(clazz.getDeclaredFields()).filter(field -> field.isAnnotationPresent(EnumValue.class)).findFirst() : Optional.empty();
     }
-
-    private Method getMethod(Class<?> clazz) {
-        return Optional.ofNullable(TABLE_FIELD_OF_ENUM_TYPES.get(type)).orElseGet(() -> {
-            Field field = dealEnumType(clazz).orElseThrow(() -> new IllegalArgumentException("当前[" + type.getName() + "]枚举类未找到标有@EnumValue注解的字段"));
-            Method method = ReflectionKit.getMethod(clazz, field);
-            addEnumType(clazz, method);
-            return method;
-        });
-    }
+    
 }

+ 28 - 36
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/handlers/EnumTypeHandler.java

@@ -45,15 +45,15 @@ import java.util.concurrent.ConcurrentHashMap;
  * @since 2017-10-11
  */
 public class EnumTypeHandler<E extends Enum<?>> extends BaseTypeHandler<Enum> {
-
+    
     private static final Log LOGGER = LogFactory.getLog(EnumTypeHandler.class);
-
-    private static final Map<Class<?>, Method> TABLE_FIELD_OF_ENUM_TYPES = new ConcurrentHashMap<>();
-
-    private Class<E> type;
-
-    private Method method;
-
+    
+    private static final Map<Class<?>, Method> TABLE_METHOD_OF_ENUM_TYPES = new ConcurrentHashMap<>();
+    
+    private final Class<E> type;
+    
+    private final Method method;
+    
     public EnumTypeHandler(Class<E> type) {
         if (type == null) {
             throw new IllegalArgumentException("Type argument cannot be null");
@@ -61,71 +61,63 @@ public class EnumTypeHandler<E extends Enum<?>> extends BaseTypeHandler<Enum> {
         this.type = type;
         if (IEnum.class.isAssignableFrom(type)) {
             try {
-                method = type.getMethod("getValue");
+                this.method = type.getMethod("getValue");
             } catch (NoSuchMethodException e) {
-                throw new IllegalArgumentException("当前[" + type.getName() + "]枚举类未找到getValue方法");
+                throw new IllegalArgumentException(String.format("NoSuchMethod getValue() in Class: %s.", type.getName()));
             }
         } else {
-            method = getMethod(type);
+            this.method = TABLE_METHOD_OF_ENUM_TYPES.computeIfAbsent(type, k -> {
+                Field field = dealEnumType(this.type).orElseThrow(() -> new IllegalArgumentException(String.format("Could not find @EnumValue in Class: %s.", type.getName())));
+                return ReflectionKit.getMethod(this.type, field);
+            });
         }
     }
-
+    
+    @SuppressWarnings("Duplicates")
     @Override
     public void setNonNullParameter(PreparedStatement ps, int i, Enum parameter, JdbcType jdbcType)
         throws SQLException {
         try {
-            method.setAccessible(true);
+            this.method.setAccessible(true);
             if (jdbcType == null) {
-                ps.setObject(i, method.invoke(parameter));
+                ps.setObject(i, this.method.invoke(parameter));
             } else {
                 // see r3589
-                ps.setObject(i, method.invoke(parameter), jdbcType.TYPE_CODE);
+                ps.setObject(i, this.method.invoke(parameter), jdbcType.TYPE_CODE);
             }
         } catch (IllegalAccessException e) {
             LOGGER.error("unrecognized jdbcType, failed to set StringValue for type=" + parameter);
         } catch (InvocationTargetException e) {
-            throw ExceptionUtils.mpe("Error: NoSuchMethod in %s.  Cause:", e, type.getName());
+            throw ExceptionUtils.mpe("Error: NoSuchMethod in %s.  Cause:", e, this.type.getName());
         }
     }
-
+    
     @Override
     public E getNullableResult(ResultSet rs, String columnName) throws SQLException {
         if (null == rs.getObject(columnName) && rs.wasNull()) {
             return null;
         }
-        return EnumUtils.valueOf(type, rs.getObject(columnName), method);
+        return EnumUtils.valueOf(this.type, rs.getObject(columnName), this.method);
     }
-
+    
     @Override
     public E getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
         if (null == rs.getObject(columnIndex) && rs.wasNull()) {
             return null;
         }
-        return EnumUtils.valueOf(type, rs.getObject(columnIndex), method);
+        return EnumUtils.valueOf(this.type, rs.getObject(columnIndex), this.method);
     }
-
+    
     @Override
     public E getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
         if (null == cs.getObject(columnIndex) && cs.wasNull()) {
             return null;
         }
-        return EnumUtils.valueOf(type, cs.getObject(columnIndex), method);
+        return EnumUtils.valueOf(this.type, cs.getObject(columnIndex), this.method);
     }
-
+    
     public static Optional<Field> dealEnumType(Class<?> clazz) {
         return clazz.isEnum() ? Arrays.stream(clazz.getDeclaredFields()).filter(field -> field.isAnnotationPresent(EnumValue.class)).findFirst() : Optional.empty();
     }
-
-    public static void addEnumType(Class<?> clazz, Method method) {
-        TABLE_FIELD_OF_ENUM_TYPES.put(clazz, method);
-    }
-
-    private Method getMethod(Class<?> clazz) {
-        return Optional.ofNullable(TABLE_FIELD_OF_ENUM_TYPES.get(type)).orElseGet(() -> {
-            Field field = dealEnumType(clazz).orElseThrow(() -> new IllegalArgumentException("当前[" + type.getName() + "]枚举类未找到标有@EnumValue注解的字段"));
-            Method method = ReflectionKit.getMethod(clazz, field);
-            addEnumType(clazz, method);
-            return method;
-        });
-    }
+    
 }

+ 4 - 17
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/spring/MybatisSqlSessionFactoryBean.java

@@ -21,7 +21,6 @@ import com.baomidou.mybatisplus.core.MybatisXMLConfigBuilder;
 import com.baomidou.mybatisplus.core.config.GlobalConfig;
 import com.baomidou.mybatisplus.core.enums.IEnum;
 import com.baomidou.mybatisplus.core.toolkit.*;
-import com.baomidou.mybatisplus.extension.handlers.EnumAnnotationTypeHandler;
 import com.baomidou.mybatisplus.extension.handlers.EnumTypeHandler;
 import com.baomidou.mybatisplus.extension.toolkit.AopUtils;
 import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
@@ -58,7 +57,6 @@ import org.springframework.jdbc.datasource.TransactionAwareDataSourceProxy;
 
 import javax.sql.DataSource;
 import java.io.IOException;
-import java.lang.reflect.Field;
 import java.sql.Connection;
 import java.sql.SQLException;
 import java.util.*;
@@ -526,21 +524,10 @@ public class MybatisSqlSessionFactoryBean implements FactoryBean<SqlSessionFacto
             }
             // 取得类型转换注册器
             TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
-            classes.forEach(cls ->{
-                if (cls.isEnum()) {
-                    if (IEnum.class.isAssignableFrom(cls)) {
-                        // 接口方式
-                        typeHandlerRegistry.register(cls, EnumTypeHandler.class);
-                    } else {
-                        // 注解方式
-                        Optional<Field> optional = EnumTypeHandler.dealEnumType(cls);
-                        if (optional.isPresent()) {
-                            EnumTypeHandler.addEnumType(cls, ReflectionKit.getMethod(cls, optional.get()));
-                            typeHandlerRegistry.register(cls, EnumTypeHandler.class);
-                        }
-                    }
-                }
-            });
+            classes.stream()
+                .filter(Class::isEnum)
+                .filter(cls -> IEnum.class.isAssignableFrom(cls) || EnumTypeHandler.dealEnumType(cls).isPresent())
+                .forEach(cls -> typeHandlerRegistry.register(cls, EnumTypeHandler.class));
         }
 
         if (!isEmpty(this.typeAliases)) {

+ 4 - 3
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/handlers/EnumAnnotationTypeHandlerTest.java

@@ -12,6 +12,7 @@ import lombok.AllArgsConstructor;
 import lombok.Getter;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -44,7 +45,7 @@ public class EnumAnnotationTypeHandlerTest extends BaseTypeHandlerTest {
     @Override
     public void getResultFromResultSetByColumnName() throws Exception {
         when(resultSet.getObject("column")).thenReturn(null);
-        assertEquals(null, HANDLER.getResult(resultSet, "column"));
+        assertNull(HANDLER.getResult(resultSet, "column"));
         when(resultSet.getObject("column")).thenReturn(1);
         assertEquals(SexEnum.MAN, HANDLER.getResult(resultSet, "column"));
         when(resultSet.getObject("column")).thenReturn(2);
@@ -59,7 +60,7 @@ public class EnumAnnotationTypeHandlerTest extends BaseTypeHandlerTest {
         when(resultSet.getObject(2)).thenReturn(2);
         assertEquals(SexEnum.WO_MAN, HANDLER.getResult(resultSet, 2));
         when(resultSet.getObject(3)).thenReturn(null);
-        assertEquals(null, HANDLER.getResult(resultSet, 3));
+        assertNull(HANDLER.getResult(resultSet, 3));
     }
 
     @Test
@@ -70,7 +71,7 @@ public class EnumAnnotationTypeHandlerTest extends BaseTypeHandlerTest {
         when(callableStatement.getObject(2)).thenReturn(2);
         assertEquals(SexEnum.WO_MAN, HANDLER.getResult(callableStatement, 2));
         when(callableStatement.getObject(3)).thenReturn(null);
-        assertEquals(null, HANDLER.getResult(callableStatement, 3));
+        assertNull(HANDLER.getResult(callableStatement, 3));
 
     }
 

+ 69 - 22
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/handlers/EnumTypeHandlerTest.java

@@ -1,5 +1,6 @@
 package com.baomidou.mybatisplus.extension.handlers;
 
+import com.baomidou.mybatisplus.annotation.EnumValue;
 import com.baomidou.mybatisplus.core.enums.IEnum;
 
 import org.apache.ibatis.type.JdbcType;
@@ -17,64 +18,110 @@ import static org.mockito.Mockito.when;
 
 @ExtendWith(MockitoExtension.class)
 public class EnumTypeHandlerTest extends BaseTypeHandlerTest {
-
-    private static final EnumTypeHandler<SexEnum> HANDLER = new EnumTypeHandler<>(SexEnum.class);
-
+    
+    private static final EnumTypeHandler<SexEnum> SEX_ENUM_ENUM_TYPE_HANDLER = new EnumTypeHandler<>(SexEnum.class);
+    
+    private static final EnumTypeHandler<GradeEnum> GRADE_ENUM_ENUM_TYPE_HANDLER = new EnumTypeHandler<>(GradeEnum.class);
+    
     @Getter
     @AllArgsConstructor
     enum SexEnum implements IEnum<Integer> {
-        MAN(1, "1"), WO_MAN(2, "2");
+        
+        MAN(1, "1"),
+        WO_MAN(2, "2");
         Integer code;
         String desc;
-
+        
         @Override
         public Integer getValue() {
             return this.code;
         }
     }
-
+    
+    @Getter
+    @AllArgsConstructor
+    enum GradeEnum {
+        
+        PRIMARY(1, "小学"),
+        SECONDARY(2, "中学"),
+        HIGH(3, "高中");
+        
+        @EnumValue
+        private final int code;
+        
+        private final String desc;
+    }
+    
     @Test
     @Override
     public void setParameter() throws Exception {
-        HANDLER.setParameter(preparedStatement, 1, SexEnum.MAN, null);
+        SEX_ENUM_ENUM_TYPE_HANDLER.setParameter(preparedStatement, 1, SexEnum.MAN, null);
         verify(preparedStatement).setObject(1, 1);
-        HANDLER.setParameter(preparedStatement, 2, SexEnum.WO_MAN, null);
+        SEX_ENUM_ENUM_TYPE_HANDLER.setParameter(preparedStatement, 2, SexEnum.WO_MAN, null);
         verify(preparedStatement).setObject(2, 2);
-        HANDLER.setParameter(preparedStatement, 3, null, JdbcType.INTEGER);
+        SEX_ENUM_ENUM_TYPE_HANDLER.setParameter(preparedStatement, 3, null, JdbcType.INTEGER);
         verify(preparedStatement).setNull(3, JdbcType.INTEGER.TYPE_CODE);
+        
+        GRADE_ENUM_ENUM_TYPE_HANDLER.setParameter(preparedStatement, 4, GradeEnum.PRIMARY, null);
+        verify(preparedStatement).setObject(4, 1);
+        GRADE_ENUM_ENUM_TYPE_HANDLER.setParameter(preparedStatement, 5, GradeEnum.SECONDARY, null);
+        verify(preparedStatement).setObject(5, 2);
+        GRADE_ENUM_ENUM_TYPE_HANDLER.setParameter(preparedStatement, 6, null, JdbcType.INTEGER);
+        verify(preparedStatement).setNull(6, JdbcType.INTEGER.TYPE_CODE);
     }
-
+    
     @Test
     @Override
     public void getResultFromResultSetByColumnName() throws Exception {
         when(resultSet.getObject("column")).thenReturn(null);
-        assertNull(HANDLER.getResult(resultSet, "column"));
+        assertNull(SEX_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, "column"));
+        when(resultSet.getObject("column")).thenReturn(1);
+        assertEquals(SexEnum.MAN, SEX_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, "column"));
+        when(resultSet.getObject("column")).thenReturn(2);
+        assertEquals(SexEnum.WO_MAN, SEX_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, "column"));
+        when(resultSet.getObject("column")).thenReturn(null);
+        
+        assertNull(GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, "column"));
         when(resultSet.getObject("column")).thenReturn(1);
-        assertEquals(SexEnum.MAN, HANDLER.getResult(resultSet, "column"));
+        assertEquals(GradeEnum.PRIMARY, GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, "column"));
         when(resultSet.getObject("column")).thenReturn(2);
-        assertEquals(SexEnum.WO_MAN, HANDLER.getResult(resultSet, "column"));
+        assertEquals(GradeEnum.SECONDARY, GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, "column"));
     }
-
+    
     @Test
     @Override
     public void getResultFromResultSetByColumnIndex() throws Exception {
         when(resultSet.getObject(1)).thenReturn(1);
-        assertEquals(SexEnum.MAN, HANDLER.getResult(resultSet, 1));
+        assertEquals(SexEnum.MAN, SEX_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, 1));
         when(resultSet.getObject(2)).thenReturn(2);
-        assertEquals(SexEnum.WO_MAN, HANDLER.getResult(resultSet, 2));
+        assertEquals(SexEnum.WO_MAN, SEX_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, 2));
         when(resultSet.getObject(3)).thenReturn(null);
-        assertNull(HANDLER.getResult(resultSet, 3));
+        assertNull(SEX_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, 3));
+        
+        when(resultSet.getObject(4)).thenReturn(1);
+        assertEquals(GradeEnum.PRIMARY, GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, 4));
+        when(resultSet.getObject(5)).thenReturn(2);
+        assertEquals(GradeEnum.SECONDARY, GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, 5));
+        when(resultSet.getObject(6)).thenReturn(null);
+        assertNull(GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(resultSet, 6));
     }
-
+    
     @Test
     @Override
     public void getResultFromCallableStatement() throws Exception {
         when(callableStatement.getObject(1)).thenReturn(1);
-        assertEquals(SexEnum.MAN, HANDLER.getResult(callableStatement, 1));
+        assertEquals(SexEnum.MAN, SEX_ENUM_ENUM_TYPE_HANDLER.getResult(callableStatement, 1));
         when(callableStatement.getObject(2)).thenReturn(2);
-        assertEquals(SexEnum.WO_MAN, HANDLER.getResult(callableStatement, 2));
+        assertEquals(SexEnum.WO_MAN, SEX_ENUM_ENUM_TYPE_HANDLER.getResult(callableStatement, 2));
         when(callableStatement.getObject(3)).thenReturn(null);
-        assertNull(HANDLER.getResult(callableStatement, 3));
+        assertNull(SEX_ENUM_ENUM_TYPE_HANDLER.getResult(callableStatement, 3));
+    
+        when(callableStatement.getObject(4)).thenReturn(1);
+        assertEquals(GradeEnum.PRIMARY, GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(callableStatement, 4));
+        when(callableStatement.getObject(5)).thenReturn(2);
+        assertEquals(GradeEnum.SECONDARY, GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(callableStatement, 5));
+        when(callableStatement.getObject(6)).thenReturn(null);
+        assertNull(GRADE_ENUM_ENUM_TYPE_HANDLER.getResult(callableStatement, 6));
     }
-
+    
 }