Explorar o código

!255 修复IllegalSQLInnerInterceptor类ClassCastException异常,并优化日志
Merge pull request !255 from uyong/3.0

青苗 %!s(int64=2) %!d(string=hai) anos
pai
achega
91bac93ea7

+ 19 - 9
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/IllegalSQLInnerInterceptor.java

@@ -35,6 +35,7 @@ import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.select.Join;
 import net.sf.jsqlparser.statement.select.PlainSelect;
 import net.sf.jsqlparser.statement.select.Select;
+import net.sf.jsqlparser.statement.select.SelectBody;
 import net.sf.jsqlparser.statement.select.SubSelect;
 import net.sf.jsqlparser.statement.update.Update;
 import org.apache.ibatis.executor.statement.StatementHandler;
@@ -46,7 +47,13 @@ import java.sql.Connection;
 import java.sql.DatabaseMetaData;
 import java.sql.ResultSet;
 import java.sql.SQLException;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 
 /**
@@ -108,13 +115,16 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
 
     @Override
     protected void processSelect(Select select, int index, String sql, Object obj) {
-        PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
-        Expression where = plainSelect.getWhere();
-        Assert.notNull(where, "非法SQL,必须要有where条件");
-        Table table = (Table) plainSelect.getFromItem();
-        List<Join> joins = plainSelect.getJoins();
-        validWhere(where, table, (Connection) obj);
-        validJoins(joins, table, (Connection) obj);
+        SelectBody selectBody = select.getSelectBody();
+        if (selectBody instanceof PlainSelect) {
+            PlainSelect plainSelect = (PlainSelect) selectBody;
+            Expression where = plainSelect.getWhere();
+            Assert.notNull(where, "非法SQL,必须要有where条件");
+            Table table = (Table) plainSelect.getFromItem();
+            List<Join> joins = plainSelect.getJoins();
+            validWhere(where, table, (Connection) obj);
+            validJoins(joins, table, (Connection) obj);
+        }
     }
 
     @Override
@@ -329,7 +339,7 @@ public class IllegalSQLInnerInterceptor extends JsqlParserSupport implements Inn
                     indexInfoMap.put(key, indexInfos);
                 }
             } catch (SQLException e) {
-                e.printStackTrace();
+                logger.error(String.format("getIndexInfo fault, with key:%s, dbName:%s, tableName:%s", key, dbName, tableName), e);
             }
         }
         return indexInfos;