Parcourir la source

fix: 租户插件支持`update set subSelect`的情况

miemie il y a 2 ans
Parent
commit
35f509c9fb

+ 1 - 0
changelog-temp.md

@@ -1,3 +1,4 @@
 fix: 修复在选择springdoc文档注释时entity描述异常
 fix: 在主键的`IdType`为`AUTO`的情况下,`Table#getAllInsertSqlColumnMaybeIf("xx.")`所生成sql错误问题
 perf: `wrapper#apply`支持配置`mapping`比如`column={0,javaType=int,jdbcType=NUMERIC,typeHandler=xxx.xxx.MyTypeHandler}`
+fix: 租户插件支持`update set subSelect`的情况

+ 26 - 14
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -20,7 +20,10 @@ import com.baomidou.mybatisplus.core.toolkit.*;
 import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
 import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
 import lombok.*;
-import net.sf.jsqlparser.expression.*;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.Parenthesis;
+import net.sf.jsqlparser.expression.RowConstructor;
+import net.sf.jsqlparser.expression.StringValue;
 import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
 import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
 import net.sf.jsqlparser.expression.operators.relational.ItemsList;
@@ -31,6 +34,7 @@ import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.insert.Insert;
 import net.sf.jsqlparser.statement.select.*;
 import net.sf.jsqlparser.statement.update.Update;
+import net.sf.jsqlparser.statement.update.UpdateSet;
 import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.mapping.BoundSql;
@@ -128,25 +132,25 @@ public class TenantLineInnerInterceptor extends BaseMultiTableInnerInterceptor i
             Expression tenantId = tenantLineHandler.getTenantId();
             if (itemsList instanceof MultiExpressionList) {
                 ((MultiExpressionList) itemsList).getExpressionLists().forEach(el -> el.getExpressions().add(tenantId));
-            }else {
+            } else {
                 List<Expression> expressions = ((ExpressionList) itemsList).getExpressions();
-                if(CollectionUtils.isNotEmpty(expressions)){//fix github issue 4998 jsqlparse 4.5 批量insert ItemsList不是MultiExpressionList 了,需要特殊处理
-                    int len=expressions.size();
-                    for(int i=0;i<len;i++){
-                        Expression expression=expressions.get(i);
-                        if( expression instanceof RowConstructor){
+                if (CollectionUtils.isNotEmpty(expressions)) {//fix github issue 4998 jsqlparse 4.5 批量insert ItemsList不是MultiExpressionList 了,需要特殊处理
+                    int len = expressions.size();
+                    for (int i = 0; i < len; i++) {
+                        Expression expression = expressions.get(i);
+                        if (expression instanceof RowConstructor) {
                             ((RowConstructor) expression).getExprList().getExpressions().add(tenantId);
-                        }else if (expression instanceof Parenthesis){
-                            RowConstructor rowConstructor=new RowConstructor()
-                                .withExprList(new ExpressionList(((Parenthesis) expression).getExpression(),tenantId));
-                            expressions.set(i,rowConstructor);
-                        }else {
-                            if(len-1==i){ // (?,?) 只有最后一个expre的时候才拼接tenantId
+                        } else if (expression instanceof Parenthesis) {
+                            RowConstructor rowConstructor = new RowConstructor()
+                                .withExprList(new ExpressionList(((Parenthesis) expression).getExpression(), tenantId));
+                            expressions.set(i, rowConstructor);
+                        } else {
+                            if (len - 1 == i) { // (?,?) 只有最后一个expre的时候才拼接tenantId
                                 expressions.add(tenantId);
                             }
                         }
                     }
-                }else{
+                } else {
                     expressions.add(tenantId);
                 }
             }
@@ -165,6 +169,14 @@ public class TenantLineInnerInterceptor extends BaseMultiTableInnerInterceptor i
             // 过滤退出执行
             return;
         }
+        ArrayList<UpdateSet> sets = update.getUpdateSets();
+        if (!CollectionUtils.isEmpty(sets)) {
+            sets.forEach(us -> us.getExpressions().forEach(ex -> {
+                if (ex instanceof SubSelect) {
+                    processSelectBody(((SubSelect) ex).getSelectBody(), (String) obj);
+                }
+            }));
+        }
         update.setWhere(this.andExpression(table, update.getWhere(), (String) obj));
     }
 

+ 7 - 6
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java

@@ -37,18 +37,14 @@ class TenantLineInnerInterceptorTest {
 
     @Test
     void insert() {
-        assertSql("insert into entity (id) values (?)",
-            "INSERT INTO entity (id, tenant_id) VALUES (?, 1)");
-        assertSql("insert into entity (id) values (?),(?) ",
-            "INSERT INTO entity (id, tenant_id) VALUES (?, 1), (?, 1)");
-        assertSql("insert into entity (id) values (?),(?) ",
-            "INSERT INTO entity (id, tenant_id) VALUES (?, 1),(?, 1)");
         // plain
         assertSql("insert into entity (id) values (?)",
             "INSERT INTO entity (id, tenant_id) VALUES (?, 1)");
         assertSql("insert into entity (id,name) values (?,?)",
             "INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, 1)");
         // batch
+        assertSql("insert into entity (id) values (?),(?)",
+            "INSERT INTO entity (id, tenant_id) VALUES (?, 1), (?, 1)");
         assertSql("insert into entity (id,name) values (?,?),(?,?)",
             "INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, 1), (?, ?, 1)");
         // 无 insert的列
@@ -84,6 +80,11 @@ class TenantLineInnerInterceptorTest {
     void update() {
         assertSql("update entity set name = ? where id = ?",
             "UPDATE entity SET name = ? WHERE id = ? AND tenant_id = 1");
+
+        // set subSelect
+        assertSql("UPDATE entity e SET e.cq = (SELECT e1.total FROM entity e1 WHERE e1.id = ?) WHERE e.id = ?",
+            "UPDATE entity e SET e.cq = (SELECT e1.total FROM entity e1 WHERE e1.id = ? AND e1.tenant_id = 1) " +
+                "WHERE e.id = ? AND e.tenant_id = 1");
     }
 
     @Test