Browse Source

to qiuqiu

miemie 5 years ago
parent
commit
498bb3f4c8

+ 101 - 28
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/PluginUtils.java

@@ -15,12 +15,20 @@
  */
 package com.baomidou.mybatisplus.core.toolkit;
 
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.executor.parameter.ParameterHandler;
+import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.mapping.BoundSql;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.ParameterMapping;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.SystemMetaObject;
+import org.apache.ibatis.session.Configuration;
 
-import java.lang.reflect.Field;
 import java.lang.reflect.Proxy;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -32,18 +40,6 @@ import java.util.Map;
 public abstract class PluginUtils {
     public static final String DELEGATE_BOUNDSQL_SQL = "delegate.boundSql.sql";
 
-    private final static Field additionalParametersField = initBoundSqlAdditionalParametersField();
-
-    private static Field initBoundSqlAdditionalParametersField() {
-        try {
-            Field field = BoundSql.class.getDeclaredField("additionalParameters");
-            field.setAccessible(true);
-            return field;
-        } catch (NoSuchFieldException e) {
-            throw ExceptionUtils.mpe("can not find field['additionalParameters'] from BoundSql, why?", e);
-        }
-    }
-
     /**
      * 获得真正的处理对象,可能多层代理.
      */
@@ -56,21 +52,6 @@ public abstract class PluginUtils {
         return (T) target;
     }
 
-    /**
-     * 获取 BoundSql 属性值 additionalParameters
-     *
-     * @param boundSql BoundSql
-     * @return additionalParameters
-     */
-    @SuppressWarnings("unchecked")
-    public static Map<String, Object> getAdditionalParameter(BoundSql boundSql) {
-        try {
-            return (Map<String, Object>) additionalParametersField.get(boundSql);
-        } catch (IllegalAccessException e) {
-            throw ExceptionUtils.mpe("获取 BoundSql 属性值 additionalParameters 失败: " + e, e);
-        }
-    }
-
     /**
      * 给 BoundSql 设置 additionalParameters
      *
@@ -80,4 +61,96 @@ public abstract class PluginUtils {
     public static void setAdditionalParameter(BoundSql boundSql, Map<String, Object> additionalParameters) {
         additionalParameters.forEach(boundSql::setAdditionalParameter);
     }
+
+    public static MPBoundSql mpBoundSql(BoundSql boundSql) {
+        return new MPBoundSql(boundSql);
+    }
+
+    public static MPStatementHandler mpStatementHandler(StatementHandler statementHandler) {
+        MetaObject object = SystemMetaObject.forObject(statementHandler);
+        return new MPStatementHandler(SystemMetaObject.forObject(object.getValue("delegate")));
+    }
+
+    /**
+     * {@link org.apache.ibatis.executor.statement.BaseStatementHandler}
+     */
+    public static class MPStatementHandler {
+        private final MetaObject statementHandler;
+
+        MPStatementHandler(MetaObject statementHandler) {
+            this.statementHandler = statementHandler;
+        }
+
+        public ParameterHandler parameterHandler() {
+            return get("parameterHandler");
+        }
+
+        public MappedStatement mappedStatement() {
+            return get("mappedStatement");
+        }
+
+        public Executor executor() {
+            return get("executor");
+        }
+
+        public MPBoundSql mPBoundSql() {
+            return new MPBoundSql(boundSql());
+        }
+
+        public BoundSql boundSql() {
+            return get("boundSql");
+        }
+
+        public Configuration configuration() {
+            return get("configuration");
+        }
+
+        @SuppressWarnings("unchecked")
+        private <T> T get(String property) {
+            return (T) statementHandler.getValue(property);
+        }
+    }
+
+    /**
+     * {@link BoundSql}
+     */
+    public static class MPBoundSql {
+        private final MetaObject boundSql;
+        private final BoundSql delegate;
+
+        MPBoundSql(BoundSql boundSql) {
+            this.delegate = boundSql;
+            this.boundSql = SystemMetaObject.forObject(boundSql);
+        }
+
+        public String sql() {
+            return delegate.getSql();
+        }
+
+        public void sql(String sql) {
+            boundSql.setValue("sql", sql);
+        }
+
+        public List<ParameterMapping> parameterMappings() {
+            List<ParameterMapping> parameterMappings = delegate.getParameterMappings();
+            return new ArrayList<>(parameterMappings);
+        }
+
+        public void parameterMappings(List<ParameterMapping> parameterMappings) {
+            boundSql.setValue("parameterMappings", Collections.unmodifiableList(parameterMappings));
+        }
+
+        public Object parameterObject() {
+            return get("parameterObject");
+        }
+
+        public Map<String, Object> additionalParameters() {
+            return get("additionalParameters");
+        }
+
+        @SuppressWarnings("unchecked")
+        private <T> T get(String property) {
+            return (T) boundSql.getValue(property);
+        }
+    }
 }

