浏览代码

多线程下问题修复, 仍不能解决多参数问题.

小锅盖 8 年之前
父节点
当前提交
efc122f19d

+ 99 - 61
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/OptimisticLockerInterceptor.java

@@ -48,6 +48,7 @@ import net.sf.jsqlparser.statement.update.Update;
 
 /**
  * MyBatis乐观锁插件
+ * <p>
  * 
  * <pre>
  * 之前:update user set name = ?, password = ? where id = ?
@@ -57,7 +58,7 @@ import net.sf.jsqlparser.statement.update.Update;
  * 支持short,Short,int Integer, long Long, Date Timestamp
  * 其他类型可以自定义实现,注入versionHandlers,多个以逗号分隔
  * </pre>
- * 
+ *
  * @author TaoYu
  */
 @Intercepts({ @Signature(type = Executor.class, method = "update", args = { MappedStatement.class, Object.class }) })
@@ -66,7 +67,7 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 	/**
 	 * 根据对象类型缓存version基本信息
 	 */
-	private static final Map<Class<?>, VersionCache> versionCache = new ConcurrentHashMap<Class<?>, VersionCache>();
+	private static final Map<Class<?>, CachePo> versionCache = new ConcurrentHashMap<Class<?>, CachePo>();
 
 	/**
 	 * 根据version字段类型缓存的处理器
@@ -74,11 +75,11 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 	private static final Map<Class<?>, VersionHandler<?>> typeHandlers = new HashMap<Class<?>, VersionHandler<?>>();
 
 	static {
-		ShortTypeHnadler shortTypeHnadler = new ShortTypeHnadler();
+		ShortTypeHandler shortTypeHnadler = new ShortTypeHandler();
 		typeHandlers.put(short.class, shortTypeHnadler);
 		typeHandlers.put(Short.class, shortTypeHnadler);
 
-		IntegerTypeHnadler integerTypeHnadler = new IntegerTypeHnadler();
+		IntegerTypeHandler integerTypeHnadler = new IntegerTypeHandler();
 		typeHandlers.put(int.class, integerTypeHnadler);
 		typeHandlers.put(Integer.class, integerTypeHnadler);
 
@@ -100,13 +101,12 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		// 获得参数类型,去缓存中快速判断是否有version注解才继续执行
 		Class<? extends Object> parameterClass = parameterObject.getClass();
 		Class<?> realClass = ms.getParameterMap().getType();
-		;
 		Object realParameterObject = parameterObject;
 		if (parameterObject instanceof ParamMap) {
 			EntityWrapper<?> tt = (EntityWrapper<?>) ((ParamMap<?>) parameterObject).get("ew");
 			realParameterObject = tt.getEntity();
 		}
-		VersionCache versionPo = versionCache.get(realClass);
+		CachePo versionPo = versionCache.get(realClass);
 		if (versionPo != null) {
 			if (versionPo.isVersionControl) {
 				return processChangeSql(ms, parameterObject, realParameterObject, versionPo, invocation);
@@ -131,11 +131,11 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 			}
 			if (versionField != null) {
 				versionField.setAccessible(true);
-				VersionCache cachePo = new VersionCache(true, versionColumn, versionField);
+				CachePo cachePo = new CachePo(true, versionColumn, versionField);
 				versionCache.put(parameterClass, cachePo);
 				return processChangeSql(ms, parameterObject, realParameterObject, cachePo, invocation);
 			} else {
-				versionCache.put(parameterClass, new VersionCache(false));
+				versionCache.put(parameterClass, new CachePo(false));
 			}
 		}
 		return invocation.proceed();
@@ -146,55 +146,67 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 	private static final Expression RIGHTEXPRESSION = new Column("?");
 
 	@SuppressWarnings({ "rawtypes", "unchecked" })
-	private Object processChangeSql(MappedStatement ms, Object parameterObject, Object realParameterObject, VersionCache versionPo, Invocation invocation) throws Exception {
-		Field versionField = versionPo.versionField;
-		String versionColumn = versionPo.versionColumn;
+	private Object processChangeSql(MappedStatement ms, Object parameterObject, Object realParameterObject, CachePo cachePo, Invocation invocation) throws Exception {
+		Field versionField = cachePo.versionField;
+		String versionColumn = cachePo.versionColumn;
 		final Object versionValue = versionField.get(realParameterObject);
+		SqlSource originSqlSource;
+		if (cachePo.getSqlSource() == null) {
+			originSqlSource = ms.getSqlSource();
+			cachePo.setSqlSource(originSqlSource);
+		} else {
+			originSqlSource = cachePo.getSqlSource();
+		}
+		Configuration configuration = ms.getConfiguration();
+		MetaObject metaObject = configuration.newMetaObject(ms);
 		if (versionValue != null) {// 先判断传参是否携带version,没带跳过插件
-			Configuration configuration = ms.getConfiguration();
-			BoundSql originBoundSql = ms.getBoundSql(parameterObject);
-			SqlSource originSqlSource = ms.getSqlSource();
-			MetaObject metaObject = configuration.newMetaObject(ms);
-			try {
-				// 处理
-				Update jsqlSql = (Update) CCJSqlParserUtil.parse(originBoundSql.getSql());
-				List<Column> columns = jsqlSql.getColumns();
-				List<String> columnNames = new ArrayList<String>();
-				for (Column column : columns) {
-					columnNames.add(column.getColumnName());
-				}
-				List<Expression> expressions = jsqlSql.getExpressions();
-				if (!columnNames.contains(versionColumn)) {
-					columns.add(new Column(versionColumn));
-					jsqlSql.setColumns(columns);
-					expressions.add(JDBCPARAMETER);
-					jsqlSql.setExpressions(expressions);
-				}
-				// 处理where条件,添加?
-				BinaryExpression expression = (BinaryExpression) jsqlSql.getWhere();
-				if (expression != null && !expression.toString().contains(versionColumn)) {
-					EqualsTo equalsTo = new EqualsTo();
-					equalsTo.setLeftExpression(new Column(versionColumn));
-					equalsTo.setRightExpression(RIGHTEXPRESSION);
-					jsqlSql.setWhere(new AndExpression(equalsTo, expression));
-				}
-				// 给字段赋新值
-				VersionHandler targetHandler = typeHandlers.get(versionField.getType());
-				targetHandler.plusVersion(realParameterObject, versionField, versionValue);
-				// 设置sqlSource
+			BoundSql originBoundSql = originSqlSource.getBoundSql(parameterObject);
+			// 处理
+			Update jsqlSql = (Update) CCJSqlParserUtil.parse(originBoundSql.getSql());
+			List<Column> columns = jsqlSql.getColumns();
+			List<String> columnNames = new ArrayList<String>();
+			for (Column column : columns) {
+				columnNames.add(column.getColumnName());
+			}
+			List<Expression> expressions = jsqlSql.getExpressions();
+			if (!columnNames.contains(versionColumn)) {
+				columns.add(new Column(versionColumn));
+				jsqlSql.setColumns(columns);
+				expressions.add(JDBCPARAMETER);
+				jsqlSql.setExpressions(expressions);
+			}
+
+			// 给字段赋新值
+			VersionHandler targetHandler = typeHandlers.get(versionField.getType());
+			targetHandler.plusVersion(realParameterObject, versionField, versionValue);
+			// 设置sqlSource
+			SqlSource sqlSource = ms.getSqlSource();
+			if (!(sqlSource instanceof OptimisticLockerSqlSource)) {
+				sqlSource = new OptimisticLockerSqlSource(configuration);
+				metaObject.setValue("sqlSource", sqlSource);
+			}
+			OptimisticLockerSqlSource optimisticLockerSqlSource = (OptimisticLockerSqlSource) sqlSource;
+			// 处理where条件,添加?
+			BinaryExpression expression = (BinaryExpression) jsqlSql.getWhere();
+			if (expression != null && !expression.toString().contains(versionColumn)) {
+				EqualsTo equalsTo = new EqualsTo();
+				equalsTo.setLeftExpression(new Column(versionColumn));
+				equalsTo.setRightExpression(RIGHTEXPRESSION);
+				jsqlSql.setWhere(new AndExpression(equalsTo, expression));
 				List<ParameterMapping> parameterMappings = new LinkedList<ParameterMapping>(originBoundSql.getParameterMappings());
 				parameterMappings.add(expressions.size(), createVersionMapping(configuration));
-				Map<String, Object> additionalParameters = new HashMap<String, Object>();
-				additionalParameters.put("originVersionValue", versionValue);
-				SqlSource sqlSource = new OptimisticLockerSqlSource(configuration, jsqlSql.toString(), parameterMappings, additionalParameters);
-				metaObject.setValue("sqlSource", sqlSource);
-				return invocation.proceed();
-			} catch (Exception e) {
-				throw ExceptionFactory.wrapException("乐观锁插件执行失败", e);
-			} finally {
-				metaObject.setValue("sqlSource", originSqlSource);
+				optimisticLockerSqlSource.setParameterMappings(parameterMappings);
+			} else {
+				optimisticLockerSqlSource.setParameterMappings(originBoundSql.getParameterMappings());
 			}
-
+			// 设置参数
+			Map<String, Object> additionalParameters = new HashMap<String, Object>();
+			additionalParameters.put("originVersionValue", versionValue);
+			additionalParameters.putAll((Map<String, Object>) configuration.newMetaObject(originBoundSql).getValue("additionalParameters"));
+			optimisticLockerSqlSource.setSql(jsqlSql.toString());
+			optimisticLockerSqlSource.setAdditionalParameters(additionalParameters);
+		} else {
+			metaObject.setValue("sqlSource", originSqlSource);
 		}
 		return invocation.proceed();
 	}
@@ -247,7 +259,7 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 	/**
 	 * 缓存对象
 	 */
