Browse Source

批量更新批次数据提交切割处理 (#6086)

VampireAchao 1 year ago
parent
commit
661341fa2d

+ 20 - 6
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/batch/MybatisBatch.java

@@ -17,6 +17,7 @@ package com.baomidou.mybatisplus.core.batch;
 
 
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
 import com.baomidou.mybatisplus.core.enums.SqlMethod;
 import com.baomidou.mybatisplus.core.enums.SqlMethod;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.Constants;
 import com.baomidou.mybatisplus.core.toolkit.Constants;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import org.apache.ibatis.executor.BatchResult;
 import org.apache.ibatis.executor.BatchResult;
@@ -61,9 +62,18 @@ public class MybatisBatch<T> {
 
 
     private final Collection<T> dataList;
     private final Collection<T> dataList;
 
 
+    private final int batchSize;
+
     public MybatisBatch(SqlSessionFactory sqlSessionFactory, Collection<T> dataList) {
     public MybatisBatch(SqlSessionFactory sqlSessionFactory, Collection<T> dataList) {
         this.sqlSessionFactory = sqlSessionFactory;
         this.sqlSessionFactory = sqlSessionFactory;
         this.dataList = dataList;
         this.dataList = dataList;
+        this.batchSize = Constants.DEFAULT_BATCH_SIZE;
+    }
+
+    public MybatisBatch(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, int batchSize) {
+        this.sqlSessionFactory = sqlSessionFactory;
+        this.dataList = dataList;
+        this.batchSize = batchSize;
     }
     }
 
 
     /**
     /**
@@ -129,13 +139,17 @@ public class MybatisBatch<T> {
      * @return 批处理结果
      * @return 批处理结果
      */
      */
     public List<BatchResult> execute(boolean autoCommit, String statement, ParameterConvert<T> parameterConvert) {
     public List<BatchResult> execute(boolean autoCommit, String statement, ParameterConvert<T> parameterConvert) {
+        List<BatchResult> resultList = new ArrayList<>(dataList.size());
         try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, autoCommit)) {
         try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, autoCommit)) {
-            for (T data : dataList) {
-                sqlSession.update(statement, toParameter(parameterConvert, data));
-            }
-            List<BatchResult> resultList = sqlSession.flushStatements();
-            if(!autoCommit) {
-                sqlSession.commit();
+            List<List<T>> split = CollectionUtils.split(dataList, batchSize);
+            for (List<T> splitedList : split) {
+                for (T data : splitedList) {
+                    sqlSession.update(statement, toParameter(parameterConvert, data));
+                }
+                resultList.addAll(sqlSession.flushStatements());
+                if (!autoCommit) {
+                    sqlSession.commit();
+                }
             }
             }
             return resultList;
             return resultList;
         }
         }

+ 50 - 12
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/mapper/BaseMapper.java

@@ -25,14 +25,7 @@ import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.override.MybatisMapperProxy;
 import com.baomidou.mybatisplus.core.override.MybatisMapperProxy;
-import com.baomidou.mybatisplus.core.toolkit.Assert;
-import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
-import com.baomidou.mybatisplus.core.toolkit.Constants;
-import com.baomidou.mybatisplus.core.toolkit.MybatisBatchUtils;
-import com.baomidou.mybatisplus.core.toolkit.MybatisUtils;
-import com.baomidou.mybatisplus.core.toolkit.StringPool;
-import com.baomidou.mybatisplus.core.toolkit.StringUtils;
-import com.baomidou.mybatisplus.core.toolkit.Wrappers;
+import com.baomidou.mybatisplus.core.toolkit.*;
 import com.baomidou.mybatisplus.core.toolkit.reflect.GenericTypeUtils;
 import com.baomidou.mybatisplus.core.toolkit.reflect.GenericTypeUtils;
 import org.apache.ibatis.annotations.Param;
 import org.apache.ibatis.annotations.Param;
 import org.apache.ibatis.exceptions.TooManyResultsException;
 import org.apache.ibatis.exceptions.TooManyResultsException;
@@ -43,6 +36,7 @@ import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
 import org.apache.ibatis.session.SqlSessionFactory;
 
 
 import java.io.Serializable;
 import java.io.Serializable;
