miemie vor 5 Jahren
Ursprung
Commit
56dc59a9ba

+ 65 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/MybatisPlusInterceptor.java

@@ -0,0 +1,65 @@
+package com.baomidou.mybatisplus.extension.plugins;
+
+import com.baomidou.mybatisplus.extension.plugins.chain.BeforeQuery;
+import org.apache.ibatis.cache.CacheKey;
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.plugin.*;
+import org.apache.ibatis.session.ResultHandler;
+import org.apache.ibatis.session.RowBounds;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Properties;
+
+/**
+ * @author miemie
+ * @since 2020-06-16
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+@Intercepts(
+    {
+        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
+        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
+    }
+)
+public class MybatisPlusInterceptor implements Interceptor {
+
+    private final List<BeforeQuery> beforeQueries = new ArrayList<>();
+
+    @Override
+    public Object intercept(Invocation invocation) throws Throwable {
+        Object[] args = invocation.getArgs();
+        MappedStatement ms = (MappedStatement) args[0];
+        Object parameter = args[1];
+        RowBounds rowBounds = (RowBounds) args[2];
+        ResultHandler resultHandler = (ResultHandler) args[3];
+        Executor executor = (Executor) invocation.getTarget();
+        BoundSql boundSql;
+        if (args.length == 4) {
+            boundSql = ms.getBoundSql(parameter);
+        } else {
+            // 几乎不可能走进这里面
+            boundSql = (BoundSql) args[5];
+        }
+        for (BeforeQuery query : beforeQueries) {
+            boundSql = query.change(executor, ms, parameter, rowBounds, resultHandler, boundSql);
+        }
+        CacheKey cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
+        return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
+    }
+
+    @Override
+    public Object plugin(Object target) {
+        if (target instanceof Executor) {
+            return Plugin.wrap(target, this);
+        }
+        return target;
+    }
+
+    @Override
+    public void setProperties(Properties properties) {
+
+    }
+}

+ 29 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/BeforeQuery.java

@@ -0,0 +1,29 @@
+package com.baomidou.mybatisplus.extension.plugins.chain;
+
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.session.ResultHandler;
+import org.apache.ibatis.session.RowBounds;
+
+import java.sql.SQLException;
+
+/**
+ * @author miemie
+ * @since 2020-06-16
+ */
+public interface BeforeQuery {
+
+    /**
+     * 拦截 Executor.query 执行前对执行sql进行处理
+     *
+     * @param executor      Executor(可能是代理对象)
+     * @param ms            MappedStatement
+     * @param parameter     parameter
+     * @param rowBounds     rowBounds
+     * @param resultHandler resultHandler
+     * @param boundSql      boundSql
+     * @return 新的 boundSql
+     */
+    BoundSql change(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException;
+}

+ 189 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/PageBeforeQuery.java

@@ -0,0 +1,189 @@
+package com.baomidou.mybatisplus.extension.plugins.chain;
+
+import com.baomidou.mybatisplus.annotation.DbType;
+import com.baomidou.mybatisplus.core.metadata.IPage;
+import com.baomidou.mybatisplus.core.metadata.OrderItem;
+import com.baomidou.mybatisplus.core.parser.ISqlParser;
+import com.baomidou.mybatisplus.core.parser.SqlInfo;
+import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
+import com.baomidou.mybatisplus.core.toolkit.ParameterUtils;
+import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
+import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
+import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
+import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
+import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
+import lombok.Data;
+import net.sf.jsqlparser.JSQLParserException;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.statement.select.*;
+import org.apache.ibatis.cache.CacheKey;
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.ParameterMapping;
+import org.apache.ibatis.session.Configuration;
+import org.apache.ibatis.session.ResultHandler;
+import org.apache.ibatis.session.RowBounds;
+
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * @author miemie
+ * @since 2020-06-16
+ */
+@Data
+public class PageBeforeQuery implements BeforeQuery {
+
+    /**
+     * COUNT SQL 解析
+     */
+    protected ISqlParser countSqlParser;
+    /**
+     * 溢出总页数后是否进行处理
+     */
+    protected boolean overflow = false;
+    /**
+     * 单页限制 500 条,小于 0 如 -1 不受限制
+     */
+    protected long limit = 500L;
+    /**
+     * 数据库类型
+     *
+     * @since 3.3.1
+     */
+    private DbType dbType;
+    /**
+     * 方言实现类
+     *
+     * @since 3.3.1
+     */
+    private IDialect dialect;
+
+    @Override
+    public BoundSql change(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
+        // 判断参数里是否有page对象
+        IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
+        /*
+         * 不需要分页的场合,如果 size 小于 0 返回结果集
+         */
+        if (null == page || page.getSize() < 0) {
+            return boundSql;
+        }
+
+        if (this.limit > 0 && this.limit <= page.getSize()) {
+            //处理单页条数限制
+            handlerLimit(page);
+        }
+        String originalSql = boundSql.getSql();
+        if (page.isSearchCount() && !page.isHitCount()) {
+            SqlInfo sqlInfo = SqlParserUtils.getOptimizeCountSql(page.optimizeCountSql(), countSqlParser, originalSql, null);
+            ms = buildCountMappedStatement(ms);
+            CacheKey cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
+            long count = (long) executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql).get(0);
+            page.setTotal(count);
+            if (!this.continueLimit(page)) {
+                return boundSql;
+            }
+        }
+        DbType dbType = this.dbType == null ? JdbcUtils.getDbType(JdbcUtils.getJdbcUrl(ms)) : this.dbType;
+        IDialect dialect = Optional.ofNullable(this.dialect).orElseGet(() -> DialectFactory.getDialect(dbType));
+        String buildSql = concatOrderBy(originalSql, page);
+        DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
+        final Configuration configuration = ms.getConfiguration();
+        List<ParameterMapping> mappings = new ArrayList<>(boundSql.getParameterMappings());
+        Map<String, Object> additionalParameter = JdbcUtils.getAdditionalParameter(boundSql);
+        model.consumers(mappings, configuration, additionalParameter);
+        boundSql = new BoundSql(configuration, model.getDialectSql(), mappings, parameter);
+        for (Map.Entry<String, Object> entry : additionalParameter.entrySet()) {
+            boundSql.setAdditionalParameter(entry.getKey(), entry.getValue());
+        }
+        return boundSql;
+    }
+
+    private MappedStatement buildCountMappedStatement(MappedStatement ms) {
+        //todo
+        return new MappedStatement.Builder(ms.getConfiguration(), ).build();
+    }
+
+    /**
+     * 查询SQL拼接Order By
+     *
+     * @param originalSql 需要拼接的SQL
+     * @param page        page对象
+     * @return ignore
+     */
+    public String concatOrderBy(String originalSql, IPage<?> page) {
+        if (CollectionUtils.isNotEmpty(page.orders())) {
+            try {
+                List<OrderItem> orderList = page.orders();
+                Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
+                if (selectStatement.getSelectBody() instanceof PlainSelect) {
+                    PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
+                    List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
+                    List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
+                    plainSelect.setOrderByElements(orderByElementsReturn);
+                    return plainSelect.toString();
+                } else if (selectStatement.getSelectBody() instanceof SetOperationList) {
+                    SetOperationList setOperationList = (SetOperationList) selectStatement.getSelectBody();
+                    List<OrderByElement> orderByElements = setOperationList.getOrderByElements();
+                    List<OrderByElement> orderByElementsReturn = addOrderByElements(orderList, orderByElements);
+                    setOperationList.setOrderByElements(orderByElementsReturn);
+                    return setOperationList.toString();
+                } else if (selectStatement.getSelectBody() instanceof WithItem) {
+                    // todo: don't known how to resole
+                    return originalSql;
+                } else {
+                    return originalSql;
+                }
+
+            } catch (JSQLParserException e) {
+                logger.warn("failed to concat orderBy from IPage, exception=" + e.getMessage());
+            }
+        }
+        return originalSql;
+    }
+
+    /**
+     * 判断是否继续执行 Limit 逻辑
+     *
+     * @param page 分页对象
+     * @return
+     */
+    protected boolean continueLimit(IPage<?> page) {
+        if (page.getTotal() <= 0) {
+            return false;
+        }
+        if (page.getCurrent() > page.getPages()) {
+            if (this.overflow) {
+                //溢出总页数处理
+                handlerOverflow(page);
+            } else {
+                // 超过最大范围,未设置溢出逻辑中断 list 执行
+                return false;
+            }
+        }
+        return true;
+    }
+
+    /**
+     * 处理超出分页条数限制,默认归为限制数
+     *
+     * @param page IPage
+     */
+    protected void handlerLimit(IPage<?> page) {
+        page.setSize(this.limit);
+    }
+
+    /**
+     * 处理页数溢出,默认设置为第一页
+     *
+     * @param page IPage
+     */
+    protected void handlerOverflow(IPage<?> page) {
+        page.setCurrent(1);
+    }
+}

