Browse Source

乐观锁优化 > update(et,ew) > 支持乐观锁:et带上version

yuxiaobin 7 years ago
parent
commit
82a4714892

+ 147 - 116
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/plugins/OptimisticLockerInterceptor.java

@@ -22,11 +22,13 @@ import org.apache.ibatis.plugin.Signature;
 import com.baomidou.mybatisplus.annotations.Version;
 import com.baomidou.mybatisplus.entity.TableFieldInfo;
 import com.baomidou.mybatisplus.entity.TableInfo;
+import com.baomidou.mybatisplus.mapper.EntityWrapper;
 import com.baomidou.mybatisplus.mapper.Wrapper;
 import com.baomidou.mybatisplus.toolkit.ClassUtils;
 import com.baomidou.mybatisplus.toolkit.ReflectionKit;
 import com.baomidou.mybatisplus.toolkit.TableInfoHelper;
 
+
 /**
  * <p>
  * Optimistic Lock Light version<BR>
@@ -47,20 +49,19 @@ import com.baomidou.mybatisplus.toolkit.TableInfoHelper;
  * </p>
  *
  * @author yuxiaobin
- * @date 2017/5/24
+ * @since 2017/5/24
  */
 @Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})
 public class OptimisticLockerInterceptor implements Interceptor {
 
-    private final Map<Class<?>, EntityField> versionFieldCache = new ConcurrentHashMap<>();
-    private final Map<Class<?>, List<EntityField>> entityFieldsCache = new ConcurrentHashMap<>();
-
-    private static final String MP_OPTLOCK_VERSION_ORIGINAL = "MP_OPTLOCK_VERSION_ORIGINAL";
-    private static final String MP_OPTLOCK_VERSION_COLUMN = "MP_OPTLOCK_VERSION_COLUMN";
+    public static final String MP_OPTLOCK_VERSION_ORIGINAL = "MP_OPTLOCK_VERSION_ORIGINAL";
+    public static final String MP_OPTLOCK_VERSION_COLUMN = "MP_OPTLOCK_VERSION_COLUMN";
     public static final String MP_OPTLOCK_ET_ORIGINAL = "MP_OPTLOCK_ET_ORIGINAL";
     private static final String NAME_ENTITY = "et";
     private static final String NAME_ENTITY_WRAPPER = "ew";
     private static final String PARAM_UPDATE_METHOD_NAME = "update";
+    private final Map<Class<?>, EntityField> versionFieldCache = new ConcurrentHashMap<>();
+    private final Map<Class<?>, List<EntityField>> entityFieldsCache = new ConcurrentHashMap<>();
 
     @Override
     @SuppressWarnings("unchecked")
@@ -71,93 +72,102 @@ public class OptimisticLockerInterceptor implements Interceptor {
             return invocation.proceed();
         }
         Object param = args[1];
-        Object et = null;
-        Field versionField = null;
-        Object updatedVersionVal = null;
-        if (param instanceof HashMap) {
-            HashMap map = (HashMap) param;
-            Wrapper ew = null;
-            if (map.containsKey(NAME_ENTITY_WRAPPER)) {
-                // mapper.update(updEntity, EntityWrapper<>(whereEntity);
-                ew = (Wrapper) map.get(NAME_ENTITY_WRAPPER);
-            }
-            //else updateById(entity) -->> change updateById(entity) to updateById(@Param("et") entity)
 
+        // wrapper = ew
+        Wrapper ew = null;
+        // entity = et
+        Object et = null;
+        if (param instanceof Map) {
+            Map map = (Map) param;
             if (map.containsKey(NAME_ENTITY)) {
+                //updateById(et), update(et, wrapper);
                 et = map.get(NAME_ENTITY);
             }
-            if (ew != null) {
-                Object entity = ew.getEntity();
-                if (entity != null) {
-                    Class<?> entityClass = ClassUtils.getUserClass(entity.getClass());
-                    EntityField ef = getVersionField(entityClass);
-                    versionField = ef == null ? null : ef.getField();
-                    if (versionField != null) {
-                        Object originalVersionVal = versionField.get(entity);
-                        if (originalVersionVal != null) {
-                            versionField.set(et, updatedVersionVal = getUpdatedVersionVal(originalVersionVal));
-                        }
-                    }
-                }
-            } else if (et != null) {
+            if (map.containsKey(NAME_ENTITY_WRAPPER)) {
+                // mapper.update(updEntity, QueryWrapper<>(whereEntity);
+                ew = (Wrapper) map.get(NAME_ENTITY_WRAPPER);
+            }
+            if (et != null) {
+                // entity
                 String methodId = ms.getId();
                 String updateMethodName = methodId.substring(ms.getId().lastIndexOf(".") + 1);
-                if (PARAM_UPDATE_METHOD_NAME.equals(updateMethodName)) {
-                    //update(entityClass, null) -->> update all. ignore version
+                Class<?> entityClass = et.getClass();
+                TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
+                // fixed github 299
+                while (tableInfo == null && entityClass != null) {
+                    entityClass = ClassUtils.getUserClass(entityClass.getSuperclass());
+                    tableInfo = TableInfoHelper.getTableInfo(entityClass);
+                }
+                EntityField entityVersionField = this.getVersionField(entityClass, tableInfo);
+                if (entityVersionField == null) {
                     return invocation.proceed();
                 }
-                Class<?> entityClass = ClassUtils.getUserClass(et.getClass());
-                EntityField entityField = this.getVersionField(entityClass);
-                versionField = entityField == null ? null : entityField.getField();
-                Object originalVersionVal;
-                if (versionField != null && (originalVersionVal = versionField.get(et)) != null) {
-                    TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
-                    // fixed github 299
-                    while (null == tableInfo && null != entityClass) {
-                        entityClass = ClassUtils.getUserClass(entityClass.getSuperclass());
-                        tableInfo = TableInfoHelper.getTableInfo(entityClass);
-                    }
-                    Map<String, Object> entityMap = new HashMap<>();
-                    List<EntityField> fields = getEntityFields(entityClass);
-                    for (EntityField ef : fields) {
-                        Field fd = ef.getField();
-                        if (fd.isAccessible()) {
-                            entityMap.put(fd.getName(), fd.get(et));
-                            if (ef.isVersion()) {
-                                versionField = fd;
-                            }
+                Field versionField = entityVersionField.getField();
+                Object originalVersionVal = entityVersionField.getField().get(et);
+                Object updatedVersionVal = getUpdatedVersionVal(originalVersionVal);
+                if (PARAM_UPDATE_METHOD_NAME.equals(updateMethodName)) {
+                    // update(entity, wrapper)
+                    if (originalVersionVal != null) {
+                        if (ew == null) {
+                            Wrapper aw = new EntityWrapper();
+                            aw.eq(entityVersionField.getColumnName(), originalVersionVal);
+                            map.put(NAME_ENTITY_WRAPPER, aw);
+                            versionField.set(et, updatedVersionVal);
+                        } else if (ew instanceof EntityWrapper) {
+                            EntityWrapper aw = (EntityWrapper) ew;
+                            aw.eq(entityVersionField.getColumnName(), originalVersionVal);
+                            versionField.set(et, updatedVersionVal);
+                            //TODO: should remove version=oldval condition from aw; 0827
                         }
                     }
-                    String versionPropertyName = versionField.getName();
-                    List<TableFieldInfo> fieldList = tableInfo.getFieldList();
-                    String versionColumnName = entityField.getColumnName();
-                    if (versionColumnName == null) {
-                        for (TableFieldInfo tf : fieldList) {
-                            if (versionPropertyName.equals(tf.getProperty())) {
-                                versionColumnName = tf.getColumn();
-                            }
+                    return invocation.proceed();
+                } else {
+                    dealUpdateById(entityClass, et, entityVersionField, originalVersionVal, updatedVersionVal, map);
+                    Object resultObj = invocation.proceed();
+                    if (resultObj instanceof Integer) {
+                        Integer effRow = (Integer) resultObj;
+                        if (updatedVersionVal != null && effRow != 0 && versionField != null) {
+                            //updated version value set to entity.
+                            versionField.set(et, updatedVersionVal);
                         }
                     }
-                    if (versionColumnName != null) {
-                        entityField.setColumnName(versionColumnName);
-                        entityMap.put(versionField.getName(), updatedVersionVal = getUpdatedVersionVal(originalVersionVal));
-                        entityMap.put(MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
-                        entityMap.put(MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
-                        entityMap.put(MP_OPTLOCK_ET_ORIGINAL, et);
-                        map.put(NAME_ENTITY, entityMap);
-                    }
+                    return resultObj;
                 }
             }
         }
-        Object resultObj = invocation.proceed();
-        if (resultObj instanceof Integer) {
-            Integer effRow = (Integer) resultObj;
-            if (effRow != 0 && et != null && versionField != null && updatedVersionVal != null) {
-                //updated version value set to entity.
-                versionField.set(et, updatedVersionVal);
-            }
+        return invocation.proceed();
+    }
+
+    /**
+     * 处理updateById(entity)乐观锁逻辑
+     *
+     * @param entityClass        实体类
+     * @param et                 参数entity
+     * @param entityVersionField
+     * @param originalVersionVal 原来版本的value
+     * @param updatedVersionVal  乐观锁自动更新的新value
+     * @param map
+     */
+    private void dealUpdateById(Class<?> entityClass, Object et, EntityField entityVersionField,
+                                Object originalVersionVal, Object updatedVersionVal, Map map) throws IllegalAccessException {
+        if (originalVersionVal == null) {
+            return;
+        }
+        List<EntityField> fields = getEntityFields(entityClass);
+        Map<String, Object> entityMap = new HashMap<>();
+        for (EntityField ef : fields) {
+            Field fd = ef.getField();
+            entityMap.put(fd.getName(), fd.get(et));
         }
-        return resultObj;
+        Field versionField = entityVersionField.getField();
+        String versionColumnName = entityVersionField.getColumnName();
+        //update to cache
+        entityVersionField.setColumnName(versionColumnName);
+        entityMap.put(versionField.getName(), updatedVersionVal);
+        entityMap.put(MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
+        entityMap.put(MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
+        entityMap.put(MP_OPTLOCK_ET_ORIGINAL, et);
+        map.put(NAME_ENTITY, entityMap);
     }
 
     /**
@@ -168,6 +178,9 @@ public class OptimisticLockerInterceptor implements Interceptor {
      * @return updated version val
      */
     protected Object getUpdatedVersionVal(Object originalVersionVal) {
+        if (null == originalVersionVal) {
+            return null;
+        }
         Class<?> versionValClass = originalVersionVal.getClass();
         if (long.class.equals(versionValClass)) {
             return ((long) originalVersionVal) + 1;
@@ -182,7 +195,8 @@ public class OptimisticLockerInterceptor implements Interceptor {
         } else if (Timestamp.class.equals(versionValClass)) {
             return new Timestamp(System.currentTimeMillis());
         }
-        return originalVersionVal;//not supported type, return original val.
+        //not supported type, return original val.
+        return originalVersionVal;
     }
 
     @Override
@@ -198,13 +212,13 @@ public class OptimisticLockerInterceptor implements Interceptor {
         // to do nothing
     }
 
-    private EntityField getVersionField(Class<?> parameterClass) {
+    private EntityField getVersionField(Class<?> parameterClass, TableInfo tableInfo) {
         synchronized (parameterClass.getName()) {
             if (versionFieldCache.containsKey(parameterClass)) {
                 return versionFieldCache.get(parameterClass);
             }
             // 缓存类信息
-            EntityField field = this.getVersionFieldRegular(parameterClass);
+            EntityField field = this.getVersionFieldRegular(parameterClass, tableInfo);
             if (field != null) {
                 versionFieldCache.put(parameterClass, field);
                 return field;
@@ -218,36 +232,48 @@ public class OptimisticLockerInterceptor implements Interceptor {
      * 反射检查参数类是否启动乐观锁
      * </p>
      *
-     * @param parameterClass 参数类
+     * @param parameterClass 实体类
+     * @param tableInfo      实体数据库反射信息
      * @return
      */
-    private EntityField getVersionFieldRegular(Class<?> parameterClass) {
+    private EntityField getVersionFieldRegular(Class<?> parameterClass, TableInfo tableInfo) {
         if (parameterClass != Object.class) {
             for (Field field : parameterClass.getDeclaredFields()) {
                 if (field.isAnnotationPresent(Version.class)) {
                     field.setAccessible(true);
-                    return new EntityField(field, true);
+                    String versionPropertyName = field.getName();
+                    String versionColumnName = null;
+                    for (TableFieldInfo fieldInfo : tableInfo.getFieldList()) {
+                        if (versionPropertyName.equals(fieldInfo.getProperty())) {
+                            versionColumnName = fieldInfo.getColumn();
+                        }
+                    }
+                    return new EntityField(field, true, versionColumnName);
                 }
             }
             // 递归父类
-            return this.getVersionFieldRegular(parameterClass.getSuperclass());
+            return this.getVersionFieldRegular(parameterClass.getSuperclass(), tableInfo);
         }
         return null;
     }
 
+    /**
+     * 获取实体的反射属性(类似getter)
+     *
+     * @param parameterClass
+     * @return
+     */
     private List<EntityField> getEntityFields(Class<?> parameterClass) {
         if (entityFieldsCache.containsKey(parameterClass)) {
             return entityFieldsCache.get(parameterClass);
         }
-        List<EntityField> fields = this.getFieldsFromClazz(parameterClass, null);
+        List<EntityField> fields = this.getFieldsFromClazz(parameterClass);
         entityFieldsCache.put(parameterClass, fields);
         return fields;
     }
 
-    private List<EntityField> getFieldsFromClazz(Class<?> parameterClass, List<EntityField> fieldList) {
-        if (fieldList == null) {
-            fieldList = new ArrayList<>();
-        }
+    private List<EntityField> getFieldsFromClazz(Class<?> parameterClass) {
+        List<EntityField> fieldList = new ArrayList<>();
         List<Field> fields = ReflectionKit.getFieldList(parameterClass);
         for (Field field : fields) {
             field.setAccessible(true);
@@ -259,41 +285,46 @@ public class OptimisticLockerInterceptor implements Interceptor {
         }
         return fieldList;
     }
-}
 
-class EntityField {
+    private class EntityField {
 
-    private Field field;
-    private boolean version;
-    private String columnName;
+        EntityField(Field field, boolean version) {
+            this.field = field;
+            this.version = version;
+        }
 
-    public EntityField(Field field, boolean version) {
-        this.field = field;
-        this.version = version;
-    }
+        public EntityField(Field field, boolean version, String columnName) {
+            this.field = field;
+            this.version = version;
+            this.columnName = columnName;
+        }
 
-    public Field getField() {
-        return field;
-    }
+        private Field field;
+        private boolean version;
+        private String columnName;
 
-    public void setField(Field field) {
-        this.field = field;
-    }
+        public Field getField() {
+            return field;
+        }
 
-    public boolean isVersion() {
-        return version;
-    }
+        public void setField(Field field) {
+            this.field = field;
+        }
 
-    public void setVersion(boolean version) {
-        this.version = version;
-    }
+        public boolean isVersion() {
+            return version;
+        }
 
-    public String getColumnName() {
-        return columnName;
-    }
+        public void setVersion(boolean version) {
+            this.version = version;
+        }
 
-    public void setColumnName(String columnName) {
-        this.columnName = columnName;
+        public String getColumnName() {
+            return columnName;
+        }
+
+        public void setColumnName(String columnName) {
+            this.columnName = columnName;
+        }
     }
 }
-