فهرست منبع

重构ServiceImpl参数提取.

https://github.com/baomidou/mybatis-plus/pull/6067
nieqiurong 1 سال پیش
والد
کامیت
3c1702ce30

+ 47 - 41
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -31,15 +31,11 @@ import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.session.ExecutorType;
 import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
-import org.mybatis.spring.SqlSessionTemplate;
 import org.mybatis.spring.SqlSessionUtils;
-import org.springframework.aop.framework.AopProxyUtils;
-import org.springframework.aop.support.AopUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.transaction.annotation.Transactional;
 
 import java.io.Serializable;
-import java.lang.reflect.Proxy;
 import java.util.Collection;
 import java.util.Map;
 import java.util.Optional;
@@ -61,48 +57,36 @@ public abstract class ServiceImpl<M extends BaseMapper<T>, T> implements IServic
     @Autowired
     protected M baseMapper;
 
-    protected final Class<?>[] typeArguments = GenericTypeUtils.resolveTypeArguments(getClass(), ServiceImpl.class);
-
     @Override
     public M getBaseMapper() {
-        return baseMapper;
+        Assert.notNull(this.baseMapper, "baseMapper can not be null");
+        return this.baseMapper;
     }
 
-    protected final Class<T> entityClass = currentModelClass();
+    /**
+     * @see #getEntityClass()
+     */
+    private Class<T> entityClass;
 
     @Override
     public Class<T> getEntityClass() {
-        return entityClass;
+        if(this.entityClass == null) {
+            this.entityClass = (Class<T>) GenericTypeUtils.resolveTypeArguments(this.getMapperClass(), BaseMapper.class)[0];
+        }
+        return this.entityClass;
     }
 
-    protected final Class<M> mapperClass = currentMapperClass();
+    /**
+     *  @see #currentMapperClass()
+     */
+    private Class<M> mapperClass;
 
     private volatile SqlSessionFactory sqlSessionFactory;
 
-    @SuppressWarnings({"rawtypes", "deprecation"})
     protected SqlSessionFactory getSqlSessionFactory() {
         if (this.sqlSessionFactory == null) {
-            synchronized (this) {
-                if (this.sqlSessionFactory == null) {
-                    Object target = this.baseMapper;
-                    // 这个检查目前看着来说基本上可以不用判断Aop是不是存在了.
-                    if (com.baomidou.mybatisplus.core.toolkit.AopUtils.isLoadSpringAop()) {
-                        while (AopUtils.isAopProxy(target)) {
-                            target = AopProxyUtils.getSingletonTarget(target);
-                        }
-                    }
-                    if (target != null && Proxy.isProxyClass(target.getClass())) {
-                        target = Proxy.getInvocationHandler(target);
-                    }
-                    if (target instanceof MybatisMapperProxy) {
-                        MybatisMapperProxy mybatisMapperProxy = (MybatisMapperProxy) target;
-                        SqlSessionTemplate sqlSessionTemplate = (SqlSessionTemplate) mybatisMapperProxy.getSqlSession();
-                        this.sqlSessionFactory = sqlSessionTemplate.getSqlSessionFactory();
-                    } else {
-                        this.sqlSessionFactory = GlobalConfigUtils.currentSessionFactory(this.entityClass);
-                    }
-                }
-            }
+            MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this.getBaseMapper());
+            this.sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
         }
         return this.sqlSessionFactory;
     }
@@ -119,12 +103,34 @@ public abstract class ServiceImpl<M extends BaseMapper<T>, T> implements IServic
         return SqlHelper.retBool(result);
     }
 
+    /**
+     * @return baseMapper 真实类型
+     * @deprecated 3.5.7 {@link #getMapperClass()}
+     */
+    @Deprecated
     protected Class<M> currentMapperClass() {
-        return (Class<M>) this.typeArguments[0];
+        return this.getMapperClass();
     }
 
+    /**
+     * @return baseMapper 真实类型
+     * @since 3.5.7
+     */
+    public Class<M> getMapperClass() {
+        if (this.mapperClass == null) {
+            MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this.getBaseMapper());
+            this.mapperClass = (Class<M>) mybatisMapperProxy.getMapperInterface();
+        }
+        return this.mapperClass;
+    }
+
+    /**
+     * @return 实体类型
+     * @deprecated 3.5.7 {@link #getEntityClass()}
+     */
+    @Deprecated
     protected Class<T> currentModelClass() {
-        return (Class<T>) this.typeArguments[1];
+        return getEntityClass();
     }
 
 
