Browse Source

初步支持多参数

小锅盖 8 years ago
parent
commit
20f9f7f646

+ 26 - 27
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/OptimisticLockerInterceptor.java

@@ -17,7 +17,6 @@ import java.util.concurrent.ConcurrentHashMap;
 import org.apache.ibatis.binding.MapperMethod.ParamMap;
 import org.apache.ibatis.exceptions.ExceptionFactory;
 import org.apache.ibatis.executor.Executor;
-import org.apache.ibatis.javassist.util.proxy.ProxyFactory;
 import org.apache.ibatis.mapping.BoundSql;
 import org.apache.ibatis.mapping.MappedStatement;
 import org.apache.ibatis.mapping.ParameterMapping;
@@ -29,17 +28,18 @@ import org.apache.ibatis.plugin.Invocation;
 import org.apache.ibatis.plugin.Plugin;
 import org.apache.ibatis.plugin.Signature;
 import org.apache.ibatis.reflection.MetaObject;
-import org.apache.ibatis.reflection.SystemMetaObject;
 import org.apache.ibatis.session.Configuration;
 import org.apache.ibatis.type.TypeException;
 import org.apache.ibatis.type.UnknownTypeHandler;
 
 import com.baomidou.mybatisplus.annotations.TableName;
 import com.baomidou.mybatisplus.annotations.Version;
+import com.baomidou.mybatisplus.mapper.EntityWrapper;
 import com.baomidou.mybatisplus.toolkit.StringUtils;
 
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.JdbcParameter;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
 import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
 import net.sf.jsqlparser.parser.CCJSqlParserUtil;
@@ -99,20 +99,17 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		}
 		// 获得参数类型,去缓存中快速判断是否有version注解才继续执行
 		Class<? extends Object> parameterClass = parameterObject.getClass();
-		Class<?> realClass = null;
+		Class<?> realClass = ms.getParameterMap().getType();
+		;
+		Object realParameterObject = parameterObject;
 		if (parameterObject instanceof ParamMap) {
-			// FIXME
-			ParamMap<?> tt = (ParamMap<?>) parameterObject;
-			realClass = tt.get("param1").getClass();
-		} else if (ProxyFactory.isProxyClass(parameterClass)) {
-			realClass = parameterClass.getSuperclass();
-		} else {
-			realClass = parameterClass;
+			EntityWrapper<?> tt = (EntityWrapper<?>) ((ParamMap<?>) parameterObject).get("ew");
+			realParameterObject = tt.getEntity();
 		}
 		VersionCache versionPo = versionCache.get(realClass);
 		if (versionPo != null) {
 			if (versionPo.isVersionControl) {
-				return processChangeSql(ms, parameterObject, versionPo, invocation);
+				return processChangeSql(ms, parameterObject, realParameterObject, versionPo, invocation);
 			}
 		} else {
 			String versionColumn = null;
@@ -136,7 +133,7 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 				versionField.setAccessible(true);
 				VersionCache cachePo = new VersionCache(true, versionColumn, versionField);
 				versionCache.put(parameterClass, cachePo);
-				return processChangeSql(ms, parameterObject, cachePo, invocation);
+				return processChangeSql(ms, parameterObject, realParameterObject, cachePo, invocation);
 			} else {
 				versionCache.put(parameterClass, new VersionCache(false));
 			}
@@ -145,43 +142,48 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 
 	}
 
+	private static final Expression JDBCPARAMETER = new JdbcParameter();
+	private static final Expression RIGHTEXPRESSION = new Column("?");
+
 	@SuppressWarnings({ "rawtypes", "unchecked" })