+import java.lang.reflect.Proxy;
 import java.util.ArrayList;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashMap;
@@ -461,10 +455,21 @@ public interface BaseMapper<T> extends Mapper<T> {
      * @since 3.5.7
      * @since 3.5.7
      */
      */
     default List<BatchResult> saveBatch(Collection<T> entityList) {
     default List<BatchResult> saveBatch(Collection<T> entityList) {
+        return saveBatch(entityList, Constants.DEFAULT_BATCH_SIZE);
+    }
+
+    /**
+     * 插入(批量)
+     *
+     * @param entityList 实体对象集合
+     * @param batchSize  插入批次数量
+     * @since 3.5.7
+     */
+    default List<BatchResult> saveBatch(Collection<T> entityList, int batchSize) {
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
         MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
         SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
         SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
-        return MybatisBatchUtils.execute(sqlSessionFactory, entityList, method.insert());
+        return MybatisBatchUtils.execute(sqlSessionFactory, entityList, method.insert(), batchSize);
     }
     }
 
 
     /**
     /**
@@ -474,10 +479,21 @@ public interface BaseMapper<T> extends Mapper<T> {
      * @since 3.5.7
      * @since 3.5.7
      */
      */
     default List<BatchResult> updateBatchById(Collection<T> entityList) {
     default List<BatchResult> updateBatchById(Collection<T> entityList) {
+        return updateBatchById(entityList, Constants.DEFAULT_BATCH_SIZE);
+    }
+
+    /**
+     * 根据ID 批量更新
+     *
+     * @param entityList 实体对象集合
+     * @param batchSize  插入批次数量
+     * @since 3.5.7
+     */
+    default List<BatchResult> updateBatchById(Collection<T> entityList, int batchSize) {
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
         MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
         SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
         SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
-        return MybatisBatchUtils.execute(sqlSessionFactory, entityList, method.updateById());
+        return MybatisBatchUtils.execute(sqlSessionFactory, entityList, method.updateById(), batchSize);
     }
     }
 
 
     /**
     /**
@@ -487,6 +503,17 @@ public interface BaseMapper<T> extends Mapper<T> {
      * @since 3.5.7
      * @since 3.5.7
      */
      */
     default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList) {
     default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList) {
+        return saveOrUpdateBatch(entityList, Constants.DEFAULT_BATCH_SIZE);
+    }
+
+    /**
+     * 批量修改或插入
+     *
+     * @param entityList 实体对象集合
+     * @param batchSize  插入批次数量
+     * @since 3.5.7
+     */
+    default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         Class<?> entityClass = GenericTypeUtils.resolveTypeArguments(getClass(), BaseMapper.class)[0];
         Class<?> entityClass = GenericTypeUtils.resolveTypeArguments(getClass(), BaseMapper.class)[0];
         TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
         TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
@@ -495,7 +522,7 @@ public interface BaseMapper<T> extends Mapper<T> {
         return saveOrUpdateBatch(entityList, (sqlSession, entity) -> {
         return saveOrUpdateBatch(entityList, (sqlSession, entity) -> {
             Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
             Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
             return StringUtils.checkValNull(idVal) || CollectionUtils.isEmpty(sqlSession.selectList(statement, entity));
             return StringUtils.checkValNull(idVal) || CollectionUtils.isEmpty(sqlSession.selectList(statement, entity));
-        });
+        }, batchSize);
     }
     }
 
 
     /**
     /**
@@ -505,10 +532,21 @@ public interface BaseMapper<T> extends Mapper<T> {
      * @since 3.5.7
      * @since 3.5.7
      */
      */
     default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList, BiPredicate<BatchSqlSession, T> insertPredicate) {
     default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList, BiPredicate<BatchSqlSession, T> insertPredicate) {
+        return saveOrUpdateBatch(entityList, insertPredicate, Constants.DEFAULT_BATCH_SIZE);
+    }
+
+    /**
+     * 批量修改或插入
+     *
+     * @param entityList 实体对象集合
+     * @param batchSize  插入批次数量
+     * @since 3.5.7
+     */
+    default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList, BiPredicate<BatchSqlSession, T> insertPredicate, int batchSize) {
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         MybatisMapperProxy<?> mybatisMapperProxy = MybatisUtils.getMybatisMapperProxy(this);
         MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
         MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
         SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
         SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
-        return MybatisBatchUtils.saveOrUpdate(sqlSessionFactory, entityList, method.insert(), insertPredicate, method.updateById());
+        return MybatisBatchUtils.saveOrUpdate(sqlSessionFactory, entityList, method.insert(), insertPredicate, method.updateById(), batchSize);
     }
     }
 
 
 }
 }

