Jelajahi Sumber

修复非法SQL拦截插件索引信息读取问题.

nieqiurong 4 bulan lalu
induk
melakukan
8cd0353687

+ 39 - 22
mybatis-plus-jsqlparser-support/mybatis-plus-jsqlparser-4.9/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java

@@ -232,19 +232,12 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
         //是否使用索引
         boolean useIndexFlag = false;
         if (StringUtils.isNotBlank(columnName)) {
-            String tableInfo = table.getName();
+            String tableName = table.getName();
             //表存在的索引
-            String dbName = null;
-            String tableName;
-            String[] tableArray = tableInfo.split("\\.");
-            if (tableArray.length == 1) {
-                tableName = tableArray[0];
-            } else {
-                dbName = tableArray[0];
-                tableName = tableArray[1];
-            }
+            String dbName = getPartItemValue(table, 1);
+            String catalogName = getPartItemValue(table, 2);
             columnName = SqlParserUtils.removeWrapperSymbol(columnName);
-            List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
+            List<IndexInfo> indexInfos = getIndexInfos(catalogName, dbName, tableName, connection);
             for (IndexInfo indexInfo : indexInfos) {
                 if (indexInfo.getColumnName().equalsIgnoreCase(columnName)) {
                     useIndexFlag = true;
@@ -253,10 +246,14 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
             }
         }
         if (!useIndexFlag) {
-            throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
+            throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table.getName() + ", columnName:" + columnName);
         }
     }
 
+    private String getPartItemValue(Table table, int index) {
+        return index < table.getNameParts().size() ? table.getNameParts().get(index) : null;
+    }
+
     /**
      * 验证where条件的字段,是否有not、or等等,并且where的第一个字段,必须使用索引
      *
@@ -315,10 +312,12 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
     /**
      * 得到表的索引信息
      *
-     * @param dbName    ignore
-     * @param tableName ignore
-     * @param conn      ignore
-     * @return ignore
+     * @param dbName    数据库名
+     * @param tableName 表名
+     * @param conn      数据库连接
+     * @return 索引信息
+     * @see #getIndexInfos(String, String, String, String, Connection)
+     * @deprecated 3.5.11
      */
     public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
         return getIndexInfos(null, dbName, tableName, conn);
@@ -327,13 +326,31 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
     /**
      * 得到表的索引信息
      *
-     * @param key       ignore
-     * @param dbName    ignore
-     * @param tableName ignore
-     * @param conn      ignore
-     * @return ignore
+     * @param key       缓存key
+     * @param dbName    数据库名
+     * @param tableName 表名
+     * @param conn      数据库连接
+     * @return 索引信息
+     * @see #getIndexInfos(String, String, String, String, Connection)
+     * @deprecated 3.5.11
      */
+    @Deprecated
     public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
+        return getIndexInfos(key, null, dbName, tableName, conn);
+    }
+
+    /**
+     * 得到表的索引信息
+     *
+     * @param key         缓存key
+     * @param catalogName catalogName
+     * @param dbName      数据库名
+     * @param tableName   表名
+     * @param conn        数据库连接
+     * @return 索引信息
+     * @since 3.5.11
+     */
+    public List<IndexInfo> getIndexInfos(String key, String catalogName, String dbName, String tableName, Connection conn) {
         List<IndexInfo> indexInfos = null;
         if (StringUtils.isNotBlank(key)) {
             indexInfos = indexInfoMap.get(key);
@@ -342,7 +359,7 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
             ResultSet rs;
             try {
                 DatabaseMetaData metadata = conn.getMetaData();
-                String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
+                String catalog = StringUtils.isBlank(catalogName) ? conn.getCatalog() : catalogName;
                 String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
                 rs = metadata.getIndexInfo(catalog, schema, SqlParserUtils.removeWrapperSymbol(tableName), false, true);
                 indexInfos = new ArrayList<>();

+ 12 - 2
mybatis-plus-jsqlparser-support/mybatis-plus-jsqlparser-4.9/src/test/java/com/baomidou/mybatisplus/test/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java

@@ -78,8 +78,8 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
 
-        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
-        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM PUBLIC.`T_DEMO` a INNER JOIN PUBLIC.`T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM PUBLIC.`T_DEMO` a INNER JOIN PUBLIC.`T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
 
         Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
@@ -120,4 +120,14 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1)) a ", dataSource.getConnection()));
     }
 
