瀏覽代碼

SaveOrUpdateBatch支持自定义条件.

https://gitee.com/baomidou/mybatis-plus/issues/I1MK3L
nieqiuqiu 5 年之前
父節點
當前提交
ff646d14d4

+ 44 - 56
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -26,16 +26,10 @@ import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
 import org.apache.ibatis.binding.MapperMethod;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
-import org.apache.ibatis.reflection.ExceptionUtil;
-import org.apache.ibatis.session.ExecutorType;
 import org.apache.ibatis.session.SqlSession;
-import org.apache.ibatis.session.SqlSessionFactory;
-import org.mybatis.spring.MyBatisExceptionTranslator;
-import org.mybatis.spring.SqlSessionHolder;
 import org.mybatis.spring.SqlSessionUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.transaction.annotation.Transactional;
-import org.springframework.transaction.support.TransactionSynchronizationManager;
 
 import java.io.Serializable;
 import java.util.Collection;
@@ -44,6 +38,8 @@ import java.util.Objects;
 import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.function.Predicate;
+import java.util.function.Supplier;
 
 /**
  * IService 实现类( 泛型:M 是 mapper 对象,T 是实体 )
@@ -155,15 +151,13 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
         Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
         String keyProperty = tableInfo.getKeyProperty();
         Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
-        return executeBatch(entityList, batchSize, (sqlSession, entity) -> {
+        return SqlHelper.saveOrUpdateBatch(this.entityClass, this.log, entityList, batchSize, entity -> {
             Object idVal = ReflectionKit.getFieldValue(entity, keyProperty);
-            if (StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal))) {
-                sqlSession.insert(tableInfo.getSqlStatement(SqlMethod.INSERT_ONE.getMethod()), entity);
-            } else {
-                MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
-                param.put(Constants.ENTITY, entity);
-                sqlSession.update(tableInfo.getSqlStatement(SqlMethod.UPDATE_BY_ID.getMethod()), param);
-            }
+            return StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal));
+        }, (sqlSession, entity) -> {
+            MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
+            param.put(Constants.ENTITY, entity);
+            sqlSession.update(tableInfo.getSqlStatement(SqlMethod.UPDATE_BY_ID.getMethod()), param);
         });
     }
 
@@ -205,36 +199,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
      */
     @Deprecated
     protected boolean executeBatch(Consumer<SqlSession> consumer) {
-        SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(entityClass);
-        SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);
-        boolean transaction = TransactionSynchronizationManager.isSynchronizationActive();
-        if (sqlSessionHolder != null) {
-            SqlSession sqlSession = sqlSessionHolder.getSqlSession();
-            //原生无法支持执行器切换,当存在批量操作时,会嵌套两个session的,优先commit上一个session
-            //按道理来说,这里的值应该一直为false。
-            sqlSession.commit(!transaction);
-        }
-        SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH);
-        if (!transaction) {
-            log.warn("SqlSession [" + sqlSession + "] was not registered for synchronization because DataSource is not transactional");
-        }
-        try {
-            consumer.accept(sqlSession);
-            //非事物情况下,强制commit。
-            sqlSession.commit(!transaction);
-            return true;
-        } catch (Throwable t) {
-            sqlSession.rollback();
-            Throwable unwrapped = ExceptionUtil.unwrapThrowable(t);
-            if (unwrapped instanceof RuntimeException) {
-                MyBatisExceptionTranslator myBatisExceptionTranslator
-                    = new MyBatisExceptionTranslator(sqlSessionFactory.getConfiguration().getEnvironment().getDataSource(), true);
-                throw Objects.requireNonNull(myBatisExceptionTranslator.translateExceptionIfPossible((RuntimeException) unwrapped));
-            }
-            throw ExceptionUtils.mpe(unwrapped);
-        } finally {
-            sqlSession.close();
-        }
+        return SqlHelper.executeBatch(this.entityClass, this.log, consumer);
     }
 
     /**
@@ -248,18 +213,7 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
      * @since 3.3.1
      */
     protected <E> boolean executeBatch(Collection<E> list, int batchSize, BiConsumer<SqlSession, E> consumer) {
-        Assert.isFalse(batchSize < 1, "batchSize must not be less than one");
-        return !CollectionUtils.isEmpty(list) && executeBatch(sqlSession -> {
-            int size = list.size();
-            int i = 1;
-            for (E element : list) {
-                consumer.accept(sqlSession, element);
-                if ((i % batchSize == 0) || i == size) {
-                    sqlSession.flushStatements();
-                }
-                i++;
-            }
-        });
+        return SqlHelper.executeBatch(this.entityClass, this.log, list, batchSize, consumer);
     }
 
     /**
@@ -274,4 +228,38 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     protected <E> boolean executeBatch(Collection<E> list, BiConsumer<SqlSession, E> consumer) {
         return executeBatch(list, DEFAULT_BATCH_SIZE, consumer);
     }
+
+    /**
+     * 批量更新或新增
+     *
+     * @param list      数据集合
+     * @param batchSize 批量大小
+     * @param predicate 新增条件 notnull
+     * @param function  更新条件 notnull
+     * @return 操作结果
+     * @since 3.3.3
+     */
+    protected boolean saveOrUpdateBatch(Collection<T> list, int batchSize, Predicate<T> predicate, Function<T, Wrapper<T>> function) {
+        TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
+        return SqlHelper.saveOrUpdateBatch(this.entityClass, this.log, list, batchSize, predicate, ((sqlSession, entity) -> {
+            String sqlStatement = tableInfo.getSqlStatement(SqlMethod.UPDATE.getMethod());
+            MapperMethod.ParamMap<Object> param = new MapperMethod.ParamMap<>();
+            param.put(Constants.ENTITY, entity);
+            param.put(Constants.WRAPPER, function.apply(entity));
+            sqlSession.update(sqlStatement, param);
+        }));
+    }
+
+    /**
+     * 批量更新或新增
+     *
+     * @param list           数据集合
+     * @param predicate      新增条件
+     * @param updateFunction 更新条件
+     * @return 操作结果
+     * @since 3.3.3
+     */
+    protected boolean saveOrUpdateBatch(Collection<T> list, Predicate<T> predicate, Function<T,Wrapper<T>> updateFunction) {
+        return saveOrUpdateBatch(list, DEFAULT_BATCH_SIZE, predicate, updateFunction);
+    }
 }

