Переглянути джерело

Merge pull request #1445 from kana112233/3.0

用SetOperationList处理sql带union的语句
qmdx 5 роки тому
батько
коміт
107658da81

+ 52 - 43
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/PaginationInterceptor.java

@@ -15,33 +15,6 @@
  */
 package com.baomidou.mybatisplus.extension.plugins;
 
-import java.sql.Connection;
-import java.sql.PreparedStatement;
-import java.sql.ResultSet;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.Properties;
-
-import org.apache.ibatis.executor.statement.StatementHandler;
-import org.apache.ibatis.logging.Log;
-import org.apache.ibatis.logging.LogFactory;
-import org.apache.ibatis.mapping.BoundSql;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.apache.ibatis.mapping.ParameterMapping;
-import org.apache.ibatis.mapping.SqlCommandType;
-import org.apache.ibatis.mapping.StatementType;
-import org.apache.ibatis.plugin.Interceptor;
-import org.apache.ibatis.plugin.Intercepts;
-import org.apache.ibatis.plugin.Invocation;
-import org.apache.ibatis.plugin.Plugin;
-import org.apache.ibatis.plugin.Signature;
-import org.apache.ibatis.reflection.MetaObject;
-import org.apache.ibatis.reflection.SystemMetaObject;
-import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
-import org.apache.ibatis.session.Configuration;
-import org.apache.ibatis.session.RowBounds;
-
 import com.baomidou.mybatisplus.annotation.DbType;
 import com.baomidou.mybatisplus.core.MybatisDefaultParameterHandler;
 import com.baomidou.mybatisplus.core.metadata.IPage;
@@ -57,15 +30,31 @@ import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
 import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
 import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
 import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
-
 import lombok.Setter;
 import lombok.experimental.Accessors;
 import net.sf.jsqlparser.JSQLParserException;
 import net.sf.jsqlparser.parser.CCJSqlParserUtil;
 import net.sf.jsqlparser.schema.Column;
-import net.sf.jsqlparser.statement.select.OrderByElement;
-import net.sf.jsqlparser.statement.select.PlainSelect;
-import net.sf.jsqlparser.statement.select.Select;
+import net.sf.jsqlparser.statement.select.*;
+import org.apache.ibatis.executor.statement.StatementHandler;
+import org.apache.ibatis.logging.Log;
+import org.apache.ibatis.logging.LogFactory;
+import org.apache.ibatis.mapping.*;
+import org.apache.ibatis.plugin.*;
+import org.apache.ibatis.reflection.MetaObject;
+import org.apache.ibatis.reflection.SystemMetaObject;
+import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
+import org.apache.ibatis.session.Configuration;
+import org.apache.ibatis.session.RowBounds;
+import org.jetbrains.annotations.NotNull;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
 
 /**
  * 分页拦截器
@@ -113,19 +102,25 @@ public class PaginationInterceptor extends AbstractSqlParserHandler implements I
             try {
                 List<OrderItem> orderList = page.orders();
                 Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
-                PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
-                List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
-                if (orderByElements == null || orderByElements.isEmpty()) {
-                    orderByElements = new ArrayList<>(orderList.size());
-                }
-                for (OrderItem item : orderList) {
-                    OrderByElement element = new OrderByElement();
-                    element.setExpression(new Column(item.getColumn()));
-                    element.setAsc(item.isAsc());
-                    orderByElements.add(element);
+                if (selectStatement.getSelectBody() instanceof PlainSelect) {
+                    PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
+                    List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
+                    List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
+                    plainSelect.setOrderByElements(orderByElementsReturn);
+                    return plainSelect.toString();
+                } else if (selectStatement.getSelectBody() instanceof SetOperationList) {
+                    SetOperationList setOperationList = (SetOperationList) selectStatement.getSelectBody();
+                    List<OrderByElement> orderByElements = setOperationList.getOrderByElements();
+                    List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
+                    setOperationList.setOrderByElements(orderByElementsReturn);
+                    return setOperationList.toString();
+                } else if (selectStatement.getSelectBody() instanceof WithItem) {
+                    // todo: don't known how to resole
+                    return originalSql;
+                } else {
+                    return originalSql;
                 }
-                plainSelect.setOrderByElements(orderByElements);
-                return plainSelect.toString();
+
             } catch (JSQLParserException e) {
                 logger.warn("failed to concat orderBy from IPage, exception=" + e.getMessage());
             }
@@ -133,6 +128,20 @@ public class PaginationInterceptor extends AbstractSqlParserHandler implements I
         return originalSql;
     }
 
+    @NotNull
+    private static List<OrderByElement> addOrderByElements(List<OrderItem> orderList, List<OrderByElement> orderByElements) {
+        if (orderByElements == null || orderByElements.isEmpty()) {
+            orderByElements = new ArrayList<>(orderList.size());
+        }
+        for (OrderItem item : orderList) {
+            OrderByElement element = new OrderByElement();
+            element.setExpression(new Column(item.getColumn()));
+            element.setAsc(item.isAsc());
+            orderByElements.add(element);
+        }
+        return orderByElements;
+    }
+
     /**
      * Physical Page Interceptor for all the queries with parameter {@link RowBounds}
      */

