Caratacus 8 年之前
父节点
当前提交
585d75150b

+ 16 - 1
mybatis-plus/pom.xml

@@ -37,6 +37,8 @@
 		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
 		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
 		<mybatis-spring.version>1.3.0</mybatis-spring.version>
 		<mybatis-spring.version>1.3.0</mybatis-spring.version>
 		<mybatis.version>3.4.1</mybatis.version>
 		<mybatis.version>3.4.1</mybatis.version>
+		<jsqlparser.version>0.9.6</jsqlparser.version>
+		<alibaba.druid.version>1.0.24</alibaba.druid.version>
 		<slf4j.version>1.7.21</slf4j.version>
 		<slf4j.version>1.7.21</slf4j.version>
 		<logback-classic.version>1.1.7</logback-classic.version>
 		<logback-classic.version>1.1.7</logback-classic.version>
 		<mysql-connector-java.version>5.1.38</mysql-connector-java.version>
 		<mysql-connector-java.version>5.1.38</mysql-connector-java.version>
@@ -89,7 +91,20 @@
 			<version>${spring.version}</version>
 			<version>${spring.version}</version>
 			<scope>provided</scope>
 			<scope>provided</scope>
 		</dependency>
 		</dependency>
-
+		<!--jsqlparser-->
+		<dependency>
+			<groupId>com.github.jsqlparser</groupId>
+			<artifactId>jsqlparser</artifactId>
+			<version>${jsqlparser.version}</version>
+			<scope>provided</scope>
+		</dependency>
+		<!--druid-->
+		<dependency>
+			<groupId>com.alibaba</groupId>
+			<artifactId>druid</artifactId>
+			<version>${alibaba.druid.version}</version>
+			<scope>provided</scope>
+		</dependency>
 		<!-- test begin -->
 		<!-- test begin -->
 		<dependency>
 		<dependency>
 			<groupId>mysql</groupId>
 			<groupId>mysql</groupId>

+ 10 - 4
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/PaginationInterceptor.java

@@ -60,10 +60,10 @@ public class PaginationInterceptor implements Interceptor {
 
 
 	/* 溢出总页数,设置第一页 */
 	/* 溢出总页数,设置第一页 */
 	private boolean overflowCurrent = false;
 	private boolean overflowCurrent = false;
-
+	/* Count优化方式 */
+	private String optimizeType = "default";
 	/* 方言类型 */
 	/* 方言类型 */
 	private String dialectType;
 	private String dialectType;
-
 	/* 方言实现类 */
 	/* 方言实现类 */
 	private String dialectClazz;
 	private String dialectClazz;
 
 
@@ -107,7 +107,8 @@ public class PaginationInterceptor implements Interceptor {
 					/*
 					/*
 					 * COUNT 查询,去掉 ORDER BY 优化执行 SQL
 					 * COUNT 查询,去掉 ORDER BY 优化执行 SQL
 					 */
 					 */
-					CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, page.isOptimizeCount());
+					CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, optimizeType, dialectType,
+							page.isOptimizeCount());
 					orderBy = countOptimize.isOrderBy();
 					orderBy = countOptimize.isOrderBy();
 				}
 				}
 				/* 执行 SQL */
 				/* 执行 SQL */
