Browse Source

add optLocker support for update(et, wrapper) method
et.setVersion(oldval) -> update version=oldval+1 where version=oldval
previous version: ew.et.setVersion(oldval)

yuxiaobin 7 years ago
parent
commit
47c39bec58

+ 118 - 103
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/plugins/OptimisticLockerInterceptor.java

@@ -1,21 +1,37 @@
 package com.baomidou.mybatisplus.extension.plugins;
 
+import java.lang.reflect.Field;
+import java.sql.Timestamp;
+import java.time.LocalDateTime;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.ibatis.executor.Executor;
+import org.apache.ibatis.mapping.MappedStatement;
+import org.apache.ibatis.mapping.SqlCommandType;
+import org.apache.ibatis.plugin.Interceptor;
+import org.apache.ibatis.plugin.Intercepts;
+import org.apache.ibatis.plugin.Invocation;
+import org.apache.ibatis.plugin.Plugin;
+import org.apache.ibatis.plugin.Signature;
+
 import com.baomidou.mybatisplus.annotation.Version;
+import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
 import com.baomidou.mybatisplus.core.conditions.Wrapper;
 import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
 import com.baomidou.mybatisplus.core.metadata.TableInfo;
-import com.baomidou.mybatisplus.core.toolkit.*;
-import lombok.Data;
-import org.apache.ibatis.executor.Executor;
-import org.apache.ibatis.mapping.MappedStatement;
-import org.apache.ibatis.mapping.SqlCommandType;
-import org.apache.ibatis.plugin.*;
+import com.baomidou.mybatisplus.core.toolkit.ClassUtils;
+import com.baomidou.mybatisplus.core.toolkit.Constants;
+import com.baomidou.mybatisplus.core.toolkit.ReflectionKit;
+import com.baomidou.mybatisplus.core.toolkit.StringPool;
+import com.baomidou.mybatisplus.core.toolkit.TableInfoHelper;
 
