Browse Source

feat: 增加批量更新数据上限设置:当该功能开启时,默认一条更新语句更新记录超过1000,就拦截不予执行。防止误操作出现全表更新

yuxiaobin 2 years ago
parent
commit
cf4c3bf846

+ 1 - 0
build.gradle

@@ -175,6 +175,7 @@ subprojects {
         dependsOn("cleanTest", "generatePomFileForMavenJavaPublication")
         useJUnitPlatform()
         // 解决 IdeaProxyLambdaMetaTest 和 LambdaUtilsTest 测试失败问题
+        //JDK 8测试,请删除以下两行 jvmArgs
         jvmArgs += ["--add-opens", "java.base/java.lang=ALL-UNNAMED",
                     "--add-opens", "java.base/java.lang.invoke=ALL-UNNAMED"]
         exclude("**/phoenix/**")

+ 136 - 19
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/DataChangeRecorderInnerInterceptor.java

@@ -15,9 +15,43 @@
  */
 package com.baomidou.mybatisplus.extension.plugins.inner;
 
+import java.io.Reader;
+import java.math.BigDecimal;
+import java.sql.Clob;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Properties;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.ibatis.executor.statement.StatementHandler;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.ParameterMapping;
+import org.apache.ibatis.mapping.SqlCommandType;
+import org.apache.ibatis.reflection.MetaObject;
+import org.apache.ibatis.reflection.SystemMetaObject;
+import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
+
 import lombok.Data;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.JdbcParameter;
@@ -27,32 +61,22 @@ import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.Statement;
 import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.insert.Insert;
-import net.sf.jsqlparser.statement.select.*;
+import net.sf.jsqlparser.statement.select.AllColumns;
+import net.sf.jsqlparser.statement.select.FromItem;
+import net.sf.jsqlparser.statement.select.PlainSelect;
+import net.sf.jsqlparser.statement.select.Select;
+import net.sf.jsqlparser.statement.select.SelectExpressionItem;
+import net.sf.jsqlparser.statement.select.SelectItem;
 import net.sf.jsqlparser.statement.update.Update;
 import net.sf.jsqlparser.statement.update.UpdateSet;
-import org.apache.ibatis.executor.statement.StatementHandler;
-import org.apache.ibatis.mapping.BoundSql;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.apache.ibatis.mapping.ParameterMapping;
-import org.apache.ibatis.mapping.SqlCommandType;
-import org.apache.ibatis.reflection.MetaObject;
-import org.apache.ibatis.reflection.SystemMetaObject;
-import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.Reader;
-import java.math.BigDecimal;
-import java.sql.*;
-import java.util.Date;
-import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
 
 /**
  * <p>
  * 数据变动记录插件
  * 默认会生成一条log,格式:
  * ----------------------INSERT LOG------------------------------
+ * </p>
+ * <p>
  * {
  * "tableName": "h2user",
  * "operation": "insert",
@@ -96,6 +120,10 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
 
     private final Map<String, Set<String>> ignoredTableColumns = new ConcurrentHashMap<>();
     private final Set<String> ignoreAllColumns = new HashSet<>();//全部表的这些字段名,INSERT/UPDATE都忽略,delete暂时保留
+    //批量更新上限, 默认一次最多1000条
+    private int BATCH_UPDATE_LIMIT = 1000;
+    private boolean batchUpdateLimitationOpened = false;
+    private final Map<String, Integer> BATCH_UPDATE_LIMIT_MAP = new ConcurrentHashMap<>();//表名->批量更新上限
 
     @Override
     public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
@@ -120,6 +148,9 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
                     return;
                 }
             } catch (Exception e) {
+                if (e instanceof DataUpdateLimitationException) {
+                    throw (DataUpdateLimitationException) e;
+                }
                 logger.error("Unexpected error for mappedStatement={}, sql={}", ms.getId(), mpBs.sql(), e);
                 return;
             }
@@ -316,7 +347,13 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
             final ResultSetMetaData metaData = resultSet.getMetaData();
             int columnCount = metaData.getColumnCount();
             StringBuilder sb = new StringBuilder("[");
+            int count = 0;
             while (resultSet.next()) {
+                ++count;
+                if (checkTableBatchLimitExceeded(selectStmt, count)) {
+                    logger.error("batch delete limit exceed: count={}, BATCH_UPDATE_LIMIT={}", count, BATCH_UPDATE_LIMIT);
+                    throw DataUpdateLimitationException.DEFAULT;
+                }
                 sb.append("{");
                 for (int i = 1; i <= columnCount; ++i) {
                     sb.append("\"").append(metaData.getColumnName(i)).append("\":\"");
@@ -334,6 +371,9 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
             resultSet.close();
             return sb.toString();
         } catch (Exception e) {
+            if (e instanceof DataUpdateLimitationException) {
+                throw (DataUpdateLimitationException) e;
+            }
             logger.error("try to get record tobe deleted for selectStmt={}", selectStmt, e);
             return "failed to get original data";
         }
@@ -341,12 +381,17 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
 
     private OriginalDataObj buildOriginalObjectData(Select selectStmt, Column pk, MappedStatement mappedStatement, BoundSql boundSql, Connection connection) {
         try (PreparedStatement statement = connection.prepareStatement(selectStmt.toString())) {
-
             DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql);
             parameterHandler.setParameters(statement);
             ResultSet resultSet = statement.executeQuery();
             List<DataChangedRecord> originalObjectDatas = new LinkedList<>();
+            int count = 0;
             while (resultSet.next()) {
+                ++count;
+                if (checkTableBatchLimitExceeded(selectStmt, count)) {
+                    logger.error("batch update limit exceed: count={}, BATCH_UPDATE_LIMIT={}", count, BATCH_UPDATE_LIMIT);
+                    throw DataUpdateLimitationException.DEFAULT;
+                }
                 originalObjectDatas.add(prepareOriginalDataObj(resultSet, pk));
             }
             OriginalDataObj result = new OriginalDataObj();
@@ -354,11 +399,51 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
             resultSet.close();
             return result;
         } catch (Exception e) {
+            if (e instanceof DataUpdateLimitationException) {
+                throw (DataUpdateLimitationException) e;
+            }
             logger.error("try to get record tobe updated for selectStmt={}", selectStmt, e);
             return new OriginalDataObj();
         }
     }
 
+    /**
+     * 防止出现全表批量更新
+     * 默认一次更新不超过1000条
+     *
+     * @param selectStmt
+     * @param count
+     * @return
+     */
+    private boolean checkTableBatchLimitExceeded(Select selectStmt, int count) {
+        if (!batchUpdateLimitationOpened) {
+            return false;
+        }
+        final PlainSelect selectBody = (PlainSelect) selectStmt.getSelectBody();
+        final FromItem fromItem = selectBody.getFromItem();
+        if (fromItem instanceof Table) {
+            Table fromTable = (Table) fromItem;
+            final String tableName = fromTable.getName().toUpperCase();
+            if (!BATCH_UPDATE_LIMIT_MAP.containsKey(tableName)) {
+                if (count > BATCH_UPDATE_LIMIT) {
+                    logger.error("batch update limit exceed for tableName={}, BATCH_UPDATE_LIMIT={}, count={}",
+                        tableName, BATCH_UPDATE_LIMIT, count);
+                    return true;
+                }
+                return false;
+            }
+            final Integer limit = BATCH_UPDATE_LIMIT_MAP.get(tableName);
+            if (count > limit) {
+                logger.error("batch update limit exceed for configured tableName={}, BATCH_UPDATE_LIMIT={}, count={}",
+                    tableName, limit, count);
+                return true;
+            }
+            return false;
+        }
+        return count > BATCH_UPDATE_LIMIT;
+    }
+
+
     /**
      * get records : include related column with original data in DB
      *
@@ -460,6 +545,27 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
         return result;
     }
 
+    /**
+     * 设置批量更新记录条数上限
+     *
+     * @param limit
+     * @return
+     */
+    public DataChangeRecorderInnerInterceptor setBatchUpdateLimit(int limit) {
+        this.BATCH_UPDATE_LIMIT = limit;
+        return this;
+    }
+
+    public DataChangeRecorderInnerInterceptor openBatchUpdateLimitation() {
+        this.batchUpdateLimitationOpened = true;
+        return this;
+    }
+
+    public DataChangeRecorderInnerInterceptor configTableLimitation(String tableName, int limit) {
+        this.BATCH_UPDATE_LIMIT_MAP.put(tableName.toUpperCase(), limit);
+        return this;
+    }
+
     /**
      * ignoredColumns = TABLE_NAME1.COLUMN1,COLUMN2; TABLE2.COLUMN1,COLUMN2; TABLE3.*; *.COLUMN1,COLUMN2
      * 多个表用分号分隔
@@ -728,4 +834,15 @@ public class DataChangeRecorderInnerInterceptor implements InnerInterceptor {
             return obj.toString().replace("\"", "\\\"");
         }
     }
+
+    public static class DataUpdateLimitationException extends MybatisPlusException {
+
+        public DataUpdateLimitationException(String message) {
+            super(message);
+        }
+
+        public static DataUpdateLimitationException DEFAULT = new DataUpdateLimitationException("本次操作 因超过系统安全阈值 被拦截,如需继续,请联系管理员!");
+
+    }
+
 }

+ 76 - 15
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java

@@ -15,6 +15,32 @@
  */
 package com.baomidou.mybatisplus.test.h2;
 
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.AbstractList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.ibatis.plugin.Interceptor;
+import org.apache.ibatis.session.Configuration;
+import org.apache.ibatis.session.SqlSessionFactory;
+import org.apache.ibatis.session.defaults.DefaultSqlSessionFactory;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.MethodOrderer;
+import org.junit.jupiter.api.Order;
+import org.junit.jupiter.api.RepeatedTest;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestMethodOrder;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.dao.DataAccessException;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit.jupiter.SpringExtension;
+import org.springframework.transaction.annotation.Transactional;
+
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
@@ -24,26 +50,26 @@ import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
+import com.baomidou.mybatisplus.extension.plugins.MybatisPlusInterceptor;
+import com.baomidou.mybatisplus.extension.plugins.inner.DataChangeRecorderInnerInterceptor;
+import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
 import com.baomidou.mybatisplus.test.h2.entity.H2User;
 import com.baomidou.mybatisplus.test.h2.enums.AgeEnum;
 import com.baomidou.mybatisplus.test.h2.service.IH2UserService;
+
 import net.sf.jsqlparser.parser.CCJSqlParserUtil;
 import net.sf.jsqlparser.statement.select.Select;
-import org.junit.jupiter.api.*;
-import org.junit.jupiter.api.extension.ExtendWith;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.dao.DataAccessException;
-import org.springframework.test.context.ContextConfiguration;
-import org.springframework.test.context.junit.jupiter.SpringExtension;
-import org.springframework.transaction.annotation.Transactional;
-
-import java.math.BigDecimal;
-import java.math.RoundingMode;
-import java.util.*;
 
 /**
  * Mybatis Plus H2 Junit Test
+ * JDK 8 run test:
+ * <p>"Error: Could not create the Java Virtual Machine."</p>
+ * <p>Go to build.gradle: remove below configuration:</p>
+ * <p>
+ * //  jvmArgs += ["--add-opens", "java.base/java.lang=ALL-UNNAMED",
+ * //                    "--add-opens", "java.base/java.lang.invoke=ALL-UNNAMED"]
+ * </p>
  *
  * @author Caratacus
  * @since 2017/4/1
@@ -55,6 +81,24 @@ class H2UserTest extends BaseTest {
 
     @Autowired
     protected IH2UserService userService;
+    @Autowired
+    SqlSessionFactory sqlSessionFactory;
+
+    public void initBatchLimitation(int limitation) {
+        if (sqlSessionFactory instanceof DefaultSqlSessionFactory) {
+            Configuration configuration = sqlSessionFactory.getConfiguration();
+            for (Interceptor interceptor : configuration.getInterceptors()) {
+                if (interceptor instanceof MybatisPlusInterceptor) {
+                    List<InnerInterceptor> innerInterceptors = ((MybatisPlusInterceptor) interceptor).getInterceptors();
+                    for (InnerInterceptor innerInterceptor : innerInterceptors) {
+                        if (innerInterceptor instanceof DataChangeRecorderInnerInterceptor) {
+                            ((DataChangeRecorderInnerInterceptor) innerInterceptor).setBatchUpdateLimit(limitation).openBatchUpdateLimitation();
+                        }
+                    }
+                }
+            }
+        }
+    }
 
     @Test
     @Order(1)
@@ -231,7 +275,24 @@ class H2UserTest extends BaseTest {
             System.out.println(u.getName() + "," + u.getAge() + "," + u.getVersion());
             Assertions.assertEquals(u.getPrice().setScale(2, RoundingMode.HALF_UP).intValue(), BigDecimal.ZERO.setScale(2, RoundingMode.HALF_UP).intValue(), "all records should be updated");
         }
+        try {
+            initBatchLimitation(3);
+            userService.update(new H2User().setPrice(BigDecimal.ZERO), null);
+            Assertions.fail("SHOULD NOT REACH HERE");
+        } catch (Exception e) {
+            e.printStackTrace();
+            Assertions.assertTrue(checkIsDataUpdateLimitationException(e));
+        }
+    }
 
+    private boolean checkIsDataUpdateLimitationException(Throwable e) {
+        if (e instanceof DataChangeRecorderInnerInterceptor.DataUpdateLimitationException) {
+            return true;
+        }
+        if (e.getCause() == null) {
+            return false;
+        }
+        return checkIsDataUpdateLimitationException(e.getCause());
     }
 
     @Test
@@ -545,7 +606,7 @@ class H2UserTest extends BaseTest {
 //        userService.removeById("100000");
         userService.removeById(h2User);
         userService.removeByIds(Arrays.asList(10000L, h2User));
-        userService.removeByIds(Arrays.asList(10000L, h2User),false);
+        userService.removeByIds(Arrays.asList(10000L, h2User), false);
     }
 
     @Test
@@ -575,11 +636,11 @@ class H2UserTest extends BaseTest {
         userService.removeById(h2User, true);
         userService.removeById(h2User, false);
         userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User));
-        userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User),2);
+        userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User), 2);
         userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User), true);
         userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User), false);
-        userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User),2,true);
-        userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User),2,false);
+        userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User), 2, true);
+        userService.removeBatchByIds(Arrays.asList(1L, 2L, h2User), 2, false);
     }
 
     @Test