+ 102 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SqlHelper.java

@@ -15,20 +15,29 @@
  */
 package com.baomidou.mybatisplus.extension.toolkit;
 
+import com.baomidou.mybatisplus.core.enums.SqlMethod;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.toolkit.Assert;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
 import com.baomidou.mybatisplus.core.toolkit.GlobalConfigUtils;
 import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.reflection.ExceptionUtil;
 import org.apache.ibatis.session.ExecutorType;
 import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
+import org.mybatis.spring.MyBatisExceptionTranslator;
 import org.mybatis.spring.SqlSessionHolder;
 import org.mybatis.spring.SqlSessionUtils;
 import org.springframework.transaction.support.TransactionSynchronizationManager;
 
+import java.util.Collection;
 import java.util.List;
+import java.util.Objects;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+import java.util.function.Predicate;
 
 /**
  * SQL 辅助类
@@ -141,4 +150,97 @@ public final class SqlHelper {
             sqlSession.clearCache();
         }
     }
+
+    /**
+     * 执行批量操作
+     *
+     * @param entityClass 实体
+     * @param log         日志对象
+     * @param consumer    consumer
+     * @return 操作结果
+     * @since 3.3.3
+     */
+    public static boolean executeBatch(Class<?> entityClass, Log log, Consumer<SqlSession> consumer) {
+        SqlSessionFactory sqlSessionFactory = sqlSessionFactory(entityClass);
+        SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);
+        boolean transaction = TransactionSynchronizationManager.isSynchronizationActive();
+        if (sqlSessionHolder != null) {
+            SqlSession sqlSession = sqlSessionHolder.getSqlSession();
+            //原生无法支持执行器切换,当存在批量操作时,会嵌套两个session的,优先commit上一个session
+            //按道理来说,这里的值应该一直为false。
+            sqlSession.commit(!transaction);
+        }
+        SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH);
+        if (!transaction) {
+            log.warn("SqlSession [" + sqlSession + "] was not registered for synchronization because DataSource is not transactional");
+        }
+        try {
+            consumer.accept(sqlSession);
+            //非事物情况下,强制commit。
+            sqlSession.commit(!transaction);
+            return true;
+        } catch (Throwable t) {
+            sqlSession.rollback();
+            Throwable unwrapped = ExceptionUtil.unwrapThrowable(t);
+            if (unwrapped instanceof RuntimeException) {
+                MyBatisExceptionTranslator myBatisExceptionTranslator
+                    = new MyBatisExceptionTranslator(sqlSessionFactory.getConfiguration().getEnvironment().getDataSource(), true);
+                throw Objects.requireNonNull(myBatisExceptionTranslator.translateExceptionIfPossible((RuntimeException) unwrapped));
+            }
+            throw ExceptionUtils.mpe(unwrapped);
+        } finally {
+            sqlSession.close();
+        }
+    }
+
+    /**
+     * 执行批量操作
+     *
+     * @param entityClass 实体类
+     * @param log         日志对象
+     * @param list        数据集合
+     * @param batchSize   批次大小
+     * @param consumer    consumer
+     * @param <E>         T
+     * @return 操作结果
+     * @since 3.3.3
+     */
+    public static <E> boolean executeBatch(Class<?> entityClass, Log log, Collection<E> list, int batchSize, BiConsumer<SqlSession, E> consumer) {
+        Assert.isFalse(batchSize < 1, "batchSize must not be less than one");
+        return !CollectionUtils.isEmpty(list) && executeBatch(entityClass, log, sqlSession -> {
+            int size = list.size();
+            int i = 1;
+            for (E element : list) {
+                consumer.accept(sqlSession, element);
+                if ((i % batchSize == 0) || i == size) {
+                    sqlSession.flushStatements();
+                }
+                i++;
+            }
+        });
+    }
+
+    /**
+     * 批量更新或保存
+     *
+     * @param entityClass 实体
+     * @param log         日志对象
+     * @param list        数据集合
+     * @param batchSize   批次大小
+     * @param predicate   predicate(新增条件)
+     * @param consumer    consumer(更新处理)
+     * @param <E>         E
+     * @return 操作结果
+     * @since 3.3.3
+     */
+    public static <E> boolean saveOrUpdateBatch(Class<?> entityClass, Log log, Collection<E> list, int batchSize, Predicate<E> predicate, BiConsumer<SqlSession, E> consumer) {
+        TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
+        return executeBatch(entityClass, log, list, batchSize, (sqlSession, entity) -> {
+            if (predicate.test(entity)) {
+                sqlSession.insert(tableInfo.getSqlStatement(SqlMethod.INSERT_ONE.getMethod()), entity);
+            } else {
+                consumer.accept(sqlSession, entity);
+            }
+        });
+    }
 }

