|
@@ -17,12 +17,19 @@ package com.baomidou.mybatisplus.extension.plugins.inner;
|
|
|
|
|
|
import com.baomidou.mybatisplus.annotation.Version;
|
|
|
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
|
|
|
+import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
|
|
|
+import com.baomidou.mybatisplus.core.conditions.Wrapper;
|
|
|
+import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
|
|
|
+import com.baomidou.mybatisplus.core.conditions.update.Update;
|
|
|
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
|
|
+import com.baomidou.mybatisplus.core.enums.SqlKeyword;
|
|
|
+import com.baomidou.mybatisplus.core.mapper.Mapper;
|
|
|
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
|
|
|
import com.baomidou.mybatisplus.core.metadata.TableInfo;
|
|
|
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
|
|
|
import com.baomidou.mybatisplus.core.toolkit.Constants;
|
|
|
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
|
|
|
+import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
|
|
|
import com.baomidou.mybatisplus.core.toolkit.StringPool;
|
|
|
import org.apache.ibatis.executor.Executor;
|
|
|
import org.apache.ibatis.mapping.MappedStatement;
|
|
@@ -34,6 +41,10 @@ import java.sql.Timestamp;
|
|
|
import java.time.LocalDateTime;
|
|
|
import java.util.Date;
|
|
|
import java.util.Map;
|
|
|
+import java.util.Objects;
|
|
|
+import java.util.concurrent.ConcurrentHashMap;
|
|
|
+import java.util.regex.Matcher;
|
|
|
+import java.util.regex.Pattern;
|
|
|
|
|
|
/**
|
|
|
* Optimistic Lock Light version
|
|
@@ -63,6 +74,31 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
|
|
|
this.exception = exception;
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * entity类缓存
|
|
|
+ */
|
|
|
+ private static final Map<String, Class<?>> ENTITY_CLASS_CACHE = new ConcurrentHashMap<>();
|
|
|
+ /**
|
|
|
+ * 变量占位符正则
|
|
|
+ */
|
|
|
+ private static final Pattern PARAM_PAIRS_RE = Pattern.compile("#\\{ew\\.paramNameValuePairs\\.(" + Constants.WRAPPER_PARAM + "\\d+)\\}");
|
|
|
+ /**
|
|
|
+ * paramNameValuePairs存放的version值的key
|
|
|
+ */
|
|
|
+ private static final String UPDATED_VERSION_VAL_KEY = "#updatedVersionVal#";
|
|
|
+ /**
|
|
|
+ * Support wrapper mode
|
|
|
+ */
|
|
|
+ private final boolean wrapperMode;
|
|
|
+
|
|
|
+ public OptimisticLockerInnerInterceptor() {
|
|
|
+ this(false);
|
|
|
+ }
|
|
|
+
|
|
|
+ public OptimisticLockerInnerInterceptor(boolean wrapperMode) {
|
|
|
+ this.wrapperMode = wrapperMode;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
|
|
|
if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
|
|
@@ -75,17 +111,17 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
|
|
|
}
|
|
|
|
|
|
protected void doOptimisticLocker(Map<String, Object> map, String msId) {
|
|
|
- //updateById(et), update(et, wrapper);
|
|
|
+ // updateById(et), update(et, wrapper);
|
|
|
Object et = map.getOrDefault(Constants.ENTITY, null);
|
|
|
- if (et != null) {
|
|
|
- // entity
|
|
|
- String methodName = msId.substring(msId.lastIndexOf(StringPool.DOT) + 1);
|
|
|
- TableInfo tableInfo = TableInfoHelper.getTableInfo(et.getClass());
|
|
|
- if (tableInfo == null || !tableInfo.isWithVersion()) {
|
|
|
+ if (Objects.nonNull(et)) {
|
|
|
+
|
|
|
+ // version field
|
|
|
+ TableFieldInfo fieldInfo = this.getVersionFieldInfo(et.getClass());
|
|
|
+ if (null == fieldInfo) {
|
|
|
return;
|
|
|
}
|
|
|
+
|
|
|
try {
|
|
|
- TableFieldInfo fieldInfo = tableInfo.getVersionFieldInfo();
|
|
|
Field versionField = fieldInfo.getField();
|
|
|
// 旧的 version 值
|
|
|
Object originalVersionVal = versionField.get(et);
|
|
@@ -101,6 +137,7 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
|
|
|
String versionColumn = fieldInfo.getColumn();
|
|
|
// 新的 version 值
|
|
|
Object updatedVersionVal = this.getUpdatedVersionVal(fieldInfo.getPropertyType(), originalVersionVal);
|
|
|
+ String methodName = msId.substring(msId.lastIndexOf(StringPool.DOT) + 1);
|
|
|
if ("update".equals(methodName)) {
|
|
|
AbstractWrapper<?, ?, ?> aw = (AbstractWrapper<?, ?, ?>) map.getOrDefault(Constants.WRAPPER, null);
|
|
|
if (aw == null) {
|
|
@@ -118,6 +155,124 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
|
|
|
throw ExceptionUtils.mpe(e);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ // update(LambdaUpdateWrapper) or update(UpdateWrapper)
|
|
|
+ else if (wrapperMode && map.entrySet().stream().anyMatch(t -> Objects.equals(t.getKey(), Constants.WRAPPER))) {
|
|
|
+ setVersionByWrapper(map, msId);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ protected TableFieldInfo getVersionFieldInfo(Class<?> entityClazz) {
|
|
|
+ TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClazz);
|
|
|
+ return (null != tableInfo && tableInfo.isWithVersion()) ? tableInfo.getVersionFieldInfo() : null;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void setVersionByWrapper(Map<String, Object> map, String msId) {
|
|
|
+ Object ew = map.get(Constants.WRAPPER);
|
|
|
+ if (null != ew && ew instanceof AbstractWrapper && ew instanceof Update) {
|
|
|
+ Class<?> entityClass = ENTITY_CLASS_CACHE.get(msId);
|
|
|
+ if (null == entityClass) {
|
|
|
+ try {
|
|
|
+ final String className = msId.substring(0, msId.lastIndexOf('.'));
|
|
|
+ entityClass = ReflectionKit.getSuperClassGenericType(Class.forName(className), Mapper.class, 0);
|
|
|
+ ENTITY_CLASS_CACHE.put(msId, entityClass);
|
|
|
+ } catch (ClassNotFoundException e) {
|
|
|
+ throw ExceptionUtils.mpe(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ final TableFieldInfo versionField = getVersionFieldInfo(entityClass);
|
|
|
+ if (null == versionField) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ final String versionColumn = versionField.getColumn();
|
|
|
+ final FieldEqFinder fieldEqFinder = new FieldEqFinder(versionColumn, (Wrapper<?>) ew);
|
|
|
+ if (!fieldEqFinder.isPresent()) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ final Map<String, Object> paramNameValuePairs = ((AbstractWrapper<?, ?, ?>) ew).getParamNameValuePairs();
|
|
|
+ final Object originalVersionValue = paramNameValuePairs.get(fieldEqFinder.valueKey);
|
|
|
+ if (originalVersionValue == null) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ final Object updatedVersionVal = getUpdatedVersionVal(originalVersionValue.getClass(), originalVersionValue);
|
|
|
+ if (originalVersionValue == updatedVersionVal) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // 拼接新的version值
|
|
|
+ paramNameValuePairs.put(UPDATED_VERSION_VAL_KEY, updatedVersionVal);
|
|
|
+ ((Update<?, ?>) ew).setSql(String.format("%s = #{%s.%s}", versionColumn, "ew.paramNameValuePairs", UPDATED_VERSION_VAL_KEY));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * EQ字段查找器
|
|
|
+ */
|
|
|
+ private static class FieldEqFinder {
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 状态机
|
|
|
+ */
|
|
|
+ enum State {
|
|
|
+ INIT,
|
|
|
+ FIELD_FOUND,
|
|
|
+ EQ_FOUND,
|
|
|
+ VERSION_VALUE_PRESENT;
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 字段值的key
|
|
|
+ */
|
|
|
+ private String valueKey;
|
|
|
+ /**
|
|
|
+ * 当前状态
|
|
|
+ */
|
|
|
+ private State state;
|
|
|
+ /**
|
|
|
+ * 字段名
|
|
|
+ */
|
|
|
+ private final String fieldName;
|
|
|
+
|
|
|
+ public FieldEqFinder(String fieldName, Wrapper<?> wrapper) {
|
|
|
+ this.fieldName = fieldName;
|
|
|
+ state = State.INIT;
|
|
|
+ find(wrapper);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 是否已存在
|
|
|
+ */
|
|
|
+ public boolean isPresent() {
|
|
|
+ return state == State.VERSION_VALUE_PRESENT;
|
|
|
+ }
|
|
|
+
|
|
|
+ private boolean find(Wrapper<?> wrapper) {
|
|
|
+ Matcher matcher;
|
|
|
+ final NormalSegmentList segments = wrapper.getExpression().getNormal();
|
|
|
+ for (ISqlSegment segment : segments) {
|
|
|
+ // 如果字段已找到并且当前segment为EQ
|
|
|
+ if (state == State.FIELD_FOUND && segment == SqlKeyword.EQ) {
|
|
|
+ this.state = State.EQ_FOUND;
|
|
|
+ // 如果EQ找到并且value已找到
|
|
|
+ } else if (state == State.EQ_FOUND
|
|
|
+ && (matcher = PARAM_PAIRS_RE.matcher(segment.getSqlSegment())).matches()) {
|
|
|
+ this.valueKey = matcher.group(1);
|
|
|
+ this.state = State.VERSION_VALUE_PRESENT;
|
|
|
+ return true;
|
|
|
+ // 处理嵌套
|
|
|
+ } else if (segment instanceof Wrapper) {
|
|
|
+ if (find((Wrapper<?>) segment)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ // 判断字段是否是要查找字段
|
|
|
+ } else if (segment.getSqlSegment().equals(this.fieldName)) {
|
|
|
+ this.state = State.FIELD_FOUND;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/**
|