OptimisticLockerInterceptor.java 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package com.baomidou.mybatisplus.plugins;
  2. import java.lang.reflect.Field;
  3. import java.sql.Timestamp;
  4. import java.util.ArrayList;
  5. import java.util.Date;
  6. import java.util.HashMap;
  7. import java.util.List;
  8. import java.util.Map;
  9. import java.util.Properties;
  10. import java.util.concurrent.ConcurrentHashMap;
  11. import org.apache.ibatis.binding.MapperMethod;
  12. import org.apache.ibatis.executor.Executor;
  13. import org.apache.ibatis.mapping.MappedStatement;
  14. import org.apache.ibatis.mapping.SqlCommandType;
  15. import org.apache.ibatis.plugin.Interceptor;
  16. import org.apache.ibatis.plugin.Intercepts;
  17. import org.apache.ibatis.plugin.Invocation;
  18. import org.apache.ibatis.plugin.Plugin;
  19. import org.apache.ibatis.plugin.Signature;
  20. import com.baomidou.mybatisplus.annotations.Version;
  21. import com.baomidou.mybatisplus.entity.TableFieldInfo;
  22. import com.baomidou.mybatisplus.entity.TableInfo;
  23. import com.baomidou.mybatisplus.mapper.Wrapper;
  24. import com.baomidou.mybatisplus.toolkit.ClassUtils;
  25. import com.baomidou.mybatisplus.toolkit.ReflectionKit;
  26. import com.baomidou.mybatisplus.toolkit.TableInfoHelper;
  27. /**
  28. * <p>
  29. * Optimistic Lock Light version<BR>
  30. * Intercept on {@link Executor}.update;<BR>
  31. * Support version types: int/Integer, long/Long, java.util.Date, java.sql.Timestamp<BR>
  32. * For extra types, please define a subclass and override {@code getUpdatedVersionVal}() method.<BR>
  33. * <BR>
  34. * How to use?<BR>
  35. * (1) Define an Entity and add {@link Version} annotation on one entity field.<BR>
  36. * (2) Add {@link OptimisticLockerInterceptor} into mybatis plugin.
  37. * <p>
  38. * How to work?<BR>
  39. * if update entity with version column=1:<BR>
  40. * (1) no {@link OptimisticLockerInterceptor}:<BR>
  41. * SQL: update tbl_test set name='abc' where id=100001;<BR>
  42. * (2) add {@link OptimisticLockerInterceptor}:<BR>
  43. * SQL: update tbl_test set name='abc',version=2 where id=100001 and version=1;
  44. * </p>
  45. *
  46. * @author yuxiaobin
  47. * @date 2017/5/24
  48. */
  49. @Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})
  50. public class OptimisticLockerInterceptor implements Interceptor {
  51. private final Map<Class<?>, EntityField> versionFieldCache = new ConcurrentHashMap<>();
  52. private final Map<Class<?>, List<EntityField>> entityFieldsCache = new ConcurrentHashMap<>();
  53. private static final String MP_OPTLOCK_VERSION_ORIGINAL = "MP_OPTLOCK_VERSION_ORIGINAL";
  54. private static final String MP_OPTLOCK_VERSION_COLUMN = "MP_OPTLOCK_VERSION_COLUMN";
  55. public static final String MP_OPTLOCK_ET_ORIGINAL = "MP_OPTLOCK_ET_ORIGINAL";
  56. private static final String NAME_ENTITY = "et";
  57. private static final String NAME_ENTITY_WRAPPER = "ew";
  58. private static final String PARAM_UPDATE_METHOD_NAME = "update";
  59. @Override
  60. @SuppressWarnings("unchecked")
  61. public Object intercept(Invocation invocation) throws Throwable {
  62. Object[] args = invocation.getArgs();
  63. MappedStatement ms = (MappedStatement) args[0];
  64. if (SqlCommandType.UPDATE != ms.getSqlCommandType()) {
  65. return invocation.proceed();
  66. }
  67. Object param = args[1];
  68. if (param instanceof MapperMethod.ParamMap) {
  69. MapperMethod.ParamMap map = (MapperMethod.ParamMap) param;
  70. Wrapper ew = null;
  71. if (map.containsKey(NAME_ENTITY_WRAPPER)) {//mapper.update(updEntity, EntityWrapper<>(whereEntity);
  72. ew = (Wrapper) map.get(NAME_ENTITY_WRAPPER);
  73. }//else updateById(entity) -->> change updateById(entity) to updateById(@Param("et") entity)
  74. // TODO 待验证逻辑
  75. // if mannual sql or updagteById(entity),unsupport OCC,proceed as usual unless use updateById(@Param("et") entity)
  76. //if(!map.containsKey(NAME_ENTITY)) {
  77. // return invocation.proceed();
  78. //}
  79. Object et = null;
  80. if (map.containsKey(NAME_ENTITY)) {
  81. et = map.get(NAME_ENTITY);
  82. }
  83. if (ew != null) {
  84. Object entity = ew.getEntity();
  85. if (entity != null) {
  86. Class<?> entityClass = ClassUtils.getUserClass(entity.getClass());
  87. EntityField ef = getVersionField(entityClass);
  88. Field versionField = ef == null ? null : ef.getField();
  89. if (versionField != null) {
  90. Object originalVersionVal = versionField.get(entity);
  91. if (originalVersionVal != null) {
  92. versionField.set(et, getUpdatedVersionVal(originalVersionVal));
  93. }
  94. }
  95. }
  96. } else if (et != null) {
  97. String methodId = ms.getId();
  98. String updateMethodName = methodId.substring(ms.getId().lastIndexOf(".") + 1);
  99. if (PARAM_UPDATE_METHOD_NAME.equals(updateMethodName)) {//update(entityClass, null) -->> update all. ignore version
  100. return invocation.proceed();
  101. }
  102. Class<?> entityClass = ClassUtils.getUserClass(et.getClass());
  103. EntityField entityField = this.getVersionField(entityClass);
  104. Field versionField = entityField == null ? null : entityField.getField();
  105. Object originalVersionVal;
  106. if (versionField != null && (originalVersionVal = versionField.get(et)) != null) {
  107. TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
  108. Map<String, Object> entityMap = new HashMap<>();
  109. List<EntityField> fields = getEntityFields(entityClass);
  110. for (EntityField ef : fields) {
  111. Field fd = ef.getField();
  112. if (fd.isAccessible()) {
  113. entityMap.put(fd.getName(), fd.get(et));
  114. if (ef.isVersion()) {
  115. versionField = fd;
  116. }
  117. }
  118. }
  119. String versionPropertyName = versionField.getName();
  120. List<TableFieldInfo> fieldList = tableInfo.getFieldList();
  121. String versionColumnName = entityField.getColumnName();
  122. if (versionColumnName == null) {
  123. for (TableFieldInfo tf : fieldList) {
  124. if (versionPropertyName.equals(tf.getProperty())) {
  125. versionColumnName = tf.getColumn();
  126. }
  127. }
  128. }
  129. if (versionColumnName != null) {
  130. entityField.setColumnName(versionColumnName);
  131. entityMap.put(versionField.getName(), getUpdatedVersionVal(originalVersionVal));
  132. entityMap.put(MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
  133. entityMap.put(MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
  134. entityMap.put(MP_OPTLOCK_ET_ORIGINAL, et);
  135. map.put(NAME_ENTITY, entityMap);
  136. }
  137. }
  138. }
  139. }
  140. return invocation.proceed();
  141. }
  142. /**
  143. * This method provides the control for version value.<BR>
  144. * Returned value type must be the same as original one.
  145. *
  146. * @param originalVersionVal
  147. * @return updated version val
  148. */
  149. protected Object getUpdatedVersionVal(Object originalVersionVal) {
  150. Class<?> versionValClass = originalVersionVal.getClass();
  151. if (long.class.equals(versionValClass)) {
  152. return ((long) originalVersionVal) + 1;
  153. } else if (Long.class.equals(versionValClass)) {
  154. return ((Long) originalVersionVal) + 1;
  155. } else if (int.class.equals(versionValClass)) {
  156. return ((int) originalVersionVal) + 1;
  157. } else if (Integer.class.equals(versionValClass)) {
  158. return ((Integer) originalVersionVal) + 1;
  159. } else if (Date.class.equals(versionValClass)) {
  160. return new Date();
  161. } else if (Timestamp.class.equals(versionValClass)) {
  162. return new Timestamp(System.currentTimeMillis());
  163. }
  164. return originalVersionVal;//not supported type, return original val.
  165. }
  166. @Override
  167. public Object plugin(Object target) {
  168. if (target instanceof Executor) {
  169. return Plugin.wrap(target, this);
  170. }
  171. return target;
  172. }
  173. @Override
  174. public void setProperties(Properties properties) {
  175. // to do nothing
  176. }
  177. private EntityField getVersionField(Class<?> parameterClass) {
  178. synchronized (parameterClass.getName()) {
  179. if (versionFieldCache.containsKey(parameterClass)) {
  180. return versionFieldCache.get(parameterClass);
  181. }
  182. // 缓存类信息
  183. EntityField field = this.getVersionFieldRegular(parameterClass);
  184. if(field != null) {
  185. versionFieldCache.put(parameterClass, field);
  186. return field;
  187. }
  188. return null;
  189. }
  190. }
  191. /**
  192. * <p>
  193. * 反射检查参数类是否启动乐观锁
  194. * </p>
  195. *
  196. * @param parameterClass 参数类
  197. * @return
  198. */
  199. private EntityField getVersionFieldRegular(Class<?> parameterClass) {
  200. if (parameterClass != Object.class) {
  201. for (Field field : parameterClass.getDeclaredFields()) {
  202. if (field.isAnnotationPresent(Version.class)) {
  203. field.setAccessible(true);
  204. return new EntityField(field, true);
  205. }
  206. }
  207. // 递归父类
  208. return this.getVersionFieldRegular(parameterClass.getSuperclass());
  209. }
  210. return null;
  211. }
  212. private List<EntityField> getEntityFields(Class<?> parameterClass) {
  213. if (entityFieldsCache.containsKey(parameterClass)) {
  214. return entityFieldsCache.get(parameterClass);
  215. }
  216. List<EntityField> fields = this.getFieldsFromClazz(parameterClass, null);
  217. entityFieldsCache.put(parameterClass, fields);
  218. return fields;
  219. }
  220. private List<EntityField> getFieldsFromClazz(Class<?> parameterClass, List<EntityField> fieldList) {
  221. if (fieldList == null) {
  222. fieldList = new ArrayList<>();
  223. }
  224. List<Field> fields = ReflectionKit.getFieldList(parameterClass);
  225. for (Field field : fields) {
  226. field.setAccessible(true);
  227. if (field.isAnnotationPresent(Version.class)) {
  228. fieldList.add(new EntityField(field, true));
  229. } else {
  230. fieldList.add(new EntityField(field, false));
  231. }
  232. }
  233. return fieldList;
  234. }
  235. }
  236. class EntityField {
  237. private Field field;
  238. private boolean version;
  239. private String columnName;
  240. public EntityField(Field field, boolean version) {
  241. this.field = field;
  242. this.version = version;
  243. }
  244. public Field getField() {
  245. return field;
  246. }
  247. public void setField(Field field) {
  248. this.field = field;
  249. }
  250. public boolean isVersion() {
  251. return version;
  252. }
  253. public void setVersion(boolean version) {
  254. this.version = version;
  255. }
  256. public String getColumnName() {
  257. return columnName;
  258. }
  259. public void setColumnName(String columnName) {
  260. this.columnName = columnName;
  261. }
  262. }