Browse Source

批量操作异常转换为DataAccessException(不兼容改动).

聂秋秋 5 năm trước cách đây
mục cha
commit
e7eea9400f

+ 38 - 16
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -27,7 +27,10 @@ import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
 import org.apache.ibatis.binding.MapperMethod;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
+import org.apache.ibatis.session.ExecutorType;
 import org.apache.ibatis.session.SqlSession;
+import org.apache.ibatis.session.SqlSessionFactory;
+import org.mybatis.spring.MyBatisExceptionTranslator;
 import org.mybatis.spring.SqlSessionUtils;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.transaction.annotation.Transactional;
@@ -37,6 +40,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -75,7 +79,10 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
 
     /**
      * 批量操作 SqlSession
+     *
+     * @deprecated 3.2.1
      */
+    @Deprecated
     protected SqlSession sqlSessionBatch() {
         return SqlHelper.sqlSessionBatch(currentModelClass());
     }
@@ -114,19 +121,18 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     @Transactional(rollbackFor = Exception.class)
     @Override
     public boolean saveBatch(Collection<T> entityList, int batchSize) {
-        SqlHelper.clearCache(currentModelClass());
         String sqlStatement = sqlStatement(SqlMethod.INSERT_ONE);
         int size = entityList.size();
-        try (SqlSession batchSqlSession = sqlSessionBatch()) {
+        executeBatch(sqlSession -> {
             int i = 1;
             for (T entity : entityList) {
-                batchSqlSession.insert(sqlStatement, entity);
+                sqlSession.insert(sqlStatement, entity);
                 if ((i % batchSize == 0) || i == size) {
-                    batchSqlSession.flushStatements();
+                    sqlSession.flushStatements();
                 }
                 i++;
             }
-        }
+        });
         return true;
     }
 
@@ -156,30 +162,29 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     public boolean saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
         Assert.notEmpty(entityList, "error: entityList must not be empty");
         Class<?> cls = currentModelClass();
-        SqlHelper.clearCache(cls);
         TableInfo tableInfo = TableInfoHelper.getTableInfo(cls);
         Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
         String keyProperty = tableInfo.getKeyProperty();
         Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
         int size = entityList.size();
-        try (SqlSession batchSqlSession = sqlSessionBatch()) {
+        executeBatch(sqlSession -> {
             int i = 1;
             for (T entity : entityList) {
                 Object idVal = ReflectionKit.getMethodValue(cls, entity, keyProperty);
                 if (StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal))) {
-                    batchSqlSession.insert(sqlStatement(SqlMethod.INSERT_ONE), entity);
+                    sqlSession.insert(sqlStatement(SqlMethod.INSERT_ONE), entity);
                 } else {
                     MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
                     param.put(Constants.ENTITY, entity);
-                    batchSqlSession.update(sqlStatement(SqlMethod.UPDATE_BY_ID), param);
+                    sqlSession.update(sqlStatement(SqlMethod.UPDATE_BY_ID), param);
                 }
                 // 不知道以后会不会有人说更新失败了还要执行插入 😂😂😂
                 if ((i % batchSize == 0) || i == size) {
-                    batchSqlSession.flushStatements();
+                    sqlSession.flushStatements();
                 }
                 i++;
             }
-        }
+        });
         return true;
     }
 
@@ -219,20 +224,19 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     public boolean updateBatchById(Collection<T> entityList, int batchSize) {
         Assert.notEmpty(entityList, "error: entityList must not be empty");
         String sqlStatement = sqlStatement(SqlMethod.UPDATE_BY_ID);
-        SqlHelper.clearCache(currentModelClass());
         int size = entityList.size();
-        try (SqlSession batchSqlSession = sqlSessionBatch()) {
+        executeBatch(sqlSession -> {
             int i = 1;
             for (T anEntityList : entityList) {
                 MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
                 param.put(Constants.ENTITY, anEntityList);
-                batchSqlSession.update(sqlStatement, param);
+                sqlSession.update(sqlStatement, param);
                 if ((i % batchSize == 0) || i == size) {
-                    batchSqlSession.flushStatements();
+                    sqlSession.flushStatements();
                 }
                 i++;
             }
-        }
+        });
         return true;
     }
 
@@ -298,4 +302,22 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     public <V> V getObj(Wrapper<T> queryWrapper, Function<? super Object, V> mapper) {
         return SqlHelper.getObject(log, listObjs(queryWrapper, mapper));
     }
+
+    /**
+     * 执行批量操作
+     *
+     * @param fun fun
+     * @since 3.2.1
+     */
+    protected void executeBatch(Consumer<SqlSession> fun) {
+        Class<T> tClass = currentModelClass();
+        SqlHelper.clearCache(tClass);
+        SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(tClass);
+        try (SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH)) {
+            fun.accept(sqlSession);
+        } catch (RuntimeException ex) {
+            MyBatisExceptionTranslator myBatisExceptionTranslator = new MyBatisExceptionTranslator(sqlSessionFactory.getConfiguration().getEnvironment().getDataSource(), true);
+            throw Objects.requireNonNull(myBatisExceptionTranslator.translateExceptionIfPossible(ex));
+        }
+    }
 }

+ 12 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SqlHelper.java

@@ -54,7 +54,18 @@ public final class SqlHelper {
      */
     public static SqlSession sqlSessionBatch(Class<?> clazz) {
         // TODO 暂时让能用先,但日志会显示Closing non transactional SqlSession,因为这个并没有绑定.
-        return GlobalConfigUtils.currentSessionFactory(clazz).openSession(ExecutorType.BATCH);
+        return sqlSessionFactory(clazz).openSession(ExecutorType.BATCH);
+    }
+
+    /**
+     * 获取SqlSessionFactory
+     *
+     * @param clazz 实体类
+     * @return SqlSessionFactory
+     * @since 3.2.1
+     */
+    public static SqlSessionFactory sqlSessionFactory(Class<?> clazz) {
+        return GlobalConfigUtils.currentSessionFactory(clazz);
     }
 
     /**

+ 2 - 2
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java

@@ -24,10 +24,10 @@ import com.baomidou.mybatisplus.test.h2.enums.AgeEnum;
 import com.baomidou.mybatisplus.test.h2.service.IH2UserService;
 import net.sf.jsqlparser.parser.CCJSqlParserUtil;
 import net.sf.jsqlparser.statement.select.Select;
-import org.apache.ibatis.exceptions.PersistenceException;
 import org.junit.jupiter.api.*;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.dao.DataAccessException;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
 import org.springframework.transaction.annotation.Transactional;
@@ -350,7 +350,7 @@ class H2UserTest extends BaseTest {
                     new H2User(1L, "andy")
             ));
         } catch (Exception e) {
-            Assertions.assertTrue(e instanceof PersistenceException);
+            Assertions.assertTrue(e instanceof DataAccessException);
         }
     }