瀏覽代碼

!219 数据权限插件支持SQL多表(JOIN连表查询)查询场景
Merge pull request !219 from 侯坤林/3.0

青苗 2 年之前
父節點
當前提交
6c31e16396

+ 1 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/handler/DataPermissionHandler.java

@@ -30,7 +30,7 @@ public interface DataPermissionHandler {
      *
      * @param where             待执行 SQL Where 条件表达式
      * @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
-     * @return JSqlParser 条件表达式
+     * @return JSqlParser 条件表达式,返回的条件表达式会覆盖原有的条件表达式
      */
     Expression getSqlSegment(Expression where, String mappedStatementId);
 }

+ 52 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/handler/MultiDataPermissionHandler.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.baomidou.mybatisplus.extension.plugins.handler;
+
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.schema.Table;
+
+/**
+ * 支持多表的数据权限处理器
+ *
+ * @author houkunlin
+ * @since 3.5.2 +
+ */
+public interface MultiDataPermissionHandler extends DataPermissionHandler {
+    /**
+     * 为兼容旧版数据权限处理器,继承了 {@link DataPermissionHandler} 但是新的多表数据权限处理又不会调用此方法,因此标记过时
+     *
+     * @param where             待执行 SQL Where 条件表达式
+     * @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
+     * @return JSqlParser 条件表达式
+     * @deprecated 新的多表数据权限处理不会调用此方法,因此标记过时
+     */
+    @Deprecated
+    @Override
+    default Expression getSqlSegment(Expression where, String mappedStatementId) {
+        return where;
+    }
+
+    /**
+     * 获取数据权限 SQL 片段。
+     * <p>旧的 {@link MultiDataPermissionHandler#getSqlSegment(Expression, String)} 方法第一个参数包含所有的 where 条件信息,如果 return 了 null 会覆盖原有的 where 数据,</p>
+     * <p>新版的 {@link MultiDataPermissionHandler#getSqlSegment(Table, String)} 方法不能覆盖原有的 where 数据,如果 return 了 null 则表示不追加任何 where 条件</p>
+     *
+     * @param table             所执行的数据库表信息,可以通过此参数获取表名和表别名
+     * @param mappedStatementId Mybatis MappedStatement Id 根据该参数可以判断具体执行方法
+     * @return JSqlParser 条件表达式,返回的条件表达式会拼接在原有的表达式后面(不会覆盖原有的表达式)
+     */
+    Expression getSqlSegment(final Table table, final String mappedStatementId);
+}

+ 420 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/BaseMultiTableInnerInterceptor.java

