Explorar el Código

feat: 解决多租户插件与源仓库代码冲突问题

HouKunLin hace 2 años
padre
commit
5721250370

+ 3 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/handler/MultiDataPermissionHandler.java

@@ -40,7 +40,9 @@ public interface MultiDataPermissionHandler extends DataPermissionHandler {
     }
 
     /**
-     * 获取数据权限 SQL 片段
+     * 获取数据权限 SQL 片段。
+     * <p>旧的 {@link MultiDataPermissionHandler#getSqlSegment(Expression, String)} 方法第一个参数包含所有的 where 条件信息,如果 return 了 null 会覆盖原有的 where 数据,</p>
+     * <p>新版的 {@link MultiDataPermissionHandler#getSqlSegment(Table, String)} 方法不能覆盖原有的 where 数据,如果 return 了 null 则表示不追加任何 where 条件</p>
      *
      * @param table             所执行的数据库表信息,可以通过此参数获取表名和表别名
      * @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法

+ 21 - 17
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/BaseMultiTableInnerInterceptor.java

@@ -17,7 +17,10 @@ package com.baomidou.mybatisplus.extension.plugins.inner;
 
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
-import lombok.*;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NoArgsConstructor;
+import lombok.ToString;
 import net.sf.jsqlparser.expression.*;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
@@ -67,9 +70,6 @@ public abstract class BaseMultiTableInnerInterceptor extends JsqlParserSupport i
     protected Expression andExpression(Table table, Expression where, final String whereSegment) {
         //获得where条件表达式
         final Expression expression = buildTableExpression(table, whereSegment);
-        if (expression == null) {
-            return where;
-        }
         if (null != where) {
             if (where instanceof OrExpression) {
                 return new AndExpression(expression, new Parenthesis(where));
@@ -137,15 +137,17 @@ public abstract class BaseMultiTableInnerInterceptor extends JsqlParserSupport i
      * 处理where条件内的子查询
      * <p>
      * 支持如下:
-     * 1. in
-     * 2. =
-     * 3. >
-     * 4. <
-     * 5. >=
-     * 6. <=
-     * 7. <>
-     * 8. EXISTS
-     * 9. NOT EXISTS
+     * <ol>
+     *     <li>in</li>
+     *     <li>=</li>
+     *     <li>&gt;</li>
+     *     <li>&lt;</li>
+     *     <li>&gt;=</li>
+     *     <li>&lt;=</li>
+     *     <li>&lt;&gt;</li>
+     *     <li>EXISTS</li>
+     *     <li>NOT EXISTS</li>
+     * </ol>
      * <p>
      * 前提条件:
      * 1. 子查询必须放在小括号中
@@ -193,10 +195,11 @@ public abstract class BaseMultiTableInnerInterceptor extends JsqlParserSupport i
     protected void processSelectItem(SelectItem selectItem, final String whereSegment) {
         if (selectItem instanceof SelectExpressionItem) {
             SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
-            if (selectExpressionItem.getExpression() instanceof SubSelect) {
-                processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody(), whereSegment);
-            } else if (selectExpressionItem.getExpression() instanceof Function) {
-                processFunction((Function) selectExpressionItem.getExpression(), whereSegment);
+            final Expression expression = selectExpressionItem.getExpression();
+            if (expression instanceof SubSelect) {
+                processSelectBody(((SubSelect) expression).getSelectBody(), whereSegment);
+            } else if (expression instanceof Function) {
+                processFunction((Function) expression, whereSegment);
             }
         }
     }
@@ -382,6 +385,7 @@ public abstract class BaseMultiTableInnerInterceptor extends JsqlParserSupport i
             .filter(Objects::nonNull)
             .collect(Collectors.toList());
 
+        // 没有表需要处理直接返回
         if (CollectionUtils.isEmpty(expressions)) {
             return currentExpression;
         }

+ 40 - 388
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -17,14 +17,15 @@ package com.baomidou.mybatisplus.extension.plugins.inner;
 
 import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
 import com.baomidou.mybatisplus.core.toolkit.*;
-import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
 import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
 import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
 import lombok.*;
-import net.sf.jsqlparser.expression.*;
-import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
-import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
-import net.sf.jsqlparser.expression.operators.relational.*;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.StringValue;
+import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.expression.operators.relational.ItemsList;
+import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.delete.Delete;
@@ -41,8 +42,8 @@ import org.apache.ibatis.session.RowBounds;
 
 import java.sql.Connection;
 import java.sql.SQLException;
-import java.util.*;
-import java.util.stream.Collectors;
+import java.util.List;
+import java.util.Properties;
 
 /**
  * @author hubin
@@ -54,13 +55,15 @@ import java.util.stream.Collectors;
 @ToString(callSuper = true)
 @EqualsAndHashCode(callSuper = true)
 @SuppressWarnings({"rawtypes"})
-public class TenantLineInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
+public class TenantLineInnerInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
 
     private TenantLineHandler tenantLineHandler;
 
     @Override
     public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
-        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
+        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
+            return;
+        }
         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
         mpBs.sql(parserSingle(mpBs.sql(), null));
     }
@@ -81,28 +84,11 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
 
     @Override
     protected void processSelect(Select select, int index, String sql, Object obj) {
-        processSelectBody(select.getSelectBody());
+        final String whereSegment = (String) obj;
+        processSelectBody(select.getSelectBody(), whereSegment);
         List<WithItem> withItemsList = select.getWithItemsList();
         if (!CollectionUtils.isEmpty(withItemsList)) {
-            withItemsList.forEach(this::processSelectBody);
-        }
-    }
-
-    protected void processSelectBody(SelectBody selectBody) {
-        if (selectBody == null) {
-            return;
-        }
-        if (selectBody instanceof PlainSelect) {
-            processPlainSelect((PlainSelect) selectBody);
-        } else if (selectBody instanceof WithItem) {
-            WithItem withItem = (WithItem) selectBody;
-            processSelectBody(withItem.getSubSelect().getSelectBody());
-        } else {
-            SetOperationList operationList = (SetOperationList) selectBody;
-            List<SelectBody> selectBodyList = operationList.getSelects();
-            if (CollectionUtils.isNotEmpty(selectBodyList)) {
-                selectBodyList.forEach(this::processSelectBody);
-            }
+            withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
         }
     }
 
@@ -135,7 +121,7 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
 
         Select select = insert.getSelect();
         if (select != null) {
-            processInsertSelect(select.getSelectBody());
+            this.processInsertSelect(select.getSelectBody(), (String) obj);
         } else if (insert.getItemsList() != null) {
             // fixed github pull/295
             ItemsList itemsList = insert.getItemsList();
@@ -160,7 +146,7 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
             // 过滤退出执行
             return;
         }
-        update.setWhere(this.andExpression(table, update.getWhere()));
+        update.setWhere(this.andExpression(table, update.getWhere(), (String) obj));
     }
 
     /**
@@ -172,28 +158,9 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
             // 过滤退出执行
             return;
         }
-        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
+        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere(), (String) obj));
     }
 
-    /**
-     * delete update 语句 where 处理
-     */
-    protected BinaryExpression andExpression(Table table, Expression where) {
-        //获得where条件表达式
-        EqualsTo equalsTo = new EqualsTo();
-        equalsTo.setLeftExpression(this.getAliasColumn(table));
-        equalsTo.setRightExpression(tenantLineHandler.getTenantId());
-        if (null != where) {
-            if (where instanceof OrExpression) {
-                return new AndExpression(equalsTo, new Parenthesis(where));
-            } else {
-                return new AndExpression(equalsTo, where);
-            }
-        }
-        return equalsTo;
-    }
-
-
     /**
      * 处理 insert into select
      * <p>
@@ -201,17 +168,17 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      *
      * @param selectBody SelectBody
      */
