Browse Source

:bug: 解决年久失修的隐式内连接问题

Hccake 3 years ago
parent
commit
5ba2470814

+ 66 - 30
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -41,6 +41,7 @@ import org.apache.ibatis.session.RowBounds;
 
 import java.sql.Connection;
 import java.sql.SQLException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
@@ -248,36 +249,43 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
 
         // 处理 fromItem
         FromItem fromItem = plainSelect.getFromItem();
-        Table mainTable = processFromItem(fromItem);
+        List<Table> list = processFromItem(fromItem);
+        List<Table> mainTables = new ArrayList<>(list);
 
         // 处理 join
         List<Join> joins = plainSelect.getJoins();
         if (CollectionUtils.isNotEmpty(joins)) {
-            mainTable = processJoins(mainTable, joins);
+            mainTables = processJoins(mainTables, joins);
         }
 
         // 当有 mainTable 时,进行 where 条件追加
-        if (mainTable != null) {
-            plainSelect.setWhere(builderExpression(where, Collections.singletonList(mainTable)));
+        if (CollectionUtils.isNotEmpty(mainTables)) {
+            plainSelect.setWhere(builderExpression(where, mainTables));
         }
     }
 
-    private Table processFromItem(FromItem fromItem) {
-        Table mainTable = null;
+    private List<Table> processFromItem(FromItem fromItem) {
+        // 处理括号括起来的表达式
+        while (fromItem instanceof ParenthesisFromItem) {
+            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
+        }
+
+        List<Table> mainTables = new ArrayList<>();
         // 无 join 时的处理逻辑
         if (fromItem instanceof Table) {
             Table fromTable = (Table) fromItem;
             if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
-                mainTable = fromTable;
+                mainTables.add(fromTable);
             }
         } else if (fromItem instanceof SubJoin) {
             // SubJoin 类型则还需要添加上 where 条件
-            mainTable = processSubJoin((SubJoin) fromItem);
+            List<Table> tables = processSubJoin((SubJoin) fromItem);
+            mainTables.addAll(tables);
         } else {
             // 处理下 fromItem
             processOtherFromItem(fromItem);
         }
-        return mainTable;
+        return mainTables;
     }
 
     /**
@@ -372,6 +380,11 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      * 处理子查询等
      */
     protected void processOtherFromItem(FromItem fromItem) {
+        // 去除括号
+        while (fromItem instanceof ParenthesisFromItem) {
+            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
+        }
+
         if (fromItem instanceof SubSelect) {
             SubSelect subSelect = (SubSelect) fromItem;
             if (subSelect.getSelectBody() != null) {
@@ -396,50 +409,65 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      * @param subJoin subJoin
      * @return Table subJoin 中的主表
      */
-    private Table processSubJoin(SubJoin subJoin) {
-        Table mainTable = null;
+    private List<Table> processSubJoin(SubJoin subJoin) {
+        List<Table> mainTables = new ArrayList<>();
         if (subJoin.getJoinList() != null) {
-            mainTable = processFromItem(subJoin.getLeft());
-            mainTable = processJoins(mainTable, subJoin.getJoinList());
+            List<Table> list = processFromItem(subJoin.getLeft());
+            mainTables.addAll(list);
+            mainTables = processJoins(mainTables, subJoin.getJoinList());
         }
-        return mainTable;
+        return mainTables;
     }
 
     /**
      * 处理 joins
      *
-     * @param fromTable 可以为 null
-     * @param joins     join 集合
+     * @param mainTables 可以为 null
+     * @param joins      join 集合
      * @return List<Table> 右连接查询的 Table 列表
      */
-    private Table processJoins(Table fromTable, List<Join> joins) {
+    private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
+        if (mainTables == null) {
+            mainTables = new ArrayList<>();
+        }
+
         // join 表达式中最终的主表
-        Table mainTable = fromTable;
+        Table mainTable = null;
         // 当前 join 的左表
-        Table leftTable = fromTable;
+        Table leftTable = null;
+        if (mainTables.size() == 1) {
+            mainTable = mainTables.get(0);
+            leftTable = mainTable;
+        }
 
         //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
         Deque<List<Table>> onTableDeque = new LinkedList<>();
         for (Join join : joins) {
-            List<Table> onTables = null;
             // 处理 on 表达式
             FromItem joinItem = join.getRightItem();
 
             // 获取当前 join 的表,subJoint 可以看作是一张表
-            Table joinTable = null;
+            List<Table> joinTables = null;
             if (joinItem instanceof Table) {
-                joinTable = (Table) joinItem;
+                joinTables = new ArrayList<>();
+                joinTables.add((Table) joinItem);
             } else if (joinItem instanceof SubJoin) {
-                joinTable = processSubJoin((SubJoin) joinItem);
+                joinTables = processSubJoin((SubJoin) joinItem);
             }
 
-            if (joinTable != null) {
-                // 获取 join 尾缀的 on 表达式列表
-                Collection<Expression> originOnExpressions = join.getOnExpressions();
+            if (joinTables != null) {
+
+                // 如果是隐式内连接
+                if (join.isSimple()) {
+                    mainTables.addAll(joinTables);
+                    continue;
+                }
 
                 // 当前表是否忽略
+                Table joinTable = joinTables.get(0);
                 boolean joinTableNeedIgnore = tenantLineHandler.ignoreTable(joinTable.getName());
 
+                List<Table> onTables = null;
                 // 如果不要忽略,且是右连接,则记录下当前表
                 if (join.isRight()) {
                     mainTable = joinTableNeedIgnore ? null : joinTable;
@@ -458,7 +486,13 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
                     }
                     mainTable = null;
                 }
+                mainTables = new ArrayList<>();
+                if (mainTable != null) {
+                    mainTables.add(mainTable);
+                }
 
+                // 获取 join 尾缀的 on 表达式列表
+                Collection<Expression> originOnExpressions = join.getOnExpressions();
                 // 正常 join on 表达式只有一个,立刻处理
                 if (originOnExpressions.size() == 1 && onTables != null) {
                     List<Expression> onExpressions = new LinkedList<>();
@@ -467,7 +501,6 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
                     leftTable = joinTable;
                     continue;
                 }
-
                 // 表名压栈,忽略的表压入 null,以便后续不处理
                 onTableDeque.push(onTables);
                 // 尾缀多个 on 表达式的时候统一处理
@@ -491,7 +524,7 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
 
         }
 
-        return mainTable;
+        return mainTables;
     }
 
     /**
@@ -536,10 +569,13 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      */
     protected Column getAliasColumn(Table table) {
         StringBuilder column = new StringBuilder();
+        // 为了兼容隐式内连接,没有别名时条件就需要加上表名
         if (table.getAlias() != null) {
-            column.append(table.getAlias().getName()).append(StringPool.DOT);
+            column.append(table.getAlias().getName());
+        } else {
+            column.append(table.getName());
         }
-        column.append(tenantLineHandler.getTenantIdColumn());
+        column.append(StringPool.DOT).append(tenantLineHandler.getTenantIdColumn());
         return new Column(column.toString());
     }
 

+ 50 - 24
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java

@@ -38,48 +38,48 @@ class TenantLineInnerInterceptorTest {
             "INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, ?)");
         // insert into select
         assertSql("insert into entity (id,name) select id,name from entity2",
-            "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM entity2 WHERE tenant_id = 1");
+            "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM entity2 WHERE entity2.tenant_id = 1");
 
         assertSql("insert into entity (id,name) select * from entity2",
-            "INSERT INTO entity (id, name, tenant_id) SELECT * FROM entity2 WHERE tenant_id = 1");
+            "INSERT INTO entity (id, name, tenant_id) SELECT * FROM entity2 WHERE entity2.tenant_id = 1");
 
         assertSql("insert into entity (id,name) select id,name from (select id,name from entity3) t",
-            "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t");
+            "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM (SELECT id, name, tenant_id FROM entity3 WHERE entity3.tenant_id = 1) t");
 
         assertSql("insert into entity (id,name) select * from (select id,name from entity3) t",
-            "INSERT INTO entity (id, name, tenant_id) SELECT * FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t");
+            "INSERT INTO entity (id, name, tenant_id) SELECT * FROM (SELECT id, name, tenant_id FROM entity3 WHERE entity3.tenant_id = 1) t");
 
         assertSql("insert into entity (id,name) select t.* from (select id,name from entity3) t",
-            "INSERT INTO entity (id, name, tenant_id) SELECT t.* FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t");
+            "INSERT INTO entity (id, name, tenant_id) SELECT t.* FROM (SELECT id, name, tenant_id FROM entity3 WHERE entity3.tenant_id = 1) t");
     }
 
     @Test
     void delete() {
         assertSql("delete from entity where id = ?",
-            "DELETE FROM entity WHERE tenant_id = 1 AND id = ?");
+            "DELETE FROM entity WHERE entity.tenant_id = 1 AND id = ?");
     }
 
     @Test
     void update() {
         assertSql("update entity set name = ? where id = ?",
-            "UPDATE entity SET name = ? WHERE tenant_id = 1 AND id = ?");
+            "UPDATE entity SET name = ? WHERE entity.tenant_id = 1 AND id = ?");
     }
 
     @Test
     void selectSingle() {
         // 单表
         assertSql("select * from entity where id = ?",
-            "SELECT * FROM entity WHERE id = ? AND tenant_id = 1");
+            "SELECT * FROM entity WHERE id = ? AND entity.tenant_id = 1");
 
         assertSql("select * from entity where id = ? or name = ?",
-            "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
+            "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
 
         assertSql("SELECT * FROM entity WHERE (id = ? OR name = ?)",
-            "SELECT * FROM entity WHERE (id = ? OR name = ?) AND tenant_id = 1");
+            "SELECT * FROM entity WHERE (id = ? OR name = ?) AND entity.tenant_id = 1");
 
         /* not */
         assertSql("SELECT * FROM entity WHERE not (id = ? OR name = ?)",
-            "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND tenant_id = 1");
+            "SELECT * FROM entity WHERE NOT (id = ? OR name = ?) AND entity.tenant_id = 1");
     }
 
     @Test