+ 46 - 22
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/MybatisPlusInterceptor.java

@@ -1,15 +1,18 @@
 package com.baomidou.mybatisplus.extension.plugins;
 
-import com.baomidou.mybatisplus.extension.plugins.chain.BeforeQuery;
-import com.baomidou.mybatisplus.extension.plugins.chain.PageBeforeQuery;
+import com.baomidou.mybatisplus.extension.plugins.chain.PageQiuQiu;
+import com.baomidou.mybatisplus.extension.plugins.chain.QiuQiu;
 import org.apache.ibatis.cache.CacheKey;
 import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
 import org.apache.ibatis.plugin.*;
 import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
+import java.sql.Connection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Properties;
@@ -21,42 +24,63 @@ import java.util.Properties;
 @SuppressWarnings({"rawtypes"})
 @Intercepts(
     {
+        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class}),
+        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
         @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
         @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
     }
 )
 public class MybatisPlusInterceptor implements Interceptor {
 
-    private final List<BeforeQuery> beforeQueries = Collections.singletonList(new PageBeforeQuery());
+    private final List<QiuQiu> qiuQius = Collections.singletonList(new PageQiuQiu());
 
     @Override
     public Object intercept(Invocation invocation) throws Throwable {
+        Object target = invocation.getTarget();
         Object[] args = invocation.getArgs();
-        MappedStatement ms = (MappedStatement) args[0];
-        Object parameter = args[1];
-        RowBounds rowBounds = (RowBounds) args[2];
-        ResultHandler resultHandler = (ResultHandler) args[3];
-        Executor executor = (Executor) invocation.getTarget();
-        BoundSql boundSql;
-        if (args.length == 4) {
-            boundSql = ms.getBoundSql(parameter);
+        if (target instanceof Executor) {
+            final Executor executor = (Executor) target;
+            Object parameter = args[1];
+            boolean isUpdate = args.length == 2;
+            MappedStatement ms = (MappedStatement) args[0];
+            if (!isUpdate && ms.getSqlCommandType() == SqlCommandType.SELECT) {
+                RowBounds rowBounds = (RowBounds) args[2];
+                ResultHandler resultHandler = (ResultHandler) args[3];
+                BoundSql boundSql;
+                if (args.length == 4) {
+                    boundSql = ms.getBoundSql(parameter);
+                } else {
+                    // 几乎不可能走进这里面,除非使用Executor的代理对象调用query[args[6]]
+                    boundSql = (BoundSql) args[5];
+                }
+                for (QiuQiu query : qiuQius) {
+                    if (!query.willDoQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql)) {
+                        return Collections.emptyList();
+                    }
+                    query.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql);
+                }
+                CacheKey cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
+                return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
+            } else if (isUpdate && ms.getSqlCommandType() == SqlCommandType.UPDATE) {
+                for (QiuQiu query : qiuQius) {
+                    query.update(executor, ms, parameter);
+                }
+            }
         } else {
-            // 几乎不可能走进这里面,除非使用Executor的代理对象调用query[args[6]]
-            boundSql = (BoundSql) args[5];
-        }
-        for (BeforeQuery query : beforeQueries) {
-            if (!query.canChange(executor, ms, parameter, rowBounds, resultHandler, boundSql)) {
-                return Collections.emptyList();
+            // StatementHandler
+            final StatementHandler sh = (StatementHandler) target;
+            Connection connections = (Connection) args[0];
+            Integer transactionTimeout = (Integer) args[1];
+            for (QiuQiu qiuQiu : qiuQius) {
+                qiuQiu.prepare(sh, connections, transactionTimeout);
             }
-            boundSql = query.change(executor, ms, parameter, rowBounds, resultHandler, boundSql);
         }
-        CacheKey cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
-        return executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
+        return invocation.proceed();
     }
 
     @Override
     public Object plugin(Object target) {
-        if (target instanceof Executor) {
+        if (target instanceof Executor || target instanceof StatementHandler) {
             return Plugin.wrap(target, this);
         }
         return target;
@@ -64,6 +88,6 @@ public class MybatisPlusInterceptor implements Interceptor {
 
     @Override
     public void setProperties(Properties properties) {
-
+        // todo
     }
 }

+ 13 - 12
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/PageBeforeQuery.java → mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/PageQiuQiu.java

@@ -41,9 +41,9 @@ import java.util.stream.Collectors;
  * @since 2020-06-16
  */
 @Data
-public class PageBeforeQuery implements BeforeQuery {
+public class PageQiuQiu implements QiuQiu {
 
-    protected static final Log logger = LogFactory.getLog(PageBeforeQuery.class);
+    protected static final Log logger = LogFactory.getLog(PageQiuQiu.class);
 
     /**
      * COUNT SQL 解析
@@ -74,7 +74,7 @@ public class PageBeforeQuery implements BeforeQuery {
      * 这里进行count,如果count为0这返回false(就是不再执行sql了)
      */
     @Override
-    public boolean canChange(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
+    public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
         // 判断参数里是否有page对象
         IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
         if (page != null) {
@@ -89,8 +89,9 @@ public class PageBeforeQuery implements BeforeQuery {
                 final String originalSql = boundSql.getSql();
                 SqlInfo sqlInfo = SqlParserUtils.getOptimizeCountSql(page.optimizeCountSql(), countSqlParser, originalSql, SystemMetaObject.forObject(parameter));
                 MappedStatement countMappedStatement = buildCountMappedStatement(ms);
-                BoundSql countSql = new BoundSql(countMappedStatement.getConfiguration(), sqlInfo.getSql(), boundSql.getParameterMappings(), parameter);
-                PluginUtils.setAdditionalParameter(countSql, PluginUtils.getAdditionalParameter(boundSql));
+                PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
+                BoundSql countSql = new BoundSql(countMappedStatement.getConfiguration(), sqlInfo.getSql(), mpBoundSql.parameterMappings(), parameter);
+                PluginUtils.setAdditionalParameter(countSql, mpBoundSql.additionalParameters());
                 CacheKey cacheKey = executor.createCacheKey(countMappedStatement, parameter, rowBounds, countSql);
                 long count = (long) executor.query(countMappedStatement, parameter, rowBounds, resultHandler, cacheKey, countSql).get(0);
                 page.setTotal(count);
@@ -101,14 +102,14 @@ public class PageBeforeQuery implements BeforeQuery {
     }
 
     @Override
-    public BoundSql change(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
+    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
         // 判断参数里是否有page对象
         IPage<?> page = ParameterUtils.findPage(parameter).orElse(null);
         /*
          * 不需要分页的场合,如果 size 小于 0 返回结果集
          */
         if (null == page || page.getSize() < 0) {
-            return boundSql;
+            return;
         }
 
         if (this.limit > 0 && this.limit <= page.getSize()) {
@@ -121,12 +122,12 @@ public class PageBeforeQuery implements BeforeQuery {
         String buildSql = this.concatOrderBy(originalSql, page);
         DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
         final Configuration configuration = ms.getConfiguration();
-        List<ParameterMapping> mappings = new ArrayList<>(boundSql.getParameterMappings());
-        Map<String, Object> additionalParameter = PluginUtils.getAdditionalParameter(boundSql);
+        PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
+        List<ParameterMapping> mappings = mpBoundSql.parameterMappings();
+        Map<String, Object> additionalParameter = mpBoundSql.additionalParameters();
         model.consumers(mappings, configuration, additionalParameter);
-        boundSql = new BoundSql(configuration, model.getDialectSql(), mappings, parameter);
-        PluginUtils.setAdditionalParameter(boundSql, additionalParameter);
-        return boundSql;
+        mpBoundSql.sql(model.getDialectSql());
+        mpBoundSql.parameterMappings(mappings);
     }
 
     protected MappedStatement buildCountMappedStatement(MappedStatement ms) {

+ 23 - 8
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/BeforeQuery.java → mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/chain/QiuQiu.java

@@ -1,22 +1,22 @@
 package com.baomidou.mybatisplus.extension.plugins.chain;
 
 import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
 import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
+import java.sql.Connection;
 import java.sql.SQLException;
 
 /**
  * @author miemie
  * @since 2020-06-16
  */
-public interface BeforeQuery {
+public interface QiuQiu {
 
     /**
-     * 拦截 Executor.query 执行前对执行sql进行处理
-     *
      * @param executor      Executor(可能是代理对象)
      * @param ms            MappedStatement
      * @param parameter     parameter
@@ -25,20 +25,35 @@ public interface BeforeQuery {
      * @param boundSql      boundSql
      * @return 新的 boundSql
      */
-    default boolean canChange(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
+    default boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
         return true;
     }
 
     /**
-     * 拦截 Executor.query 执行前对执行sql进行处理
-     *
      * @param executor      Executor(可能是代理对象)
      * @param ms            MappedStatement
      * @param parameter     parameter
      * @param rowBounds     rowBounds
      * @param resultHandler resultHandler
      * @param boundSql      boundSql
-     * @return 新的 boundSql
      */
-    BoundSql change(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException;
+    default void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
+        // do nothing
+    }
+
+    /**
+     * @param executor  Executor(可能是代理对象)
+     * @param ms        MappedStatement
+     * @param parameter parameter
+     */
+    default void update(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
+        // do nothing
+    }
+
+    /**
+     *
+     */
+    default void prepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
+        // do nothing
+    }
 }