Explorar el Código

BaseMapper新增批量修改或更新方法.

nieqiurong hace 1 año
padre
commit
bb05cdbdd8

+ 22 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/mapper/BaseMapper.java

@@ -24,9 +24,11 @@ import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.override.MybatisMapperProxy;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.Constants;
 import com.baomidou.mybatisplus.core.toolkit.MybatisBatchUtils;
 import com.baomidou.mybatisplus.core.toolkit.MybatisUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.core.toolkit.reflect.GenericTypeUtils;
 import org.apache.ibatis.annotations.Param;
@@ -423,4 +425,24 @@ public interface BaseMapper<T> extends Mapper<T> {
         return MybatisBatchUtils.execute(sqlSessionFactory, entityList, method.updateById());
     }
 
+    /**
+     * 批量修改或插入
+     *
+     * @param entityList 实体对象集合
+     * @since 3.5.7
+     */
+    default List<BatchResult> saveOrUpdateBatch(Collection<T> entityList) {
+        Class<?> entityClass = GenericTypeUtils.resolveTypeArguments(getClass(), BaseMapper.class)[0];
+        TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
+        MybatisMapperProxy<?> mybatisMapperProxy = (MybatisMapperProxy<?>) Proxy.getInvocationHandler(this);
+        MybatisBatch.Method<T> method = new MybatisBatch.Method<>(mybatisMapperProxy.getMapperInterface());
+        SqlSessionFactory sqlSessionFactory = MybatisUtils.getSqlSessionFactory(mybatisMapperProxy);
+        String keyProperty = tableInfo.getKeyProperty();
+        String statementId = method.get(SqlMethod.SELECT_BY_ID.getMethod()).getStatementId();
+        return MybatisBatchUtils.saveOrUpdate(sqlSessionFactory, entityList, method.insert(), (sqlSession, entity) -> {
+            Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
+            return StringUtils.checkValNull(idVal) || CollectionUtils.isEmpty(sqlSession.selectList(statementId, entity));
+        }, method.updateById());
+    }
+
 }

+ 18 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserMapperTest.java

@@ -255,6 +255,24 @@ class H2UserMapperTest extends BaseTest {
         }
     }
 
+    @Test
+    void testSaveOrUpdateBatchMapper1() {
+        int batchSize = 10;
+        List<H2User> h2UserList = new ArrayList<>();
+        for (int i = 0; i < batchSize; i++) {
+            h2UserList.add(new H2User(Long.valueOf(140000 + i), "test" + i));
+        }
+        List<BatchResult> batchResults = userMapper.saveOrUpdateBatch(h2UserList);
+        Assertions.assertEquals(batchSize, batchResults.size());
+        // 使用共享的sqlSession,等于每次都是刷新了,批次总结果集就等于数据大小了
+        Assertions.assertEquals(batchSize, batchResults.size());
+        for (BatchResult batchResult : batchResults) {
+            Assertions.assertEquals(batchResult.getUpdateCounts().length, 1);
+            Assertions.assertEquals(1, batchResult.getUpdateCounts()[0]);
+        }
+    }
+
+
     @Test
     void testSaveOrUpdateBatch2() {
         int batchSize = 10;