Browse Source

[优化] lambdaQueryWrapper 的 select 优化

miemie 6 years ago
parent
commit
b2dc3e94c9

+ 22 - 7
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractLambdaWrapper.java

@@ -15,17 +15,18 @@
  */
 package com.baomidou.mybatisplus.core.conditions;
 
-import com.baomidou.mybatisplus.core.toolkit.Assert;
-import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
-import com.baomidou.mybatisplus.core.toolkit.LambdaUtils;
-import com.baomidou.mybatisplus.core.toolkit.StringUtils;
+import com.baomidou.mybatisplus.core.toolkit.*;
+import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import com.baomidou.mybatisplus.core.toolkit.support.SerializedLambda;
 
+import java.util.Arrays;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Optional;
 
+import static java.util.stream.Collectors.joining;
+
 /**
  * <p>
  * Lambda 语法使用 Wrapper
@@ -39,7 +40,7 @@ import java.util.Optional;
 public abstract class AbstractLambdaWrapper<T, Children extends AbstractLambdaWrapper<T, Children>>
     extends AbstractWrapper<T, SFunction<T, ?>, Children> {
 
-    private Map<String, String> columnMap = null;
+    private Map<String, ColumnCache> columnMap = null;
     private boolean initColumnMap = false;
 
     @Override
@@ -51,12 +52,25 @@ public abstract class AbstractLambdaWrapper<T, Children extends AbstractLambdaWr
         }
     }
 
+    @Override
+    protected String columnsToString(SFunction<T, ?>... columns) {
+        return columnsToString(true, columns);
+    }
+
+    protected String columnsToString(boolean onlyColumn, SFunction<T, ?>... columns) {
+        return Arrays.stream(columns).map(i -> columnToString(i, onlyColumn)).collect(joining(StringPool.COMMA));
+    }
+
     @Override
     protected String columnToString(SFunction<T, ?> column) {
-        return getColumn(LambdaUtils.resolve(column));
+        return columnToString(column, true);
+    }
+
+    protected String columnToString(SFunction<T, ?> column, boolean onlyColumn) {
+        return getColumn(LambdaUtils.resolve(column), onlyColumn);
     }
 
-    private String getColumn(SerializedLambda lambda) {
+    private String getColumn(SerializedLambda lambda, boolean onlyColumn) {
         String fieldName = StringUtils.resolveFieldName(lambda.getImplMethodName());
         if (!initColumnMap || !columnMap.containsKey(fieldName.toUpperCase(Locale.ENGLISH))) {
             String entityClassName = lambda.getImplClassName();
@@ -66,6 +80,7 @@ public abstract class AbstractLambdaWrapper<T, Children extends AbstractLambdaWr
             initColumnMap = true;
         }
         return Optional.ofNullable(columnMap.get(fieldName.toUpperCase(Locale.ENGLISH)))
+            .map(onlyColumn ? ColumnCache::getColumn : ColumnCache::getColumnSelect)
             .orElseThrow(() -> ExceptionUtils.mpe("your property named %s cannot find the corresponding database column name!", fieldName));
     }
 }

+ 1 - 1
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/query/LambdaQueryWrapper.java

@@ -74,7 +74,7 @@ public class LambdaQueryWrapper<T> extends AbstractLambdaWrapper<T, LambdaQueryW
     @Override
     public final LambdaQueryWrapper<T> select(SFunction<T, ?>... columns) {
         if (ArrayUtils.isNotEmpty(columns)) {
-            this.sqlSelect.setStringValue(this.columnsToString(columns));
+            this.sqlSelect.setStringValue(columnsToString(false, columns));
         }
         return typedThis;
     }

+ 19 - 14
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/LambdaUtils.java

@@ -17,6 +17,7 @@
 package com.baomidou.mybatisplus.core.toolkit;
 
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
+import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;
 import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
 import com.baomidou.mybatisplus.core.toolkit.support.SerializedLambda;
 
@@ -39,7 +40,7 @@ import static java.util.Locale.ENGLISH;
  */
 public final class LambdaUtils {
 
-    private static final Map<String, Map<String, String>> LAMBDA_CACHE = new ConcurrentHashMap<>();
+    private static final Map<String, Map<String, ColumnCache>> LAMBDA_CACHE = new ConcurrentHashMap<>();
 
     /**
      * SerializedLambda 反序列化缓存
@@ -81,13 +82,13 @@ public final class LambdaUtils {
     /**
      * 保存缓存信息
      *
-     * @param className 类名
-     * @param property  属性
-     * @param sqlSelect 字段搜索
+     * @param className   类名
+     * @param property    属性
+     * @param columnCache 字段信息
      */
-    private static void saveCache(String className, String property, String sqlSelect) {
-        Map<String, String> cacheMap = LAMBDA_CACHE.getOrDefault(className, new HashMap<>());
-        cacheMap.put(property, sqlSelect);
+    private static void saveCache(String className, String property, ColumnCache columnCache) {
+        Map<String, ColumnCache> cacheMap = LAMBDA_CACHE.getOrDefault(className, new HashMap<>());
+        cacheMap.put(property, columnCache);
         LAMBDA_CACHE.put(className, cacheMap);
     }
 
@@ -99,24 +100,28 @@ public final class LambdaUtils {
      * @param tableInfo 表信息
      * @return 缓存 map
      */
-    private static Map<String, String> createLambdaMap(TableInfo tableInfo, Class clazz) {
-        Map<String, String> map = new HashMap<>();
+    private static Map<String, ColumnCache> createLambdaMap(TableInfo tableInfo, Class clazz) {
+        Map<String, ColumnCache> map = new HashMap<>();
         String keyProperty = tableInfo.getKeyProperty();
         if (StringUtils.isNotEmpty(keyProperty)) {
             keyProperty = keyProperty.toUpperCase(ENGLISH);
             String keyColumn = tableInfo.getKeyColumn();
+            String keySelect = tableInfo.getSqlSelect();
+            ColumnCache cache = new ColumnCache(keyColumn, keySelect);
             if (tableInfo.getClazz() != clazz) {
-                saveCache(tableInfo.getClazz().getName(), keyProperty, keyColumn);
+                saveCache(tableInfo.getClazz().getName(), keyProperty, cache);
             }
-            map.put(keyProperty, keyColumn);
+            map.put(keyProperty, cache);
         }
         tableInfo.getFieldList().forEach(i -> {
             String property = i.getProperty().toUpperCase(ENGLISH);
             String column = i.getColumn();
+            String columnSelect = i.getSqlSelect(tableInfo.getDbType());
+            ColumnCache cache = new ColumnCache(column, columnSelect);
             if (i.getClazz() != clazz) {
-                saveCache(i.getClazz().getName(), property, column);
+                saveCache(i.getClazz().getName(), property, cache);
             }
-            map.put(property, column);
+            map.put(property, cache);
         });
         return map;
     }
@@ -129,7 +134,7 @@ public final class LambdaUtils {
      * @param entityClassName 实体类名
      * @return 缓存 map
      */
-    public static Map<String, String> getColumnMap(String entityClassName) {
+    public static Map<String, ColumnCache> getColumnMap(String entityClassName) {
         return LAMBDA_CACHE.getOrDefault(entityClassName, Collections.emptyMap());
     }
 }

+ 22 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/support/ColumnCache.java

@@ -0,0 +1,22 @@
+package com.baomidou.mybatisplus.core.toolkit.support;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+
+/**
+ * @author miemie
+ * @since 2018-12-30
+ */
+@Data
+@AllArgsConstructor
+public class ColumnCache {
+
+    /**
+     * 使用 column
+     */
+    private String column;
+    /**
+     * 查询 column
+     */
+    private String columnSelect;
+}

+ 20 - 2
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/kotlin/AbstractKtWrapper.kt

@@ -16,8 +16,12 @@
 package com.baomidou.mybatisplus.extension.kotlin
 
 import com.baomidou.mybatisplus.core.conditions.AbstractWrapper
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils
 import com.baomidou.mybatisplus.core.toolkit.LambdaUtils
+import com.baomidou.mybatisplus.core.toolkit.StringPool
+import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache
 import java.util.*
+import java.util.stream.Collectors.joining
 import kotlin.reflect.KProperty
 
 /**
@@ -29,14 +33,28 @@ import kotlin.reflect.KProperty
  */
 abstract class AbstractKtWrapper<T, This : AbstractKtWrapper<T, This>> : AbstractWrapper<T, KProperty<*>, This>() {
 
-    private var columnMap: Map<String, String>? = null
+    private var columnMap: Map<String, ColumnCache>? = null
 
     override fun initEntityClass() {
         super.initEntityClass()
         columnMap = LambdaUtils.getColumnMap(this.entityClass.name)
     }
 
+    override fun columnsToString(vararg columns: KProperty<*>): String {
+        return columnsToString(true, *columns)
+    }
+
+    fun columnsToString(onlyColumn: Boolean, vararg columns: KProperty<*>): String {
+        return Arrays.stream(columns).map { i -> columnToString(i, onlyColumn) }.collect(joining(StringPool.COMMA))
+    }
+
     override fun columnToString(kProperty: KProperty<*>): String? {
-        return columnMap?.get(kProperty.name.toUpperCase(Locale.ENGLISH))
+        return columnToString(kProperty, true)
+    }
+
+    fun columnToString(kProperty: KProperty<*>, onlyColumn: Boolean): String? {
+        return Optional.ofNullable(columnMap?.get(kProperty.name.toUpperCase(Locale.ENGLISH)))
+            .map(if (onlyColumn) ColumnCache::getColumn else ColumnCache::getColumnSelect)
+            .orElseThrow { ExceptionUtils.mpe("your property named %s cannot find the corresponding database column name!", kProperty.name) }
     }
 }

+ 3 - 3
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/kotlin/KtQueryWrapper.kt

@@ -43,8 +43,8 @@ class KtQueryWrapper<T : Any> : AbstractKtWrapper<T, KtQueryWrapper<T>>, Query<K
         this.initNeed()
     }
 
-    internal constructor(entity: T, entityClass: Class<T>?, sqlSelect: String?, paramNameSeq: AtomicInteger, paramNameValuePairs: Map<String, Any>,
-                         mergeSegments: MergeSegments) {
+    internal constructor(entity: T, entityClass: Class<T>?, sqlSelect: String?, paramNameSeq: AtomicInteger,
+                         paramNameValuePairs: Map<String, Any>, mergeSegments: MergeSegments) {
         this.entity = entity
         this.paramNameSeq = paramNameSeq
         this.paramNameValuePairs = paramNameValuePairs
@@ -61,7 +61,7 @@ class KtQueryWrapper<T : Any> : AbstractKtWrapper<T, KtQueryWrapper<T>>, Query<K
     @SafeVarargs
     override fun select(vararg columns: KProperty<*>): KtQueryWrapper<T> {
         if (ArrayUtils.isNotEmpty(columns)) {
-            this.sqlSelect = this.columnsToString(*columns)
+            this.sqlSelect = this.columnsToString(false, *columns)
         }
         return typedThis
     }

+ 2 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/base/entity/mysql/MysqlData.java

@@ -20,4 +20,6 @@ public class MysqlData {
     private Integer group;
     @TableField(strategy = FieldStrategy.NOT_EMPTY)
     private String testStr;
+    @TableField("lambda_str")
+    private String yaHoStr;
 }

+ 31 - 28
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/mysql/MysqlTestDataMapperTest.java

@@ -1,21 +1,5 @@
 package com.baomidou.mybatisplus.test.mysql;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-import javax.annotation.Resource;
-
-import org.junit.Assert;
-import org.junit.FixMethodOrder;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.MethodSorters;
-import org.springframework.test.context.ContextConfiguration;
-import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
-
 import com.alibaba.fastjson.JSON;
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
@@ -32,6 +16,18 @@ import com.baomidou.mybatisplus.test.base.enums.TestEnum;
 import com.baomidou.mybatisplus.test.base.mapper.commons.CommonDataMapper;
 import com.baomidou.mybatisplus.test.base.mapper.commons.CommonLogicDataMapper;
 import com.baomidou.mybatisplus.test.base.mapper.mysql.MysqlDataMapper;
+import com.baomidou.mybatisplus.test.mysql.config.MysqlDb;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.FixMethodOrder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.MethodSorters;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
+
+import javax.annotation.Resource;
+import java.util.*;
 
 
 /**
@@ -54,20 +50,21 @@ public class MysqlTestDataMapperTest {
     @Resource
     private MysqlDataMapper mysqlMapper;
 
-//    @BeforeClass
-//    public static void init() throws Exception {
-//        MysqlDb.initMysqlData();
-//        System.out.println("init success");
-//    }
+    @BeforeClass
+    public static void init() throws Exception {
+        MysqlDb.initMysqlData();
+        System.out.println("init success");
+    }
 
     @Test
     public void a1_insertForeach() {
         for (int i = 1; i < 20; i++) {
             Long id = (long) i;
-            commonMapper.insert(new CommonData().setTestInt(i).setTestStr(String.format("第%s条数据", i)).setId(id)
+            String str = String.format("第%s条数据", i);
+            commonMapper.insert(new CommonData().setTestInt(i).setTestStr(str).setId(id)
                 .setTestEnum(TestEnum.ONE));
-            commonLogicMapper.insert(new CommonLogicData().setTestInt(i).setTestStr(String.format("第%s条数据", i)).setId(id));
-            mysqlMapper.insert(new MysqlData().setOrder(i).setGroup(i).setId(id).setTestStr(String.format("第%s条数据", i)));
+            commonLogicMapper.insert(new CommonLogicData().setTestInt(i).setTestStr(str).setId(id));
+            mysqlMapper.insert(new MysqlData().setOrder(i).setGroup(i).setId(id).setTestStr(str).setYaHoStr(str));
         }
     }
 
@@ -77,9 +74,10 @@ public class MysqlTestDataMapperTest {
         List<CommonData> commonDataList = new ArrayList<>();
         List<CommonLogicData> commonLogicDataList = new ArrayList<>();
         for (int i = 0; i < 9; i++) {
-            mysqlDataList.add(new MysqlData().setOrder(i).setGroup(i).setTestStr(i + "条"));
-            commonDataList.add(new CommonData().setTestInt(i).setTestEnum(TestEnum.TWO).setTestStr(i + "条"));
-            commonLogicDataList.add(new CommonLogicData().setTestInt(i).setTestStr(i + "条"));
+            String str = i + "条";
+            mysqlDataList.add(new MysqlData().setOrder(i).setGroup(i).setTestStr(str).setYaHoStr(str));
+            commonDataList.add(new CommonData().setTestInt(i).setTestEnum(TestEnum.TWO).setTestStr(str));
+            commonLogicDataList.add(new CommonLogicData().setTestInt(i).setTestStr(str));
         }
         Assert.assertEquals(9, mysqlMapper.insertBatchSomeColumn(mysqlDataList));
         Assert.assertEquals(9, commonMapper.insertBatchSomeColumn(commonDataList));
@@ -355,7 +353,7 @@ public class MysqlTestDataMapperTest {
     @Test
     public void d11_testWrapperCustomSql() {
         // 1. 只有 order by 或者 last
-        mysqlMapper.getAll(Wrappers.<MysqlData>query().lambda().orderByDesc(MysqlData::getOrder).last("limit 1"));
+        mysqlMapper.getAll(Wrappers.<MysqlData>lambdaQuery().orderByDesc(MysqlData::getOrder).last("limit 1"));
         // 2. 什么都没有情况
         mysqlMapper.getAll(Wrappers.emptyWrapper());
         // 3. 只有 where 条件
@@ -379,6 +377,11 @@ public class MysqlTestDataMapperTest {
 //        commonMapper.selectPage(new Page<>(1, 10), wrapper);
     }
 
+    @Test
+    public void testLambdaColumnCache() {
+        mysqlMapper.selectList(Wrappers.<MysqlData>lambdaQuery().select(MysqlData::getYaHoStr)).forEach(System.out::print);
+    }
+
     @Test
     public void testUpdateNotEntity() {
         mysqlMapper.update(null, Wrappers.<MysqlData>lambdaUpdate().set(MysqlData::getOrder, 1));

+ 8 - 10
mybatis-plus/src/test/resources/mysql/test_data.ddl.sql

@@ -10,8 +10,7 @@ CREATE TABLE common_data (
     version   integer default 0,
     test_enum integer,
     tenant_id bigint
-)
-    ENGINE = innodb
+)ENGINE = innodb
 DEFAULT CHARSET = utf8;
 
 CREATE TABLE common_logic_data (
@@ -22,15 +21,14 @@ CREATE TABLE common_logic_data (
     u_time   datetime,
     deleted  tinyint default 0,
     version  integer default 0
-)
-    ENGINE = innodb
+)ENGINE = innodb
 DEFAULT CHARSET = utf8;
 
 CREATE TABLE mysql_data (
-    id       BIGINT primary key,
-    `order`  integer,
-    `group`  integer,
-    test_str varchar(255)
-)
-    ENGINE = innodb
+    id         BIGINT primary key,
+    `order`    integer,
+    `group`    integer,
+    test_str   varchar(255),
+    lambda_str varchar(255) default ''
+)ENGINE = innodb
 DEFAULT CHARSET = utf8;