yuxiaobin 8 vuotta sitten
vanhempi
commit
da4ff8bb07

+ 58 - 3
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/OptimisticLockerInterceptor.java

@@ -35,6 +35,7 @@ import com.baomidou.mybatisplus.toolkit.StringUtils;
 
 import net.sf.jsqlparser.expression.BinaryExpression;
 import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
 import net.sf.jsqlparser.expression.LongValue;
 import net.sf.jsqlparser.expression.StringValue;
 import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
@@ -82,9 +83,16 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		if (parameterObject == null || !ms.getSqlCommandType().equals(SqlCommandType.UPDATE)) {
 			return invocation.proceed();
 		}
+		
 		// 获得参数类型,去缓存中快速判断是否有version注解才继续执行
 		Class<? extends Object> parameterClass = parameterObject.getClass();
-		VersionCache versionPo = versionCache.get(parameterClass);
+		Class<?> realClass = null;
+		if(org.apache.ibatis.javassist.util.proxy.ProxyFactory.isProxyClass(parameterClass)){
+			realClass = parameterClass.getSuperclass();
+		}else{
+			realClass = parameterClass;
+		}
+		VersionCache versionPo = versionCache.get(realClass);
 		if (versionPo != null) {
 			if (versionPo.isVersionControl) {
 				processChangeSql(ms, parameterObject, versionPo);
@@ -92,7 +100,7 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		} else {
 			String versionColumn = null;
 			Field versionField = null;
-			for (Field field : parameterClass.getDeclaredFields()) {
+			for (Field field : realClass.getDeclaredFields()) {
 				if (field.isAnnotationPresent(Version.class)) {
 					if (!typeHandlers.containsKey(field.getType())) {
 						throw new TypeException("乐观锁不支持" + field.getType().getName() + "类型,请自定义实现");
@@ -136,6 +144,42 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 			if (!columnNames.contains(versionColumn)) {// 如果sql没有version手动加一个
 				columns.add(new Column(versionColumn));
 				parse.setColumns(columns);
+			}else{
+				VersionHandler targetHandler = typeHandlers.get(versionField.getType());
+				Expression plusExpression = targetHandler.getPlusExpression(versionValue);
+				VersionExpSimpleTypeVisitor visitor = new VersionExpSimpleTypeVisitor(); 
+				plusExpression.accept(visitor);
+				if(visitor.isLongValue()){
+					Object versionNewValue = null;
+					String versionClassname = versionValue.getClass().getSimpleName();
+					switch (versionClassname) {
+					case "Long":
+						versionNewValue = Long.parseLong(plusExpression.toString());
+						break;
+					case "long":
+						versionNewValue = Long.parseLong(plusExpression.toString());
+						break;
+					case "Integer":
+						versionNewValue = Integer.parseInt(plusExpression.toString());
+						break;
+					case "int":
+						versionNewValue = Integer.parseInt(plusExpression.toString());
+						break;
+					case "Short":
+						versionNewValue = Short.parseShort(plusExpression.toString());
+						break;
+					case "short":
+						versionNewValue = Short.parseShort(plusExpression.toString());
+						break;
+					default:
+						versionNewValue = versionValue;//not support
+						break;
+					}
+					versionField.set(parameterObject, versionNewValue);
+				}else{
+					//TODO: 自定义VersionHandler处理 
+					
+				}
 			}
 			BinaryExpression expression = (BinaryExpression) parse.getWhere();
 			if (expression != null && !expression.toString().contains(versionColumn)) {
@@ -225,7 +269,7 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		}
 
 		public Expression getPlusExpression(Object param) {
-			return new LongValue(param.toString() + 1);
+			return new LongValue(Long.parseLong(param.toString()) + 1);
 		}
 
 	}
@@ -266,5 +310,16 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		}
 
 	}
+	
+	private static class VersionExpSimpleTypeVisitor extends ExpressionVisitorAdapter {
+		private boolean longValue = false;
+		@Override
+		public void visit(LongValue value) {
+			longValue = true;
+		}
+		public boolean isLongValue() {
+			return longValue;
+		}
+	}
 
 }

+ 5 - 1
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/plugin/OptimisticLockerInterceptorJUnitTest.java

@@ -8,6 +8,7 @@ import org.junit.runner.RunWith;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 
+import com.baomidou.mybatisplus.mapper.EntityWrapper;
 import com.baomidou.mybatisplus.test.mysql.entity.User;
 
 
@@ -32,7 +33,7 @@ public class OptimisticLockerInterceptorJUnitTest extends UserTestBase{
 		userService.updateById(user);
 		user = userService.selectById(11);
 		Assert.assertEquals(2, user.getAge().intValue());
-		Assert.assertEquals(1, user.getVersion().intValue());
+		Assert.assertEquals(null, user.getVersion());
 	}
 	
 	@Test
@@ -42,6 +43,9 @@ public class OptimisticLockerInterceptorJUnitTest extends UserTestBase{
 		Assert.assertEquals(2, user.getAge().intValue());
 		user.setAge(3);
 		userService.updateById(user);
+		User where = new User();
+		where.setId(userId);
+		userService.update(user, new EntityWrapper<User>(where));
 		user = userService.selectById(userId);
 		Assert.assertEquals(3, user.getAge().intValue());
 		Assert.assertEquals(2, user.getVersion().intValue());