Browse Source

to qiuqiu

miemie 5 years ago
parent
commit
35aeb2e582

+ 112 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/parser/JsqlParserSupport.java

@@ -0,0 +1,112 @@
+package com.baomidou.mybatisplus.extension.parser;
+
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
+import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.statement.Statement;
+import net.sf.jsqlparser.statement.Statements;
+import net.sf.jsqlparser.statement.delete.Delete;
+import net.sf.jsqlparser.statement.insert.Insert;
+import net.sf.jsqlparser.statement.select.Select;
+import net.sf.jsqlparser.statement.update.Update;
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
+
+/**
+ * @author miemie
+ * @since 2020-06-22
+ */
+public abstract class JsqlParserSupport {
+
+    /**
+     * 日志
+     */
+    protected final Log logger = LogFactory.getLog(this.getClass());
+
+    protected String parserSingle(String sql) {
+        if (logger.isDebugEnabled()) {
+            logger.debug("Original SQL: " + sql);
+        }
+        try {
+            Statement statement = CCJSqlParserUtil.parse(sql);
+            return processParser(statement);
+        } catch (JSQLParserException e) {
+            throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, sql);
+        }
+    }
+
+    protected String parserMulti(String sql) {
+        if (logger.isDebugEnabled()) {
+            logger.debug("Original SQL: " + sql);
+        }
+        try {
+            // fixed github pull/295
+            StringBuilder sb = new StringBuilder();
+            Statements statements = CCJSqlParserUtil.parseStatements(sql);
+            int i = 0;
+            for (Statement statement : statements.getStatements()) {
+                if (null != statement) {
+                    if (i++ > 0) {
+                        sb.append(StringPool.SEMICOLON);
+                    }
+                    sb.append(processParser(statement));
+                }
+            }
+            return sb.toString();
+        } catch (JSQLParserException e) {
+            throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, sql);
+        }
+    }
+
+    /**
+     * 执行 SQL 解析
+     *
+     * @param statement JsqlParser Statement
+     * @return sql
+     */
+    public String processParser(Statement statement) {
+        if (statement instanceof Insert) {
+            this.processInsert((Insert) statement);
+        } else if (statement instanceof Select) {
+            this.processSelect((Select) statement);
+        } else if (statement instanceof Update) {
+            this.processUpdate((Update) statement);
+        } else if (statement instanceof Delete) {
+            this.processDelete((Delete) statement);
+        }
+        final String sql = statement.toString();
+        if (logger.isDebugEnabled()) {
+            logger.debug("parser sql: " + sql);
+        }
+        return sql;
+    }
+
+    /**
+     * 新增
+     */
+    protected void processInsert(Insert insert) {
+        throw new UnsupportedOperationException();
+    }
+
+    /**
+     * 删除
+     */
+    protected void processDelete(Delete delete) {
+        throw new UnsupportedOperationException();
+    }
+
+    /**
+     * 更新
+     */
+    protected void processUpdate(Update update) {
+        throw new UnsupportedOperationException();
+    }
+
+    /**
+     * 查询
+     */
+    protected void processSelect(Select select) {
+        throw new UnsupportedOperationException();
+    }
+}

+ 1 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/MybatisPlusInterceptor.java

@@ -60,7 +60,7 @@ public class MybatisPlusInterceptor implements Interceptor {
                 }
                 CacheKey cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
                 return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
-            } else if (isUpdate && ms.getSqlCommandType() == SqlCommandType.UPDATE) {
+            } else if (isUpdate) {
                 for (QiuQiu query : qiuQius) {
                     query.update(executor, ms, parameter);
                 }

+ 41 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/BlockAttackQiuQiu.java

@@ -0,0 +1,41 @@
+package com.baomidou.mybatisplus.extension.plugins.chain;
+
+import com.baomidou.mybatisplus.core.toolkit.Assert;
+import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
+import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
+import net.sf.jsqlparser.statement.delete.Delete;
+import net.sf.jsqlparser.statement.update.Update;
+import org.apache.ibatis.executor.statement.StatementHandler;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
+
+import java.sql.Connection;
+
+/**
+ * @author miemie
+ * @since 2020-06-22
+ */
+public class BlockAttackQiuQiu extends JsqlParserSupport implements QiuQiu {
+
+    @Override
+    public void prepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
+        PluginUtils.MPStatementHandler handler = PluginUtils.mpStatementHandler(sh);
+        MappedStatement ms = handler.mappedStatement();
+        SqlCommandType sct = ms.getSqlCommandType();
+        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
+            BoundSql boundSql = handler.boundSql();
+            parserMulti(boundSql.getSql());
+        }
+    }
+
+    @Override
+    protected void processDelete(Delete delete) {
+        Assert.notNull(delete.getWhere(), "Prohibition of full table deletion");
+    }
+
+    @Override
+    protected void processUpdate(Update update) {
+        Assert.notNull(update.getWhere(), "Prohibition of table update operation");
+    }
+}