+ 91 - 0
mybatis-plus-extension/src/test/java/com/baomidou/mybatisplus/extension/plugins/pagination/SelectBodyToPlainSelectTest.java

@@ -0,0 +1,91 @@
+package com.baomidou.mybatisplus.extension.plugins.pagination;
+
+import com.baomidou.mybatisplus.core.metadata.OrderItem;
+import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor;
+import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.statement.select.PlainSelect;
+import net.sf.jsqlparser.statement.select.Select;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * SelectBody强转PlainSelect不支持sql里面最外层带union
+ * 用SetOperationList处理sql带union的语句
+ */
+class SelectBodyToPlainSelectTest {
+
+    /**
+     * 报错的测试
+     */
+    @Test
+    void testSelectBodyToPlainSelectThrowException() {
+        Select selectStatement = null;
+        try {
+            String originalUnionSql = "select * from test union select * from test";
+            selectStatement = (Select) CCJSqlParserUtil.parse(originalUnionSql);
+        } catch (JSQLParserException e) {
+            e.printStackTrace();
+        }
+        assert selectStatement != null;
+        try {
+            PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
+            assert false;
+        } catch (Exception e) {
+            assertThat(e.getMessage()).isEqualTo("net.sf.jsqlparser.statement.select.SetOperationList cannot be cast to net.sf.jsqlparser.statement.select.PlainSelect");
+        }
+    }
+
+    private Page<?> page = new Page<>();
+
+    @BeforeEach
+    void setup() {
+        List<OrderItem> orderItems = new ArrayList<>();
+        OrderItem order = new OrderItem();
+        order.setAsc(true);
+        order.setColumn("column");
+        orderItems.add(order);
+        page.setOrders(orderItems);
+    }
+
+    @Test
+    void testPaginationInterceptorConcatOrderByBefore() {
+        String actualSql = PaginationInterceptor
+            .concatOrderBy("select * from test", page);
+
+        assertThat(actualSql).isEqualTo("SELECT * FROM test ORDER BY column");
+
+        String actualSqlWhere = PaginationInterceptor
+            .concatOrderBy("select * from test where 1 = 1", page);
+
+        assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 ORDER BY column");
+    }
+
+    @Test
+    void testPaginationInterceptorConcatOrderByFix() {
+        String actualSql = PaginationInterceptor
+            .concatOrderBy("select * from test union select * from test2", page);
+        assertThat(actualSql).isEqualTo("SELECT * FROM test UNION SELECT * FROM test2 ORDER BY column");
+
+        String actualSqlUnionAll = PaginationInterceptor
+            .concatOrderBy("select * from test union all select * from test2", page);
+        assertThat(actualSqlUnionAll).isEqualTo("SELECT * FROM test UNION ALL SELECT * FROM test2 ORDER BY column");
+    }
+
+    @Test
+    void testPaginationInterceptorConcatOrderByFixWithWhere() {
+        String actualSqlWhere = PaginationInterceptor
+            .concatOrderBy("select * from test where 1 = 1 union select * from test2 where 1 = 1", page);
+        assertThat(actualSqlWhere).isEqualTo("SELECT * FROM test WHERE 1 = 1 UNION SELECT * FROM test2 WHERE 1 = 1 ORDER BY column");
+
+        String actualSqlUnionAll = PaginationInterceptor
+            .concatOrderBy("select * from test where 1 = 1 union all select * from test2 where 1 = 1 ", page);
+        assertThat(actualSqlUnionAll).isEqualTo("SELECT * FROM test WHERE 1 = 1 UNION ALL SELECT * FROM test2 WHERE 1 = 1 ORDER BY column");
+    }
+
+}