瀏覽代碼

to qiuqiu

miemie 5 年之前
父節點
當前提交
e682ffaf0a

+ 15 - 17
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/parser/JsqlParserSupport.java

@@ -24,19 +24,19 @@ public abstract class JsqlParserSupport {
      */
     protected final Log logger = LogFactory.getLog(this.getClass());
 
-    protected String parserSingle(String sql) {
+    protected String parserSingle(String sql, Object obj) {
         if (logger.isDebugEnabled()) {
             logger.debug("Original SQL: " + sql);
         }
         try {
             Statement statement = CCJSqlParserUtil.parse(sql);
-            return processParser(statement);
+            return processParser(statement, 0, obj);
         } catch (JSQLParserException e) {
             throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, sql);
         }
     }
 
-    protected String parserMulti(String sql) {
+    protected String parserMulti(String sql, Object obj) {
         if (logger.isDebugEnabled()) {
             logger.debug("Original SQL: " + sql);
         }
@@ -46,12 +46,10 @@ public abstract class JsqlParserSupport {
             Statements statements = CCJSqlParserUtil.parseStatements(sql);
             int i = 0;
             for (Statement statement : statements.getStatements()) {
-                if (null != statement) {
-                    if (i++ > 0) {
-                        sb.append(StringPool.SEMICOLON);
-                    }
-                    sb.append(processParser(statement));
+                if (i++ > 0) {
+                    sb.append(StringPool.SEMICOLON);
                 }
+                sb.append(processParser(statement, i, obj));
             }
             return sb.toString();
         } catch (JSQLParserException e) {
@@ -65,15 +63,15 @@ public abstract class JsqlParserSupport {
      * @param statement JsqlParser Statement
      * @return sql
      */
-    public String processParser(Statement statement) {
+    public String processParser(Statement statement, int index, Object obj) {
         if (statement instanceof Insert) {
-            this.processInsert((Insert) statement);
+            this.processInsert((Insert) statement, index, obj);
         } else if (statement instanceof Select) {
-            this.processSelect((Select) statement);
+            this.processSelect((Select) statement, index, obj);
         } else if (statement instanceof Update) {
-            this.processUpdate((Update) statement);
+            this.processUpdate((Update) statement, index, obj);
         } else if (statement instanceof Delete) {
-            this.processDelete((Delete) statement);
+            this.processDelete((Delete) statement, index, obj);
         }
         final String sql = statement.toString();
         if (logger.isDebugEnabled()) {
@@ -85,28 +83,28 @@ public abstract class JsqlParserSupport {
     /**
      * 新增
      */
-    protected void processInsert(Insert insert) {
+    protected void processInsert(Insert insert, int index, Object obj) {
         throw new UnsupportedOperationException();
     }
 
     /**
      * 删除
      */
-    protected void processDelete(Delete delete) {
+    protected void processDelete(Delete delete, int index, Object obj) {
         throw new UnsupportedOperationException();
     }
 
     /**
      * 更新
      */
-    protected void processUpdate(Update update) {
+    protected void processUpdate(Update update, int index, Object obj) {
         throw new UnsupportedOperationException();
     }
 
     /**
      * 查询
      */
-    protected void processSelect(Select select) {
+    protected void processSelect(Select select, int index, Object obj) {
         throw new UnsupportedOperationException();
     }
 }

+ 3 - 3
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/BlockAttackQiuQiu.java

@@ -25,17 +25,17 @@ public class BlockAttackQiuQiu extends JsqlParserSupport implements QiuQiu {
         SqlCommandType sct = ms.getSqlCommandType();
         if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
             BoundSql boundSql = handler.boundSql();
-            parserMulti(boundSql.getSql());
+            parserMulti(boundSql.getSql(), null);
         }
     }
 
     @Override
-    protected void processDelete(Delete delete) {
+    protected void processDelete(Delete delete, int index, Object obj) {
         Assert.notNull(delete.getWhere(), "Prohibition of full table deletion");
     }
 
     @Override
-    protected void processUpdate(Update update) {
+    protected void processUpdate(Update update, int index, Object obj) {
         Assert.notNull(update.getWhere(), "Prohibition of table update operation");
     }
 }

+ 312 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/IllegalSQLQiuQiu.java

@@ -0,0 +1,312 @@
+package com.baomidou.mybatisplus.extension.plugins.chain;
+
+import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
+import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
+import com.baomidou.mybatisplus.core.toolkit.Assert;
+import com.baomidou.mybatisplus.core.toolkit.EncryptUtils;
+import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
+import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
+import lombok.Data;
+import net.sf.jsqlparser.expression.BinaryExpression;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.Function;
+import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
+import net.sf.jsqlparser.expression.operators.relational.InExpression;
+import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
+import net.sf.jsqlparser.schema.Column;
+import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.delete.Delete;
+import net.sf.jsqlparser.statement.select.Join;
+import net.sf.jsqlparser.statement.select.PlainSelect;
+import net.sf.jsqlparser.statement.select.Select;
+import net.sf.jsqlparser.statement.select.SubSelect;
+import net.sf.jsqlparser.statement.update.Update;
+import org.apache.ibatis.executor.statement.StatementHandler;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
+
+import java.sql.Connection;
+import java.sql.DatabaseMetaData;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * @author miemie
+ * @since 2020-06-22
+ */
+public class IllegalSQLQiuQiu extends JsqlParserSupport implements QiuQiu {
+
+    /**
+     * 缓存验证结果,提高性能
+     */
+    private static final Set<String> cacheValidResult = new HashSet<>();
+    /**
+     * 缓存表的索引信息
+     */
+    private static final Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<>();
+
+    @Override
+    public void prepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
+        PluginUtils.MPStatementHandler mpStatementHandler = PluginUtils.mpStatementHandler(sh);
+        MappedStatement ms = mpStatementHandler.mappedStatement();
+        SqlCommandType sct = ms.getSqlCommandType();
+        if (sct == SqlCommandType.INSERT || SqlParserHelper.getSqlParserInfo(ms)) {
+            return;
+        }
+        BoundSql boundSql = mpStatementHandler.boundSql();
+        String originalSql = boundSql.getSql();
+        logger.debug("检查SQL是否合规,SQL:" + originalSql);
+        String md5Base64 = EncryptUtils.md5Base64(originalSql);
+        if (cacheValidResult.contains(md5Base64)) {
+            logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
+            return;
+        }
+        parserSingle(originalSql, connection);
+        //缓存验证结果
+        cacheValidResult.add(md5Base64);
+    }
+
+    @Override
+    protected void processSelect(Select select, int index, Object obj) {
+        PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
+        Expression where = plainSelect.getWhere();
+        Assert.notNull(where, "非法SQL,必须要有where条件");
+        Table table = (Table) plainSelect.getFromItem();
+        List<Join> joins = plainSelect.getJoins();
+        validWhere(where, table, (Connection) obj);
+        validJoins(joins, table, (Connection) obj);
+    }
+
+    @Override
+    protected void processUpdate(Update update, int index, Object obj) {
+        Expression where = update.getWhere();
+        Assert.notNull(where, "非法SQL,必须要有where条件");
+        Table table = update.getTable();
+        List<Join> joins = update.getJoins();
+        validWhere(where, table, (Connection) obj);
+        validJoins(joins, table, (Connection) obj);
+    }
+
+    @Override
+    protected void processDelete(Delete delete, int index, Object obj) {
+        Expression where = delete.getWhere();
+        Assert.notNull(where, "非法SQL,必须要有where条件");
+        Table table = delete.getTable();
+        List<Join> joins = delete.getJoins();
+        validWhere(where, table, (Connection) obj);
+        validJoins(joins, table, (Connection) obj);
+    }
+
+    /**
+     * 验证expression对象是不是 or、not等等
+     *
+     * @param expression ignore
+     */
+    private void validExpression(Expression expression) {
+        //where条件使用了 or 关键字
+        if (expression instanceof OrExpression) {
+            OrExpression orExpression = (OrExpression) expression;
+            throw new MybatisPlusException("非法SQL,where条件中不能使用【or】关键字,错误or信息:" + orExpression.toString());
+        } else if (expression instanceof NotEqualsTo) {
+            NotEqualsTo notEqualsTo = (NotEqualsTo) expression;
+            throw new MybatisPlusException("非法SQL,where条件中不能使用【!=】关键字,错误!=信息:" + notEqualsTo.toString());
+        } else if (expression instanceof BinaryExpression) {
+            BinaryExpression binaryExpression = (BinaryExpression) expression;
+            // TODO 升级 jsqlparser 后待实现
+//            if (binaryExpression.isNot()) {
+//                throw new MybatisPlusException("非法SQL,where条件中不能使用【not】关键字,错误not信息:" + binaryExpression.toString());
+//            }
+            if (binaryExpression.getLeftExpression() instanceof Function) {
+                Function function = (Function) binaryExpression.getLeftExpression();
+                throw new MybatisPlusException("非法SQL,where条件中不能使用数据库函数,错误函数信息:" + function.toString());
+            }
+            if (binaryExpression.getRightExpression() instanceof SubSelect) {
+                SubSelect subSelect = (SubSelect) binaryExpression.getRightExpression();
+                throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
+            }
+        } else if (expression instanceof InExpression) {
+            InExpression inExpression = (InExpression) expression;
+            if (inExpression.getRightItemsList() instanceof SubSelect) {
+                SubSelect subSelect = (SubSelect) inExpression.getRightItemsList();
+                throw new MybatisPlusException("非法SQL,where条件中不能使用子查询,错误子查询SQL信息:" + subSelect.toString());
+            }
+        }
+
+    }
+
+    /**
+     * 如果SQL用了 left Join,验证是否有or、not等等,并且验证是否使用了索引
+     *
+     * @param joins      ignore
+     * @param table      ignore
+     * @param connection ignore
+     */
+    private void validJoins(List<Join> joins, Table table, Connection connection) {
+        //允许执行join,验证jion是否使用索引等等
+        if (joins != null) {
+            for (Join join : joins) {
+                Table rightTable = (Table) join.getRightItem();
+                Expression expression = join.getOnExpression();
+                validWhere(expression, table, rightTable, connection);
+            }
+        }
+    }
+
+    /**
+     * 检查是否使用索引
+     *
+     * @param table      ignore
+     * @param columnName ignore
+     * @param connection ignore
+     */
+    private void validUseIndex(Table table, String columnName, Connection connection) {
+        //是否使用索引
+        boolean useIndexFlag = false;
+
+        String tableInfo = table.getName();
+        //表存在的索引
+        String dbName = null;
+        String tableName;
+        String[] tableArray = tableInfo.split("\\.");
+        if (tableArray.length == 1) {
+            tableName = tableArray[0];
+        } else {
+            dbName = tableArray[0];
+            tableName = tableArray[1];
+        }
+        List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
+        for (IndexInfo indexInfo : indexInfos) {
+            if (null != columnName && columnName.equalsIgnoreCase(indexInfo.getColumnName())) {
+                useIndexFlag = true;
+                break;
+            }
+        }
+        if (!useIndexFlag) {
+            throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
+        }
+    }
+
+    /**
+     * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
+     *
+     * @param expression ignore
+     * @param table      ignore
+     * @param connection ignore
+     */
+    private void validWhere(Expression expression, Table table, Connection connection) {
+        validWhere(expression, table, null, connection);
+    }
+
+    /**
+     * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
+     *
+     * @param expression ignore
+     * @param table      ignore
+     * @param joinTable  ignore
+     * @param connection ignore
+     */
+    private void validWhere(Expression expression, Table table, Table joinTable, Connection connection) {
+        validExpression(expression);
+        if (expression instanceof BinaryExpression) {
+            //获得左边表达式
+            Expression leftExpression = ((BinaryExpression) expression).getLeftExpression();
+            validExpression(leftExpression);
+
+            //如果左边表达式为Column对象,则直接获得列名
+            if (leftExpression instanceof Column) {
+                Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
+                if (joinTable != null && rightExpression instanceof Column) {
+                    if (Objects.equals(((Column) rightExpression).getTable().getName(), table.getAlias().getName())) {
+                        validUseIndex(table, ((Column) rightExpression).getColumnName(), connection);
+                        validUseIndex(joinTable, ((Column) leftExpression).getColumnName(), connection);
+                    } else {
+                        validUseIndex(joinTable, ((Column) rightExpression).getColumnName(), connection);
+                        validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
+                    }
+                } else {
+                    //获得列名
+                    validUseIndex(table, ((Column) leftExpression).getColumnName(), connection);
+                }
+            }
+            //如果BinaryExpression,进行迭代
+            else if (leftExpression instanceof BinaryExpression) {
+                validWhere(leftExpression, table, joinTable, connection);
+            }
+
+            //获得右边表达式,并分解
+            Expression rightExpression = ((BinaryExpression) expression).getRightExpression();
+            validExpression(rightExpression);
+        }
+    }
+
+    /**
+     * 得到表的索引信息
+     *
+     * @param dbName    ignore
+     * @param tableName ignore
+     * @param conn      ignore
+     * @return ignore
+     */
+    public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
+        return getIndexInfos(null, dbName, tableName, conn);
+    }
+
+    /**
+     * 得到表的索引信息
+     *
+     * @param key       ignore
+     * @param dbName    ignore
+     * @param tableName ignore
+     * @param conn      ignore
+     * @return ignore
+     */
+    public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
+        List<IndexInfo> indexInfos = null;
+        if (StringUtils.isNotBlank(key)) {
+            indexInfos = indexInfoMap.get(key);
+        }
+        if (indexInfos == null || indexInfos.isEmpty()) {
+            ResultSet rs;
+            try {
+                DatabaseMetaData metadata = conn.getMetaData();
+                String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
+                String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
+                rs = metadata.getIndexInfo(catalog, schema, tableName, false, true);
+                indexInfos = new ArrayList<>();
+                while (rs.next()) {
+                    //索引中的列序列号等于1,才有效
+                    if (Objects.equals(rs.getString(8), "1")) {
+                        IndexInfo indexInfo = new IndexInfo();
+                        indexInfo.setDbName(rs.getString(1));
+                        indexInfo.setTableName(rs.getString(3));
+                        indexInfo.setColumnName(rs.getString(9));
+                        indexInfos.add(indexInfo);
+                    }
+                }
+                if (StringUtils.isNotBlank(key)) {
+                    indexInfoMap.put(key, indexInfos);
+                }
+            } catch (SQLException e) {
+                e.printStackTrace();
+            }
+        }
+        return indexInfos;
+    }
+
+    /**
+     * 索引对象
+     */
+    @Data
+    private static class IndexInfo {
+
+        private String dbName;
+
+        private String tableName;
+
+        private String columnName;
+    }
+}

+ 6 - 6
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/TenantQiuQiu.java

@@ -52,7 +52,7 @@ public class TenantQiuQiu extends JsqlParserSupport implements QiuQiu {
             return;
         }
         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
-        mpBs.sql(parserSingle(mpBs.sql()));
+        mpBs.sql(parserSingle(mpBs.sql(), null));
     }
 
     @Override
@@ -65,12 +65,12 @@ public class TenantQiuQiu extends JsqlParserSupport implements QiuQiu {
         SqlCommandType sct = ms.getSqlCommandType();
         if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
             PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
-            mpBs.sql(parserMulti(mpBs.sql()));
+            mpBs.sql(parserMulti(mpBs.sql(), null));
         }
     }
 
     @Override
-    protected void processSelect(Select select) {
+    protected void processSelect(Select select, int index, Object obj) {
         processSelectBody(select.getSelectBody());
     }
 
@@ -91,7 +91,7 @@ public class TenantQiuQiu extends JsqlParserSupport implements QiuQiu {
     }
 
     @Override
-    protected void processInsert(Insert insert) {
+    protected void processInsert(Insert insert, int index, Object obj) {
         if (tenantHandler.doTableFilter(insert.getTable().getName())) {
             // 过滤退出执行
             return;
@@ -116,7 +116,7 @@ public class TenantQiuQiu extends JsqlParserSupport implements QiuQiu {
      * update 语句处理
      */
     @Override
-    protected void processUpdate(Update update) {
+    protected void processUpdate(Update update, int index, Object obj) {
         final Table table = update.getTable();
         if (tenantHandler.doTableFilter(table.getName())) {
             // 过滤退出执行
@@ -129,7 +129,7 @@ public class TenantQiuQiu extends JsqlParserSupport implements QiuQiu {
      * delete 语句处理
      */
     @Override
-    protected void processDelete(Delete delete) {
+    protected void processDelete(Delete delete, int index, Object obj) {
         if (tenantHandler.doTableFilter(delete.getTable().getName())) {
             // 过滤退出执行
             return;