Explorar o código

乐观锁插件支持根据wrapper填充 github pull/3664

hubin %!s(int64=3) %!d(string=hai) anos
pai
achega
9a70683c31

+ 162 - 7
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/inner/OptimisticLockerInnerInterceptor.java

@@ -17,12 +17,19 @@ package com.baomidou.mybatisplus.extension.plugins.inner;
 
 import com.baomidou.mybatisplus.annotation.Version;
 import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
+import com.baomidou.mybatisplus.core.conditions.ISqlSegment;
+import com.baomidou.mybatisplus.core.conditions.Wrapper;
+import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
+import com.baomidou.mybatisplus.core.conditions.update.Update;
 import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
+import com.baomidou.mybatisplus.core.enums.SqlKeyword;
+import com.baomidou.mybatisplus.core.mapper.Mapper;
 import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
 import com.baomidou.mybatisplus.core.toolkit.Constants;
 import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
+import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
 import com.baomidou.mybatisplus.core.toolkit.StringPool;
 import org.apache.ibatis.executor.Executor;
 import org.apache.ibatis.mapping.MappedStatement;
@@ -34,6 +41,10 @@ import java.sql.Timestamp;
 import java.time.LocalDateTime;
 import java.util.Date;
 import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
 
 /**
  * Optimistic Lock Light version
@@ -63,6 +74,31 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
         this.exception = exception;
     }
 
+    /**
+     * entity类缓存
+     */
+    private static final Map<String, Class<?>> ENTITY_CLASS_CACHE = new ConcurrentHashMap<>();
+    /**
+     * 变量占位符正则
+     */
+    private static final Pattern PARAM_PAIRS_RE = Pattern.compile("#\\{ew\\.paramNameValuePairs\\.(" + Constants.WRAPPER_PARAM + "\\d+)\\}");
+    /**
+     * paramNameValuePairs存放的version值的key
+     */
+    private static final String UPDATED_VERSION_VAL_KEY = "#updatedVersionVal#";
+    /**
+     * Support wrapper mode
+     */
+    private final boolean wrapperMode;
+
+    public OptimisticLockerInnerInterceptor() {
+        this(false);
+    }
+
+    public OptimisticLockerInnerInterceptor(boolean wrapperMode) {
+        this.wrapperMode = wrapperMode;
+    }
+
     @Override
     public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
         if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