+ 44 - 0
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/JdbcUtils.java

@@ -17,9 +17,18 @@ package com.baomidou.mybatisplus.extension.toolkit;
 
 import com.baomidou.mybatisplus.annotation.DbType;
 import com.baomidou.mybatisplus.core.toolkit.Assert;
+import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
 import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
+import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+
+import javax.sql.DataSource;
+import java.lang.reflect.Field;
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.util.Map;
 
 /**
  * JDBC 工具类
@@ -31,6 +40,41 @@ public class JdbcUtils {
 
     private static final Log logger = LogFactory.getLog(JdbcUtils.class);
 
+    private static Field additionalParametersField;
+
+    static {
+        try {
+            additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
+            additionalParametersField.setAccessible(true);
+        } catch (NoSuchFieldException e) {
+            throw ExceptionUtils.mpe("获取 BoundSql 属性 additionalParameters 失败: " + e, e);
+        }
+    }
+
+    /**
+     * 获取 BoundSql 属性值 additionalParameters
+     *
+     * @param boundSql
+     * @return
+     */
+    @SuppressWarnings("unchecked")
+    public static Map<String, Object> getAdditionalParameter(BoundSql boundSql) {
+        try {
+            return (Map<String, Object>) additionalParametersField.get(boundSql);
+        } catch (IllegalAccessException e) {
+            throw ExceptionUtils.mpe("获取 BoundSql 属性值 additionalParameters 失败: " + e, e);
+        }
+    }
+
+    public static String getJdbcUrl(MappedStatement ms) {
+        DataSource dataSource = ms.getConfiguration().getEnvironment().getDataSource();
+        try (Connection conn = dataSource.getConnection()) {
+            return conn.getMetaData().getURL();
+        } catch (SQLException e) {
+            throw ExceptionUtils.mpe(e);
+        }
+    }
+
     /**
      * 根据连接地址判断数据库类型
      *