瀏覽代碼

修复IllegalSQLInnerInterceptor分析嵌套count语句错误.

https://github.com/baomidou/mybatis-plus/issues/6311
nieqiurong 1 年之前
父節點
當前提交
880fb82183

+ 8 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java

@@ -35,7 +35,9 @@ import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionLi
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.delete.Delete;
+import net.sf.jsqlparser.statement.select.FromItem;
 import net.sf.jsqlparser.statement.select.Join;
+import net.sf.jsqlparser.statement.select.ParenthesedSelect;
 import net.sf.jsqlparser.statement.select.PlainSelect;
 import net.sf.jsqlparser.statement.select.Select;
 import net.sf.jsqlparser.statement.update.Update;
@@ -120,6 +122,12 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
     protected void processSelect(Select select, int index, String sql, Object obj) {
         if (select instanceof PlainSelect) {
             PlainSelect plainSelect = (PlainSelect) select;
+            FromItem fromItem = ((PlainSelect) select).getFromItem();
+            while (fromItem instanceof ParenthesedSelect) {
+                ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
+                plainSelect = (PlainSelect) parenthesedSelect.getSelect();
+                fromItem = plainSelect.getFromItem();
+            }
             Expression where = plainSelect.getWhere();
             Assert.notNull(where, "非法SQL,必须要有where条件");
             Table table = (Table) plainSelect.getFromItem();

+ 8 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/test/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java

@@ -111,5 +111,13 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b >= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `T_DEMO` where b <= (SELECT b FROM T_TEST limit 1) ", dataSource.getConnection()));
     }
+    @Test
+    void testCount() {
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from (select * from T_DEMO where a = 1 and `b` = 2) a", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from (select count(*) from (select * from T_DEMO where a = 1 and `b` = 2) a) c", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO`) a ", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1)) a ", dataSource.getConnection()));
+    }
 
 }