Browse Source

修复selectObjs泛型错误问题

Caratacus 6 years ago
parent
commit
e0dde28c53

+ 4 - 3
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/injector/methods/SelectObjs.java

@@ -15,11 +15,12 @@
  */
 package com.baomidou.mybatisplus.core.injector.methods;
 
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlSource;
+
 import com.baomidou.mybatisplus.core.enums.SqlMethod;
 import com.baomidou.mybatisplus.core.injector.AbstractMethod;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.apache.ibatis.mapping.SqlSource;
 
 /**
  * <p>
@@ -37,6 +38,6 @@ public class SelectObjs extends AbstractMethod {
         String sql = String.format(sqlMethod.getSql(), sqlSelectObjsColumns(tableInfo),
             tableInfo.getTableName(), this.sqlWhereEntityWrapper(true, tableInfo));
         SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
-        return this.addSelectMappedStatement(mapperClass, sqlMethod.getMethod(), sqlSource, modelClass, tableInfo);
+        return this.addSelectMappedStatement(mapperClass, sqlMethod.getMethod(), sqlSource, Object.class, tableInfo);
     }
 }

+ 10 - 6
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/IService.java

@@ -19,10 +19,11 @@ import java.io.Serializable;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.function.Function;
 
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
-import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.core.metadata.IPage;
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
 
 /**
@@ -234,9 +235,10 @@ public interface IService<T> {
      * </p>
      *
      * @param queryWrapper 实体对象封装操作类 {@link com.baomidou.mybatisplus.core.conditions.query.QueryWrapper}
+     * @param mapper       转换函数
      */
