浏览代码

to qiuqiu

miemie 5 年之前
父节点
当前提交
d9656ec118

+ 35 - 8
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/PluginUtils.java

@@ -15,11 +15,13 @@
  */
 package com.baomidou.mybatisplus.core.toolkit;
 
+import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.SystemMetaObject;
 
+import java.lang.reflect.Field;
 import java.lang.reflect.Proxy;
-import java.util.Properties;
+import java.util.Map;
 
 /**
  * 插件工具类
@@ -27,11 +29,19 @@ import java.util.Properties;
  * @author TaoYu , hubin
  * @since 2017-06-20
  */
-public final class PluginUtils {
+public abstract class PluginUtils {
     public static final String DELEGATE_BOUNDSQL_SQL = "delegate.boundSql.sql";
 
-    private PluginUtils() {
-        // to do nothing
+    private final static Field additionalParametersField = initBoundSqlAdditionalParametersField();
+
+    private static Field initBoundSqlAdditionalParametersField() {
+        try {
+            Field field = BoundSql.class.getDeclaredField("additionalParameters");
+            field.setAccessible(true);
+            return field;
+        } catch (NoSuchFieldException e) {
+            throw ExceptionUtils.mpe("can not find field['additionalParameters'] from BoundSql, why?", e);
+        }
     }
 
     /**
@@ -47,10 +57,27 @@ public final class PluginUtils {
     }
 
     /**
-     * 根据 key 获取 Properties 的值
+     * 获取 BoundSql 属性值 additionalParameters
+     *
+     * @param boundSql BoundSql
+     * @return additionalParameters
+     */
+    @SuppressWarnings("unchecked")
+    public static Map<String, Object> getAdditionalParameter(BoundSql boundSql) {
+        try {
+            return (Map<String, Object>) additionalParametersField.get(boundSql);
+        } catch (IllegalAccessException e) {
+            throw ExceptionUtils.mpe("获取 BoundSql 属性值 additionalParameters 失败: " + e, e);
+        }
+    }
+
+    /**
+     * 给 BoundSql 设置 additionalParameters
+     *
+     * @param boundSql             BoundSql
+     * @param additionalParameters additionalParameters
      */
-    public static String getProperty(Properties properties, String key) {
-        String value = properties.getProperty(key);
-        return StringUtils.isBlank(value) ? null : value;
+    public static void setAdditionalParameter(BoundSql boundSql, Map<String, Object> additionalParameters) {
+        additionalParameters.forEach(boundSql::setAdditionalParameter);
     }
 }

+ 3 - 2
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/PageBeforeQuery.java

@@ -7,6 +7,7 @@ import com.baomidou.mybatisplus.core.parser.ISqlParser;
 import com.baomidou.mybatisplus.core.parser.SqlInfo;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.ParameterUtils;
+import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
@@ -94,13 +95,13 @@ public class PageBeforeQuery implements BeforeQuery {
                 return boundSql;
             }
         }
-        DbType dbType = this.dbType == null ? JdbcUtils.getDbType(JdbcUtils.getJdbcUrl(ms)) : this.dbType;
+        DbType dbType = this.dbType == null ? JdbcUtils.getDbType(ms) : this.dbType;
         IDialect dialect = Optional.ofNullable(this.dialect).orElseGet(() -> DialectFactory.getDialect(dbType));
         String buildSql = concatOrderBy(originalSql, page);
         DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
         final Configuration configuration = ms.getConfiguration();
         List<ParameterMapping> mappings = new ArrayList<>(boundSql.getParameterMappings());
-        Map<String, Object> additionalParameter = JdbcUtils.getAdditionalParameter(boundSql);
+        Map<String, Object> additionalParameter = PluginUtils.getAdditionalParameter(boundSql);
         model.consumers(mappings, configuration, additionalParameter);
         boundSql = new BoundSql(configuration, model.getDialectSql(), mappings, parameter);
         for (Map.Entry<String, Object> entry : additionalParameter.entrySet()) {

+ 14 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/pagination/count/CountSqlParser.java

@@ -0,0 +1,14 @@
+package com.baomidou.mybatisplus.extension.plugins.pagination.count;
+
+/**
+ * @author miemie
+ * @since 2020-06-16
+ */
+public interface CountSqlParser {
+
+    String parser(String sql);
+
+    default String defaultCount(String sql) {
+        return String.format("SELECT COUNT(1) FROM ( %s ) TOTAL", sql);
+    }
+}

+ 105 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/pagination/count/JsqlParserCountSqlParser.java

@@ -0,0 +1,105 @@
+package com.baomidou.mybatisplus.extension.plugins.pagination.count;
+
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import net.sf.jsqlparser.expression.Alias;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.Function;
+import net.sf.jsqlparser.expression.LongValue;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.select.*;
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * @author miemie
+ * @since 2020-06-16
+ */
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class JsqlParserCountSqlParser implements CountSqlParser {
+
+    private static final List<SelectItem> COUNT_SELECT_ITEM = Collections.singletonList(defaultCountSelectItem());
+    private final Log logger = LogFactory.getLog(JsqlParserCountSqlParser.class);
+
+    private boolean optimizeJoin = false;
+
+    /**
+     * 获取jsqlparser中count的SelectItem
+     */
+    private static SelectItem defaultCountSelectItem() {
+        Function function = new Function();
+        ExpressionList expressionList = new ExpressionList(Collections.singletonList(new LongValue(1)));
+        function.setName("COUNT");
+        function.setParameters(expressionList);
+        return new SelectExpressionItem(function);
+    }
+
+    @Override
+    public String parser(String originalSql) {
+        if (logger.isDebugEnabled()) {
+            logger.debug("JsqlParserCountOptimize sql=" + originalSql);
+        }
+        try {
+            Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
+            PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
+            Distinct distinct = plainSelect.getDistinct();
+            GroupByElement groupBy = plainSelect.getGroupBy();
+            List<OrderByElement> orderBy = plainSelect.getOrderByElements();
+
+            // 添加包含groupBy 不去除orderBy
+            if (null == groupBy && CollectionUtils.isNotEmpty(orderBy)) {
+                plainSelect.setOrderByElements(null);
+            }
+            //#95 Github, selectItems contains #{} ${}, which will be translated to ?, and it may be in a function: power(#{myInt},2)
+            for (SelectItem item : plainSelect.getSelectItems()) {
+                if (item.toString().contains(StringPool.QUESTION_MARK)) {
+                    return defaultCount(selectStatement.toString());
+                }
+            }
+            // 包含 distinct、groupBy不优化
+            if (distinct != null || null != groupBy) {
+                return defaultCount(selectStatement.toString());
+            }
+            // 包含 join 连表,进行判断是否移除 join 连表
+            List<Join> joins = plainSelect.getJoins();
+            if (optimizeJoin && CollectionUtils.isNotEmpty(joins)) {
+                boolean canRemoveJoin = true;
+                String whereS = Optional.ofNullable(plainSelect.getWhere()).map(Expression::toString).orElse(StringPool.EMPTY);
+                for (Join join : joins) {
+                    if (!join.isLeft()) {
+                        canRemoveJoin = false;
+                        break;
+                    }
+                    Table table = (Table) join.getRightItem();
+                    String str = Optional.ofNullable(table.getAlias()).map(Alias::getName).orElse(table.getName()) + StringPool.DOT;
+                    String onExpressionS = join.getOnExpression().toString();
+                    /* 如果 join 里包含 ?(代表有入参) 或者 where 条件里包含使用 join 的表的字段作条件,就不移除 join */
+                    if (onExpressionS.contains(StringPool.QUESTION_MARK) || whereS.contains(str)) {
+                        canRemoveJoin = false;
+                        break;
+                    }
+                }
+                if (canRemoveJoin) {
+                    plainSelect.setJoins(null);
+                }
+            }
+            // 优化 SQL
+            plainSelect.setSelectItems(COUNT_SELECT_ITEM);
+            return defaultCount(selectStatement.toString());
+        } catch (Throwable e) {
+            // 无法优化使用原 SQL
+            return defaultCount(originalSql);
+        }
+    }
+}

+ 2 - 31
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/JdbcUtils.java

@@ -21,14 +21,11 @@ import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
-import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
 
 import javax.sql.DataSource;
-import java.lang.reflect.Field;
 import java.sql.Connection;
 import java.sql.SQLException;
-import java.util.Map;
 
 /**
  * JDBC 工具类
@@ -40,36 +37,10 @@ public class JdbcUtils {
 
     private static final Log logger = LogFactory.getLog(JdbcUtils.class);
 
-    private static Field additionalParametersField;
-
-    static {
-        try {
-            additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
-            additionalParametersField.setAccessible(true);
-        } catch (NoSuchFieldException e) {
-            throw ExceptionUtils.mpe("获取 BoundSql 属性 additionalParameters 失败: " + e, e);
-        }
-    }
-
-    /**
-     * 获取 BoundSql 属性值 additionalParameters
-     *
-     * @param boundSql
-     * @return
-     */
-    @SuppressWarnings("unchecked")
-    public static Map<String, Object> getAdditionalParameter(BoundSql boundSql) {
-        try {
-            return (Map<String, Object>) additionalParametersField.get(boundSql);
-        } catch (IllegalAccessException e) {
-            throw ExceptionUtils.mpe("获取 BoundSql 属性值 additionalParameters 失败: " + e, e);
-        }
-    }
-
-    public static String getJdbcUrl(MappedStatement ms) {
+    public static DbType getDbType(MappedStatement ms) {
         DataSource dataSource = ms.getConfiguration().getEnvironment().getDataSource();
         try (Connection conn = dataSource.getConnection()) {
-            return conn.getMetaData().getURL();
+            return getDbType(conn.getMetaData().getURL());
         } catch (SQLException e) {
             throw ExceptionUtils.mpe(e);
         }