+    @Test
+    void testCatalogAndSchemaName() {
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from TEST.PUBLIC.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from PUBLIC.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        // 非同一模式,读不到索引的情况
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from DB.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from PUBLIC.DB.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+    }
+
+
 }

+ 24 - 25
mybatis-plus-jsqlparser-support/mybatis-plus-jsqlparser-5.0/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java

@@ -231,19 +231,12 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
         //是否使用索引
         boolean useIndexFlag = false;
         if (StringUtils.isNotBlank(columnName)) {
-            String tableInfo = table.getName();
             //表存在的索引
-            String dbName = null;
-            String tableName;
-            String[] tableArray = tableInfo.split("\\.");
-            if (tableArray.length == 1) {
-                tableName = tableArray[0];
-            } else {
-                dbName = tableArray[0];
-                tableName = tableArray[1];
-            }
+            String dbName = table.getSchemaName();
+            String tableName = table.getName();
+            String catalogName = table.getCatalogName();
             columnName = SqlParserUtils.removeWrapperSymbol(columnName);
-            List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
+            List<IndexInfo> indexInfos = getIndexInfos(catalogName, dbName, tableName, connection);
             for (IndexInfo indexInfo : indexInfos) {
                 if (indexInfo.getColumnName().equalsIgnoreCase(columnName)) {
                     useIndexFlag = true;
@@ -252,7 +245,7 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
             }
         }
         if (!useIndexFlag) {
-            throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
+            throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table.getName() + ", columnName:" + columnName);
         }
     }
 
@@ -314,25 +307,31 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
     /**
      * 得到表的索引信息
      *
-     * @param dbName    ignore
-     * @param tableName ignore
-     * @param conn      ignore
-     * @return ignore
+     * @param key       缓存key
+     * @param dbName    数据库名
+     * @param tableName 表名
+     * @param conn      数据库连接
+     * @return 索引信息
+     * @see #getIndexInfos(String, String, String, String, Connection)
+     * @deprecated 3.5.11
      */
-    public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
-        return getIndexInfos(null, dbName, tableName, conn);
+    @Deprecated
+    public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
+        return getIndexInfos(key, null, dbName, tableName, conn);
     }
 
     /**
      * 得到表的索引信息
      *
-     * @param key       ignore
-     * @param dbName    ignore
-     * @param tableName ignore
-     * @param conn      ignore
-     * @return ignore
+     * @param key         缓存key
+     * @param catalogName catalogName
+     * @param dbName      数据库名
+     * @param tableName   表名
+     * @param conn        数据库连接
+     * @return 索引信息
+     * @since 3.5.11
      */
-    public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
+    public List<IndexInfo> getIndexInfos(String key, String catalogName, String dbName, String tableName, Connection conn) {
         List<IndexInfo> indexInfos = null;
         if (StringUtils.isNotBlank(key)) {
             indexInfos = indexInfoMap.get(key);
@@ -341,7 +340,7 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
             ResultSet rs;
             try {
                 DatabaseMetaData metadata = conn.getMetaData();
-                String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
+                String catalog = StringUtils.isBlank(catalogName) ? conn.getCatalog() : catalogName;
                 String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
                 rs = metadata.getIndexInfo(catalog, schema, SqlParserUtils.removeWrapperSymbol(tableName), false, true);
                 indexInfos = new ArrayList<>();

+ 13 - 2
mybatis-plus-jsqlparser-support/mybatis-plus-jsqlparser-5.0/src/test/java/com/baomidou/mybatisplus/test/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java

@@ -78,8 +78,8 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
 
-        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
-        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM PUBLIC.`T_DEMO` a INNER JOIN PUBLIC.`T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM PUBLIC.`T_DEMO` a INNER JOIN PUBLIC.`T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
 
         Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
@@ -120,4 +120,15 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1)) a ", dataSource.getConnection()));
     }
 
+
+    @Test
+    void testCatalogAndSchemaName() {
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from TEST.PUBLIC.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from PUBLIC.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        // 非同一模式,读不到索引的情况
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from DB.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from PUBLIC.DB.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+    }
+
+
 }

+ 34 - 12
mybatis-plus-jsqlparser-support/mybatis-plus-jsqlparser/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java

@@ -234,8 +234,9 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
             //表存在的索引
             String dbName = table.getSchemaName();
             String tableName = table.getName();
+            String catalogName = table.getCatalogName();
             columnName = SqlParserUtils.removeWrapperSymbol(columnName);