-    default Object getObj(Wrapper<T> queryWrapper) {
-        return SqlHelper.getObject(listObjs(queryWrapper));
+    default <V> V getObj(Wrapper<T> queryWrapper, Function<? super Object, V> mapper) {
+        return SqlHelper.getObject(listObjs(queryWrapper, mapper));
     }
 
     /**
@@ -327,18 +329,20 @@ public interface IService<T> {
      * </p>
      *
      * @param queryWrapper 实体对象封装操作类 {@link com.baomidou.mybatisplus.core.conditions.query.QueryWrapper}
+     * @param mapper       转换函数
      */
-    List<Object> listObjs(Wrapper<T> queryWrapper);
+    <V> List<V> listObjs(Wrapper<T> queryWrapper, Function<? super Object, V> mapper);
 
     /**
      * <p>
      * 查询全部记录
      * </p>
      *
+     * @param mapper 转换函数
      * @see Wrappers#emptyWrapper()
      */
-    default List<Object> listObjs() {
-        return listObjs(Wrappers.<T>emptyWrapper());
+    default <V> List<V> listObjs(Function<? super Object, V> mapper) {
+        return listObjs(Wrappers.<T>emptyWrapper(), mapper);
     }
 
     /**

+ 59 - 43
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/service/impl/ServiceImpl.java

@@ -15,26 +15,34 @@
  */
 package com.baomidou.mybatisplus.extension.service.impl;
 
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+import org.apache.ibatis.binding.MapperMethod;
+import org.apache.ibatis.session.SqlSession;
+import org.mybatis.spring.SqlSessionUtils;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.transaction.annotation.Transactional;
+
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
 import com.baomidou.mybatisplus.core.enums.SqlMethod;
 import com.baomidou.mybatisplus.core.mapper.BaseMapper;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
-import com.baomidou.mybatisplus.core.toolkit.*;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.Constants;
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
+import com.baomidou.mybatisplus.core.toolkit.GlobalConfigUtils;
+import com.baomidou.mybatisplus.core.toolkit.ObjectUtils;
+import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
+import com.baomidou.mybatisplus.core.toolkit.TableInfoHelper;
 import com.baomidou.mybatisplus.extension.service.IService;
 import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
-import org.apache.ibatis.binding.MapperMethod;
-import org.apache.ibatis.session.SqlSession;
-import org.mybatis.spring.SqlSessionUtils;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.transaction.annotation.Transactional;
-
-import java.io.Serializable;
-import java.util.Collection;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.stream.Collectors;
 
 /**
  * <p>
@@ -103,16 +111,16 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     /**
      * 批量插入
      *
-     * @param entityList 实体类集合
-     * @param batchSize  每次提交的量
-     * @return 执行是否成功
+     * @param entityList
+     * @param batchSize
+     * @return
      */
     @Transactional(rollbackFor = Exception.class)
     @Override
     public boolean saveBatch(Collection<T> entityList, int batchSize) {
+        int i = 0;
         String sqlStatement = sqlStatement(SqlMethod.INSERT_ONE);
         try (SqlSession batchSqlSession = sqlSessionBatch()) {
-            int i = 0;
             for (T anEntityList : entityList) {
                 batchSqlSession.insert(sqlStatement, anEntityList);
                 if (i >= 1 && i % batchSize == 0) {
@@ -139,11 +147,12 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
         if (null != entity) {
             Class<?> cls = entity.getClass();
             TableInfo tableInfo = TableInfoHelper.getTableInfo(cls);
-            Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
-            String keyProperty = tableInfo.getKeyProperty();
-            Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
-            Object idVal = ReflectionKit.getMethodValue(cls, entity, tableInfo.getKeyProperty());
-            return StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal)) ? save(entity) : updateById(entity);
+            if (null != tableInfo && StringUtils.isNotEmpty(tableInfo.getKeyProperty())) {
+                Object idVal = ReflectionKit.getMethodValue(cls, entity, tableInfo.getKeyProperty());
+                return StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal)) ? save(entity) : updateById(entity);
+            } else {
+                throw ExceptionUtils.mpe("Error:  Can not execute. Could not find @TableId.");
+            }
         }
         return false;
     }
@@ -151,30 +160,33 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     @Transactional(rollbackFor = Exception.class)
     @Override
     public boolean saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
-        Assert.notEmpty(entityList, "error: entityList must not be empty");
+        if (CollectionUtils.isEmpty(entityList)) {
+            throw new IllegalArgumentException("Error: entityList must not be empty");
+        }
         Class<?> cls = currentModelClass();
         TableInfo tableInfo = TableInfoHelper.getTableInfo(cls);
-        Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
-        String keyProperty = tableInfo.getKeyProperty();
-        Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
+        int i = 0;
         try (SqlSession batchSqlSession = sqlSessionBatch()) {
-            int i = 0;
-            for (T entity : entityList) {
-                Object idVal = ReflectionKit.getMethodValue(cls, entity, keyProperty);
-                if (StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal))) {
-                    batchSqlSession.insert(sqlStatement(SqlMethod.INSERT_ONE), entity);
+            for (T anEntityList : entityList) {
+                if (null != tableInfo && StringUtils.isNotEmpty(tableInfo.getKeyProperty())) {
+                    Object idVal = ReflectionKit.getMethodValue(cls, anEntityList, tableInfo.getKeyProperty());
+                    if (StringUtils.checkValNull(idVal) || Objects.isNull(getById((Serializable) idVal))) {
+                        batchSqlSession.insert(sqlStatement(SqlMethod.INSERT_ONE), anEntityList);
+                    } else {
+                        MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
+                        param.put(Constants.ENTITY, anEntityList);
+                        batchSqlSession.update(sqlStatement(SqlMethod.UPDATE_BY_ID), param);
+                    }
+                    //不知道以后会不会有人说更新失败了还要执行插入 😂😂😂
+                    if (i >= 1 && i % batchSize == 0) {
+                        batchSqlSession.flushStatements();
+                    }
+                    i++;
                 } else {
-                    MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
-                    param.put(Constants.ENTITY, entity);
-                    batchSqlSession.update(sqlStatement(SqlMethod.UPDATE_BY_ID), param);
-                }
-                //不知道以后会不会有人说更新失败了还要执行插入 😂😂😂
-                if (i >= 1 && i % batchSize == 0) {
-                    batchSqlSession.flushStatements();
+                    throw ExceptionUtils.mpe("Error:  Can not execute. Could not find @TableId.");
                 }
-                i++;
+                batchSqlSession.flushStatements();
             }
-            batchSqlSession.flushStatements();
         }
         return true;
     }
@@ -188,7 +200,9 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     @Transactional(rollbackFor = Exception.class)
     @Override
     public boolean removeByMap(Map<String, Object> columnMap) {
-        Assert.notEmpty(columnMap, "error: columnMap must not be empty");
+        if (ObjectUtils.isEmpty(columnMap)) {
+            throw ExceptionUtils.mpe("removeByMap columnMap is empty.");
+        }
         return SqlHelper.delBool(baseMapper.deleteByMap(columnMap));
     }
 
@@ -219,10 +233,12 @@ public class ServiceImpl<M extends BaseMapper<T>, T> implements IService<T> {
     @Transactional(rollbackFor = Exception.class)
     @Override
     public boolean updateBatchById(Collection<T> entityList, int batchSize) {
-        Assert.notEmpty(entityList, "error: entityList must not be empty");
+        if (CollectionUtils.isEmpty(entityList)) {
+            throw new IllegalArgumentException("Error: entityList must not be empty");
+        }
+        int i = 0;
         String sqlStatement = sqlStatement(SqlMethod.UPDATE_BY_ID);
         try (SqlSession batchSqlSession = sqlSessionBatch()) {
-            int i = 0;
             for (T anEntityList : entityList) {
                 MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
                 param.put(Constants.ENTITY, anEntityList);