@@ -75,17 +111,17 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
     }
 
     protected void doOptimisticLocker(Map<String, Object> map, String msId) {
-        //updateById(et), update(et, wrapper);
+        // updateById(et), update(et, wrapper);
         Object et = map.getOrDefault(Constants.ENTITY, null);
-        if (et != null) {
-            // entity
-            String methodName = msId.substring(msId.lastIndexOf(StringPool.DOT) + 1);
-            TableInfo tableInfo = TableInfoHelper.getTableInfo(et.getClass());
-            if (tableInfo == null || !tableInfo.isWithVersion()) {
+        if (Objects.nonNull(et)) {
+
+            // version field
+            TableFieldInfo fieldInfo = this.getVersionFieldInfo(et.getClass());
+            if (null == fieldInfo) {
                 return;
             }
+
             try {
-                TableFieldInfo fieldInfo = tableInfo.getVersionFieldInfo();
                 Field versionField = fieldInfo.getField();
                 // 旧的 version 值
                 Object originalVersionVal = versionField.get(et);
@@ -101,6 +137,7 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
                 String versionColumn = fieldInfo.getColumn();
                 // 新的 version 值
                 Object updatedVersionVal = this.getUpdatedVersionVal(fieldInfo.getPropertyType(), originalVersionVal);
+                String methodName = msId.substring(msId.lastIndexOf(StringPool.DOT) + 1);
                 if ("update".equals(methodName)) {
                     AbstractWrapper<?, ?, ?> aw = (AbstractWrapper<?, ?, ?>) map.getOrDefault(Constants.WRAPPER, null);
                     if (aw == null) {
@@ -118,6 +155,124 @@ public class OptimisticLockerInnerInterceptor implements InnerInterceptor {
                 throw ExceptionUtils.mpe(e);
             }
         }
+
+        // update(LambdaUpdateWrapper) or update(UpdateWrapper)
+        else if (wrapperMode && map.entrySet().stream().anyMatch(t -> Objects.equals(t.getKey(), Constants.WRAPPER))) {
+            setVersionByWrapper(map, msId);
+        }
+    }
+
+    protected TableFieldInfo getVersionFieldInfo(Class<?> entityClazz) {
+        TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClazz);
+        return (null != tableInfo && tableInfo.isWithVersion()) ? tableInfo.getVersionFieldInfo() : null;
+    }
+
+    private void setVersionByWrapper(Map<String, Object> map, String msId) {
+        Object ew = map.get(Constants.WRAPPER);
+        if (null != ew && ew instanceof AbstractWrapper && ew instanceof Update) {
+            Class<?> entityClass = ENTITY_CLASS_CACHE.get(msId);
+            if (null == entityClass) {
+                try {
+                    final String className = msId.substring(0, msId.lastIndexOf('.'));
+                    entityClass = ReflectionKit.getSuperClassGenericType(Class.forName(className), Mapper.class, 0);
+                    ENTITY_CLASS_CACHE.put(msId, entityClass);
+                } catch (ClassNotFoundException e) {
+                    throw ExceptionUtils.mpe(e);
+                }
+            }
+
+            final TableFieldInfo versionField = getVersionFieldInfo(entityClass);
+            if (null == versionField) {
+                return;
+            }
+
+            final String versionColumn = versionField.getColumn();
+            final FieldEqFinder fieldEqFinder = new FieldEqFinder(versionColumn, (Wrapper<?>) ew);
+            if (!fieldEqFinder.isPresent()) {
+                return;
+            }
+            final Map<String, Object> paramNameValuePairs = ((AbstractWrapper<?, ?, ?>) ew).getParamNameValuePairs();
+            final Object originalVersionValue = paramNameValuePairs.get(fieldEqFinder.valueKey);
+            if (originalVersionValue == null) {
+                return;
+            }
+            final Object updatedVersionVal = getUpdatedVersionVal(originalVersionValue.getClass(), originalVersionValue);
+            if (originalVersionValue == updatedVersionVal) {
+                return;
+            }
+            // 拼接新的version值
+            paramNameValuePairs.put(UPDATED_VERSION_VAL_KEY, updatedVersionVal);
+            ((Update<?, ?>) ew).setSql(String.format("%s = #{%s.%s}", versionColumn, "ew.paramNameValuePairs", UPDATED_VERSION_VAL_KEY));
+        }
+    }
+
+    /**
+     * EQ字段查找器
+     */
+    private static class FieldEqFinder {
+
+        /**
+         * 状态机
+         */
+        enum State {
+            INIT,
+            FIELD_FOUND,
+            EQ_FOUND,
+            VERSION_VALUE_PRESENT;
+
+        }
+
+        /**
+         * 字段值的key
+         */
+        private String valueKey;
+        /**
+         * 当前状态
+         */
+        private State state;
+        /**
+         * 字段名
+         */
+        private final String fieldName;
+
+        public FieldEqFinder(String fieldName, Wrapper<?> wrapper) {
+            this.fieldName = fieldName;
+            state = State.INIT;
+            find(wrapper);
+        }
+
+        /**
+         * 是否已存在
+         */
+        public boolean isPresent() {
+            return state == State.VERSION_VALUE_PRESENT;
+        }
+
+        private boolean find(Wrapper<?> wrapper) {
+            Matcher matcher;
+            final NormalSegmentList segments = wrapper.getExpression().getNormal();
+            for (ISqlSegment segment : segments) {
+                // 如果字段已找到并且当前segment为EQ
+                if (state == State.FIELD_FOUND && segment == SqlKeyword.EQ) {
+                    this.state = State.EQ_FOUND;
+                    // 如果EQ找到并且value已找到
+                } else if (state == State.EQ_FOUND
+                    && (matcher = PARAM_PAIRS_RE.matcher(segment.getSqlSegment())).matches()) {
+                    this.valueKey = matcher.group(1);
+                    this.state = State.VERSION_VALUE_PRESENT;
+                    return true;
+                    // 处理嵌套
+                } else if (segment instanceof Wrapper) {
+                    if (find((Wrapper<?>) segment)) {
+                        return true;
+                    }
+                    // 判断字段是否是要查找字段
+                } else if (segment.getSqlSegment().equals(this.fieldName)) {
+                    this.state = State.FIELD_FOUND;
+                }
+            }
+            return false;
+        }
     }
 
     /**

+ 2 - 1
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/spring/MybatisSqlSessionFactoryBean.java

@@ -604,10 +604,11 @@ public class MybatisSqlSessionFactoryBean implements FactoryBean<SqlSessionFacto
         }
 
         final SqlSessionFactory sqlSessionFactory = new MybatisSqlSessionFactoryBuilder().build(targetConfiguration);
+
         // TODO SqlRunner
         SqlHelper.FACTORY = sqlSessionFactory;
 
-        // TODO 打印骚东西 Banner
+        // TODO 打印 Banner
         if (globalConfig.isBanner()) {
             System.out.println(" _ _   |_  _ _|_. ___ _ |    _ ");
             System.out.println("| | |\\/|_)(_| | |_\\  |_)||_|_\\ ");

+ 40 - 2
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/version/VersionTest.java

@@ -13,12 +13,50 @@ import java.util.List;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
+import lombok.extern.slf4j.Slf4j;
+
 /**
  * @author miemie
+ * @author raylax
  * @since 2020-07-04
  */
+@Slf4j
 public class VersionTest extends BaseDbTest<EntityMapper> {
 
+
+    @Test
+    void testWrapperMode() {
+        log.info("[wrapper mode] test");
+
+        doTestAutoCommit(i -> {
+            int result = i.update(null, Wrappers.<Entity>update()
+                .eq("id", 3)
+                .set("version", 1)
+            );
+            assertThat(result).as("[wrapper mode] 设置version值成功").isEqualTo(1);
+        });
+
+        doTestAutoCommit(i -> {
+            int result = i.update(null, Wrappers.<Entity>update()
+                .eq("id", 3)
+                .eq("version", 1)
+            );
+            assertThat(result).as("[wrapper mode] 设置version值匹配更新成功").isEqualTo(1);
+            final Entity entity = i.selectById(3);
+            assertThat(entity.getVersion()).isEqualTo(2);
+        });
+
+        doTestAutoCommit(i -> {
+            int result = i.update(null, Wrappers.<Entity>update()
+                .eq("id", 3)
+                .eq("version", 1)
+            );
+            assertThat(result).as("[wrapper mode] 设置version值匹配更新失败").isEqualTo(0);
+            final Entity entity = i.selectById(3);
+            assertThat(entity.getVersion()).isEqualTo(2);
+        });
+    }
+
     @Test
     void test() {
         doTestAutoCommit(i -> {
@@ -59,13 +97,13 @@ public class VersionTest extends BaseDbTest<EntityMapper> {
     @Override
     protected List<Interceptor> interceptors() {
         MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
-        interceptor.addInnerInterceptor(new OptimisticLockerInnerInterceptor());
+        interceptor.addInnerInterceptor(new OptimisticLockerInnerInterceptor(true));
         return Collections.singletonList(interceptor);
     }
 
     @Override
     protected String tableDataSql() {
-        return "insert into entity(id,name) values(1,'老王'),(2,'老李')";
+        return "insert into entity(id,name) values(1,'老王'),(2,'老李'),(3,'老赵')";
     }
 
     @Override