Ver código fonte

支持多重继承获取泛型

hubin 4 anos atrás
pai
commit
7210b461b2

+ 3 - 33
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/injector/AbstractSqlInjector.java

@@ -15,19 +15,16 @@
  */
 package com.baomidou.mybatisplus.core.injector;
 
+import com.baomidou.mybatisplus.core.mapper.Mapper;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
-import com.baomidou.mybatisplus.core.toolkit.ArrayUtils;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.GlobalConfigUtils;
+import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
 import org.apache.ibatis.builder.MapperBuilderAssistant;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
 
-import java.lang.reflect.ParameterizedType;
-import java.lang.reflect.Type;
-import java.lang.reflect.TypeVariable;
-import java.lang.reflect.WildcardType;
 import java.util.List;
 import java.util.Set;
 
@@ -43,7 +40,7 @@ public abstract class AbstractSqlInjector implements ISqlInjector {
 
     @Override
     public void inspectInject(MapperBuilderAssistant builderAssistant, Class<?> mapperClass) {
-        Class<?> modelClass = extractModelClass(mapperClass);
+        Class<?> modelClass = ReflectionKit.getSuperClassGenericType(mapperClass, Mapper.class, 0);
         if (modelClass != null) {
             String className = mapperClass.toString();
             Set<String> mapperRegistryCache = GlobalConfigUtils.getMapperRegistryCache(builderAssistant.getConfiguration());
@@ -72,31 +69,4 @@ public abstract class AbstractSqlInjector implements ISqlInjector {
      */
     public abstract List<AbstractMethod> getMethodList(Class<?> mapperClass);
 
-    /**
-     * 提取泛型模型,多泛型的时候请将泛型T放在第一位
-     *
-     * @param mapperClass mapper 接口
-     * @return mapper 泛型
-     */
-    protected Class<?> extractModelClass(Class<?> mapperClass) {
-        Type[] types = mapperClass.getGenericInterfaces();
-        ParameterizedType target = null;
-        for (Type type : types) {
-            if (type instanceof ParameterizedType) {
-                Type[] typeArray = ((ParameterizedType) type).getActualTypeArguments();
-                if (ArrayUtils.isNotEmpty(typeArray)) {
-                    for (Type t : typeArray) {
-                        if (t instanceof TypeVariable || t instanceof WildcardType) {
-                            break;
-                        } else {
-                            target = (ParameterizedType) type;
-                            break;
-                        }
-                    }
-                }
-                break;
-            }
-        }
-        return target == null ? null : (Class<?>) target.getActualTypeArguments()[0];
-    }
 }

+ 20 - 31
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/ReflectionKit.java

@@ -17,8 +17,11 @@ package com.baomidou.mybatisplus.core.toolkit;
 
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
+import org.springframework.core.GenericTypeResolver;
 
-import java.lang.reflect.*;
+import java.lang.reflect.AccessibleObject;
+import java.lang.reflect.Field;
+import java.lang.reflect.Modifier;
 import java.security.AccessController;
 import java.util.*;
 import java.util.concurrent.ConcurrentHashMap;
@@ -85,28 +88,14 @@ public final class ReflectionKit {
      * 反射对象获取泛型
      * </p>
      *
-     * @param clazz 对象
-     * @param index 泛型所在位置
+     * @param clazz      对象
+     * @param genericIfc 所属泛型父类
+     * @param index      泛型所在位置
      * @return Class
      */
-    public static Class<?> getSuperClassGenericType(final Class<?> clazz, final int index) {
-        Type genType = clazz.getGenericSuperclass();
-        if (!(genType instanceof ParameterizedType)) {
-            logger.warn(String.format("Warn: %s's superclass not ParameterizedType", clazz.getSimpleName()));
-            return Object.class;
-        }
-        Type[] params = ((ParameterizedType) genType).getActualTypeArguments();
-        if (index >= params.length || index < 0) {
-            logger.warn(String.format("Warn: Index: %s, Size of %s's Parameterized Type: %s .", index,
-                    clazz.getSimpleName(), params.length));
-            return Object.class;
-        }
-        if (!(params[index] instanceof Class)) {
-            logger.warn(String.format("Warn: %s not set the actual class on superclass generic parameter",
-                    clazz.getSimpleName()));
-            return Object.class;
-        }
-        return (Class<?>) params[index];
+    public static Class<?> getSuperClassGenericType(final Class<?> clazz, final Class<?> genericIfc, final int index) {
+        Class<?>[] typeArguments = GenericTypeResolver.resolveTypeArguments(ClassUtils.getUserClass(clazz), genericIfc);
+        return null == typeArguments ? null : typeArguments[index];
     }
 
     /**
@@ -149,11 +138,11 @@ public final class ReflectionKit {
              * 中间表实体重写父类属性 ` private transient Date createTime; `
              */
             return fieldMap.values().stream()
-                    /* 过滤静态属性 */
-                    .filter(f -> !Modifier.isStatic(f.getModifiers()))
-                    /* 过滤 transient关键字修饰的属性 */
-                    .filter(f -> !Modifier.isTransient(f.getModifiers()))
-                    .collect(Collectors.toList());
+                /* 过滤静态属性 */
+                .filter(f -> !Modifier.isStatic(f.getModifiers()))
+                /* 过滤 transient关键字修饰的属性 */
+                .filter(f -> !Modifier.isTransient(f.getModifiers()))
+                .collect(Collectors.toList());
         });
     }
 
@@ -168,12 +157,12 @@ public final class ReflectionKit {
     public static Map<String, Field> excludeOverrideSuperField(Field[] fields, List<Field> superFieldList) {
         // 子类属性
         Map<String, Field> fieldMap = Stream.of(fields).collect(toMap(Field::getName, identity(),
-                (u, v) -> {
-                    throw new IllegalStateException(String.format("Duplicate key %s", u));
-                },
-                LinkedHashMap::new));
+            (u, v) -> {
+                throw new IllegalStateException(String.format("Duplicate key %s", u));
+            },
+            LinkedHashMap::new));
         superFieldList.stream().filter(field -> !fieldMap.containsKey(field.getName()))
-                .forEach(f -> fieldMap.put(f.getName(), f));
+            .forEach(f -> fieldMap.put(f.getName(), f));
         return fieldMap;
     }
 

+ 35 - 0
mybatis-plus-core/src/test/java/com/baomidou/mybatisplus/test/pom/ReflectionKitTest.java

@@ -0,0 +1,35 @@
+package com.baomidou.mybatisplus.test.pom;
+
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import com.baomidou.mybatisplus.core.mapper.Mapper;
+import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * 反工具类测试
+ */
+public class ReflectionKitTest {
+
+    public class MyEntity {
+    }
+
+    public interface Mapper1<T> extends BaseMapper<T> {
+    }
+
+    public interface Mapper2 extends Mapper<MyEntity> {
+    }
+
+    public interface Mapper3 extends Mapper2 {
+    }
+
+    @Test
+    void testSuperClassGenericType() {
+        // 多重继承测试
+        assertThat(ReflectionKit.getSuperClassGenericType(Mapper2.class,
+            Mapper.class, 0).equals(MyEntity.class));
+        assertThat(ReflectionKit.getSuperClassGenericType(Mapper3.class,
+            Mapper.class, 0).equals(MyEntity.class));
+    }
+}

+ 2 - 10
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -29,7 +29,6 @@ import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.session.SqlSession;
 import org.mybatis.spring.SqlSessionUtils;
 import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.core.ResolvableType;
 import org.springframework.transaction.annotation.Transactional;
 
 import java.io.Serializable;
@@ -81,20 +80,13 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     }
 
     protected Class<T> currentMapperClass() {
-        return (Class<T>) this.getResolvableType().as(ServiceImpl.class).getGeneric(0).getType();
+        return (Class<T>) ReflectionKit.getSuperClassGenericType(this.getClass(), ServiceImpl.class, 0);
     }
 
     protected Class<T> currentModelClass() {
-        return (Class<T>) this.getResolvableType().as(ServiceImpl.class).getGeneric(1).getType();
+        return (Class<T>) ReflectionKit.getSuperClassGenericType(this.getClass(), ServiceImpl.class, 1);
     }
 
