miemie 5 years ago
parent
commit
120678caf7

+ 21 - 25
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/override/MybatisMapperMethod.java

@@ -15,8 +15,9 @@
  */
 package com.baomidou.mybatisplus.core.override;
 
-import com.baomidou.mybatisplus.core.metadata.PageList;
 import com.baomidou.mybatisplus.core.metadata.IPage;
+import com.baomidou.mybatisplus.core.metadata.PageList;
+import com.baomidou.mybatisplus.core.toolkit.Assert;
 import org.apache.ibatis.binding.BindingException;
 import org.apache.ibatis.binding.MapperMethod;
 import org.apache.ibatis.cursor.Cursor;
@@ -51,7 +52,6 @@ public class MybatisMapperMethod {
         this.method = new MapperMethod.MethodSignature(config, mapperInterface, method);
     }
 
-    @SuppressWarnings("unchecked")
     public Object execute(SqlSession sqlSession, Object[] args) {
         Object result;
         switch (command.getType()) {
@@ -84,25 +84,7 @@ public class MybatisMapperMethod {
                     Object param = method.convertArgsToSqlCommandParam(args);
                     // TODO 这里下面改了
                     if (IPage.class.isAssignableFrom(method.getReturnType())) {
-                        assert args != null;
-                        IPage<?> page = null;
-                        for (Object arg : args) {
-                            if (arg instanceof IPage) {
-                                page = (IPage) arg;
-                                break;
-                            }
-                        }
-                        assert page != null;
                         result = executeForIPage(sqlSession, args);
-                        if (result instanceof PageList) {
-                            PageList pageList = (PageList) result;
-                            page.setRecords(pageList.getRecords());
-                            page.setTotal(pageList.getTotal());
-                            result = page;
-                        } else {
-                            List list = (List<Object>) result;
-                            result = page.setRecords(list);
-                        }
                         // TODO 这里上面改了
                     } else {
                         result = sqlSession.selectOne(command.getName(), param);
@@ -126,12 +108,26 @@ public class MybatisMapperMethod {
         return result;
     }
 
-    /**
-     * TODO IPage 专用
-     */
-    private <E> List<E> executeForIPage(SqlSession sqlSession, Object[] args) {
+    @SuppressWarnings("all")
+    private <E> Object executeForIPage(SqlSession sqlSession, Object[] args) {
+        IPage<E> result = null;
+        for (Object arg : args) {
+            if (arg instanceof IPage) {
+                result = (IPage<E>) arg;
+                break;
+            }
+        }
+        Assert.notNull(result, "can't found IPage for args!");
         Object param = method.convertArgsToSqlCommandParam(args);
-        return sqlSession.selectList(command.getName(), param);
+        List<E> list = sqlSession.selectList(command.getName(), param);
+        if (list instanceof PageList) {
+            PageList<E> pageList = (PageList<E>) list;
+            result.setRecords(pageList.getRecords());
+            result.setTotal(pageList.getTotal());
+        } else {
+            result.setRecords(list);
+        }
+        return result;
     }
 
     private Object rowCountResult(int rowCount) {