@@ -0,0 +1,420 @@
+/*
+ * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.baomidou.mybatisplus.extension.plugins.inner;
+
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
+import lombok.Data;
+import lombok.EqualsAndHashCode;
+import lombok.NoArgsConstructor;
+import lombok.ToString;
+import net.sf.jsqlparser.expression.*;
+import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
+import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
+import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.expression.operators.relational.InExpression;
+import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.select.*;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+/**
+ * 多表条件处理基对象,从原有的 {@link TenantLineInnerInterceptor} 拦截器中提取出来
+ *
+ * @author houkunlin
+ * @since 3.5.2
+ */
+@Data
+@NoArgsConstructor
+@ToString(callSuper = true)
+@EqualsAndHashCode(callSuper = true)
+@SuppressWarnings({"rawtypes"})
+public abstract class BaseMultiTableInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
+
+    protected void processSelectBody(SelectBody selectBody, final String whereSegment) {
+        if (selectBody == null) {
+            return;
+        }
+        if (selectBody instanceof PlainSelect) {
+            processPlainSelect((PlainSelect) selectBody, whereSegment);
+        } else if (selectBody instanceof WithItem) {
+            WithItem withItem = (WithItem) selectBody;
+            processSelectBody(withItem.getSubSelect().getSelectBody(), whereSegment);
+        } else {
+            SetOperationList operationList = (SetOperationList) selectBody;
+            List<SelectBody> selectBodyList = operationList.getSelects();
+            if (CollectionUtils.isNotEmpty(selectBodyList)) {
+                selectBodyList.forEach(body -> processSelectBody(body, whereSegment));
+            }
+        }
+    }
+
+    /**
+     * delete update 语句 where 处理
+     */
+    protected Expression andExpression(Table table, Expression where, final String whereSegment) {
+        //获得where条件表达式
+        final Expression expression = buildTableExpression(table, whereSegment);
+        if (null != where) {
+            if (where instanceof OrExpression) {
+                return new AndExpression(expression, new Parenthesis(where));
+            } else {
+                return new AndExpression(expression, where);
+            }
+        }
+        return expression;
+    }
+
+    /**
+     * 处理 PlainSelect
+     */
+    protected void processPlainSelect(final PlainSelect plainSelect, final String whereSegment) {
+        //#3087 github
+        List<SelectItem> selectItems = plainSelect.getSelectItems();
+        if (CollectionUtils.isNotEmpty(selectItems)) {
+            selectItems.forEach(selectItem -> processSelectItem(selectItem, whereSegment));
+        }
+
+        // 处理 where 中的子查询
+        Expression where = plainSelect.getWhere();
+        processWhereSubSelect(where, whereSegment);
+
+        // 处理 fromItem
+        FromItem fromItem = plainSelect.getFromItem();
+        List<Table> list = processFromItem(fromItem, whereSegment);
+        List<Table> mainTables = new ArrayList<>(list);
+
+        // 处理 join
+        List<Join> joins = plainSelect.getJoins();
+        if (CollectionUtils.isNotEmpty(joins)) {
+            mainTables = processJoins(mainTables, joins, whereSegment);
+        }
+
+        // 当有 mainTable 时,进行 where 条件追加
+        if (CollectionUtils.isNotEmpty(mainTables)) {
+            plainSelect.setWhere(builderExpression(where, mainTables, whereSegment));
+        }
+    }
+
+    private List<Table> processFromItem(FromItem fromItem, final String whereSegment) {
+        // 处理括号括起来的表达式
+        while (fromItem instanceof ParenthesisFromItem) {
+            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
+        }
+
+        List<Table> mainTables = new ArrayList<>();
+        // 无 join 时的处理逻辑
+        if (fromItem instanceof Table) {
+            Table fromTable = (Table) fromItem;
+            mainTables.add(fromTable);
+        } else if (fromItem instanceof SubJoin) {
+            // SubJoin 类型则还需要添加上 where 条件
+            List<Table> tables = processSubJoin((SubJoin) fromItem, whereSegment);
+            mainTables.addAll(tables);
+        } else {
+            // 处理下 fromItem
+            processOtherFromItem(fromItem, whereSegment);
+        }
+        return mainTables;
+    }
+
+    /**
+     * 处理where条件内的子查询
+     * <p>
+     * 支持如下:
+     * <ol>
+     *     <li>in</li>
+     *     <li>=</li>
+     *     <li>&gt;</li>
+     *     <li>&lt;</li>
+     *     <li>&gt;=</li>
+     *     <li>&lt;=</li>
+     *     <li>&lt;&gt;</li>
+     *     <li>EXISTS</li>
+     *     <li>NOT EXISTS</li>
+     * </ol>
+     * <p>
+     * 前提条件:
+     * 1. 子查询必须放在小括号中
+     * 2. 子查询一般放在比较操作符的右边
+     *
+     * @param where where 条件
+     */
+    protected void processWhereSubSelect(Expression where, final String whereSegment) {
+        if (where == null) {
+            return;
+        }
+        if (where instanceof FromItem) {
+            processOtherFromItem((FromItem) where, whereSegment);
+            return;
+        }
+        if (where.toString().indexOf("SELECT") > 0) {
+            // 有子查询
+            if (where instanceof BinaryExpression) {
+                // 比较符号 , and , or , 等等
+                BinaryExpression expression = (BinaryExpression) where;
+                processWhereSubSelect(expression.getLeftExpression(), whereSegment);
+                processWhereSubSelect(expression.getRightExpression(), whereSegment);
+            } else if (where instanceof InExpression) {
+                // in
+                InExpression expression = (InExpression) where;
+                Expression inExpression = expression.getRightExpression();
+                if (inExpression instanceof SubSelect) {
+                    processSelectBody(((SubSelect) inExpression).getSelectBody(), whereSegment);
+                }
+            } else if (where instanceof ExistsExpression) {
+                // exists
+                ExistsExpression expression = (ExistsExpression) where;
+                processWhereSubSelect(expression.getRightExpression(), whereSegment);
+            } else if (where instanceof NotExpression) {
+                // not exists
+                NotExpression expression = (NotExpression) where;
+                processWhereSubSelect(expression.getExpression(), whereSegment);
+            } else if (where instanceof Parenthesis) {
+                Parenthesis expression = (Parenthesis) where;
+                processWhereSubSelect(expression.getExpression(), whereSegment);
+            }
+        }
+    }
+
+    protected void processSelectItem(SelectItem selectItem, final String whereSegment) {
+        if (selectItem instanceof SelectExpressionItem) {
+            SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
+            final Expression expression = selectExpressionItem.getExpression();
+            if (expression instanceof SubSelect) {
+                processSelectBody(((SubSelect) expression).getSelectBody(), whereSegment);
+            } else if (expression instanceof Function) {
+                processFunction((Function) expression, whereSegment);
+            }
+        }
+    }
+
+    /**
+     * 处理函数
+     * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
+     * <p> fixed gitee pulls/141</p>
+     *
+     * @param function
+     */
+    protected void processFunction(Function function, final String whereSegment) {
+        ExpressionList parameters = function.getParameters();
+        if (parameters != null) {
+            parameters.getExpressions().forEach(expression -> {
+                if (expression instanceof SubSelect) {
+                    processSelectBody(((SubSelect) expression).getSelectBody(), whereSegment);
+                } else if (expression instanceof Function) {
+                    processFunction((Function) expression, whereSegment);
+                }
+            });
+        }
+    }
+
+    /**
+     * 处理子查询等
+     */
+    protected void processOtherFromItem(FromItem fromItem, final String whereSegment) {
+        // 去除括号
+        while (fromItem instanceof ParenthesisFromItem) {
+            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
+        }
+
+        if (fromItem instanceof SubSelect) {
+            SubSelect subSelect = (SubSelect) fromItem;
+            if (subSelect.getSelectBody() != null) {
+                processSelectBody(subSelect.getSelectBody(), whereSegment);
+            }
+        } else if (fromItem instanceof ValuesList) {
+            logger.debug("Perform a subQuery, if you do not give us feedback");
+        } else if (fromItem instanceof LateralSubSelect) {
+            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
+            if (lateralSubSelect.getSubSelect() != null) {
+                SubSelect subSelect = lateralSubSelect.getSubSelect();
+                if (subSelect.getSelectBody() != null) {
+                    processSelectBody(subSelect.getSelectBody(), whereSegment);
+                }
+            }
+        }
+    }
+
+    /**
+     * 处理 sub join
+     *
+     * @param subJoin subJoin
+     * @return Table subJoin 中的主表
+     */
+    private List<Table> processSubJoin(SubJoin subJoin, final String whereSegment) {
+        List<Table> mainTables = new ArrayList<>();
+        if (subJoin.getJoinList() != null) {
+            List<Table> list = processFromItem(subJoin.getLeft(), whereSegment);
+            mainTables.addAll(list);
+            mainTables = processJoins(mainTables, subJoin.getJoinList(), whereSegment);
+        }
+        return mainTables;
+    }
+
+    /**
+     * 处理 joins
+     *
+     * @param mainTables 可以为 null
+     * @param joins      join 集合
+     * @return List<Table> 右连接查询的 Table 列表
+     */
+    private List<Table> processJoins(List<Table> mainTables, List<Join> joins, final String whereSegment) {
+        // join 表达式中最终的主表
+        Table mainTable = null;
+        // 当前 join 的左表
+        Table leftTable = null;
+
+        if (mainTables == null) {
+            mainTables = new ArrayList<>();
+        } else if (mainTables.size() == 1) {
+            mainTable = mainTables.get(0);
+            leftTable = mainTable;
+        }
+
+        //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
+        Deque<List<Table>> onTableDeque = new LinkedList<>();
+        for (Join join : joins) {
+            // 处理 on 表达式
+            FromItem joinItem = join.getRightItem();
+
+            // 获取当前 join 的表,subJoint 可以看作是一张表
+            List<Table> joinTables = null;
+            if (joinItem instanceof Table) {
+                joinTables = new ArrayList<>();
+                joinTables.add((Table) joinItem);
+            } else if (joinItem instanceof SubJoin) {
+                joinTables = processSubJoin((SubJoin) joinItem, whereSegment);
+            }
+
+            if (joinTables != null) {
+
+                // 如果是隐式内连接
+                if (join.isSimple()) {
+                    mainTables.addAll(joinTables);
+                    continue;
+                }
+
+                // 当前表是否忽略
+                Table joinTable = joinTables.get(0);
+
+                List<Table> onTables = null;
+                // 如果不要忽略,且是右连接,则记录下当前表
+                if (join.isRight()) {
+                    mainTable = joinTable;
+                    if (leftTable != null) {
+                        onTables = Collections.singletonList(leftTable);
+                    }
+                } else if (join.isLeft()) {
+                    onTables = Collections.singletonList(joinTable);
+                } else if (join.isInner()) {
+                    if (mainTable == null) {
+                        onTables = Collections.singletonList(joinTable);
+                    } else {
+                        onTables = Arrays.asList(mainTable, joinTable);
+                    }
+                    mainTable = null;
+                }
+
+                mainTables = new ArrayList<>();
+                if (mainTable != null) {
+                    mainTables.add(mainTable);
+                }
+
+                // 获取 join 尾缀的 on 表达式列表
+                Collection<Expression> originOnExpressions = join.getOnExpressions();
+                // 正常 join on 表达式只有一个,立刻处理
+                if (originOnExpressions.size() == 1 && onTables != null) {
+                    List<Expression> onExpressions = new LinkedList<>();
+                    onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables, whereSegment));
+                    join.setOnExpressions(onExpressions);
+                    leftTable = joinTable;
+                    continue;
+                }
+                // 表名压栈,忽略的表压入 null,以便后续不处理
+                onTableDeque.push(onTables);
+                // 尾缀多个 on 表达式的时候统一处理
+                if (originOnExpressions.size() > 1) {
+                    Collection<Expression> onExpressions = new LinkedList<>();
+                    for (Expression originOnExpression : originOnExpressions) {
+                        List<Table> currentTableList = onTableDeque.poll();
+                        if (CollectionUtils.isEmpty(currentTableList)) {
+                            onExpressions.add(originOnExpression);
+                        } else {
+                            onExpressions.add(builderExpression(originOnExpression, currentTableList, whereSegment));
+                        }
+                    }
+                    join.setOnExpressions(onExpressions);
+                }
+                leftTable = joinTable;
+            } else {
+                processOtherFromItem(joinItem, whereSegment);
+                leftTable = null;
+            }
+        }
+
+        return mainTables;
+    }
+
+    /**
+     * 处理条件
+     */
+    protected Expression builderExpression(Expression currentExpression, List<Table> tables, final String whereSegment) {
+        // 没有表需要处理直接返回
+        if (CollectionUtils.isEmpty(tables)) {
+            return currentExpression;
+        }
+        // 构造每张表的条件
+        List<Expression> expressions = tables.stream()
+            .map(item -> buildTableExpression(item, whereSegment))
+            .filter(Objects::nonNull)
+            .collect(Collectors.toList());
+
+        // 没有表需要处理直接返回
+        if (CollectionUtils.isEmpty(expressions)) {
+            return currentExpression;
+        }
+
+        // 注入的表达式
+        Expression injectExpression = expressions.get(0);
+        // 如果有多表,则用 and 连接
+        if (expressions.size() > 1) {
+            for (int i = 1; i < expressions.size(); i++) {
+                injectExpression = new AndExpression(injectExpression, expressions.get(i));
+            }
+        }
+
+        if (currentExpression == null) {
+            return injectExpression;
+        }
+        if (currentExpression instanceof OrExpression) {
+            return new AndExpression(new Parenthesis(currentExpression), injectExpression);
+        } else {
+            return new AndExpression(currentExpression, injectExpression);
+        }
+    }
+
+    /**
+     * 构建数据库表的查询条件
+     *
+     * @param table        表对象
+     * @param whereSegment 所属Mapper对象全路径
+     * @return 需要拼接的新条件(不会覆盖原有的where条件,只会在原有条件上再加条件),为 null 则不加入新的条件
+     */
+    public abstract Expression buildTableExpression(final Table table, final String whereSegment);
+}

