소스 검색

Wrapper 优化ing

miemie 5 년 전
부모
커밋
0ffc249f68

+ 3 - 2
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractWrapper.java

@@ -35,7 +35,8 @@ import java.util.function.BiPredicate;
 import java.util.function.Consumer;
 
 import static com.baomidou.mybatisplus.core.enums.SqlKeyword.*;
-import static com.baomidou.mybatisplus.core.enums.WrapperKeyword.*;
+import static com.baomidou.mybatisplus.core.enums.WrapperKeyword.APPLY;
+import static com.baomidou.mybatisplus.core.enums.WrapperKeyword.BRACKET;
 import static java.util.stream.Collectors.joining;
 
 /**
@@ -364,7 +365,7 @@ public abstract class AbstractWrapper<T, R, Children extends AbstractWrapper<T,
         if (condition) {
             final Children instance = instance();
             consumer.accept(instance);
-            return doIt(true, LEFT_BRACKET, instance, RIGHT_BRACKET);
+            return doIt(true, BRACKET, instance);
         }
         return typedThis;
     }

+ 7 - 1
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/Wrapper.java

@@ -52,7 +52,9 @@ public abstract class Wrapper<T> implements ISqlSegment {
         return null;
     }
 
-    public String getSqlFirst() { return null; }
+    public String getSqlFirst() {
+        return null;
+    }
 
     /**
      * 获取 MergeSegments
@@ -158,4 +160,8 @@ public abstract class Wrapper<T> implements ISqlSegment {
     public boolean isEmptyOfEntity() {
         return !nonEmptyOfEntity();
     }
+
+    public String getTargetSql() {
+        return getSqlSegment().replaceAll("#\\{.+?}", "?");
+    }
 }

+ 1 - 2
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/MatchSegment.java

@@ -37,8 +37,7 @@ public enum MatchSegment {
     EXISTS(i -> i == SqlKeyword.EXISTS),
     HAVING(i -> i == SqlKeyword.HAVING),
     APPLY(i -> i == WrapperKeyword.APPLY),
-    LEFT_BRACKET(i -> i == WrapperKeyword.LEFT_BRACKET),
-    RIGHT_BRACKET(i -> i == WrapperKeyword.RIGHT_BRACKET);
+    BRACKET(i -> i == WrapperKeyword.BRACKET);
 
     private final Predicate<ISqlSegment> predicate;
 

+ 5 - 6
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/segments/NormalSegmentList.java

@@ -71,15 +71,14 @@ public class NormalSegmentList extends AbstractISegmentList {
                 list.add(MatchSegment.EXISTS.match(firstSegment) ? 0 : 1, SqlKeyword.NOT);
                 executeNot = true;
             }
-            if (!MatchSegment.AND_OR.match(lastValue) && !isEmpty()) {
-                add(SqlKeyword.AND);
-            }
             if (MatchSegment.APPLY.match(firstSegment)) {
                 list.remove(0);
             }
-            if (MatchSegment.LEFT_BRACKET.match(firstSegment) && MatchSegment.RIGHT_BRACKET.match(lastSegment)) {
+            if (MatchSegment.BRACKET.match(firstSegment)) {
                 list.remove(0);
-                list.remove(list.size() - 1);
+            }
+            if (!MatchSegment.AND_OR.match(lastValue) && !isEmpty()) {
+                add(SqlKeyword.AND);
             }
         }
         return true;
@@ -91,6 +90,6 @@ public class NormalSegmentList extends AbstractISegmentList {
             removeAndFlushLast();
         }
         final String str = this.stream().map(ISqlSegment::getSqlSegment).collect(Collectors.joining(SPACE));
-        return (str.startsWith(LEFT_BRACKET) && str.endsWith(RIGHT_BRACKET)) ? str : (LEFT_BRACKET + str + RIGHT_BRACKET);
+        return (LEFT_BRACKET + str + RIGHT_BRACKET);
     }
 }

+ 1 - 3
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/enums/WrapperKeyword.java

@@ -16,7 +16,6 @@
 package com.baomidou.mybatisplus.core.enums;
 
 import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
-import com.baomidou.mybatisplus.core.toolkit.StringPool;
 
 /**
  * wrapper 内部使用枚举
@@ -29,8 +28,7 @@ public enum WrapperKeyword implements ISqlSegment {
      * 只用作于辨识,不用于其他
      */
     APPLY(null),
