Sfoglia il codice sorgente

fix:https://github.com/baomidou/mybatis-plus/issues/373

聂秋秋 6 anni fa
parent
commit
6e37c5ef42

+ 1 - 0
mybatis-plus-core/build.gradle

@@ -6,6 +6,7 @@ dependencies {
     compile project(":mybatis-plus-annotation")
     compile rootProject.ext.dependencies["mybatis"]
     compile rootProject.ext.dependencies["jsqlparser"]
+    compile rootProject.ext.dependencies["mybatis-spring"]
 
     provided rootProject.ext.dependencies["cglib"]
     provided rootProject.ext.dependencies["spring-aop"]

+ 7 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/config/GlobalConfig.java

@@ -19,6 +19,7 @@ import java.io.Serializable;
 import java.util.Set;
 import java.util.concurrent.ConcurrentSkipListSet;
 
+import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
 
 import com.baomidou.mybatisplus.annotation.DbType;
@@ -31,6 +32,7 @@ import com.baomidou.mybatisplus.core.toolkit.GlobalConfigUtils;
 
 import lombok.Data;
 import lombok.experimental.Accessors;
+import org.mybatis.spring.SqlSessionTemplate;
 
 /**
  * <p>
@@ -65,6 +67,10 @@ public class GlobalConfig implements Serializable {
      * SQL注入器
      */
     private ISqlInjector sqlInjector;
+    /**
+     * 单例重用SqlSession
+     */
+    private SqlSession sqlSession;
     /**
      * 缓存当前Configuration的SqlSessionFactory
      */
@@ -85,6 +91,7 @@ public class GlobalConfig implements Serializable {
      */
     public SqlSessionFactory signGlobalConfig(SqlSessionFactory sqlSessionFactory) {
         if (null != sqlSessionFactory) {
+            this.sqlSession = new SqlSessionTemplate(sqlSessionFactory);
             GlobalConfigUtils.setGlobalConfig(sqlSessionFactory.getConfiguration(), this);
         }
         return sqlSessionFactory;

+ 5 - 5
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/TableInfoHelper.java

@@ -78,14 +78,14 @@ public class TableInfoHelper {
             return tableInfo;
         }
         //尝试获取父类缓存
-        Class c = clazz;
-        while (null == tableInfo && Object.class != c) {
-            c = c.getSuperclass();
-            tableInfo = TABLE_INFO_CACHE.get(ClassUtils.getUserClass(c).getName());
+        Class currentClass = clazz;
+        while (null == tableInfo && Object.class != currentClass) {
+            currentClass = currentClass.getSuperclass();
+            tableInfo = TABLE_INFO_CACHE.get(ClassUtils.getUserClass(currentClass).getName());
         }
         if (null == tableInfo) {
             //找不到了,我也很绝望呀
-            logger.warn(ClassUtils.getUserClass(clazz).getName() + "Not Found TableInfoCache.");
+            throw ExceptionUtils.mpe(ClassUtils.getUserClass(clazz).getName() + "Not Found TableInfoCache.");
         } else {
             TABLE_INFO_CACHE.put(ClassUtils.getUserClass(clazz).getName(), tableInfo);
         }

+ 20 - 17
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/sql/SqlHelper.java

@@ -22,9 +22,11 @@ import com.baomidou.mybatisplus.core.metadata.TableInfo;
 import com.baomidou.mybatisplus.core.toolkit.*;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
+import org.apache.ibatis.session.Configuration;
 import org.apache.ibatis.session.ExecutorType;
 import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
+import org.mybatis.spring.SqlSessionTemplate;
 
 import java.util.List;
 
@@ -41,44 +43,45 @@ public final class SqlHelper {
     private static final Log logger = LogFactory.getLog(SqlHelper.class);
     public static SqlSessionFactory FACTORY;
 
-
+    
     /**
      * <p>
-     * 获取Session 默认自动提交
-     * </p>
-     * <p>
-     * 特别说明:这里获取SqlSession时这里虽然设置了自动提交但是如果事务托管了的话 是不起作用的 切记!!
+     * 批量操作 SqlSession
      * </p>
      *
+     * @param clazz 实体类
      * @return SqlSession
      */
-    public static SqlSession sqlSession(Class<?> clazz) {
-        return SqlHelper.sqlSession(clazz, true);
+    public static SqlSession sqlSessionBatch(Class<?> clazz) {
+        return GlobalConfigUtils.currentSessionFactory(clazz).openSession(ExecutorType.BATCH);
     }
-
+    
     /**
      * <p>
-     * 批量操作 SqlSession
+     * 获取sqlSession
      * </p>
      *
-     * @param clazz 实体
-     * @return SqlSession
+     * @param clazz 对象
+     * @return
      */
-    public static SqlSession sqlSessionBatch(Class<?> clazz) {
-        return GlobalConfigUtils.currentSessionFactory(clazz).openSession(ExecutorType.BATCH);
+    private static SqlSession getSqlSession(Class<?> clazz) {
+        SqlSession session;
+        SqlSessionFactory sqlSessionFactory = GlobalConfigUtils.currentSessionFactory(clazz);
+        Configuration configuration = sqlSessionFactory.getConfiguration();
+        session = GlobalConfigUtils.getGlobalConfig(configuration).getSqlSession();
+        return session !=null ? session : new SqlSessionTemplate(sqlSessionFactory);
     }
-
+    
     /**
      * <p>
      * 获取Session
      * </p>
      *
      * @param clazz      实体类
-     * @param autoCommit true自动提交false则相反
      * @return SqlSession
      */
-    public static SqlSession sqlSession(Class<?> clazz, boolean autoCommit) {
-        return GlobalConfigUtils.currentSessionFactory(clazz).openSession(autoCommit);
+    public static SqlSession sqlSession(Class<?> clazz) {
+        return SqlHelper.getSqlSession(clazz);
     }
 
     /**

+ 10 - 30
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/activerecord/Model.java

@@ -51,9 +51,7 @@ public abstract class Model<T extends Model> implements Serializable {
      */
     @Transactional(rollbackFor = Exception.class)
     public boolean insert() {
-        try (SqlSession session = sqlSession()) {
-            return SqlHelper.retBool(session.insert(sqlStatement(SqlMethod.INSERT_ONE), this));
-        }
+        return SqlHelper.retBool(sqlSession().insert(sqlStatement(SqlMethod.INSERT_ONE), this));
     }
 
     /**
@@ -84,9 +82,7 @@ public abstract class Model<T extends Model> implements Serializable {
      */
     @Transactional(rollbackFor = Exception.class)
     public boolean deleteById(Serializable id) {
-        try (SqlSession session = sqlSession()) {
-            return SqlHelper.delBool(session.delete(sqlStatement(SqlMethod.DELETE_BY_ID), id));
-        }
+        return SqlHelper.delBool(sqlSession().delete(sqlStatement(SqlMethod.DELETE_BY_ID), id));
     }
 
     /**
@@ -114,9 +110,7 @@ public abstract class Model<T extends Model> implements Serializable {
     public boolean delete(Wrapper wrapper) {
         Map<String, Object> map = new HashMap<>(1);
         map.put(Constants.WRAPPER, wrapper);
-        try (SqlSession session = sqlSession()) {
-            return SqlHelper.delBool(session.delete(sqlStatement(SqlMethod.DELETE), map));
-        }
+        return SqlHelper.delBool(sqlSession().delete(sqlStatement(SqlMethod.DELETE), map));
     }
 
     /**
@@ -130,9 +124,7 @@ public abstract class Model<T extends Model> implements Serializable {
         // updateById
         Map<String, Object> map = new HashMap<>(1);
         map.put(Constants.ENTITY, this);
-        try(SqlSession sqlSession = sqlSession()) {
-            return SqlHelper.retBool(sqlSession.update(sqlStatement(SqlMethod.UPDATE_BY_ID), map));
-        }
+        return SqlHelper.retBool(sqlSession().update(sqlStatement(SqlMethod.UPDATE_BY_ID), map));
     }
 
     /**
@@ -149,9 +141,7 @@ public abstract class Model<T extends Model> implements Serializable {
         map.put(Constants.ENTITY, this);
         map.put(Constants.WRAPPER, wrapper);
         // update
-        try (SqlSession session = sqlSession()) {
-            return SqlHelper.retBool(session.update(sqlStatement(SqlMethod.UPDATE), map));
-        }
+        return SqlHelper.retBool(sqlSession().update(sqlStatement(SqlMethod.UPDATE), map));
     }
 
     /**
@@ -162,9 +152,7 @@ public abstract class Model<T extends Model> implements Serializable {
      * @return
      */
     public List<T> selectAll() {
-        try (SqlSession session = sqlSession()) {
-            return session.selectList(sqlStatement(SqlMethod.SELECT_LIST));
-        }
+        return sqlSession().selectList(sqlStatement(SqlMethod.SELECT_LIST));
     }
 
     /**
@@ -176,9 +164,7 @@ public abstract class Model<T extends Model> implements Serializable {
      * @return
      */
     public T selectById(Serializable id) {
-        try (SqlSession session = sqlSession()) {
-            return session.selectOne(sqlStatement(SqlMethod.SELECT_BY_ID), id);
-        }
+        return sqlSession().selectOne(sqlStatement(SqlMethod.SELECT_BY_ID), id);
     }
 
     /**
@@ -205,9 +191,7 @@ public abstract class Model<T extends Model> implements Serializable {
     public List<T> selectList(Wrapper wrapper) {
         Map<String, Object> map = new HashMap<>(1);
         map.put(Constants.WRAPPER, wrapper);
-        try (SqlSession session = sqlSession()) {
-            return session.selectList(sqlStatement(SqlMethod.SELECT_LIST), map);
-        }
+        return sqlSession().selectList(sqlStatement(SqlMethod.SELECT_LIST), map);
     }
 
     /**
@@ -235,9 +219,7 @@ public abstract class Model<T extends Model> implements Serializable {
         Map<String, Object> map = new HashMap<>(2);
         map.put(Constants.WRAPPER, SqlHelper.fillWrapper(page, wrapper));
         map.put("page", page);
-        try (SqlSession session = sqlSession()) {
-            page.setRecords(session.selectList(sqlStatement(SqlMethod.SELECT_PAGE), map));
-        }
+        page.setRecords(sqlSession().selectList(sqlStatement(SqlMethod.SELECT_PAGE), map));
         return page;
     }
 
@@ -252,9 +234,7 @@ public abstract class Model<T extends Model> implements Serializable {
     public int selectCount(Wrapper wrapper) {
         Map<String, Object> map = new HashMap<>(1);
         map.put(Constants.WRAPPER, wrapper);
-        try (SqlSession session = sqlSession()) {
-            return SqlHelper.retCount(session.<Integer>selectOne(sqlStatement(SqlMethod.SELECT_COUNT), map));
-        }
+        return SqlHelper.retCount(sqlSession().<Integer>selectOne(sqlStatement(SqlMethod.SELECT_COUNT), map));
     }
 
     /**

+ 17 - 23
mybatis-plus-extension/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SqlRunner.java

@@ -21,6 +21,7 @@ import java.util.Map;
 
 import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
+import org.mybatis.spring.SqlSessionTemplate;
 import org.springframework.transaction.annotation.Transactional;
 
 import com.baomidou.mybatisplus.core.assist.ISqlRunner;
@@ -43,6 +44,8 @@ public class SqlRunner implements ISqlRunner {
     // 默认FACTORY
 //    public static SqlSessionFactory FACTORY;
     private SqlSessionFactory sqlSessionFactory;
+    
+    private SqlSession sqlSession;
 
     private Class<?> clazz;
 
@@ -80,21 +83,17 @@ public class SqlRunner implements ISqlRunner {
     public static SqlRunner db(Class<?> clazz) {
         return new SqlRunner(clazz);
     }
-
+    
     @Transactional
     @Override
     public boolean insert(String sql, Object... args) {
-        try(SqlSession session = sqlSession()) {
-            return SqlHelper.retBool(session.insert(INSERT, sqlMap(sql, args)));
-        }
+        return SqlHelper.retBool(sqlSession().insert(INSERT, sqlMap(sql, args)));
     }
-
+    
     @Transactional
     @Override
     public boolean delete(String sql, Object... args) {
-        try(SqlSession session = sqlSession()) {
-            return SqlHelper.retBool(session.delete(DELETE, sqlMap(sql, args)));
-        }
+        return SqlHelper.retBool(sqlSession().delete(DELETE, sqlMap(sql, args)));
     }
 
     /**
@@ -109,13 +108,11 @@ public class SqlRunner implements ISqlRunner {
         sqlMap.put(SQL, StringUtils.sqlArgsFill(sql, args));
         return sqlMap;
     }
-
+    
     @Transactional
     @Override
     public boolean update(String sql, Object... args) {
-        try(SqlSession session = sqlSession()) {
-            return SqlHelper.retBool(session.update(UPDATE, sqlMap(sql, args)));
-        }
+        return SqlHelper.retBool(sqlSession().update(UPDATE, sqlMap(sql, args)));
     }
 
     /**
@@ -128,9 +125,7 @@ public class SqlRunner implements ISqlRunner {
      */
     @Override
     public List<Map<String, Object>> selectList(String sql, Object... args) {
-        try(SqlSession session = sqlSession()) {
-            return session.selectList(SELECT_LIST, sqlMap(sql, args));
-        }
+        return sqlSession().selectList(SELECT_LIST, sqlMap(sql, args));
     }
 
     /**
@@ -143,9 +138,7 @@ public class SqlRunner implements ISqlRunner {
      */
     @Override
     public List<Object> selectObjs(String sql, Object... args) {
-        try(SqlSession session = sqlSession()) {
-            return session.selectList(SELECT_OBJS, sqlMap(sql, args));
-        }
+        return sqlSession().selectList(SELECT_OBJS, sqlMap(sql, args));
     }
 
     /**
@@ -160,12 +153,10 @@ public class SqlRunner implements ISqlRunner {
     public Object selectObj(String sql, Object... args) {
         return SqlHelper.getObject(selectObjs(sql, args));
     }
-
+    
     @Override
     public int selectCount(String sql, Object... args) {
-        try(SqlSession session = sqlSession()) {
-            return SqlHelper.retCount(session.<Integer>selectOne(COUNT, sqlMap(sql, args)));
-        }
+        return SqlHelper.retCount(sqlSession().<Integer>selectOne(COUNT, sqlMap(sql, args)));
     }
 
     @Override
@@ -190,7 +181,10 @@ public class SqlRunner implements ISqlRunner {
      * <p/>
      */
     private SqlSession sqlSession() {
-        return (clazz != null) ? SqlHelper.sqlSession(clazz) : SqlHelper.FACTORY.openSession(true);
+        if(sqlSession == null){
+            this.sqlSession = new SqlSessionTemplate(DEFAULT.sqlSessionFactory);
+        }
+        return (clazz != null) ? SqlHelper.sqlSession(clazz) : sqlSession;
     }
 
 }

+ 4 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/ActiveRecordTest.java

@@ -14,6 +14,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
+import org.springframework.transaction.annotation.Transactional;
+
 import java.io.IOException;
 import java.sql.SQLException;
 import java.util.List;
@@ -35,9 +37,11 @@ public class ActiveRecordTest {
     }
 
     @Test
+    @Transactional
     public void testInsert() {
         H2Student student = new H2Student(null, "测试学生", 2);
         Assert.assertTrue(student.insert());
+        Assert.assertTrue(student.insert());
     }
 
     @Test

+ 2 - 0
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/SqlRunnerTest.java

@@ -14,6 +14,7 @@ import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import com.baomidou.mybatisplus.extension.toolkit.SqlRunner;
 import com.baomidou.mybatisplus.test.h2.config.H2Db;
 import com.baomidou.mybatisplus.test.h2.entity.persistent.H2Student;
+import org.springframework.transaction.annotation.Transactional;
 
 
 /**
@@ -42,6 +43,7 @@ public class SqlRunnerTest {
     }
 
     @Test
+    @Transactional
     public void testInsert(){
         Assert.assertTrue(SqlRunner.db().insert("INSERT INTO h2student ( name, age ) VALUES ( {0}, {1} )","测试学生",2));
         Assert.assertTrue(SqlRunner.db(H2Student.class).insert("INSERT INTO h2student ( name, age ) VALUES ( {0}, {1} )","测试学生2",3));

+ 4 - 0
mybatis-plus/src/test/resources/logback.xml

@@ -10,4 +10,8 @@
     <logger name="com.baomidou.mybatisplus.test" level="DEBUG" additivity="false">
         <appender-ref ref="STDOUT"/>
     </logger>
+
+    <root level="debug">
+        <appender-ref ref="STDOUT"/>
+    </root>
 </configuration>