Browse Source

多租户子查询

miemie 4 years ago
parent
commit
52a1ba84af

+ 51 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -224,11 +224,13 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      */
     protected void processPlainSelect(PlainSelect plainSelect) {
         FromItem fromItem = plainSelect.getFromItem();
+        Expression where = plainSelect.getWhere();
+        processWhere(where);
         if (fromItem instanceof Table) {
             Table fromTable = (Table) fromItem;
             if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
                 //#1186 github
-                plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
+                plainSelect.setWhere(builderExpression(where, fromTable));
             }
         } else {
             processFromItem(fromItem);
@@ -242,6 +244,54 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         }
     }
 
+    /**
+     * 处理子查询
+     * <p>
+     * 支持如下:
+     * 1. in
+     * 2. =
+     * 3. >
+     * 4. <
+     * 5. >=
+     * 6. <=
+     * 7. <>
+     *
+     * @param where where 条件
+     */
+    protected void processWhere(Expression where) {
+        if (where == null) {
+            return;
+        }
+        if (where instanceof SubSelect) {
+            processSelectBody(((SubSelect) where).getSelectBody());
+            return;
+        }
+        if (where.toString().indexOf("SELECT") > 0) {
+            // 有子查询
+            if (where instanceof AndExpression) {
+                AndExpression expression = (AndExpression) where;
+                processWhere(expression.getLeftExpression());
+                processWhere(expression.getRightExpression());
+            } else if (where instanceof OrExpression) {
+                OrExpression expression = (OrExpression) where;
+                processWhere(expression.getLeftExpression());
+                processWhere(expression.getRightExpression());
+            } else if (where instanceof InExpression) {
+                InExpression expression = (InExpression) where;
+                processItemsList(expression.getRightItemsList());
+            } else if (where instanceof ComparisonOperator) {
+                ComparisonOperator expression = (ComparisonOperator) where;
+                processWhere(expression.getRightExpression());
+            }
+        }
+    }
+
+    protected void processItemsList(ItemsList itemsList) {
+        if (itemsList instanceof SubSelect) {
+            processSelectBody(((SubSelect) itemsList).getSelectBody());
+        }
+    }
+
     /**
      * 处理子查询等
      */

+ 28 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java

@@ -78,6 +78,34 @@ class TenantLineInnerInterceptorTest {
             "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
     }
 
+    @Test
+    void selectSubSelect() {
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id IN " +
+                "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id = ? AND e.id IN " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id = " +
+                "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id = ? AND e.id = " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id >= " +
+                "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id = ? AND e.id >= " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id <= " +
+                "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id = ? AND e.id <= " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e WHERE e.id = ? and e.id <> " +
+                "(select e1.id from entity1 e1 where e1.id = ?) and e.id = ?",
+            "SELECT * FROM entity e WHERE e.id = ? AND e.id <> " +
+                "(SELECT e1.id FROM entity1 e1 WHERE e1.id = ? AND e1.tenant_id = 1) AND e.id = ? AND e.tenant_id = 1");
+    }
+
     @Test
     void selectLeftJoin() {
         // left join