miemie 5 yıl önce
ebeveyn
işleme
96109b6713

+ 14 - 9
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/BaseDbTest.java

@@ -14,6 +14,7 @@ import org.apache.ibatis.plugin.Interceptor;
 import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
 import org.apache.ibatis.transaction.jdbc.JdbcTransactionFactory;
+import org.apache.ibatis.type.TypeReference;
 import org.h2.Driver;
 import org.springframework.core.io.ClassPathResource;
 import org.springframework.jdbc.core.JdbcTemplate;
@@ -28,10 +29,12 @@ import java.util.function.Consumer;
  * @author miemie
  * @since 2020-06-23
  */
-public abstract class BaseDbTest {
+public abstract class BaseDbTest<T> extends TypeReference<T> {
 
     protected SqlSessionFactory sqlSessionFactory;
+    protected Class<T> mapper;
 
+    @SuppressWarnings("unchecked")
     public BaseDbTest() {
         DataSource ds = dataSource();
         List<String> tableSql = tableSql();
@@ -39,7 +42,7 @@ public abstract class BaseDbTest {
         String mapperXml = mapperXml();
         GlobalConfig globalConfig = globalConfig();
         List<Interceptor> interceptors = interceptors();
-        List<Class<?>> mappers = mappers();
+        mapper = (Class<T>) getRawType();
 
         JdbcTemplate template = new JdbcTemplate(ds);
         if (CollectionUtils.isNotEmpty(tableSql)) {
@@ -68,7 +71,7 @@ public abstract class BaseDbTest {
                 throw ExceptionUtils.mpe(e);
             }
         }
-        mappers.forEach(configuration::addMapper);
+        configuration.addMapper(mapper);
         if (CollectionUtils.isNotEmpty(interceptors)) {
             interceptors.forEach(configuration::addInterceptor);
         }
@@ -84,13 +87,17 @@ public abstract class BaseDbTest {
         return dataSource;
     }
 
-    protected <T> void doTest(Class<T> mapper, Consumer<T> consumer) {
-        try (SqlSession sqlSession = sqlSessionFactory.openSession(true)) {
-            doTest(sqlSession, mapper, consumer);
+    protected SqlSession autoCommitSession() {
+        return sqlSessionFactory.openSession(true);
+    }
+
+    protected void doTestAutoCommit(Consumer<T> consumer) {
+        try (SqlSession sqlSession = autoCommitSession()) {
+            doTest(sqlSession, consumer);
         }
     }
 
-    protected <T> void doTest(SqlSession sqlSession, Class<T> mapper, Consumer<T> consumer) {
+    protected void doTest(SqlSession sqlSession, Consumer<T> consumer) {
         T t = sqlSession.getMapper(mapper);
         consumer.accept(t);
     }
@@ -111,8 +118,6 @@ public abstract class BaseDbTest {
         return null;
     }
 
-    protected abstract List<Class<?>> mappers();
-
     protected GlobalConfig globalConfig() {
         return GlobalConfigUtils.defaults();
     }

+ 7 - 10
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/pagination/PaginationTest.java

@@ -19,14 +19,14 @@ import static org.assertj.core.api.Assertions.assertThat;
  * @author miemie
  * @since 2020-06-23
  */
-public class PaginationTest extends BaseDbTest {
+public class PaginationTest extends BaseDbTest<EntityMapper> {
 
     @Test
     void page() {
         Cache cache = sqlSessionFactory.getConfiguration().getCache(EntityMapper.class.getName());
         assertThat(cache).as("使用 @CacheNamespace 指定了使用缓存").isNotNull();
 
-        doTest(EntityMapper.class, m -> {
+        doTestAutoCommit(m -> {
             Page<Entity> page = new Page<>(1, 5);
             IPage<Entity> result = m.selectPage(page, null);
             assertThat(page).isEqualTo(result);
@@ -36,7 +36,7 @@ public class PaginationTest extends BaseDbTest {
         assertThat(cache.getSize()).as("一条count缓存一条分页缓存").isEqualTo(2);
 
 
-        doTest(EntityMapper.class, m -> {
+        doTestAutoCommit(m -> {
             Page<Entity> page = new Page<>(1, 5);
             IPage<Entity> result = m.selectPage(page, null);
             assertThat(page).isEqualTo(result);
@@ -45,10 +45,12 @@ public class PaginationTest extends BaseDbTest {
         });
         assertThat(cache.getSize()).as("一条count缓存一条分页缓存").isEqualTo(2);
 
-        doTest(EntityMapper.class, m -> m.insert(new Entity()));
+
+        doTestAutoCommit(m -> m.insert(new Entity()));
         assertThat(cache.getSize()).as("update 操作清除了所有缓存").isEqualTo(0);
 
-        doTest(EntityMapper.class, m -> {
+
+        doTestAutoCommit(m -> {
             Page<Entity> page = new Page<>(1, 5);
             IPage<Entity> result = m.selectPage(page, null);
             assertThat(page).isEqualTo(result);
@@ -65,11 +67,6 @@ public class PaginationTest extends BaseDbTest {
         return Collections.singletonList(interceptor);
     }
 
-    @Override
-    protected List<Class<?>> mappers() {
-        return Collections.singletonList(EntityMapper.class);
-    }
-
     @Override
     protected String tableDataSql() {
         return "insert into entity(id,name) values(1,'1');\n" +