瀏覽代碼

ipage 新增功能,函数式获取 count

miemie 6 年之前
父節點
當前提交
1e26e140fc

+ 12 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/metadata/IPage.java

@@ -18,6 +18,7 @@ package com.baomidou.mybatisplus.core.metadata;
 import java.io.Serializable;
 import java.util.List;
 import java.util.Map;
+import java.util.function.LongSupplier;
 
 /**
  * <p>
@@ -29,6 +30,17 @@ import java.util.Map;
  */
 public interface IPage<T> extends Serializable {
 
+    /**
+     * <p>
+     * 自定义获取 count 的提供方
+     * </p>
+     *
+     * @return LongSupplier
+     */
+    default LongSupplier getSupplier() {
+        return null;
+    }
+
     /**
      * <p>
      * 降序字段数组

+ 5 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/PaginationInterceptor.java

@@ -44,6 +44,7 @@ import java.sql.Connection;
 import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.util.*;
+import java.util.function.LongSupplier;
 
 import static java.util.stream.Collectors.joining;
 
@@ -163,14 +164,18 @@ public class PaginationInterceptor extends AbstractSqlParserHandler implements I
         DbType dbType = StringUtils.isNotEmpty(dialectType) ? DbType.getDbType(dialectType)
             : JdbcUtils.getDbType(connection.getMetaData().getURL());
 
+        LongSupplier supplier = page.getSupplier();
         boolean orderBy = true;
         if (page.getTotal() == 0) {
+            // total 为0 才进行 count
             SqlInfo sqlInfo = SqlParserUtils.getOptimizeCountSql(page.optimizeCountSql(), sqlParser, originalSql);
             orderBy = sqlInfo.isOrderBy();
             this.queryTotal(overflow, sqlInfo.getSql(), mappedStatement, boundSql, page, connection);
             if (page.getTotal() <= 0) {
                 return invocation.proceed();
             }
+        } else if (supplier != null) {
+            page.setTotal(supplier.getAsLong());
         }
 
         String buildSql = concatOrderBy(originalSql, page, orderBy);

+ 20 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/pagination/Page.java

@@ -20,6 +20,7 @@ import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.function.LongSupplier;
 
 /**
  * <p>
@@ -67,6 +68,12 @@ public class Page<T> implements IPage<T> {
      * </p>
      */
     private boolean optimizeCountSql = true;
+    /**
+     * <p>
+     * 自定义获取 count 的提供方
+     * </p>
+     */
+    private transient LongSupplier supplier = null;
 
     public Page() {
         // to do nothing
@@ -92,6 +99,14 @@ public class Page<T> implements IPage<T> {
         this.total = total;
     }
 
+    /**
+     * 后台使用的构造函数
+     */
+    public Page(long current, long size, LongSupplier supplier) {
+        this(current, size, -1);
+        this.supplier = supplier;
+    }
+
     /**
      * <p>
      * 是否存在上一页
@@ -216,4 +231,9 @@ public class Page<T> implements IPage<T> {
         this.optimizeCountSql = optimizeCountSql;
         return this;
     }
+
+    @Override
+    public LongSupplier getSupplier() {
+        return this.supplier;
+    }
 }

+ 3 - 2
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/ActiveRecordTest.java

@@ -19,6 +19,7 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import org.springframework.transaction.annotation.Transactional;
+
 import java.util.List;
 
 
@@ -49,7 +50,7 @@ public class ActiveRecordTest {
         H2Student student = new H2Student(1L,"Tom长大了",2);
         Assert.assertTrue(student.updateById());
         student.setName("不听话的学生");
-        Assert.assertTrue(student.update(new QueryWrapper<>().gt("id",10)));
+        Assert.assertTrue(student.update(new QueryWrapper<H2Student>().gt("id", 10)));
     }
 
     @Test
@@ -125,6 +126,6 @@ public class ActiveRecordTest {
         student.setId(2L);
         Assert.assertTrue(student.deleteById());
         Assert.assertTrue(student.deleteById(12L));
-        Assert.assertTrue(student.delete(new QueryWrapper<>().gt("id",10)));
+        Assert.assertTrue(student.delete(new QueryWrapper<H2Student>().gt("id", 10)));
     }
 }

+ 8 - 8
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/mysql/MysqlTestDataMapperTest.java

@@ -15,9 +15,7 @@ import com.baomidou.mybatisplus.test.base.enums.TestEnum;
 import com.baomidou.mybatisplus.test.base.mapper.commons.CommonDataMapper;
 import com.baomidou.mybatisplus.test.base.mapper.commons.CommonLogicDataMapper;
 import com.baomidou.mybatisplus.test.base.mapper.mysql.MysqlDataMapper;
-import com.baomidou.mybatisplus.test.mysql.config.MysqlDb;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.FixMethodOrder;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -52,11 +50,11 @@ public class MysqlTestDataMapperTest {
     @Resource
     private MysqlDataMapper mysqlMapper;
 
-    @BeforeClass
-    public static void init() throws Exception {
-        MysqlDb.initMysqlData();
-        System.out.println("init success");
-    }
+//    @BeforeClass
+//    public static void init() throws Exception {
+//        MysqlDb.initMysqlData();
+//        System.out.println("init success");
+//    }
 
     @Test
     public void a_insertForeach() {
@@ -290,7 +288,9 @@ public class MysqlTestDataMapperTest {
 
     @Test
     public void xxx() {
-        mysqlMapper.selectPage(new Page<>(1, 5),
+        IPage<MysqlData> page = mysqlMapper.selectPage(new Page<>(2, 5, mysqlMapper.selectCount(null)),
             Condition.<MysqlData>create().gt("`order`", 1).gt("`group`", 2));
+        System.out.println(page.getTotal());
+        System.out.println(page.getRecords().size());
     }
 }

+ 1 - 1
mybatis-plus/src/test/resources/logback.xml

@@ -11,7 +11,7 @@
         <appender-ref ref="STDOUT"/>
     </logger>
 
-    <root level="debug">
+    <root level="info">
         <appender-ref ref="STDOUT"/>
     </root>
 </configuration>