+ 26 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/CollectionUtils.java

@@ -21,11 +21,14 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Optional;
 import java.util.function.Function;
 import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 
 /**
 /**
  * Collection工具类
  * Collection工具类
@@ -232,4 +235,27 @@ public class CollectionUtils {
         return Collections.emptyList();
         return Collections.emptyList();
     }
     }
 
 
+    /**
+     * 切割集合为多个集合
+     * @param entityList 数据集合
+     * @param batchSize 每批集合的大小
+     * @return 切割后的多个集合
+     * @param <T> 数据类型
+     */
+    public static <T> List<List<T>> split(Collection<T> entityList, int batchSize) {
+        if (isEmpty(entityList)) {
+            return Collections.emptyList();
+        }
+        Assert.isFalse(batchSize < 1, "batchSize must not be less than one");
+        final Iterator<T> iterator = entityList.iterator();
+        final List<List<T>> results = new ArrayList<>(entityList.size() / batchSize);
+        while (iterator.hasNext()) {
+            final List<T> list = IntStream.range(0, batchSize).filter(x -> iterator.hasNext())
+                .mapToObj(i -> iterator.next()).collect(Collectors.toList());
+            if (!list.isEmpty()) {
+                results.add(list);
+            }
+        }
+        return results;
+    }
 }
 }

+ 6 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/Constants.java

@@ -179,4 +179,10 @@ public interface Constants extends StringPool, Serializable {
      */
      */
     String WRAPPER_PARAM = "MPGENVAL";
     String WRAPPER_PARAM = "MPGENVAL";
     String WRAPPER_PARAM_MIDDLE = ".paramNameValuePairs" + DOT;
     String WRAPPER_PARAM_MIDDLE = ".paramNameValuePairs" + DOT;
+
+
+    /**
+     * 默认批次提交数量
+     */
+    int DEFAULT_BATCH_SIZE = 1000;
 }
 }

+ 130 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/MybatisBatchUtils.java

@@ -45,6 +45,20 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(statement);
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(statement);
     }
     }
 
 
