Procházet zdrojové kódy

(Class<T>) entity.getClass() 封装成方法,减少警告注解使用次数
getOne、getMap、count、list、listMaps 方法增加传入实体的方法重载,利用mp的特性实现根据实体不为空的属性查询

mahuibo před 2 roky
rodič
revize
936c061810

+ 96 - 25
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/Db.java

@@ -15,25 +15,36 @@
  */
 package com.baomidou.mybatisplus.extension.toolkit;
 
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+import org.apache.ibatis.binding.MapperMethod;
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
+
 import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
 import com.baomidou.mybatisplus.core.enums.SqlMethod;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
-import com.baomidou.mybatisplus.core.toolkit.*;
+import com.baomidou.mybatisplus.core.toolkit.Assert;
+import com.baomidou.mybatisplus.core.toolkit.ClassUtils;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.Constants;
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
+import com.baomidou.mybatisplus.core.toolkit.StringUtils;
+import com.baomidou.mybatisplus.core.toolkit.Wrappers;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper;
 import com.baomidou.mybatisplus.extension.conditions.query.QueryChainWrapper;
 import com.baomidou.mybatisplus.extension.conditions.update.LambdaUpdateChainWrapper;
 import com.baomidou.mybatisplus.extension.conditions.update.UpdateChainWrapper;
 import com.baomidou.mybatisplus.extension.service.IService;
