Преглед на файлове

新增SimpleQuery#group以及listGroupBy的重载,支持传入自定义下游操作,解锁高级玩法

VampireAchao преди 3 години
родител
ревизия
4dc8c6542f

+ 93 - 18
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SimpleQuery.java

@@ -6,8 +6,8 @@ import com.baomidou.mybatisplus.core.toolkit.LambdaUtils;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 
 
 import java.util.*;
 import java.util.*;
-import java.util.function.Consumer;
-import java.util.function.Function;
+import java.util.function.*;
+import java.util.stream.Collector;
 import java.util.stream.Collectors;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import java.util.stream.Stream;
 import java.util.stream.StreamSupport;
 import java.util.stream.StreamSupport;
@@ -93,20 +93,40 @@ public class SimpleQuery {
         return listGroupBy(SqlHelper.getMapper(getType(sFunction)).selectList(wrapper), sFunction, peeks);
         return listGroupBy(SqlHelper.getMapper(getType(sFunction)).selectList(wrapper), sFunction, peeks);
     }
     }
 
 
+    /**
+     * ignore
+     */
+    @SafeVarargs
+    public static <T, K> Map<K, List<T>> group(LambdaQueryWrapper<T> wrapper, SFunction<T, K> sFunction, boolean isParallel, Consumer<T>... peeks) {
+        return listGroupBy(SqlHelper.getMapper(getType(sFunction)).selectList(wrapper), sFunction, isParallel, peeks);
+    }
+
+    /**
+     * ignore
+     */
+    @SafeVarargs
+    public static <T, K, D, A, M extends Map<K, D>> M group(LambdaQueryWrapper<T> wrapper, SFunction<T, K> sFunction, Collector<? super T, A, D> downstream, Consumer<T>... peeks) {
+        return listGroupBy(SqlHelper.getMapper(getType(sFunction)).selectList(wrapper), sFunction, downstream, false, peeks);
+    }
+
     /**
     /**
      * 传入Wrappers和key,从数据库中根据条件查询出对应的列表,封装成Map
      * 传入Wrappers和key,从数据库中根据条件查询出对应的列表,封装成Map
      *
      *
      * @param wrapper    条件构造器
      * @param wrapper    条件构造器
      * @param sFunction  分组依据
      * @param sFunction  分组依据
+     * @param downstream 下游操作
      * @param isParallel 是否并行流
      * @param isParallel 是否并行流
      * @param peeks      后续操作
      * @param peeks      后续操作
-     * @param <E>        实体类型
-     * @param <A>        实体中的属性类型
+     * @param <T>        实体类型
+     * @param <K>        实体中的分组依据对应类型,也是Map中key的类型
+     * @param <D>        下游操作对应返回类型,也是Map中value的类型
+     * @param <A>        下游操作在进行中间操作时对应类型
+     * @param <M>        最后返回结果Map类型
      * @return Map<实体中的属性, List < 实体>>
      * @return Map<实体中的属性, List < 实体>>
      */
      */
     @SafeVarargs
     @SafeVarargs
-    public static <E, A> Map<A, List<E>> group(LambdaQueryWrapper<E> wrapper, SFunction<E, A> sFunction, boolean isParallel, Consumer<E>... peeks) {
-        return listGroupBy(SqlHelper.getMapper(getType(sFunction)).selectList(wrapper), sFunction, isParallel, peeks);
+    public static <T, K, D, A, M extends Map<K, D>> M group(LambdaQueryWrapper<T> wrapper, SFunction<T, K> sFunction, Collector<? super T, A, D> downstream, boolean isParallel, Consumer<T>... peeks) {
+        return listGroupBy(SqlHelper.getMapper(getType(sFunction)).selectList(wrapper), sFunction, downstream, isParallel, peeks);
     }
     }
 
 
     /**
     /**
@@ -159,31 +179,86 @@ public class SimpleQuery {
      * ignore
      * ignore
      */
      */
     @SafeVarargs
     @SafeVarargs
-    public static <A, E> Map<A, List<E>> listGroupBy(List<E> list, SFunction<E, A> sFunction, Consumer<E>... peeks) {
+    public static <K, T> Map<K, List<T>> listGroupBy(List<T> list, SFunction<T, K> sFunction, Consumer<T>... peeks) {
         return listGroupBy(list, sFunction, false, peeks);
         return listGroupBy(list, sFunction, false, peeks);
     }
     }
 
 