-    /**
-     * @see ResolvableType
-     * @since 3.4.3
-     */
-    protected ResolvableType getResolvableType() {
-        return ResolvableType.forClass(ClassUtils.getUserClass(getClass()));
-    }
 
     /**
      * 批量操作 SqlSession

+ 19 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/test/service/ServiceTest.java

@@ -7,6 +7,8 @@ import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import static org.assertj.core.api.Assertions.assertThat;
+
 /**
  * @author nieqiurong 2021/1/19.
  */
@@ -38,4 +40,21 @@ public class ServiceTest {
         }
     }
 
+
+    static class MyServiceImpl<M extends BaseMapper<T>, T> extends ServiceImpl<M, T> {
+
+    }
+
+    static class MyServiceExtend extends MyServiceImpl<DemoMapper, Demo> {
+
+    }
+
+    @Test
+    void testSuperClassGenericType() {
+        // 多重继承测试
+        assertThat(ReflectionKit.getSuperClassGenericType(MyServiceExtend.class,
+            ServiceImpl.class, 0).equals(DemoMapper.class));
+        assertThat(ReflectionKit.getSuperClassGenericType(MyServiceExtend.class,
+            ServiceImpl.class, 1).equals(Demo.class));
+    }
 }

+ 13 - 13
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/MybatisMapperRegistryTest.java

