Browse Source

修复创建索引获取表错误.

nieqiurong 1 tháng trước cách đây
mục cha
commit
df9239e5bc

+ 50 - 17
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/TableNameParser.java

@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.List;
 import java.util.List;
 import java.util.Map;
 import java.util.Map;
+import java.util.Set;
 import java.util.regex.Matcher;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 import java.util.regex.Pattern;
 
 
@@ -59,6 +60,13 @@ public final class TableNameParser {
     private static final List<String> concerned = Arrays.asList(KEYWORD_TABLE, KEYWORD_INTO, KEYWORD_JOIN, KEYWORD_USING, KEYWORD_UPDATE, KEYWORD_STRAIGHT_JOIN);
     private static final List<String> concerned = Arrays.asList(KEYWORD_TABLE, KEYWORD_INTO, KEYWORD_JOIN, KEYWORD_USING, KEYWORD_UPDATE, KEYWORD_STRAIGHT_JOIN);
     private static final List<String> ignored = Arrays.asList(StringPool.LEFT_BRACKET, TOKEN_SET, TOKEN_OF, TOKEN_DUAL);
     private static final List<String> ignored = Arrays.asList(StringPool.LEFT_BRACKET, TOKEN_SET, TOKEN_OF, TOKEN_DUAL);
 
 
+    /**
+     * 索引类型
+     *
+     * @since 3.5.11
+     */
+    private static final Set<String> INDEX_TYPES = new HashSet<>(Arrays.asList("UNIQUE", "FULLTEXT", "SPATIAL", "CLUSTERED", "NONCLUSTERED"));
+
     /**
     /**
      * 该表达式会匹配 SQL 中不是 SQL TOKEN 的部分,比如换行符,注释信息,结尾的 {@code ;} 等。
      * 该表达式会匹配 SQL 中不是 SQL TOKEN 的部分,比如换行符,注释信息,结尾的 {@code ;} 等。
      * <p>
      * <p>
@@ -96,11 +104,16 @@ public final class TableNameParser {
         int index = 0;
         int index = 0;
         String first = tokens.get(index).getValue();
         String first = tokens.get(index).getValue();
         if (isOracleSpecialDelete(first, tokens, index)) {
         if (isOracleSpecialDelete(first, tokens, index)) {
-            visitNameToken(tokens.get(index + 1), visitor);
+            visitNameToken(safeGetToken(index + 1), visitor);
         } else if (isCreateIndex(first, tokens, index)) {
         } else if (isCreateIndex(first, tokens, index)) {
-            visitNameToken(tokens.get(index + 4), visitor);
+            String value = tokens.get(index + 4).getValue();
+            if("ON".equalsIgnoreCase(value)) {
+                visitNameToken(safeGetToken(index + 5), visitor);
+            } else {
+                visitNameToken(safeGetToken(index + 4), visitor);
+            }
         } else if (isCreateTableIfNotExist(first, tokens, index)) {
         } else if (isCreateTableIfNotExist(first, tokens, index)) {
-            visitNameToken(tokens.get(index + 5), visitor);
+            visitNameToken(safeGetToken(index + 5), visitor);
         } else {
         } else {
             while (hasMoreTokens(tokens, index)) {
             while (hasMoreTokens(tokens, index)) {
                 String current = tokens.get(index++).getValue();
                 String current = tokens.get(index++).getValue();
@@ -122,6 +135,17 @@ public final class TableNameParser {
         }
         }
     }
     }
 
 
+    /**
+     * 安全访问获取SqlToken
+     *
+     * @param index 索引
+     * @return 超出索引返回 null,否则返回SqlToken
+     * @since 3.5.11
+     */
+    private SqlToken safeGetToken(int index) {
+        return index < tokens.size() ? tokens.get(index) : null;
+    }
+
     /**
     /**
      * 表名访问器
      * 表名访问器
      */
      */
@@ -173,23 +197,27 @@ public final class TableNameParser {
         return false;
         return false;
     }
     }
 
 