+    /**
+     * ignore
+     */
+    @SafeVarargs
+    public static <K, T> Map<K, List<T>> listGroupBy(List<T> list, SFunction<T, K> sFunction, boolean isParallel, Consumer<T>... peeks) {
+        return listGroupBy(list, sFunction, Collectors.toList(), isParallel, peeks);
+    }
+
+    /**
+     * ignore
+     */
+    @SafeVarargs
+    public static <T, K, D, A, M extends Map<K, D>> M listGroupBy(List<T> list, SFunction<T, K> sFunction, Collector<? super T, A, D> downstream, Consumer<T>... peeks) {
+        return listGroupBy(list, sFunction, downstream, false, peeks);
+    }
+
     /**
     /**
      * 对list进行groupBy操作
      * 对list进行groupBy操作
      *
      *
      * @param list       数据
      * @param list       数据
      * @param sFunction  分组的key,依据
      * @param sFunction  分组的key,依据
+     * @param downstream 下游操作
      * @param isParallel 是否并行流
      * @param isParallel 是否并行流
      * @param peeks      封装成map时可能需要的后续操作,不需要可以不传
      * @param peeks      封装成map时可能需要的后续操作,不需要可以不传
-     * @param <E>        实体类型
-     * @param <A>        实体中的属性类型
+     * @param <T>        实体类型
+     * @param <K>        实体中的分组依据对应类型,也是Map中key的类型
+     * @param <D>        下游操作对应返回类型,也是Map中value的类型
+     * @param <A>        下游操作在进行中间操作时对应类型
+     * @param <M>        最后返回结果Map类型
      * @return Map<实体中的属性, List < 实体>>
      * @return Map<实体中的属性, List < 实体>>
      */
      */
     @SafeVarargs
     @SafeVarargs
-    public static <A, E> Map<A, List<E>> listGroupBy(List<E> list, SFunction<E, A> sFunction, boolean isParallel, Consumer<E>... peeks) {
-        return peekStream(list, isParallel, peeks).collect(HashMap::new, (m, v) -> {
-            A key = Optional.ofNullable(v).map(sFunction).orElse(null);
-            List<E> values = m.computeIfAbsent(key, k -> new ArrayList<>(list.size()));
-            values.add(v);
-        }, (totalMap, nowMap) -> nowMap.forEach((key, v) -> {
-            List<E> values = totalMap.computeIfAbsent(key, k -> new ArrayList<>(list.size()));
-            values.addAll(v);
-        }));
+    @SuppressWarnings("unchecked")
+    public static <T, K, D, A, M extends Map<K, D>> M listGroupBy(List<T> list, SFunction<T, K> sFunction, Collector<? super T, A, D> downstream, boolean isParallel, Consumer<T>... peeks) {
+        boolean hasFinished = downstream.characteristics().contains(Collector.Characteristics.IDENTITY_FINISH);
+        return peekStream(list, isParallel, peeks).collect(new Collector<T, HashMap<K, A>, M>() {
+            @Override
+            public Supplier<HashMap<K, A>> supplier() {
+                return HashMap::new;
+            }
+
+            @Override
+            public BiConsumer<HashMap<K, A>, T> accumulator() {
+                return (m, t) -> {
+                    // 只此一处,和原版groupingBy修改只此一处,成功在支持下游操作的情况下支持null值
+                    K key = Optional.ofNullable(t).map(sFunction).orElse(null);
+                    A container = m.computeIfAbsent(key, k -> downstream.supplier().get());
+                    downstream.accumulator().accept(container, t);
+                };
+            }
+
+            @Override
+            public BinaryOperator<HashMap<K, A>> combiner() {
+                return (m1, m2) -> {
+                    for (Map.Entry<K, A> e : m2.entrySet()) {
+                        m1.merge(e.getKey(), e.getValue(), downstream.combiner());
+                    }
+                    return m1;
+                };
+            }
+
+            @Override
+            public Function<HashMap<K, A>, M> finisher() {
+                return hasFinished ? i -> (M) i : intermediate -> {
+                    intermediate.replaceAll((k, v) -> (A) downstream.finisher().apply(v));
+                    @SuppressWarnings("unchecked")
+                    M castResult = (M) intermediate;
+                    return castResult;
+                };
+            }
+
+            @Override
+            public Set<Characteristics> characteristics() {
+                return hasFinished ? Collections.unmodifiableSet(EnumSet.of(Collector.Characteristics.IDENTITY_FINISH)) : Collections.emptySet();
+            }
+        });
     }
     }
 
 
 
 

+ 7 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/toolkit/SimpleQueryTest.java

@@ -11,6 +11,7 @@ import com.baomidou.mybatisplus.test.rewrite.EntityMapper;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
 
 
 import java.util.*;
 import java.util.*;
+import java.util.stream.Collectors;
 
 
 /**
 /**
  * 简单查询工具类测试
  * 简单查询工具类测试
@@ -79,6 +80,12 @@ public class SimpleQueryTest extends BaseDbTest<EntityMapper> {
         map.put("ruben", Arrays.asList(ruben, ruben2));
         map.put("ruben", Arrays.asList(ruben, ruben2));
         Assert.isTrue(nameUsersMap.equals(map), "Ops!");
         Assert.isTrue(nameUsersMap.equals(map), "Ops!");
 
 
+        // 解锁高级玩法:
+        // 获取Map<name,List<id>>
+        Map<String, List<Long>> nameIdMap = SimpleQuery.group(Wrappers.lambdaQuery(), Entity::getName, Collectors.mapping(Entity::getId, Collectors.toList()));
+        // 获取Map<name,个数>
+        Map<String, Long> nameCountMap = SimpleQuery.group(Wrappers.lambdaQuery(), Entity::getName, Collectors.counting());
+        // ...超多花样
     }
     }
 
 
     @Override
     @Override