miemie преди 4 години
родител
ревизия
cc08de61ca

+ 8 - 33
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -93,13 +93,14 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
     }
     }
 
 
     protected void processSelectBody(SelectBody selectBody) {
     protected void processSelectBody(SelectBody selectBody) {
+        if (selectBody == null) {
+            return;
+        }
         if (selectBody instanceof PlainSelect) {
         if (selectBody instanceof PlainSelect) {
             processPlainSelect((PlainSelect) selectBody);
             processPlainSelect((PlainSelect) selectBody);
         } else if (selectBody instanceof WithItem) {
         } else if (selectBody instanceof WithItem) {
             WithItem withItem = (WithItem) selectBody;
             WithItem withItem = (WithItem) selectBody;
-            if (withItem.getSelectBody() != null) {
-                processSelectBody(withItem.getSelectBody());
-            }
+            processSelectBody(withItem.getSelectBody());
         } else {
         } else {
             SetOperationList operationList = (SetOperationList) selectBody;
             SetOperationList operationList = (SetOperationList) selectBody;
             if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
             if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
@@ -285,7 +286,10 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
                 processWhereSubSelect(expression.getRightExpression());
                 processWhereSubSelect(expression.getRightExpression());
             } else if (where instanceof InExpression) {
             } else if (where instanceof InExpression) {
                 InExpression expression = (InExpression) where;
                 InExpression expression = (InExpression) where;
-                processItemsList(expression.getRightItemsList());
+                ItemsList itemsList = expression.getRightItemsList();
+                if (itemsList instanceof SubSelect) {
+                    processSelectBody(((SubSelect) itemsList).getSelectBody());
+                }
             } else if (where instanceof ComparisonOperator) {
             } else if (where instanceof ComparisonOperator) {
                 ComparisonOperator expression = (ComparisonOperator) where;
                 ComparisonOperator expression = (ComparisonOperator) where;
                 processWhereSubSelect(expression.getRightExpression());
                 processWhereSubSelect(expression.getRightExpression());
@@ -299,12 +303,6 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         }
         }
     }
     }
 
 
-    protected void processItemsList(ItemsList itemsList) {
-        if (itemsList instanceof SubSelect) {
-            processSelectBody(((SubSelect) itemsList).getSelectBody());
-        }
-    }
-
     /**
     /**
      * 处理子查询等
      * 处理子查询等
      */
      */
@@ -359,17 +357,6 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         if (currentExpression == null) {
         if (currentExpression == null) {
             return equalsTo;
             return equalsTo;
         }
         }
-        if (currentExpression instanceof BinaryExpression) {
-            BinaryExpression binaryExpression = (BinaryExpression) currentExpression;
-            doExpression(binaryExpression.getLeftExpression());
-            doExpression(binaryExpression.getRightExpression());
-        } else if (currentExpression instanceof InExpression) {
-            InExpression inExp = (InExpression) currentExpression;
-            ItemsList rightItems = inExp.getRightItemsList();
-            if (rightItems instanceof SubSelect) {
-                processSelectBody(((SubSelect) rightItems).getSelectBody());
-            }
-        }
         if (currentExpression instanceof OrExpression) {
         if (currentExpression instanceof OrExpression) {
             return new AndExpression(new Parenthesis(currentExpression), equalsTo);
             return new AndExpression(new Parenthesis(currentExpression), equalsTo);
         } else {
         } else {
@@ -377,18 +364,6 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         }
         }
     }
     }
 
 
-    protected void doExpression(Expression expression) {
-        if (expression instanceof FromItem) {
-            processFromItem((FromItem) expression);
-        } else if (expression instanceof InExpression) {
-            InExpression inExp = (InExpression) expression;
-            ItemsList rightItems = inExp.getRightItemsList();
-            if (rightItems instanceof SubSelect) {
-                processSelectBody(((SubSelect) rightItems).getSelectBody());
-            }
-        }
-    }
-
     /**
     /**
      * 租户字段别名设置
      * 租户字段别名设置
      * <p>tenantId 或 tableAlias.tenantId</p>
      * <p>tenantId 或 tableAlias.tenantId</p>

+ 40 - 6
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java

@@ -80,21 +80,55 @@ class TenantLineInnerInterceptorTest {
 
 
     @Test
     @Test
     void selectSubSelect() {
     void selectSubSelect() {
-        // in
+        /* in */
+        assertSql("SELECT * FROM entity e WHERE e.id IN (select e1.id from entity1 e1 where e1.id = ?)",
+            "SELECT * FROM entity e WHERE e.id IN (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+        // 在最前
+        assertSql("SELECT * FROM entity e WHERE e.id IN " +
+                "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id IN " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+        // 在最后
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
+                "(select e1.id from entity1 e1 where e1.id = ?)",
+            "SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+        // 在中间
         assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
         assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
                 "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
                 "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
             "SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
             "SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
                 "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
                 "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
-        // =
-        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id = " +
+
+
+        /* = */
+        assertSql("SELECT * FROM entity e WHERE e.id = (select e1.id from entity1 e1 where e1.id = ?)",
+            "SELECT * FROM entity e WHERE e.id = (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+        // 在最前
+        assertSql("SELECT * FROM entity e WHERE e.id = " +
                 "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
                 "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id = " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+        // 在最后
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id = " +
+                "(select e1.id from entity1 e1 where e1.id = ?)",
             "SELECT * FROM entity e WHERE e.id = ? AND e.id = " +
             "SELECT * FROM entity e WHERE e.id = ? AND e.id = " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+
+
+        /* >= */
+        assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?)",
+            "SELECT * FROM entity e WHERE e.id >= (SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+        // 在最前
+        assertSql("SELECT * FROM entity e WHERE e.id >= (select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id >= " +
                 "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
                 "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
-        // >=
+        // 在最后
         assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id >= " +
         assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id >= " +
-                "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+                "(select e1.id from entity1 e1 where e1.id = ?)",
             "SELECT * FROM entity e WHERE e.id = ? AND e.id >= " +
             "SELECT * FROM entity e WHERE e.id = ? AND e.id >= " +
-                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.tenant_id = 1");
+
+
         // <=
         // <=
         assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id <= " +
         assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id <= " +
                 "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
                 "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",