-	private Object processChangeSql(MappedStatement ms, Object parameterObject, VersionCache versionPo, Invocation invocation) throws Exception {
+	private Object processChangeSql(MappedStatement ms, Object parameterObject, Object realParameterObject, VersionCache versionPo, Invocation invocation) throws Exception {
 		Field versionField = versionPo.versionField;
 		String versionColumn = versionPo.versionColumn;
-		final Object versionValue = versionField.get(parameterObject);
+		final Object versionValue = versionField.get(realParameterObject);
 		if (versionValue != null) {// 先判断传参是否携带version,没带跳过插件
 			Configuration configuration = ms.getConfiguration();
 			BoundSql originBoundSql = ms.getBoundSql(parameterObject);
 			SqlSource originSqlSource = ms.getSqlSource();
-			MetaObject metaObject = SystemMetaObject.forObject(ms);
-			// 解析sql,预处理更新字段没有version字段的情况
+			MetaObject metaObject = configuration.newMetaObject(ms);
 			try {
+				// 处理
 				Update jsqlSql = (Update) CCJSqlParserUtil.parse(originBoundSql.getSql());
 				List<Column> columns = jsqlSql.getColumns();
 				List<String> columnNames = new ArrayList<String>();
 				for (Column column : columns) {
 					columnNames.add(column.getColumnName());
 				}
+				List<Expression> expressions = jsqlSql.getExpressions();
 				if (!columnNames.contains(versionColumn)) {
 					columns.add(new Column(versionColumn));
 					jsqlSql.setColumns(columns);
+					expressions.add(JDBCPARAMETER);
+					jsqlSql.setExpressions(expressions);
 				}
-				// 添加条件
+				// 处理where条件,添加?
 				BinaryExpression expression = (BinaryExpression) jsqlSql.getWhere();
 				if (expression != null && !expression.toString().contains(versionColumn)) {
 					EqualsTo equalsTo = new EqualsTo();
 					equalsTo.setLeftExpression(new Column(versionColumn));
-					Expression rightExpression = new Column("?");
-					equalsTo.setRightExpression(rightExpression);
+					equalsTo.setRightExpression(RIGHTEXPRESSION);
 					jsqlSql.setWhere(new AndExpression(equalsTo, expression));
 				}
 				// 给字段赋新值
 				VersionHandler targetHandler = typeHandlers.get(versionField.getType());
-				targetHandler.plusVersion(parameterObject, versionField, versionValue);
+				targetHandler.plusVersion(realParameterObject, versionField, versionValue);
 				// 设置sqlSource
 				List<ParameterMapping> parameterMappings = new LinkedList<ParameterMapping>(originBoundSql.getParameterMappings());
-				parameterMappings.add(jsqlSql.getExpressions().size(), createVersionMapping(configuration));
+				parameterMappings.add(expressions.size(), createVersionMapping(configuration));
 				Map<String, Object> additionalParameters = new HashMap<String, Object>();
 				additionalParameters.put("originVersionValue", versionValue);
 				SqlSource sqlSource = new OptimisticLockerSqlSource(configuration, jsqlSql.toString(), parameterMappings, additionalParameters);
@@ -242,6 +244,9 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		typeHandlers.put(type, versionHandler);
 	}
 
+	/**
+	 * 缓存对象
+	 */
 	private class VersionCache {
 
 		private Boolean isVersionControl;
@@ -298,33 +303,27 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 	}
 
 	private static class IntegerTypeHnadler implements VersionHandler<Integer> {
-
 		public void plusVersion(Object paramObj, Field field, Integer versionValue) throws Exception {
 			field.set(paramObj, versionValue + 1);
 		}
 	}
 
 	private static class LongTypeHnadler implements VersionHandler<Long> {
-
 		public void plusVersion(Object paramObj, Field field, Long versionValue) throws Exception {
 			field.set(paramObj, versionValue + 1);
 		}
-
 	}
 
 	// ***************************** 时间类型处理器*****************************
 	private static class DateTypeHandler implements VersionHandler<Date> {
-
 		public void plusVersion(Object paramObj, Field field, Date versionValue) throws Exception {
 			field.set(paramObj, new Date());
 		}
 	}
 
 	private static class TimestampTypeHandler implements VersionHandler<Timestamp> {
-
 		public void plusVersion(Object paramObj, Field field, Timestamp versionValue) throws Exception {
 			field.set(paramObj, new Timestamp(new Date().getTime()));
 		}
 	}
-
 }

+ 4 - 2
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/plugins/optimisticLocker/OptimisticLockerInterceptorTest.java

@@ -17,6 +17,7 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 
+import com.baomidou.mybatisplus.mapper.EntityWrapper;
 import com.baomidou.mybatisplus.test.plugins.optimisticLocker.entity.DateVersionUser;
 import com.baomidou.mybatisplus.test.plugins.optimisticLocker.entity.IntVersionUser;
 import com.baomidou.mybatisplus.test.plugins.optimisticLocker.entity.LongVersionUser;
@@ -158,8 +159,9 @@ public class OptimisticLockerInterceptorTest {
 		IntVersionUser versionUser = intVersionUserMapper.selectById(2);
 		Integer originVersion = versionUser.getVersion();
 		// 更新数据
-		versionUser.setName("苗神");
-		intVersionUserMapper.updateById(versionUser);
+		IntVersionUser intVersionUser = new IntVersionUser();
+		intVersionUser.setName("苗神");
+		intVersionUserMapper.update(versionUser, new EntityWrapper<IntVersionUser>(versionUser));
 		Assert.assertTrue(versionUser.getVersion() == originVersion + 1);
 	}
 }