浏览代码

:ambulance: 调整了多租户 sql 解析流程,修复 right join、subJoin 的问题,以及优化 innerJoin

Hccake 3 年之前
父节点
当前提交
f8f209c261

+ 103 - 86
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -41,8 +41,9 @@ import org.apache.ibatis.session.RowBounds;
 
 import java.sql.Connection;
 import java.sql.SQLException;
-import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.Deque;
 import java.util.LinkedList;
 import java.util.List;
@@ -235,38 +236,48 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      * 处理 PlainSelect
      */
     protected void processPlainSelect(PlainSelect plainSelect) {
-        FromItem fromItem = plainSelect.getFromItem();
-
         //#3087 github
         List<SelectItem> selectItems = plainSelect.getSelectItems();
         if (CollectionUtils.isNotEmpty(selectItems)) {
             selectItems.forEach(this::processSelectItem);
         }
 
-        // #I4FP6E gitee:右连接查询时,where 条件需要过滤
-        List<Table> rightJointTables;
+        // 处理 where 中的子查询
+        Expression where = plainSelect.getWhere();
+        processWhereSubSelect(where);
+
+        // 处理 fromItem
+        FromItem fromItem = plainSelect.getFromItem();
+        Table mainTable = processFromItem(fromItem);
+
+        // 处理 join
         List<Join> joins = plainSelect.getJoins();
         if (CollectionUtils.isNotEmpty(joins)) {
-            rightJointTables = processJoins(joins);
-        }else {
-            rightJointTables = new ArrayList<>();
+            mainTable = processJoins(mainTable, joins);
         }
 
-        Expression where = plainSelect.getWhere();
-        processWhereSubSelect(where);
+        // 当有 mainTable 时,进行 where 条件追加
+        if (mainTable != null) {
+            plainSelect.setWhere(builderExpression(where, Collections.singletonList(mainTable)));
+        }
+    }
+
+    private Table processFromItem(FromItem fromItem) {
+        Table mainTable = null;
+        // 无 join 时的处理逻辑
         if (fromItem instanceof Table) {
             Table fromTable = (Table) fromItem;
-            boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName());
-            if (needIgnore) {
-                plainSelect.setWhere(builderExpression(where, null, rightJointTables));
-            }else {
-                //#1186 github
-                plainSelect.setWhere(builderExpression(where, fromTable, rightJointTables));
+            if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
+                mainTable = fromTable;
             }
+        } else if (fromItem instanceof SubJoin) {
+            // SubJoin 类型则还需要添加上 where 条件
+            mainTable = processSubJoin((SubJoin) fromItem);
         } else {
-            processFromItem(fromItem);
+            // 处理下 fromItem
+            processOtherFromItem(fromItem);
         }
-
+        return mainTable;
     }
 
     /**
@@ -294,7 +305,7 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
             return;
         }
         if (where instanceof FromItem) {
-            processFromItem((FromItem) where);
+            processOtherFromItem((FromItem) where);
             return;
         }
         if (where.toString().indexOf("SELECT") > 0) {
@@ -360,16 +371,8 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
     /**
      * 处理子查询等
      */
