Browse Source

解决表名或字段名包裹导致无法获取索引信息和索引字段校验问题.

https://github.com/baomidou/mybatis-plus/issues/5578
nieqiurong 1 year ago
parent
commit
a84e466478

+ 1 - 0
mybatis-plus-extension/build.gradle

@@ -27,4 +27,5 @@ dependencies {
     testImplementation "com.google.guava:guava:33.0.0-jre"
     testImplementation "io.github.classgraph:classgraph:4.8.165"
     testImplementation "${lib.h2}"
+    testImplementation "${lib.mysql}"
 }

+ 21 - 18
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java

@@ -22,6 +22,7 @@ import com.baomidou.mybatisplus.core.toolkit.EncryptUtils;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
+import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
 import lombok.Data;
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.Expression;
@@ -214,23 +215,25 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
     private void validUseIndex(Table table, String columnName, Connection connection) {
         //是否使用索引
         boolean useIndexFlag = false;
-
-        String tableInfo = table.getName();
-        //表存在的索引
-        String dbName = null;
-        String tableName;
-        String[] tableArray = tableInfo.split("\\.");
-        if (tableArray.length == 1) {
-            tableName = tableArray[0];
-        } else {
-            dbName = tableArray[0];
-            tableName = tableArray[1];
-        }
-        List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
-        for (IndexInfo indexInfo : indexInfos) {
-            if (null != columnName && columnName.equalsIgnoreCase(indexInfo.getColumnName())) {
-                useIndexFlag = true;
-                break;
+        if (StringUtils.isNotBlank(columnName)) {
+            String tableInfo = table.getName();
+            //表存在的索引
+            String dbName = null;
+            String tableName;
+            String[] tableArray = tableInfo.split("\\.");
+            if (tableArray.length == 1) {
+                tableName = tableArray[0];
+            } else {
+                dbName = tableArray[0];
+                tableName = tableArray[1];
+            }
+            columnName = SqlParserUtils.removeWrapperSymbol(columnName);
+            List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
+            for (IndexInfo indexInfo : indexInfos) {
+                if (indexInfo.getColumnName().equalsIgnoreCase(columnName)) {
+                    useIndexFlag = true;
+                    break;
+                }
             }
         }
         if (!useIndexFlag) {
@@ -323,7 +326,7 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
                 DatabaseMetaData metadata = conn.getMetaData();
                 String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
                 String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
-                rs = metadata.getIndexInfo(catalog, schema, tableName, false, true);
+                rs = metadata.getIndexInfo(catalog, schema, SqlParserUtils.removeWrapperSymbol(tableName), false, true);
                 indexInfos = new ArrayList<>();
                 while (rs.next()) {
                     //索引中的列序列号等于1,才有效

+ 21 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SqlParserUtils.java

@@ -15,6 +15,8 @@
  */
 package com.baomidou.mybatisplus.extension.toolkit;
 
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
+
 /**
  * SQL 解析工具类
  *
@@ -32,4 +34,23 @@ public class SqlParserUtils {
     public static String getOriginalCountSql(String originalSql) {
         return String.format("SELECT COUNT(*) FROM (%s) TOTAL", originalSql);
     }
+
+    /**
+     * 去除表或字段包裹符号
+     *
+     * @param tableOrColumn 表名或字段名
+     * @return str
+     * @since 3.5.6
+     */
+    public static String removeWrapperSymbol(String tableOrColumn) {
+        if (StringUtils.isBlank(tableOrColumn)) {
+            return null;
+        }
+        if (tableOrColumn.startsWith("`") || tableOrColumn.startsWith("\"")
+            || tableOrColumn.startsWith("[") || tableOrColumn.startsWith("<")) {
+            return tableOrColumn.substring(1, tableOrColumn.length() - 1);
+        }
+        return tableOrColumn;
+    }
+
 }

+ 77 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java

@@ -1,9 +1,17 @@
 package com.baomidou.mybatisplus.extension.plugins.inner;
 
 import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
+import com.mysql.cj.jdbc.MysqlDataSource;
+import org.apache.ibatis.jdbc.SqlRunner;
+import org.h2.jdbcx.JdbcDataSource;
 import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Disabled;
 import org.junit.jupiter.api.Test;
 
+import java.sql.Connection;
+import java.sql.SQLException;
+
 /**
  * @author miemie
  * @since 2022-04-11
@@ -11,6 +19,26 @@ import org.junit.jupiter.api.Test;
 class IllegalSQLInnerInterceptorTest {
 
     private final IllegalSQLInnerInterceptor interceptor = new IllegalSQLInnerInterceptor();
+//
+      // 待研究为啥H2读不到索引信息
+//    private static Connection connection;
+//
+//    @BeforeAll
+//    public static void beforeAll() throws SQLException {
+//        var jdbcDataSource = new JdbcDataSource();
+//        jdbcDataSource.setURL("jdbc:h2:mem:test;MODE=mysql;DB_CLOSE_DELAY=-1;DB_CLOSE_ON_EXIT=FALSE");
+//        connection = jdbcDataSource.getConnection("sa","");
+//        var sql = """
+//            CREATE TABLE t_demo (
+//              `a` int DEFAULT NULL,
+//              `b` int DEFAULT NULL,
+//              `c` int DEFAULT NULL,
+//              KEY `ab_index` (`a`,`b`)
+//            );
+//            """;
+//        SqlRunner sqlRunner = new SqlRunner(connection);
+//        sqlRunner.run(sql);
+//    }
 
     @Test
     void test() {
@@ -20,6 +48,55 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("delete from t_user set age = 18", null));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age != 1", null));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_user where age = 1 or name = 'test'", null));
+//        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where a = 1 and b = 2", connection));
+    }
+
+    @Test
+//    @Disabled
+    void testMysql(){
+        /*
+         *   CREATE TABLE `t_demo` (
+              `a` int DEFAULT NULL,
+              `b` int DEFAULT NULL,
+              `c` int DEFAULT NULL,
+              KEY `ab_index` (`a`,`b`)
+            );
+            CREATE TABLE `test` (
+              `a` int DEFAULT NULL,
+              `b` int DEFAULT NULL,
+              `c` int DEFAULT NULL,
+              KEY `ab_index` (`a`,`b`)
+            ) ;
+         */
+        var dataSource = new MysqlDataSource();
+        dataSource.setUrl("jdbc:mysql://127.0.0.1:3306/test?serverTimezone=Asia/Shanghai");
+        dataSource.setUser("root");
+        dataSource.setPassword("123456");
+
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from t_demo where `a` = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from t_demo where a = 1 and `b` = 2", dataSource.getConnection()));
+
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where `a` = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select * from `t_demo` where a = 1 and `b` = 2", dataSource.getConnection()));
+
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from `t_demo` where c = 3", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from t_demo where c = 3", dataSource.getConnection()));
+
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.`t_demo` where c = 3", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select * from test.t_demo where c = 3", dataSource.getConnection()));
+
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `t_demo` a INNER JOIN `test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `t_demo` a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
+
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM test.`t_demo` a INNER JOIN test.`test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM test.`t_demo` a INNER JOIN test.`test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
+
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM t_demo a INNER JOIN `test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM t_demo a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
+
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `t_demo` a LEFT JOIN `test` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `t_demo` a LEFT JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
+
     }
 
 }

+ 22 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/test/SqlParserUtilsTest.java

@@ -0,0 +1,22 @@
+package com.baomidou.mybatisplus.test;
+
+import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class SqlParserUtilsTest {
+
+    @Test
+    void testRemoveWrapperSymbol() {
+        //用SQLServer的人喜欢写这种
+        Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("[Demo]"), "Demo");
+        Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("Demo"), "Demo");
+        //mysql比较常见
+        Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("`Demo`"), "Demo");
+        //用关键字表的
+        Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("\"Demo\""), "Demo");
+        //这种少
+        Assertions.assertEquals(SqlParserUtils.removeWrapperSymbol("<Demo>"), "Demo");
+    }
+
+}