瀏覽代碼

发布 3.5.0 测试用例修复

hubin 3 年之前
父節點
當前提交
d9389eae39

+ 12 - 6
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -78,7 +78,9 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         MappedStatement ms = mpSh.mappedStatement();
         SqlCommandType sct = ms.getSqlCommandType();
         if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
-            if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
+            if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
+                return;
+            }
             PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
             mpBs.sql(parserMulti(mpBs.sql(), null));
         }
@@ -225,10 +227,14 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      * @param selectItems SelectItem
      */
     protected void appendSelectItem(List<SelectItem> selectItems) {
-        if (CollectionUtils.isEmpty(selectItems)) return;
+        if (CollectionUtils.isEmpty(selectItems)) {
+            return;
+        }
         if (selectItems.size() == 1) {
             SelectItem item = selectItems.get(0);
-            if (item instanceof AllColumns || item instanceof AllTableColumns) return;
+            if (item instanceof AllColumns || item instanceof AllTableColumns) {
+                return;
+            }
         }
         selectItems.add(new SelectExpressionItem(new Column(tenantLineHandler.getTenantIdColumn())));
     }
@@ -326,9 +332,9 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
             } else if (where instanceof InExpression) {
                 // in
                 InExpression expression = (InExpression) where;
-                ItemsList itemsList = expression.getRightItemsList();
-                if (itemsList instanceof SubSelect) {
-                    processSelectBody(((SubSelect) itemsList).getSelectBody());
+                Expression inExpression = expression.getRightExpression();
+                if (inExpression instanceof SubSelect) {
+                    processSelectBody(((SubSelect) inExpression).getSelectBody());
                 }
             } else if (where instanceof ExistsExpression) {
                 // exists

+ 12 - 10
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/DynamicTableNameInnerInterceptorTest.java

@@ -20,25 +20,27 @@ class DynamicTableNameInnerInterceptorTest {
     @SuppressWarnings({"SqlDialectInspection", "SqlNoDataSourceInspection"})
     void doIt() {
         DynamicTableNameInnerInterceptor interceptor = new DynamicTableNameInnerInterceptor();
-        interceptor.setTableNameHandler((sql, tableName) -> "t_user_r");
+        interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");
+
         // 表名相互包含
         @Language("SQL")
-        String origin = "SELECT * FROM t_user, t_user_role", replaced = "SELECT * FROM t_user_r, t_user_role";
-        assertEquals(replaced, interceptor.changeTable(origin));
+        String origin = "SELECT * FROM t_user, t_user_role";
+        assertEquals("SELECT * FROM t_user_r, t_user_role_r", interceptor.changeTable(origin));
+
         // 表名在末尾
         origin = "SELECT * FROM t_user";
-        replaced = "SELECT * FROM t_user_r";
-        assertEquals(replaced, interceptor.changeTable(origin));
+        assertEquals("SELECT * FROM t_user_r", interceptor.changeTable(origin));
+
         // 表名前后有注释
         origin = "SELECT * FROM /**/t_user/* t_user */";
-        replaced = "SELECT * FROM /**/t_user_r/* t_user */";
-        assertEquals(replaced, interceptor.changeTable(origin));
+        assertEquals("SELECT * FROM /**/t_user_r/* t_user */", interceptor.changeTable(origin));
+
         // 值中带有表名
         origin = "SELECT * FROM t_user WHERE name = 't_user'";
-        replaced = "SELECT * FROM t_user_r WHERE name = 't_user'";
-        assertEquals(replaced, interceptor.changeTable(origin));
+        assertEquals("SELECT * FROM t_user_r WHERE name = 't_user'", interceptor.changeTable(origin));
+
         // 别名被声明要替换
         origin = "SELECT t_user.* FROM t_user_real t_user";
-        assertEquals(origin, interceptor.changeTable(origin));
+        assertEquals("SELECT t_user.* FROM t_user_real_r t_user", interceptor.changeTable(origin));
     }
 }

+ 13 - 13
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/PaginationInnerInterceptorTest.java

@@ -20,47 +20,47 @@ class PaginationInnerInterceptorTest {
     void optimizeCount() {
         /* 能进行优化的 SQL */
         assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id",
-            "SELECT COUNT(*) FROM user u");
+            "SELECT COUNT(*) AS total FROM user u");
 
         assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xx = ?",
-            "SELECT COUNT(*) FROM user u WHERE u.xx = ?");
+            "SELECT COUNT(*) AS total FROM user u WHERE u.xx = ?");
 
         assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id LEFT JOIN permission p on p.id = u.per_id",
