Browse Source

乐观锁 回写更新后的version到实体

yuxiaobin 7 years ago
parent
commit
a4b8dc7001

+ 28 - 24
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/OptimisticLockerInterceptor.java

@@ -1,5 +1,24 @@
 package com.baomidou.mybatisplus.extension.plugins;
 
+import java.lang.reflect.Field;
+import java.sql.Timestamp;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
+import org.apache.ibatis.plugin.Interceptor;
+import org.apache.ibatis.plugin.Intercepts;
+import org.apache.ibatis.plugin.Invocation;
+import org.apache.ibatis.plugin.Plugin;
+import org.apache.ibatis.plugin.Signature;
+
 import com.baomidou.mybatisplus.annotation.Version;
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
 import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
@@ -7,17 +26,6 @@ import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.toolkit.ClassUtils;
 import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
 import com.baomidou.mybatisplus.core.toolkit.TableInfoHelper;
-import org.apache.ibatis.binding.MapperMethod;
-import org.apache.ibatis.executor.Executor;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.apache.ibatis.mapping.SqlCommandType;
-import org.apache.ibatis.plugin.*;
-
-import java.lang.reflect.Field;
-import java.lang.reflect.Method;
-import java.sql.Timestamp;
-import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
 
 /**
  * <p>
@@ -76,8 +84,8 @@ public class OptimisticLockerInterceptor implements Interceptor {
         Wrapper ew = null;
         // entity = et
         Object et = null;
-        if (param instanceof MapperMethod.ParamMap) {
-            MapperMethod.ParamMap map = (MapperMethod.ParamMap) param;
+        if (param instanceof Map) {
+            Map map = (Map) param;
             if (map.containsKey(NAME_ENTITY_WRAPPER)) {
                 // mapper.update(updEntity, QueryWrapper<>(whereEntity);
                 ew = (Wrapper) map.get(NAME_ENTITY_WRAPPER);
@@ -92,7 +100,7 @@ public class OptimisticLockerInterceptor implements Interceptor {
             if (map.containsKey(NAME_ENTITY)) {
                 et = map.get(NAME_ENTITY);
             }
-            if (ew != null) { // entityWrapper
+            if (ew != null) { // entityWrapper. baseMapper.update(et,ew);
                 Object entity = ew.getEntity();
                 if (entity != null) {
                     entityClass = ClassUtils.getUserClass(entity.getClass());
@@ -113,6 +121,7 @@ public class OptimisticLockerInterceptor implements Interceptor {
                     // update(entityClass, null) -->> update all. ignore version
                     return invocation.proceed();
                 }
+                //invoke: baseMapper.updateById()
                 entityClass = ClassUtils.getUserClass(et.getClass());
                 EntityField entityField = this.getVersionField(entityClass);
                 versionField = entityField == null ? null : entityField.getField();
@@ -159,16 +168,11 @@ public class OptimisticLockerInterceptor implements Interceptor {
         }
 
         Object resultObj = invocation.proceed();
-        if (Objects.equals(1, resultObj)) {
-            // setVersion, Long.class
-            String _setterMethodName = ReflectionKit.setMethodCapitalize(versionField, versionColumnName);
-            Class<?> _fieldType = versionField.getType();
-            if (ew != null) {
-                Method md = ew.getClass().getMethod(_setterMethodName, _fieldType);
-                Object o = md.invoke(ew, updatedVersionVal);
-            } else if (et != null) {
-                Method md = et.getClass().getMethod(_setterMethodName, _fieldType);
-                Object o = md.invoke(et, updatedVersionVal);
+        if (resultObj instanceof Integer) {
+            Integer effRow = (Integer) resultObj;
+            if (effRow != 0 && et != null && versionField != null && updatedVersionVal != null) {
+                //updated version value set to entity.
+                versionField.set(et, updatedVersionVal);
             }
         }
         return resultObj;

+ 40 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java

@@ -1,6 +1,7 @@
 package com.baomidou.mybatisplus.test.h2;
 
 import java.io.IOException;
+import java.math.BigDecimal;
 import java.sql.SQLException;
 import java.util.HashMap;
 import java.util.List;
@@ -117,4 +118,43 @@ public class H2UserTest extends BaseTest {
         Assert.assertNotEquals(0, count);
     }
 
+    @Test
+    public void testUpdateByIdWitiOptLock(){
+        Long id = 991L;
+        H2User user = new H2User();
+        user.setTestId(id);
+        user.setName("991");
+        user.setAge(91);
+        user.setPrice(BigDecimal.TEN);
+        user.setDesc("asdf");
+        user.setTestType(1);
+        user.setVersion(1);
+        userService.insert(user);
+
+        H2User userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+
+        userDB.setName("992");
+        userService.updateById(userDB);
+        Assert.assertEquals("updated version value should be updated to entity",2, userDB.getVersion().intValue());
+
+        userDB = userService.selectById(id);
+        Assert.assertEquals(2, userDB.getVersion().intValue());
+        Assert.assertEquals("992", userDB.getName());
+    }
+
+    @Test
+    public void testUpdateByEwWithOptLock(){
+        QueryWrapper<H2User> ew = new QueryWrapper<>();
+        ew.gt("age",13);
+        for(H2User u: userService.selectList(ew)){
+            System.out.println(u.getName()+","+u.getAge()+","+u.getVersion());
+        }
+        userService.update(new H2User().setPrice(BigDecimal.TEN), ew);
+        for(H2User u: userService.selectList(ew)){
+            System.out.println(u.getName()+","+u.getAge()+","+u.getVersion());
+        }
+    }
+
+
 }