-    LEFT_BRACKET(StringPool.LEFT_BRACKET),
-    RIGHT_BRACKET(StringPool.RIGHT_BRACKET);
+    BRACKET(null);
 
     private final String keyword;
 

+ 47 - 32
mybatis-plus-core/src/test/java/com/baomidou/mybatisplus/core/test/WrapperTest.java

@@ -16,13 +16,13 @@
 package com.baomidou.mybatisplus.core.test;
 
 import com.baomidou.mybatisplus.core.MybatisConfiguration;
-import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
 import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
 import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import org.apache.ibatis.builder.MapperBuilderAssistant;
+import org.assertj.core.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.time.LocalDate;
@@ -34,9 +34,11 @@ class WrapperTest {
         System.out.println(message);
     }
 
-    private void logSqlSegment(String explain, ISqlSegment sqlSegment) {
+    private void logSqlSegment(String explain, Wrapper<?> wrapper, String targetSql) {
         System.out.println(String.format(" ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓   ->(%s)<-   ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓", explain));
-        System.out.println(sqlSegment.getSqlSegment());
+        System.out.println(wrapper.getSqlSegment());
+        System.out.println(wrapper.getTargetSql());
+        Assertions.assertThat(wrapper.getTargetSql().trim()).isEqualTo(targetSql);
     }
 
     private <T> void logParams(QueryWrapper<T> wrapper) {
@@ -59,15 +61,17 @@ class WrapperTest {
     @Test
     void test1() {
         QueryWrapper<User> ew = new QueryWrapper<User>() {
-          /**
-           *  serialVersionUID
-           */
-          private static final long serialVersionUID = 4719966531503901490L;
-        {
-            eq("xxx", 123);
-            and(i -> i.eq("andx", 65444).le("ande", 66666));
-            ne("xxx", 222);
-        }};
+            /**
+             * serialVersionUID
+             */
+            private static final long serialVersionUID = 4719966531503901490L;
+
+            {
+                eq("xxx", 123);
+                and(i -> i.eq("andx", 65444).le("ande", 66666));
+                ne("xxx", 222);
+            }
+        };
         log(ew.getSqlSegment());
         log(ew.getSqlSegment());
         ew.gt("x22", 333);
@@ -101,40 +105,51 @@ class WrapperTest {
     @Test
     void testQueryWrapper() {
         logSqlSegment("去除第一个 or,以及自动拼接 and,以及手动拼接 or,以及去除最后的多个or", new QueryWrapper<User>().or()
-            .ge("age", 3).or().ge("age", 3).ge("age", 3).or().or().or().or());
+                .ge("age", 3).or().ge("age", 3).ge("age", 3).or().or().or().or(),
+            "(age >= ? OR age >= ? AND age >= ?)");
 
         logSqlSegment("多个 or 相连接,去除多余的 or", new QueryWrapper<User>()
-            .ge("age", 3).or().or().or().ge("age", 3).or().or().ge("age", 3));
+                .ge("age", 3).or().or().or().ge("age", 3).or().or().ge("age", 3),
+            "(age >= ? OR age >= ? OR age >= ?)");
 
         logSqlSegment("嵌套,正常嵌套", new QueryWrapper<User>()
-            .nested(i -> i.eq("id", 1)).eq("id", 1));
+                .nested(i -> i.eq("id", 1)).eq("id", 1),
+            "((id = ?) AND id = ?)");
 
         logSqlSegment("嵌套,第一个套外的 and 自动消除", new QueryWrapper<User>()
-            .and(i -> i.eq("id", 1)).eq("id", 1));
+                .and(i -> i.eq("id", 1)).eq("id", 1),
+            "((id = ?) AND id = ?)");
 
         logSqlSegment("嵌套,多层嵌套", new QueryWrapper<User>()
-            .and(i -> i.eq("id", 1).and(j -> j.eq("id", 1))));
+                .and(i -> i.eq("id", 1).and(j -> j.eq("id", 1))),
+            "((id = ?) AND (id = ?))");
 
         logSqlSegment("嵌套,第一个套外的 or 自动消除", new QueryWrapper<User>()
-            .or(i -> i.eq("id", 1)).eq("id", 1));
+                .or(i -> i.eq("id", 1)).eq("id", 1),
+            "((id = ?) AND id = ?)");
 
         logSqlSegment("嵌套,套内外自动拼接 and", new QueryWrapper<User>()
-            .eq("id", 11).and(i -> i.eq("id", 1)).eq("id", 1));
+                .eq("id", 11).and(i -> i.eq("id", 1)).eq("id", 1),
+            "(id = ? AND (id = ?) AND id = ?)");
 
         logSqlSegment("嵌套,套内外手动拼接 or,去除套内第一个 or", new QueryWrapper<User>()
-            .eq("id", 11).or(i -> i.or().eq("id", 1)).or().eq("id", 1));
+                .eq("id", 11).or(i -> i.or().eq("id", 1)).or().eq("id", 1),
+            "(id = ? OR (id = ?) OR id = ?)");
 
         logSqlSegment("多个 order by 和 group by 拼接,自动优化顺序,last方法拼接在最后", new QueryWrapper<User>()
-            .eq("id", 11)
-            .last("limit 1")
-            .orderByAsc("id", "name", "sex").orderByDesc("age", "txl")
-            .groupBy("id", "name", "sex").groupBy("id", "name"));
+                .eq("id", 11)
+                .last("limit 1")
+                .orderByAsc("id", "name", "sex").orderByDesc("age", "txl")
+                .groupBy("id", "name", "sex").groupBy("id", "name"),
+            "(id = ?) GROUP BY id,name,sex,id,name ORDER BY id ASC,name ASC,sex ASC,age DESC,txl DESC limit 1");
 
         logSqlSegment("只存在 order by", new QueryWrapper<User>()
-            .orderByAsc("id", "name", "sex").orderByDesc("age", "txl"));
+                .orderByAsc("id", "name", "sex").orderByDesc("age", "txl"),
+            "ORDER BY id ASC,name ASC,sex ASC,age DESC,txl DESC");
 
         logSqlSegment("只存在 group by", new QueryWrapper<User>()
-            .groupBy("id", "name", "sex").groupBy("id", "name"));
+                .groupBy("id", "name", "sex").groupBy("id", "name"),
+            "GROUP BY id,name,sex,id,name");
     }
 
     @Test