@@ -159,7 +165,7 @@ public abstract class ServiceImpl<M extends BaseMapper<T>, T> implements IServic
      */
     @Deprecated
     protected String sqlStatement(SqlMethod sqlMethod) {
-        return SqlHelper.table(entityClass).getSqlStatement(sqlMethod.getMethod());
+        return SqlHelper.table(getEntityClass()).getSqlStatement(sqlMethod.getMethod());
     }
 
     /**
@@ -184,7 +190,7 @@ public abstract class ServiceImpl<M extends BaseMapper<T>, T> implements IServic
      * @since 3.4.0
      */
     protected String getSqlStatement(SqlMethod sqlMethod) {
-        return SqlHelper.getSqlStatement(mapperClass, sqlMethod);
+        return SqlHelper.getSqlStatement(this.currentMapperClass(), sqlMethod);
     }
 
     /**
@@ -201,11 +207,11 @@ public abstract class ServiceImpl<M extends BaseMapper<T>, T> implements IServic
     @Transactional(rollbackFor = Exception.class)
     @Override
     public boolean saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
-        TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
+        TableInfo tableInfo = TableInfoHelper.getTableInfo(this.getEntityClass());
         Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
         String keyProperty = tableInfo.getKeyProperty();
         Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
-        return SqlHelper.saveOrUpdateBatch(getSqlSessionFactory(), this.mapperClass, this.log, entityList, batchSize, (sqlSession, entity) -> {
+        return SqlHelper.saveOrUpdateBatch(getSqlSessionFactory(), this.currentMapperClass(), this.log, entityList, batchSize, (sqlSession, entity) -> {
             Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
             return StringUtils.checkValNull(idVal)
                 || CollectionUtils.isEmpty(sqlSession.selectList(getSqlStatement(SqlMethod.SELECT_BY_ID), entity));
@@ -229,17 +235,17 @@ public abstract class ServiceImpl<M extends BaseMapper<T>, T> implements IServic
 
     @Override
     public T getOne(Wrapper<T> queryWrapper, boolean throwEx) {
-        return baseMapper.selectOne(queryWrapper, throwEx);
+        return getBaseMapper().selectOne(queryWrapper, throwEx);
     }
 
     @Override
     public Optional<T> getOneOpt(Wrapper<T> queryWrapper, boolean throwEx) {
-        return Optional.ofNullable(baseMapper.selectOne(queryWrapper, throwEx));
+        return Optional.ofNullable(getBaseMapper().selectOne(queryWrapper, throwEx));
     }
 
     @Override
     public Map<String, Object> getMap(Wrapper<T> queryWrapper) {
-        return SqlHelper.getObject(log, baseMapper.selectMaps(queryWrapper));
+        return SqlHelper.getObject(log, getBaseMapper().selectMaps(queryWrapper));
     }
 
     @Override

+ 16 - 3
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/test/service/ServiceTest.java

@@ -1,11 +1,15 @@
 package com.baomidou.mybatisplus.test.service;
 
 import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import com.baomidou.mybatisplus.core.override.MybatisMapperProxyFactory;
 import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
 import com.baomidou.mybatisplus.extension.service.IService;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
+import org.apache.ibatis.session.SqlSession;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -24,19 +28,28 @@ public class ServiceTest {
 
     static class DemoServiceImpl extends ServiceImpl<DemoMapper, Demo> {
 
+        public DemoServiceImpl(BaseMapper<Demo> baseMapper) {
+            super.baseMapper = (DemoMapper) baseMapper;
+        }
     }
 
     static class DemoServiceExtend extends DemoServiceImpl {
 
+        public DemoServiceExtend(BaseMapper<Demo> baseMapper) {
+            super(baseMapper);
+        }
     }
 
     @Test
     @SuppressWarnings("unchecked")
     void genericTest() {
-        IService<Demo>[] services = new IService[]{new DemoServiceImpl(), new DemoServiceExtend()};
+        MybatisMapperProxyFactory<? extends BaseMapper<?>> mybatisMapperProxyFactory = new MybatisMapperProxyFactory<>(DemoMapper.class);
+        BaseMapper<Demo> baseMapper = (BaseMapper<Demo>) mybatisMapperProxyFactory.newInstance(Mockito.mock(SqlSession.class));
+        IService<Demo>[] services = new IService[]{new DemoServiceImpl(baseMapper), new DemoServiceExtend(baseMapper)};
         for (IService<Demo> service : services) {
-            Assertions.assertEquals(Demo.class, service.getEntityClass());
-            Assertions.assertEquals(DemoMapper.class, ReflectionKit.getFieldValue(service, "mapperClass"));
+            ServiceImpl<?,?> impl = (ServiceImpl<?,?>) service;
+            Assertions.assertEquals(Demo.class, impl.getEntityClass());
+            Assertions.assertEquals(DemoMapper.class, impl.getMapperClass());
         }
     }