Przeglądaj źródła

bugfix for Opt Locker : update(entity, null)

add more test case
yuxiaobin 8 lat temu
rodzic
commit
811560b48c

+ 41 - 32
mybatis-plus/src/main/java/com/baomidou/mybatisplus/plugins/OptimisticLockerInterceptor.java

@@ -55,7 +55,14 @@ public class OptimisticLockerInterceptor implements Interceptor {
     private final Map<Class<?>, EntityField> versionFieldCache = new HashMap<>();
     private final Map<Class<?>, List<EntityField>> entityFieldsCache = new HashMap<>();
 
+    private static final String MP_OPTLOCK_VERSION_ORIGINAL = "MP_OPTLOCK_VERSION_ORIGINAL";
+    private static final String MP_OPTLOCK_VERSION_COLUMN = "MP_OPTLOCK_VERSION_COLUMN";
+    private static final String NAME_ENTITY = "et";
+    private static final String NAME_ENTITY_WRAPPER = "ew";
+    private static final String PARAM_UPDATE_METHOD_NAME = "update";
+
     @Override
+    @SuppressWarnings("unchecked")
     public Object intercept(Invocation invocation) throws Throwable {
         Object[] args = invocation.getArgs();
         MappedStatement ms = (MappedStatement) args[0];
@@ -66,11 +73,10 @@ public class OptimisticLockerInterceptor implements Interceptor {
         if(param instanceof MapperMethod.ParamMap){
             MapperMethod.ParamMap map = (MapperMethod.ParamMap) param;
             Wrapper ew = null;
-            if(map.containsKey("ew")){
-                //mapper.update(updEntity, EntityWrapper<>(whereEntity);
-                ew = (Wrapper) map.get("ew");
+            if(map.containsKey(NAME_ENTITY_WRAPPER)){//mapper.update(updEntity, EntityWrapper<>(whereEntity);
+                ew = (Wrapper) map.get(NAME_ENTITY_WRAPPER);
             }//else updateById(entity) -->> change updateById(entity) to updateById(@Param("et") entity)
-            Object et = map.get("et");
+            Object et = map.get(NAME_ENTITY);
             if(ew!=null){
                 Object entity = ew.getEntity();
                 if(entity!=null){
@@ -84,40 +90,43 @@ public class OptimisticLockerInterceptor implements Interceptor {
                     }
                 }
             }else{
+                String methodId = ms.getId();
+                String updateMethodName = methodId.substring(ms.getId().lastIndexOf(".")+1);
+                if(PARAM_UPDATE_METHOD_NAME.equals(updateMethodName)){//update(entity, null) -->> update all. ignore version
+                    return invocation.proceed();
+                }
                 EntityField entityField = getVersionField(et.getClass());
                 Field versionField = entityField==null?null:entityField.getField();
-                if(versionField!=null) {
-                    Object originalVersionVal = versionField.get(et);
-                    if (originalVersionVal != null) {
-                        TableInfo tableInfo = TableInfoHelper.getTableInfo(et.getClass());
-                        Map<String,Object> entityMap = new HashMap<>();
-                        List<EntityField> fields = getEntityFields(et.getClass());
-                        for(EntityField ef : fields){
-                            Field fd = ef.getField();
-                            if(fd.isAccessible()) {
-                                entityMap.put(fd.getName(), fd.get(et));
-                                if (ef.isVersion()) {
-                                    versionField = fd;
-                                }
+                Object originalVersionVal;
+                if(versionField!=null && (originalVersionVal=versionField.get(et))!=null) {
+                    TableInfo tableInfo = TableInfoHelper.getTableInfo(et.getClass());
+                    Map<String,Object> entityMap = new HashMap<>();
+                    List<EntityField> fields = getEntityFields(et.getClass());
+                    for(EntityField ef : fields){
+                        Field fd = ef.getField();
+                        if(fd.isAccessible()) {
+                            entityMap.put(fd.getName(), fd.get(et));
+                            if (ef.isVersion()) {
+                                versionField = fd;
                             }
                         }
-                        String versionPropertyName = versionField.getName();
-                        List<TableFieldInfo> fieldList = tableInfo.getFieldList();
-                        String versionColumnName = entityField.getColumnName();
-                        if(versionColumnName==null) {
-                            for (TableFieldInfo tf : fieldList) {
-                                if (versionPropertyName.equals(tf.getProperty())) {
-                                    versionColumnName = tf.getColumn();
-                                }
+                    }
+                    String versionPropertyName = versionField.getName();
+                    List<TableFieldInfo> fieldList = tableInfo.getFieldList();
+                    String versionColumnName = entityField.getColumnName();
+                    if(versionColumnName==null) {
+                        for (TableFieldInfo tf : fieldList) {
+                            if (versionPropertyName.equals(tf.getProperty())) {
+                                versionColumnName = tf.getColumn();
                             }
                         }
-                        if (versionColumnName != null) {
-                            entityField.setColumnName(versionColumnName);
-                            entityMap.put(versionField.getName(), getUpdatedVersionVal(originalVersionVal));
-                            entityMap.put("MP_OPTLOCK_VERSION_ORIGINAL", originalVersionVal);
-                            entityMap.put("MP_OPTLOCK_VERSION_COLUMN", versionColumnName);
-                            map.put("et", entityMap);
-                        }
+                    }
+                    if (versionColumnName != null) {
+                        entityField.setColumnName(versionColumnName);
+                        entityMap.put(versionField.getName(), getUpdatedVersionVal(originalVersionVal));
+                        entityMap.put(MP_OPTLOCK_VERSION_ORIGINAL, originalVersionVal);
+                        entityMap.put(MP_OPTLOCK_VERSION_COLUMN, versionColumnName);
+                        map.put(NAME_ENTITY, entityMap);
                     }
                 }
             }

+ 311 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserNoOptLockTest.java

@@ -0,0 +1,311 @@
+package com.baomidou.mybatisplus.test.h2;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import javax.sql.DataSource;
+
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.support.ClassPathXmlApplicationContext;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
+
+import com.baomidou.mybatisplus.mapper.EntityWrapper;
+import com.baomidou.mybatisplus.plugins.Page;
+import com.baomidou.mybatisplus.test.h2.entity.persistent.H2User;
+import com.baomidou.mybatisplus.test.h2.entity.service.IH2UserService;
+
+/**
+ * <p>
+ * Mybatis Plus H2 Junit Test
+ * </p>
+ *
+ * @author Caratacus
+ * @date 2017/4/1
+ */
+@RunWith(SpringJUnit4ClassRunner.class)
+@ContextConfiguration(locations = {"classpath:h2/spring-test-no-opt-lock-h2.xml"})
+public class H2UserNoOptLockTest {
+
+    @Autowired
+    private IH2UserService userService;
+
+    @BeforeClass
+    public static void initDB() throws SQLException, IOException {
+        @SuppressWarnings("resource")
+        ApplicationContext context = new ClassPathXmlApplicationContext("classpath:h2/spring-test-h2.xml");
+        DataSource ds = (DataSource) context.getBean("dataSource");
+        try (Connection conn = ds.getConnection()) {
+            String createTableSql = readFile("user.ddl.sql");
+            Statement stmt = conn.createStatement();
+            stmt.execute(createTableSql);
+            stmt.execute("truncate table h2user");
+            insertUsers(stmt);
+            conn.commit();
+        }
+    }
+
+    private static void insertUsers(Statement stmt) throws SQLException, IOException {
+        String filename = "user.insert.sql";
+        String filePath = H2UserNoOptLockTest.class.getClassLoader().getResource("").getPath() + "/h2/" + filename;
+        try (
+                BufferedReader reader = new BufferedReader(new FileReader(filePath))
+        ) {
+            String line;
+            while ((line = reader.readLine()) != null) {
+                stmt.execute(line.replace(";", ""));
+            }
+        }
+    }
+
+    private static String readFile(String filename) {
+        StringBuilder builder = new StringBuilder();
+        String filePath = H2UserNoOptLockTest.class.getClassLoader().getResource("").getPath() + "/h2/" + filename;
+        try (
+                BufferedReader reader = new BufferedReader(new FileReader(filePath))
+        ) {
+            String line;
+            while ((line = reader.readLine()) != null)
+                builder.append(line).append(" ");
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+        return builder.toString();
+    }
+
+    @Test
+    public void testInsert() {
+        H2User user = new H2User();
+        user.setAge(1);
+        user.setPrice(new BigDecimal("9.99"));
+        userService.insert(user);
+        Assert.assertNotNull(user.getId());
+        user.setDesc("Caratacus");
+        userService.insertOrUpdate(user);
+        H2User userFromDB = userService.selectById(user.getId());
+        Assert.assertEquals("Caratacus", userFromDB.getDesc());
+    }
+
+    @Test
+    public void testDelete() {
+        H2User user = new H2User();
+        user.setAge(1);
+        user.setPrice(new BigDecimal("9.99"));
+        userService.insert(user);
+        Long userId = user.getId();
+        Assert.assertNotNull(userId);
+        userService.deleteById(userId);
+        Assert.assertNull(userService.selectById(userId));
+    }
+
+    @Test
+    public void testSelectByid() {
+        Long userId = 101L;
+        Assert.assertNotNull(userService.selectById(userId));
+    }
+
+    @Test
+    public void testSelectOne() {
+        H2User user = new H2User();
+        user.setId(105L);
+        EntityWrapper<H2User> ew = new EntityWrapper<>(user);
+        H2User userFromDB = userService.selectOne(ew);
+        Assert.assertNotNull(userFromDB);
+    }
+
+    @Test
+    public void testSelectList() {
+        H2User user = new H2User();
+        EntityWrapper<H2User> ew = new EntityWrapper<>(user);
+        List<H2User> list = userService.selectList(ew);
+        Assert.assertNotNull(list);
+        Assert.assertNotEquals(0, list.size());
+    }
+
+    @Test
+    public void testSelectPage() {
+        Page<H2User> page = userService.selectPage(new Page<H2User>(1, 3));
+        Assert.assertEquals(3, page.getRecords().size());
+    }
+
+    @Test
+    public void testUpdateByIdOptLock(){
+        Long id = 991L;
+        H2User user = new H2User();
+        user.setId(id);
+        user.setName("991");
+        user.setAge(91);
+        user.setPrice(BigDecimal.TEN);
+        user.setDesc("asdf");
+        user.setTestType(1);
+        user.setVersion(1);
+        userService.insertAllColumn(user);
+
+        H2User userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+
+        userDB.setName("991");
+        userService.updateById(userDB);
+
+        userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        Assert.assertEquals("991", userDB.getName());
+    }
+
+    @Test
+    public void testUpdateAllColumnByIdOptLock(){
+        Long id = 997L;
+        H2User user = new H2User();
+        user.setId(id);
+        user.setName("991");
+        user.setAge(91);
+        user.setPrice(BigDecimal.TEN);
+        user.setDesc("asdf");
+        user.setTestType(1);
+        user.setVersion(1);
+        userService.insertAllColumn(user);
+
+        H2User userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+
+        userDB.setName("991");
+        userService.updateAllColumnById(userDB);
+
+        userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        Assert.assertEquals("991", userDB.getName());
+
+        userDB.setName("990");
+        userService.updateById(userDB);
+
+        userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        Assert.assertEquals("990", userDB.getName());
+    }
+
+    @Test
+    public void testUpdateByEntityWrapperOptLock(){
+        Long id = 992L;
+        H2User user = new H2User();
+        user.setId(id);
+        user.setName("992");
+        user.setAge(92);
+        user.setPrice(BigDecimal.TEN);
+        user.setDesc("asdf");
+        user.setTestType(1);
+        user.setVersion(1);
+        userService.insertAllColumn(user);
+
+        H2User userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+
+        H2User updUser = new H2User();
+        updUser.setName("999");
+
+        userService.update(updUser, new EntityWrapper<>(userDB));
+
+        userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        Assert.assertEquals("999", userDB.getName());
+    }
+
+    @Test
+    public void testUpdateByEntityWrapperOptLockWithoutVersion(){
+        Long id = 993L;
+        H2User user = new H2User();
+        user.setId(id);
+        user.setName("992");
+        user.setAge(92);
+        user.setPrice(BigDecimal.TEN);
+        user.setDesc("asdf");
+        user.setTestType(1);
+        user.setVersion(1);
+        userService.insertAllColumn(user);
+
+        H2User userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+
+        H2User updUser = new H2User();
+        updUser.setName("999");
+        userDB.setVersion(null);
+        userService.update(updUser, new EntityWrapper<>(userDB));
+
+        userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        Assert.assertEquals("999", userDB.getName());
+    }
+
+    @Test
+    public void testUpdateBatch(){
+        List<H2User> list = userService.selectList(new EntityWrapper<H2User>());
+        Map<Long, Integer> userVersionMap = new HashMap<>();
+        for(H2User u:list){
+            userVersionMap.put(u.getId(),u.getVersion());
+        }
+
+        Assert.assertTrue(userService.updateBatchById(list));
+        list = userService.selectList(new EntityWrapper<H2User>());
+        for(H2User user:list){
+            Assert.assertEquals(userVersionMap.get(user.getId()).intValue(), user.getVersion().intValue());
+        }
+
+    }
+
+    @Test
+    public void testUpdateInLoop(){
+        List<H2User> list = userService.selectList(new EntityWrapper<H2User>());
+        Map<Long,Integer> versionBefore = new HashMap<>();
+        Map<Long,String> nameExpect = new HashMap<>();
+        for (H2User h2User : list) {
+            Long id = h2User.getId();
+            Integer versionVal = h2User.getVersion();
+            versionBefore.put(id, versionVal);
+            String randomName = h2User.getName()+"_"+new Random().nextInt(10);
+            nameExpect.put(id, randomName);
+            h2User.setName(randomName);
+            userService.updateById(h2User);
+        }
+
+        list = userService.selectList(new EntityWrapper<H2User>());
+        for(H2User u:list){
+            Assert.assertEquals(u.getName(), nameExpect.get(u.getId()));
+            Assert.assertEquals(versionBefore.get(u.getId()).intValue(), u.getVersion().intValue());
+        }
+    }
+    @Test
+    public void testUpdateAllColumnInLoop(){
+        List<H2User> list = userService.selectList(new EntityWrapper<H2User>());
+        Map<Long,Integer> versionBefore = new HashMap<>();
+        Map<Long,String> nameExpect = new HashMap<>();
+        for (H2User h2User : list) {
+            Long id = h2User.getId();
+            Integer versionVal = h2User.getVersion();
+            versionBefore.put(id, versionVal);
+            String randomName = h2User.getName()+"_"+new Random().nextInt(10);
+            nameExpect.put(id, randomName);
+            h2User.setName(randomName);
+            userService.updateAllColumnById(h2User);
+        }
+
+        list = userService.selectList(new EntityWrapper<H2User>());
+        for(H2User u:list){
+            Assert.assertEquals(u.getName(), nameExpect.get(u.getId()));
+            Assert.assertEquals(versionBefore.get(u.getId()).intValue(), u.getVersion().intValue());
+        }
+    }
+
+}

+ 60 - 1
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/H2UserTest.java

@@ -224,7 +224,7 @@ public class H2UserTest {
     }
 
     @Test
-    public void testUpdateByEntityWrapperOptLockWithoutVersion(){
+    public void testUpdateByEntityWrapperOptLockWithoutVersionVal(){
         Long id = 993L;
         H2User user = new H2User();
         user.setId(id);
@@ -249,6 +249,65 @@ public class H2UserTest {
         Assert.assertEquals("999", userDB.getName());
     }
 
+    @Test
+    public void testUpdateByEntityWrapperNoEntity(){
+        Long id = 998L;
+        H2User user = new H2User();
+        user.setId(id);
+        user.setName("992");
+        user.setAge(92);
+        user.setPrice(BigDecimal.TEN);
+        user.setDesc("asdf");
+        user.setTestType(1);
+        user.setVersion(1);
+        userService.insertAllColumn(user);
+
+        H2User userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        H2User updateUser = new H2User();
+        updateUser.setName("998");
+        boolean result = userService.update(updateUser, new EntityWrapper<H2User>());
+        Assert.assertTrue(result);
+        userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        EntityWrapper<H2User> param = new EntityWrapper<>();
+        param.eq("name","998");
+        List<H2User> userList = userService.selectList(param);
+        Assert.assertTrue(userList.size()>1);
+    }
+
+    @Test
+    public void testUpdateByEntityWrapperNull(){
+        Long id = 918L;
+        H2User user = new H2User();
+        user.setId(id);
+        user.setName("992");
+        user.setAge(92);
+        user.setPrice(BigDecimal.TEN);
+        user.setDesc("asdf");
+        user.setTestType(1);
+        user.setVersion(1);
+        userService.insertAllColumn(user);
+
+        H2User userDB = userService.selectById(id);
+        Assert.assertEquals(1, userDB.getVersion().intValue());
+        H2User updateUser = new H2User();
+        updateUser.setName("918");
+        updateUser.setVersion(1);
+        Assert.assertTrue(userService.update(updateUser,null));
+        EntityWrapper<H2User> ew = new EntityWrapper<>();
+        int count1 = userService.selectCount(ew);
+        ew.eq("name","918").eq("version",1);
+        int count2 = userService.selectCount(ew);
+        List<H2User> userList = userService.selectList(new EntityWrapper<H2User>());
+        for(H2User u:userList){
+            System.out.println(u);
+        }
+        System.out.println("count1="+count1+", count2="+count2);
+        Assert.assertTrue(count2>0);
+        Assert.assertEquals(count1,count2);
+    }
+
     @Test
     public void testUpdateBatch(){
         List<H2User> list = userService.selectList(new EntityWrapper<H2User>());

+ 56 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/config/MybatisPlusNoOptLockConfig.java

@@ -0,0 +1,56 @@
+package com.baomidou.mybatisplus.test.h2.config;
+
+import javax.sql.DataSource;
+
+import org.apache.ibatis.plugin.Interceptor;
+import org.apache.ibatis.session.SqlSessionFactory;
+import org.apache.ibatis.type.JdbcType;
+import org.mybatis.spring.annotation.MapperScan;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.core.io.ResourceLoader;
+
+import com.baomidou.mybatisplus.MybatisConfiguration;
+import com.baomidou.mybatisplus.MybatisXMLLanguageDriver;
+import com.baomidou.mybatisplus.entity.GlobalConfiguration;
+import com.baomidou.mybatisplus.plugins.PaginationInterceptor;
+import com.baomidou.mybatisplus.spring.MybatisSqlSessionFactoryBean;
+
+/**
+ * <p>
+ * Mybatis Plus Config without OptimisLock
+ * </p>
+ *
+ * @author Caratacus
+ * @date 2017/4/1
+ */
+@Configuration
+@MapperScan("com.baomidou.mybatisplus.test.h2.entity.mapper")
+public class MybatisPlusNoOptLockConfig {
+
+    @Bean("mybatisSqlSession")
+    public SqlSessionFactory sqlSessionFactory(DataSource dataSource, ResourceLoader resourceLoader, GlobalConfiguration globalConfiguration) throws Exception {
+        MybatisSqlSessionFactoryBean sqlSessionFactory = new MybatisSqlSessionFactoryBean();
+        sqlSessionFactory.setDataSource(dataSource);
+//        sqlSessionFactory.setConfigLocation(resourceLoader.getResource("classpath:mybatis-config.xml"));
+        sqlSessionFactory.setTypeAliasesPackage("com.baomidou.mybatisplus.test.h2.entity.persistent");
+        MybatisConfiguration configuration = new MybatisConfiguration();
+        configuration.setDefaultScriptingLanguage(MybatisXMLLanguageDriver.class);
+        configuration.setJdbcTypeForNull(JdbcType.NULL);
+        sqlSessionFactory.setConfiguration(configuration);
+        PaginationInterceptor pagination = new PaginationInterceptor();
+        pagination.setDialectType("h2");
+        sqlSessionFactory.setPlugins(new Interceptor[]{
+                pagination,
+        });
+        sqlSessionFactory.setGlobalConfig(globalConfiguration);
+        return sqlSessionFactory.getObject();
+    }
+
+    @Bean
+    public GlobalConfiguration globalConfiguration() {
+        GlobalConfiguration globalConfiguration = new GlobalConfiguration();
+        globalConfiguration.setIdType(2);
+        return globalConfiguration;
+    }
+}