@@ -147,7 +162,7 @@ class WrapperTest {
             .or().between("id", 1, 2).notBetween("id", 1, 3)
             .like("id", 1).notLike("id", 1)
             .or().likeLeft("id", 1).likeRight("id", 1);
-        logSqlSegment("测试 Compare 下的方法", queryWrapper);
+        logSqlSegment("测试 Compare 下的方法", queryWrapper, null);
         logParams(queryWrapper);
     }
 
@@ -161,7 +176,7 @@ class WrapperTest {
             .in("inArray").notIn("notInArray", 1, 2, 3)
             .inSql("inSql", "1,2,3,4,5").notInSql("inSql", "1,2,3,4,5")
             .having("sum(age) > {0}", 1).having("id is not null");
-        logSqlSegment("测试 Func 下的方法", queryWrapper);
+        logSqlSegment("测试 Func 下的方法", queryWrapper, null);
         logParams(queryWrapper);
     }
 
@@ -173,7 +188,7 @@ class WrapperTest {
             .apply("date_format(column,'%Y-%m-%d') = {0}", LocalDate.now())
             .or().exists("select id from table where age = 1")
             .or().notExists("select id from table where age = 1");
-        logSqlSegment("测试 Join 下的方法", queryWrapper);
+        logSqlSegment("测试 Join 下的方法", queryWrapper, null);
         logParams(queryWrapper);
     }
 
@@ -183,7 +198,7 @@ class WrapperTest {
             .and(i -> i.eq("id", 1).nested(j -> j.ne("id", 2)))
             .or(i -> i.eq("id", 1).and(j -> j.ne("id", 2)))
             .nested(i -> i.eq("id", 1).or(j -> j.ne("id", 2)));
-        logSqlSegment("测试 Nested 下的方法", queryWrapper);
+        logSqlSegment("测试 Nested 下的方法", queryWrapper, null);
         logParams(queryWrapper);
     }
 
