Quellcode durchsuchen

增加一个批量操作结果集转换bool工具方法.

nieqiurong vor 10 Monaten
Ursprung
Commit
acbe9a2d80

+ 4 - 6
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/Db.java

@@ -36,10 +36,8 @@ import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
 
 import java.io.Serializable;
-import java.sql.Statement;
 import java.util.*;
 import java.util.stream.Collectors;
-import java.util.stream.IntStream;
 
 /**
  * 以静态方式调用Service中的函数
@@ -89,7 +87,7 @@ public class Db {
         }
         Class<T> entityClass = getEntityClass(entityList);
         List<BatchResult> batchResults = SqlHelper.execute(entityClass, baseMapper -> baseMapper.insert(entityList, batchSize));
-        return batchResults.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).allMatch(i -> i > 0 || i == Statement.SUCCESS_NO_INFO);
+        return SqlHelper.retBool(batchResults);
     }
 
     /**
@@ -113,7 +111,7 @@ public class Db {
         }
         Class<T> entityClass = getEntityClass(entityList);
         List<BatchResult> batchResults = SqlHelper.execute(entityClass, baseMapper -> baseMapper.insertOrUpdate(entityList, batchSize));
-        return batchResults.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).allMatch(i -> i > 0 || i == Statement.SUCCESS_NO_INFO);
+        return SqlHelper.retBool(batchResults);
     }
 
     /**
@@ -196,7 +194,7 @@ public class Db {
     public static <T> boolean updateBatchById(Collection<T> entityList, int batchSize) {
         Class<T> entityClass = getEntityClass(entityList);
         List<BatchResult> batchResults = SqlHelper.execute(entityClass, baseMapper -> baseMapper.updateById(entityList, batchSize));
-        return batchResults.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).allMatch(i -> i > 0 || i == Statement.SUCCESS_NO_INFO);
+        return SqlHelper.retBool(batchResults);
     }
 
     /**
@@ -651,7 +649,7 @@ public class Db {
     protected static <T> Class<T> getEntityClass(Collection<T> entityList) {
         Class<T> entityClass = null;
         for (T entity : entityList) {
-            if (entity != null && entity.getClass() != null) {
+            if (entity != null) {
                 entityClass = getEntityClass(entity);
                 break;
             }

+ 14 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SqlHelper.java

@@ -23,6 +23,7 @@ import com.baomidou.mybatisplus.core.toolkit.*;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import lombok.SneakyThrows;
 import org.apache.ibatis.exceptions.PersistenceException;
+import org.apache.ibatis.executor.BatchResult;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.reflection.ExceptionUtil;
 import org.apache.ibatis.session.ExecutorType;
@@ -33,6 +34,7 @@ import org.mybatis.spring.SqlSessionHolder;
 import org.mybatis.spring.SqlSessionUtils;
 import org.springframework.transaction.support.TransactionSynchronizationManager;
 
+import java.sql.Statement;
 import java.util.Collection;
 import java.util.List;
 import java.util.Optional;
@@ -40,6 +42,7 @@ import java.util.function.BiConsumer;
 import java.util.function.BiPredicate;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import java.util.stream.IntStream;
 
 /**
  * SQL 辅助类
@@ -124,6 +127,17 @@ public final class SqlHelper {
         return null != result && result >= 1;
     }
 
+    /**
+     * 批量操作是否成功
+     *
+     * @param result 批量操作结果集
+     * @return 操作结果(批量行记录全满足成功的的情况下为true)
+     * @since 3.5.8
+     */
+    public static boolean retBool(List<BatchResult> result) {
+        return result != null && result.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).allMatch(i -> i > 0 || i == Statement.SUCCESS_NO_INFO);
+    }
+
     /**
      * 返回SelectCount执行结果
      *

+ 37 - 55
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserMapperTest.java

@@ -10,6 +10,7 @@ import com.baomidou.mybatisplus.core.toolkit.IdWorker;
 import com.baomidou.mybatisplus.core.toolkit.MybatisBatchUtils;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
+import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
 import com.baomidou.mybatisplus.test.h2.entity.H2User;
 import com.baomidou.mybatisplus.test.h2.entity.SuperEntity;
 import com.baomidou.mybatisplus.test.h2.enums.AgeEnum;
@@ -31,6 +32,7 @@ import java.util.Date;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.stream.IntStream;
 
 import static java.util.stream.Collectors.toList;
 
@@ -58,7 +60,8 @@ class H2UserMapperTest extends BaseTest {
     void testMapperSaveBatch() {
         var list = List.of(new H2User("秋秋1"), new H2User("秋秋2"));
         List<BatchResult> batchResults = userMapper.insert(list);
-        Assertions.assertEquals(2, batchResults.get(0).getUpdateCounts().length);
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        Assertions.assertEquals(2, batchResults.getFirst().getUpdateCounts().length);
     }
 
     @Test
@@ -69,7 +72,8 @@ class H2UserMapperTest extends BaseTest {
             h2User.setName("test" + 1);
         }
         List<BatchResult> batchResults = userMapper.updateById(list);
-        Assertions.assertEquals(2, batchResults.get(0).getUpdateCounts().length);
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        Assertions.assertEquals(2, batchResults.getFirst().getUpdateCounts().length);
     }
 
 
@@ -108,15 +112,12 @@ class H2UserMapperTest extends BaseTest {
         MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
         // 执行批量插入
         List<BatchResult> batchResults = MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, mapperMethod.insert());
-        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        int[] updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(1, updateCount);
-        }
-
         List<Long> ids = Arrays.asList(120000L, 120001L);
         MybatisBatch.Method<H2User> method = new MybatisBatch.Method<>(H2UserMapper.class);
-        MybatisBatchUtils.execute(sqlSessionFactory, ids, method.insert(H2User::ofId));
+        Assertions.assertTrue(SqlHelper.retBool(MybatisBatchUtils.execute(sqlSessionFactory, ids, method.insert(H2User::ofId))));
     }
 
     @Test
@@ -131,34 +132,26 @@ class H2UserMapperTest extends BaseTest {
         MybatisBatch.Method<H2User> method = new MybatisBatch.Method<>(H2UserMapper.class);
         // 执行批量插入
         batchResults = MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, method.get("myInsertWithoutParam"));
-        updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(1, updateCount);
-        }
-
         h2UserList = new ArrayList<>();
         for (int i = 0; i < batchSize; i++) {
             h2UserList.add(new H2User("myInsertWithParam" + i));
         }
         // 执行批量插入
         batchResults = MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, method.get("myInsertWithParam", parameter -> Map.of("user1", parameter)));
-        updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(1, updateCount);
-        }
-
         h2UserList = new ArrayList<>();
         for (int i = 0; i < batchSize; i++) {
             h2UserList.add(new H2User("myInsertWithParam" + i));
         }
         batchResults = MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, method.get("myInsertWithParam", parameter -> Map.of("user1", parameter)));
-        updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(1, updateCount);
-        }
     }
 
     @Test
@@ -177,20 +170,17 @@ class H2UserMapperTest extends BaseTest {
             userList.add(h2User);
             return h2User;
         }));
-        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertFalse(SqlHelper.retBool(batchResults));
+        int[] updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(0, updateCount);
-        }
         for (H2User h2User : userList) {
             Assertions.assertNotNull(h2User.getLastUpdatedDt());
         }
         // 不能走填充
         batchResults = MybatisBatchUtils.execute(sqlSessionFactory, ids, method.deleteById());
-        updateCounts = batchResults.get(0).getUpdateCounts();
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(0, updateCount);
-        }
+        Assertions.assertFalse(SqlHelper.retBool(batchResults));
+        updateCounts = batchResults.getFirst().getUpdateCounts();
+        Assertions.assertEquals(batchSize, updateCounts.length);
     }
 
     @Test
@@ -203,20 +193,18 @@ class H2UserMapperTest extends BaseTest {
         MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
         // 执行批量更新
         List<BatchResult> batchResults = MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, mapperMethod.updateById());
-        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertFalse(SqlHelper.retBool(batchResults));
+        int[] updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(0, updateCount);
-        }
 
         List<Long> ids = Arrays.asList(120000L, 120001L);
         MybatisBatch.Method<H2User> method = new MybatisBatch.Method<>(H2UserMapper.class);
 
-        MybatisBatchUtils.execute(sqlSessionFactory, ids, method.update(id -> Wrappers.<H2User>lambdaUpdate().set(H2User::getName, "updateTest").eq(H2User::getTestId, id)));
-        MybatisBatchUtils.execute(sqlSessionFactory, ids, method.update(id -> new H2User().setName("updateTest2"), id -> Wrappers.<H2User>lambdaUpdate().eq(H2User::getTestId, id)));
+        Assertions.assertFalse(SqlHelper.retBool(MybatisBatchUtils.execute(sqlSessionFactory, ids, method.update(id -> Wrappers.<H2User>lambdaUpdate().set(H2User::getName, "updateTest").eq(H2User::getTestId, id)))));
+        Assertions.assertFalse(SqlHelper.retBool(MybatisBatchUtils.execute(sqlSessionFactory, ids, method.update(id -> new H2User().setName("updateTest2"), id -> Wrappers.<H2User>lambdaUpdate().eq(H2User::getTestId, id)))));
 
-        MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, method.update(user -> Wrappers.<H2User>update().set("name", "updateTest3").eq("test_id", user.getTestId())));
-        MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, method.update(user -> new H2User("updateTests4"), p -> Wrappers.<H2User>update().eq("test_id", p.getTestId())));
+        Assertions.assertFalse(SqlHelper.retBool(MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, method.update(user -> Wrappers.<H2User>update().set("name", "updateTest3").eq("test_id", user.getTestId())))));
+        Assertions.assertFalse(SqlHelper.retBool(MybatisBatchUtils.execute(sqlSessionFactory, h2UserList, method.update(user -> new H2User("updateTests4"), p -> Wrappers.<H2User>update().eq("test_id", p.getTestId())))));
     }
 
     @Test
@@ -232,11 +220,9 @@ class H2UserMapperTest extends BaseTest {
                 ((sqlSession, h2User) -> userMapper.selectById(h2User.getTestId()) == null),
                 mapperMethod.updateById());
         // 没有使用共享的sqlSession,由于都是新增返回还是一个批次
-        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        int[] updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(1, updateCount);
-        }
     }
 
     @Test
@@ -247,9 +233,9 @@ class H2UserMapperTest extends BaseTest {
             h2UserList.add(new H2User(Long.valueOf(140000 + i), "test" + i));
         }
         List<BatchResult> batchResults = userMapper.insertOrUpdate(h2UserList);
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
         Assertions.assertEquals(batchSize, batchResults.size());
         // 使用共享的sqlSession,等于每次都是刷新了,批次总结果集就等于数据大小了
-        Assertions.assertEquals(batchSize, batchResults.size());
         for (BatchResult batchResult : batchResults) {
             Assertions.assertEquals(batchResult.getUpdateCounts().length, 1);
             Assertions.assertEquals(1, batchResult.getUpdateCounts()[0]);
@@ -264,12 +250,10 @@ class H2UserMapperTest extends BaseTest {
             h2UserList.add(new H2User(Long.valueOf(40000 + i), "test" + i));
         }
         List<BatchResult> batchResults = userMapper.insertOrUpdate(h2UserList,((sqlSession, h2User) -> userMapper.selectById(h2User.getTestId()) == null));
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
         // 没有使用共享的sqlSession,由于都是新增返回还是一个批次
-        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        int[] updateCounts = batchResults.getFirst().getUpdateCounts();
         Assertions.assertEquals(batchSize, updateCounts.length);
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(1, updateCount);
-        }
     }
 
     @Test
@@ -290,10 +274,8 @@ class H2UserMapperTest extends BaseTest {
         // 共享一个sqlSession,每次selectById都会刷新一下,第二条记录为update.
         var batchResults = userMapper.insertOrUpdate(h2UserList,
             ((sqlSession, h2User) -> sqlSession.selectList(mapperMethod.get("selectById").getStatementId(), h2User.getTestId()).isEmpty()));
-        var updateCounts = batchResults.get(0).getUpdateCounts();
-        for (int updateCount : updateCounts) {
-            Assertions.assertEquals(1, updateCount);
-        }
+        Assertions.assertTrue(SqlHelper.retBool(batchResults));
+        Assertions.assertEquals(h2UserList.size(), batchResults.stream().flatMapToInt(r -> IntStream.of(r.getUpdateCounts())).count());
         Assertions.assertEquals(userMapper.selectById(id).getName(), "testSaveOrUpdateBatchMapper4-1");
     }
 
@@ -301,12 +283,12 @@ class H2UserMapperTest extends BaseTest {
     void testRemoveByIds() {
         Assertions.assertEquals(userMapper.deleteByIds(List.of(666666661, "2")), userMapper.deleteByIds(List.of(666666661, "2"), false));
         H2User h2User = new H2User("testRemoveByIds");
-        userMapper.insert(h2User);
-        userMapper.deleteByIds(List.of(h2User));
+        Assertions.assertEquals(1, userMapper.insert(h2User));
+        Assertions.assertEquals(1, userMapper.deleteByIds(List.of(h2User)));
         Assertions.assertNotNull(userMapper.getById(h2User.getTestId()).getLastUpdatedDt());
         h2User = new H2User("testRemoveByIds");
-        userMapper.insert(h2User);
-        userMapper.deleteByIds(List.of(h2User), false);
+        Assertions.assertEquals(1, userMapper.insert(h2User));
+        Assertions.assertEquals(1, userMapper.deleteByIds(List.of(h2User), false));
         Assertions.assertNull(userMapper.getById(h2User.getTestId()).getLastUpdatedDt());
     }
 

+ 16 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/optimisticlocker/OptimisticLockerTest.java

@@ -3,8 +3,10 @@ package com.baomidou.mybatisplus.test.optimisticlocker;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
 import com.baomidou.mybatisplus.extension.plugins.inner.OptimisticLockerInnerInterceptor;
+import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
 import com.baomidou.mybatisplus.test.BaseDbTest;
 import org.apache.ibatis.plugin.Interceptor;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.Arrays;
@@ -52,6 +54,20 @@ public class OptimisticLockerTest extends BaseDbTest<EntityMapper> {
         });
     }
 
+    @Test
+    void testBatch() {
+        var entity1 = new Entity().setName("testBatch").setVersion(1);
+        var entity2 = new Entity().setName("testBatch").setVersion(1);
+        var entity3 = new Entity().setName("testBatch").setVersion(1);
+        var entityList = List.of(entity1, entity2, entity3);
+        doTest(mapper -> {
+            Assertions.assertTrue(SqlHelper.retBool(mapper.insert(entityList)));
+            Assertions.assertTrue(SqlHelper.retBool(mapper.updateById(entityList)));
+            entity2.setVersion(6);
+            Assertions.assertFalse(SqlHelper.retBool(mapper.updateById(entityList)));
+        });
+    }
+
     @Override
     protected List<Interceptor> interceptors() {
         MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();