浏览代码

抽取批量执行代码,方便用户自定义.

聂秋秋 5 年之前
父节点
当前提交
30d54334f3

+ 38 - 46
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -127,18 +127,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     @Override
     @Override
     public boolean saveBatch(Collection<T> entityList, int batchSize) {
     public boolean saveBatch(Collection<T> entityList, int batchSize) {
         String sqlStatement = sqlStatement(SqlMethod.INSERT_ONE);
         String sqlStatement = sqlStatement(SqlMethod.INSERT_ONE);
-        int size = entityList.size();
-        executeBatch(sqlSession -> {
-            int i = 1;
-            for (T entity : entityList) {
-                sqlSession.insert(sqlStatement, entity);
-                if ((i % batchSize == 0) || i == size) {
-                    sqlSession.flushStatements();
-                }
-                i++;
-            }
-        });
-        return true;
+        return executeBatch(sqlSession -> execute(sqlSession, entityList, batchSize, entity -> sqlSession.insert(sqlStatement, entity)));
     }
     }
 
 
     /**
     /**
@@ -171,26 +160,16 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
         Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
         Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
         String keyProperty = tableInfo.getKeyProperty();
         String keyProperty = tableInfo.getKeyProperty();
         Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
         Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
-        int size = entityList.size();
-        executeBatch(sqlSession -> {
-            int i = 1;
-            for (T entity : entityList) {
-                Object idVal = ReflectionKit.getMethodValue(cls, entity, keyProperty);
-                if (StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal))) {
-                    sqlSession.insert(sqlStatement(SqlMethod.INSERT_ONE), entity);
-                } else {
-                    MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
-                    param.put(Constants.ENTITY, entity);
-                    sqlSession.update(sqlStatement(SqlMethod.UPDATE_BY_ID), param);
-                }
-                // 不知道以后会不会有人说更新失败了还要执行插入 😂😂😂
-                if ((i % batchSize == 0) || i == size) {
-                    sqlSession.flushStatements();
-                }
-                i++;
+        return executeBatch(sqlSession -> execute(sqlSession, entityList, batchSize, entity -> {
+            Object idVal = ReflectionKit.getMethodValue(cls, entity, keyProperty);
+            if (StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal))) {
+                sqlSession.insert(sqlStatement(SqlMethod.INSERT_ONE), entity);
+            } else {
+                MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
+                param.put(Constants.ENTITY, entity);
+                sqlSession.update(sqlStatement(SqlMethod.UPDATE_BY_ID), param);
             }
             }
-        });
-        return true;
+        }));
     }
     }
 
 
     @Override
     @Override
@@ -232,20 +211,11 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     public boolean updateBatchById(Collection<T> entityList, int batchSize) {
     public boolean updateBatchById(Collection<T> entityList, int batchSize) {
         Assert.notEmpty(entityList, "error: entityList must not be empty");
         Assert.notEmpty(entityList, "error: entityList must not be empty");
         String sqlStatement = sqlStatement(SqlMethod.UPDATE_BY_ID);
         String sqlStatement = sqlStatement(SqlMethod.UPDATE_BY_ID);
-        int size = entityList.size();
-        executeBatch(sqlSession -> {
-            int i = 1;
-            for (T anEntityList : entityList) {
-                MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
-                param.put(Constants.ENTITY, anEntityList);
-                sqlSession.update(sqlStatement, param);
-                if ((i % batchSize == 0) || i == size) {
-                    sqlSession.flushStatements();
-                }
-                i++;
-            }
-        });
-        return true;
+        return executeBatch(sqlSession -> execute(sqlSession, entityList, batchSize, entity -> {
+            MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
+            param.put(Constants.ENTITY, entity);
+            sqlSession.update(sqlStatement, param);
+        }));
     }
     }
 
 
     @Override
     @Override
@@ -317,7 +287,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
      * @param fun fun
      * @param fun fun
      * @since 3.3.0
      * @since 3.3.0
      */
      */
-    protected void executeBatch(Consumer<SqlSession> fun) {
+    protected boolean executeBatch(Consumer<SqlSession> fun) {
         Class<T> tClass = currentModelClass();
         Class<T> tClass = currentModelClass();
         SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(tClass);
         SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(tClass);
         SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);
         SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);
@@ -336,6 +306,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
             fun.accept(sqlSession);
             fun.accept(sqlSession);
             //非事物情况下,强制commit。
             //非事物情况下,强制commit。
             sqlSession.commit(!transaction);
             sqlSession.commit(!transaction);
+            return true;
         } catch (Throwable t) {
         } catch (Throwable t) {
             sqlSession.rollback();
             sqlSession.rollback();
             Throwable unwrapped = ExceptionUtil.unwrapThrowable(t);
             Throwable unwrapped = ExceptionUtil.unwrapThrowable(t);
@@ -349,4 +320,25 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
             sqlSession.close();
             sqlSession.close();
         }
         }
     }
     }
+
+    /**
+     * 执行批量操作
+     *
+     * @param sqlSession sqlSession
+     * @param entityList 数据集合
+     * @param batchSize  批量大小
+     * @param consumer   执行方法
+     * @since 3.3.1
+     */
+    protected <E> void execute(SqlSession sqlSession, Collection<E> entityList, int batchSize, Consumer<E> consumer) {
+        int size = entityList.size();
+        int i = 1;
+        for (E entity : entityList) {
+            consumer.accept(entity);
+            if ((i % batchSize == 0) || i == size) {
+                sqlSession.flushStatements();
+            }
+            i++;
+        }
+    }
 }
 }

+ 1 - 1
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/service/impl/CacheServiceImpl.java

@@ -17,7 +17,7 @@ public class CacheServiceImpl extends ServiceImpl<CacheMapper, CacheModel> imple
     //手动撸一个批量删除.
     //手动撸一个批量删除.
     private void removeBatchById(Collection<Long> idList) {
     private void removeBatchById(Collection<Long> idList) {
         String sqlStatement = sqlStatement(SqlMethod.DELETE_BY_ID);
         String sqlStatement = sqlStatement(SqlMethod.DELETE_BY_ID);
-        executeBatch(sqlSession -> idList.forEach(id -> sqlSession.delete(sqlStatement, id)));
+        executeBatch(sqlSession -> execute(sqlSession, idList, idList.size(), id -> sqlSession.delete(sqlStatement, id)));
     }
     }
 
 
     @Override
     @Override