ソースを参照

多租户很麻烦啊

miemie 5 年 前
コミット
f3c875c229

+ 1 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantHandler.java

@@ -30,7 +30,7 @@ public interface TenantHandler {
      * <p>
      * 支持自定义表达式,比如:tenant_id in (1,2) @since 2019-8-2
      *
-     * @param where 参数 true 表示为 where 条件 false 表示为 insert 或者 select 条件
+     * @param where 参数 true 表示为 select 下的 where 条件,false 表示 insert/update/delete 下的条件
      * @return 租户 ID 值表达式
      */
     Expression getTenantId(boolean where);

+ 17 - 13
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantSqlParser.java

@@ -26,6 +26,7 @@ import lombok.experimental.Accessors;
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.Parenthesis;
+import net.sf.jsqlparser.expression.ValueListExpression;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
 import net.sf.jsqlparser.expression.operators.relational.*;
@@ -130,7 +131,7 @@ public class TenantSqlParser extends AbstractJsqlParser {
         //获得where条件表达式
         EqualsTo equalsTo = new EqualsTo();
         equalsTo.setLeftExpression(this.getAliasColumn(table));
-        equalsTo.setRightExpression(tenantHandler.getTenantId(true));
+        equalsTo.setRightExpression(tenantHandler.getTenantId(false));
         if (null != where) {
             if (where instanceof OrExpression) {
                 return new AndExpression(equalsTo, new Parenthesis(where));
@@ -227,15 +228,8 @@ public class TenantSqlParser extends AbstractJsqlParser {
      * 默认tenantId的表达式: LongValue(1)这种依旧支持
      */
     protected Expression builderExpression(Expression currentExpression, Table table) {
-        final Expression tenantExpression = tenantHandler.getTenantId(false);
-        Expression appendExpression;
-        if (!(tenantExpression instanceof SupportsOldOracleJoinSyntax)) {
-            appendExpression = new EqualsTo();
-            ((EqualsTo) appendExpression).setLeftExpression(this.getAliasColumn(table));
-            ((EqualsTo) appendExpression).setRightExpression(tenantExpression);
-        } else {
-            appendExpression = processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
-        }
+        final Expression tenantExpression = tenantHandler.getTenantId(true);
+        Expression appendExpression = this.processTableAlias4CustomizedTenantIdExpression(tenantExpression, table);
         if (currentExpression == null) {
             return appendExpression;
         }
@@ -280,9 +274,19 @@ public class TenantSqlParser extends AbstractJsqlParser {
      * @return 加上别名的多租户字段表达式
      */
     protected Expression processTableAlias4CustomizedTenantIdExpression(Expression expression, Table table) {
-        //cannot add table alias for customized tenantId expression,
-        // when tables including tenantId at the join table poistion
-        return expression;
+        Expression target;
+        if (expression instanceof ValueListExpression) {
+            InExpression inExpression = new InExpression();
+            inExpression.setLeftExpression(this.getAliasColumn(table));
+            inExpression.setRightItemsList(((ValueListExpression) expression).getExpressionList());
+            target = inExpression;
+        } else {
+            EqualsTo equalsTo = new EqualsTo();
+            equalsTo.setLeftExpression(this.getAliasColumn(table));
+            equalsTo.setRightExpression(expression);
+            target = equalsTo;
+        }
+        return target;
     }
 
     /**

+ 10 - 12
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantSqlParserTest.java

@@ -3,10 +3,9 @@ package com.baomidou.mybatisplus.extension.plugins.tenant;
 import net.sf.jsqlparser.JSQLParserException;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.LongValue;
+import net.sf.jsqlparser.expression.ValueListExpression;
 import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
-import net.sf.jsqlparser.expression.operators.relational.InExpression;
 import net.sf.jsqlparser.parser.CCJSqlParserUtil;
-import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.statement.Statements;
 import net.sf.jsqlparser.statement.select.Select;
 import net.sf.jsqlparser.statement.update.Update;
@@ -28,11 +27,10 @@ public class TenantSqlParserTest {
                 if (!where) {
                     return new LongValue(1);
                 }
-                final InExpression inExpression = new InExpression();
-                inExpression.setLeftExpression(new Column(getTenantIdColumn()));
-                final ExpressionList itemsList = new ExpressionList(new LongValue(1), new LongValue(2));
-                inExpression.setRightItemsList(itemsList);
-                return inExpression;
+                ValueListExpression expression = new ValueListExpression();
+                ExpressionList list = new ExpressionList(new LongValue(1), new LongValue(2));
+                expression.setExpressionList(list);
+                return expression;
             }
 
             @Override
@@ -49,15 +47,15 @@ public class TenantSqlParserTest {
     @Test
     public void processSelectBody() throws JSQLParserException {
         select("select * from user",
-            "select * from user where t_id = 1");
+            "select * from user where t_id in (1, 2)");
         select("select * from user u",
-            "select * from user u where u.t_id = 1");
+            "select * from user u where u.t_id in (1, 2)");
         select("select * from user where id in (select id from user)",
-            "select * from user where id in (select id from user where t_id = 1) and t_id = 1");
+            "select * from user where id in (select id from user where t_id in (1, 2)) and t_id in (1, 2)");
         select("select * from user where id = 1 and id in (select id from user)",
-            "select * from user where id = 1 and id in (select id from user where t_id = 1) and t_id = 1");
+            "select * from user where id = 1 and id in (select id from user where t_id in (1, 2)) and t_id in (1, 2)");
         select("select * from user where id = 1 or id in (select id from user)",
-            "select * from user where (id = 1 or id in (select id from user where t_id = 1)) and t_id = 1");
+            "select * from user where (id = 1 or id in (select id from user where t_id in (1, 2))) and t_id in (1, 2)");
 
         update("update user set age = 1",
             "update user set age = 1 where t_id = 1");