Explorar o código

新增批量执行方法.

nieqiurong hai 1 ano
pai
achega
227d5c0cc9

+ 38 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/batch/BatchMethod.java

@@ -0,0 +1,38 @@
+package com.baomidou.mybatisplus.core.batch;
+
+import org.apache.ibatis.mapping.MappedStatement;
+
+/**
+ * @author nieqiurong
+ * @since 3.5.4
+ */
+public class BatchMethod<T> {
+
+    /**
+     * 执行的{@link MappedStatement#getId()}
+     */
+    private final String statementId;
+
+    /**
+     * 方法参数转换器,默认传递批量的entity的参数
+     */
+    private ParameterConvert<T> parameterConvert;
+
+    public BatchMethod(String statementId) {
+        this.statementId = statementId;
+    }
+
+    public BatchMethod(String statementId, ParameterConvert<T> parameterConvert) {
+        this.statementId = statementId;
+        this.parameterConvert = parameterConvert;
+    }
+
+    public String getStatementId() {
+        return statementId;
+    }
+
+    public ParameterConvert<T> getParameterConvert() {
+        return parameterConvert;
+    }
+
+}

+ 71 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/batch/BatchSqlSession.java

@@ -0,0 +1,71 @@
+package com.baomidou.mybatisplus.core.batch;
+
+import org.apache.ibatis.executor.BatchResult;
+import org.apache.ibatis.session.RowBounds;
+import org.apache.ibatis.session.SqlSession;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * 当使用Batch混合查询时,每次都会将原来的结果集清空,建议使用Batch时就不要混合使用select了 (后面看看要不要改成动态代理把...)
+ *
+ * @author nieqiurong
+ * @since 3.5.4
+ */
+public class BatchSqlSession {
+
+    private final SqlSession sqlSession;
+
+    private final List<BatchResult> resultBatchList = new ArrayList<>();
+
+    public BatchSqlSession(SqlSession sqlSession) {
+        this.sqlSession = sqlSession;
+    }
+
+    public <T> T selectOne(String statement) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectOne(statement);
+    }
+
+    public <T> T selectOne(String statement, Object parameter) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectOne(statement, parameter);
+    }
+
+    public <E> List<E> selectList(String statement) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectList(statement);
+    }
+
+    public <E> List<E> selectList(String statement, Object parameter) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectList(statement, parameter);
+    }
+
+    public <E> List<E> selectList(String statement, Object parameter, RowBounds rowBounds) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectList(statement, parameter, rowBounds);
+    }
+
+    public <K, V> Map<K, V> selectMap(String statement, String mapKey) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectMap(statement, mapKey);
+    }
+
+    public <K, V> Map<K, V> selectMap(String statement, Object parameter, String mapKey) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectMap(statement, parameter, mapKey);
+    }
+
+    public <K, V> Map<K, V> selectMap(String statement, Object parameter, String mapKey, RowBounds rowBounds) {
+        resultBatchList.addAll(sqlSession.flushStatements());
+        return sqlSession.selectMap(statement, parameter, mapKey, rowBounds);
+    }
+
+    public List<BatchResult> getResultBatchList() {
+        return resultBatchList;
+    }
+
+}

+ 140 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/batch/MybatisBatch.java