@@ -50,31 +50,31 @@ import java.util.Map;
 @MapperScan(value = "com.baomidou.mybatisplus.test.h2")
 @ContextConfiguration(classes = {MybatisMapperRegistryTest.class, DBConfig.class})
 class MybatisMapperRegistryTest extends BaseTest {
-    
+
     private interface H2StudentChildrenMapper extends H2StudentMapper {
-    
+
     }
-    
+
     @Bean
     DBConfig dbConfig() {
         return new DBConfig();
     }
-    
+
     @Bean
     MybatisPlusConfig mybatisPlusConfig() {
         return new MybatisPlusConfig();
     }
-    
+
     @Bean
     SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
         MybatisSqlSessionFactoryBean sqlSessionFactory = new MybatisSqlSessionFactoryBean();
         sqlSessionFactory.setDataSource(dataSource);
         return sqlSessionFactory.getObject();
     }
-    
+
     @Autowired
     private SqlSessionFactory sqlSessionFactory;
-    
+
     @SuppressWarnings("unchecked")
     @Test
     void test() throws ReflectiveOperationException {
@@ -84,21 +84,21 @@ class MybatisMapperRegistryTest extends BaseTest {
             Assertions.assertTrue(mapperRegistry.hasMapper(H2UserMapper.class));
             Assertions.assertTrue(mapperRegistry.hasMapper(H2StudentChildrenMapper.class));
             H2StudentMapper studentMapper = mapperRegistry.getMapper(H2StudentMapper.class, sqlSession);
-            
+
             Assertions.assertTrue(configuration.hasStatement(H2StudentMapper.class.getName() + ".selectById"));
             studentMapper.selectById(1);
-            
+
             Field field = mapperRegistry.getClass().getDeclaredField("knownMappers");
             field.setAccessible(true);
             Map<Class<?>, MybatisMapperProxyFactory<?>> knownMappers = (Map<Class<?>, MybatisMapperProxyFactory<?>>) field.get(mapperRegistry);
             MybatisMapperProxyFactory<?> mybatisMapperProxyFactory = knownMappers.get(H2StudentChildrenMapper.class);
-            
-            
+
+
             H2StudentChildrenMapper h2StudentChildrenMapper = mapperRegistry.getMapper(H2StudentChildrenMapper.class, sqlSession);
-            Assertions.assertFalse(configuration.hasStatement(H2StudentChildrenMapper.class.getName() + ".selectById"));
+            Assertions.assertTrue(configuration.hasStatement(H2StudentChildrenMapper.class.getName() + ".selectById"));
             Map<Method, ?> methodCache = mybatisMapperProxyFactory.getMethodCache();
             Assertions.assertTrue(methodCache.isEmpty());
-            
+
             h2StudentChildrenMapper.selectById(2);
             methodCache = mybatisMapperProxyFactory.getMethodCache();
             Assertions.assertFalse(methodCache.isEmpty());