miemie 5 years ago
parent
commit
61a277ac31

+ 14 - 6
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantSqlParser.java

@@ -238,12 +238,8 @@ public class TenantSqlParser extends AbstractJsqlParser {
         }
         if (currentExpression instanceof BinaryExpression) {
             BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
-            if (binaryExpression.getLeftExpression() instanceof FromItem) {
-                processFromItem((FromItem) binaryExpression.getLeftExpression());
-            }
-            if (binaryExpression.getRightExpression() instanceof FromItem) {
-                processFromItem((FromItem) binaryExpression.getRightExpression());
-            }
+            doExpression(binaryExpression.getLeftExpression());
+            doExpression(binaryExpression.getRightExpression());
         } else if (currentExpression instanceof InExpression) {
             InExpression inExp = (InExpression) currentExpression;
             ItemsList rightItems = inExp.getRightItemsList();
@@ -258,6 +254,18 @@ public class TenantSqlParser extends AbstractJsqlParser {
         }
     }
 
+    protected void doExpression(Expression expression) {
+        if (expression instanceof FromItem) {
+            processFromItem((FromItem) expression);
+        } else if (expression instanceof InExpression) {
+            InExpression inExp = (InExpression) expression;
+            ItemsList rightItems = inExp.getRightItemsList();
+            if (rightItems instanceof SubSelect) {
+                processSelectBody(((SubSelect) rightItems).getSelectBody());
+            }
+        }
+    }
+
     /**
      * 目前: 针对自定义的tenantId的条件表达式[tenant_id in (1,2,3)],无法处理多租户的字段加上表别名
      * select a.id, b.name

+ 49 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/tenant/TenantSqlParserTest.java

@@ -0,0 +1,49 @@
+package com.baomidou.mybatisplus.extension.plugins.tenant;
+
+import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.LongValue;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.statement.Statements;
+import net.sf.jsqlparser.statement.select.Select;
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author miemie
+ * @since 2019-11-02
+ */
+
+public class TenantSqlParserTest {
+
+    private final TenantSqlParser parser = new TenantSqlParser()
+        .setTenantHandler(new TenantHandler() {
+            @Override
+            public Expression getTenantId(boolean where) {
+                return new LongValue(1);
+            }
+
+            @Override
+            public String getTenantIdColumn() {
+                return "t_id";
+            }
+
+            @Override
+            public boolean doTableFilter(String tableName) {
+                return false;
+            }
+        });
+
+    @Test
+    public void processSelectBody() throws JSQLParserException {
+        m("select * from user", "select * from user where user.t_id = 1");
+    }
+
+    private void m(String sql, String target) throws JSQLParserException {
+        Statements statement = CCJSqlParserUtil.parseStatements(sql);
+        Select select = (Select) statement.getStatements().get(0);
+        parser.processSelectBody(select.getSelectBody());
+        assertThat(select.toString().toLowerCase()).isEqualTo(target);
+    }
+}