@@ -0,0 +1,140 @@
+package com.baomidou.mybatisplus.core.batch;
+
+import com.baomidou.mybatisplus.core.conditions.Wrapper;
+import com.baomidou.mybatisplus.core.enums.SqlMethod;
+import com.baomidou.mybatisplus.core.toolkit.Constants;
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
+import org.apache.ibatis.executor.BatchResult;
+import org.apache.ibatis.session.ExecutorType;
+import org.apache.ibatis.session.SqlSession;
+import org.apache.ibatis.session.SqlSessionFactory;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiPredicate;
+import java.util.function.Function;
+
+/**
+ * <li>事务需要自行控制</li>
+ * <li>批次数据尽量自行切割处理</li>
+ * <li>返回值为批处理结果,如果对返回值比较关心的可接收判断处理</li>
+ * <li>saveOrUpdate尽量少用把,保持批处理为简单的插入或更新</li>
+ * <li>关于saveOrUpdate中的sqlSession,如果执行了select操作的话,BatchExecutor都会触发一次flushStatements,为了保证结果集,故使用包装了部分sqlSession查询操作</li>
+ * <pre>
+ *     Spring示例:
+ * 		transactionTemplate.execute(new TransactionCallback<List<BatchResult>>() {
+ *            {@code @Override}
+ * 			public List<BatchResult> doInTransaction(TransactionStatus status) {
+ * 				MybatisBatch.Method<Demo> method = new MybatisBatch.Method<>(DemoMapper.class);
+ * 				return new MybatisBatch<>(sqlSessionFactory,demoList).execute(true, method.insert());
+ *            }
+ *        });
+ * </pre>
+ *
+ * @author nieqiurong
+ * @since 3.5.4
+ */
+public class MybatisBatch<T> {
+
+    private final SqlSessionFactory sqlSessionFactory;
+
+    private final List<T> dataList;
+
+    public MybatisBatch(SqlSessionFactory sqlSessionFactory, List<T> dataList) {
+        this.sqlSessionFactory = sqlSessionFactory;
+        this.dataList = dataList;
+    }
+
+    public List<BatchResult> execute(String statement) {
+        return execute(false, statement, (entity) -> entity);
+    }
+
+    public List<BatchResult> execute(String statement, ParameterConvert<T> parameterConvert) {
+        return execute(false, statement, parameterConvert);
+    }
+
+    public List<BatchResult> execute(boolean autoCommit, String statement) {
+        return execute(autoCommit, statement, null);
+    }
+
+    public List<BatchResult> execute(BatchMethod<T> batchMethod) {
+        return execute(false, batchMethod);
+    }
+
+    public List<BatchResult> execute(boolean autoCommit, BatchMethod<T> batchMethod) {
+        try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, autoCommit)) {
+            for (T data : dataList) {
+                ParameterConvert<T> parameterConvert = batchMethod.getParameterConvert();
+                sqlSession.update(batchMethod.getStatementId(), toParameter(parameterConvert, data));
+            }
+            return sqlSession.flushStatements();
+        }
+    }
+
+    public List<BatchResult> execute(boolean autoCommit, String statement, ParameterConvert<T> parameterConvert) {
+        try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, autoCommit)) {
+            for (T data : dataList) {
+                sqlSession.update(statement, parameterConvert != null ? parameterConvert.convert(data) : data);
+            }
+            return sqlSession.flushStatements();
+        }
+    }
+
+    public List<BatchResult> saveOrUpdate(BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod) {
+        return saveOrUpdate(false, insertMethod, insertPredicate, updateMethod);
+    }
+
+    public List<BatchResult> saveOrUpdate(boolean autoCommit, BatchMethod<T> insertMethod, BiPredicate<BatchSqlSession, T> insertPredicate, BatchMethod<T> updateMethod) {
+        List<BatchResult> resultList = new ArrayList<>();
+        try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH, autoCommit)) {
+            BatchSqlSession session = new BatchSqlSession(sqlSession);
+            for (T data : dataList) {
+                if (insertPredicate.test(session, data)) {
+                    sqlSession.insert(insertMethod.getStatementId(), toParameter(insertMethod.getParameterConvert(), data));
+                } else {
+                    sqlSession.update(updateMethod.getStatementId(), toParameter(updateMethod.getParameterConvert(), data));
+                }
+            }
+            resultList.addAll(sqlSession.flushStatements());
+            resultList.addAll(session.getResultBatchList());
+            return resultList;
+        }
+    }
+
+    protected Object toParameter(ParameterConvert<T> parameterConvert, T data) {
+        return parameterConvert != null ? parameterConvert.convert(data) : data;
+    }
+
+    public static class Method<T> {
+
+        private final String namespace;
+
+        public Method(Class<?> mapperClass) {
+            this.namespace = mapperClass.getName();
+        }
+
+        public BatchMethod<T> insert() {
+            return new BatchMethod<>(namespace + StringPool.DOT + SqlMethod.INSERT_ONE.getMethod());
+        }
+
+        public BatchMethod<T> updateById() {
+            return new BatchMethod<>(namespace + StringPool.DOT + SqlMethod.UPDATE_BY_ID.getMethod(), (entity) -> {
+                Map<String, Object> param = new HashMap<>();
+                param.put(Constants.ENTITY, entity);
+                return param;
+            });
+        }
+
+        public BatchMethod<T> update(Function<T, Wrapper<T>> wrapperFunction) {
+            return new BatchMethod<>(namespace + StringPool.DOT + SqlMethod.UPDATE.getMethod(), (entity) -> {
+                Map<String, Object> param = new HashMap<>();
+                param.put(Constants.ENTITY, entity);
+                param.put(Constants.WRAPPER, wrapperFunction.apply(entity));
+                return param;
+            });
+        }
+    }
+
+}