-    protected void processFromItem(FromItem fromItem) {
-        if (fromItem instanceof SubJoin) {
-            SubJoin subJoin = (SubJoin) fromItem;
-            if (subJoin.getJoinList() != null) {
-                processJoins(subJoin.getJoinList());
-            }
-            if (subJoin.getLeft() != null) {
-                processFromItem(subJoin.getLeft());
-            }
-        } else if (fromItem instanceof SubSelect) {
+    protected void processOtherFromItem(FromItem fromItem) {
+        if (fromItem instanceof SubSelect) {
             SubSelect subSelect = (SubSelect) fromItem;
             if (subSelect.getSelectBody() != null) {
                 processSelectBody(subSelect.getSelectBody());
@@ -387,104 +390,118 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         }
     }
 
+    /**
+     * 处理 sub join
+     *
+     * @param subJoin subJoin
+     * @return Table subJoin 中的主表
+     */
+    private Table processSubJoin(SubJoin subJoin) {
+        Table mainTable = null;
+        if (subJoin.getJoinList() != null) {
+            mainTable = processFromItem(subJoin.getLeft());
+            mainTable = processJoins(mainTable, subJoin.getJoinList());
+        }
+        return mainTable;
+    }
+
     /**
      * 处理 joins
      *
-     * @param joins join 集合
+     * @param fromTable 可以为 null
+     * @param joins     join 集合
      * @return List<Table> 右连接查询的 Table 列表
      */
-    private List<Table> processJoins(List<Join> joins) {
-
-        List<Table> rightJointTables = new ArrayList<>();
+    private Table processJoins(Table fromTable, List<Join> joins) {
+        // join 表达式中最终的主表
+        Table mainTable = fromTable;
+        // 当前 join 的左表
+        Table leftTable = fromTable;
 
         //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
-        Deque<Table> tables = new LinkedList<>();
+        Deque<List<Table>> onTableDeque = new LinkedList<>();
         for (Join join : joins) {
+            List<Table> onTables = null;
             // 处理 on 表达式
-            FromItem fromItem = join.getRightItem();
-            if (fromItem instanceof Table) {
-                Table fromTable = (Table) fromItem;
+            FromItem joinItem = join.getRightItem();
+
+            // 获取当前 join 的表,subJoint 可以看作是一张表
+            Table joinTable = null;
+            if (joinItem instanceof Table) {
+                joinTable = (Table) joinItem;
+            } else if (joinItem instanceof SubJoin) {
+                joinTable = processSubJoin((SubJoin) joinItem);
+            }
+
+            if (joinTable != null) {
                 // 获取 join 尾缀的 on 表达式列表
                 Collection<Expression> originOnExpressions = join.getOnExpressions();
 
                 // 当前表是否忽略
-                boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName());
+                boolean joinTableNeedIgnore = tenantLineHandler.ignoreTable(joinTable.getName());
+
                 // 如果不要忽略,且是右连接,则记录下当前表
-                if (!needIgnore && join.isRight()) {
-                    rightJointTables.add(fromTable);
+                if (join.isRight()) {
+                    mainTable = joinTableNeedIgnore ? null : joinTable;
+                    if (leftTable != null) {
+                        onTables = Collections.singletonList(leftTable);
+                    }
+                } else if (join.isLeft()) {
+                    if (!joinTableNeedIgnore) {
+                        onTables = Collections.singletonList(joinTable);
+                    }
+                } else if (join.isInner()) {
+                    if (mainTable == null) {
+                        onTables = Collections.singletonList(joinTable);
+                    } else {
+                        onTables = Arrays.asList(mainTable, joinTable);
+                    }
+                    mainTable = null;
                 }
 
                 // 正常 join on 表达式只有一个,立刻处理
-                if (originOnExpressions.size() == 1) {
-                    processJoin(join);
+                if (originOnExpressions.size() == 1 && onTables != null) {
+                    List<Expression> onExpressions = new LinkedList<>();
+                    onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
+                    join.setOnExpressions(onExpressions);
+                    leftTable = joinTable;
                     continue;
                 }
 
                 // 表名压栈,忽略的表压入 null,以便后续不处理
-                tables.push(needIgnore ? null : fromTable);
+                onTableDeque.push(onTables);
                 // 尾缀多个 on 表达式的时候统一处理
                 if (originOnExpressions.size() > 1) {
                     Collection<Expression> onExpressions = new LinkedList<>();
                     for (Expression originOnExpression : originOnExpressions) {
-                        Table currentTable = tables.poll();
-                        if (currentTable == null) {
+                        List<Table> currentTableList = onTableDeque.poll();
+                        if (CollectionUtils.isEmpty(currentTableList)) {
                             onExpressions.add(originOnExpression);
                         } else {
-                            onExpressions.add(builderExpression(originOnExpression, currentTable));
+                            onExpressions.add(builderExpression(originOnExpression, currentTableList));
                         }
                     }
                     join.setOnExpressions(onExpressions);
                 }
+                leftTable = joinTable;
             } else {
-                // 处理右边连接的子表达式
-                processFromItem(fromItem);
+                processOtherFromItem(joinItem);
+                leftTable = null;
             }
-        }
 
-        return rightJointTables;
-    }
-
-    /**
-     * 处理联接语句
-     */
-    protected void processJoin(Join join) {
-        if (join.getRightItem() instanceof Table) {
-            Table fromTable = (Table) join.getRightItem();
-            if (tenantLineHandler.ignoreTable(fromTable.getName())) {
-                // 过滤退出执行
-                return;
-            }
-            // 走到这里说明 on 表达式肯定只有一个
-            Collection<Expression> originOnExpressions = join.getOnExpressions();
-            List<Expression> onExpressions = new LinkedList<>();
-            onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable));
-            join.setOnExpressions(onExpressions);
         }
-    }
 
-    /**
-     * 处理条件
-     */
-    protected Expression builderExpression(Expression currentExpression, Table table) {
-        return builderExpression(currentExpression, table, new ArrayList<>());
+        return mainTable;
     }
 
     /**
      * 处理条件
      */
-    protected Expression builderExpression(Expression currentExpression, Table table, List<Table> rightJointTables) {
-       // 没有表需要处理直接返回
-        if(table == null && CollectionUtils.isEmpty(rightJointTables)){
-           return currentExpression;
-       }
-
-        // 当前需要处理的表
-        List<Table> tables = new ArrayList<>();
-        if(table != null){
-            tables.add(table);
+    protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
+        // 没有表需要处理直接返回
+        if (CollectionUtils.isEmpty(tables)) {
+            return currentExpression;
         }
-        tables.addAll(rightJointTables);
-
         // 租户
         Expression tenantId = tenantLineHandler.getTenantId();
         // 构造每张表的条件
@@ -494,7 +511,7 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         // 注入的表达式
         Expression injectExpression = equalsTos.get(0);
         // 如果有多表,则用 and 连接
-        if(equalsTos.size() > 1){
+        if (equalsTos.size() > 1) {
             for (int i = 1; i < equalsTos.size(); i++) {
                 injectExpression = new AndExpression(injectExpression, equalsTos.get(i));
             }

+ 114 - 12
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java

@@ -178,6 +178,14 @@ class TenantLineInnerInterceptorTest {
             "SELECT * FROM entity e " +
                 "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
                 "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e " +
+                "left join entity1 e1 on e1.id = e.id " +
+                "left join entity2 e2 on e1.id = e2.id",
+            "SELECT * FROM entity e " +
+                "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
+                "WHERE e.tenant_id = 1");
     }
 
     @Test
@@ -186,31 +194,125 @@ class TenantLineInnerInterceptorTest {
         assertSql("SELECT * FROM entity e " +
                 "right join entity1 e1 on e1.id = e.id",
             "SELECT * FROM entity e " +
-                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
-                "WHERE e.tenant_id = 1 AND e1.tenant_id = 1");
+                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
+                "WHERE e1.tenant_id = 1");
 
         assertSql("SELECT * FROM with_as_1 e " +
                 "right join entity1 e1 on e1.id = e.id",
             "SELECT * FROM with_as_1 e " +
-                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                "RIGHT JOIN entity1 e1 ON e1.id = e.id " +
                 "WHERE e1.tenant_id = 1");
 
         assertSql("SELECT * FROM entity e " +
                 "right join entity1 e1 on e1.id = e.id " +
                 "WHERE e.id = ? OR e.name = ?",
             "SELECT * FROM entity e " +
-                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
-                "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e1.tenant_id = 1");
+                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
+                "WHERE (e.id = ? OR e.name = ?) AND e1.tenant_id = 1");
 
         assertSql("SELECT * FROM entity e " +
                 "right join entity1 e1 on e1.id = e.id " +
                 "right join entity2 e2 on e1.id = e2.id ",
             "SELECT * FROM entity e " +
-                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
-                "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
-                "WHERE e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1");
+                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
+                "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
+                "WHERE e2.tenant_id = 1");
     }
 
+    @Test
+    void selectMixJoin(){
+        assertSql("SELECT * FROM entity e " +
+                "right join entity1 e1 on e1.id = e.id " +
+                "left join entity2 e2 on e1.id = e2.id",
+            "SELECT * FROM entity e " +
+                "RIGHT JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 " +
+                "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
+                "WHERE e1.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e " +
+                "left join entity1 e1 on e1.id = e.id " +
+                "right join entity2 e2 on e1.id = e2.id",
+            "SELECT * FROM entity e " +
+                "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                "RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1 " +
+                "WHERE e2.tenant_id = 1");
+
+        assertSql("SELECT * FROM entity e " +
+                "left join entity1 e1 on e1.id = e.id " +
+                "inner join entity2 e2 on e1.id = e2.id",
+            "SELECT * FROM entity e " +
+                "LEFT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
+                "INNER JOIN entity2 e2 ON e1.id = e2.id AND e.tenant_id = 1 AND e2.tenant_id = 1");
+    }
+
+
+    @Test
+    void selectJoinSubSelect(){
+        assertSql("select * from (select * from entity) e1 " +
+            "left join entity2 e2 on e1.id = e2.id",
+            "SELECT * FROM (SELECT * FROM entity WHERE tenant_id = 1) e1 " +
+                "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1");
+
+        assertSql("select * from entity1 e1 " +
+                "left join (select * from entity2) e2 " +
+                "on e1.id = e2.id",
+            "SELECT * FROM entity1 e1 " +
+                "LEFT JOIN (SELECT * FROM entity2 WHERE tenant_id = 1) e2 " +
+                "ON e1.id = e2.id " +
+                "WHERE e1.tenant_id = 1");
+    }
+
+    @Test
+    void selectSubJoin(){
+
+        assertSql("select * FROM " +
+                "(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)",
+            "SELECT * FROM " +
+                "(entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
+                "WHERE e2.tenant_id = 1");
+
+        assertSql("select * FROM " +
+                "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id)",
+            "SELECT * FROM " +
+                "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
+                "WHERE e1.tenant_id = 1");
+
+
+        assertSql("select * FROM " +
+                "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id) " +
+                "right join entity3 e3 on e1.id = e3.id",
+            "SELECT * FROM " +
+                "(entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
+                "RIGHT JOIN entity3 e3 ON e1.id = e3.id AND e1.tenant_id = 1 " +
+                "WHERE e3.tenant_id = 1");
+
+
+        assertSql("select * FROM entity e " +
+                "LEFT JOIN (entity1 e1 right join entity2 e2 ON e1.id = e2.id) " +
+                "on e.id = e2.id",
+            "SELECT * FROM entity e " +
+                "LEFT JOIN (entity1 e1 RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e1.tenant_id = 1) " +
+                "ON e.id = e2.id AND e2.tenant_id = 1 " +
+                "WHERE e.tenant_id = 1");
+
+        assertSql("select * FROM entity e " +
+                "LEFT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
+                "on e.id = e2.id",
+            "SELECT * FROM entity e " +
+                "LEFT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
+                "ON e.id = e2.id AND e1.tenant_id = 1 " +
+                "WHERE e.tenant_id = 1");
+
+        assertSql("select * FROM entity e " +
+                "RIGHT JOIN (entity1 e1 left join entity2 e2 ON e1.id = e2.id) " +
+                "on e.id = e2.id",
+            "SELECT * FROM entity e " +
+                "RIGHT JOIN (entity1 e1 LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1) " +
+                "ON e.id = e2.id AND e.tenant_id = 1 " +
+                "WHERE e1.tenant_id = 1");
+    }
+
+
     @Test
     void selectLeftJoinMultipleTrailingOn() {
         // 多个 on 尾缀的
@@ -244,15 +346,15 @@ class TenantLineInnerInterceptorTest {
                 "inner join entity1 e1 on e1.id = e.id " +
                 "WHERE e.id = ? OR e.name = ?",
             "SELECT * FROM entity e " +
-                "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
-                "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+                "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
+                "WHERE e.id = ? OR e.name = ?");
 
         assertSql("SELECT * FROM entity e " +
                 "inner join entity1 e1 on e1.id = e.id " +
                 "WHERE (e.id = ? OR e.name = ?)",
             "SELECT * FROM entity e " +
-                "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
-                "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+                "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
+                "WHERE (e.id = ? OR e.name = ?)");
 
         // 垃圾 inner join todo
 //        assertSql("SELECT * FROM entity,entity1 " +