@@ -156,7 +157,8 @@ public class PaginationInterceptor implements Interceptor {
 						/*
 						/*
 						 * COUNT 查询,去掉 ORDER BY 优化执行 SQL
 						 * COUNT 查询,去掉 ORDER BY 优化执行 SQL
 						 */
 						 */
-						CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, page.isOptimizeCount());
+						CountOptimize countOptimize = SqlUtils.getCountOptimize(originalSql, optimizeType, dialectType,
+								page.isOptimizeCount());
 						page = this.count(countOptimize.getCountSQL(), connection, mappedStatement, boundSql, page);
 						page = this.count(countOptimize.getCountSQL(), connection, mappedStatement, boundSql, page);
 						/** 总数 0 跳出执行 */
 						/** 总数 0 跳出执行 */
 						if (page.getTotal() <= 0) {
 						if (page.getTotal() <= 0) {
@@ -274,4 +276,8 @@ public class PaginationInterceptor implements Interceptor {
 	public void setOverflowCurrent(boolean overflowCurrent) {
 	public void setOverflowCurrent(boolean overflowCurrent) {
 		this.overflowCurrent = overflowCurrent;
 		this.overflowCurrent = overflowCurrent;
 	}
 	}
+
+	public void setOptimizeType(String optimizeType) {
+		this.optimizeType = optimizeType;
+	}
 }
 }

+ 75 - 0
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/entity/Optimize.java

@@ -0,0 +1,75 @@
+/**
+ * Copyright (c) 2011-2014, hubin (jobob@qq.com).
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.baomidou.mybatisplus.plugins.entity;
+
+/**
+ * <p>
+ * Count优化枚举
+ * </p>
+ * 
+ * @author Caratacus
+ * @Date 2016-11-30
+ */
+public enum Optimize {
+	/**
+	 * 默认支持方式
+	 */
+	DEFAULT("default", "默认方式"),
+	/**
+	 * aliDruid,需添加相关依赖jar包
+	 */
+	ALI_DRUID("aliDruid", "依赖aliDruid模式"),
+	/**
+	 * jsqlparser方式,需添加相关依赖jar包
+	 */
+	JSQLPARSER("jsqlparser", "jsqlparser方式");
+
+	private final String optimize;
+
+	private final String desc;
+
+	Optimize(final String optimize, final String desc) {
+		this.optimize = optimize;
+		this.desc = desc;
+	}
+
+	/**
+	 * <p>
+	 * 获取优化类型.如果没有找到默认DEFAULT
+	 * </p>
+	 * 
+	 * @param optimizeType
+	 *            优化方式
+	 * @return
+	 */
+	public static Optimize getOptimizeType(String optimizeType) {
+		for (Optimize optimize : Optimize.values()) {
+			if (optimize.getOptimize().equalsIgnoreCase(optimizeType)) {
+				return optimize;
+			}
+		}
+		return DEFAULT;
+	}
+
+	public String getOptimize() {
+		return this.optimize;
+	}
+
+	public String getDesc() {
+		return this.desc;
+	}
+
+}

+ 122 - 22
mybatis-plus/src/main/java/com/baomidou/mybatisplus/toolkit/SqlUtils.java

@@ -15,9 +15,25 @@
  */
  */
 package com.baomidou.mybatisplus.toolkit;
 package com.baomidou.mybatisplus.toolkit;
 
 
+import com.alibaba.druid.sql.PagerUtils;
 import com.baomidou.mybatisplus.plugins.SQLFormatter;
 import com.baomidou.mybatisplus.plugins.SQLFormatter;
 import com.baomidou.mybatisplus.plugins.entity.CountOptimize;
 import com.baomidou.mybatisplus.plugins.entity.CountOptimize;
+import com.baomidou.mybatisplus.plugins.entity.Optimize;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
 import com.baomidou.mybatisplus.plugins.pagination.Pagination;
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.Function;
+import net.sf.jsqlparser.expression.LongValue;
+import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
+import net.sf.jsqlparser.parser.CCJSqlParserUtil;
+import net.sf.jsqlparser.statement.select.Distinct;
+import net.sf.jsqlparser.statement.select.OrderByElement;
+import net.sf.jsqlparser.statement.select.PlainSelect;
+import net.sf.jsqlparser.statement.select.Select;
+import net.sf.jsqlparser.statement.select.SelectExpressionItem;
+import net.sf.jsqlparser.statement.select.SelectItem;
+
+import java.util.ArrayList;
+import java.util.List;
 
 
 /**
 /**
  * <p>
  * <p>
@@ -29,46 +45,73 @@ import com.baomidou.mybatisplus.plugins.pagination.Pagination;
  */
  */
 public class SqlUtils {
 public class SqlUtils {
 	private final static SQLFormatter sqlFormatter = new SQLFormatter();
 	private final static SQLFormatter sqlFormatter = new SQLFormatter();
+	private static final String SQL_BASE_COUNT = "SELECT COUNT(1) FROM ( %s )";
+	private static List<SelectItem> countSelectItem = null;
 
 
 	/**
 	/**
 	 * 获取CountOptimize
 	 * 获取CountOptimize
 	 * 
 	 * 
 	 * @param originalSql
 	 * @param originalSql
 	 *            需要计算Count SQL
 	 *            需要计算Count SQL
+	 * @param optimizeType
+	 *            count优化方式
 	 * @param isOptimizeCount
 	 * @param isOptimizeCount
 	 *            是否需要优化Count
 	 *            是否需要优化Count
 	 * @return CountOptimize
 	 * @return CountOptimize
 	 */
 	 */
-	public static CountOptimize getCountOptimize(String originalSql, boolean isOptimizeCount) {
-		boolean optimize = false;
+	public static CountOptimize getCountOptimize(String originalSql, String optimizeType, String dialectType,
+			boolean isOptimizeCount) {
 		CountOptimize countOptimize = CountOptimize.newInstance();
 		CountOptimize countOptimize = CountOptimize.newInstance();
-		StringBuffer countSql = new StringBuffer("SELECT COUNT(1) AS TOTAL ");
 		if (isOptimizeCount) {
 		if (isOptimizeCount) {
 			String tempSql = originalSql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY").replaceAll("(?i)GROUP[\\s]+BY", "GROUP BY");
 			String tempSql = originalSql.replaceAll("(?i)ORDER[\\s]+BY", "ORDER BY").replaceAll("(?i)GROUP[\\s]+BY", "GROUP BY");
 			String indexOfSql = tempSql.toUpperCase();
 			String indexOfSql = tempSql.toUpperCase();
-			if (!indexOfSql.contains("DISTINCT") && !indexOfSql.contains("GROUP BY")) {
-				int formIndex = indexOfSql.indexOf("FROM");
-				if (formIndex > -1) {
-					// 有排序情况
-					int orderByIndex = indexOfSql.lastIndexOf("ORDER BY");
-					if (orderByIndex > -1) {
-						tempSql = tempSql.substring(0, orderByIndex);
-						countSql.append(tempSql.substring(formIndex));
-						countOptimize.setOrderBy(false);
-						// 无排序情况
-					} else {
-						countSql.append(tempSql.substring(formIndex));
+			// 有排序情况
+			int orderByIndex = indexOfSql.lastIndexOf("ORDER BY");
+			// 只针对 ALI_DRUID DEFAULT 这2种情况
+			if (orderByIndex > -1) {
+				countOptimize.setOrderBy(false);
+			}
+			Optimize opType = Optimize.getOptimizeType(optimizeType);
+			switch (opType) {
+			case ALI_DRUID:
+				/**
+				 * 调用ali druid方式 插件dbType一定要设置为小写与JdbcConstants保持一致
+				 * 
+				 * @see com.alibaba.druid.util.JdbcConstants
+				 */
+				String aliCountSql = PagerUtils.count(originalSql, dialectType);
+				countOptimize.setCountSQL(aliCountSql);
+				break;
+			case JSQLPARSER:
+				jsqlparserCount(countOptimize, originalSql);
+				break;
+			default:
+				StringBuffer countSql = new StringBuffer("SELECT COUNT(1) AS TOTAL ");
+				boolean optimize = false;
+				if (!indexOfSql.contains("DISTINCT") && !indexOfSql.contains("GROUP BY")) {
+					int formIndex = indexOfSql.indexOf("FROM");
+					if (formIndex > -1) {
+						if (orderByIndex > -1) {
+							tempSql = tempSql.substring(0, orderByIndex);
+							countSql.append(tempSql.substring(formIndex));
+							// 无排序情况
+						} else {
+							countSql.append(tempSql.substring(formIndex));
+						}
+						// 执行优化
+						optimize = true;
 					}
 					}
-					// 执行优化
-					optimize = true;
 				}
 				}
+				if (!optimize) {
+					// 无优化SQL
+					countSql.append("FROM (").append(originalSql).append(") A");
+				}
+				countOptimize.setCountSQL(countSql.toString());
+				;
 			}
 			}
+
 		}
 		}
-		if (!optimize) {
-			// 无优化SQL
-			countSql.append("FROM (").append(originalSql).append(") A");
-		}
-		countOptimize.setCountSQL(countSql.toString());
+
 		return countOptimize;
 		return countOptimize;
 	}
 	}
 
 
@@ -107,4 +150,61 @@ public class SqlUtils {
 			return boundSql.replaceAll("[\\s]+", " ");
 			return boundSql.replaceAll("[\\s]+", " ");
 		}
 		}
 	}
 	}