+ 18 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/batch/ParameterConvert.java

@@ -0,0 +1,18 @@
+package com.baomidou.mybatisplus.core.batch;
+
+/**
+ * @author nieqiurong
+ * @since 3.5.4
+ */
+@FunctionalInterface
+public interface ParameterConvert<T> {
+
+    /**
+     * 转换当前实体参数为mapper方法参数
+     *
+     * @param entity 实体对象
+     * @return mapper方法参数.
+     */
+    Object convert(T entity);
+
+}

+ 121 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserMapperTest.java

@@ -15,6 +15,7 @@
  */
  */
 package com.baomidou.mybatisplus.test.h2;
 package com.baomidou.mybatisplus.test.h2;
 
 
+import com.baomidou.mybatisplus.core.batch.MybatisBatch;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
 import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
@@ -25,12 +26,20 @@ import com.baomidou.mybatisplus.test.h2.entity.H2User;
 import com.baomidou.mybatisplus.test.h2.entity.SuperEntity;
 import com.baomidou.mybatisplus.test.h2.entity.SuperEntity;
 import com.baomidou.mybatisplus.test.h2.enums.AgeEnum;
 import com.baomidou.mybatisplus.test.h2.enums.AgeEnum;
 import com.baomidou.mybatisplus.test.h2.mapper.H2UserMapper;
 import com.baomidou.mybatisplus.test.h2.mapper.H2UserMapper;
+import org.apache.ibatis.executor.BatchResult;
+import org.apache.ibatis.session.SqlSessionFactory;
 import org.junit.jupiter.api.*;
 import org.junit.jupiter.api.*;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
+import org.springframework.transaction.TransactionStatus;
+import org.springframework.transaction.support.TransactionCallback;
+import org.springframework.transaction.support.TransactionTemplate;
 
 
 import javax.annotation.Resource;
 import javax.annotation.Resource;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Date;
 import java.util.Date;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.List;
 import java.util.List;