+    /**
+     * 执行批量操作
+     *
+     * @param sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param statement         执行的 mapper 方法 (示例: com.baomidou.mybatisplus.core.mapper.BaseMapper.insert )
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, String statement, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).execute(statement);
+    }
+
     /**
     /**
      * 执行批量操作
      * 执行批量操作
      *
      *
@@ -59,6 +73,21 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(statement, parameterConvert);
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(statement, parameterConvert);
     }
     }
 
 
+    /**
+     * 执行批量操作
+     *
+     * @param sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param statement         执行的 mapper 方法 (示例: com.baomidou.mybatisplus.core.mapper.BaseMapper.insert )
+     * @param parameterConvert  参数转换器
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, String statement, ParameterConvert<T> parameterConvert, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).execute(statement, parameterConvert);
+    }
+
     /**
     /**
      * 执行批量操作
      * 执行批量操作
      *
      *
@@ -73,6 +102,21 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, statement);
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, statement);
     }
     }
 
 
+    /**
+     * 执行批量操作
+     *
+     * @param sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param autoCommit        是否自动提交(这里生效的前提依赖于事务管理器 {@link org.apache.ibatis.transaction.Transaction})
+     * @param statement         执行的 mapper 方法 (示例: com.baomidou.mybatisplus.core.mapper.BaseMapper.insert )
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, String statement, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).execute(autoCommit, statement);
+    }
+
     /**
     /**
      * 执行批量操作
      * 执行批量操作
      *
      *
@@ -88,6 +132,22 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, statement, parameterConvert);
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, statement, parameterConvert);
     }
     }
 
 
+    /**
+     * 执行批量操作
+     *
+     * @param sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param autoCommit        是否自动提交(这里生效的前提依赖于事务管理器 {@link org.apache.ibatis.transaction.Transaction})
+     * @param statement         执行的 mapper 方法 (示例: com.baomidou.mybatisplus.core.mapper.BaseMapper.insert )
+     * @param parameterConvert  参数转换器
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, String statement, ParameterConvert<T> parameterConvert, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).execute(autoCommit, statement, parameterConvert);
+    }
+
     /**
     /**
      * 执行批量操作
      * 执行批量操作
      *
      *
@@ -101,6 +161,20 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(batchMethod);
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(batchMethod);
     }
     }
 
 
+    /**
+     * 执行批量操作
+     *
+     * @param sqlSessionFactory sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param batchMethod       批量操作方法
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, BatchMethod<T> batchMethod, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).execute(batchMethod);
+    }
+
     /**
     /**
      * 执行批量操作
      * 执行批量操作
      *
      *
@@ -115,6 +189,21 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, batchMethod);
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, batchMethod);
     }
     }
 
 
+    /**
+     * 执行批量操作
+     *
+     * @param sqlSessionFactory sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param autoCommit        是否自动提交(这里生效的前提依赖于事务管理器 {@link org.apache.ibatis.transaction.Transaction})
+     * @param batchMethod       批量操作方法
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, BatchMethod<T> batchMethod, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).execute(autoCommit, batchMethod);
+    }
+
     /**
     /**
      * 批量保存或更新
      * 批量保存或更新
      * 这里需要注意一下,如果在insertPredicate里判断调用其他sqlSession(类似mapper.xxx)时,要注意一级缓存问题或数据感知问题(因为当前会话数据还未提交)
      * 这里需要注意一下,如果在insertPredicate里判断调用其他sqlSession(类似mapper.xxx)时,要注意一级缓存问题或数据感知问题(因为当前会话数据还未提交)
@@ -134,6 +223,26 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).saveOrUpdate(insertMethod, insertPredicate, updateMethod);
         return new MybatisBatch<>(sqlSessionFactory, dataList).saveOrUpdate(insertMethod, insertPredicate, updateMethod);
     }
     }
 
 
+    /**
+     * 批量保存或更新
+     * 这里需要注意一下,如果在insertPredicate里判断调用其他sqlSession(类似mapper.xxx)时,要注意一级缓存问题或数据感知问题(因为当前会话数据还未提交)
+     * 举个例子(事务开启状态下):
+     * 如果当前批次里面执行两个主键相同的数据,当调用mapper.selectById时,如果数据库未有这条记录,在同个sqlSession下,由于一级缓存的问题,下次再查就还是null,导致插入主键冲突,
+     * 但使用 {@link BatchSqlSession}时,由于每次select操作都会触发一次flushStatements,就会执行更新操作
+     *
+     * @param sqlSessionFactory sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param insertMethod      插入方法
+     * @param insertPredicate   插入条件 (当条件满足时执行插入方法,否则执行更新方法)
+     * @param updateMethod      更新方法
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> saveOrUpdate(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).saveOrUpdate(insertMethod, insertPredicate, updateMethod);
+    }
+
     /**
     /**
      * 批量保存或更新
      * 批量保存或更新
      * 这里需要注意一下,如果在insertPredicate里判断调用其他sqlSession(类似mapper.xxx)时,要注意一级缓存问题或数据感知问题(因为当前会话数据还未提交)
      * 这里需要注意一下,如果在insertPredicate里判断调用其他sqlSession(类似mapper.xxx)时,要注意一级缓存问题或数据感知问题(因为当前会话数据还未提交)
@@ -154,5 +263,26 @@ public class MybatisBatchUtils {
         return new MybatisBatch<>(sqlSessionFactory, dataList).saveOrUpdate(autoCommit, insertMethod, insertPredicate, updateMethod);
         return new MybatisBatch<>(sqlSessionFactory, dataList).saveOrUpdate(autoCommit, insertMethod, insertPredicate, updateMethod);
     }
     }
 
 
+    /**
+     * 批量保存或更新
+     * 这里需要注意一下,如果在insertPredicate里判断调用其他sqlSession(类似mapper.xxx)时,要注意一级缓存问题或数据感知问题(因为当前会话数据还未提交)
+     * 举个例子(事务开启状态下):
+     * 如果当前批次里面执行两个主键相同的数据,当调用mapper.selectById时,如果数据库未有这条记录,在同个sqlSession下,由于一级缓存的问题,下次再查就还是null,导致插入主键冲突,
+     * 但使用 {@link BatchSqlSession}时,由于每次select操作都会触发一次flushStatements,就会执行更新操作
+     *
+     * @param sqlSessionFactory sqlSessionFactory {@link SqlSessionFactory}
+     * @param dataList          数据集列表
+     * @param autoCommit        是否自动提交(这里生效的前提依赖于事务管理器 {@link org.apache.ibatis.transaction.Transaction})
+     * @param insertMethod      插入方法
+     * @param insertPredicate   插入条件 (当条件满足时执行插入方法,否则执行更新方法)
+     * @param updateMethod      更新方法
+     * @param <T>               泛型
+     * @param batchSize         插入批次数量
+     * @return 批处理结果
+     */
+    public static <T> List<BatchResult> saveOrUpdate(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod, int batchSize) {
+        return new MybatisBatch<>(sqlSessionFactory, dataList, batchSize).saveOrUpdate(autoCommit, insertMethod, insertPredicate, updateMethod);
+    }
+
 
 
 }
 }