-            List<IndexInfo> indexInfos = getIndexInfos(dbName, tableName, connection);
+            List<IndexInfo> indexInfos = getIndexInfos(null, catalogName, dbName, tableName, connection);
             for (IndexInfo indexInfo : indexInfos) {
                 if (indexInfo.getColumnName().equalsIgnoreCase(columnName)) {
                     useIndexFlag = true;
@@ -244,7 +245,7 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
             }
         }
         if (!useIndexFlag) {
-            throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table + ", columnName:" + columnName);
+            throw new MybatisPlusException("非法SQL,SQL未使用到索引, table:" + table.getName() + ", columnName:" + columnName);
         }
     }
 
@@ -306,11 +307,14 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
     /**
      * 得到表的索引信息
      *
-     * @param dbName    ignore
-     * @param tableName ignore
-     * @param conn      ignore
-     * @return ignore
+     * @param dbName    数据库名
+     * @param tableName 表名
+     * @param conn      数据库连接
+     * @return 索引信息
+     * @see #getIndexInfos(String, String, String, String, Connection)
+     * @deprecated 3.5.11
      */
+    @Deprecated
     public List<IndexInfo> getIndexInfos(String dbName, String tableName, Connection conn) {
         return getIndexInfos(null, dbName, tableName, conn);
     }
@@ -318,13 +322,31 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
     /**
      * 得到表的索引信息
      *
-     * @param key       ignore
-     * @param dbName    ignore
-     * @param tableName ignore
-     * @param conn      ignore
-     * @return ignore
+     * @param key       缓存key
+     * @param dbName    数据库名
+     * @param tableName 表名
+     * @param conn      数据库连接
+     * @return 索引信息
+     * @see #getIndexInfos(String, String, String, String, Connection)
+     * @deprecated 3.5.11
      */
+    @Deprecated
     public List<IndexInfo> getIndexInfos(String key, String dbName, String tableName, Connection conn) {
+        return getIndexInfos(key, null, dbName, tableName, conn);
+    }
+
+    /**
+     * 得到表的索引信息
+     *
+     * @param key         缓存key
+     * @param catalogName catalogName
+     * @param dbName      数据库名
+     * @param tableName   表名
+     * @param conn        数据库连接
+     * @return 索引信息
+     * @since 3.5.11
+     */
+    public List<IndexInfo> getIndexInfos(String key, String catalogName, String dbName, String tableName, Connection conn) {
         List<IndexInfo> indexInfos = null;
         if (StringUtils.isNotBlank(key)) {
             indexInfos = indexInfoMap.get(key);
@@ -333,7 +355,7 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
             ResultSet rs;
             try {
                 DatabaseMetaData metadata = conn.getMetaData();
-                String catalog = StringUtils.isBlank(dbName) ? conn.getCatalog() : dbName;
+                String catalog = StringUtils.isBlank(catalogName) ? conn.getCatalog() : catalogName;
                 String schema = StringUtils.isBlank(dbName) ? conn.getSchema() : dbName;
                 rs = metadata.getIndexInfo(catalog, schema, SqlParserUtils.removeWrapperSymbol(tableName), false, true);
                 indexInfos = new ArrayList<>();

+ 11 - 2
mybatis-plus-jsqlparser-support/mybatis-plus-jsqlparser/src/test/java/com/baomidou/mybatisplus/test/extension/plugins/inner/IllegalSQLInnerInterceptorTest.java

@@ -78,8 +78,8 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM `T_DEMO` a INNER JOIN `test` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
 
-        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
-        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM test.`T_DEMO` a INNER JOIN test.`T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM PUBLIC.`T_DEMO` a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM PUBLIC.`T_DEMO` a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
 
         Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.a = 1", dataSource.getConnection()));
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("SELECT * FROM T_DEMO a INNER JOIN `T_TEST` b ON a.a = b.a WHERE a.b = 1", dataSource.getConnection()));
@@ -120,4 +120,13 @@ class IllegalSQLInnerInterceptorTest {
         Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from (select * from `T_DEMO` where b = (SELECT b FROM T_TEST limit 1)) a ", dataSource.getConnection()));
     }
 
+    @Test
+    void testCatalogAndSchemaName() {
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from TEST.PUBLIC.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertDoesNotThrow(() -> interceptor.parserSingle("select count(*) from PUBLIC.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        // 非同一模式,读不到索引的情况
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from DB.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+        Assertions.assertThrows(MybatisPlusException.class, () -> interceptor.parserSingle("select count(*) from PUBLIC.DB.T_DEMO where a = 1 and `b` = 2", dataSource.getConnection()));
+    }
+
 }