Prechádzať zdrojové kódy

新增重载方法支持自定义条件.

nieqiurong 1 rok pred
rodič
commit
52c7497c5b

+ 20 - 6
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/mapper/BaseMapper.java

@@ -15,6 +15,7 @@
  */
 package com.baomidou.mybatisplus.core.mapper;
 
+import com.baomidou.mybatisplus.core.batch.BatchSqlSession;
 import com.baomidou.mybatisplus.core.batch.MybatisBatch;
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
 import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
@@ -28,6 +29,7 @@ import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.Constants;
 import com.baomidou.mybatisplus.core.toolkit.MybatisBatchUtils;
 import com.baomidou.mybatisplus.core.toolkit.MybatisUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.core.toolkit.reflect.GenericTypeUtils;
@@ -44,6 +46,7 @@ import java.lang.reflect.Proxy;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.function.BiPredicate;
 
 /*
 
@@ -432,17 +435,28 @@ public interface BaseMapper<T> extends Mapper<T> {
      * @since 3.5.7
      */
     default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList) {
+        MybatisMapperProxy<?> mybatisMapperProxy = (MybatisMapperProxy<?>) Proxy.getInvocationHandler(this);
         Class<?> entityClass = GenericTypeUtils.resolveTypeArguments(getClass(), BaseMapper.class)[0];
         TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
+        String keyProperty = tableInfo.getKeyProperty();
+        String statement = mybatisMapperProxy.getMapperInterface().getName() + StringPool.DOT + SqlMethod.SELECT_BY_ID.getMethod();
+        return saveOrUpdateBatch(entityList, (sqlSession, entity) -> {
+            Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
+            return StringUtils.checkValNull(idVal) || CollectionUtils.isEmpty(sqlSession.selectList(statement, entity));
+        });
+    }
+
+    /**
+     * 批量修改或插入
+     *
+     * @param entityList 实体对象集合
+     * @since 3.5.7
+     */
+    default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList, BiPredicate<BatchSqlSession, T> insertPredicate) {
         MybatisMapperProxy<?> mybatisMapperProxy = (MybatisMapperProxy<?>) Proxy.getInvocationHandler(this);
         MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
         SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
-        String keyProperty = tableInfo.getKeyProperty();
-        String statementId = method.get(SqlMethod.SELECT_BY_ID.getMethod()).getStatementId();
-        return MybatisBatchUtils.saveOrUpdate(sqlSessionFactory, entityList, method.insert(), (sqlSession, entity) -> {
-            Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
-            return StringUtils.checkValNull(idVal) || CollectionUtils.isEmpty(sqlSession.selectList(statementId, entity));
-        }, method.updateById());
+        return MybatisBatchUtils.saveOrUpdate(sqlSessionFactory, entityList, method.insert(), insertPredicate, method.updateById());
     }
 
 }

+ 43 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserMapperTest.java

@@ -272,6 +272,49 @@ class H2UserMapperTest extends BaseTest {
         }
     }
 
+    @Test
+    void testSaveOrUpdateBatchMapper2() {
+        int batchSize = 10;
+        List<H2User> h2UserList = new ArrayList<>();
+        for (int i = 0; i < batchSize; i++) {
+            h2UserList.add(new H2User(Long.valueOf(40000 + i), "test" + i));
+        }
+        List<BatchResult> batchResults = userMapper.saveOrUpdateBatch(h2UserList,((sqlSession, h2User) -> userMapper.selectById(h2User.getTestId()) == null));
+        // 没有使用共享的sqlSession,由于都是新增返回还是一个批次
+        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertEquals(batchSize, updateCounts.length);
+        for (int updateCount : updateCounts) {
+            Assertions.assertEquals(1, updateCount);
+        }
+    }
+
+    @Test
+    void testSaveOrUpdateBatchMapper3() {
+        var id = IdWorker.getId();
+        var h2UserList = List.of(new H2User(id, "testSaveOrUpdateBatchMapper3"), new H2User(id, "testSaveOrUpdateBatchMapper3-1"));
+        // 由于没有共享一个sqlSession,第二条记录selectById的时候第一个sqlSession的数据还没提交,会执行插入导致主键冲突.
+        Assertions.assertThrowsExactly(PersistenceException.class, () -> {
+            userMapper.saveOrUpdateBatch(h2UserList, ((sqlSession, h2User) -> userMapper.selectById(h2User.getTestId()) == null));
+        });
+    }
+
+    @Test
+    void testSaveOrUpdateBatchMapper4() {
+        var id = IdWorker.getId();
+        var h2UserList = List.of(new H2User(id, "testSaveOrUpdateBatchMapper4"), new H2User(id, "testSaveOrUpdateBatchMapper4-1"));
+        var mapperMethod = new MybatisBatch.Method<H2User>(H2UserMapper.class);
+        // 共享一个sqlSession,每次selectById都会刷新一下,第二条记录为update.
+        var batchResults = userMapper.saveOrUpdateBatch(h2UserList,
+            ((sqlSession, h2User) -> sqlSession.selectList(mapperMethod.get("selectById").getStatementId(), h2User.getTestId()).isEmpty()));
+        var updateCounts = batchResults.get(0).getUpdateCounts();
+        for (int updateCount : updateCounts) {
+            Assertions.assertEquals(1, updateCount);
+        }
+        Assertions.assertEquals(userMapper.selectById(id).getName(), "testSaveOrUpdateBatchMapper4-1");
+    }
+
+
+
 
     @Test
     void testSaveOrUpdateBatch2() {