+ 11 - 0
mybatis-plus-core/src/test/java/com/baomidou/mybatisplus/test/toolkit/CollectionUtilsTest.java

@@ -8,9 +8,13 @@ import org.junit.jupiter.api.condition.JRE;
 
 
 import java.lang.reflect.Field;
 import java.lang.reflect.Field;
 import java.util.HashMap;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentHashMap;
 
 
+import static java.util.Arrays.asList;
+import static java.util.Collections.singletonList;
+
 /**
 /**
  * @author nieqiuqiu 2020/7/2
  * @author nieqiuqiu 2020/7/2
  */
  */
@@ -46,6 +50,13 @@ class CollectionUtilsTest {
         computeIfAbsent();
         computeIfAbsent();
     }
     }
 
 
+    @Test
+    void testSplit() {
+        List<Integer> list = asList(1, 2, 3, 4, 5);
+        List<List<Integer>> lists = CollectionUtils.split(list, 2);
+        Assertions.assertEquals(asList(asList(1, 2), asList(3, 4), singletonList(5)), lists);
+    }
+
     private Map<String, String> newHashMapWithExpectedSize(int size) {
     private Map<String, String> newHashMapWithExpectedSize(int size) {
         Map<String, String> hashMap = CollectionUtils.newHashMapWithExpectedSize(size);
         Map<String, String> hashMap = CollectionUtils.newHashMapWithExpectedSize(size);
         hashMap.put("1", "1");
         hashMap.put("1", "1");

+ 2 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/IService.java

@@ -20,6 +20,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.toolkit.Assert;
 import com.baomidou.mybatisplus.core.toolkit.Assert;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.Constants;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.extension.conditions.query.ChainQuery;
 import com.baomidou.mybatisplus.extension.conditions.query.ChainQuery;
 import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper;
 import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper;
@@ -49,7 +50,7 @@ public interface IService<T> {
     /**
     /**
      * 默认批次提交数量
      * 默认批次提交数量
      */
      */
-    int DEFAULT_BATCH_SIZE = 1000;
+    int DEFAULT_BATCH_SIZE = Constants.DEFAULT_BATCH_SIZE;
 
 
     /**
     /**
      * 插入一条记录(选择字段,策略插入)
      * 插入一条记录(选择字段,策略插入)

+ 12 - 26
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/Db.java

@@ -16,11 +16,13 @@
 package com.baomidou.mybatisplus.extension.toolkit;
 package com.baomidou.mybatisplus.extension.toolkit;
 
 
 import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
 import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
-import com.baomidou.mybatisplus.core.enums.SqlMethod;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
-import com.baomidou.mybatisplus.core.toolkit.*;
+import com.baomidou.mybatisplus.core.toolkit.Assert;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper;
 import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper;
 import com.baomidou.mybatisplus.extension.conditions.query.QueryChainWrapper;
 import com.baomidou.mybatisplus.extension.conditions.query.QueryChainWrapper;
@@ -29,13 +31,14 @@ import com.baomidou.mybatisplus.extension.conditions.update.UpdateChainWrapper;
 import com.baomidou.mybatisplus.extension.kotlin.KtQueryChainWrapper;
 import com.baomidou.mybatisplus.extension.kotlin.KtQueryChainWrapper;
 import com.baomidou.mybatisplus.extension.kotlin.KtUpdateChainWrapper;
 import com.baomidou.mybatisplus.extension.kotlin.KtUpdateChainWrapper;
 import com.baomidou.mybatisplus.extension.service.IService;
 import com.baomidou.mybatisplus.extension.service.IService;
-import org.apache.ibatis.binding.MapperMethod;
+import org.apache.ibatis.executor.BatchResult;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.logging.LogFactory;
 
 
 import java.io.Serializable;
 import java.io.Serializable;
 import java.util.*;
 import java.util.*;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 
 /**
 /**
  * 以静态方式调用Service中的函数
  * 以静态方式调用Service中的函数
@@ -84,9 +87,8 @@ public class Db {
             return false;
             return false;
         }
         }
         Class<T> entityClass = getEntityClass(entityList);
         Class<T> entityClass = getEntityClass(entityList);
-        Class<?> mapperClass = ClassUtils.toClassConfident(getTableInfo(entityClass).getCurrentNamespace());
-        String sqlStatement = SqlHelper.getSqlStatement(mapperClass, SqlMethod.INSERT_ONE);
-        return SqlHelper.executeBatch(entityClass, log, entityList, batchSize, (sqlSession, entity) -> sqlSession.insert(sqlStatement, entity));
+        List<BatchResult> batchResults = SqlHelper.execute(entityClass, baseMapper -> baseMapper.saveBatch(entityList, batchSize));
+        return batchResults.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).allMatch(i -> i > 0);
     }
     }
 
 
     /**
     /**
@@ -109,19 +111,8 @@ public class Db {
             return false;
             return false;
         }
         }
         Class<T> entityClass = getEntityClass(entityList);
         Class<T> entityClass = getEntityClass(entityList);
-        TableInfo tableInfo = getTableInfo(entityClass);
-        Class<?> mapperClass = ClassUtils.toClassConfident(tableInfo.getCurrentNamespace());
-        String keyProperty = tableInfo.getKeyProperty();
-        Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for primary key from entity!");
-        return SqlHelper.saveOrUpdateBatch(entityClass, mapperClass, log, entityList, batchSize, (sqlSession, entity) -> {
-            Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
-            return StringUtils.checkValNull(idVal)
-                || CollectionUtils.isEmpty(sqlSession.selectList(SqlHelper.getSqlStatement(mapperClass, SqlMethod.SELECT_BY_ID), entity));
-        }, (sqlSession, entity) -> {
-            MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
-            param.put(Constants.ENTITY, entity);
-            sqlSession.update(SqlHelper.getSqlStatement(mapperClass, SqlMethod.UPDATE_BY_ID), param);
-        });
+        List<BatchResult> batchResults = SqlHelper.execute(entityClass, baseMapper -> baseMapper.saveOrUpdateBatch(entityList, batchSize));
+        return batchResults.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).allMatch(i -> i > 0);
     }
     }
 
 
     /**
     /**
@@ -203,13 +194,8 @@ public class Db {
      */
      */
     public static <T> boolean updateBatchById(Collection<T> entityList, int batchSize) {
     public static <T> boolean updateBatchById(Collection<T> entityList, int batchSize) {
         Class<T> entityClass = getEntityClass(entityList);
         Class<T> entityClass = getEntityClass(entityList);
-        TableInfo tableInfo = getTableInfo(entityClass);
-        String sqlStatement = SqlHelper.getSqlStatement(ClassUtils.toClassConfident(tableInfo.getCurrentNamespace()), SqlMethod.UPDATE_BY_ID);
-        return SqlHelper.executeBatch(entityClass, log, entityList, batchSize, (sqlSession, entity) -> {
-            MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
-            param.put(Constants.ENTITY, entity);
-            sqlSession.update(sqlStatement, param);
-        });
+        List<BatchResult> batchResults = SqlHelper.execute(entityClass, baseMapper -> baseMapper.updateBatchById(entityList, batchSize));
+        return batchResults.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).allMatch(i -> i > 0);
     }
     }
 
 
     /**
     /**