@@ -193,14 +208,14 @@ class WrapperTest {
         QueryWrapper<User> queryWrapper = new QueryWrapper<>();
         queryWrapper.lambda().eq(User::getName, "sss");
         queryWrapper.lambda().eq(User::getName, "sss2");
-        logSqlSegment("测试 PluralLambda", queryWrapper);
+        logSqlSegment("测试 PluralLambda", queryWrapper, null);
         logParams(queryWrapper);
     }
 
     @Test
     void testInEmptyColl() {
         QueryWrapper<User> queryWrapper = new QueryWrapper<User>().in("xxx", Collections.emptyList());
-        logSqlSegment("测试 empty 的 coll", queryWrapper);
+        logSqlSegment("测试 empty 的 coll", queryWrapper, null);
     }
 
     private List<Object> getList() {

+ 3 - 3
mybatis-plus-extension/src/main/kotlin/com/baomidou/mybatisplus/extension/kotlin/KtQueryWrapper.kt

@@ -50,12 +50,12 @@ class KtQueryWrapper<T : Any> : AbstractKtWrapper<T, KtQueryWrapper<T>>, Query<K
     }
 
     internal constructor(entity: T?, entityClass: Class<T>, sqlSelect: SharedString, paramNameSeq: AtomicInteger,
-                         paramNameValuePairs: Map<String, Any>, mergeSegments: MergeSegments, columnMap: Map<String, ColumnCache>,
+                         paramNameValuePairs: Map<String, Any>, columnMap: Map<String, ColumnCache>,
                          lastSql: SharedString, sqlComment: SharedString, sqlFirst: SharedString) {
         this.entity = entity
         this.paramNameSeq = paramNameSeq
         this.paramNameValuePairs = paramNameValuePairs
-        this.expression = mergeSegments
+        this.expression = MergeSegments()
         this.columnMap = columnMap
         this.sqlSelect = sqlSelect
         this.entityClass = entityClass
@@ -110,7 +110,7 @@ class KtQueryWrapper<T : Any> : AbstractKtWrapper<T, KtQueryWrapper<T>>, Query<K
      * 故 sqlSelect 不向下传递
      */
     override fun instance(): KtQueryWrapper<T> {
-        return KtQueryWrapper(entity, entityClass, sqlSelect, paramNameSeq, paramNameValuePairs, expression, columnMap,
+        return KtQueryWrapper(entity, entityClass, sqlSelect, paramNameSeq, paramNameValuePairs, columnMap,
             SharedString.emptyString(), SharedString.emptyString(), SharedString.emptyString())
     }
 }

+ 4 - 4
mybatis-plus-extension/src/main/kotlin/com/baomidou/mybatisplus/extension/kotlin/KtUpdateWrapper.kt

@@ -50,12 +50,12 @@ class KtUpdateWrapper<T : Any> : AbstractKtWrapper<T, KtUpdateWrapper<T>>, Updat
     }
 
     internal constructor(entity: T?, paramNameSeq: AtomicInteger, paramNameValuePairs: Map<String, Any>,
-                         mergeSegments: MergeSegments, columnMap: Map<String, ColumnCache>,
-                         lastSql: SharedString, sqlComment: SharedString, sqlFirst: SharedString) {
+                         columnMap: Map<String, ColumnCache>, lastSql: SharedString, sqlComment: SharedString,
+                         sqlFirst: SharedString) {
         this.entity = entity
         this.paramNameSeq = paramNameSeq
         this.paramNameValuePairs = paramNameValuePairs
-        this.expression = mergeSegments
+        this.expression = MergeSegments()
         this.columnMap = columnMap
         this.lastSql = lastSql
         this.sqlComment = sqlComment
@@ -82,7 +82,7 @@ class KtUpdateWrapper<T : Any> : AbstractKtWrapper<T, KtUpdateWrapper<T>>, Updat
     }
 
     override fun instance(): KtUpdateWrapper<T> {
-        return KtUpdateWrapper(entity, paramNameSeq, paramNameValuePairs, expression, columnMap,
+        return KtUpdateWrapper(entity, paramNameSeq, paramNameValuePairs, columnMap,
             SharedString.emptyString(), SharedString.emptyString(), SharedString.emptyString())
     }
 }

+ 32 - 36
mybatis-plus-extension/src/test/kotlin/com/baomidou/mybatisplus/extension/kotlin/FixIssue1986.kt