-import org.apache.ibatis.binding.MapperMethod;
-import org.apache.ibatis.logging.Log;
-import org.apache.ibatis.logging.LogFactory;
-
-import java.io.Serializable;
-import java.util.*;
-import java.util.stream.Collectors;
 
 /**
  * 以静态方式调用Service中的函数
@@ -58,9 +69,7 @@ public class Db {
         if (Objects.isNull(entity)) {
             return false;
         }
-        @SuppressWarnings("unchecked")
-        Class<T> entityClass = (Class<T>) entity.getClass();
-        Integer result = SqlHelper.execute(entityClass, baseMapper -> baseMapper.insert(entity));
+        Integer result = SqlHelper.execute(getEntityClass(entity), baseMapper -> baseMapper.insert(entity));
         return SqlHelper.retBool(result);
     }
 
@@ -143,9 +152,7 @@ public class Db {
         if (Objects.isNull(entity)) {
             return false;
         }
-        @SuppressWarnings("unchecked")
-        Class<T> entityClass = (Class<T>) entity.getClass();
-        return SqlHelper.execute(entityClass, baseMapper -> SqlHelper.retBool(baseMapper.deleteById(entity)));
+        return SqlHelper.execute(getEntityClass(entity), baseMapper -> SqlHelper.retBool(baseMapper.deleteById(entity)));
     }
 
     /**
@@ -166,9 +173,7 @@ public class Db {
         if (Objects.isNull(entity)) {
             return false;
         }
-        @SuppressWarnings("unchecked")
-        Class<T> entityClass = (Class<T>) entity.getClass();
-        return SqlHelper.execute(entityClass, baseMapper -> SqlHelper.retBool(baseMapper.updateById(entity)));
+        return SqlHelper.execute(getEntityClass(entity), baseMapper -> SqlHelper.retBool(baseMapper.updateById(entity)));
     }
 
     /**
@@ -245,8 +250,7 @@ public class Db {
         if (Objects.isNull(entity)) {
             return false;
         }
-        @SuppressWarnings("unchecked")
-        Class<T> entityClass = (Class<T>) entity.getClass();
+        Class<T> entityClass = getEntityClass(entity);
         TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
         Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
         String keyProperty = tableInfo.getKeyProperty();
@@ -275,6 +279,26 @@ public class Db {
         return getOne(queryWrapper, true);
     }
 
+    /**
+     * 根据 entity里不为空的字段,查询一条记录 <br/>
+     * <p>结果集,如果是多个会抛出异常,随机取一条加上限制条件 wrapper.last("LIMIT 1")</p>
+     *
+     * @param entity 实体对象
+     */
+    public static <T> T getOne(T entity) {
+        return getOne(Wrappers.lambdaQuery(entity), true);
+    }
+
+    /**
+     * 根据 entity里不为空的字段,查询一条记录
+     *
+     * @param entity  实体对象
+     * @param throwEx 有多个 result 是否抛出异常
+     */
+    public static <T> T getOne(T entity, boolean throwEx) {
+        return getOne(Wrappers.lambdaQuery(entity), throwEx);
+    }
+
     /**
      * 根据 Wrapper,查询一条记录
      *
@@ -318,6 +342,15 @@ public class Db {
         return SqlHelper.execute(getEntityClass(queryWrapper), baseMapper -> SqlHelper.getObject(log, baseMapper.selectMaps(queryWrapper)));
     }
 
+    /**
+     * 根据 entity不为空条件,查询一条记录
+     *
+     * @param entity 实体对象
+     */
+    public static <T> Map<String, Object> getMap(T entity) {
+        return getMap(Wrappers.lambdaQuery(entity));
+    }
+
     /**
      * 查询总记录数
      *
@@ -328,6 +361,15 @@ public class Db {
         return SqlHelper.execute(entityClass, baseMapper -> baseMapper.selectCount(null));
     }
 
+    /**
+     * 根据entity中不为空的数据查询记录数
+     *
+     * @param entity 实体类
+     */
+    public static <T> long count(T entity) {
+        return count(Wrappers.lambdaQuery(entity));
+    }
+
     /**
      * 根据 Wrapper 条件,查询总记录数
      *
@@ -343,8 +385,7 @@ public class Db {
      * @param queryWrapper 实体对象封装操作类 {@link com.baomidou.mybatisplus.core.conditions.query.QueryWrapper}
      */
     public static <T> List<T> list(AbstractWrapper<T, ?, ?> queryWrapper) {
-        Class<T> entityClass = getEntityClass(queryWrapper);
-        return SqlHelper.execute(entityClass, baseMapper -> baseMapper.selectList(queryWrapper));
+        return SqlHelper.execute(getEntityClass(queryWrapper), baseMapper -> baseMapper.selectList(queryWrapper));
     }
 
     /**
@@ -357,6 +398,16 @@ public class Db {
         return SqlHelper.execute(entityClass, baseMapper -> baseMapper.selectList(null));
     }
 
+    /**
+     * 根据entity中不为空的字段进行查询
+     *
+     * @param entity 实体类
+     * @see Wrappers#emptyWrapper()
+     */
+    public static <T> List<T> list(T entity) {
+        return list(Wrappers.lambdaQuery(entity));
+    }
+
     /**
      * 查询列表
      *
@@ -376,6 +427,15 @@ public class Db {
         return SqlHelper.execute(entityClass, baseMapper -> baseMapper.selectMaps(null));
     }
 
+    /**
+     * 根据entity不为空的条件查询列表
+     *
+     * @param entity 实体类
+     */
+    public static <T> List<Map<String, Object>> listMaps(T entity) {
+        return listMaps(Wrappers.lambdaQuery(entity));
+    }
+
     /**
      * 查询全部记录
      *
@@ -523,12 +583,11 @@ public class Db {
      * @param <T>        实体类型
      * @return 实体类型
      */
