Caratacus 8 lat temu
rodzic
commit
7b2f2f6810

+ 95 - 34
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/PaginationInterceptor.java

@@ -22,6 +22,7 @@ import com.baomidou.mybatisplus.plugins.pagination.IDialect;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
 import com.baomidou.mybatisplus.toolkit.SqlUtils;
 import com.baomidou.mybatisplus.toolkit.StringUtils;
+import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.executor.parameter.ParameterHandler;
 import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.mapping.BoundSql;
@@ -34,6 +35,7 @@ import org.apache.ibatis.plugin.Signature;
 import org.apache.ibatis.reflection.MetaObject;
 import org.apache.ibatis.reflection.SystemMetaObject;
 import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
+import org.apache.ibatis.session.ResultHandler;
 import org.apache.ibatis.session.RowBounds;
 
 import java.sql.Connection;
@@ -50,7 +52,10 @@ import java.util.Properties;
  * @author hubin
  * @Date 2016-01-23
  */
-@Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class, Integer.class }) })
+@Intercepts({
+		@Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class, RowBounds.class,
+				ResultHandler.class }),
+		@Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class, Integer.class }) })
 public class PaginationInterceptor implements Interceptor {
 
 	/* 溢出总页数,设置第一页 */
@@ -63,6 +68,7 @@ public class PaginationInterceptor implements Interceptor {
 	private String dialectClazz;
 
 	public Object intercept(Invocation invocation) throws Throwable {
+
 		Object target = invocation.getTarget();
 		if (target instanceof StatementHandler) {
 			StatementHandler statementHandler = (StatementHandler) target;
@@ -75,39 +81,66 @@ public class PaginationInterceptor implements Interceptor {
 			}
 
 			/* 定义数据库方言 */
-			IDialect dialect = null;
-			if (StringUtils.isNotEmpty(dialectType)) {
-				dialect = DialectFactory.getDialectByDbtype(dialectType);
-			} else {
-				if (StringUtils.isNotEmpty(dialectClazz)) {
-					try {
-						Class<?> clazz = Class.forName(dialectClazz);
-						if (IDialect.class.isAssignableFrom(clazz)) {
-							dialect = (IDialect) clazz.newInstance();
-						}
-					} catch (ClassNotFoundException e) {
-						throw new MybatisPlusException("Class :" + dialectClazz + " is not found");
-					}
+			IDialect dialect = getiDialect();
+
+			/*
+			 * <p> 禁用内存分页 </p> <p> 内存分页会查询所有结果出来处理(这个很吓人的),如果结果变化频繁这个数据还会不准。
+			 * </p>
+			 */
+			BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
+			String originalSql = (String) boundSql.getSql();
+			metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
+			metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
+
+			/**
+			 * <p>
+			 * 分页逻辑
+			 * </p>
+			 * <p>
+			 * 查询总记录数 count
+			 * </p>
+			 */
+			if (rowBounds instanceof Pagination) {
+				Pagination page = (Pagination) rowBounds;
+				boolean orderBy = true;
+				if (page.isSearchCount()) {
+					/*
+					 * COUNT 查询,去掉 ORDER BY 优化执行 SQL
+					 */
+					CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, page.isOptimizeCount());
+					orderBy = countOptimize.isOrderBy();
 				}
+				/* 执行 SQL */
+				String buildSql = SqlUtils.concatOrderBy(originalSql, page, orderBy);
+				originalSql = dialect.buildPaginationSql(buildSql, page.getOffsetCurrent(), page.getSize());
 			}
 
-			/* 未配置方言则抛出异常 */
-			if (dialect == null) {
-				throw new MybatisPlusException("The value of the dialect property in mybatis configuration.xml is not defined.");
+			/**
+			 * 查询 SQL 设置
+			 */
+			metaStatementHandler.setValue("delegate.boundSql.sql", originalSql);
+		} else {
+			MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
+			Object parameterObject = null;
+			RowBounds rowBounds = null;
+			if (invocation.getArgs().length > 1) {
+				parameterObject = invocation.getArgs()[1];
+				rowBounds = (RowBounds) invocation.getArgs()[2];
+			}
+			/* 不需要分页的场合 */
+			if (rowBounds == null || rowBounds == RowBounds.DEFAULT) {
+				return invocation.proceed();
 			}
 
+			/* 定义数据库方言 */
+			IDialect dialect = getiDialect();
+
+			BoundSql boundSql = mappedStatement.getBoundSql(parameterObject);
 			/*
-			 * <p>
-			 * 禁用内存分页
-			 * </p>
-			 * <p>
-			 * 内存分页会查询所有结果出来处理(这个很吓人的),如果结果变化频繁这个数据还会不准。
+			 * <p> 禁用内存分页 </p> <p> 内存分页会查询所有结果出来处理(这个很吓人的),如果结果变化频繁这个数据还会不准。
 			 * </p>
 			 */
-			BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");
 			String originalSql = (String) boundSql.getSql();
-			metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);
-			metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);
 
 			/**
 			 * <p>
@@ -118,8 +151,7 @@ public class PaginationInterceptor implements Interceptor {
 			 * </p>
 			 */
 			if (rowBounds instanceof Pagination) {
-				MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");
-				Connection connection = (Connection) invocation.getArgs()[0];
+				Connection connection = mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();
 				Pagination page = (Pagination) rowBounds;
 				boolean orderBy = true;
 				if (page.isSearchCount()) {
@@ -137,15 +169,42 @@ public class PaginationInterceptor implements Interceptor {
 				/* 执行 SQL */
 				String buildSql = SqlUtils.concatOrderBy(originalSql, page, orderBy);
 				originalSql = dialect.buildPaginationSql(buildSql, page.getOffsetCurrent(), page.getSize());
+				/* 更新需要执行SQL */
+				SystemMetaObject.forObject(boundSql).setValue("sql", originalSql);
 			}
-
-			/**
-			 * 查询 SQL 设置
-			 */
-			metaStatementHandler.setValue("delegate.boundSql.sql", originalSql);
 		}
 
 		return invocation.proceed();
+
+	}
+
+	/**
+	 * 获取数据库方言
+	 *
+	 * @return
+	 * @throws Exception
+	 */
+	private IDialect getiDialect() throws Exception {
+		IDialect dialect = null;
+		if (StringUtils.isNotEmpty(dialectType)) {
+			dialect = DialectFactory.getDialectByDbtype(dialectType);
+		} else {
+			if (StringUtils.isNotEmpty(dialectClazz)) {
+				try {
+					Class<?> clazz = Class.forName(dialectClazz);
+					if (IDialect.class.isAssignableFrom(clazz)) {
+						dialect = (IDialect) clazz.newInstance();
+					}
+				} catch (ClassNotFoundException e) {
+					throw new MybatisPlusException("Class :" + dialectClazz + " is not found");
+				}
+			}
+		}
+		/* 未配置方言则抛出异常 */
+		if (dialect == null) {
+			throw new MybatisPlusException("The value of the dialect property in mybatis configuration.xml is not defined.");
+		}
+		return dialect;
 	}
 
 	/**
@@ -198,6 +257,9 @@ public class PaginationInterceptor implements Interceptor {
 	}
 
 	public Object plugin(Object target) {
+		if (target instanceof Executor) {
+			return Plugin.wrap(target, this);
+		}
 		if (target instanceof StatementHandler) {
 			return Plugin.wrap(target, this);
 		}
@@ -226,5 +288,4 @@ public class PaginationInterceptor implements Interceptor {
 	public void setOverflowCurrent(boolean overflowCurrent) {
 		this.overflowCurrent = overflowCurrent;
 	}
-
-}
+}

+ 2 - 2
mybatis-plus/src/main/java/com/baomidou/mybatisplus/toolkit/SqlUtils.java

@@ -44,9 +44,9 @@ public class SqlUtils {
 		CountOptimize countOptimize = CountOptimize.newInstance();
 		StringBuffer countSql = new StringBuffer("SELECT COUNT(1) AS TOTAL ");
 		if (isOptimizeCount) {
-			String tempSql = originalSql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY");
+			String tempSql = originalSql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY").replaceAll("(?i)GROUP[\\s]+BY", "GROUP BY");
 			String indexOfSql = tempSql.toUpperCase();
-			if (!indexOfSql.contains("DISTINCT")) {
+			if (!indexOfSql.contains("DISTINCT") && !indexOfSql.contains("GROUP BY")) {
 				int formIndex = indexOfSql.indexOf("FROM");
 				if (formIndex > -1) {
 					// 有排序情况