Browse Source

重写乐观锁,完成所有测试. 改拦截stamentHandler.

小锅盖 8 years ago
parent
commit
f3c031ffb7

+ 31 - 119
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/OptimisticLockerInterceptor.java

@@ -3,43 +3,41 @@ package com.baomidou.mybatisplus.plugins;
 import java.lang.reflect.Field;
 import java.lang.reflect.ParameterizedType;
 import java.lang.reflect.Type;
+import java.sql.Statement;
 import java.sql.Timestamp;
-import java.util.ArrayList;
 import java.util.Date;
 import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
-import java.util.Map.Entry;
 import java.util.Properties;
 import java.util.concurrent.ConcurrentHashMap;
 
 import org.apache.ibatis.binding.MapperMethod.ParamMap;
 import org.apache.ibatis.exceptions.ExceptionFactory;
-import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.executor.statement.StatementHandler;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
 import org.apache.ibatis.mapping.ParameterMapping;
 import org.apache.ibatis.mapping.SqlCommandType;
-import org.apache.ibatis.mapping.SqlSource;
 import org.apache.ibatis.plugin.Interceptor;
 import org.apache.ibatis.plugin.Intercepts;
 import org.apache.ibatis.plugin.Invocation;
 import org.apache.ibatis.plugin.Plugin;
 import org.apache.ibatis.plugin.Signature;
 import org.apache.ibatis.reflection.MetaObject;
+import org.apache.ibatis.reflection.SystemMetaObject;
 import org.apache.ibatis.session.Configuration;
 import org.apache.ibatis.type.TypeException;
 import org.apache.ibatis.type.UnknownTypeHandler;
 
 import com.baomidou.mybatisplus.annotations.TableName;
 import com.baomidou.mybatisplus.annotations.Version;
-import com.baomidou.mybatisplus.mapper.EntityWrapper;
+import com.baomidou.mybatisplus.toolkit.PluginUtils;
 import com.baomidou.mybatisplus.toolkit.StringUtils;
 
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.Expression;
-import net.sf.jsqlparser.expression.JdbcParameter;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
 import net.sf.jsqlparser.parser.CCJSqlParserUtil;
@@ -61,7 +59,7 @@ import net.sf.jsqlparser.statement.update.Update;
  *
  * @author TaoYu
  */
