瀏覽代碼

fix: #1186 github: select from a join b: tenant_id in table b
support: tenant_id in (1,2)

yuxiaobin 5 年之前
父節點
當前提交
c7be69175b

+ 3 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantHandler.java

@@ -27,6 +27,8 @@ public interface TenantHandler {
 
     /**
      * 获取租户值
+     * <p>
+     * 支持自定义表达式,比如:tenant_id in (1,2) @since 2019-8-2
      *
      * @return 租户值
      */
@@ -43,7 +45,7 @@ public interface TenantHandler {
      * 根据表名判断是否进行过滤
      *
      * @param tableName 表名
-     * @return 是否进行过滤
+     * @return 是否进行过滤, true:表示忽略,false:需要解析多租户字段
      */
     boolean doTableFilter(String tableName);
 }

+ 73 - 40
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantSqlParser.java

@@ -15,10 +15,13 @@
  */
 package com.baomidou.mybatisplus.extension.plugins.tenant;
 
+import java.util.List;
+
 import com.baomidou.mybatisplus.core.parser.AbstractJsqlParser;
 import com.baomidou.mybatisplus.core.toolkit.Assert;
 import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
+
 import lombok.Data;
 import lombok.EqualsAndHashCode;
 import lombok.experimental.Accessors;
@@ -27,16 +30,29 @@ import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.Parenthesis;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
-import net.sf.jsqlparser.expression.operators.relational.*;
+import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.expression.operators.relational.InExpression;
+import net.sf.jsqlparser.expression.operators.relational.ItemsList;
+import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
+import net.sf.jsqlparser.expression.operators.relational.SupportsOldOracleJoinSyntax;
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.insert.Insert;
-import net.sf.jsqlparser.statement.select.*;
+import net.sf.jsqlparser.statement.select.FromItem;
+import net.sf.jsqlparser.statement.select.Join;
+import net.sf.jsqlparser.statement.select.LateralSubSelect;
+import net.sf.jsqlparser.statement.select.PlainSelect;
+import net.sf.jsqlparser.statement.select.SelectBody;
+import net.sf.jsqlparser.statement.select.SelectExpressionItem;
+import net.sf.jsqlparser.statement.select.SetOperationList;
+import net.sf.jsqlparser.statement.select.SubJoin;
+import net.sf.jsqlparser.statement.select.SubSelect;
+import net.sf.jsqlparser.statement.select.ValuesList;
+import net.sf.jsqlparser.statement.select.WithItem;
 import net.sf.jsqlparser.statement.update.Update;
 
-import java.util.List;
-
 /**
  * 租户 SQL 解析器( TenantId 行级 )
  *
@@ -158,13 +174,11 @@ public class TenantSqlParser extends AbstractJsqlParser {
         FromItem fromItem = plainSelect.getFromItem();
         if (fromItem instanceof Table) {
             Table fromTable = (Table) fromItem;
-            if (tenantHandler.doTableFilter(fromTable.getName())) {
-                // 过滤退出执行
-                return;
-            }
-            plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
-            if (addColumn) {
-                plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(tenantHandler.getTenantIdColumn())));
+            if (!this.getTenantHandler().doTableFilter(fromTable.getName())) {//#1186 github
+                plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
+                if (addColumn) {
+                    plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(this.getTenantHandler().getTenantIdColumn())));
+                }
             }
         } else {
             processFromItem(fromItem);
@@ -223,40 +237,59 @@ public class TenantSqlParser extends AbstractJsqlParser {
     }
 
     /**
-     * 处理条件
+     * 处理条件:
+     * 支持 getTenantHandler().getTenantId()是一个完整的表达式:tenant in (1,2)
+     * 默认tenantId的表达式: LongValue(1)这种依旧支持
      */
-    protected Expression builderExpression(Expression expression, Table table) {
-        //生成字段名
-        EqualsTo equalsTo = new EqualsTo();
-        equalsTo.setLeftExpression(this.getAliasColumn(table));
-        equalsTo.setRightExpression(tenantHandler.getTenantId());
-        //加入判断防止条件为空时生成 "and null" 导致查询结果为空
-        if (expression == null) {
-            return equalsTo;
+    protected Expression builderExpression(Expression currentExpression, Table table) {
+        final Expression tenantExpression = this.getTenantHandler().getTenantId();
+        Expression appendExpression;
+        if (!(tenantExpression instanceof SupportsOldOracleJoinSyntax)) {
+            appendExpression = new EqualsTo();
+            ((EqualsTo) appendExpression).setLeftExpression(this.getAliasColumn(table));
+            ((EqualsTo) appendExpression).setRightExpression(tenantExpression);
         } else {
-            if (expression instanceof BinaryExpression) {
-                BinaryExpression binaryExpression = (BinaryExpression) expression;
-                if (binaryExpression.getLeftExpression() instanceof FromItem) {
-                    processFromItem((FromItem) binaryExpression.getLeftExpression());
-                }
-                if (binaryExpression.getRightExpression() instanceof FromItem) {
-                    processFromItem((FromItem) binaryExpression.getRightExpression());
-                }
+            appendExpression = processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
+        }
+        if (currentExpression == null) {
+            return appendExpression;
+        }
+        if (currentExpression instanceof BinaryExpression) {
+            BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
+            if (binaryExpression.getLeftExpression() instanceof FromItem) {
+                processFromItem((FromItem) binaryExpression.getLeftExpression());
             }
-            if (expression instanceof OrExpression) {
-                return new AndExpression(equalsTo, new Parenthesis(expression));
-            } else {
-                // fix github 1201
-                if (expression instanceof InExpression) {
-                    InExpression inExp = (InExpression) expression;
-                    ItemsList rightItems = inExp.getRightItemsList();
-                    if (rightItems instanceof SubSelect) {
-                        processSelectBody(((SubSelect) rightItems).getSelectBody());
-                    }
-                }
-                return new AndExpression(equalsTo, expression);
+            if (binaryExpression.getRightExpression() instanceof FromItem) {
+                processFromItem((FromItem) binaryExpression.getRightExpression());
+            }
+        } else if (currentExpression instanceof InExpression) {
+            InExpression inExp = (InExpression) currentExpression;
+            ItemsList rightItems = inExp.getRightItemsList();
+            if (rightItems instanceof SubSelect) {
+                processSelectBody(((SubSelect) rightItems).getSelectBody());
             }
         }
+        if (currentExpression instanceof OrExpression) {
+            return new AndExpression(new Parenthesis(currentExpression), appendExpression);
+        } else {
+            return new AndExpression(currentExpression, appendExpression);
+        }
+    }
+
+    /**
+     * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
+     * select a.id, b.name
+     * from a
+     * join b on b.aid = a.id and [b.]tenant_id in (1,2) --别名[b.]无法加上 TODO
+     *
+     * @param expression
+     * @param table
+     * @return 加上别名的多租户字段表达式
+     */
+    protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
+        //cannot add table alias for customized tenantId expression,
+        // when tables including tenantId at the join table poistion
+        return expression;
     }
 
     /**