@@ -24,6 +24,7 @@ class FixIssue1986 {
     fun test1986() {
         val wrapper = fillQueryWrapper(OpportunityWebPageQuery())
         var sql = wrapper.toSql()
+        print(sql)
         // (valid = ? AND district_id = ? AND name LIKE ? OR phone LIKE ? AND (valid = ? AND district_id = ?))
     }
 
@@ -35,49 +36,44 @@ fun KtQueryWrapper<*>.toSql() = sqlSegment?.replace(Regex("#\\{.+?}"), "?") ?: "
  * 用户代码
  */
 private fun fillQueryWrapper(query: OpportunityWebPageQuery): KtQueryWrapper<CustomerEntity> {
-    val wrapper = KtQueryWrapper(CustomerEntity::class.java)
-    wrapper.eq(CustomerEntity::valid, query.valid)
-    if (!query.districtId.isNullOrEmpty()) {
-        wrapper.eq(CustomerEntity::districtId, query.districtId)
-    } else if (!query.cityId.isNullOrEmpty()) {
-        wrapper.eq(CustomerEntity::cityId, query.cityId)
-    } else if (!query.provinceId.isNullOrEmpty()) {
-        wrapper.eq(CustomerEntity::provinceId, query.provinceId)
-    } else if (!query.region.isNullOrEmpty() && RegionType.of(query.region!!.toInt())?.areaCodes?.toList()?.isNullOrEmpty() != false) {
-        wrapper.`in`(CustomerEntity::provinceId, RegionType.of(query.region!!.toInt())?.areaCodes?.toList())
-    }
-    wrapper.entity = CustomerEntity()
-    if (!query.searchKey.isNullOrEmpty()) {
-        wrapper.and { itemWrapper ->
-            itemWrapper.like(CustomerEntity::name, query.searchKey).or()
-                    .like(CustomerEntity::phone, query.searchKey)
+//    return KtQueryWrapper(CustomerEntity::class.java)
+//        .eq(CustomerEntity::valid, query.valid)
+//        .eq(!query.districtId.isNullOrEmpty(), CustomerEntity::districtId, query.districtId)
+//        .eq(!query.cityId.isNullOrEmpty(), CustomerEntity::cityId, query.cityId)
+//        .eq(!query.provinceId.isNullOrEmpty(), CustomerEntity::provinceId, query.provinceId)
+//        .`in`(!query.region.isNullOrEmpty() && RegionType.of(query.region!!.toInt())?.areaCodes?.toList()?.isNullOrEmpty() != false,
+//            CustomerEntity::provinceId, RegionType.of(query.region!!.toInt())?.areaCodes?.toList())
+//        .and(!query.searchKey.isNullOrEmpty()) { i ->
+//            i.like(CustomerEntity::name, query.searchKey).or()
+//                .like(CustomerEntity::phone, query.searchKey)
+//        }
+//        .eq(query.opportunityType != 0, CustomerEntity::type, query.opportunityType)
+    return KtQueryWrapper(CustomerEntity::class.java)
+        .eq(CustomerEntity::valid, query.valid)
+        .and { i ->
+            i.like(CustomerEntity::name, query.searchKey).or().like(CustomerEntity::phone, query.searchKey)
         }
-    }
-    if (query.opportunityType != 0) {
-        wrapper.eq(CustomerEntity::type, query.opportunityType)
-    }
-    return wrapper
 }
 
 // 用户代码模拟补全
 class CustomerEntity(
-        var valid: String? = null,
-        var name: String? = null,
-        var phone: String? = null,
-        var provinceId: String? = null,
-        var districtId: String? = null,
-        var cityId: String? = null,
-        var type: Int = 0
+    var valid: String? = null,
+    var name: String? = null,
+    var phone: String? = null,
+    var provinceId: String? = null,
+    var districtId: String? = null,
+    var cityId: String? = null,
+    var type: Int = 0
 )
 
 class OpportunityWebPageQuery(
-        var valid: String = "123",
-        var searchKey: String? = "123",
-        var provinceId: String? = "123",
-        var districtId: String? = "123",
-        var cityId: String? = "123",
-        var region: String? = "123",
-        var opportunityType: Int = 0
+    var valid: String = "123",
+    var searchKey: String? = "123",
+    var provinceId: String? = "123",
+    var districtId: String? = "123",
+    var cityId: String? = "123",
+    var region: String? = "123",
+    var opportunityType: Int = 0
 )
 
 object RegionType {
@@ -86,4 +82,4 @@ object RegionType {
 
 class Area(i: Int) {
     var areaCodes: IntArray? = IntArray(i) { it }
-}
+}