+ 70 - 5
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/DataPermissionInterceptor.java

@@ -17,20 +17,26 @@ package com.baomidou.mybatisplus.extension.plugins.inner;
 
 import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
-import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
 import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
+import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
 import lombok.*;
 import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.select.PlainSelect;
 import net.sf.jsqlparser.statement.select.Select;
 import net.sf.jsqlparser.statement.select.SelectBody;
 import net.sf.jsqlparser.statement.select.SetOperationList;
+import net.sf.jsqlparser.statement.update.Update;
 import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
 import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
+import java.sql.Connection;
 import java.sql.SQLException;
 import java.util.List;
 
@@ -38,7 +44,7 @@ import java.util.List;
  * 数据权限处理器
  *
  * @author hubin
- * @since 3.4.1 +
+ * @since 3.5.2
  */
 @Data
 @NoArgsConstructor
@@ -46,16 +52,32 @@ import java.util.List;
 @ToString(callSuper = true)
 @EqualsAndHashCode(callSuper = true)
 @SuppressWarnings({"rawtypes"})
-public class DataPermissionInterceptor extends JsqlParserSupport implements InnerInterceptor {
+public class DataPermissionInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
     private DataPermissionHandler dataPermissionHandler;
 
     @Override
     public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
-        if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) return;
+        if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
+            return;
+        }
         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
         mpBs.sql(parserSingle(mpBs.sql(), ms.getId()));
     }
 
