瀏覽代碼

!146 修复在mysql特殊语句on duplicate key update下把字段名判断为表名的问题
Merge pull request !146 from Normcorer/20210506_fixOnDuplicateKeyUpdate

青苗 4 年之前
父節點
當前提交
ed8a9c148f

+ 23 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/TableNameParser.java

@@ -47,6 +47,7 @@ public final class TableNameParser {
     private static final String KEYWORD_FROM = "from";
     private static final String KEYWORD_USING = "using";
     private static final String KEYWORD_UPDATE = "update";
+    private static final String KEYWORD_DUPLICATE = "duplicate";
 
     private static final List<String> concerned = Arrays.asList(KEYWORD_TABLE, KEYWORD_INTO, KEYWORD_JOIN, KEYWORD_USING, KEYWORD_UPDATE);
     private static final List<String> ignored = Arrays.asList(TOKEN_GROUP_START, TOKEN_SET, TOKEN_OF, TOKEN_DUAL);
@@ -96,6 +97,8 @@ public final class TableNameParser {
                 String current = tokens.get(index++).getValue();
                 if (isFromToken(current)) {
                     processFromToken(tokens, index, visitor);
+                } else if (isOnDuplicateKeyUpdate(current, index)) {
+                    index = skipDuplicateKeyUpdateIndex(index);
                 } else if (concerned.contains(current.toLowerCase())) {
                     if (hasMoreTokens(tokens, index)) {
                         SqlToken next = tokens.get(index++);
@@ -166,6 +169,21 @@ public final class TableNameParser {
         return false;
     }
 
+    /**
+     * @param current 当前token
+     * @param index   索引
+     * @return 判断是否是mysql的特殊语法 on duplicate key update
+     */
+    private boolean isOnDuplicateKeyUpdate(String current, int index) {
+        if (KEYWORD_DUPLICATE.equals(current.toLowerCase())) {
+            if (hasMoreTokens(tokens, index++)) {
+                String next = tokens.get(index).getValue();
+                return KEYWORD_UPDATE.equals(next.toLowerCase());
+            }
+        }
+        return false;
+    }
+
     private static boolean hasIthToken(List<SqlToken> tokens, int currentIndex) {
         return hasMoreTokens(tokens, currentIndex) && tokens.size() > currentIndex + 3;
     }
@@ -174,6 +192,11 @@ public final class TableNameParser {
         return KEYWORD_FROM.equals(currentToken.toLowerCase());
     }
 
+    private int skipDuplicateKeyUpdateIndex(int index) {
+        // on duplicate key update为mysql的固定写法,直接跳过即可。
+        return index + 2;
+    }
+
     private static void processFromToken(List<SqlToken> tokens, int index, TableNameVisitor visitor) {
         SqlToken sqlToken = tokens.get(index++);
         visitNameToken(sqlToken, visitor);

+ 6 - 0
mybatis-plus-core/src/test/java/com/baomidou/mybatisplus/test/toolkit/TableNameParserTest.java

@@ -486,6 +486,12 @@ public class TableNameParserTest {
         assertThat(new TableNameParser("select * from mp where id = 1 for update").tables()).isEqualTo(asSet("mp"));
     }
 
+    @Test
+    public void testOnDuplicateKeyUpdate () {
+        String sql = "INSERT INTO cf_procedure (_id,password) VALUES ('1','password') ON DUPLICATE KEY UPDATE id = 'UpId', password = 'upPassword';";
+        assertThat(new TableNameParser(sql).tables()).isEqualTo(asSet("cf_procedure"));
+    }
+
     private static Collection<String> asSet(String... a) {
         Set<String> result = new HashSet<>();
         Collections.addAll(result, a);