Ver código fonte

处理批量操作嵌套事物问题.

聂秋秋 5 anos atrás
pai
commit
a09cd7fa4f

+ 11 - 2
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -32,9 +32,11 @@ import org.apache.ibatis.session.ExecutorType;
 import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
 import org.mybatis.spring.MyBatisExceptionTranslator;
+import org.mybatis.spring.SqlSessionHolder;
 import org.mybatis.spring.SqlSessionUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.transaction.annotation.Transactional;
+import org.springframework.transaction.support.TransactionSynchronizationManager;
 
 import java.io.Serializable;
 import java.util.Collection;
@@ -317,12 +319,19 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
      */
     protected void executeBatch(Consumer<SqlSession> fun) {
         Class<T> tClass = currentModelClass();
-        SqlHelper.clearCache(tClass);
         SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(tClass);
+        SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);
+        boolean transaction = TransactionSynchronizationManager.isSynchronizationActive();
+        if (sqlSessionHolder != null) {
+            SqlSession sqlSession = sqlSessionHolder.getSqlSession();
+            //原生无法支持执行器切换,当存在批量操作时,会嵌套两个session的,优先commit上一个session。
+            sqlSession.commit();
+        }
         SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH);
         try {
             fun.accept(sqlSession);
-            sqlSession.commit();
+            //非事物情况下,强制commit。
+            sqlSession.commit(!transaction);
         } catch (Throwable t) {
             sqlSession.rollback();
             Throwable unwrapped = ExceptionUtil.unwrapThrowable(t);

+ 1 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SqlHelper.java

@@ -131,6 +131,7 @@ public final class SqlHelper {
      *
      * @param clazz 实体类
      */
+    @Deprecated
     public static void clearCache(Class<?> clazz) {
         SqlSessionFactory sqlSessionFactory = GlobalConfigUtils.currentSessionFactory(clazz);
         SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);

+ 30 - 10
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java

@@ -281,28 +281,48 @@ class H2UserTest extends BaseTest {
     @Order(21)
     void testSaveBatch() {
         Assertions.assertTrue(userService.saveBatch(Arrays.asList(new H2User("saveBatch1"), new H2User("saveBatch2"), new H2User("saveBatch3"), new H2User("saveBatch4"))));
+        Assertions.assertEquals(4, userService.count(new QueryWrapper<H2User>().like("name", "saveBatch")));
         Assertions.assertTrue(userService.saveBatch(Arrays.asList(new H2User("saveBatch5"), new H2User("saveBatch6"), new H2User("saveBatch7"), new H2User("saveBatch8")), 2));
+        Assertions.assertEquals(8, userService.count(new QueryWrapper<H2User>().like("name", "saveBatch")));
     }
 
     @Test
     @Order(22)
     void testUpdateBatch() {
-        Assertions.assertTrue(userService.updateBatchById(Arrays.asList(new H2User(1010L, "batch1010"), new H2User(1011L, "batch1011"), new H2User(1010L, "batch1010"), new H2User(1012L, "batch1012"))));
-        Assertions.assertTrue(userService.updateBatchById(Arrays.asList(new H2User(1010L, "batch1010A"), new H2User(1011L, "batch1011A"), new H2User(1010L, "batch1010"), new H2User(1012L, "batch1012")), 1));
+        Assertions.assertTrue(userService.updateBatchById(Arrays.asList(new H2User(1010L, "batch1010"),
+            new H2User(1011L, "batch1011"), new H2User(1010L, "batch1010"), new H2User(1012L, "batch1012"))));
+        Assertions.assertEquals(userService.getById(1010L).getName(), "batch1010");
+        Assertions.assertEquals(userService.getById(1011L).getName(), "batch1011");
+        Assertions.assertEquals(userService.getById(1012L).getName(), "batch1012");
+        Assertions.assertTrue(userService.updateBatchById(Arrays.asList(new H2User(1010L, "batch1010A"),
+            new H2User(1011L, "batch1011A"), new H2User(1010L, "batch1010"), new H2User(1012L, "batch1012")), 1));
+        Assertions.assertEquals(userService.getById(1010L).getName(), "batch1010");
+        Assertions.assertEquals(userService.getById(1011L).getName(), "batch1011A");
+        Assertions.assertEquals(userService.getById(1012L).getName(), "batch1012");
     }
 
     @Test
     @Order(23)
     void testSaveOrUpdateBatch() {
-        Assertions.assertTrue(userService.saveOrUpdateBatch(Arrays.asList(new H2User(1010L, "batch1010"), new H2User("batch1011"), new H2User(1010L, "batch1010"), new H2User("batch1015"))));
-        Assertions.assertTrue(userService.saveOrUpdateBatch(Arrays.asList(new H2User(1010L, "batch1010A"), new H2User("batch1011A"), new H2User(1010L, "batch1010"), new H2User("batch1016")), 1));
+        Assertions.assertTrue(userService.saveOrUpdateBatch(Arrays.asList(new H2User(1010L, "batch1010"),
+            new H2User("batch1011"), new H2User(1010L, "batch1010"), new H2User("batch1015"))));
+        Assertions.assertEquals(userService.getById(1010L).getName(), "batch1010");
+        Assertions.assertEquals(userService.count(new QueryWrapper<H2User>().eq("name","batch1011")), 1);
+        Assertions.assertEquals(userService.count(new QueryWrapper<H2User>().eq("name","batch1015")), 1);
+        Assertions.assertTrue(userService.saveOrUpdateBatch(Arrays.asList(new H2User(1010L, "batch1010A"),
+            new H2User("batch1011AB"), new H2User(1010L, "batch1010"), new H2User("batch1016")), 1));
+        Assertions.assertEquals(userService.getById(1010L).getName(), "batch1010");
+        Assertions.assertEquals(userService.count(new QueryWrapper<H2User>().eq("name","batch1011AB")), 1);
+        Assertions.assertEquals(userService.count(new QueryWrapper<H2User>().eq("name","batch1016")), 1);
     }
 
     @Test
     @Order(24)
     void testSimpleAndBatch() {
         Assertions.assertTrue(userService.save(new H2User("testSimpleAndBatch1", 0)));
+        Assertions.assertEquals(1, userService.count(new QueryWrapper<H2User>().eq("name", "testSimpleAndBatch1")));
         Assertions.assertTrue(userService.saveOrUpdateBatch(Arrays.asList(new H2User("testSimpleAndBatch2"), new H2User("testSimpleAndBatch3"), new H2User("testSimpleAndBatch4")), 1));
+        Assertions.assertEquals(4, userService.count(new QueryWrapper<H2User>().like("name", "testSimpleAndBatch")));
     }
 
     @Test
@@ -324,10 +344,10 @@ class H2UserTest extends BaseTest {
         Assertions.assertNotEquals(0L, userService.lambdaQuery().like(H2User::getName, "a").count().longValue());
 
         List<H2User> users = userService.lambdaQuery().like(H2User::getName, "T")
-                .ne(H2User::getAge, AgeEnum.TWO)
-                .ge(H2User::getVersion, 1)
-                .isNull(H2User::getPrice)
-                .list();
+            .ne(H2User::getAge, AgeEnum.TWO)
+            .ge(H2User::getVersion, 1)
+            .isNull(H2User::getPrice)
+            .list();
         Assertions.assertTrue(users.isEmpty());
     }
 
@@ -346,8 +366,8 @@ class H2UserTest extends BaseTest {
     void testSaveBatchException() {
         try {
             userService.saveBatch(Arrays.asList(
-                    new H2User(1L, "tom"),
-                    new H2User(1L, "andy")
+                new H2User(1L, "tom"),
+                new H2User(1L, "andy")
             ));
         } catch (Exception e) {
             Assertions.assertTrue(e instanceof DataAccessException);

+ 26 - 4
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/CacheTest.java

@@ -5,10 +5,7 @@ import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
 import com.baomidou.mybatisplus.test.h2.cache.model.CacheModel;
 import com.baomidou.mybatisplus.test.h2.cache.service.ICacheService;
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.MethodOrderer;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.TestMethodOrder;
+import org.junit.jupiter.api.*;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.test.context.ContextConfiguration;
@@ -25,6 +22,7 @@ class CacheTest {
     private ICacheService cacheService;
 
     @Test
+    @Order(1)
     void testPageCache() {
         IPage<CacheModel> cacheModelIPage1 = cacheService.page(new Page<>(1, 3), new QueryWrapper<>());
         IPage<CacheModel> cacheModelIPage2 = cacheService.page(new Page<>(1, 3), new QueryWrapper<>());
@@ -60,6 +58,7 @@ class CacheTest {
     }
 
     @Test
+    @Order(2)
     void testCleanBatchCache() {
         CacheModel model = new CacheModel("靓仔");
         cacheService.save(model);
@@ -68,4 +67,27 @@ class CacheTest {
         Assertions.assertEquals(cacheService.getById(model.getId()).getName(),"旺仔");
     }
 
+    @Test
+    @Order(3)
+    void testBatchTransactionalClear1() {
+        long id = cacheService.testBatchTransactionalClear1();
+        CacheModel cacheModel = cacheService.getById(id);
+        Assertions.assertEquals(cacheModel.getName(), "旺仔");
+    }
+
+    @Test
+    @Order(4)
+    void testBatchTransactionalClear2() {
+        long id = cacheService.testBatchTransactionalClear2();
+        CacheModel cacheModel = cacheService.getById(id);
+        Assertions.assertEquals(cacheModel.getName(), "小红");
+    }
+
+    @Test
+    @Order(5)
+    void testBatchTransactionalClear3() {
+        long id = cacheService.testBatchTransactionalClear3();
+        CacheModel cacheModel = cacheService.getById(id);
+        Assertions.assertEquals(cacheModel.getName(), "小红");
+    }
 }

+ 5 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/service/ICacheService.java

@@ -5,4 +5,9 @@ import com.baomidou.mybatisplus.test.h2.cache.model.CacheModel;
 
 public interface ICacheService extends IService<CacheModel> {
 
+    long testBatchTransactionalClear1();
+
+    long testBatchTransactionalClear2();
+
+    long testBatchTransactionalClear3();
 }

+ 36 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/service/impl/CacheServiceImpl.java

@@ -5,8 +5,44 @@ import com.baomidou.mybatisplus.test.h2.cache.mapper.CacheMapper;
 import com.baomidou.mybatisplus.test.h2.cache.model.CacheModel;
 import com.baomidou.mybatisplus.test.h2.cache.service.ICacheService;
 import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
+
+import java.util.Collections;
 
 @Service
 public class CacheServiceImpl extends ServiceImpl<CacheMapper, CacheModel> implements ICacheService {
 
+    @Override
+    @Transactional
+    public long testBatchTransactionalClear1() {
+        CacheModel model = new CacheModel("靓仔");
+        save(model);
+        getById(model.getId());
+        updateBatchById(Collections.singletonList(new CacheModel(model.getId(), "旺仔")));
+        return model.getId();
+    }
+
+    @Override
+    @Transactional
+    public long testBatchTransactionalClear2() {
+        CacheModel model = new CacheModel("靓仔");
+        save(model);
+        getById(model.getId());
+        updateBatchById(Collections.singletonList(new CacheModel(model.getId(), "旺仔")));
+        model.setName("小红");
+        updateById(model);
+        return model.getId();
+    }
+
+    @Override
+    @Transactional
+    public long testBatchTransactionalClear3() {
+        CacheModel model = new CacheModel("靓仔");
+        save(model);
+        updateBatchById(Collections.singletonList(new CacheModel(model.getId(), "旺仔")));
+        model.setName("小红");
+        updateById(model);
+        getById(model.getId());
+        return model.getId();
+    }
 }