+    @Override
+    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
+        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
+        MappedStatement ms = mpSh.mappedStatement();
+        SqlCommandType sct = ms.getSqlCommandType();
+        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
+            if (InterceptorIgnoreHelper.willIgnoreDataPermission(ms.getId())) {
+                return;
+            }
+            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
+            mpBs.sql(parserMulti(mpBs.sql(), ms.getId()));
+        }
+    }
+
     @Override
     protected void processSelect(Select select, int index, String sql, Object obj) {
         SelectBody selectBody = select.getSelectBody();
@@ -75,9 +97,52 @@ public class DataPermissionInterceptor extends JsqlParserSupport implements Inne
      * @param whereSegment 查询条件片段
      */
     protected void setWhere(PlainSelect plainSelect, String whereSegment) {
-        Expression sqlSegment = dataPermissionHandler.getSqlSegment(plainSelect.getWhere(), whereSegment);
+        if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
+            processPlainSelect(plainSelect, whereSegment);
+            return;
+        }
+        // 兼容旧版的数据权限处理
+        final Expression sqlSegment = dataPermissionHandler.getSqlSegment(plainSelect.getWhere(), whereSegment);
         if (null != sqlSegment) {
             plainSelect.setWhere(sqlSegment);
         }
     }
+
+    /**
+     * update 语句处理
+     */
+    @Override
+    protected void processUpdate(Update update, int index, String sql, Object obj) {
+        final Expression sqlSegment = getUpdateOrDeleteExpression(update.getTable(), update.getWhere(), (String) obj);
+        if (null != sqlSegment) {
+            update.setWhere(sqlSegment);
+        }
+    }
+
+    /**
+     * delete 语句处理
+     */
+    @Override
+    protected void processDelete(Delete delete, int index, String sql, Object obj) {
+        final Expression sqlSegment = getUpdateOrDeleteExpression(delete.getTable(), delete.getWhere(), (String) obj);
+        if (null != sqlSegment) {
+            delete.setWhere(sqlSegment);
+        }
+    }
+
+    protected Expression getUpdateOrDeleteExpression(final Table table, final Expression where, final String whereSegment) {
+        if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
+            return andExpression(table, where, whereSegment);
+        } else {
+            // 兼容旧版的数据权限处理
+            return dataPermissionHandler.getSqlSegment(where, whereSegment);
+        }
+    }
+
+    @Override
+    public Expression buildTableExpression(final Table table, final String whereSegment) {
+        // 只有新版数据权限处理器才会执行到这里
+        final MultiDataPermissionHandler handler = (MultiDataPermissionHandler) dataPermissionHandler;
+        return handler.getSqlSegment(table, whereSegment);
+    }
 }

+ 40 - 388
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptor.java

@@ -17,14 +17,15 @@ package com.baomidou.mybatisplus.extension.plugins.inner;
 
 import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
 import com.baomidou.mybatisplus.core.toolkit.*;
-import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
 import com.baomidou.mybatisplus.extension.plugins.handler.TenantLineHandler;
 import com.baomidou.mybatisplus.extension.toolkit.PropertyMapper;
 import lombok.*;
-import net.sf.jsqlparser.expression.*;
-import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
-import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
-import net.sf.jsqlparser.expression.operators.relational.*;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.StringValue;
+import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.expression.operators.relational.ItemsList;
+import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
 import net.sf.jsqlparser.schema.Column;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.delete.Delete;
@@ -41,8 +42,8 @@ import org.apache.ibatis.session.RowBounds;
 
 import java.sql.Connection;
 import java.sql.SQLException;
