浏览代码

sql 解析抽象封装优化

jobob 8 年之前
父节点
当前提交
6c9e538c3f

+ 16 - 11
src/main/java/com/baomidou/mybatisplus/plugins/CachePaginationInterceptor.java

@@ -32,9 +32,10 @@ import org.apache.ibatis.reflection.SystemMetaObject;
 import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
-import com.baomidou.mybatisplus.entity.CountOptimize;
 import com.baomidou.mybatisplus.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
+import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.toolkit.SqlUtils;
 import com.baomidou.mybatisplus.toolkit.StringUtils;
@@ -47,12 +48,11 @@ import com.baomidou.mybatisplus.toolkit.StringUtils;
  * @author hubin
  * @Date 2016-01-23
  */
-@Intercepts({
-        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class,
-                ResultHandler.class}),
+@Intercepts({@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
         @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
 public class CachePaginationInterceptor extends PaginationInterceptor implements Interceptor {
-
+    // COUNT SQL 解析
+    private AbstractSqlParser sqlParser;
     /* Count优化方式 */
     private String optimizeType = "default";
     /* 方言类型 */
@@ -82,9 +82,11 @@ public class CachePaginationInterceptor extends PaginationInterceptor implements
                 Pagination page = (Pagination) rowBounds;
                 boolean orderBy = true;
                 if (page.isSearchCount()) {
-                    CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, optimizeType, dialectType,
-                            page.isOptimizeCount());
-                    orderBy = countOptimize.isOrderBy();
+                    String tempSql = originalSql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY");
+                    int orderByIndex = tempSql.toUpperCase().lastIndexOf("ORDER BY");
+                    if(orderByIndex <= -1) {
+                        orderBy = false;
+                    }
                 }
                 String buildSql = SqlUtils.concatOrderBy(originalSql, page, orderBy);
                 originalSql = DialectFactory.buildPaginationSql(page, buildSql, dialectType, dialectClazz);
@@ -110,9 +112,9 @@ public class CachePaginationInterceptor extends PaginationInterceptor implements
             if (rowBounds instanceof Pagination) {
                 Pagination page = (Pagination) rowBounds;
                 if (page.isSearchCount()) {
-                    CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, optimizeType, dialectType,
-                            page.isOptimizeCount());
-                    super.queryTotal(countOptimize.getCountSQL(), mappedStatement, boundSql, page, connection);
+                    SqlInfo sqlInfo = SqlUtils.getCountOptimize(sqlParser, originalSql, optimizeType,
+                            dialectType, page.isOptimizeCount());
+                    super.queryTotal(sqlInfo.getSql(), mappedStatement, boundSql, page, connection);
                     if (page.getTotal() <= 0) {
                         return invocation.proceed();
                     }
@@ -151,4 +153,7 @@ public class CachePaginationInterceptor extends PaginationInterceptor implements
         this.optimizeType = optimizeType;
     }
 
+    public void setSqlParser(AbstractSqlParser sqlParser) {
+        this.sqlParser = sqlParser;
+    }
 }

+ 13 - 6
src/main/java/com/baomidou/mybatisplus/plugins/PaginationInterceptor.java

@@ -37,9 +37,10 @@ import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
 import org.apache.ibatis.session.RowBounds;
 
 import com.baomidou.mybatisplus.MybatisDefaultParameterHandler;
-import com.baomidou.mybatisplus.entity.CountOptimize;
 import com.baomidou.mybatisplus.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
+import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 import com.baomidou.mybatisplus.toolkit.JdbcUtils;
 import com.baomidou.mybatisplus.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.toolkit.SqlUtils;
@@ -55,9 +56,10 @@ import com.baomidou.mybatisplus.toolkit.StringUtils;
  */
 @Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
 public class PaginationInterceptor implements Interceptor {
-
+    // 日志
     private static final Log logger = LogFactory.getLog(PaginationInterceptor.class);
-
+    // COUNT SQL 解析
+    private AbstractSqlParser sqlParser;
     /* 溢出总页数,设置第一页 */
     private boolean overflowCurrent = false;
     /* 是否设置动态数据源 设置之后动态获取当前数据源 */
@@ -95,9 +97,10 @@ public class PaginationInterceptor implements Interceptor {
             Pagination page = (Pagination) rowBounds;
             boolean orderBy = true;
             if (page.isSearchCount()) {
-                CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, optimizeType, dialectType, page.isOptimizeCount());
-                orderBy = countOptimize.isOrderBy();
-                this.queryTotal(countOptimize.getCountSQL(), mappedStatement, boundSql, page, connection);
+                SqlInfo sqlInfo = SqlUtils.getCountOptimize(sqlParser, originalSql, optimizeType,
+                        dialectType, page.isOptimizeCount());
+                orderBy = sqlInfo.isOrderBy();
+                this.queryTotal(sqlInfo.getSql(), mappedStatement, boundSql, page, connection);
                 if (page.getTotal() <= 0) {
                     return invocation.proceed();
                 }
@@ -192,4 +195,8 @@ public class PaginationInterceptor implements Interceptor {
     public void setDynamicDataSource(boolean dynamicDataSource) {
         this.dynamicDataSource = dynamicDataSource;
     }
+
+    public void setSqlParser(AbstractSqlParser sqlParser) {
+        this.sqlParser = sqlParser;
+    }
 }

+ 57 - 0
src/main/java/com/baomidou/mybatisplus/plugins/TenancyInterceptor.java

@@ -0,0 +1,57 @@
+/**
+ * Copyright (c) 2011-2020, hubin (jobob@qq.com).
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.baomidou.mybatisplus.plugins;
+
+import java.util.Properties;
+
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.plugin.Interceptor;
+import org.apache.ibatis.plugin.Intercepts;
+import org.apache.ibatis.plugin.Invocation;
+import org.apache.ibatis.plugin.Signature;
+import org.apache.ibatis.session.ResultHandler;
+import org.apache.ibatis.session.RowBounds;
+
+/**
+ * <p>
+ * 租户拦截器,解决 SAAS 共享数据库租户场景
+ * </p>
+ *
+ * @author hubin
+ * @Date 2016-08-16
+ */
+@Intercepts({
+        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
+        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
+})
+public class TenancyInterceptor implements Interceptor {
+
+    @Override
+    public Object intercept(Invocation invocation) throws Throwable {
+        return null;
+    }
+
+    @Override
+    public Object plugin(Object target) {
+        return null;
+    }
+
+    @Override
+    public void setProperties(Properties properties) {
+
+    }
+}

+ 51 - 0
src/main/java/com/baomidou/mybatisplus/plugins/pagination/optimize/AliDruidCountOptimize.java

@@ -0,0 +1,51 @@
+/**
+ * Copyright (c) 2011-2014, hubin (jobob@qq.com).
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.baomidou.mybatisplus.plugins.pagination.optimize;
+
+import com.alibaba.druid.sql.PagerUtils;
+import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
+
+/**
+ * <p>
+ * Ali Druid Count Optimize
+ * </p>
+ *
+ * @author hubin
+ * @Date 2017-06-20
+ */
+public class AliDruidCountOptimize extends AbstractSqlParser {
+
+    public AliDruidCountOptimize(String sql, String dbType) {
+        super(sql, dbType);
+    }
+
+    @Override
+    public SqlInfo optimizeSql() {
+        String sql = this.getSql();
+        String dbType = this.getDbType();
+        if (logger.isDebugEnabled()) {
+            logger.debug(" AliDruidCountOptimize sql=" + sql + ", dbType=" + dbType);
+        }
+        SqlInfo sqlInfo = SqlInfo.newInstance();
+        // 调整SQL便于解析
+        String tempSql = sql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY");
+        int orderByIndex = tempSql.toUpperCase().lastIndexOf("ORDER BY");
+        sqlInfo.setOrderBy(orderByIndex > -1);
+        sqlInfo.setSql(PagerUtils.count(sql, dbType));
+        return sqlInfo;
+    }
+}

+ 75 - 0
src/main/java/com/baomidou/mybatisplus/plugins/pagination/optimize/DefaultCountOptimize.java

@@ -0,0 +1,75 @@
+/**
+ * Copyright (c) 2011-2014, hubin (jobob@qq.com).
+ * <p>
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.baomidou.mybatisplus.plugins.pagination.optimize;
+
+import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
+
+/**
+ * <p>
+ * Default Count Optimize
+ * </p>
+ *
+ * @author hubin
+ * @Date 2017-06-20
+ */
+public class DefaultCountOptimize extends AbstractSqlParser {
+
+    public DefaultCountOptimize(String sql, String dbType) {
+        super(sql, dbType);
+    }
+
+    @Override
+    public SqlInfo optimizeSql() {
+        String sql = this.getSql();
+        String dbType = this.getDbType();
+        if (logger.isDebugEnabled()) {
+            logger.debug(" DefaultCountOptimize sql=" + sql + ", dbType=" + dbType);
+        }
+        SqlInfo sqlInfo = SqlInfo.newInstance();
+        // 调整SQL便于解析
+        String tempSql = sql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY").replaceAll("(?i)GROUP[\\s]+BY", "GROUP BY");
+        String indexOfSql = tempSql.toUpperCase();
+        // 有排序情况
+        int orderByIndex = indexOfSql.lastIndexOf("ORDER BY");
+        // 只针对 ALI_DRUID DEFAULT 这2种情况
+        if (orderByIndex > -1) {
+            sqlInfo.setOrderBy(false);
+        }
+        StringBuilder countSql = new StringBuilder("SELECT COUNT(1) ");
+        boolean optimize = false;
+        if (!indexOfSql.contains("DISTINCT") && !indexOfSql.contains("GROUP BY")) {
+            int formIndex = indexOfSql.indexOf("FROM");
+            if (formIndex > -1) {
+                if (orderByIndex > -1) {
+                    tempSql = tempSql.substring(0, orderByIndex);
+                    countSql.append(tempSql.substring(formIndex));
+                    // 无排序情况
+                } else {
+                    countSql.append(tempSql.substring(formIndex));
+                }
+                // 执行优化
+                optimize = true;
+            }
+        }
+        if (!optimize) {
+            // 无优化SQL
+            countSql.append("FROM ( ").append(sql).append(" ) TOTAL");
+        }
+        sqlInfo.setSql(countSql.toString());
+        return sqlInfo;
+    }
+}

+ 32 - 25
src/main/java/com/baomidou/mybatisplus/toolkit/JsqlParserUtils.java → src/main/java/com/baomidou/mybatisplus/plugins/pagination/optimize/JsqlParserCountOptimize.java

@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2011-2020, hubin (jobob@qq.com).
+ * Copyright (c) 2011-2014, hubin (jobob@qq.com).
  * <p>
  * Licensed under the Apache License, Version 2.0 (the "License"); you may not
  * use this file except in compliance with the License. You may obtain a copy of
@@ -13,12 +13,15 @@
  * License for the specific language governing permissions and limitations under
  * the License.
  */
-package com.baomidou.mybatisplus.toolkit;
+package com.baomidou.mybatisplus.plugins.pagination.optimize;
 
 import java.util.ArrayList;
 import java.util.List;
 
-import com.baomidou.mybatisplus.entity.CountOptimize;
+import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
+import com.baomidou.mybatisplus.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.toolkit.SqlUtils;
 
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.Function;
@@ -34,27 +37,30 @@ import net.sf.jsqlparser.statement.select.SelectItem;
 
 /**
  * <p>
- * JsqlParserUtils工具类
+ * JsqlParser Count Optimize
  * </p>
  *
- * @author Caratacus
- * @Date 2016-11-30
+ * @author hubin
+ * @Date 2017-06-20
  */
-public class JsqlParserUtils {
-
+public class JsqlParserCountOptimize extends AbstractSqlParser {
     private static final List<SelectItem> countSelectItem = countSelectItem();
 
-    /**
-     * <p>
-     * jsqlparser方式获取select的count语句
-     * </p>
-     *
-     * @param originalSql selectSQL
-     * @return
-     */
-    public static CountOptimize jsqlparserCount(CountOptimize countOptimize, String originalSql) {
+    public JsqlParserCountOptimize(String sql, String dbType) {
+        super(sql, dbType);
+    }
+
+
+    @Override
+    public SqlInfo optimizeSql() {
+        String sql = this.getSql();
+        String dbType = this.getDbType();
+        if (logger.isDebugEnabled()) {
+            logger.debug(" JsqlParserCountOptimize sql=" + sql + ", dbType=" + dbType);
+        }
+        SqlInfo sqlInfo = SqlInfo.newInstance();
         try {
-            Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
+            Select selectStatement = (Select) CCJSqlParserUtil.parse(sql);
             PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
             Distinct distinct = plainSelect.getDistinct();
             List<Expression> groupBy = plainSelect.getGroupByColumnReferences();
@@ -63,26 +69,27 @@ public class JsqlParserUtils {
             // 添加包含groupBy 不去除orderBy
             if (CollectionUtils.isEmpty(groupBy) && CollectionUtils.isNotEmpty(orderBy)) {
                 plainSelect.setOrderByElements(null);
-                countOptimize.setOrderBy(false);
+                sqlInfo.setOrderBy(false);
             }
 
             // 包含 distinct、groupBy不优化
             if (distinct != null || CollectionUtils.isNotEmpty(groupBy)) {
-                countOptimize.setCountSQL(String.format(SqlUtils.SQL_BASE_COUNT, selectStatement.toString()));
-                return countOptimize;
+                sqlInfo.setSql(String.format(SqlUtils.SQL_BASE_COUNT, selectStatement.toString()));
+                return sqlInfo;
             }
 
             // 优化 SQL
             plainSelect.setSelectItems(countSelectItem);
-            countOptimize.setCountSQL(selectStatement.toString());
-            return countOptimize;
+            sqlInfo.setSql(selectStatement.toString());
+            return sqlInfo;
         } catch (Throwable e) {
             // 无法优化使用原 SQL
-            countOptimize.setCountSQL(String.format(SqlUtils.SQL_BASE_COUNT, originalSql));
-            return countOptimize;
+            sqlInfo.setSql(String.format(SqlUtils.SQL_BASE_COUNT, sql));
+            return sqlInfo;
         }
     }
 
+
     /**
      * <p>
      * 获取jsqlparser中count的SelectItem

+ 47 - 0
src/main/java/com/baomidou/mybatisplus/plugins/parser/AbstractSqlParser.java

@@ -0,0 +1,47 @@
+package com.baomidou.mybatisplus.plugins.parser;
+
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
+
+/**
+ * <p>
+ * 抽象 SQL 解析类
+ * </p>
+ */
+public abstract class AbstractSqlParser {
+
+    // 日志
+    protected static final Log logger = LogFactory.getLog(AbstractSqlParser.class);
+    private String sql;// SQL 语句
+    private String dbType; // 数据库类型
+
+    public AbstractSqlParser(String sql, String dbType) {
+        this.sql = sql;
+        this.dbType = dbType;
+    }
+
+    /**
+     * <p>
+     * 获取优化 SQL 方法
+     * </p>
+     *
+     * @return SQL 信息
+     */
+    public abstract SqlInfo optimizeSql();
+
+    public String getSql() {
+        return sql;
+    }
+
+    public void setSql(String sql) {
+        this.sql = sql;
+    }
+
+    public String getDbType() {
+        return dbType;
+    }
+
+    public void setDbType(String dbType) {
+        this.dbType = dbType;
+    }
+}

+ 30 - 0
src/main/java/com/baomidou/mybatisplus/plugins/parser/SqlInfo.java

@@ -0,0 +1,30 @@
+package com.baomidou.mybatisplus.plugins.parser;
+
+/**
+ * Created by jobob on 17/6/20.
+ */
+public class SqlInfo {
+
+    private String sql;// SQL 内容
+    private boolean orderBy = true;// 是否排序
+
+    public static SqlInfo newInstance() {
+        return new SqlInfo();
+    }
+
+    public String getSql() {
+        return sql;
+    }
+
+    public void setSql(String sql) {
+        this.sql = sql;
+    }
+
+    public boolean isOrderBy() {
+        return orderBy;
+    }
+
+    public void setOrderBy(boolean orderBy) {
+        this.orderBy = orderBy;
+    }
+}

+ 57 - 0
src/main/java/com/baomidou/mybatisplus/plugins/tenancy/TenancySqlParser.java

@@ -0,0 +1,57 @@
+package com.baomidou.mybatisplus.plugins.tenancy;
+
+import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
+
+import net.sf.jsqlparser.statement.insert.Insert;
+import net.sf.jsqlparser.statement.select.SelectBody;
+import net.sf.jsqlparser.statement.update.Update;
+
+/**
+ * Created by jobob on 17/6/20.
+ */
+public class TenancySqlParser extends AbstractSqlParser {
+
+    public TenancySqlParser(String sql, String dbType) {
+        super(sql, dbType);
+    }
+
+    @Override
+    public SqlInfo optimizeSql() {
+        return null;
+    }
+
+    /**
+     * <p>
+     * select 语句处理
+     * </p>
+     *
+     * @param selectBody
+     */
+    public void processSelectBody(SelectBody selectBody) {
+
+    }
+
+    /**
+     * <p>
+     * insert 语句处理
+     * </p>
+     *
+     * @param insert
+     */
+    public void processInsert(Insert insert) {
+
+    }
+
+    /**
+     * <p>
+     * update 语句处理
+     * </p>
+     *
+     * @param update
+     */
+    public void processUpdate(Update update) {
+
+    }
+
+}

+ 0 - 43
src/main/java/com/baomidou/mybatisplus/toolkit/DruidUtils.java

@@ -1,43 +0,0 @@
-/**
- * Copyright (c) 2011-2020, hubin (jobob@qq.com).
- * <p>
- * Licensed under the Apache License, Version 2.0 (the "License"); you may not
- * use this file except in compliance with the License. You may obtain a copy of
- * the License at
- * <p>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations under
- * the License.
- */
-package com.baomidou.mybatisplus.toolkit;
-
-import com.alibaba.druid.sql.PagerUtils;
-
-/**
- * <p>
- * DruidUtils工具类
- * </p>
- *
- * @author Caratacus
- * @Date 2016-11-30
- */
-public class DruidUtils {
-
-    /**
-     * <p>
-     * 通过Druid方式获取count语句
-     * </p>
-     *
-     * @param originalSql 原执行 SQL
-     * @param dialectType 数据库方言类型
-     * @return
-     */
-    public static String count(String originalSql, String dialectType) {
-        return PagerUtils.count(originalSql, dialectType);
-    }
-
-}

+ 32 - 51
src/main/java/com/baomidou/mybatisplus/toolkit/SqlUtils.java

@@ -15,10 +15,14 @@
  */
 package com.baomidou.mybatisplus.toolkit;
 
-import com.baomidou.mybatisplus.entity.CountOptimize;
 import com.baomidou.mybatisplus.enums.Optimize;
 import com.baomidou.mybatisplus.enums.SqlLike;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
+import com.baomidou.mybatisplus.plugins.pagination.optimize.AliDruidCountOptimize;
+import com.baomidou.mybatisplus.plugins.pagination.optimize.DefaultCountOptimize;
+import com.baomidou.mybatisplus.plugins.pagination.optimize.JsqlParserCountOptimize;
+import com.baomidou.mybatisplus.plugins.parser.AbstractSqlParser;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 
 /**
  * <p>
@@ -30,77 +34,54 @@ import com.baomidou.mybatisplus.plugins.pagination.Pagination;
  */
 public class SqlUtils {
 
-    public static final String SQL_BASE_COUNT = "SELECT COUNT(1) FROM ( %s ) TOTAL";
     private final static SqlFormatter sqlFormatter = new SqlFormatter();
+    public static final String SQL_BASE_COUNT = "SELECT COUNT(1) FROM ( %s ) TOTAL";
 
     /**
+     * <p>
      * 获取CountOptimize
+     * </p>
      *
+     * @param sqlParser       Count SQL 解析类
      * @param originalSql     需要计算Count SQL
      * @param optimizeType    count优化方式
      * @param isOptimizeCount 是否需要优化Count
-     * @return CountOptimize
+     * @return SqlInfo
      */
-    public static CountOptimize getCountOptimize(String originalSql, String optimizeType, String dialectType,
-                                                 boolean isOptimizeCount) {
-        CountOptimize countOptimize = CountOptimize.newInstance();
-        // 获取优化类型
+    public static SqlInfo getCountOptimize(AbstractSqlParser sqlParser, String originalSql,
+                                           String optimizeType, String dialectType,
+                                           boolean isOptimizeCount) {
         Optimize opType = Optimize.getOptimizeType(optimizeType);
-        // 调整SQL便于解析
-        String tempSql = originalSql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY").replaceAll("(?i)GROUP[\\s]+BY", "GROUP BY");
-        String indexOfSql = tempSql.toUpperCase();
-        // 有排序情况
-        int orderByIndex = indexOfSql.lastIndexOf("ORDER BY");
-        // 只针对 ALI_DRUID DEFAULT 这2种情况
-        if (orderByIndex > -1) {
-            countOptimize.setOrderBy(false);
+
+        // COUNT SQL 不优化
+        if (!isOptimizeCount && Optimize.DEFAULT == opType) {
+            SqlInfo sqlInfo = SqlInfo.newInstance();
+            String tempSql = originalSql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY");
+            int orderByIndex = tempSql.toUpperCase().lastIndexOf("ORDER BY");
+            sqlInfo.setOrderBy(orderByIndex > -1);
+            sqlInfo.setSql(String.format(SQL_BASE_COUNT, originalSql));
+            return sqlInfo;
         }
-        if (!isOptimizeCount && opType.equals(Optimize.DEFAULT)) {
-            countOptimize.setCountSQL(String.format(SQL_BASE_COUNT, originalSql));
-            return countOptimize;
+
+        // 用户自定义 COUNT SQL 解析
+        if (null != sqlParser) {
+            return sqlParser.optimizeSql();
         }
 
+        // 默认存在的优化类型
         switch (opType) {
             case ALI_DRUID:
-                /**
-                 * 调用ali druid方式 插件dbType一定要设置为小写与JdbcConstants保持一致
-                 *
-                 * @see com.alibaba.druid.util.JdbcConstants
-                 */
-                String aliCountSql = DruidUtils.count(originalSql, dialectType);
-                countOptimize.setCountSQL(aliCountSql);
+                sqlParser = new AliDruidCountOptimize(originalSql, dialectType);
                 break;
             case JSQLPARSER:
-                /**
-                 * 调用JsqlParser方式
-                 */
-                JsqlParserUtils.jsqlparserCount(countOptimize, originalSql);
+                sqlParser = new JsqlParserCountOptimize(originalSql, dialectType);
                 break;
             default:
-                StringBuilder countSql = new StringBuilder("SELECT COUNT(1) ");
-                boolean optimize = false;
-                if (!indexOfSql.contains("DISTINCT") && !indexOfSql.contains("GROUP BY")) {
-                    int formIndex = indexOfSql.indexOf("FROM");
-                    if (formIndex > -1) {
-                        if (orderByIndex > -1) {
-                            tempSql = tempSql.substring(0, orderByIndex);
-                            countSql.append(tempSql.substring(formIndex));
-                            // 无排序情况
-                        } else {
-                            countSql.append(tempSql.substring(formIndex));
-                        }
-                        // 执行优化
-                        optimize = true;
-                    }
-                }
-                if (!optimize) {
-                    // 无优化SQL
-                    countSql.append("FROM ( ").append(originalSql).append(" ) TOTAL");
-                }
-                countOptimize.setCountSQL(countSql.toString());
+                sqlParser = new DefaultCountOptimize(originalSql, dialectType);
+                break;
         }
 
-        return countOptimize;
+        return sqlParser.optimizeSql();
     }
 
     /**

+ 64 - 100
src/test/java/com/baomidou/mybatisplus/test/SqlUtilsTest.java

@@ -3,8 +3,10 @@ package com.baomidou.mybatisplus.test;
 import org.junit.Assert;
 import org.junit.Test;
 
-import com.baomidou.mybatisplus.entity.CountOptimize;
-import com.baomidou.mybatisplus.toolkit.SqlUtils;
+import com.baomidou.mybatisplus.plugins.pagination.optimize.AliDruidCountOptimize;
+import com.baomidou.mybatisplus.plugins.pagination.optimize.DefaultCountOptimize;
+import com.baomidou.mybatisplus.plugins.pagination.optimize.JsqlParserCountOptimize;
+import com.baomidou.mybatisplus.plugins.parser.SqlInfo;
 
 /**
  * <p>
@@ -16,18 +18,19 @@ import com.baomidou.mybatisplus.toolkit.SqlUtils;
  */
 public class SqlUtilsTest {
 
+    public SqlInfo jsqlParserCountSqlInfo(String sql) {
+        return new JsqlParserCountOptimize(sql, "mysql").optimizeSql();
+    }
+
     /**
      * 测试jsqlparser方式
      */
     @Test
     public void sqlCountOptimize1() {
-
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)",
-                        "jsqlparser", "mysql", true);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = jsqlParserCountSqlInfo(
+                "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
         Assert.assertFalse(orderBy);
@@ -41,12 +44,11 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize2() {
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select distinct * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)",
-                        "jsqlparser", "mysql", true);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = jsqlParserCountSqlInfo(
+                "select distinct * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)"
+        );
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
         Assert.assertFalse(orderBy);
@@ -60,31 +62,33 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize3() {
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 group by a.id order by (select 1 from dual)",
-                        "jsqlparser", "mysql", true);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = jsqlParserCountSqlInfo(
+                "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 group by a.id order by (select 1 from dual)"
+        );
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
-        Assert.assertFalse(orderBy);
+        Assert.assertTrue(orderBy);
         Assert.assertEquals(
                 "SELECT COUNT(1) FROM ( SELECT * FROM user a LEFT JOIN (SELECT uuid FROM user2) b ON b.id = a.aid WHERE a = 1 GROUP BY a.id ORDER BY (SELECT 1 FROM dual) ) TOTAL",
                 countsql);
     }
 
+
+    public SqlInfo defaultCountSqlInfo(String sql) {
+        return new DefaultCountOptimize(sql, "mysql").optimizeSql();
+    }
+
     /**
      * 测试default方式
      */
     @Test
     public void sqlCountOptimize4() {
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 group by a.id order by (select 1 from dual)",
-                        "default", "mysql", false);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = defaultCountSqlInfo(
+                "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 group by a.id order by (select 1 from dual)");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
         Assert.assertFalse(orderBy);
@@ -98,43 +102,32 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize5() {
-        CountOptimize countOptimize = SqlUtils.getCountOptimize("select * from test where 1= 1 order by id ", "default", "mysql",
-                true);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = defaultCountSqlInfo("select * from test where 1= 1 order by id ");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
         Assert.assertFalse(orderBy);
         Assert.assertEquals("SELECT COUNT(1) from test where 1= 1 ", countsql);
     }
 
-    /**
-     * 测试default方式
-     */
-    @Test
-    public void sqlCountOptimize6() {
-        CountOptimize countOptimize = SqlUtils.getCountOptimize("select * from test where 1= 1 order by id ", "default", "mysql",
-                false);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
-        System.out.println(countsql);
-        System.out.println(orderBy);
-        Assert.assertFalse(orderBy);
-        Assert.assertEquals("SELECT COUNT(1) FROM ( select * from test where 1= 1 order by id  ) TOTAL", countsql);
-    }
-
     /**
      * 测试default方式
      */
     @Test
     public void sqlCountOptimize7() {
-        CountOptimize countOptimize = SqlUtils.getCountOptimize("select * from test where 1= 1 ", "default", "mysql", false);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = defaultCountSqlInfo("select * from test where 1= 1 ");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
         Assert.assertTrue(orderBy);
-        Assert.assertEquals("SELECT COUNT(1) FROM ( select * from test where 1= 1  ) TOTAL", countsql);
+        Assert.assertEquals("SELECT COUNT(1) from test where 1= 1 ", countsql);
+    }
+
+
+    public SqlInfo aliDruidCountSqlInfo(String sql) {
+        return new AliDruidCountOptimize(sql, "mysql").optimizeSql();
     }
 
     /**
@@ -142,13 +135,12 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize8() {
-        CountOptimize countOptimize = SqlUtils.getCountOptimize("select * from test where 1= 1 order by id ", "aliDruid",
-                "mysql", false);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = aliDruidCountSqlInfo("select * from test where 1= 1 order by id ");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
-        Assert.assertFalse(orderBy);
+        Assert.assertTrue(orderBy);
         Assert.assertEquals("SELECT COUNT(*)\n" + "FROM test\n" + "WHERE 1 = 1", countsql);
     }
 
@@ -157,12 +149,12 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize9() {
-        CountOptimize countOptimize = SqlUtils.getCountOptimize("select * from test where 1= 1 ", "aliDruid", "mysql", false);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = aliDruidCountSqlInfo("select * from test where 1= 1 ");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
-        Assert.assertTrue(orderBy);
+        Assert.assertFalse(orderBy);
         Assert.assertEquals("SELECT COUNT(*)\n" + "FROM test\n" + "WHERE 1 = 1", countsql);
     }
 
@@ -171,16 +163,13 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize10() {
+        SqlInfo sqlInfo = aliDruidCountSqlInfo("select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)");
 
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)",
-                        "aliDruid", "mysql", true);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
-        Assert.assertFalse(orderBy);
+        Assert.assertTrue(orderBy);
         Assert.assertEquals("SELECT COUNT(*)\n" + "FROM user a\n" + "\tLEFT JOIN (SELECT uuid\n" + "\t\tFROM user2\n"
                 + "\t\t) b ON b.id = a.aid\n" + "WHERE a = 1", countsql);
 
@@ -191,15 +180,12 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize11() {
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select distinct * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)",
-                        "aliDruid", "mysql", true);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = aliDruidCountSqlInfo("select distinct * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 order by (select 1 from dual)");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
-        Assert.assertFalse(orderBy);
+        Assert.assertTrue(orderBy);
         Assert.assertEquals("SELECT COUNT(DISTINCT *)\n" + "FROM user a\n" + "\tLEFT JOIN (SELECT uuid\n" + "\t\tFROM user2\n"
                 + "\t\t) b ON b.id = a.aid\n" + "WHERE a = 1", countsql);
     }
@@ -209,34 +195,12 @@ public class SqlUtilsTest {
      */
     @Test
     public void sqlCountOptimize12() {
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 group by a.id order by (select 1 from dual)",
-                        "aliDruid", "mysql", true);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
+        SqlInfo sqlInfo = aliDruidCountSqlInfo("select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 group by a.id order by (select 1 from dual)");
+        String countsql = sqlInfo.getSql();
+        boolean orderBy = sqlInfo.isOrderBy();
         System.out.println(countsql);
         System.out.println(orderBy);
-        Assert.assertFalse(orderBy);
-        Assert.assertEquals("SELECT COUNT(*)\n" + "FROM (SELECT *\n" + "\tFROM user a\n" + "\t\tLEFT JOIN (SELECT uuid\n"
-                + "\t\t\tFROM user2\n" + "\t\t\t) b ON b.id = a.aid\n" + "\tWHERE a = 1\n" + "\tGROUP BY a.id\n"
-                + "\t) ALIAS_COUNT", countsql);
-    }
-
-    /**
-     * 测试aliDruid方式
-     */
-    @Test
-    public void sqlCountOptimize13() {
-        CountOptimize countOptimize = SqlUtils
-                .getCountOptimize(
-                        "select * from user a left join (select uuid from user2) b on b.id = a.aid where a=1 group by a.id order by (select 1 from dual)",
-                        "aliDruid", "mysql", false);
-        String countsql = countOptimize.getCountSQL();
-        boolean orderBy = countOptimize.isOrderBy();
-        System.out.println(countsql);
-        System.out.println(orderBy);
-        Assert.assertFalse(orderBy);
+        Assert.assertTrue(orderBy);
         Assert.assertEquals("SELECT COUNT(*)\n" + "FROM (SELECT *\n" + "\tFROM user a\n" + "\t\tLEFT JOIN (SELECT uuid\n"
                 + "\t\t\tFROM user2\n" + "\t\t\t) b ON b.id = a.aid\n" + "\tWHERE a = 1\n" + "\tGROUP BY a.id\n"
                 + "\t) ALIAS_COUNT", countsql);