瀏覽代碼

多租户 insert into select

miemie 4 年之前
父節點
當前提交
7c2594989c

+ 42 - 13
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -115,20 +115,26 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
             return;
         }
         List<Column> columns = insert.getColumns();
+        if (CollectionUtils.isEmpty(columns)) {
+            // 针对不给列名的insert 不处理
+            return;
+        }
         String tenantIdColumn = tenantLineHandler.getTenantIdColumn();
         if (columns.stream().map(Column::getColumnName).anyMatch(i -> i.equals(tenantIdColumn))) {
+            // 针对已给出租户列的insert 不处理
             return;
         }
         columns.add(new Column(tenantLineHandler.getTenantIdColumn()));
-        if (insert.getSelect() != null) {
-            processPlainSelect((PlainSelect) insert.getSelect().getSelectBody(), true);
+        Select select = insert.getSelect();
+        if (select != null) {
+            this.processInsertSelect(select.getSelectBody());
         } else if (insert.getItemsList() != null) {
             // fixed github pull/295
             ItemsList itemsList = insert.getItemsList();
             if (itemsList instanceof MultiExpressionList) {
                 ((MultiExpressionList) itemsList).getExprList().forEach(el -> el.getExpressions().add(tenantLineHandler.getTenantId()));
             } else {
-                ((ExpressionList) insert.getItemsList()).getExpressions().add(tenantLineHandler.getTenantId());
+                ((ExpressionList) itemsList).getExpressions().add(tenantLineHandler.getTenantId());
             }
         } else {
             throw ExceptionUtils.mpe("Failed to process multiple-table update, please exclude the tableName or statementId");
@@ -178,29 +184,52 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         return equalsTo;
     }
 
+
     /**
-     * 处理 PlainSelect
+     * 处理 insert into select
+     * <p>
+     * 进入这里表示需要 insert 的表启用了多租户,则 select 的表都启动了
+     *
+     * @param selectBody SelectBody
      */
-    protected void processPlainSelect(PlainSelect plainSelect) {
-        processPlainSelect(plainSelect, false);
+    protected void processInsertSelect(SelectBody selectBody) {
+        PlainSelect plainSelect = (PlainSelect) selectBody;
+        FromItem fromItem = plainSelect.getFromItem();
+        if (fromItem instanceof Table) {
+            Table fromTable = (Table) fromItem;
+            plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
+            appendSelectItem(plainSelect.getSelectItems());
+        } else if (fromItem instanceof SubSelect) {
+            SubSelect subSelect = (SubSelect) fromItem;
+            appendSelectItem(plainSelect.getSelectItems());
+            processInsertSelect(subSelect.getSelectBody());
+        }
     }
 
     /**
-     * 处理 PlainSelect
+     * 追加 SelectItem
      *
-     * @param plainSelect ignore
-     * @param addColumn   是否添加租户列,insert into select语句中需要
+     * @param selectItems SelectItem
+     */
+    protected void appendSelectItem(List<SelectItem> selectItems) {
+        if (CollectionUtils.isEmpty(selectItems)) return;
+        if (selectItems.size() == 1) {
+            SelectItem item = selectItems.get(0);
+            if (item instanceof AllColumns || item instanceof AllTableColumns) return;
+        }
+        selectItems.add(new SelectExpressionItem(new Column(tenantLineHandler.getTenantIdColumn())));
+    }
+
+    /**
+     * 处理 PlainSelect
      */
-    protected void processPlainSelect(PlainSelect plainSelect, boolean addColumn) {
+    protected void processPlainSelect(PlainSelect plainSelect) {
         FromItem fromItem = plainSelect.getFromItem();
         if (fromItem instanceof Table) {
             Table fromTable = (Table) fromItem;
             if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
                 //#1186 github
                 plainSelect.setWhere(builderExpression(plainSelect.getWhere(), fromTable));
-                if (addColumn) {
-                    plainSelect.getSelectItems().add(new SelectExpressionItem(new Column(tenantLineHandler.getTenantIdColumn())));
-                }
             }
         } else {
             processFromItem(fromItem);

+ 22 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java

@@ -27,8 +27,30 @@ class TenantLineInnerInterceptorTest {
 
     @Test
     void insert() {
+        // plain
         assertSql("insert into entity (id,name) value (?,?)",
             "INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, 1)");
+        // 无 insert的列
+        assertSql("insert into entity value (?,?)",
+            "INSERT INTO entity VALUES (?, ?)");
+        // 自己加了insert的列
+        assertSql("insert into entity (id,name,tenant_id) value (?,?,?)",
+            "INSERT INTO entity (id, name, tenant_id) VALUES (?, ?, ?)");
+        // insert into select
+        assertSql("insert into entity (id,name) select id,name from entity2",
+            "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM entity2 WHERE tenant_id = 1");
+
+        assertSql("insert into entity (id,name) select * from entity2",
+            "INSERT INTO entity (id, name, tenant_id) SELECT * FROM entity2 WHERE tenant_id = 1");
+
+        assertSql("insert into entity (id,name) select id,name from (select id,name from entity3) t",
+            "INSERT INTO entity (id, name, tenant_id) SELECT id, name, tenant_id FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t");
+
+        assertSql("insert into entity (id,name) select * from (select id,name from entity3) t",
+            "INSERT INTO entity (id, name, tenant_id) SELECT * FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t");
+
+        assertSql("insert into entity (id,name) select t.* from (select id,name from entity3) t",
+            "INSERT INTO entity (id, name, tenant_id) SELECT t.* FROM (SELECT id, name, tenant_id FROM entity3 WHERE tenant_id = 1) t");
     }
 
     @Test