-@Intercepts({ @Signature(type = Executor.class, method = "update", args = { MappedStatement.class, Object.class }) })
+@Intercepts({ @Signature(type = StatementHandler.class, method = "update", args = { Statement.class }) })
 public final class OptimisticLockerInterceptor implements Interceptor {
 
 	/**
@@ -92,29 +90,25 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 	}
 
 	public Object intercept(Invocation invocation) throws Exception {
+		StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
+		MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
 		// 先判断入参为null或者不是真正的UPDATE语句
-		MappedStatement ms = (MappedStatement) invocation.getArgs()[0];
-		Object parameterObject = invocation.getArgs()[1];
-		if (parameterObject == null || !ms.getSqlCommandType().equals(SqlCommandType.UPDATE)) {
+		MappedStatement ms = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
+		if (!ms.getSqlCommandType().equals(SqlCommandType.UPDATE)) {
 			return invocation.proceed();
 		}
+		BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
 		// 获得参数类型,去缓存中快速判断是否有version注解才继续执行
-		Class<? extends Object> parameterClass = parameterObject.getClass();
-		Class<?> realClass = ms.getParameterMap().getType();
-		Object realParameterObject = parameterObject;
-		if (parameterObject instanceof ParamMap) {
-			EntityWrapper<?> tt = (EntityWrapper<?>) ((ParamMap<?>) parameterObject).get("ew");
-			realParameterObject = tt.getEntity();
-		}
-		CachePo versionPo = versionCache.get(realClass);
+		Class<?> parameterClass = ms.getParameterMap().getType();
+		CachePo versionPo = versionCache.get(parameterClass);
 		if (versionPo != null) {
 			if (versionPo.isVersionControl) {
-				return processChangeSql(ms, parameterObject, realParameterObject, versionPo, invocation);
+				processChangeSql(ms, boundSql, versionPo);
 			}
 		} else {
 			String versionColumn = null;
 			Field versionField = null;
-			for (final Field field : realClass.getDeclaredFields()) {
+			for (final Field field : parameterClass.getDeclaredFields()) {
 				if (field.isAnnotationPresent(Version.class)) {
 					if (!typeHandlers.containsKey(field.getType())) {
 						throw new TypeException("乐观锁不支持" + field.getType().getName() + "类型,请自定义实现");
@@ -133,7 +127,7 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 				versionField.setAccessible(true);
 				CachePo cachePo = new CachePo(true, versionColumn, versionField);
 				versionCache.put(parameterClass, cachePo);
-				return processChangeSql(ms, parameterObject, realParameterObject, cachePo, invocation);
+				processChangeSql(ms, boundSql, cachePo);
 			} else {
 				versionCache.put(parameterClass, new CachePo(false));
 			}
@@ -142,76 +136,42 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 
 	}
 
-	private static final Expression JDBCPARAMETER = new JdbcParameter();
 	private static final Expression RIGHTEXPRESSION = new Column("?");
 
 	@SuppressWarnings({ "rawtypes", "unchecked" })
-	private Object processChangeSql(MappedStatement ms, Object parameterObject, Object realParameterObject, CachePo cachePo, Invocation invocation) throws Exception {
+	private void processChangeSql(MappedStatement ms, BoundSql boundSql, CachePo cachePo) throws Exception {
 		Field versionField = cachePo.versionField;
 		String versionColumn = cachePo.versionColumn;
-		final Object versionValue = versionField.get(realParameterObject);
-		SqlSource originSqlSource;
-		if (cachePo.getSqlSource() == null) {
-			originSqlSource = ms.getSqlSource();
-			cachePo.setSqlSource(originSqlSource);
-		} else {
-			originSqlSource = cachePo.getSqlSource();
+		Object parameterObject = boundSql.getParameterObject();
+		if (parameterObject instanceof ParamMap) {
+			parameterObject = ((ParamMap) parameterObject).get("et");
 		}
-		Configuration configuration = ms.getConfiguration();
-		MetaObject metaObject = configuration.newMetaObject(ms);
+		final Object versionValue = versionField.get(parameterObject);
 		if (versionValue != null) {// 先判断传参是否携带version,没带跳过插件
-			BoundSql originBoundSql = originSqlSource.getBoundSql(parameterObject);
-			// 处理
-			Update jsqlSql = (Update) CCJSqlParserUtil.parse(originBoundSql.getSql());
-			List<Column> columns = jsqlSql.getColumns();
-			List<String> columnNames = new ArrayList<String>();
-			for (Column column : columns) {
-				columnNames.add(column.getColumnName());
-			}
-			List<Expression> expressions = jsqlSql.getExpressions();
-			if (!columnNames.contains(versionColumn)) {
-				columns.add(new Column(versionColumn));
-				jsqlSql.setColumns(columns);
-				expressions.add(JDBCPARAMETER);
-				jsqlSql.setExpressions(expressions);
-			}
-
+			Configuration configuration = ms.getConfiguration();
 			// 给字段赋新值
 			VersionHandler targetHandler = typeHandlers.get(versionField.getType());
-			targetHandler.plusVersion(realParameterObject, versionField, versionValue);
-			// 设置sqlSource
-			SqlSource sqlSource = ms.getSqlSource();
-			if (!(sqlSource instanceof OptimisticLockerSqlSource)) {
-				sqlSource = new OptimisticLockerSqlSource(configuration);
-				metaObject.setValue("sqlSource", sqlSource);
-			}
-			OptimisticLockerSqlSource optimisticLockerSqlSource = (OptimisticLockerSqlSource) sqlSource;
+			targetHandler.plusVersion(parameterObject, versionField, versionValue);
 			// 处理where条件,添加?
+			Update jsqlSql = (Update) CCJSqlParserUtil.parse(boundSql.getSql());
 			BinaryExpression expression = (BinaryExpression) jsqlSql.getWhere();
 			if (expression != null && !expression.toString().contains(versionColumn)) {
 				EqualsTo equalsTo = new EqualsTo();
 				equalsTo.setLeftExpression(new Column(versionColumn));
 				equalsTo.setRightExpression(RIGHTEXPRESSION);
 				jsqlSql.setWhere(new AndExpression(equalsTo, expression));
-				List<ParameterMapping> parameterMappings = new LinkedList<ParameterMapping>(originBoundSql.getParameterMappings());
-				parameterMappings.add(expressions.size(), createVersionMapping(configuration));
-				optimisticLockerSqlSource.setParameterMappings(parameterMappings);
-			} else {
-				optimisticLockerSqlSource.setParameterMappings(originBoundSql.getParameterMappings());
+				List<ParameterMapping> parameterMappings = new LinkedList<ParameterMapping>(boundSql.getParameterMappings());
+				parameterMappings.add(jsqlSql.getExpressions().size(), createVersionMapping(configuration));
+				MetaObject boundSqlMeta = configuration.newMetaObject(boundSql);
+				boundSqlMeta.setValue("sql", jsqlSql.toString());
+				boundSqlMeta.setValue("parameterMappings", parameterMappings);
 			}
 			// 设置参数
-			Map<String, Object> additionalParameters = new HashMap<String, Object>();
-			additionalParameters.put("originVersionValue", versionValue);
-			additionalParameters.putAll((Map<String, Object>) configuration.newMetaObject(originBoundSql).getValue("additionalParameters"));
-			optimisticLockerSqlSource.setSql(jsqlSql.toString());
-			optimisticLockerSqlSource.setAdditionalParameters(additionalParameters);
-		} else {
-			metaObject.setValue("sqlSource", originSqlSource);
+			boundSql.setAdditionalParameter("originVersionValue", versionValue);
 		}
-		return invocation.proceed();
 	}
 
-	private ParameterMapping parameterMapping;
+	private volatile ParameterMapping parameterMapping;
 
 	private ParameterMapping createVersionMapping(Configuration configuration) {
 		if (parameterMapping == null) {
@@ -225,7 +185,7 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 	}
 
 	public Object plugin(Object target) {
-		if (target instanceof Executor) {
+		if (target instanceof StatementHandler) {
 			return Plugin.wrap(target, this);
 		}
 		return target;
@@ -267,8 +227,6 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 
 		private Field versionField;
 
-		private SqlSource sqlSource;
-
 		public CachePo(Boolean isVersionControl) {
 			this.isVersionControl = isVersionControl;
 		}
@@ -279,54 +237,8 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 			this.versionField = versionField;
 		}
 
-		public SqlSource getSqlSource() {
-			return sqlSource;
-		}
-
-		public void setSqlSource(SqlSource sqlSource) {
-			this.sqlSource = sqlSource;
-		}
-
 	}
 
-	/**
-	 * 乐观锁数据源,主要是为动态参数设计
-	 */
-	private class OptimisticLockerSqlSource implements SqlSource {
-
-		private Configuration configuration;
-		private String sql;
-		private List<ParameterMapping> parameterMappings;
-		private Map<String, Object> additionalParameters;
-
-		public OptimisticLockerSqlSource(Configuration configuration) {
-			this.configuration = configuration;
-		}
-
-		public BoundSql getBoundSql(Object parameterObject) {
-			BoundSql boundSql = new BoundSql(configuration, sql, parameterMappings, parameterObject);
-			if (additionalParameters != null && additionalParameters.size() > 0) {
-				for (Entry<String, Object> item : additionalParameters.entrySet()) {
-					boundSql.setAdditionalParameter(item.getKey(), item.getValue());
-				}
-			}
-			additionalParameters.clear();
-			return boundSql;
-		}
-
-		public void setSql(String sql) {
-			this.sql = sql;
-		}
-
-		public void setParameterMappings(List<ParameterMapping> parameterMappings) {
-			this.parameterMappings = parameterMappings;
-		}
-
-		public void setAdditionalParameters(Map<String, Object> additionalParameters) {
-			this.additionalParameters = additionalParameters;
-		}
-
-	}
 	// *****************************基本类型处理器*****************************
 
 	private static class ShortTypeHandler implements VersionHandler<Short> {

+ 2 - 1
mybatis-plus/src/test/resources/plugins/optimisticLockerInterceptor.xml

@@ -19,7 +19,7 @@
 			value="com.baomidou.mybatisplus.test.plugins.optimisticLocker.entity" />
 		<property name="plugins">
 			<array>
-				<bean class="com.baomidou.mybatisplus.plugins.PerformanceInterceptor" />
+				
 				<bean class="com.baomidou.mybatisplus.plugins.OptimisticLockerInterceptor">
 					<property name="properties">
 						<value>
@@ -27,6 +27,7 @@
 						</value>
 					</property>
 				</bean>
+				<bean class="com.baomidou.mybatisplus.plugins.PerformanceInterceptor" />
 			</array>
 		</property>
 	</bean>