-import java.util.*;
-import java.util.stream.Collectors;
+import java.util.List;
+import java.util.Properties;
 
 /**
  * @author hubin
@@ -54,13 +55,15 @@ import java.util.stream.Collectors;
 @ToString(callSuper = true)
 @EqualsAndHashCode(callSuper = true)
 @SuppressWarnings({"rawtypes"})
-public class TenantLineInnerInterceptor extends JsqlParserSupport implements InnerInterceptor {
+public class TenantLineInnerInterceptor extends BaseMultiTableInnerInterceptor implements InnerInterceptor {
 
     private TenantLineHandler tenantLineHandler;
 
     @Override
     public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
-        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
+        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) {
+            return;
+        }
         PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
         mpBs.sql(parserSingle(mpBs.sql(), null));
     }
@@ -81,28 +84,11 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
 
     @Override
     protected void processSelect(Select select, int index, String sql, Object obj) {
-        processSelectBody(select.getSelectBody());
+        final String whereSegment = (String) obj;
+        processSelectBody(select.getSelectBody(), whereSegment);
         List<WithItem> withItemsList = select.getWithItemsList();
         if (!CollectionUtils.isEmpty(withItemsList)) {
-            withItemsList.forEach(this::processSelectBody);
-        }
-    }
-
-    protected void processSelectBody(SelectBody selectBody) {
-        if (selectBody == null) {
-            return;
-        }
-        if (selectBody instanceof PlainSelect) {
-            processPlainSelect((PlainSelect) selectBody);
-        } else if (selectBody instanceof WithItem) {
-            WithItem withItem = (WithItem) selectBody;
-            processSelectBody(withItem.getSubSelect().getSelectBody());
-        } else {
-            SetOperationList operationList = (SetOperationList) selectBody;
-            List<SelectBody> selectBodyList = operationList.getSelects();
-            if (CollectionUtils.isNotEmpty(selectBodyList)) {
-                selectBodyList.forEach(this::processSelectBody);
-            }
+            withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
         }
     }
 
@@ -135,7 +121,7 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
 
         Select select = insert.getSelect();
         if (select != null) {
-            processInsertSelect(select.getSelectBody());
+            this.processInsertSelect(select.getSelectBody(), (String) obj);
         } else if (insert.getItemsList() != null) {
             // fixed github pull/295
             ItemsList itemsList = insert.getItemsList();
@@ -160,7 +146,7 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
             // 过滤退出执行
             return;
         }
-        update.setWhere(this.andExpression(table, update.getWhere()));
+        update.setWhere(this.andExpression(table, update.getWhere(), (String) obj));
     }
 
     /**
@@ -172,28 +158,9 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
             // 过滤退出执行
             return;
         }
-        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere()));
+        delete.setWhere(this.andExpression(delete.getTable(), delete.getWhere(), (String) obj));
     }
 
-    /**
-     * delete update 语句 where 处理
-     */
-    protected BinaryExpression andExpression(Table table, Expression where) {
-        //获得where条件表达式
-        EqualsTo equalsTo = new EqualsTo();
-        equalsTo.setLeftExpression(this.getAliasColumn(table));
-        equalsTo.setRightExpression(tenantLineHandler.getTenantId());
-        if (null != where) {
-            if (where instanceof OrExpression) {
-                return new AndExpression(equalsTo, new Parenthesis(where));
-            } else {
-                return new AndExpression(equalsTo, where);
-            }
-        }
-        return equalsTo;
-    }
-
-
     /**
      * 处理 insert into select
      * <p>
@@ -201,17 +168,17 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
      *
      * @param selectBody SelectBody
      */
-    protected void processInsertSelect(SelectBody selectBody) {
+    protected void processInsertSelect(SelectBody selectBody, final String whereSegment) {
         PlainSelect plainSelect = (PlainSelect) selectBody;
         FromItem fromItem = plainSelect.getFromItem();
         if (fromItem instanceof Table) {
             // fixed gitee pulls/141 duplicate update
-            processPlainSelect(plainSelect);
+            processPlainSelect(plainSelect, whereSegment);
             appendSelectItem(plainSelect.getSelectItems());
         } else if (fromItem instanceof SubSelect) {
             SubSelect subSelect = (SubSelect) fromItem;
             appendSelectItem(plainSelect.getSelectItems());
-            processInsertSelect(subSelect.getSelectBody());
+            processInsertSelect(subSelect.getSelectBody(), whereSegment);
         }
     }
 
@@ -233,336 +200,6 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         selectItems.add(new SelectExpressionItem(new Column(tenantLineHandler.getTenantIdColumn())));
     }
 
