Browse Source

BaseMapper新增批量保存与更新方法.

nieqiurong 1 year ago
parent
commit
68a5042601

+ 3 - 2
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/batch/MybatisBatch.java

@@ -26,6 +26,7 @@ import org.apache.ibatis.session.SqlSessionFactory;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -58,9 +59,9 @@ public class MybatisBatch<T> {
 
     private final SqlSessionFactory sqlSessionFactory;
 
-    private final List<T> dataList;
+    private final Collection<T> dataList;
 
-    public MybatisBatch(SqlSessionFactory sqlSessionFactory, List<T> dataList) {
+    public MybatisBatch(SqlSessionFactory sqlSessionFactory, Collection<T> dataList) {
         this.sqlSessionFactory = sqlSessionFactory;
         this.dataList = dataList;
     }

+ 33 - 1
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/mapper/BaseMapper.java

@@ -15,18 +15,24 @@
  */
 package com.baomidou.mybatisplus.core.mapper;
 
+import com.baomidou.mybatisplus.core.batch.MybatisBatch;
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
-import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
 import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
 import com.baomidou.mybatisplus.core.metadata.IPage;
+import com.baomidou.mybatisplus.core.override.MybatisMapperProxy;
 import com.baomidou.mybatisplus.core.toolkit.Constants;
+import com.baomidou.mybatisplus.core.toolkit.MybatisBatchUtils;
+import com.baomidou.mybatisplus.core.toolkit.MybatisUtils;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import org.apache.ibatis.annotations.Param;
 import org.apache.ibatis.exceptions.TooManyResultsException;
+import org.apache.ibatis.executor.BatchResult;
 import org.apache.ibatis.session.ResultHandler;
+import org.apache.ibatis.session.SqlSessionFactory;
 
 import java.io.Serializable;
+import java.lang.reflect.Proxy;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
@@ -361,4 +367,30 @@ public interface BaseMapper<T> extends Mapper<T> {
         return page;
     }
 
+    /**
+     * 插入(批量)
+     *
+     * @param entityList 实体对象集合
+     * @since 3.5.7
+     */
+    default List<BatchResult> saveBatch(Collection<T> entityList) {
+        MybatisMapperProxy<?> mybatisMapperProxy = (MybatisMapperProxy<?>) Proxy.getInvocationHandler(this);
+        MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
+        SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
+        return MybatisBatchUtils.execute(sqlSessionFactory, entityList, method.insert());
+    }
+
+    /**
+     * 根据ID 批量更新
+     *
+     * @param entityList 实体对象集合
+     * @since 3.5.7
+     */
+    default List<BatchResult> updateBatchById(Collection<T> entityList) {
+        MybatisMapperProxy<?> mybatisMapperProxy = (MybatisMapperProxy<?>) Proxy.getInvocationHandler(this);
+        MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
+        SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
+        return MybatisBatchUtils.execute(sqlSessionFactory, entityList, method.updateById());
+    }
+
 }

+ 4 - 1
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/override/MybatisMapperProxy.java

@@ -15,7 +15,6 @@
  */
 package com.baomidou.mybatisplus.core.override;
 
-import org.apache.ibatis.binding.MapperMethod;
 import org.apache.ibatis.binding.MapperProxy;
 import org.apache.ibatis.reflection.ExceptionUtil;
 import org.apache.ibatis.session.SqlSession;
@@ -119,6 +118,10 @@ public class MybatisMapperProxy<T> implements InvocationHandler, Serializable {
         return sqlSession;
     }
 
+    public Class<T> getMapperInterface() {
+        return mapperInterface;
+    }
+
     private MethodHandle getMethodHandleJava9(Method method)
         throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
         final Class<?> declaringClass = method.getDeclaringClass();

+ 9 - 8
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/MybatisBatchUtils.java

@@ -22,6 +22,7 @@ import com.baomidou.mybatisplus.core.batch.ParameterConvert;
 import org.apache.ibatis.executor.BatchResult;
 import org.apache.ibatis.session.SqlSessionFactory;
 
+import java.util.Collection;
 import java.util.List;
 import java.util.function.BiPredicate;
 
@@ -40,7 +41,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, List<T> dataList, String statement) {
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, String statement) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(statement);
     }
 
@@ -54,7 +55,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, List<T> dataList, String statement, ParameterConvert<T> parameterConvert) {
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, String statement, ParameterConvert<T> parameterConvert) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(statement, parameterConvert);
     }
 
@@ -68,7 +69,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, List<T> dataList, boolean autoCommit, String statement) {
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, String statement) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, statement);
     }
 
@@ -83,7 +84,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, List<T> dataList, boolean autoCommit, String statement, ParameterConvert<T> parameterConvert) {
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, String statement, ParameterConvert<T> parameterConvert) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, statement, parameterConvert);
     }
 