+ 14 - 54
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/TenantQiuQiu.java

@@ -4,11 +4,11 @@ import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
 import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
+import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
 import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler;
 import lombok.Data;
 import lombok.RequiredArgsConstructor;
 import lombok.experimental.Accessors;
-import net.sf.jsqlparser.JSQLParserException;
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.Parenthesis;
@@ -16,19 +16,14 @@ import net.sf.jsqlparser.expression.ValueListExpression;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
 import net.sf.jsqlparser.expression.operators.relational.*;
-import net.sf.jsqlparser.parser.CCJSqlParserUtil;
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
-import net.sf.jsqlparser.statement.Statement;
-import net.sf.jsqlparser.statement.Statements;
 import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.insert.Insert;
 import net.sf.jsqlparser.statement.select.*;
 import net.sf.jsqlparser.statement.update.Update;
 import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.executor.statement.StatementHandler;
-import org.apache.ibatis.logging.Log;
-import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
 import org.apache.ibatis.mapping.SqlCommandType;
@@ -47,8 +42,7 @@ import java.util.List;
 @Accessors(chain = true)
 @RequiredArgsConstructor
 @SuppressWarnings({"rawtypes"})
-public class TenantQiuQiu implements QiuQiu {
-    protected final Log logger = LogFactory.getLog(this.getClass());
+public class TenantQiuQiu extends JsqlParserSupport implements QiuQiu {
 
     private final TenantHandler tenantHandler;
 
@@ -58,15 +52,7 @@ public class TenantQiuQiu implements QiuQiu {
             return;
         }
         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
-        String sql = mpBs.sql();
-        try {
-            Select select = (Select) CCJSqlParserUtil.parse(sql);
-            SelectBody selectBody = select.getSelectBody();
-            processSelectBody(selectBody);
-            mpBs.sql(select.toString());
-        } catch (JSQLParserException e) {
-            throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, sql);
-        }
+        mpBs.sql(parserSingle(mpBs.sql()));
     }
 
     @Override
@@ -76,47 +62,18 @@ public class TenantQiuQiu implements QiuQiu {
         if (SqlParserHelper.getSqlParserInfo(ms)) {
             return;
         }
-        SqlCommandType sqlCommandType = ms.getSqlCommandType();
-        final boolean insert = sqlCommandType == SqlCommandType.INSERT;
-        final boolean update = sqlCommandType == SqlCommandType.UPDATE;
-        final boolean delete = sqlCommandType == SqlCommandType.DELETE;
-        if (insert || update || delete) {
+        SqlCommandType sct = ms.getSqlCommandType();
+        if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
             PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
-            String sql = mpBs.sql();
-            try {
-                if (logger.isDebugEnabled()) {
-                    logger.debug("Original SQL: " + sql);
-                }
-                // fixed github pull/295
-                StringBuilder sb = new StringBuilder();
-                Statements statements = CCJSqlParserUtil.parseStatements(sql);
-                int i = 0;
-                for (Statement statement : statements.getStatements()) {
-                    if (null != statement) {
-                        if (i++ > 0) {
-                            sb.append(';');
-                        }
-                        if (insert) {
-                            processInsert((Insert) statement);
-                            sb.append(statement.toString());
-                        } else if (update) {
-                            processUpdate((Update) statement);
-                            sb.append(statement.toString());
-                        } else {
-                            processDelete((Delete) statement);
-                            sb.append(statement.toString());
-                        }
-                    }
-                }
-                if (sb.length() > 0) {
-                    mpBs.sql(sb.toString());
-                }
-            } catch (JSQLParserException e) {
-                throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, sql);
-            }
+            mpBs.sql(parserMulti(mpBs.sql()));
         }
     }
 
+    @Override
+    protected void processSelect(Select select) {
+        processSelectBody(select.getSelectBody());
+    }
+
     protected void processSelectBody(SelectBody selectBody) {
         if (selectBody instanceof PlainSelect) {
             processPlainSelect((PlainSelect) selectBody);
@@ -133,6 +90,7 @@ public class TenantQiuQiu implements QiuQiu {
         }
     }
 
+    @Override
     protected void processInsert(Insert insert) {
         if (tenantHandler.doTableFilter(insert.getTable().getName())) {
             // 过滤退出执行
@@ -157,6 +115,7 @@ public class TenantQiuQiu implements QiuQiu {
     /**
      * update 语句处理
      */
+    @Override
     protected void processUpdate(Update update) {
         final Table table = update.getTable();
         if (tenantHandler.doTableFilter(table.getName())) {
@@ -169,6 +128,7 @@ public class TenantQiuQiu implements QiuQiu {
     /**
      * delete 语句处理
      */
+    @Override
     protected void processDelete(Delete delete) {
         if (tenantHandler.doTableFilter(delete.getTable().getName())) {
             // 过滤退出执行