-    /**
-     * 处理 PlainSelect
-     */
-    protected void processPlainSelect(PlainSelect plainSelect) {
-        //#3087 github
-        List<SelectItem> selectItems = plainSelect.getSelectItems();
-        if (CollectionUtils.isNotEmpty(selectItems)) {
-            selectItems.forEach(this::processSelectItem);
-        }
-
-        // 处理 where 中的子查询
-        Expression where = plainSelect.getWhere();
-        processWhereSubSelect(where);
-
-        // 处理 fromItem
-        FromItem fromItem = plainSelect.getFromItem();
-        List<Table> list = processFromItem(fromItem);
-        List<Table> mainTables = new ArrayList<>(list);
-
-        // 处理 join
-        List<Join> joins = plainSelect.getJoins();
-        if (CollectionUtils.isNotEmpty(joins)) {
-            mainTables = processJoins(mainTables, joins);
-        }
-
-        // 当有 mainTable 时,进行 where 条件追加
-        if (CollectionUtils.isNotEmpty(mainTables)) {
-            plainSelect.setWhere(builderExpression(where, mainTables));
-        }
-    }
-
-    private List<Table> processFromItem(FromItem fromItem) {
-        // 处理括号括起来的表达式
-        while (fromItem instanceof ParenthesisFromItem) {
-            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
-        }
-
-        List<Table> mainTables = new ArrayList<>();
-        // 无 join 时的处理逻辑
-        if (fromItem instanceof Table) {
-            Table fromTable = (Table) fromItem;
-            mainTables.add(fromTable);
-        } else if (fromItem instanceof SubJoin) {
-            // SubJoin 类型则还需要添加上 where 条件
-            List<Table> tables = processSubJoin((SubJoin) fromItem);
-            mainTables.addAll(tables);
-        } else {
-            // 处理下 fromItem
-            processOtherFromItem(fromItem);
-        }
-        return mainTables;
-    }
-
-    /**
-     * 处理where条件内的子查询
-     * <p>
-     * 支持如下:
-     * 1. in
-     * 2. =
-     * 3. >
-     * 4. <
-     * 5. >=
-     * 6. <=
-     * 7. <>
-     * 8. EXISTS
-     * 9. NOT EXISTS
-     * <p>
-     * 前提条件:
-     * 1. 子查询必须放在小括号中
-     * 2. 子查询一般放在比较操作符的右边
-     *
-     * @param where where 条件
-     */
-    protected void processWhereSubSelect(Expression where) {
-        if (where == null) {
-            return;
-        }
-        if (where instanceof FromItem) {
-            processOtherFromItem((FromItem) where);
-            return;
-        }
-        if (where.toString().indexOf("SELECT") > 0) {
-            // 有子查询
-            if (where instanceof BinaryExpression) {
-                // 比较符号 , and , or , 等等
-                BinaryExpression expression = (BinaryExpression) where;
-                processWhereSubSelect(expression.getLeftExpression());
-                processWhereSubSelect(expression.getRightExpression());
-            } else if (where instanceof InExpression) {
-                // in
-                InExpression expression = (InExpression) where;
-                Expression inExpression = expression.getRightExpression();
-                if (inExpression instanceof SubSelect) {
-                    processSelectBody(((SubSelect) inExpression).getSelectBody());
-                }
-            } else if (where instanceof ExistsExpression) {
-                // exists
-                ExistsExpression expression = (ExistsExpression) where;
-                processWhereSubSelect(expression.getRightExpression());
-            } else if (where instanceof NotExpression) {
-                // not exists
-                NotExpression expression = (NotExpression) where;
-                processWhereSubSelect(expression.getExpression());
-            } else if (where instanceof Parenthesis) {
-                Parenthesis expression = (Parenthesis) where;
-                processWhereSubSelect(expression.getExpression());
-            }
-        }
-    }
-
-    protected void processSelectItem(SelectItem selectItem) {
-        if (selectItem instanceof SelectExpressionItem) {
-            SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
-            if (selectExpressionItem.getExpression() instanceof SubSelect) {
-                processSelectBody(((SubSelect) selectExpressionItem.getExpression()).getSelectBody());
-            } else if (selectExpressionItem.getExpression() instanceof Function) {
-                processFunction((Function) selectExpressionItem.getExpression());
-            }
-        }
-    }
-
-    /**
-     * 处理函数
-     * <p>支持: 1. select fun(args..) 2. select fun1(fun2(args..),args..)<p>
-     * <p> fixed gitee pulls/141</p>
-     *
-     * @param function
-     */
-    protected void processFunction(Function function) {
-        ExpressionList parameters = function.getParameters();
-        if (parameters != null) {
-            parameters.getExpressions().forEach(expression -> {
-                if (expression instanceof SubSelect) {
-                    processSelectBody(((SubSelect) expression).getSelectBody());
-                } else if (expression instanceof Function) {
-                    processFunction((Function) expression);
-                }
-            });
-        }
-    }
-
-    /**
-     * 处理子查询等
-     */
-    protected void processOtherFromItem(FromItem fromItem) {
-        // 去除括号
-        while (fromItem instanceof ParenthesisFromItem) {
-            fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
-        }
-
-        if (fromItem instanceof SubSelect) {
-            SubSelect subSelect = (SubSelect) fromItem;
-            if (subSelect.getSelectBody() != null) {
-                processSelectBody(subSelect.getSelectBody());
-            }
-        } else if (fromItem instanceof ValuesList) {
-            logger.debug("Perform a subQuery, if you do not give us feedback");
-        } else if (fromItem instanceof LateralSubSelect) {
-            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
-            if (lateralSubSelect.getSubSelect() != null) {
-                SubSelect subSelect = lateralSubSelect.getSubSelect();
-                if (subSelect.getSelectBody() != null) {
-                    processSelectBody(subSelect.getSelectBody());
-                }
-            }
-        }
-    }
-
-    /**
-     * 处理 sub join
-     *
-     * @param subJoin subJoin
-     * @return Table subJoin 中的主表
-     */
-    private List<Table> processSubJoin(SubJoin subJoin) {
-        List<Table> mainTables = new ArrayList<>();
-        if (subJoin.getJoinList() != null) {
-            List<Table> list = processFromItem(subJoin.getLeft());
-            mainTables.addAll(list);
-            mainTables = processJoins(mainTables, subJoin.getJoinList());
-        }
-        return mainTables;
-    }
-
-    /**
-     * 处理 joins
-     *
-     * @param mainTables 可以为 null
-     * @param joins      join 集合
-     * @return List<Table> 右连接查询的 Table 列表
-     */
-    private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
-        // join 表达式中最终的主表
-        Table mainTable = null;
-        // 当前 join 的左表
-        Table leftTable = null;
-
-        if (mainTables == null) {
-            mainTables = new ArrayList<>();
-        } else if (mainTables.size() == 1) {
-            mainTable = mainTables.get(0);
-            leftTable = mainTable;
-        }
-
-        //对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
-        Deque<List<Table>> onTableDeque = new LinkedList<>();
-        for (Join join : joins) {
-            // 处理 on 表达式
-            FromItem joinItem = join.getRightItem();
-
-            // 获取当前 join 的表,subJoint 可以看作是一张表
-            List<Table> joinTables = null;
-            if (joinItem instanceof Table) {
-                joinTables = new ArrayList<>();
-                joinTables.add((Table) joinItem);
-            } else if (joinItem instanceof SubJoin) {
-                joinTables = processSubJoin((SubJoin) joinItem);
-            }
-
-            if (joinTables != null) {
-
-                // 如果是隐式内连接
-                if (join.isSimple()) {
-                    mainTables.addAll(joinTables);
-                    continue;
-                }
-
-                // 当前表是否忽略
-                Table joinTable = joinTables.get(0);
-
-                List<Table> onTables = null;
-                // 如果不要忽略,且是右连接,则记录下当前表
-                if (join.isRight()) {
-                    mainTable = joinTable;
-                    if (leftTable != null) {
-                        onTables = Collections.singletonList(leftTable);
-                    }
-                } else if (join.isLeft()) {
-                    onTables = Collections.singletonList(joinTable);
-                } else if (join.isInner()) {
-                    if (mainTable == null) {
-                        onTables = Collections.singletonList(joinTable);
-                    } else {
-                        onTables = Arrays.asList(mainTable, joinTable);
-                    }
-                    mainTable = null;
-                }
-
-                mainTables = new ArrayList<>();
-                if (mainTable != null) {
-                    mainTables.add(mainTable);
-                }
-
-                // 获取 join 尾缀的 on 表达式列表
-                Collection<Expression> originOnExpressions = join.getOnExpressions();
-                // 正常 join on 表达式只有一个,立刻处理
-                if (originOnExpressions.size() == 1 && onTables != null) {
-                    List<Expression> onExpressions = new LinkedList<>();
-                    onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
-                    join.setOnExpressions(onExpressions);
-                    leftTable = joinTable;
-                    continue;
-                }
-                // 表名压栈,忽略的表压入 null,以便后续不处理
-                onTableDeque.push(onTables);
-                // 尾缀多个 on 表达式的时候统一处理
-                if (originOnExpressions.size() > 1) {
-                    Collection<Expression> onExpressions = new LinkedList<>();
-                    for (Expression originOnExpression : originOnExpressions) {
-                        List<Table> currentTableList = onTableDeque.poll();
-                        if (CollectionUtils.isEmpty(currentTableList)) {
-                            onExpressions.add(originOnExpression);
-                        } else {
-                            onExpressions.add(builderExpression(originOnExpression, currentTableList));
-                        }
-                    }
-                    join.setOnExpressions(onExpressions);
-                }
-                leftTable = joinTable;
-            } else {
-                processOtherFromItem(joinItem);
-                leftTable = null;
-            }
-        }
-
-        return mainTables;
-    }
-
-    /**
-     * 处理条件
-     */
-    protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
-        // 没有表需要处理直接返回
-        if (CollectionUtils.isEmpty(tables)) {
-            return currentExpression;
-        }
-        // 构造每张表的条件
-        List<Table> tempTables = tables.stream()
-            .filter(x -> !tenantLineHandler.ignoreTable(x.getName()))
-            .collect(Collectors.toList());
-
-        // 没有表需要处理直接返回
-        if (CollectionUtils.isEmpty(tempTables)) {
-            return currentExpression;
-        }
-
-        Expression tenantId = tenantLineHandler.getTenantId();
-        List<EqualsTo> equalsTos = tempTables.stream()
-            .map(item -> new EqualsTo(getAliasColumn(item), tenantId))
-            .collect(Collectors.toList());
-
-        // 注入的表达式
-        Expression injectExpression = equalsTos.get(0);
-        // 如果有多表,则用 and 连接
-        if (equalsTos.size() > 1) {
-            for (int i = 1; i < equalsTos.size(); i++) {
-                injectExpression = new AndExpression(injectExpression, equalsTos.get(i));
-            }
-        }
-
-        if (currentExpression == null) {
-            return injectExpression;
-        }
-        if (currentExpression instanceof OrExpression) {
-            return new AndExpression(new Parenthesis(currentExpression), injectExpression);
-        } else {
-            return new AndExpression(currentExpression, injectExpression);
-        }
-    }
-
     /**
      * 租户字段别名设置
      * <p>tenantId 或 tableAlias.tenantId</p>
@@ -575,9 +212,11 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         // 禁止 `为了兼容隐式内连接,没有别名时条件就需要加上表名`
         // 该起别名就要起别名
         if (table.getAlias() != null) {
-            column.append(table.getAlias().getName()).append(StringPool.DOT);
+            column.append(table.getAlias().getName());
+        } else {
+            column.append(table.getName());
         }
-        column.append(tenantLineHandler.getTenantIdColumn());
+        column.append(StringPool.DOT).append(tenantLineHandler.getTenantIdColumn());
         return new Column(column.toString());
     }
 
@@ -586,6 +225,19 @@ public class TenantLineInnerInterceptor extends JsqlParserSupport implements Inn
         PropertyMapper.newInstance(properties).whenNotBlank("tenantLineHandler",
             ClassUtils::newInstance, this::setTenantLineHandler);
     }
-}
-
 
+    /**
+     * 构建租户条件表达式
+     *
+     * @param table        表对象
+     * @param whereSegment 所属Mapper对象全路径(在原租户拦截器功能中,这个参数并不需要参与相关判断)
+     * @return 租户条件表达式
+     */
+    @Override
+    public Expression buildTableExpression(final Table table, final String whereSegment) {
+        if (tenantLineHandler.ignoreTable(table.getName())) {
+            return null;
+        }
+        return new EqualsTo(getAliasColumn(table), tenantLineHandler.getTenantId());
+    }
+}

