Jelajahi Sumber

Merge pull request #5821 from houkunlin-fork/feat/data-permission

修复数据权限多表支持在某些查询数据场景下失效问题
qmdx 1 tahun lalu
induk
melakukan
e998982f57

+ 20 - 15
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/DataPermissionInterceptor.java

@@ -16,6 +16,7 @@
 package com.baomidou.mybatisplus.extension.plugins.inner;
 package com.baomidou.mybatisplus.extension.plugins.inner;
 
 
 import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
 import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
 import com.baomidou.mybatisplus.extension.plugins.handler.DataPermissionHandler;
 import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
 import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
@@ -23,10 +24,7 @@ import lombok.*;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.expression.Expression;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.schema.Table;
 import net.sf.jsqlparser.statement.delete.Delete;
 import net.sf.jsqlparser.statement.delete.Delete;
-import net.sf.jsqlparser.statement.select.PlainSelect;
-import net.sf.jsqlparser.statement.select.Select;
-import net.sf.jsqlparser.statement.select.SelectBody;
-import net.sf.jsqlparser.statement.select.SetOperationList;
+import net.sf.jsqlparser.statement.select.*;
 import net.sf.jsqlparser.statement.update.Update;
 import net.sf.jsqlparser.statement.update.Update;
 import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.executor.statement.StatementHandler;
@@ -80,13 +78,24 @@ public class DataPermissionInterceptor extends BaseMultiTableInnerInterceptor im
 
 
     @Override
     @Override
     protected void processSelect(Select select, int index, String sql, Object obj) {
     protected void processSelect(Select select, int index, String sql, Object obj) {
-        SelectBody selectBody = select.getSelectBody();
-        if (selectBody instanceof PlainSelect) {
-            this.setWhere((PlainSelect) selectBody, (String) obj);
-        } else if (selectBody instanceof SetOperationList) {
-            SetOperationList setOperationList = (SetOperationList) selectBody;
-            List<SelectBody> selectBodyList = setOperationList.getSelects();
-            selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
+        if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
+            // 参照 com.baomidou.mybatisplus.extension.plugins.inner.TenantLineInnerInterceptor.processSelect 做的修改
+            final String whereSegment = (String) obj;
+            processSelectBody(select.getSelectBody(), whereSegment);
+            List<WithItem> withItemsList = select.getWithItemsList();
+            if (!CollectionUtils.isEmpty(withItemsList)) {
+                withItemsList.forEach(withItem -> processSelectBody(withItem, whereSegment));
+            }
+        } else {
+            // 兼容原来的旧版 DataPermissionHandler 场景
+            SelectBody selectBody = select.getSelectBody();
+            if (selectBody instanceof PlainSelect) {
+                this.setWhere((PlainSelect) selectBody, (String) obj);
+            } else if (selectBody instanceof SetOperationList) {
+                SetOperationList setOperationList = (SetOperationList) selectBody;
+                List<SelectBody> selectBodyList = setOperationList.getSelects();
+                selectBodyList.forEach(s -> this.setWhere((PlainSelect) s, (String) obj));
+            }
         }
         }
     }
     }
 
 
