浏览代码

to qiuqiu

miemie 5 年之前
父节点
当前提交
14725a55e6

+ 13 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/parser/SqlParserHelper.java

@@ -75,6 +75,19 @@ public class SqlParserHelper {
         }
     }
 
+    /**
+     * 获取 SqlParser 注解信息
+     */
+    public static boolean getSqlParserInfo(MappedStatement ms) {
+        String id = ms.getId();
+        Boolean value = SQL_PARSER_INFO_CACHE.get(id);
+        if (value != null) {
+            return value;
+        }
+        String mapperName = id.substring(0, id.lastIndexOf(StringPool.DOT));
+        return SQL_PARSER_INFO_CACHE.getOrDefault(mapperName, false);
+    }
+
     /**
      * 获取 SqlParser 注解信息
      *

+ 311 - 3
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/TenantQiuQiu.java

@@ -1,24 +1,43 @@
 package com.baomidou.mybatisplus.extension.plugins.chain;
 
+import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
 import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler;
 import lombok.Data;
+import lombok.RequiredArgsConstructor;
 import lombok.experimental.Accessors;
 import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.expression.BinaryExpression;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.Parenthesis;
+import net.sf.jsqlparser.expression.ValueListExpression;
+import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
+import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
+import net.sf.jsqlparser.expression.operators.relational.*;
 import net.sf.jsqlparser.parser.CCJSqlParserUtil;
-import net.sf.jsqlparser.statement.select.Select;
+import net.sf.jsqlparser.schema.Column;
+import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.Statement;
+import net.sf.jsqlparser.statement.Statements;
+import net.sf.jsqlparser.statement.delete.Delete;
+import net.sf.jsqlparser.statement.insert.Insert;
+import net.sf.jsqlparser.statement.select.*;
+import net.sf.jsqlparser.statement.update.Update;
 import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
 import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
 import java.sql.Connection;
 import java.sql.SQLException;
+import java.util.List;
 
 /**
  * @author miemie
@@ -26,18 +45,25 @@ import java.sql.SQLException;
  */
 @Data
 @Accessors(chain = true)
+@RequiredArgsConstructor
 @SuppressWarnings({"rawtypes"})
 public class TenantQiuQiu implements QiuQiu {
     protected final Log logger = LogFactory.getLog(this.getClass());
 
-    private TenantHandler tenantHandler;
+    private final TenantHandler tenantHandler;
 
     @Override
     public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
+        if (SqlParserHelper.getSqlParserInfo(ms)) {
+            return;
+        }
         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
         String sql = mpBs.sql();
         try {
             Select select = (Select) CCJSqlParserUtil.parse(sql);
+            SelectBody selectBody = select.getSelectBody();
+            processSelectBody(selectBody);
+            mpBs.sql(select.toString());
         } catch (JSQLParserException e) {
             throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, sql);
         }