+ 5 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/CacheTest.java

@@ -210,6 +210,11 @@ class CacheTest {
         cacheService.page(page1);
         Assertions.assertEquals(cache.getSize(), 2);
     }
+
+    @Test
+    void testCustomSaveOrUpdateBatch(){
+        Assertions.assertTrue(cacheService.testCustomSaveOrUpdateBatch());
+    }
     
     private Cache getCache() {
         return sqlSessionFactory.getConfiguration().getCache(CacheMapper.class.getName());

+ 2 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/service/ICacheService.java

@@ -18,4 +18,6 @@ public interface ICacheService extends IService<CacheModel> {
     long testBatchTransactionalClear6();
 
     long testBatchTransactionalClear7();
+
+    boolean testCustomSaveOrUpdateBatch();
 }

+ 12 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/cache/service/impl/CacheServiceImpl.java

@@ -1,5 +1,6 @@
 package com.baomidou.mybatisplus.test.h2.cache.service.impl;
 
+import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
 import com.baomidou.mybatisplus.core.enums.SqlMethod;
 import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
 import com.baomidou.mybatisplus.test.h2.cache.mapper.CacheMapper;
@@ -8,6 +9,7 @@ import com.baomidou.mybatisplus.test.h2.cache.service.ICacheService;
 import org.springframework.stereotype.Service;
 import org.springframework.transaction.annotation.Transactional;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 
@@ -20,6 +22,16 @@ public class CacheServiceImpl extends ServiceImpl<CacheMapper, CacheModel> imple
         executeBatch(idList, (sqlSession, id) -> sqlSession.delete(sqlStatement, id));
     }
 
+    @Override
+    @Transactional
+    public boolean testCustomSaveOrUpdateBatch() {
+        CacheModel model1 = new CacheModel();
+        CacheModel model2 = new CacheModel("旺仔");
+        //name为空写入,不为空按条件更新
+        boolean result = saveOrUpdateBatch(Arrays.asList(model1, model2), entity -> entity.getName() == null, (entity) -> new QueryWrapper<CacheModel>().lambda().eq(CacheModel::getName, entity.getName()));
+        return model1.getId() != null && model2.getId() == null && result;
+    }
+
     @Override
     @Transactional
     public long testBatchTransactionalClear1() {