-            "SELECT COUNT(*) FROM user u");
+            "SELECT COUNT(*) AS total FROM user u");
 
         assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id LEFT JOIN permission p on p.id = u.per_id WHERE u.xx = ?",
-            "SELECT COUNT(*) FROM user u WHERE u.xx = ?");
+            "SELECT COUNT(*) AS total FROM user u WHERE u.xx = ?");
     }
 
     @Test
     void notOptimizeCount() {
         /* 不能进行优化的 SQL */
         assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id AND r.name = ? where u.xx = ?",
-            "SELECT COUNT(*) FROM user u LEFT JOIN role r ON r.id = u.role_id AND r.name = ? WHERE u.xx = ?");
+            "SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id AND r.name = ? WHERE u.xx = ?");
 
         /* join 表与 where 条件大小写不同的情况 */
         assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id where R.NAME = ?",
-            "SELECT COUNT(*) FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE R.NAME = ?");
+            "SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE R.NAME = ?");
 
         assertsCountSql("select * from user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xax = ? AND r.cc = ? AND r.qq = ?",
-            "SELECT COUNT(*) FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xax = ? AND r.cc = ? AND r.qq = ?");
+            "SELECT COUNT(*) AS total FROM user u LEFT JOIN role r ON r.id = u.role_id WHERE u.xax = ? AND r.cc = ? AND r.qq = ?");
     }
 
     @Test
     void optimizeCountOrderBy() {
         /* order by 里不带参数,去除order by */
         assertsCountSql("SELECT * FROM comment ORDER BY name",
-            "SELECT COUNT(*) FROM comment");
+            "SELECT COUNT(*) AS total FROM comment");
 
         /* order by 里带参数,不去除order by */
         assertsCountSql("SELECT * FROM comment ORDER BY (CASE WHEN creator = ? THEN 0 ELSE 1 END)",
-            "SELECT COUNT(*) FROM comment ORDER BY (CASE WHEN creator = ? THEN 0 ELSE 1 END)");
+            "SELECT COUNT(*) AS total FROM comment ORDER BY (CASE WHEN creator = ? THEN 0 ELSE 1 END)");
     }
 
     @Test
     void withAsCount() {
         assertsCountSql("with A as (select * from class) select * from A",
-            "WITH A AS (SELECT * FROM class) SELECT COUNT(*) FROM A");
+            "WITH A AS (SELECT * FROM class) SELECT COUNT(*) AS total FROM A");
     }
 
     @Test
@@ -83,15 +83,15 @@ class PaginationInnerInterceptorTest {
                 "from reseller_acquire_log ral " +
                 "group by ral.reseller_id) rlr on r.id = rlr.reseller_id " +
                 "order by r.created_at desc",
-            "SELECT COUNT(*) FROM reseller r");
+            "SELECT COUNT(*) AS total FROM reseller r");
 
         // 不优化
         assertsCountSql("SELECT f.ca, f.cb FROM table_a f LEFT JOIN " +
                 "(SELECT ca FROM table_b WHERE cc = ?) rf on rf.ca = f.ca",
-            "SELECT COUNT(*) FROM table_a f LEFT JOIN (SELECT ca FROM table_b WHERE cc = ?) rf ON rf.ca = f.ca");
+            "SELECT COUNT(*) AS total FROM table_a f LEFT JOIN (SELECT ca FROM table_b WHERE cc = ?) rf ON rf.ca = f.ca");
 
         assertsCountSql("select * from order_info left join (select count(1) from order_info where create_time between ? and ?) tt on 1=1 WHERE equipment_id=?",
-            "SELECT COUNT(*) FROM order_info LEFT JOIN (SELECT count(1) FROM order_info WHERE create_time BETWEEN ? AND ?) tt ON 1 = 1 WHERE equipment_id = ?");
+            "SELECT COUNT(*) AS total FROM order_info LEFT JOIN (SELECT count(1) FROM order_info WHERE create_time BETWEEN ? AND ?) tt ON 1 = 1 WHERE equipment_id = ?");
     }
 
     void assertsCountSql(String sql, String targetSql) {