|
@@ -1,37 +1,10 @@
|
|
|
package com.baomidou.mybatisplus.extension.plugins;
|
|
|
|
|
|
-import java.sql.Connection;
|
|
|
-import java.sql.DatabaseMetaData;
|
|
|
-import java.sql.ResultSet;
|
|
|
-import java.sql.SQLException;
|
|
|
-import java.util.ArrayList;
|
|
|
-import java.util.HashSet;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
-import java.util.Objects;
|
|
|
-import java.util.Properties;
|
|
|
-import java.util.Set;
|
|
|
-import java.util.concurrent.ConcurrentHashMap;
|
|
|
-
|
|
|
-import org.apache.ibatis.executor.statement.StatementHandler;
|
|
|
-import org.apache.ibatis.logging.Log;
|
|
|
-import org.apache.ibatis.logging.LogFactory;
|
|
|
-import org.apache.ibatis.mapping.BoundSql;
|
|
|
-import org.apache.ibatis.mapping.MappedStatement;
|
|
|
-import org.apache.ibatis.mapping.SqlCommandType;
|
|
|
-import org.apache.ibatis.plugin.Interceptor;
|
|
|
-import org.apache.ibatis.plugin.Intercepts;
|
|
|
-import org.apache.ibatis.plugin.Invocation;
|
|
|
-import org.apache.ibatis.plugin.Plugin;
|
|
|
-import org.apache.ibatis.plugin.Signature;
|
|
|
-import org.apache.ibatis.reflection.MetaObject;
|
|
|
-import org.apache.ibatis.reflection.SystemMetaObject;
|
|
|
-
|
|
|
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
|
|
|
import com.baomidou.mybatisplus.core.toolkit.EncryptUtils;
|
|
|
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
|
|
|
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
|
|
|
-
|
|
|
+import lombok.Data;
|
|
|
import net.sf.jsqlparser.expression.BinaryExpression;
|
|
|
import net.sf.jsqlparser.expression.Expression;
|
|
|
import net.sf.jsqlparser.expression.Function;
|
|
@@ -48,6 +21,22 @@ import net.sf.jsqlparser.statement.select.PlainSelect;
|
|
|
import net.sf.jsqlparser.statement.select.Select;
|
|
|
import net.sf.jsqlparser.statement.select.SubSelect;
|
|
|
import net.sf.jsqlparser.statement.update.Update;
|
|
|
+import org.apache.ibatis.executor.statement.StatementHandler;
|
|
|
+import org.apache.ibatis.logging.Log;
|
|
|
+import org.apache.ibatis.logging.LogFactory;
|
|
|
+import org.apache.ibatis.mapping.BoundSql;
|
|
|
+import org.apache.ibatis.mapping.MappedStatement;
|
|
|
+import org.apache.ibatis.mapping.SqlCommandType;
|
|
|
+import org.apache.ibatis.plugin.*;
|
|
|
+import org.apache.ibatis.reflection.MetaObject;
|
|
|
+import org.apache.ibatis.reflection.SystemMetaObject;
|
|
|
+
|
|
|
+import java.sql.Connection;
|
|
|
+import java.sql.DatabaseMetaData;
|
|
|
+import java.sql.ResultSet;
|
|
|
+import java.sql.SQLException;
|
|
|
+import java.util.*;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
|
|
|
/**
|
|
|
* @author willenfoo
|
|
@@ -90,56 +79,6 @@ public class IllegalSQLInterceptor implements Interceptor {
|
|
|
*/
|
|
|
private static Map<String, List<IndexInfo>> indexInfoMap = new ConcurrentHashMap<>();
|
|
|
|
|
|
- @Override
|
|
|
- public Object intercept(Invocation invocation) throws Throwable {
|
|
|
- StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
|
|
|
- MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
|
|
|
- // 如果是insert操作,不进行验证
|
|
|
- MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
|
|
|
- if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType())) {
|
|
|
- return invocation.proceed();
|
|
|
- }
|
|
|
-
|
|
|
- BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
|
|
|
- String originalSql = boundSql.getSql();
|
|
|
- logger.debug("检查SQL是否合规,SQL:" + originalSql);
|
|
|
- String md5Base64 = EncryptUtils.md5Base64(originalSql);
|
|
|
- if (cacheValidResult.contains(md5Base64)) {
|
|
|
- logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
|
|
|
- return invocation.proceed();
|
|
|
- }
|
|
|
- Connection connection = (Connection) invocation.getArgs()[0];
|
|
|
- Statement statement = CCJSqlParserUtil.parse(originalSql);
|
|
|
- Expression where = null;
|
|
|
- Table table = null;
|
|
|
- List<Join> joins = null;
|
|
|
- if (statement instanceof Select) {
|
|
|
- PlainSelect plainSelect = (PlainSelect) ((Select) statement).getSelectBody();
|
|
|
- where = plainSelect.getWhere();
|
|
|
- table = (Table) plainSelect.getFromItem();
|
|
|
- joins = plainSelect.getJoins();
|
|
|
- } else if (statement instanceof Update) {
|
|
|
- Update update = (Update) statement;
|
|
|
- where = update.getWhere();
|
|
|
- table = update.getTables().get(0);
|
|
|
- joins = update.getJoins();
|
|
|
- } else if (statement instanceof Delete) {
|
|
|
- Delete delete = (Delete) statement;
|
|
|
- where = delete.getWhere();
|
|
|
- table = delete.getTable();
|
|
|
- joins = delete.getJoins();
|
|
|
- }
|
|
|
- //where条件不能为空
|
|
|
- if (where == null) {
|
|
|
- throw new MybatisPlusException("非法SQL,必须要有where条件");
|
|
|
- }
|
|
|
- validWhere(where, table, connection);
|
|
|
- validJoins(joins, table, connection);
|
|
|
- //缓存验证结果
|
|
|
- cacheValidResult.add(md5Base64);
|
|
|
- return invocation.proceed();
|
|
|
- }
|
|
|
-
|
|
|
/**
|
|
|
* 验证expression对象是不是 or、not等等
|
|
|
*
|
|
@@ -285,42 +224,6 @@ public class IllegalSQLInterceptor implements Interceptor {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- /**
|
|
|
- * 索引对象
|
|
|
- */
|
|
|
- private static class IndexInfo {
|
|
|
-
|
|
|
- private String dbName;
|
|
|
-
|
|
|
- private String tableName;
|
|
|
-
|
|
|
- private String columnName;
|
|
|
-
|
|
|
- public String getDbName() {
|
|
|
- return dbName;
|
|
|
- }
|
|
|
-
|
|
|
- public void setDbName(String dbName) {
|
|
|
- this.dbName = dbName;
|
|
|
- }
|
|
|
-
|
|
|
- public String getTableName() {
|
|
|
- return tableName;
|
|
|
- }
|
|
|
-
|
|
|
- public void setTableName(String tableName) {
|
|
|
- this.tableName = tableName;
|
|
|
- }
|
|
|
-
|
|
|
- public String getColumnName() {
|
|
|
- return columnName;
|
|
|
- }
|
|
|
-
|
|
|
- public void setColumnName(String columnName) {
|
|
|
- this.columnName = columnName;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
/**
|
|
|
* 得到表的索引信息
|
|
|
*
|
|
@@ -373,6 +276,55 @@ public class IllegalSQLInterceptor implements Interceptor {
|
|
|
return indexInfos;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public Object intercept(Invocation invocation) throws Throwable {
|
|
|
+ StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
|
|
|
+ MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
|
|
|
+ // 如果是insert操作,不进行验证
|
|
|
+ MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
|
|
|
+ if (SqlCommandType.INSERT.equals(mappedStatement.getSqlCommandType())) {
|
|
|
+ return invocation.proceed();
|
|
|
+ }
|
|
|
+
|
|
|
+ BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
|
|
|
+ String originalSql = boundSql.getSql();
|
|
|
+ logger.debug("检查SQL是否合规,SQL:" + originalSql);
|
|
|
+ String md5Base64 = EncryptUtils.md5Base64(originalSql);
|
|
|
+ if (cacheValidResult.contains(md5Base64)) {
|
|
|
+ logger.debug("该SQL已验证,无需再次验证,,SQL:" + originalSql);
|
|
|
+ return invocation.proceed();
|
|
|
+ }
|
|
|
+ Connection connection = (Connection) invocation.getArgs()[0];
|
|
|
+ Statement statement = CCJSqlParserUtil.parse(originalSql);
|
|
|
+ Expression where = null;
|
|
|
+ Table table = null;
|
|
|
+ List<Join> joins = null;
|
|
|
+ if (statement instanceof Select) {
|
|
|
+ PlainSelect plainSelect = (PlainSelect) ((Select) statement).getSelectBody();
|
|
|
+ where = plainSelect.getWhere();
|
|
|
+ table = (Table) plainSelect.getFromItem();
|
|
|
+ joins = plainSelect.getJoins();
|
|
|
+ } else if (statement instanceof Update) {
|
|
|
+ Update update = (Update) statement;
|
|
|
+ where = update.getWhere();
|
|
|
+ table = update.getTables().get(0);
|
|
|
+ joins = update.getJoins();
|
|
|
+ } else if (statement instanceof Delete) {
|
|
|
+ Delete delete = (Delete) statement;
|
|
|
+ where = delete.getWhere();
|
|
|
+ table = delete.getTable();
|
|
|
+ joins = delete.getJoins();
|
|
|
+ }
|
|
|
+ //where条件不能为空
|
|
|
+ if (where == null) {
|
|
|
+ throw new MybatisPlusException("非法SQL,必须要有where条件");
|
|
|
+ }
|
|
|
+ validWhere(where, table, connection);
|
|
|
+ validJoins(joins, table, connection);
|
|
|
+ //缓存验证结果
|
|
|
+ cacheValidResult.add(md5Base64);
|
|
|
+ return invocation.proceed();
|
|
|
+ }
|
|
|
|
|
|
@Override
|
|
|
public Object plugin(Object target) {
|
|
@@ -387,4 +339,16 @@ public class IllegalSQLInterceptor implements Interceptor {
|
|
|
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * 索引对象
|
|
|
+ */
|
|
|
+ @Data
|
|
|
+ private static class IndexInfo {
|
|
|
+
|
|
|
+ private String dbName;
|
|
|
+
|
|
|
+ private String tableName;
|
|
|
+
|
|
|
+ private String columnName;
|
|
|
+ }
|
|
|
}
|