Browse Source

多租户ID 值表达式,支持多个 ID 条件查询

hubin 5 years ago
parent
commit
059c622402

+ 4 - 3
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantHandler.java

@@ -26,13 +26,14 @@ import net.sf.jsqlparser.expression.Expression;
 public interface TenantHandler {
 
     /**
-     * 获取租户值
+     * 获取租户 ID 表达式,支持多个 ID 条件查询
      * <p>
      * 支持自定义表达式,比如:tenant_id in (1,2) @since 2019-8-2
      *
-     * @return 租户值
+     * @param where 参数 true 表示为 where 条件 false 表示为 insert 或者 select 条件
+     * @return 租户 ID 值表达式
      */
-    Expression getTenantId();
+    Expression getTenantId(boolean where);
 
     /**
      * 获取租户字段名

+ 13 - 26
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantSqlParser.java

@@ -15,13 +15,10 @@
  */
 package com.baomidou.mybatisplus.extension.plugins.tenant;
 
-import java.util.List;
-
 import com.baomidou.mybatisplus.core.parser.AbstractJsqlParser;
 import com.baomidou.mybatisplus.core.toolkit.Assert;
 import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
-
 import lombok.Data;
 import lombok.EqualsAndHashCode;
 import lombok.experimental.Accessors;
@@ -30,29 +27,17 @@ import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.Parenthesis;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
-import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
-import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
-import net.sf.jsqlparser.expression.operators.relational.InExpression;
-import net.sf.jsqlparser.expression.operators.relational.ItemsList;
-import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
-import net.sf.jsqlparser.expression.operators.relational.SupportsOldOracleJoinSyntax;
+import net.sf.jsqlparser.expression.operators.relational.*;
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.Statement;
 import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.insert.Insert;
-import net.sf.jsqlparser.statement.select.FromItem;
-import net.sf.jsqlparser.statement.select.Join;
-import net.sf.jsqlparser.statement.select.LateralSubSelect;
-import net.sf.jsqlparser.statement.select.PlainSelect;
-import net.sf.jsqlparser.statement.select.SelectBody;
-import net.sf.jsqlparser.statement.select.SelectExpressionItem;
-import net.sf.jsqlparser.statement.select.SetOperationList;
-import net.sf.jsqlparser.statement.select.SubJoin;
-import net.sf.jsqlparser.statement.select.SubSelect;
-import net.sf.jsqlparser.statement.select.ValuesList;
-import net.sf.jsqlparser.statement.select.WithItem;
+import net.sf.jsqlparser.statement.select.*;
 import net.sf.jsqlparser.statement.update.Update;
 
+import java.util.List;
+
 /**
  * 租户 SQL 解析器( TenantId 行级 )
  *
@@ -102,9 +87,9 @@ public class TenantSqlParser extends AbstractJsqlParser {
             // fixed github pull/295
             ItemsList itemsList = insert.getItemsList();
             if (itemsList instanceof MultiExpressionList) {
-                ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantHandler.getTenantId()));
+                ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantHandler.getTenantId(false)));
             } else {
-                ((ExpressionList) insert.getItemsList()).getExpressions().add(tenantHandler.getTenantId());
+                ((ExpressionList) insert.getItemsList()).getExpressions().add(tenantHandler.getTenantId(false));
             }
         } else {
             throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
@@ -146,7 +131,7 @@ public class TenantSqlParser extends AbstractJsqlParser {
         //获得where条件表达式
         EqualsTo equalsTo = new EqualsTo();
         equalsTo.setLeftExpression(this.getAliasColumn(table));
-        equalsTo.setRightExpression(tenantHandler.getTenantId());
+        equalsTo.setRightExpression(tenantHandler.getTenantId(true));
         if (null != where) {
             if (where instanceof OrExpression) {
                 return new AndExpression(equalsTo, new Parenthesis(where));
@@ -174,10 +159,12 @@ public class TenantSqlParser extends AbstractJsqlParser {
         FromItem fromItem = plainSelect.getFromItem();
         if (fromItem instanceof Table) {
             Table fromTable = (Table) fromItem;
-            if (!this.getTenantHandler().doTableFilter(fromTable.getName())) {//#1186 github
+            if (!this.getTenantHandler().doTableFilter(fromTable.getName())) {
+                //#1186 github
                 plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
                 if (addColumn) {
-                    plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(this.getTenantHandler().getTenantIdColumn())));
+                    plainSelect.getSelectItems().add(new SelectExpressionItem(
+                        new Column(this.getTenantHandler().getTenantIdColumn())));
                 }
             }
         } else {
@@ -242,7 +229,7 @@ public class TenantSqlParser extends AbstractJsqlParser {
      * 默认tenantId的表达式: LongValue(1)这种依旧支持
      */
     protected Expression builderExpression(Expression currentExpression, Table table) {
-        final Expression tenantExpression = this.getTenantHandler().getTenantId();
+        final Expression tenantExpression = this.getTenantHandler().getTenantId(false);
         Expression appendExpression;
         if (!(tenantExpression instanceof SupportsOldOracleJoinSyntax)) {
             appendExpression = new EqualsTo();