-    protected void processInsertSelect(SelectBody selectBody) {
+    protected void processInsertSelect(SelectBody selectBody, final String whereSegment) {
         PlainSelect plainSelect = (PlainSelect) selectBody;
         FromItem fromItem = plainSelect.getFromItem();
         if (fromItem instanceof Table) {
             // fixed gitee pulls/141 duplicate update
-            processPlainSelect(plainSelect);
+            processPlainSelect(plainSelect, whereSegment);
             appendSelectItem(plainSelect.getSelectItems());
         } else if (fromItem instanceof SubSelect) {
             SubSelect subSelect = (SubSelect) fromItem;
             appendSelectItem(plainSelect.getSelectItems());
-            processInsertSelect(subSelect.getSelectBody());
+            processInsertSelect(subSelect.getSelectBody(), whereSegment);
         }
     }
 
@@ -233,336 +200,6 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         selectItems.add(new SelectExpressionItem(new Column(tenantLineHandler.getTenantIdColumn())));
     }
 
-    /**
-     * 处理 PlainSelect
-     */
-    protected void processPlainSelect(PlainSelect plainSelect) {
-        //#3087 github
-        List<SelectItem> selectItems = plainSelect.getSelectItems();
-        if (CollectionUtils.isNotEmpty(selectItems)) {
-            selectItems.forEach(this::processSelectItem);
-        }
-
-        // 处理 where 中的子查询
-        Expression where = plainSelect.getWhere();
-        processWhereSubSelect(where);
-
-        // 处理 fromItem
-        FromItem fromItem = plainSelect.getFromItem();
-        List<Table> list = processFromItem(fromItem);
-        List<Table> mainTables = new ArrayList<>(list);
-
-        // 处理 join
-        List<Join> joins = plainSelect.getJoins();
-        if (CollectionUtils.isNotEmpty(joins)) {
-            mainTables = processJoins(mainTables, joins);
-        }
-
-        // 当有 mainTable 时,进行 where 条件追加
-        if (CollectionUtils.isNotEmpty(mainTables)) {
-            plainSelect.setWhere(builderExpression(where, mainTables));
-        }
-    }
-
-    private List<Table> processFromItem(FromItem fromItem) {
-        // 处理括号括起来的表达式
-        while (fromItem instanceof ParenthesisFromItem) {
-            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
-        }
-
-        List<Table> mainTables = new ArrayList<>();
-        // 无 join 时的处理逻辑
-        if (fromItem instanceof Table) {
-            Table fromTable = (Table) fromItem;
-            mainTables.add(fromTable);
-        } else if (fromItem instanceof SubJoin) {
-            // SubJoin 类型则还需要添加上 where 条件
-            List<Table> tables = processSubJoin((SubJoin) fromItem);
-            mainTables.addAll(tables);
-        } else {
-            // 处理下 fromItem
-            processOtherFromItem(fromItem);
-        }
-        return mainTables;
-    }
-
-    /**
-     * 处理where条件内的子查询
-     * <p>
-     * 支持如下:
-     * 1. in
-     * 2. =
-     * 3. >
-     * 4. <
-     * 5. >=
-     * 6. <=
-     * 7. <>
-     * 8. EXISTS
-     * 9. NOT EXISTS
-     * <p>
-     * 前提条件:
-     * 1. 子查询必须放在小括号中
-     * 2. 子查询一般放在比较操作符的右边
-     *
-     * @param where where 条件
-     */
-    protected void processWhereSubSelect(Expression where) {
-        if (where == null) {
-            return;
-        }
-        if (where instanceof FromItem) {
-            processOtherFromItem((FromItem) where);
-            return;
-        }
-        if (where.toString().indexOf("SELECT") > 0) {
-            // 有子查询
-            if (where instanceof BinaryExpression) {
-                // 比较符号 , and , or , 等等
-                BinaryExpression expression = (BinaryExpression) where;
-                processWhereSubSelect(expression.getLeftExpression());
-                processWhereSubSelect(expression.getRightExpression());
-            } else if (where instanceof InExpression) {
-                // in
-                InExpression expression = (InExpression) where;
-                Expression inExpression = expression.getRightExpression();
-                if (inExpression instanceof SubSelect) {
-                    processSelectBody(((SubSelect) inExpression).getSelectBody());
-                }
-            } else if (where instanceof ExistsExpression) {
-                // exists
-                ExistsExpression expression = (ExistsExpression) where;
-                processWhereSubSelect(expression.getRightExpression());
-            } else if (where instanceof NotExpression) {
-                // not exists
-                NotExpression expression = (NotExpression) where;
-                processWhereSubSelect(expression.getExpression());
-            } else if (where instanceof Parenthesis) {
-                Parenthesis expression = (Parenthesis) where;
-                processWhereSubSelect(expression.getExpression());
-            }
-        }
-    }
-
-    protected void processSelectItem(SelectItem selectItem) {
-        if (selectItem instanceof SelectExpressionItem) {
-            SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
-            if (selectExpressionItem.getExpression() instanceof SubSelect) {
-                processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
-            } else if (selectExpressionItem.getExpression() instanceof Function) {
-                processFunction((Function) selectExpressionItem.getExpression());
-            }
-        }
-    }
-
-    /**
-     * 处理函数
-     * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
-     * <p> fixed gitee pulls/141</p>
-     *
-     * @param function
-     */
-    protected void processFunction(Function function) {
-        ExpressionList parameters = function.getParameters();
-        if (parameters != null) {
-            parameters.getExpressions().forEach(expression -> {
-                if (expression instanceof SubSelect) {
-                    processSelectBody(((SubSelect) expression).getSelectBody());
-                } else if (expression instanceof Function) {
-                    processFunction((Function) expression);
-                }
-            });
-        }
-    }
-
-    /**
-     * 处理子查询等
-     */
-    protected void processOtherFromItem(FromItem fromItem) {
-        // 去除括号
-        while (fromItem instanceof ParenthesisFromItem) {
-            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
-        }
-
-        if (fromItem instanceof SubSelect) {
-            SubSelect subSelect = (SubSelect) fromItem;
-            if (subSelect.getSelectBody() != null) {
-                processSelectBody(subSelect.getSelectBody());
-            }
-        } else if (fromItem instanceof ValuesList) {
-            logger.debug("Perform a subQuery, if you do not give us feedback");
-        } else if (fromItem instanceof LateralSubSelect) {
-            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
-            if (lateralSubSelect.getSubSelect() != null) {
-                SubSelect subSelect = lateralSubSelect.getSubSelect();
-                if (subSelect.getSelectBody() != null) {
-                    processSelectBody(subSelect.getSelectBody());
-                }
-            }
-        }
-    }
-
-    /**
-     * 处理 sub join
-     *
-     * @param subJoin subJoin
-     * @return Table subJoin 中的主表
-     */
-    private List<Table> processSubJoin(SubJoin subJoin) {
-        List<Table> mainTables = new ArrayList<>();
-        if (subJoin.getJoinList() != null) {
-            List<Table> list = processFromItem(subJoin.getLeft());
-            mainTables.addAll(list);
-            mainTables = processJoins(mainTables, subJoin.getJoinList());
-        }
-        return mainTables;
-    }
-
-    /**
-     * 处理 joins
-     *
-     * @param mainTables 可以为 null
-     * @param joins      join 集合
-     * @return List<Table> 右连接查询的 Table 列表
-     */
-    private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
-        // join 表达式中最终的主表
-        Table mainTable = null;
-        // 当前 join 的左表
-        Table leftTable = null;
-
-        if (mainTables == null) {
-            mainTables = new ArrayList<>();
-        } else if (mainTables.size() == 1) {
-            mainTable = mainTables.get(0);
-            leftTable = mainTable;
-        }
-
-        //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
-        Deque<List<Table>> onTableDeque = new LinkedList<>();
-        for (Join join : joins) {
-            // 处理 on 表达式
-            FromItem joinItem = join.getRightItem();
-
-            // 获取当前 join 的表,subJoint 可以看作是一张表
-            List<Table> joinTables = null;
-            if (joinItem instanceof Table) {
-                joinTables = new ArrayList<>();
-                joinTables.add((Table) joinItem);
-            } else if (joinItem instanceof SubJoin) {
-                joinTables = processSubJoin((SubJoin) joinItem);
-            }
-
-            if (joinTables != null) {
-
-                // 如果是隐式内连接
-                if (join.isSimple()) {
-                    mainTables.addAll(joinTables);
-                    continue;
-                }
-
-                // 当前表是否忽略
-                Table joinTable = joinTables.get(0);
-
-                List<Table> onTables = null;
-                // 如果不要忽略,且是右连接,则记录下当前表
-                if (join.isRight()) {
-                    mainTable = joinTable;
-                    if (leftTable != null) {
-                        onTables = Collections.singletonList(leftTable);
-                    }
-                } else if (join.isLeft()) {
-                    onTables = Collections.singletonList(joinTable);
-                } else if (join.isInner()) {
-                    if (mainTable == null) {
-                        onTables = Collections.singletonList(joinTable);
-                    } else {
-                        onTables = Arrays.asList(mainTable, joinTable);
-                    }
-                    mainTable = null;
-                }
-
-                mainTables = new ArrayList<>();
-                if (mainTable != null) {
-                    mainTables.add(mainTable);
-                }
-
-                // 获取 join 尾缀的 on 表达式列表
-                Collection<Expression> originOnExpressions = join.getOnExpressions();
-                // 正常 join on 表达式只有一个,立刻处理
-                if (originOnExpressions.size() == 1 && onTables != null) {
-                    List<Expression> onExpressions = new LinkedList<>();
-                    onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
-                    join.setOnExpressions(onExpressions);
-                    leftTable = joinTable;
-                    continue;
-                }
-                // 表名压栈,忽略的表压入 null,以便后续不处理
-                onTableDeque.push(onTables);
-                // 尾缀多个 on 表达式的时候统一处理
-                if (originOnExpressions.size() > 1) {
-                    Collection<Expression> onExpressions = new LinkedList<>();
-                    for (Expression originOnExpression : originOnExpressions) {
-                        List<Table> currentTableList = onTableDeque.poll();
-                        if (CollectionUtils.isEmpty(currentTableList)) {
-                            onExpressions.add(originOnExpression);
-                        } else {
-                            onExpressions.add(builderExpression(originOnExpression, currentTableList));
-                        }
-                    }
-                    join.setOnExpressions(onExpressions);
-                }
-                leftTable = joinTable;
-            } else {
-                processOtherFromItem(joinItem);
-                leftTable = null;
-            }
-        }
-
-        return mainTables;
-    }
-
-    /**
-     * 处理条件
-     */
-    protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
-        // 没有表需要处理直接返回
-        if (CollectionUtils.isEmpty(tables)) {
-            return currentExpression;
-        }
-        // 构造每张表的条件
-        List<Table> tempTables = tables.stream()
-            .filter(x -> !tenantLineHandler.ignoreTable(x.getName()))
-            .collect(Collectors.toList());
-
-        // 没有表需要处理直接返回
-        if (CollectionUtils.isEmpty(tempTables)) {
-            return currentExpression;
-        }
-
-        Expression tenantId = tenantLineHandler.getTenantId();
-        List<EqualsTo> equalsTos = tempTables.stream()
-            .map(item -> new EqualsTo(getAliasColumn(item), tenantId))
-            .collect(Collectors.toList());
-
-        // 注入的表达式
-        Expression injectExpression = equalsTos.get(0);
-        // 如果有多表,则用 and 连接
-        if (equalsTos.size() > 1) {
-            for (int i = 1; i < equalsTos.size(); i++) {
-                injectExpression = new AndExpression(injectExpression, equalsTos.get(i));
-            }
-        }
-
-        if (currentExpression == null) {
-            return injectExpression;
-        }
-        if (currentExpression instanceof OrExpression) {
-            return new AndExpression(new Parenthesis(currentExpression), injectExpression);
-        } else {
-            return new AndExpression(currentExpression, injectExpression);
-        }
-    }
-
     /**
      * 租户字段别名设置
      * <p>tenantId 或 tableAlias.tenantId</p>
@@ -575,9 +212,11 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         // 禁止 `为了兼容隐式内连接,没有别名时条件就需要加上表名`
         // 该起别名就要起别名
         if (table.getAlias() != null) {
-            column.append(table.getAlias().getName()).append(StringPool.DOT);
+            column.append(table.getAlias().getName());
+        } else {
+            column.append(table.getName());
         }
-        column.append(tenantLineHandler.getTenantIdColumn());
+        column.append(StringPool.DOT).append(tenantLineHandler.getTenantIdColumn());
         return new Column(column.toString());
     }
 
@@ -586,6 +225,19 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         PropertyMapper.newInstance(properties).whenNotBlank("tenantLineHandler",
             ClassUtils::newInstance, this::setTenantLineHandler);
     }
-}
-
 
+    /**
+     * 构建租户条件表达式
+     *
+     * @param table        表对象
+     * @param whereSegment 所属Mapper对象全路径(在原租户拦截器功能中,这个参数并不需要参与相关判断)
+     * @return 租户条件表达式
+     */
+    @Override
+    public Expression buildTableExpression(final Table table, final String whereSegment) {
+        if (tenantLineHandler.ignoreTable(table.getName())) {
+            return null;
+        }
+        return new EqualsTo(getAliasColumn(table), tenantLineHandler.getTenantId());
+    }
+}