@@ -220,7 +220,7 @@ class TenantLineInnerInterceptorTest {
     }
 
     @Test
-    void selectMixJoin(){
+    void selectMixJoin() {
         assertSql("SELECT * FROM entity e " +
                 "right join entity1 e1 on e1.id = e.id " +
                 "left join entity2 e2 on e1.id = e2.id",
@@ -247,23 +247,23 @@ class TenantLineInnerInterceptorTest {
 
 
     @Test
-    void selectJoinSubSelect(){
+    void selectJoinSubSelect() {
         assertSql("select * from (select * from entity) e1 " +
-            "left join entity2 e2 on e1.id = e2.id",
-            "SELECT * FROM (SELECT * FROM entity WHERE tenant_id = 1) e1 " +
+                "left join entity2 e2 on e1.id = e2.id",
+            "SELECT * FROM (SELECT * FROM entity WHERE entity.tenant_id = 1) e1 " +
                 "LEFT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1");
 
         assertSql("select * from entity1 e1 " +
                 "left join (select * from entity2) e2 " +
                 "on e1.id = e2.id",
             "SELECT * FROM entity1 e1 " +
-                "LEFT JOIN (SELECT * FROM entity2 WHERE tenant_id = 1) e2 " +
+                "LEFT JOIN (SELECT * FROM entity2 WHERE entity2.tenant_id = 1) e2 " +
                 "ON e1.id = e2.id " +
                 "WHERE e1.tenant_id = 1");
     }
 
     @Test
-    void selectSubJoin(){
+    void selectSubJoin() {
 
         assertSql("select * FROM " +
                 "(entity1 e1 right JOIN entity2 e2 ON e1.id = e2.id)",
@@ -356,19 +356,45 @@ class TenantLineInnerInterceptorTest {
                 "INNER JOIN entity1 e1 ON e1.id = e.id AND e.tenant_id = 1 AND e1.tenant_id = 1 " +
                 "WHERE (e.id = ? OR e.name = ?)");
 
-        // 垃圾 inner join todo
-//        assertSql("SELECT * FROM entity,entity1 " +
-//                "WHERE entity.id = entity1.id",
-//            "SELECT * FROM entity e " +
-//                "INNER JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
-//                "WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
+        // 隐式内连接
+        assertSql("SELECT * FROM entity,entity1 " +
+                "WHERE entity.id = entity1.id",
+            "SELECT * FROM entity, entity1 " +
+                "WHERE entity.id = entity1.id AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
+
+        // SubJoin with 隐式内连接
+        assertSql("SELECT * FROM (entity,entity1) " +
+                "WHERE entity.id = entity1.id",
+            "SELECT * FROM (entity, entity1) " +
+                "WHERE entity.id = entity1.id " +
+                "AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
+
+        assertSql("SELECT * FROM ((entity,entity1),entity2) " +
+                "WHERE entity.id = entity1.id and entity.id = entity2.id",
+            "SELECT * FROM ((entity, entity1), entity2) " +
+                "WHERE entity.id = entity1.id AND entity.id = entity2.id " +
+                "AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
+
+        assertSql("SELECT * FROM (entity,(entity1,entity2)) " +
+                "WHERE entity.id = entity1.id and entity.id = entity2.id",
+            "SELECT * FROM (entity, (entity1, entity2)) " +
+                "WHERE entity.id = entity1.id AND entity.id = entity2.id " +
+                "AND entity.tenant_id = 1 AND entity1.tenant_id = 1 AND entity2.tenant_id = 1");
+
+        // 沙雕的括号写法
+        assertSql("SELECT * FROM (((entity,entity1))) " +
+                "WHERE entity.id = entity1.id",
+            "SELECT * FROM (((entity, entity1))) " +
+                "WHERE entity.id = entity1.id " +
+                "AND entity.tenant_id = 1 AND entity1.tenant_id = 1");
+
     }
 
 
     @Test
     void selectWithAs() {
         assertSql("with with_as_A as (select * from entity) select * from with_as_A",
-            "WITH with_as_A AS (SELECT * FROM entity WHERE tenant_id = 1) SELECT * FROM with_as_A");
+            "WITH with_as_A AS (SELECT * FROM entity WHERE entity.tenant_id = 1) SELECT * FROM with_as_A");
     }
 
     void assertSql(String sql, String targetSql) {