@@ -97,10 +106,6 @@ public class DataPermissionInterceptor extends BaseMultiTableInnerInterceptor im
      * @param whereSegment 查询条件片段
      * @param whereSegment 查询条件片段
      */
      */
     protected void setWhere(PlainSelect plainSelect, String whereSegment) {
     protected void setWhere(PlainSelect plainSelect, String whereSegment) {
-        if (dataPermissionHandler instanceof MultiDataPermissionHandler) {
-            processPlainSelect(plainSelect, whereSegment);
-            return;
-        }
         // 兼容旧版的数据权限处理
         // 兼容旧版的数据权限处理
         final Expression sqlSegment = dataPermissionHandler.getSqlSegment(plainSelect.getWhere(), whereSegment);
         final Expression sqlSegment = dataPermissionHandler.getSqlSegment(plainSelect.getWhere(), whereSegment);
         if (null != sqlSegment) {
         if (null != sqlSegment) {

+ 34 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/inner/MultiDataPermissionInterceptorTest.java

@@ -1,5 +1,6 @@
 package com.baomidou.mybatisplus.extension.plugins.inner;
 package com.baomidou.mybatisplus.extension.plugins.inner;
 
 
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
 import com.baomidou.mybatisplus.extension.plugins.handler.MultiDataPermissionHandler;
 import com.google.common.collect.HashBasedTable;
 import com.google.common.collect.HashBasedTable;
 import net.sf.jsqlparser.JSQLParserException;
 import net.sf.jsqlparser.JSQLParserException;
@@ -32,6 +33,11 @@ public class MultiDataPermissionInterceptorTest {
     private static String TEST_5 = "com.baomidou.roleMapper.selectByRoleId";
     private static String TEST_5 = "com.baomidou.roleMapper.selectByRoleId";
     private static String TEST_6 = "com.baomidou.roleMapper.selectUserInfo";
     private static String TEST_6 = "com.baomidou.roleMapper.selectUserInfo";
     private static String TEST_7 = "com.baomidou.roleMapper.summarySum";
     private static String TEST_7 = "com.baomidou.roleMapper.summarySum";
+    private static String TEST_8_1 = "com.baomidou.CustomMapper.selectByOnlyMyData";
+    private static String TEST_8_2 = "com.baomidou.CustomMapper.selectByOnlyOrgData";
+    private static String TEST_8_3 = "com.baomidou.CustomMapper.selectByOnlyDeptData";
+    private static String TEST_8_4 = "com.baomidou.CustomMapper.selectByMyDataOrDeptData";
+    private static String TEST_8_5 = "com.baomidou.CustomMapper.selectByMyData";
 
 
     static {
     static {
         sqlSegmentMap = HashBasedTable.create();
         sqlSegmentMap = HashBasedTable.create();
@@ -44,6 +50,12 @@ public class MultiDataPermissionInterceptorTest {
         sqlSegmentMap.put(TEST_6, "sys_user_role", "r.role_id=3 AND r.role_id IN (7,9,11)");
         sqlSegmentMap.put(TEST_6, "sys_user_role", "r.role_id=3 AND r.role_id IN (7,9,11)");
         sqlSegmentMap.put(TEST_7, "`fund`", "a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111");
         sqlSegmentMap.put(TEST_7, "`fund`", "a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111");
         sqlSegmentMap.put(TEST_7, "`fund_month`", "b.fund_id = 2 AND b.month <= '2022-05'");
         sqlSegmentMap.put(TEST_7, "`fund_month`", "b.fund_id = 2 AND b.month <= '2022-05'");
+        sqlSegmentMap.put(TEST_8_1, "fund", "user_id=1");
+        sqlSegmentMap.put(TEST_8_2, "fund", "org_id=1");
+        sqlSegmentMap.put(TEST_8_3, "fund", "dept_id=1");
+        sqlSegmentMap.put(TEST_8_4, "fund", "user_id=1 or dept_id=1");
+        sqlSegmentMap.put(TEST_8_5, "table1", "u.user_id=1");
+        sqlSegmentMap.put(TEST_8_5, "table2", "u.dept_id=1");
         interceptor = new DataPermissionInterceptor(new MultiDataPermissionHandler() {
         interceptor = new DataPermissionInterceptor(new MultiDataPermissionHandler() {
 
 
             @Override
             @Override
@@ -54,6 +66,10 @@ public class MultiDataPermissionInterceptorTest {
                         logger.info("{} {} AS {} : NOT FOUND", mappedStatementId, table.getName(), table.getAlias());
                         logger.info("{} {} AS {} : NOT FOUND", mappedStatementId, table.getName(), table.getAlias());
                         return null;
                         return null;
                     }
                     }
+                    if (table.getAlias() != null) {
+                        // 替换表别名
+                        sqlSegment = sqlSegment.replaceAll("u\\.", table.getAlias().getName() + StringPool.DOT);
+                    }
                     Expression sqlSegmentExpression = CCJSqlParserUtil.parseCondExpression(sqlSegment);
                     Expression sqlSegmentExpression = CCJSqlParserUtil.parseCondExpression(sqlSegment);
                     logger.info("{} {} AS {} : {}", mappedStatementId, table.getName(), table.getAlias(), sqlSegmentExpression.toString());
                     logger.info("{} {} AS {} : {}", mappedStatementId, table.getName(), table.getAlias(), sqlSegmentExpression.toString());
                     return sqlSegmentExpression;
                     return sqlSegmentExpression;
@@ -114,6 +130,24 @@ public class MultiDataPermissionInterceptorTest {
             "SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE AND b.fund_id = 2 AND b.month <= '2022-05' WHERE a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111) c WHERE c.row_index = 1 GROUP BY title LIMIT 20");
             "SELECT c.doc AS title, sum(c.total_paid_amount) AS total_paid_amount, sum(c.balance_amount) AS balance_amount FROM (SELECT `a`.`id`, `a`.`doc`, `b`.`month`, `b`.`total_paid_amount`, `b`.`balance_amount`, row_number() OVER (PARTITION BY `a`.`id` ORDER BY `b`.`month` DESC) AS `row_index` FROM `fund` `a` LEFT JOIN `fund_month` `b` ON `a`.`id` = `b`.`fund_id` AND `b`.`submit` = TRUE AND b.fund_id = 2 AND b.month <= '2022-05' WHERE a.id = 1 AND a.year = 2022 AND a.create_user_id = 1111) c WHERE c.row_index = 1 GROUP BY title LIMIT 20");
     }
     }
 
 
+    @Test
+    void test8() {
+        assertSql(TEST_8_1, "select * from fund where id=3",
+            "SELECT * FROM fund WHERE id = 3 AND user_id = 1");
+        assertSql(TEST_8_2, "select * from fund where id=3",
+            "SELECT * FROM fund WHERE id = 3 AND org_id = 1");
+        assertSql(TEST_8_3, "select * from fund where id=3",
+            "SELECT * FROM fund WHERE id = 3 AND dept_id = 1");
+        assertSql(TEST_8_4, "select * from fund where id=3",
+            "SELECT * FROM fund WHERE id = 3 AND user_id = 1 OR dept_id = 1");
+        // 修改之前旧版的多表数据权限对这个SQL的表现形式:
+        // 输入 "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 on t1.uid = t2.uid) SELECT * FROM temp"
+        // 输出 "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 ON t1.uid = t2.uid) SELECT * FROM temp"
+        // 修改之后的多表数据权限对这个SQL的表现形式
+        assertSql(TEST_8_5, "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 on t1.uid = t2.uid) SELECT * FROM temp",
+            "WITH temp AS (SELECT t1.field1, t2.field2 FROM table1 t1 LEFT JOIN table2 t2 ON t1.uid = t2.uid AND t2.dept_id = 1 WHERE t1.user_id = 1) SELECT * FROM temp");
+    }
+
     void assertSql(String mappedStatementId, String sql, String targetSql) {
     void assertSql(String mappedStatementId, String sql, String targetSql) {
         assertThat(interceptor.parserSingle(sql, mappedStatementId)).isEqualTo(targetSql);
         assertThat(interceptor.parserSingle(sql, mappedStatementId)).isEqualTo(targetSql);
     }
     }