-    @SuppressWarnings("unchecked")
     protected static <T> Class<T> getEntityClass(Collection<T> entityList) {
         Class<T> entityClass = null;
         for (T entity : entityList) {
             if (entity != null && entity.getClass() != null) {
-                entityClass = (Class<T>) entity.getClass();
+                entityClass = getEntityClass(entity);
                 break;
             }
         }
@@ -543,19 +602,31 @@ public class Db {
      * @param <T>          实体类型
      * @return 实体类型
      */
-    @SuppressWarnings("unchecked")
     protected static <T> Class<T> getEntityClass(AbstractWrapper<T, ?, ?> queryWrapper) {
         Class<T> entityClass = queryWrapper.getEntityClass();
         if (entityClass == null) {
             T entity = queryWrapper.getEntity();
             if (entity != null) {
-                entityClass = (Class<T>) entity.getClass();
+                entityClass = getEntityClass(entity);
             }
         }
         Assert.notNull(entityClass, "error: can not get entityClass from wrapper");
         return entityClass;
     }
 
+    /**
+     * 从entity中尝试获取实体类型
+     *
+     * @param entity 实体
+     * @param <T>    实体类型
+     * @return 实体类型
+     */
+    @SuppressWarnings("unchecked")
+    protected static <T> Class<T> getEntityClass(T entity) {
+        return (Class<T>) entity.getClass();
+    }
+
+
     /**
      * 获取表信息,获取不到报错提示
      *

+ 30 - 7
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/toolkit/DbTest.java

@@ -1,5 +1,16 @@
 package com.baomidou.mybatisplus.test.toolkit;
 
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.ibatis.exceptions.TooManyResultsException;
+import org.apache.ibatis.plugin.Interceptor;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
 import com.baomidou.mybatisplus.annotation.DbType;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import com.baomidou.mybatisplus.core.metadata.IPage;
@@ -15,12 +26,6 @@ import com.baomidou.mybatisplus.extension.toolkit.Db;
 import com.baomidou.mybatisplus.test.BaseDbTest;
 import com.baomidou.mybatisplus.test.sqlrunner.Entity;
 import com.baomidou.mybatisplus.test.sqlrunner.EntityMapper;
-import org.apache.ibatis.exceptions.TooManyResultsException;
-import org.apache.ibatis.plugin.Interceptor;
-import org.junit.jupiter.api.Assertions;
-import org.junit.jupiter.api.Test;
-
-import java.util.*;
 
 /**
  * 以静态方式调用Service中的函数
@@ -151,6 +156,10 @@ class DbTest extends BaseDbTest<EntityMapper> {
         Assertions.assertThrows(TooManyResultsException.class, () -> Db.getOne(wrapper));
         Entity one = Db.getOne(wrapper, false);
         Assertions.assertNotNull(one);
+        Entity entity = new Entity();
+        entity.setId(1L);
+        one = Db.getOne(entity);
+        Assertions.assertNotNull(one);
     }
 
     @Test
@@ -172,6 +181,11 @@ class DbTest extends BaseDbTest<EntityMapper> {
     void testGetMap() {
         Map<String, Object> map = Db.getMap(Wrappers.lambdaQuery(Entity.class));
         Assertions.assertNotNull(map);
+
+        Entity entity = new Entity();
+        entity.setId(1L);
+        map = Db.getMap(entity);
+        Assertions.assertNotNull(map);
     }
 
     @Test
@@ -181,6 +195,11 @@ class DbTest extends BaseDbTest<EntityMapper> {
 
         list = Db.list(Entity.class);
         Assertions.assertEquals(2, list.size());
+
+        Entity entity = new Entity();
+        entity.setId(1L);
+        list = Db.list(entity);
+        Assertions.assertEquals(1, list.size());
     }
 
     @Test
@@ -190,6 +209,11 @@ class DbTest extends BaseDbTest<EntityMapper> {
 
         list = Db.listMaps(Entity.class);
         Assertions.assertEquals(2, list.size());
+
+        Entity entity = new Entity();
+        entity.setId(1L);
+        list = Db.listMaps(entity);
+        Assertions.assertEquals(1, list.size());
     }
 
     @Test
@@ -224,7 +248,6 @@ class DbTest extends BaseDbTest<EntityMapper> {
     void testPage() {
         IPage<Entity> page = Db.page(new Page<>(1, 1), Entity.class);
         Assertions.assertEquals(2, page.getTotal());
-
         page = Db.page(new Page<>(1, 1), Wrappers.lambdaQuery(Entity.class));
         Assertions.assertEquals(1, page.getRecords().size());
     }