+ 120 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/MultiDataPermissionInterceptorTest.java

@@ -0,0 +1,120 @@
+package com.baomidou.mybatisplus.extension.plugins.inner;
+
+import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
+import com.google.common.collect.HashBasedTable;
+import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.schema.Table;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * SQL多表场景的数据权限拦截器测试
+ *
+ * @author houkunlin
+ * @since 3.5.2 +
+ */
+public class MultiDataPermissionInterceptorTest {
+    private static final Logger logger = LoggerFactory.getLogger(MultiDataPermissionInterceptorTest.class);
+    /**
+     * 这里可以理解为数据库配置的数据权限规则 SQL
+     */
+    private static final com.google.common.collect.Table<String, String, String> sqlSegmentMap;
+    private static final DataPermissionInterceptor interceptor;
+    private static String TEST_1 = "com.baomidou.userMapper.selectByUsername";
+    private static String TEST_2 = "com.baomidou.userMapper.selectById";
+    private static String TEST_3 = "com.baomidou.roleMapper.selectByCompanyId";
+    private static String TEST_4 = "com.baomidou.roleMapper.selectById";
+    private static String TEST_5 = "com.baomidou.roleMapper.selectByRoleId";
+    private static String TEST_6 = "com.baomidou.roleMapper.selectUserInfo";
+    private static String TEST_7 = "com.baomidou.roleMapper.summarySum";
+
+    static {
+        sqlSegmentMap = HashBasedTable.create();
+        sqlSegmentMap.put(TEST_1, "sys_user", "username='123' or userId IN (1,2,3)");
+        sqlSegmentMap.put(TEST_2, "sys_user", "u.state=1 and u.amount > 1000");
+        sqlSegmentMap.put(TEST_3, "sys_role", "companyId in (1,2,3)");
+        sqlSegmentMap.put(TEST_4, "sys_role", "username like 'abc%'");
+        sqlSegmentMap.put(TEST_5, "sys_role", "id=1 and role_id in (select id from sys_role)");
+        sqlSegmentMap.put(TEST_6, "sys_user", "u.state=1 and u.amount > 1000");
+        sqlSegmentMap.put(TEST_6, "sys_user_role", "r.role_id=3 AND r.role_id IN (7,9,11)");
+        sqlSegmentMap.put(TEST_7, "`fund`", "a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111");
+        sqlSegmentMap.put(TEST_7, "`fund_month`", "b.fund_id = 2 AND b.month <= '2022-05'");
+        interceptor = new DataPermissionInterceptor(new MultiDataPermissionHandler() {
+
+            @Override
+            public Expression getSqlSegment(final Table table, final String mappedStatementId) {
+                try {
+                    String sqlSegment = sqlSegmentMap.get(mappedStatementId, table.getName());
+                    if (sqlSegment == null) {
+                        logger.info("{} {} AS {} : NOT FOUND", mappedStatementId, table.getName(), table.getAlias());
+                        return null;
+                    }
+                    Expression sqlSegmentExpression = CCJSqlParserUtil.parseCondExpression(sqlSegment);
+                    logger.info("{} {} AS {} : {}", mappedStatementId, table.getName(), table.getAlias(), sqlSegmentExpression.toString());
+                    return sqlSegmentExpression;
+                } catch (JSQLParserException e) {
+                    e.printStackTrace();
+                }
+                return null;
+            }
+        });
+    }
+
+    @Test
+    void test1() {
+        assertSql(TEST_1, "select * from sys_user",
+            "SELECT * FROM sys_user WHERE username = '123' OR userId IN (1, 2, 3)");
+    }
+
+    @Test
+    void test2() {
+        assertSql(TEST_2, "select u.username from sys_user u join sys_user_role r on u.id=r.user_id where r.role_id=3",
+            "SELECT u.username FROM sys_user u JOIN sys_user_role r ON u.id = r.user_id WHERE r.role_id = 3 AND u.state = 1 AND u.amount > 1000");
+    }
+
+    @Test
+    void test3() {
+        assertSql(TEST_3, "select * from sys_role where company_id=6",
+            "SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3)");
+    }
+
+    @Test
+    void test3unionAll() {
+        assertSql(TEST_3, "select * from sys_role where company_id=6 union all select * from sys_role where company_id=7",
+            "SELECT * FROM sys_role WHERE company_id = 6 AND companyId IN (1, 2, 3) UNION ALL SELECT * FROM sys_role WHERE company_id = 7 AND companyId IN (1, 2, 3)");
+    }
+
+    @Test
+    void test4() {
+        assertSql(TEST_4, "select * from sys_role where id=3",
+            "SELECT * FROM sys_role WHERE id = 3 AND username LIKE 'abc%'");
+    }
+
+    @Test
+    void test5() {
+        assertSql(TEST_5, "select * from sys_role where id=3",
+            "SELECT * FROM sys_role WHERE id = 3 AND id = 1 AND role_id IN (SELECT id FROM sys_role)");
+    }
+
+    @Test
+    void test6() {
+        // 显式指定 JOIN 类型时 JOIN 右侧表才能进行拼接条件
+        assertSql(TEST_6, "select u.username from sys_user u LEFT join sys_user_role r on u.id=r.user_id",
+            "SELECT u.username FROM sys_user u LEFT JOIN sys_user_role r ON u.id = r.user_id AND r.role_id = 3 AND r.role_id IN (7, 9, 11) WHERE u.state = 1 AND u.amount > 1000");
+    }
+
+    @Test
+    void test7() {
+        assertSql(TEST_7, "SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE) c WHERE c.row_index = 1 GROUP BY title LIMIT 20",
+            "SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE AND b.fund_id = 2 AND b.month <= '2022-05' WHERE a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111) c WHERE c.row_index = 1 GROUP BY title LIMIT 20");
+    }
+
+    void assertSql(String mappedStatementId, String sql, String targetSql) {
+        assertThat(interceptor.parserSingle(sql, mappedStatementId)).isEqualTo(targetSql);
+    }
+}