-import java.lang.reflect.Field;
-import java.sql.Timestamp;
-import java.time.LocalDateTime;
-import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
+import lombok.Data;
 
 /**
  * <p>
@@ -60,14 +76,6 @@ public class OptimisticLockerInterceptor implements Interceptor {
             return invocation.proceed();
         }
         Object param = args[1];
-        // JavaBean class
-        Class<?> entityClass = null;
-        // optimistic, Version Field
-        Field versionField = null;
-        // optimistic field annotated by Version
-        String versionColumnName = null;
-        // new version value, Long, Integer ...
-        Object updatedVersionVal = null;
 
         // wrapper = ew
         Wrapper ew = null;
@@ -75,96 +83,82 @@ public class OptimisticLockerInterceptor implements Interceptor {
         Object et = null;
         if (param instanceof Map) {
             Map map = (Map) param;
+            if (map.containsKey(NAME_ENTITY)) {
+                et = map.get(NAME_ENTITY);//updateById(et), update(et, wrapper);
+            }
             if (map.containsKey(NAME_ENTITY_WRAPPER)) {
                 // mapper.update(updEntity, QueryWrapper<>(whereEntity);
                 ew = (Wrapper) map.get(NAME_ENTITY_WRAPPER);
             }
-            //else updateById(entity) -->> change updateById(entity) to updateById(@Param("et") entity)
-
-            // TODO 待验证逻辑
-            // if mannual sql or updagteById(entity),unsupport OCC,proceed as usual unless use updateById(@Param("et") entity)
-            //if(!map.containsKey(NAME_ENTITY)) {
-            //    return invocation.proceed();
-            //}
-            if (map.containsKey(NAME_ENTITY)) {
-                et = map.get(NAME_ENTITY);
-            }
-            if (ew != null) { // entityWrapper. baseMapper.update(et,ew);
-                Object entity = ew.getEntity();
-                if (entity != null) {
-                    entityClass = ClassUtils.getUserClass(entity.getClass());
-                    EntityField ef = getVersionField(entityClass);
-                    versionField = ef == null ? null : ef.getField();
-                    if (versionField != null) {
-                        Object originalVersionVal = versionField.get(entity);
-                        if (originalVersionVal != null) {
-                            updatedVersionVal = getUpdatedVersionVal(originalVersionVal);
-                            versionField.set(et, updatedVersionVal);
-                        }
-                    }
-                }
-            } else if (et != null) { // entity
+            if (et != null) { // entity
                 String methodId = ms.getId();
                 String updateMethodName = methodId.substring(ms.getId().lastIndexOf(StringPool.DOT) + 1);
-                if (PARAM_UPDATE_METHOD_NAME.equals(updateMethodName)) {
-                    // update(entityClass, null) -->> update all. ignore version
+                Class<?> entityClass = et.getClass();
+                TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
+                // fixed github 299
+                while (tableInfo == null && entityClass != null) {
+                    entityClass = ClassUtils.getUserClass(entityClass.getSuperclass());
+                    tableInfo = TableInfoHelper.getTableInfo(entityClass);
+                }
+                EntityField entityVersionField = this.getVersionField(entityClass, tableInfo);
+                if (entityVersionField == null) {
                     return invocation.proceed();
                 }
-                //invoke: baseMapper.updateById()
-                entityClass = ClassUtils.getUserClass(et.getClass());
-                EntityField entityField = this.getVersionField(entityClass);
-                versionField = entityField == null ? null : entityField.getField();
-                Object originalVersionVal;
-                if (versionField != null && (originalVersionVal = versionField.get(et)) != null) {
-                    TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
-                    // fixed github 299
-                    while (null == tableInfo && null != entityClass) {
-                        entityClass = ClassUtils.getUserClass(entityClass.getSuperclass());
-                        tableInfo = TableInfoHelper.getTableInfo(entityClass);
-                    }
-                    Map<String, Object> entityMap = new HashMap<>();
-                    List<EntityField> fields = getEntityFields(entityClass);
-                    for (EntityField ef : fields) {
-                        Field fd = ef.getField();
-                        if (fd.isAccessible()) {
-                            entityMap.put(fd.getName(), fd.get(et));
-                            if (ef.isVersion()) {
-                                versionField = fd;
-                            }
-                        }
+                Field versionField = entityVersionField.getField();
+                Object originalVersionVal = entityVersionField.getField().get(et);
+                Object updatedVersionVal = getUpdatedVersionVal(originalVersionVal);
+                if (PARAM_UPDATE_METHOD_NAME.equals(updateMethodName)) {
+                    // update(entity, wrapper)
+                    if (ew instanceof AbstractWrapper) {
+                        AbstractWrapper aw = (AbstractWrapper) ew;
+                        aw.eq(entityVersionField.getColumnName(), originalVersionVal);
+                        versionField.set(et, updatedVersionVal);
+                        //TODO: should remove version=oldval condition from aw; 0827
                     }
-                    String versionPropertyName = versionField.getName();
-                    List<TableFieldInfo> fieldList = tableInfo.getFieldList();
-                    versionColumnName = entityField.getColumnName();
-                    if (versionColumnName == null) {
-                        for (TableFieldInfo tf : fieldList) {
-                            if (versionPropertyName.equals(tf.getProperty())) {
-                                versionColumnName = tf.getColumn();
-                            }
+                    return invocation.proceed();
+                } else {
+                    dealUpdateById(entityClass, et, entityVersionField, originalVersionVal, updatedVersionVal, map);
+                    Object resultObj = invocation.proceed();
+                    if (resultObj instanceof Integer) {
+                        Integer effRow = (Integer) resultObj;
+                        if (effRow != 0 && versionField != null && updatedVersionVal != null) {
+                            //updated version value set to entity.
+                            versionField.set(et, updatedVersionVal);
                         }
                     }
-                    if (versionColumnName != null) {
-                        entityField.setColumnName(versionColumnName);
-                        updatedVersionVal = getUpdatedVersionVal(originalVersionVal);
-                        entityMap.put(versionField.getName(), updatedVersionVal);
-                        entityMap.put(MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
-                        entityMap.put(MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
-                        entityMap.put(MP_OPTLOCK_ET_ORIGINAL, et);
-                        map.put(NAME_ENTITY, entityMap);
-                    }
+                    return resultObj;
                 }
             }
         }
+        return invocation.proceed();
+    }
 
-        Object resultObj = invocation.proceed();
-        if (resultObj instanceof Integer) {
-            Integer effRow = (Integer) resultObj;
-            if (effRow != 0 && et != null && versionField != null && updatedVersionVal != null) {
-                //updated version value set to entity.
-                versionField.set(et, updatedVersionVal);
-            }
+    /**
+     * 处理updateById(entity)乐观锁逻辑
+     *
+     * @param entityClass        实体类
+     * @param et                 参数entity
+     * @param entityVersionField
+     * @param originalVersionVal 原来版本的value
+     * @param updatedVersionVal  乐观锁自动更新的新value
+     * @param map
+     */
+    private void dealUpdateById(Class<?> entityClass, Object et, EntityField entityVersionField,
+                                Object originalVersionVal, Object updatedVersionVal, Map map) throws IllegalAccessException {
+        List<EntityField> fields = getEntityFields(entityClass);
+        Map<String, Object> entityMap = new HashMap<>();
+        for (EntityField ef : fields) {
+            Field fd = ef.getField();
+            entityMap.put(fd.getName(), fd.get(et));
         }
-        return resultObj;
+        Field versionField = entityVersionField.getField();
+        String versionColumnName = entityVersionField.getColumnName();
+        entityVersionField.setColumnName(versionColumnName);//update to cache
+        entityMap.put(versionField.getName(), updatedVersionVal);
+        entityMap.put(MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
+        entityMap.put(MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
+        entityMap.put(MP_OPTLOCK_ET_ORIGINAL, et);
+        map.put(NAME_ENTITY, entityMap);
     }
 
     /**
@@ -207,13 +201,13 @@ public class OptimisticLockerInterceptor implements Interceptor {
         // to do nothing
     }
 
-    private EntityField getVersionField(Class<?> parameterClass) {
+    private EntityField getVersionField(Class<?> parameterClass, TableInfo tableInfo) {
         synchronized (parameterClass.getName()) {
             if (versionFieldCache.containsKey(parameterClass)) {
                 return versionFieldCache.get(parameterClass);
             }
             // 缓存类信息
-            EntityField field = this.getVersionFieldRegular(parameterClass);
+            EntityField field = this.getVersionFieldRegular(parameterClass, tableInfo);
             if (field != null) {
                 versionFieldCache.put(parameterClass, field);
                 return field;
@@ -227,23 +221,37 @@ public class OptimisticLockerInterceptor implements Interceptor {
      * 反射检查参数类是否启动乐观锁
      * </p>
      *
-     * @param parameterClass 参数类
+     * @param parameterClass 实体类
+     * @param tableInfo      实体数据库反射信息
      * @return
      */
-    private EntityField getVersionFieldRegular(Class<?> parameterClass) {
+    private EntityField getVersionFieldRegular(Class<?> parameterClass, TableInfo tableInfo) {
         if (parameterClass != Object.class) {
             for (Field field : parameterClass.getDeclaredFields()) {
                 if (field.isAnnotationPresent(Version.class)) {
                     field.setAccessible(true);
-                    return new EntityField(field, true);
+                    String versionPropertyName = field.getName();
+                    String versionColumnName = null;
+                    for (TableFieldInfo fieldInfo : tableInfo.getFieldList()) {
+                        if (versionPropertyName.equals(fieldInfo.getProperty())) {
+                            versionColumnName = fieldInfo.getColumn();
+                        }
+                    }
+                    return new EntityField(field, true, versionColumnName);
                 }
             }
             // 递归父类
-            return this.getVersionFieldRegular(parameterClass.getSuperclass());
+            return this.getVersionFieldRegular(parameterClass.getSuperclass(), tableInfo);
         }
         return null;
     }
 
+    /**
+     * 获取实体的反射属性(类似getter)
+     *
+     * @param parameterClass
+     * @return
+     */
     private List<EntityField> getEntityFields(Class<?> parameterClass) {
         if (entityFieldsCache.containsKey(parameterClass)) {
             return entityFieldsCache.get(parameterClass);
@@ -272,13 +280,20 @@ public class OptimisticLockerInterceptor implements Interceptor {
     @Data
     private class EntityField {
 
-        private Field field;
-        private boolean version;
-        private String columnName;
-
         EntityField(Field field, boolean version) {
             this.field = field;
             this.version = version;
         }
+
+        public EntityField(Field field, boolean version, String columnName) {
+            this.field = field;
+            this.version = version;
+            this.columnName = columnName;
+        }
+
+        private Field field;
+        private boolean version;
+        private String columnName;
+
     }
 }

+ 21 - 2
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java

@@ -147,14 +147,33 @@ public class H2UserTest extends BaseTest {
 
     @Test
     public void testUpdateByEwWithOptLock() {
+        H2User userInsert = new H2User();
+        userInsert.setName("optLockerTest");
+        userInsert.setAge(AgeEnum.THREE);
+        userInsert.setPrice(BigDecimal.TEN);
+        userInsert.setDesc("asdf");
+        userInsert.setTestType(1);
+        userInsert.setVersion(99);
+        userService.save(userInsert);
+
         QueryWrapper<H2User> ew = new QueryWrapper<>();
-        ew.gt("age", 13);
+        ew.ge("age", AgeEnum.TWO.getValue());
+        Long id99 = null;
         for (H2User u : userService.list(ew)) {
             System.out.println(u.getName() + "," + u.getAge() + "," + u.getVersion());
+            if (u.getVersion() != null && u.getVersion() == 99) {
+                id99 = u.getTestId();
+            }
         }
-        userService.update(new H2User().setPrice(BigDecimal.TEN), ew);
+        userService.update(new H2User().setPrice(BigDecimal.TEN).setVersion(99), ew);
+        System.out.println("============after update");
+        ew = new QueryWrapper<>();
+        ew.ge("age", AgeEnum.TWO.getValue());
         for (H2User u : userService.list(ew)) {
             System.out.println(u.getName() + "," + u.getAge() + "," + u.getVersion());
+            if (id99 != null && u.getTestId().equals(id99)) {
+                Assert.assertEquals("optLocker should update version+=1", 100, u.getVersion().intValue());
+            }
         }
     }