SqlHelper.java 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. /*
  2. * Copyright (c) 2011-2021, baomidou (jobob@qq.com).
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package com.baomidou.mybatisplus.extension.toolkit;
  17. import com.baomidou.mybatisplus.core.enums.SqlMethod;
  18. import com.baomidou.mybatisplus.core.metadata.TableInfo;
  19. import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
  20. import com.baomidou.mybatisplus.core.toolkit.*;
  21. import lombok.SneakyThrows;
  22. import org.apache.ibatis.exceptions.PersistenceException;
  23. import org.apache.ibatis.logging.Log;
  24. import org.apache.ibatis.reflection.ExceptionUtil;
  25. import org.apache.ibatis.session.ExecutorType;
  26. import org.apache.ibatis.session.SqlSession;
  27. import org.apache.ibatis.session.SqlSessionFactory;
  28. import org.mybatis.spring.MyBatisExceptionTranslator;
  29. import org.mybatis.spring.SqlSessionHolder;
  30. import org.mybatis.spring.SqlSessionUtils;
  31. import org.springframework.transaction.support.TransactionSynchronizationManager;
  32. import java.util.Collection;
  33. import java.util.List;
  34. import java.util.function.BiConsumer;
  35. import java.util.function.BiPredicate;
  36. import java.util.function.Consumer;
  37. import java.util.function.Supplier;
  38. /**
  39. * SQL 辅助类
  40. *
  41. * @author hubin
  42. * @since 2016-11-06
  43. */
  44. public final class SqlHelper {
  45. /**
  46. * 主要用于 service 和 ar
  47. */
  48. public static SqlSessionFactory FACTORY;
  49. /**
  50. * 批量操作 SqlSession
  51. *
  52. * @param clazz 实体类
  53. * @return SqlSession
  54. */
  55. public static SqlSession sqlSessionBatch(Class<?> clazz) {
  56. // TODO 暂时让能用先,但日志会显示Closing non transactional SqlSession,因为这个并没有绑定.
  57. return sqlSessionFactory(clazz).openSession(ExecutorType.BATCH);
  58. }
  59. /**
  60. * 获取SqlSessionFactory
  61. *
  62. * @param clazz 实体类
  63. * @return SqlSessionFactory
  64. * @since 3.3.0
  65. */
  66. public static SqlSessionFactory sqlSessionFactory(Class<?> clazz) {
  67. return GlobalConfigUtils.currentSessionFactory(clazz);
  68. }
  69. /**
  70. * 获取Session
  71. *
  72. * @param clazz 实体类
  73. * @return SqlSession
  74. */
  75. public static SqlSession sqlSession(Class<?> clazz) {
  76. return SqlSessionUtils.getSqlSession(GlobalConfigUtils.currentSessionFactory(clazz));
  77. }
  78. /**
  79. * 获取TableInfo
  80. *
  81. * @param clazz 对象类
  82. * @return TableInfo 对象表信息
  83. */
  84. public static TableInfo table(Class<?> clazz) {
  85. TableInfo tableInfo = TableInfoHelper.getTableInfo(clazz);
  86. Assert.notNull(tableInfo, "Error: Cannot execute table Method, ClassGenricType not found .");
  87. return tableInfo;
  88. }
  89. /**
  90. * 判断数据库操作是否成功
  91. *
  92. * @param result 数据库操作返回影响条数
  93. * @return boolean
  94. */
  95. public static boolean retBool(Integer result) {
  96. return null != result && result >= 1;
  97. }
  98. /**
  99. * 返回SelectCount执行结果
  100. *
  101. * @param result ignore
  102. * @return int
  103. */
  104. public static long retCount(Long result) {
  105. return (null == result) ? 0 : result;
  106. }
  107. /**
  108. * 从list中取第一条数据返回对应List中泛型的单个结果
  109. *
  110. * @param list ignore
  111. * @param <E> ignore
  112. * @return ignore
  113. */
  114. public static <E> E getObject(Log log, List<E> list) {
  115. return getObject(() -> log, list);
  116. }
  117. /**
  118. * @since 3.4.3
  119. */
  120. public static <E> E getObject(Supplier<Log> supplier, List<E> list) {
  121. if (CollectionUtils.isNotEmpty(list)) {
  122. int size = list.size();
  123. if (size > 1) {
  124. Log log = supplier.get();
  125. log.warn(String.format("Warn: execute Method There are %s results.", size));
  126. }
  127. return list.get(0);
  128. }
  129. return null;
  130. }
  131. /**
  132. * 执行批量操作
  133. *
  134. * @param entityClass 实体
  135. * @param log 日志对象
  136. * @param consumer consumer
  137. * @return 操作结果
  138. * @since 3.4.0
  139. */
  140. @SneakyThrows
  141. public static boolean executeBatch(Class<?> entityClass, Log log, Consumer<SqlSession> consumer) {
  142. SqlSessionFactory sqlSessionFactory = sqlSessionFactory(entityClass);
  143. SqlSession sqlSession = sqlSessionFactory.openSession(ExecutorType.BATCH);
  144. try {
  145. consumer.accept(sqlSession);
  146. sqlSession.commit(true);
  147. return true;
  148. } catch (Throwable t) {
  149. sqlSession.rollback(true);
  150. Throwable unwrapped = ExceptionUtil.unwrapThrowable(t);
  151. if (unwrapped instanceof PersistenceException) {
  152. MyBatisExceptionTranslator myBatisExceptionTranslator
  153. = new MyBatisExceptionTranslator(sqlSessionFactory.getConfiguration().getEnvironment().getDataSource(), true);
  154. Throwable throwable = myBatisExceptionTranslator.translateExceptionIfPossible((PersistenceException) unwrapped);
  155. if (throwable != null) {
  156. throw throwable;
  157. }
  158. }
  159. throw ExceptionUtils.mpe(unwrapped);
  160. } finally {
  161. sqlSession.close();
  162. }
  163. }
  164. /**
  165. * 执行批量操作
  166. *
  167. * @param entityClass 实体类
  168. * @param log 日志对象
  169. * @param list 数据集合
  170. * @param batchSize 批次大小
  171. * @param consumer consumer
  172. * @param <E> T
  173. * @return 操作结果
  174. * @since 3.4.0
  175. */
  176. public static <E> boolean executeBatch(Class<?> entityClass, Log log, Collection<E> list, int batchSize, BiConsumer<SqlSession, E> consumer) {
  177. Assert.isFalse(batchSize < 1, "batchSize must not be less than one");
  178. return !CollectionUtils.isEmpty(list) && executeBatch(entityClass, log, sqlSession -> {
  179. int size = list.size();
  180. int i = 1;
  181. for (E element : list) {
  182. consumer.accept(sqlSession, element);
  183. if ((i % batchSize == 0) || i == size) {
  184. sqlSession.flushStatements();
  185. }
  186. i++;
  187. }
  188. });
  189. }
  190. /**
  191. * 批量更新或保存
  192. *
  193. * @param entityClass 实体
  194. * @param log 日志对象
  195. * @param list 数据集合
  196. * @param batchSize 批次大小
  197. * @param predicate predicate(新增条件) notNull
  198. * @param consumer consumer(更新处理) notNull
  199. * @param <E> E
  200. * @return 操作结果
  201. * @since 3.4.0
  202. */
  203. public static <E> boolean saveOrUpdateBatch(Class<?> entityClass, Class<?> mapper, Log log, Collection<E> list, int batchSize, BiPredicate<SqlSession,E> predicate, BiConsumer<SqlSession, E> consumer) {
  204. String sqlStatement = getSqlStatement(mapper, SqlMethod.INSERT_ONE);
  205. return executeBatch(entityClass, log, list, batchSize, (sqlSession, entity) -> {
  206. if (predicate.test(sqlSession, entity)) {
  207. sqlSession.insert(sqlStatement, entity);
  208. } else {
  209. consumer.accept(sqlSession, entity);
  210. }
  211. });
  212. }
  213. /**
  214. * 获取mapperStatementId
  215. *
  216. * @param sqlMethod 方法名
  217. * @return 命名id
  218. * @since 3.4.0
  219. */
  220. public static String getSqlStatement(Class<?> mapper, SqlMethod sqlMethod) {
  221. return mapper.getName() + StringPool.DOT + sqlMethod.getMethod();
  222. }
  223. }