@@ -47,6 +73,288 @@ public class TenantQiuQiu implements QiuQiu {
     public void prepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
         PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
         MappedStatement ms = mpSh.mappedStatement();
-        PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
+        if (SqlParserHelper.getSqlParserInfo(ms)) {
+            return;
+        }
+        SqlCommandType sqlCommandType = ms.getSqlCommandType();
+        final boolean insert = sqlCommandType == SqlCommandType.INSERT;
+        final boolean update = sqlCommandType == SqlCommandType.UPDATE;
+        final boolean delete = sqlCommandType == SqlCommandType.DELETE;
+        if (insert || update || delete) {
+            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
+            String sql = mpBs.sql();
+            try {
+                if (logger.isDebugEnabled()) {
+                    logger.debug("Original SQL: " + sql);
+                }
+                // fixed github pull/295
+                StringBuilder sb = new StringBuilder();
+                Statements statements = CCJSqlParserUtil.parseStatements(sql);
+                int i = 0;
+                for (Statement statement : statements.getStatements()) {
+                    if (null != statement) {
+                        if (i++ > 0) {
+                            sb.append(';');
+                        }
+                        if (insert) {
+                            processInsert((Insert) statement);
+                            sb.append(statement.toString());
+                        } else if (update) {
+                            processUpdate((Update) statement);
+                            sb.append(statement.toString());
+                        } else {
+                            processDelete((Delete) statement);
+                            sb.append(statement.toString());
+                        }
+                    }
+                }
+                if (sb.length() > 0) {
+                    mpBs.sql(sb.toString());
+                }
+            } catch (JSQLParserException e) {
+                throw ExceptionUtils.mpe("Failed to process, please exclude the tableName or statementId.\n Error SQL: %s", e, sql);
+            }
+        }
+    }
+
+    protected void processSelectBody(SelectBody selectBody) {
+        if (selectBody instanceof PlainSelect) {
+            processPlainSelect((PlainSelect) selectBody);
+        } else if (selectBody instanceof WithItem) {
+            WithItem withItem = (WithItem) selectBody;
+            if (withItem.getSelectBody() != null) {
+                processSelectBody(withItem.getSelectBody());
+            }
+        } else {
+            SetOperationList operationList = (SetOperationList) selectBody;
+            if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
+                operationList.getSelects().forEach(this::processSelectBody);
+            }
+        }
+    }
+
+    protected void processInsert(Insert insert) {
+        if (tenantHandler.doTableFilter(insert.getTable().getName())) {
+            // 过滤退出执行
+            return;
+        }
+        insert.getColumns().add(new Column(tenantHandler.getTenantIdColumn()));
+        if (insert.getSelect() != null) {
+            processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
+        } else if (insert.getItemsList() != null) {
+            // fixed github pull/295
+            ItemsList itemsList = insert.getItemsList();
+            if (itemsList instanceof MultiExpressionList) {
+                ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantHandler.getTenantId(false)));
+            } else {
+                ((ExpressionList) insert.getItemsList()).getExpressions().add(tenantHandler.getTenantId(false));
+            }
+        } else {
+            throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
+        }
+    }
+
+    /**
+     * update 语句处理
+     */
+    protected void processUpdate(Update update) {
+        final Table table = update.getTable();
+        if (tenantHandler.doTableFilter(table.getName())) {
+            // 过滤退出执行
+            return;
+        }
+        update.setWhere(this.andExpression(table, update.getWhere()));
+    }
+
+    /**
+     * delete 语句处理
+     */
+    protected void processDelete(Delete delete) {
+        if (tenantHandler.doTableFilter(delete.getTable().getName())) {
+            // 过滤退出执行
+            return;
+        }
+        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
+    }
+
+    /**
+     * delete update 语句 where 处理
+     */
+    protected BinaryExpression andExpression(Table table, Expression where) {
+        //获得where条件表达式
+        EqualsTo equalsTo = new EqualsTo();
+        equalsTo.setLeftExpression(this.getAliasColumn(table));
+        equalsTo.setRightExpression(tenantHandler.getTenantId(false));
+        if (null != where) {
+            if (where instanceof OrExpression) {
+                return new AndExpression(equalsTo, new Parenthesis(where));
+            } else {
+                return new AndExpression(equalsTo, where);
+            }
+        }
+        return equalsTo;
+    }
+
+    /**
+     * 处理 PlainSelect
+     */
+    protected void processPlainSelect(PlainSelect plainSelect) {
+        processPlainSelect(plainSelect, false);
+    }
+
+    /**
+     * 处理 PlainSelect
+     *
+     * @param plainSelect ignore
+     * @param addColumn   是否添加租户列,insert into select语句中需要
+     */
+    protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
+        FromItem fromItem = plainSelect.getFromItem();
+        if (fromItem instanceof Table) {
+            Table fromTable = (Table) fromItem;
+            if (!tenantHandler.doTableFilter(fromTable.getName())) {
+                //#1186 github
+                plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
+                if (addColumn) {
+                    plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(tenantHandler.getTenantIdColumn())));
+                }
+            }
+        } else {
+            processFromItem(fromItem);
+        }
+        List<Join> joins = plainSelect.getJoins();
+        if (joins != null && joins.size() > 0) {
+            joins.forEach(j -> {
+                processJoin(j);
+                processFromItem(j.getRightItem());
+            });
+        }
+    }
+
+    /**
+     * 处理子查询等
+     */
+    protected void processFromItem(FromItem fromItem) {
+        if (fromItem instanceof SubJoin) {
+            SubJoin subJoin = (SubJoin) fromItem;
+            if (subJoin.getJoinList() != null) {
+                subJoin.getJoinList().forEach(this::processJoin);
+            }
+            if (subJoin.getLeft() != null) {
+                processFromItem(subJoin.getLeft());
+            }
+        } else if (fromItem instanceof SubSelect) {
+            SubSelect subSelect = (SubSelect) fromItem;
+            if (subSelect.getSelectBody() != null) {
+                processSelectBody(subSelect.getSelectBody());
+            }
+        } else if (fromItem instanceof ValuesList) {
+            logger.debug("Perform a subquery, if you do not give us feedback");
+        } else if (fromItem instanceof LateralSubSelect) {
+            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
+            if (lateralSubSelect.getSubSelect() != null) {
+                SubSelect subSelect = lateralSubSelect.getSubSelect();
+                if (subSelect.getSelectBody() != null) {
+                    processSelectBody(subSelect.getSelectBody());
+                }
+            }
+        }
+    }
+
+    /**
+     * 处理联接语句
+     */
+    protected void processJoin(Join join) {
+        if (join.getRightItem() instanceof Table) {
+            Table fromTable = (Table) join.getRightItem();
+            if (this.tenantHandler.doTableFilter(fromTable.getName())) {
+                // 过滤退出执行
+                return;
+            }
+            join.setOnExpression(builderExpression(join.getOnExpression(), fromTable));
+        }
+    }
+
+    /**
+     * 处理条件:
+     * 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
+     * 默认tenantId的表达式: LongValue(1)这种依旧支持
+     */
+    protected Expression builderExpression(Expression currentExpression, Table table) {
+        final Expression tenantExpression = tenantHandler.getTenantId(true);
+        Expression appendExpression = this.processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
+        if (currentExpression == null) {
+            return appendExpression;
+        }
+        if (currentExpression instanceof BinaryExpression) {
+            BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
+            doExpression(binaryExpression.getLeftExpression());
+            doExpression(binaryExpression.getRightExpression());
+        } else if (currentExpression instanceof InExpression) {
+            InExpression inExp = (InExpression) currentExpression;
+            ItemsList rightItems = inExp.getRightItemsList();
+            if (rightItems instanceof SubSelect) {
+                processSelectBody(((SubSelect) rightItems).getSelectBody());
+            }
+        }
+        if (currentExpression instanceof OrExpression) {
+            return new AndExpression(new Parenthesis(currentExpression), appendExpression);
+        } else {
+            return new AndExpression(currentExpression, appendExpression);
+        }
+    }
+
+    protected void doExpression(Expression expression) {
+        if (expression instanceof FromItem) {
+            processFromItem((FromItem) expression);
+        } else if (expression instanceof InExpression) {
+            InExpression inExp = (InExpression) expression;
+            ItemsList rightItems = inExp.getRightItemsList();
+            if (rightItems instanceof SubSelect) {
+                processSelectBody(((SubSelect) rightItems).getSelectBody());
+            }
+        }
+    }
+
+    /**
+     * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
+     * select a.id, b.name
+     * from a
+     * join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
+     *
+     * @param expression
+     * @param table
+     * @return 加上别名的多租户字段表达式
+     */
+    protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
+        Expression target;
+        if (expression instanceof ValueListExpression) {
+            InExpression inExpression = new InExpression();
+            inExpression.setLeftExpression(this.getAliasColumn(table));
+            inExpression.setRightItemsList(((ValueListExpression) expression).getExpressionList());
+            target = inExpression;
+        } else {
+            EqualsTo equalsTo = new EqualsTo();
+            equalsTo.setLeftExpression(this.getAliasColumn(table));
+            equalsTo.setRightExpression(expression);
+            target = equalsTo;
+        }
+        return target;
+    }
+
+    /**
+     * 租户字段别名设置
+     * <p>tenantId 或 tableAlias.tenantId</p>
+     *
+     * @param table 表对象
+     * @return 字段
+     */
+    protected Column getAliasColumn(Table table) {
+        StringBuilder column = new StringBuilder();
+        if (table.getAlias() != null) {
+            column.append(table.getAlias().getName()).append(StringPool.DOT);
+        }
+        column.append(tenantHandler.getTenantIdColumn());
+        return new Column(column.toString());
     }
 }