浏览代码

优化 page 分页可自动填充 records

= 7 年之前
父节点
当前提交
4268655268

+ 79 - 19
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/plugins/PaginationInterceptor.java

@@ -16,9 +16,12 @@
 package com.baomidou.mybatisplus.plugins;
 
 import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
 import java.util.List;
 import java.util.Properties;
 
+import org.apache.ibatis.builder.StaticSqlSource;
 import org.apache.ibatis.cache.CacheKey;
 import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.executor.statement.StatementHandler;
@@ -27,6 +30,7 @@ import org.apache.ibatis.logging.LogFactory;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
 import org.apache.ibatis.mapping.SqlCommandType;
+import org.apache.ibatis.mapping.SqlSource;
 import org.apache.ibatis.plugin.Interceptor;
 import org.apache.ibatis.plugin.Intercepts;
 import org.apache.ibatis.plugin.Invocation;
@@ -34,10 +38,12 @@ import org.apache.ibatis.plugin.Plugin;
 import org.apache.ibatis.plugin.Signature;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.SystemMetaObject;
+import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
 import org.apache.ibatis.session.Configuration;
 import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
+import com.baomidou.mybatisplus.MybatisDefaultParameterHandler;
 import com.baomidou.mybatisplus.enums.DBType;
 import com.baomidou.mybatisplus.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.plugins.pagination.PageHelper;
@@ -154,27 +160,10 @@ public class PaginationInterceptor extends SqlParserHandler implements Intercept
                 if (page.isSearchCount()) {
                     SqlInfo sqlInfo = SqlUtils.getOptimizeCountSql(page.isOptimizeCountSql(), sqlParser, originalSql);
                     orderBy = sqlInfo.isOrderBy();
-                    BoundSql countBoundSql = new BoundSql(configuration, sqlInfo.getSql(), boundSql.getParameterMappings(), parameter);
-                    CacheKey countCacheKey = executor.createCacheKey(mappedStatement, parameter, RowBounds.DEFAULT, countBoundSql);
                     // 查询总记录数
-                    Object countObject = executor.query(mappedStatement, parameter, RowBounds.DEFAULT, resultHandler, countCacheKey, countBoundSql);
-//                Map tempMap = (Map) countList.get(0);
-//                Object[] tempArray = tempMap.entrySet().toArray();
-//                Map.Entry totalCount = (Map.Entry) tempArray[0];
-//                page.setTotal((Long) totalCount.getValue());
-                    page.setTotal(6);
-                    // 溢出总页数,设置第一页
-                    long pages = page.getPages();
-                    if (overflowCurrent && (page.getCurrent() > pages)) {
-                        // 设置为第一条
-                        page.setCurrent(1);
-                    }
-                    if (page.getTotal() <= 0L) {
-                        return invocation.proceed();
-                    }
+                    this.queryTotal(overflowCurrent, mappedStatement, boundSql, connection, parameter, sqlInfo, page);
                 }
-                String buildSql = SqlUtils.concatOrderBy(originalSql, page, orderBy);
-                originalSql = DialectFactory.buildPaginationSql(page, buildSql, dbType, dialectClazz);
+                originalSql = DialectFactory.buildPaginationSql(page, SqlUtils.concatOrderBy(originalSql, page, orderBy), dbType, dialectClazz);
             } else {
                 // support physical Pagination for RowBounds
                 originalSql = DialectFactory.buildPaginationSql(rowBounds, originalSql, dbType, dialectClazz);
@@ -191,6 +180,77 @@ public class PaginationInterceptor extends SqlParserHandler implements Intercept
         return invocation.proceed();
     }
 
+    /**
+     * <p>
+     * 查询总记录条数
+     * </p>
+     *
+     * @param overflowCurrent
+     * @param mappedStatement
+     * @param boundSql
+     * @param connection
+     * @param parameter
+     * @param sqlInfo
+     * @param page
+     */
+    protected void queryTotal(boolean overflowCurrent, MappedStatement mappedStatement, BoundSql boundSql, Connection connection, Object parameter, SqlInfo sqlInfo, Pagination page) {
+        MappedStatement countMappedStatement = this.getCountMappedStatement(mappedStatement, boundSql, sqlInfo, parameter);
+        try (PreparedStatement statement = connection.prepareStatement(sqlInfo.getSql())) {
+            DefaultParameterHandler parameterHandler = new MybatisDefaultParameterHandler(countMappedStatement, boundSql.getParameterObject(), boundSql);
+            parameterHandler.setParameters(statement);
+            int total = 0;
+            try (ResultSet resultSet = statement.executeQuery()) {
+                if (resultSet.next()) {
+                    total = resultSet.getInt(1);
+                }
+            }
+            page.setTotal(total);
+            /*
+             * 溢出总页数,设置第一页
+             */
+            long pages = page.getPages();
+            if (overflowCurrent && (Long.valueOf(page.getCurrent()) > pages)) {
+                // 设置为第一条
+                page.setCurrent(1);
+            }
+        } catch (Exception e) {
+            logger.error("Error: Method queryTotal execution error !", e);
+        }
+    }
+
+    /**
+     * <p>
+     * 获取 Count MappedStatement 如果存在 countStatementId 使用 XML 中的 SQL 如果不存在构建一个
+     * </p>
+     * <p>
+     * 例如: selectPage 分页查询,自定义 XML COUNT 查询未 selectPageCount
+     * </p>
+     *
+     * @param mappedStatement
+     * @param boundSql
+     * @param sqlInfo
+     * @param parameter
+     * @return
+     */
+    private MappedStatement getCountMappedStatement(MappedStatement mappedStatement, BoundSql boundSql, SqlInfo sqlInfo, Object parameter) {
+        Configuration configuration = mappedStatement.getConfiguration();
+        BoundSql countBoundSql = new BoundSql(configuration, sqlInfo.getSql(), boundSql.getParameterMappings(), parameter);
+        String countStatementId = mappedStatement.getId() + "Count";
+        MappedStatement countMappedStatement = null;
+        try {
+            countMappedStatement = configuration.getMappedStatement(countStatementId, false);
+        } catch (Throwable t) {
+            if (null == countMappedStatement) {
+                // 查询结果集
+                SqlSource sqlsource = new StaticSqlSource(configuration, countBoundSql.getSql(), countBoundSql.getParameterMappings());
+                MappedStatement.Builder builder = new MappedStatement.Builder(configuration, countStatementId, sqlsource,
+                    SqlCommandType.SELECT);
+                mappedStatement = builder.build();
+            }
+        }
+        return mappedStatement;
+    }
+
     @Override
     public Object plugin(Object target) {
         if (target instanceof StatementHandler || target instanceof Executor) {