@@ -52,6 +61,118 @@ class H2UserMapperTest extends BaseTest {
     @Resource
     @Resource
     protected H2UserMapper userMapper;
     protected H2UserMapper userMapper;
 
 
+    @Autowired
+    private SqlSessionFactory sqlSessionFactory;
+
+    @Autowired
+    private TransactionTemplate transactionTemplate;
+
+
+    @Test
+    void testBatchTransaction(){
+        List<H2User> h2UserList = Arrays.asList(new H2User(1000036L, "测试12323232"), new H2User(10000367L, "测试3323232"));
+        try {
+            transactionTemplate.execute(new TransactionCallback<List<BatchResult>>() {
+                @Override
+                public List<BatchResult> doInTransaction(TransactionStatus status) {
+                    MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
+                    // 执行批量插入
+                    new MybatisBatch<>(sqlSessionFactory, h2UserList).execute(mapperMethod.insert());
+                    throw new RuntimeException("出错了");
+                }
+            });
+        } catch (Exception exception) {
+            for (H2User h2User : h2UserList) {
+                Assertions.assertNull(userMapper.selectById(h2User.getTestId()));
+            }
+        }
+        transactionTemplate.execute(new TransactionCallback<List<BatchResult>>() {
+            @Override
+            public List<BatchResult> doInTransaction(TransactionStatus status) {
+                MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
+                // 执行批量插入
+                return new MybatisBatch<>(sqlSessionFactory, h2UserList).execute(mapperMethod.insert());
+            }
+        });
+        for (H2User h2User : h2UserList) {
+            Assertions.assertNotNull(userMapper.selectById(h2User.getTestId()));
+        }
+    }
+
+    @Test
+    void testInsertBatch() {
+        int batchSize = 1000;
+        List<H2User> h2UserList = new ArrayList<>();
+        for (int i = 0; i < batchSize; i++) {
+            h2UserList.add(new H2User("test" + i));
+        }
+        MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
+        // 执行批量插入
+        List<BatchResult> batchResults = new MybatisBatch<>(sqlSessionFactory, h2UserList).execute(mapperMethod.insert());
+        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertEquals(batchSize, updateCounts.length);
+        for (int updateCount : updateCounts) {
+            Assertions.assertEquals(1, updateCount);
+        }
+    }
+
+    @Test
+    void testUpdateBatch() {
+        int batchSize = 1000;
+        List<H2User> h2UserList = new ArrayList<>();
+        for (int i = 0; i < batchSize; i++) {
+            h2UserList.add(new H2User(Long.valueOf(30000 + i), "test" + i));
+        }
+        MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
+        // 执行批量更新
+        List<BatchResult> batchResults = new MybatisBatch<>(sqlSessionFactory, h2UserList).execute(mapperMethod.updateById());
+        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertEquals(batchSize, updateCounts.length);
+        for (int updateCount : updateCounts) {
+            Assertions.assertEquals(0, updateCount);
+        }
+    }
+
+    @Test
+    void testSaveOrUpdateBatch1(){
+        int batchSize = 10;
+        List<H2User> h2UserList = new ArrayList<>();
+        for (int i = 0; i < batchSize; i++) {
+            h2UserList.add(new H2User(Long.valueOf(40000 + i), "test" + i));
+        }
+        MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
+        List<BatchResult> batchResults = new MybatisBatch<>(sqlSessionFactory, h2UserList).saveOrUpdate(
+                mapperMethod.insert(),
+                ((sqlSession, h2User) -> userMapper.selectById(h2User.getTestId()) == null),
+                mapperMethod.updateById());
+        // 没有使用共享的sqlSession,由于都是新增返回还是一个批次
+        int[] updateCounts = batchResults.get(0).getUpdateCounts();
+        Assertions.assertEquals(batchSize, updateCounts.length);
+        for (int updateCount : updateCounts) {
+            Assertions.assertEquals(1, updateCount);
+        }
+    }
+    @Test
+    void testSaveOrUpdateBatch2(){
+        int batchSize = 10;
+        List<H2User> h2UserList = new ArrayList<>();
+        for (int i = 0; i < batchSize; i++) {
+            h2UserList.add(new H2User(Long.valueOf(50000 + i), "test" + i));
+        }
+        MybatisBatch.Method<H2User> mapperMethod = new MybatisBatch.Method<>(H2UserMapper.class);
+        List<BatchResult> batchResults = new MybatisBatch<>(sqlSessionFactory, h2UserList).saveOrUpdate(
+                mapperMethod.insert(),
+				((sqlSession, h2User) -> sqlSession.selectList(H2UserMapper.class.getName() + ".selectById", h2User.getTestId()).isEmpty()),
+                mapperMethod.updateById());
+        // 使用共享的sqlSession,等于每次都是刷新了,批次总结果集就等于数据大小了
+        Assertions.assertEquals(batchSize, batchResults.size());
+        for (BatchResult batchResult : batchResults) {
+            Assertions.assertEquals(batchResult.getUpdateCounts().length,1);
+            Assertions.assertEquals(1, batchResult.getUpdateCounts()[0]);
+        }
+    }
+
+
     @Test
     @Test
     @Order(1)
     @Order(1)
     void crudTest() {
     void crudTest() {

+ 8 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/config/MybatisPlusConfig.java

@@ -36,6 +36,8 @@ import org.apache.ibatis.type.JdbcType;
 import org.mybatis.spring.annotation.MapperScan;
 import org.mybatis.spring.annotation.MapperScan;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.transaction.PlatformTransactionManager;
+import org.springframework.transaction.support.TransactionTemplate;
 
 
 import javax.sql.DataSource;
 import javax.sql.DataSource;
 import java.util.List;
 import java.util.List;
@@ -101,4 +103,10 @@ public class MybatisPlusConfig {
                 .setLogicNotDeleteValue("0"));
                 .setLogicNotDeleteValue("0"));
         return conf;
         return conf;
     }
     }
+
+    @Bean
+    public TransactionTemplate transactionTemplate(PlatformTransactionManager platformTransactionManager){
+        return new TransactionTemplate(platformTransactionManager);
+    }
+
 }
 }