@@ -96,7 +97,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, List<T> dataList, BatchMethod<T> batchMethod) {
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, BatchMethod<T> batchMethod) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(batchMethod);
     }
 
@@ -110,7 +111,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, List<T> dataList, boolean autoCommit, BatchMethod<T> batchMethod) {
+    public static <T> List<BatchResult> execute(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, BatchMethod<T> batchMethod) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).execute(autoCommit, batchMethod);
     }
 
@@ -129,7 +130,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> saveOrUpdate(SqlSessionFactory sqlSessionFactory, List<T> dataList, BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod) {
+    public static <T> List<BatchResult> saveOrUpdate(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).saveOrUpdate(insertMethod, insertPredicate, updateMethod);
     }
 
@@ -149,7 +150,7 @@ public class MybatisBatchUtils {
      * @param <T>               泛型
      * @return 批处理结果
      */
-    public static <T> List<BatchResult> saveOrUpdate(SqlSessionFactory sqlSessionFactory, List<T> dataList, boolean autoCommit, BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod) {
+    public static <T> List<BatchResult> saveOrUpdate(SqlSessionFactory sqlSessionFactory, Collection<T> dataList, boolean autoCommit, BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod) {
         return new MybatisBatch<>(sqlSessionFactory, dataList).saveOrUpdate(autoCommit, insertMethod, insertPredicate, updateMethod);
     }
 

+ 20 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/MybatisUtils.java

@@ -1,8 +1,12 @@
 package com.baomidou.mybatisplus.core.toolkit;
 
 import com.baomidou.mybatisplus.core.handlers.IJsonTypeHandler;
+import com.baomidou.mybatisplus.core.override.MybatisMapperProxy;
 import lombok.experimental.UtilityClass;
 import lombok.extern.slf4j.Slf4j;
+import org.apache.ibatis.session.SqlSession;
+import org.apache.ibatis.session.SqlSessionFactory;
+import org.apache.ibatis.session.defaults.DefaultSqlSession;
 import org.apache.ibatis.type.TypeException;
 import org.apache.ibatis.type.TypeHandler;
 
@@ -49,4 +53,20 @@ public class MybatisUtils {
         return result;
     }
 
+    public static SqlSessionFactory getSqlSessionFactory(MybatisMapperProxy<?> mybatisMapperProxy) {
+        SqlSession sqlSession = mybatisMapperProxy.getSqlSession();
+        if (sqlSession instanceof DefaultSqlSession) {
+            // TODO 原生mybatis下只能这样了.
+            return GlobalConfigUtils.getGlobalConfig(mybatisMapperProxy.getSqlSession().getConfiguration()).getSqlSessionFactory();
+        }
+        Field declaredField;
+        try {
+            declaredField = sqlSession.getClass().getDeclaredField("sqlSessionFactory");
+            declaredField.setAccessible(true);
+            return (SqlSessionFactory) declaredField.get(sqlSession);
+        } catch (NoSuchFieldException | IllegalAccessException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
 }

+ 18 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserMapperTest.java

@@ -70,6 +70,24 @@ class H2UserMapperTest extends BaseTest {
     @Autowired
     private TransactionTemplate transactionTemplate;
 
+    @Test
+    void testMapperSaveBatch() {
+        var list = List.of(new H2User("秋秋1"), new H2User("秋秋2"));
+        List<BatchResult> batchResults = userMapper.saveBatch(list);
+        Assertions.assertEquals(2, batchResults.get(0).getUpdateCounts().length);
+    }
+
+    @Test
+    void testMapperUpdateBatch() {
+        var list = List.of(new H2User("秋秋1"), new H2User("秋秋2"));
+        userMapper.saveBatch(list);
+        for (H2User h2User : list) {
+            h2User.setName("test" + 1);
+        }
+        List<BatchResult> batchResults = userMapper.updateBatchById(list);
+        Assertions.assertEquals(2, batchResults.get(0).getUpdateCounts().length);
+    }
+
 
     @Test
     void testBatchTransaction() {

+ 1 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/issues/aop/MultiAopTest.java

@@ -39,6 +39,7 @@ public class MultiAopTest {
         );
         demoService.save(new Demo());
         demoService.saveBatch(List.of(new Demo()));
+        demoService.getBaseMapper().saveBatch(List.of(new Demo()));
     }
 
 }

+ 1 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/issues/aop/NoAopTest.java

@@ -39,6 +39,7 @@ public class NoAopTest {
         );
         demoService.save(new Demo());
         demoService.saveBatch(List.of(new Demo()));
+        demoService.getBaseMapper().saveBatch(List.of(new Demo()));
     }
 
 

+ 1 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/issues/aop/SingleAopTest.java

@@ -39,6 +39,7 @@ public class SingleAopTest {
         );
         demoService.save(new Demo());
         demoService.saveBatch(List.of(new Demo()));
+        demoService.getBaseMapper().saveBatch(List.of(new Demo()));
     }
 
 }