-	private class VersionCache {
+	private class CachePo {
 
 		private Boolean isVersionControl;
 
@@ -255,15 +267,26 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 
 		private Field versionField;
 
-		public VersionCache(Boolean isVersionControl) {
+		private SqlSource sqlSource;
+
+		public CachePo(Boolean isVersionControl) {
 			this.isVersionControl = isVersionControl;
 		}
 
-		public VersionCache(Boolean isVersionControl, String versionColumn, Field versionField) {
+		public CachePo(Boolean isVersionControl, String versionColumn, Field versionField) {
 			this.isVersionControl = isVersionControl;
 			this.versionColumn = versionColumn;
 			this.versionField = versionField;
 		}
+
+		public SqlSource getSqlSource() {
+			return sqlSource;
+		}
+
+		public void setSqlSource(SqlSource sqlSource) {
+			this.sqlSource = sqlSource;
+		}
+
 	}
 
 	/**
@@ -276,11 +299,8 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 		private List<ParameterMapping> parameterMappings;
 		private Map<String, Object> additionalParameters;
 
-		public OptimisticLockerSqlSource(Configuration configuration, String sql, List<ParameterMapping> parameterMappings, Map<String, Object> additionalParameters) {
+		public OptimisticLockerSqlSource(Configuration configuration) {
 			this.configuration = configuration;
-			this.sql = sql;
-			this.parameterMappings = parameterMappings;
-			this.additionalParameters = additionalParameters;
 		}
 
 		public BoundSql getBoundSql(Object parameterObject) {
@@ -290,25 +310,41 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 					boundSql.setAdditionalParameter(item.getKey(), item.getValue());
 				}
 			}
+			additionalParameters.clear();
 			return boundSql;
 		}
 
+		public void setSql(String sql) {
+			this.sql = sql;
+		}
+
+		public void setParameterMappings(List<ParameterMapping> parameterMappings) {
+			this.parameterMappings = parameterMappings;
+		}
+
+		public void setAdditionalParameters(Map<String, Object> additionalParameters) {
+			this.additionalParameters = additionalParameters;
+		}
+
 	}
 	// *****************************基本类型处理器*****************************
 
-	private static class ShortTypeHnadler implements VersionHandler<Short> {
+	private static class ShortTypeHandler implements VersionHandler<Short> {
+
 		public void plusVersion(Object paramObj, Field field, Short versionValue) throws Exception {
 			field.set(paramObj, (short) (versionValue + 1));
 		}
 	}
 
-	private static class IntegerTypeHnadler implements VersionHandler<Integer> {
+	private static class IntegerTypeHandler implements VersionHandler<Integer> {
+
 		public void plusVersion(Object paramObj, Field field, Integer versionValue) throws Exception {
 			field.set(paramObj, versionValue + 1);
 		}
 	}
 
 	private static class LongTypeHnadler implements VersionHandler<Long> {
+
 		public void plusVersion(Object paramObj, Field field, Long versionValue) throws Exception {
 			field.set(paramObj, versionValue + 1);
 		}
@@ -316,12 +352,14 @@ public final class OptimisticLockerInterceptor implements Interceptor {
 
 	// ***************************** 时间类型处理器*****************************
 	private static class DateTypeHandler implements VersionHandler<Date> {
+
 		public void plusVersion(Object paramObj, Field field, Date versionValue) throws Exception {
 			field.set(paramObj, new Date());
 		}
 	}
 
 	private static class TimestampTypeHandler implements VersionHandler<Timestamp> {
+
 		public void plusVersion(Object paramObj, Field field, Timestamp versionValue) throws Exception {
 			field.set(paramObj, new Timestamp(new Date().getTime()));
 		}

+ 28 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/plugins/optimisticLocker/OptimisticLockerInterceptorTest.java

@@ -4,6 +4,7 @@ import java.io.Reader;
 import java.sql.Connection;
 import java.sql.Timestamp;
 import java.util.Date;
+import java.util.Random;
 
 import org.apache.ibatis.io.Resources;
 import org.apache.ibatis.jdbc.ScriptRunner;
@@ -34,6 +35,7 @@ import com.baomidou.mybatisplus.test.plugins.optimisticLocker.mapper.TimestampVe
 @RunWith(SpringJUnit4ClassRunner.class)
 @ContextConfiguration(locations = { "/plugins/optimisticLockerInterceptor.xml" })
 public class OptimisticLockerInterceptorTest {
+
 	@Autowired
 	private ShortVersionUserMapper shortVersionUserMapper;
 	@Autowired
@@ -153,6 +155,32 @@ public class OptimisticLockerInterceptorTest {
 		Assert.assertEquals(versionUser.getVersion(), String.valueOf(Long.parseLong(originVersion) + 1));
 	}
 
+	@Test
+	public void multiThreadVersionTest() {
+		final Random random = new Random();
+		for (int i = 50; i < 150; i++) {
+			new Thread(new Runnable() {
+				public void run() {
+					IntVersionUser intVersionUser = new IntVersionUser();
+					intVersionUser.setId(random.nextLong());
+					int version = random.nextInt();
+					intVersionUser.setName("改前" + version);
+					intVersionUser.setVersion(version);
+					intVersionUserMapper.insert(intVersionUser);
+					intVersionUser.setName("改后" + version);
+					intVersionUserMapper.updateById(intVersionUser);
+					Assert.assertTrue(intVersionUser.getVersion() == version + 1);
+				}
+			}, "编号" + i).start();
+		}
+
+		try {
+			Thread.sleep(4000);
+		} catch (InterruptedException e) {
+			e.printStackTrace();
+		}
+	}
+
 	@Test
 	public void multiParamVersionTest() {
 		// 查询数据