+
+	/**
+	 * jsqlparser方式获取select的count语句
+	 *
+	 * @param originalSql
+	 *            selectSQL
+	 * @return
+	 */
+	public static CountOptimize jsqlparserCount(CountOptimize countOptimize, String originalSql) {
+		String sqlCount;
+		try {
+			Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
+			PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
+			Distinct distinct = plainSelect.getDistinct();
+			List<Expression> groupBy = plainSelect.getGroupByColumnReferences();
+			// 包含 distinct、groupBy不优化
+			if (distinct != null || CollectionUtil.isNotEmpty(groupBy)) {
+				sqlCount = String.format(SQL_BASE_COUNT, originalSql);
+			}
+			// 优化Order by
+			List<OrderByElement> orderBy = plainSelect.getOrderByElements();
+			if (CollectionUtil.isNotEmpty(orderBy)) {
+				plainSelect.setOrderByElements(null);
+				countOptimize.setOrderBy(false);
+			}
+			List<SelectItem> selectCount = countSelectItem();
+			plainSelect.setSelectItems(selectCount);
+			sqlCount = selectStatement.toString();
+		} catch (Exception e) {
+			sqlCount = String.format(SQL_BASE_COUNT, originalSql);
+		}
+		countOptimize.setCountSQL(sqlCount);
+		return countOptimize;
+	}
+
+	/**
+	 * 获取jsqlparser中count的SelectItem
+	 *
+	 * @return
+	 */
+	private static List<SelectItem> countSelectItem() {
+		if (CollectionUtil.isNotEmpty(countSelectItem)) {
+			return countSelectItem;
+		}
+		Function function = new Function();
+		function.setName("COUNT");
+		List<Expression> expressions = new ArrayList<Expression>();
+		LongValue longValue = new LongValue(1);
+		ExpressionList expressionList = new ExpressionList();
+		expressions.add(longValue);
+		expressionList.setExpressions(expressions);
+		function.setParameters(expressionList);
+		countSelectItem = new ArrayList<SelectItem>();
+		SelectExpressionItem selectExpressionItem = new SelectExpressionItem(function);
+		countSelectItem.add(selectExpressionItem);
+		return countSelectItem;
+	}
 }
 }

+ 2 - 0
mybatis-plus/src/test/resources/wiki/plugin.md

@@ -24,6 +24,8 @@
     <!-- 配置方式二、使用自定义方言实现类 -->
     <!-- 配置方式二、使用自定义方言实现类 -->
     <plugin interceptor="com.baomidou.mybatisplus.plugins.PaginationInterceptor">
     <plugin interceptor="com.baomidou.mybatisplus.plugins.PaginationInterceptor">
         <property name="dialectClazz" value="xxx.dialect.XXDialect" />
         <property name="dialectClazz" value="xxx.dialect.XXDialect" />
+        <!--支持aliDruid与jsqlparser 默认default-->
+        <property name="optimizeType" value="aliDruid" />
     </plugin>
     </plugin>
     <!-- SQL 执行性能分析,开发环境使用,线上不推荐。 maxTime 指的是 sql 最大执行时长 -->
     <!-- SQL 执行性能分析,开发环境使用,线上不推荐。 maxTime 指的是 sql 最大执行时长 -->
     <plugin interceptor="com.baomidou.mybatisplus.plugins.PerformanceInterceptor">
     <plugin interceptor="com.baomidou.mybatisplus.plugins.PerformanceInterceptor">