Jelajahi Sumber

优化批量操作方法.

聂秋秋 5 tahun lalu
induk
melakukan
1e78257dec

+ 49 - 3
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -43,6 +43,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -127,7 +128,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     @Override
     public boolean saveBatch(Collection<T> entityList, int batchSize) {
         String sqlStatement = sqlStatement(SqlMethod.INSERT_ONE);
-        return executeBatch(sqlSession -> execute(sqlSession, entityList, batchSize, entity -> sqlSession.insert(sqlStatement, entity)));
+        return executeBatch(entityList, batchSize, (sqlSession, entity) -> sqlSession.insert(sqlStatement, entity));
     }
 
     /**
@@ -160,7 +161,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
         Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
         String keyProperty = tableInfo.getKeyProperty();
         Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
-        return executeBatch(sqlSession -> execute(sqlSession, entityList, batchSize, entity -> {
+        return executeBatch(entityList, batchSize, ((sqlSession, entity) -> {
             Object idVal = ReflectionKit.getMethodValue(cls, entity, keyProperty);
             if (StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal))) {
                 sqlSession.insert(sqlStatement(SqlMethod.INSERT_ONE), entity);
@@ -211,7 +212,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     public boolean updateBatchById(Collection<T> entityList, int batchSize) {
         Assert.notEmpty(entityList, "error: entityList must not be empty");
         String sqlStatement = sqlStatement(SqlMethod.UPDATE_BY_ID);
-        return executeBatch(sqlSession -> execute(sqlSession, entityList, batchSize, entity -> {
+        return executeBatch(entityList, batchSize, ((sqlSession, entity) -> {
             MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
             param.put(Constants.ENTITY, entity);
             sqlSession.update(sqlStatement, param);
@@ -286,7 +287,9 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
      *
      * @param fun fun
      * @since 3.3.0
+     * @deprecated 3.3.1
      */
+    @Deprecated
     protected boolean executeBatch(Consumer<SqlSession> fun) {
         Class<T> tClass = currentModelClass();
         SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(tClass);
@@ -321,6 +324,48 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
         }
     }
 
+    protected <E> boolean executeBatch(Collection<E> entityList, int batchSize, BiConsumer<SqlSession, E> consumer) {
+        Class<T> tClass = currentModelClass();
+        SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(tClass);
+        SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);
+        boolean transaction = TransactionSynchronizationManager.isSynchronizationActive();
+        if (sqlSessionHolder != null) {
+            SqlSession sqlSession = sqlSessionHolder.getSqlSession();
+            //原生无法支持执行器切换,当存在批量操作时,会嵌套两个session的,优先commit上一个session
+            //按道理来说,这里的值应该一直为false。
+            sqlSession.commit(!transaction);
+        }
+        SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH);
+        if (!transaction) {
+            log.warn("SqlSession [" + sqlSession + "] was not registered for synchronization because DataSource is not transactional");
+        }
+        try {
+            int size = entityList.size();
+            int i = 1;
+            for (E entity : entityList) {
+                consumer.accept(sqlSession, entity);
+                if ((i % batchSize == 0) || i == size) {
+                    sqlSession.flushStatements();
+                }
+                i++;
+            }
+            //非事物情况下,强制commit。
+            sqlSession.commit(!transaction);
+            return true;
+        } catch (Throwable t) {
+            sqlSession.rollback();
+            Throwable unwrapped = ExceptionUtil.unwrapThrowable(t);
+            if (unwrapped instanceof RuntimeException) {
+                MyBatisExceptionTranslator myBatisExceptionTranslator
+                    = new MyBatisExceptionTranslator(sqlSessionFactory.getConfiguration().getEnvironment().getDataSource(), true);
+                throw Objects.requireNonNull(myBatisExceptionTranslator.translateExceptionIfPossible((RuntimeException) unwrapped));
+            }
+            throw ExceptionUtils.mpe(unwrapped);
+        } finally {
+            sqlSession.close();
+        }
+    }
+
     /**
      * 执行批量操作
      *
@@ -330,6 +375,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
      * @param consumer   执行方法
      * @since 3.3.1
      */
+    @Deprecated
     protected <E> void execute(SqlSession sqlSession, Collection<E> entityList, int batchSize, Consumer<E> consumer) {
         int size = entityList.size();
         int i = 1;

+ 1 - 1
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/service/impl/CacheServiceImpl.java

@@ -17,7 +17,7 @@ public class CacheServiceImpl extends ServiceImpl<CacheMapper, CacheModel> imple
     //手动撸一个批量删除.
     private void removeBatchById(Collection<Long> idList) {
         String sqlStatement = sqlStatement(SqlMethod.DELETE_BY_ID);
-        executeBatch(sqlSession -> execute(sqlSession, idList, idList.size(), id -> sqlSession.delete(sqlStatement, id)));
+        executeBatch(idList,idList.size(),(sqlSession, id) -> sqlSession.delete(sqlStatement,id));
     }
 
     @Override