+ 14 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/TenantLineInnerInterceptorTest.java

@@ -428,6 +428,20 @@ class TenantLineInnerInterceptorTest {
             "SELECT dict.dict_code, item.item_text AS \"text\", item.item_value AS \"value\" FROM sys_dict_item item INNER JOIN sys_dict dict ON dict.id = item.dict_id AND item.tenant_id = 1 WHERE dict.dict_code IN (1, 2, 3) AND item.item_value IN (1, 2, 3)");
     }
 
+    @Test
+    void test6() {
+        // 不显式指定 JOIN 类型时 JOIN 右侧表无法识进行拼接条件(在未改动之前就已经有这个问题)
+        assertSql("select u.username from sys_user u join sys_user_role r on u.id=r.user_id",
+            "SELECT u.username FROM sys_user u JOIN sys_user_role r ON u.id = r.user_id WHERE u.tenant_id = 1");
+    }
+
+    @Test
+    void test7() {
+        // 显式指定 JOIN 类型时 JOIN 右侧表才能进行拼接条件
+        assertSql("select u.username from sys_user u LEFT join sys_user_role r on u.id=r.user_id",
+            "SELECT u.username FROM sys_user u LEFT JOIN sys_user_role r ON u.id = r.user_id AND r.tenant_id = 1 WHERE u.tenant_id = 1");
+    }
+
     void assertSql(String sql, String targetSql) {
         assertThat(interceptor.parserSingle(sql, null)).isEqualTo(targetSql);
     }