浏览代码

[优化] kt类优化

miemie 6 年之前
父节点
当前提交
8341421d40

+ 1 - 1
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/conditions/AbstractWrapper.java

@@ -83,7 +83,7 @@ public abstract class AbstractWrapper<T, R, Children extends AbstractWrapper<T,
     }
 
     protected void initEntityClass() {
-        if (this.entity != null) {
+        if (this.entityClass == null && this.entity != null) {
             this.entityClass = (Class<T>) entity.getClass();
         }
     }

+ 1 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/kotlin/AbstractKtWrapper.kt

@@ -38,7 +38,7 @@ abstract class AbstractKtWrapper<T, This : AbstractKtWrapper<T, This>> : Abstrac
 
     override fun initEntityClass() {
         super.initEntityClass()
-        columnMap = LambdaUtils.getColumnMap(this.entityClass.name)
+        columnMap = LambdaUtils.getColumnMap(this.checkEntityClass.name)
     }
 
     override fun columnsToString(vararg columns: KProperty<*>): String {

+ 7 - 2
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/kotlin/KtQueryWrapper.kt

@@ -38,7 +38,12 @@ class KtQueryWrapper<T : Any> : AbstractKtWrapper<T, KtQueryWrapper<T>>, Query<K
     private var sqlSelect: String? = null
 
     constructor(entity: T) {
-        this.entity = entity
+        this.setEntity(entity)
+        this.initNeed()
+    }
+
+    constructor(entityClass: Class<T>) {
+        this.entityClass = entityClass
         this.initEntityClass()
         this.initNeed()
     }
@@ -46,11 +51,11 @@ class KtQueryWrapper<T : Any> : AbstractKtWrapper<T, KtQueryWrapper<T>>, Query<K
     internal constructor(entity: T, entityClass: Class<T>?, sqlSelect: String?, paramNameSeq: AtomicInteger,
                          paramNameValuePairs: Map<String, Any>, mergeSegments: MergeSegments) {
         this.entity = entity
+        this.entityClass = entityClass
         this.paramNameSeq = paramNameSeq
         this.paramNameValuePairs = paramNameValuePairs
         this.expression = mergeSegments
         this.sqlSelect = sqlSelect
-        this.entityClass = entityClass
     }
 
     /**

+ 6 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/kotlin/KtUpdateWrapper.kt

@@ -38,7 +38,12 @@ class KtUpdateWrapper<T : Any> : AbstractKtWrapper<T, KtUpdateWrapper<T>>, Updat
     private val sqlSet = ArrayList<String>()
 
     constructor(entity: T) {
-        this.entity = entity
+        this.setEntity(entity)
+        this.initNeed()
+    }
+
+    constructor(entityClass: Class<T>) {
+        this.entityClass = entityClass
         this.initEntityClass()
         this.initNeed()
     }

+ 10 - 6
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/kotlin/WrapperTest.kt

@@ -17,10 +17,16 @@ package com.baomidou.mybatisplus.extension.kotlin
 
 import com.baomidou.mybatisplus.core.conditions.ISqlSegment
 import com.baomidou.mybatisplus.core.toolkit.TableInfoHelper
+import org.junit.jupiter.api.BeforeEach
 import org.junit.jupiter.api.Test
 
 class WrapperTest {
 
+    @BeforeEach
+    fun beforeInit() {
+        TableInfoHelper.initTableInfo(null, User::class.java)
+    }
+
     private fun logSqlSegment(explain: String, sqlSegment: ISqlSegment) {
         println(String.format(" ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓   ->(%s)<-   ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓", explain))
         println(sqlSegment.sqlSegment)
@@ -28,15 +34,13 @@ class WrapperTest {
 
     @Test
     fun testLambdaQuery() {
-        TableInfoHelper.initTableInfo(null, User::class.java)
-        val queryWrapper = KtQueryWrapper(User()).eq(User::name, "sss").eq(User::roleId, "sss2")
-        logSqlSegment("测试 LambdaKt", queryWrapper)
+        logSqlSegment("测试1.1 LambdaKt", KtQueryWrapper(User()).eq(User::name, "sss").eq(User::roleId, "sss2"))
+        logSqlSegment("测试1.2 LambdaKt", KtQueryWrapper(User::class.java).eq(User::name, "sss").eq(User::roleId, "sss2"))
     }
 
     @Test
     fun testLambdaUpdate() {
-        TableInfoHelper.initTableInfo(null, User::class.java)
-        val updateWrapperKt = KtUpdateWrapper(User()).eq(User::name, "sss").eq(User::roleId, "sss2")
-        logSqlSegment("测试 LambdaKt", updateWrapperKt)
+        logSqlSegment("测试2.1 LambdaKt", KtUpdateWrapper(User()).eq(User::name, "sss").eq(User::roleId, "sss2"))
+        logSqlSegment("测试2.2 LambdaKt", KtUpdateWrapper(User::class.java).eq(User::name, "sss").eq(User::roleId, "sss2"))
     }
 }