+    // CREATE INDEX temp_name_idx ON table1(name) NOLOGGING PARALLEL (DEGREE 8);
+    // CREATE FULLTEXT INDEX ft_users_content ON users(content);
     private boolean isCreateIndex(String current, List<SqlToken> tokens, int index) {
     private boolean isCreateIndex(String current, List<SqlToken> tokens, int index) {
-        index++; // Point to next token
-        if (TOKEN_CREATE.equalsIgnoreCase(current) && hasIthToken(tokens, index)) {
-            String next = tokens.get(index).getValue();
+        if (TOKEN_CREATE.equalsIgnoreCase(current) && hasMoreTokens(tokens, index + 4)) {
+            String next = tokens.get(index + 1).getValue();
+            if (INDEX_TYPES.contains(next.toUpperCase())) {
+                next = tokens.get(index + 2).getValue();
+            }
             return TOKEN_INDEX.equalsIgnoreCase(next);
             return TOKEN_INDEX.equalsIgnoreCase(next);
         }
         }
         return false;
         return false;
     }
     }
-    
+
+    //create table if not exists `user_info`
     private boolean isCreateTableIfNotExist(String current, List<SqlToken> tokens, int index) {
     private boolean isCreateTableIfNotExist(String current, List<SqlToken> tokens, int index) {
-        index++; // Point to next token
-        if (TOKEN_CREATE.equalsIgnoreCase(current) && hasIthToken(tokens, index)) {
-            String tableIfNotExist = tokens.get(index).getValue();
-            tableIfNotExist += tokens.get(++index).getValue();
-            tableIfNotExist += tokens.get(++index).getValue();
-            tableIfNotExist += tokens.get(++index).getValue();
-            return "tableifnotexists".equalsIgnoreCase(tableIfNotExist);
+        if (TOKEN_CREATE.equalsIgnoreCase(current) && hasMoreTokens(tokens, index + 5)) {
+            StringBuilder tableIfNotExist = new StringBuilder();
+            for (int i = index; i <= index + 4; i++) {
+                tableIfNotExist.append(tokens.get(i).getValue());
+            }
+            return "createtableifnotexists".equalsIgnoreCase(tableIfNotExist.toString());
         }
         }
         return false;
         return false;
     }
     }
@@ -209,6 +237,9 @@ public final class TableNameParser {
         return false;
         return false;
     }
     }
 
 
+    /**
+     * @deprecated 3.5.11 建议使用 {@link #hasMoreTokens(List, int)} 判断
+     */
     private static boolean hasIthToken(List<SqlToken> tokens, int currentIndex) {
     private static boolean hasIthToken(List<SqlToken> tokens, int currentIndex) {
         return hasMoreTokens(tokens, currentIndex) && tokens.size() > currentIndex + 3;
         return hasMoreTokens(tokens, currentIndex) && tokens.size() > currentIndex + 3;
     }
     }
@@ -280,9 +311,11 @@ public final class TableNameParser {
     }
     }
 
 
     private static void visitNameToken(SqlToken token, TableNameVisitor visitor) {
     private static void visitNameToken(SqlToken token, TableNameVisitor visitor) {
-        String value = token.getValue().toLowerCase();
-        if (!ignored.contains(value)) {
-            visitor.visit(token);
+        if (token != null) {
+            String value = token.getValue().toLowerCase();
+            if (!ignored.contains(value)) {
+                visitor.visit(token);
+            }
         }
         }
     }
     }
 
 

+ 28 - 0
mybatis-plus-core/src/test/java/com/baomidou/mybatisplus/test/toolkit/TableNameParserTest.java

@@ -492,6 +492,34 @@ public class TableNameParserTest {
         assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("student"));
         assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("student"));
     }
     }
 
 
+    @Test
+    void testCreateTableIfNotExists() {
+        var sql = """
+            CREATE TABLE IF NOT EXISTS `user_info` (
+                `id` INT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
+                `username` VARCHAR(50) NOT NULL UNIQUE,
+                `email` VARCHAR(100) NOT NULL UNIQUE,
+                `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+            ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
+            """;
+        assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("`user_info`"));
+    }
+
+    @Test
+    void testCreateUniqueIndex() {
+        var sql = "CREATE UNIQUE INDEX index_name ON table1 (a, b)";
+        assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("table1"));
+        sql = "ALTER TABLE table1 ADD UNIQUE INDEX `a`(`a`)";
+        assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("table1"));
+    }
+
+    @Test
+    void testCreateFullTextIndex(){
+        var sql = "CREATE FULLTEXT INDEX index_name ON table1 (a, b)";
+        assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("table1"));
+        sql = "ALTER TABLE table1 ADD FULLTEXT INDEX `a`(`a`)";
+        assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("table1"));
+    }
 
 
 
 
     private static Collection<String> asSet(String... a) {
     private static Collection<String> asSet(String... a) {