Browse Source

重构SqlRunner执行SQL.

https://github.com/baomidou/mybatis-plus/issues/6666
nieqiurong 1 month ago
parent
commit
44bab1fc6c

+ 5 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/assist/ISqlRunner.java

@@ -15,6 +15,7 @@
  */
 package com.baomidou.mybatisplus.core.assist;
 
+import com.baomidou.mybatisplus.core.injector.SqlRunnerInjector;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 
 import java.util.List;
@@ -34,6 +35,10 @@ public interface ISqlRunner {
     String SELECT_LIST = "com.baomidou.mybatisplus.core.mapper.SqlRunner.SelectList";
     String SELECT_OBJS = "com.baomidou.mybatisplus.core.mapper.SqlRunner.SelectObjs";
     String COUNT = "com.baomidou.mybatisplus.core.mapper.SqlRunner.Count";
+
+    /**
+     * @deprecated 3.5.12 {@link SqlRunnerInjector#SQL_SCRIPT}
+     */
     String SQL_SCRIPT = "${sql}";
     String SQL = "sql";
     String PAGE = "page";

+ 14 - 6
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/injector/SqlRunnerInjector.java

@@ -41,6 +41,14 @@ public class SqlRunnerInjector {
     protected Configuration configuration;
     protected LanguageDriver languageDriver;
 
+    /**
+     * @since 3.5.12
+     */
+    public static final String SQL_SCRIPT = "<script>" +
+        "${sql}\n" +
+        "<if test=\"true\"></if>" +
+        "</script>";
+
 
     public void inject(Configuration configuration) {
         this.configuration = configuration;
@@ -109,7 +117,7 @@ public class SqlRunnerInjector {
             logger.warn("MappedStatement 'SqlRunner.SelectList' Already Exists");
             return;
         }
-        SqlSource sqlSource = languageDriver.createSqlSource(configuration, ISqlRunner.SQL_SCRIPT, Map.class);
+        SqlSource sqlSource = languageDriver.createSqlSource(configuration, SQL_SCRIPT, Map.class);
         createSelectMappedStatement(ISqlRunner.SELECT_LIST, sqlSource, Map.class);
     }
 
@@ -121,7 +129,7 @@ public class SqlRunnerInjector {
             logger.warn("MappedStatement 'SqlRunner.SelectObjs' Already Exists");
             return;
         }
-        SqlSource sqlSource = languageDriver.createSqlSource(configuration, ISqlRunner.SQL_SCRIPT, Object.class);
+        SqlSource sqlSource = languageDriver.createSqlSource(configuration, SQL_SCRIPT, Object.class);
         createSelectMappedStatement(ISqlRunner.SELECT_OBJS, sqlSource, Object.class);
     }
 
@@ -133,7 +141,7 @@ public class SqlRunnerInjector {
             logger.warn("MappedStatement 'SqlRunner.Count' Already Exists");
             return;
         }
-        SqlSource sqlSource = languageDriver.createSqlSource(configuration, ISqlRunner.SQL_SCRIPT, Map.class);
+        SqlSource sqlSource = languageDriver.createSqlSource(configuration, SQL_SCRIPT, Map.class);
         createSelectMappedStatement(ISqlRunner.COUNT, sqlSource, Long.class);
     }
 
@@ -145,7 +153,7 @@ public class SqlRunnerInjector {
             logger.warn("MappedStatement 'SqlRunner.Insert' Already Exists");
             return;
         }
-        SqlSource sqlSource = languageDriver.createSqlSource(configuration, ISqlRunner.SQL_SCRIPT, Map.class);
+        SqlSource sqlSource = languageDriver.createSqlSource(configuration, SQL_SCRIPT, Map.class);
         createUpdateMappedStatement(ISqlRunner.INSERT, sqlSource, SqlCommandType.INSERT);
     }
 
@@ -157,7 +165,7 @@ public class SqlRunnerInjector {
             logger.warn("MappedStatement 'SqlRunner.Update' Already Exists");
             return;
         }
-        SqlSource sqlSource = languageDriver.createSqlSource(configuration, ISqlRunner.SQL_SCRIPT, Map.class);
+        SqlSource sqlSource = languageDriver.createSqlSource(configuration, SQL_SCRIPT, Map.class);
         createUpdateMappedStatement(ISqlRunner.UPDATE, sqlSource, SqlCommandType.UPDATE);
     }
 
@@ -169,7 +177,7 @@ public class SqlRunnerInjector {
             logger.warn("MappedStatement 'SqlRunner.Delete' Already Exists");
             return;
         }
-        SqlSource sqlSource = languageDriver.createSqlSource(configuration, ISqlRunner.SQL_SCRIPT, Map.class);
+        SqlSource sqlSource = languageDriver.createSqlSource(configuration, SQL_SCRIPT, Map.class);
         createUpdateMappedStatement(ISqlRunner.DELETE, sqlSource, SqlCommandType.DELETE);
     }
 }

+ 11 - 0
mybatis-plus-core/src/main/java/com/baomidou/mybatisplus/core/toolkit/StringUtils.java

@@ -226,7 +226,9 @@ public final class StringUtils {
      *
      * @param content 填充内容
      * @param args    填充参数
+     * @deprecated 3.5.12
      */
+    @Deprecated
     public static String sqlArgsFill(String content, Object... args) {
         if (StringUtils.isNotBlank(content) && ArrayUtils.isNotEmpty(args)) {
             // 索引不能使用,因为 SQL 中的占位符数字与索引不相同
@@ -246,7 +248,9 @@ public final class StringUtils {
      * @param ptn      需要替换部分的正则表达式
      * @param replacer 替换处理器
      * @return 返回字符串构建起
+     * @deprecated 3.5.12
      */
+    @Deprecated
     public static StringBuilder replace(CharSequence src, Pattern ptn, BiIntFunction<Matcher, CharSequence> replacer) {
         int idx = 0, last = 0, len = src.length();
         Matcher m = ptn.matcher(src);
@@ -267,7 +271,10 @@ public final class StringUtils {
 
     /**
      * 获取SQL PARAMS字符串
+     *
+     * @deprecated 3.5.12
      */
+    @Deprecated
     public static String sqlParam(Object obj) {
         String repStr;
         if (obj instanceof Collection) {
@@ -283,7 +290,9 @@ public final class StringUtils {
      *
      * @param obj 原字符串
      * @return 单引号包含的原字符串
+     * @deprecated 3.5.12
      */
+    @Deprecated
     public static String quotaMark(Object obj) {
         String srcStr = String.valueOf(obj);
         if (obj instanceof CharSequence) {
@@ -298,7 +307,9 @@ public final class StringUtils {
      *
      * @param coll 集合
      * @return 单引号包含的原字符串的集合形式
+     * @deprecated 3.5.12
      */
+    @Deprecated
     public static String quotaMarkList(Collection<?> coll) {
         return coll.stream().map(StringUtils::quotaMark)
             .collect(joining(StringPool.COMMA, StringPool.LEFT_BRACKET, StringPool.RIGHT_BRACKET));

+ 39 - 9
mybatis-plus-spring/src/main/java/com/baomidou/mybatisplus/extension/toolkit/SqlRunner.java

@@ -19,14 +19,15 @@ import com.baomidou.mybatisplus.core.assist.ISqlRunner;
 import com.baomidou.mybatisplus.core.metadata.IPage;
 import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
 import com.baomidou.mybatisplus.core.toolkit.GlobalConfigUtils;
-import com.baomidou.mybatisplus.core.toolkit.StringUtils;
 import org.apache.ibatis.logging.Log;
 import org.apache.ibatis.logging.LogFactory;
+import org.apache.ibatis.parsing.GenericTokenParser;
 import org.apache.ibatis.session.SqlSession;
 import org.apache.ibatis.session.SqlSessionFactory;
 import org.mybatis.spring.SqlSessionUtils;
 import org.springframework.transaction.annotation.Transactional;
 
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -39,7 +40,8 @@ import java.util.Optional;
  */
 public class SqlRunner implements ISqlRunner {
 
-    private final Log log = LogFactory.getLog(SqlRunner.class);
+    private static final Log LOG = LogFactory.getLog(SqlRunner.class);
+
     // 单例Query
     public static final SqlRunner DEFAULT = new SqlRunner();
 
@@ -100,12 +102,40 @@ public class SqlRunner implements ISqlRunner {
      * @param args 仅支持String
      * @return ignore
      */
-    private Map<String, String> sqlMap(String sql, Object... args) {
-        Map<String, String> sqlMap = CollectionUtils.newHashMapWithExpectedSize(1);
-        sqlMap.put(SQL, StringUtils.sqlArgsFill(sql, args));
+    private Map<String, Object> sqlMap(String sql, Object... args) {
+        Map<String, Object> sqlMap = getParams(args);
+        sqlMap.put(SQL, parse(sql));
         return sqlMap;
     }
 
+    /**
+     * 获取执行语句
+     *
+     * @param sql    原始sql
+     * @return 执行语句
+     */
+    private String parse(String sql) {
+        return new GenericTokenParser("{", "}", content -> "#{" + content + "}").parse(sql);
+    }
+
+    /**
+     * 获取参数列表
+     *
+     * @param args 参数
+     * @return 参数map
+     * @since 3.5.12
+     */
+    private Map<String, Object> getParams(Object... args) {
+        if (args != null) {
+            Map<String, Object> params = CollectionUtils.newHashMapWithExpectedSize(args.length);
+            for (int i = 0; i < args.length; i++) {
+                params.put(String.valueOf(i), args[i]);
+            }
+            return params;
+        }
+        return new HashMap<>();
+    }
+
     /**
      * 获取sqlMap参数
      *
@@ -115,9 +145,9 @@ public class SqlRunner implements ISqlRunner {
      * @return ignore
      */
     private Map<String, Object> sqlMap(String sql, IPage<?> page, Object... args) {
-        Map<String, Object> sqlMap = CollectionUtils.newHashMapWithExpectedSize(2);
+        Map<String, Object> sqlMap = getParams(args);
         sqlMap.put(PAGE, page);
-        sqlMap.put(SQL, StringUtils.sqlArgsFill(sql, args));
+        sqlMap.put(SQL, parse(sql));
         return sqlMap;
     }
 
@@ -178,7 +208,7 @@ public class SqlRunner implements ISqlRunner {
      */
     @Override
     public Object selectObj(String sql, Object... args) {
-        return SqlHelper.getObject(log, selectObjs(sql, args));
+        return SqlHelper.getObject(LOG, selectObjs(sql, args));
     }
 
     @Override
@@ -193,7 +223,7 @@ public class SqlRunner implements ISqlRunner {
 
     @Override
     public Map<String, Object> selectOne(String sql, Object... args) {
-        return SqlHelper.getObject(log, selectList(sql, args));
+        return SqlHelper.getObject(LOG, selectList(sql, args));
     }
 
     @Override

+ 30 - 8
mybatis-plus/src/test/java/com/baomidou/mybatisplus/test/h2/SqlRunnerTest.java

@@ -19,6 +19,7 @@ import java.util.List;
 
 /**
  * SqlRunner测试
+ *
  * @author nieqiurong 2018/8/25 11:05.
  */
 @ExtendWith(SpringExtension.class)
@@ -32,37 +33,38 @@ class SqlRunnerTest {
 
     @Test
     @Order(3)
-    void testSelectCount(){
+    void testSelectCount() {
         long count = SqlRunner.db().selectCount("select count(1) from h2student");
         Assertions.assertTrue(count > 0);
-        count = SqlRunner.db().selectCount("select count(1) from h2student where id > {0}",0);
+        count = SqlRunner.db().selectCount("select count(1) from h2student where id > {0}", 0);
         Assertions.assertTrue(count > 0);
         count = SqlRunner.db(H2Student.class).selectCount("select count(1) from h2student");
         Assertions.assertTrue(count > 0);
-        count = SqlRunner.db(H2Student.class).selectCount("select count(1) from h2student where id > {0}",0);
+        count = SqlRunner.db(H2Student.class).selectCount("select count(1) from h2student where id > {0}", 0);
         Assertions.assertTrue(count > 0);
     }
 
     @Test
     @Transactional
     @Order(1)
-    void testInsert(){
-        Assertions.assertTrue(SqlRunner.db().insert("INSERT INTO h2student ( name, age ) VALUES ( {0}, {1} )","测试学生",2));
-        Assertions.assertTrue(SqlRunner.db(H2Student.class).insert("INSERT INTO h2student ( name, age ) VALUES ( {0}, {1} )","测试学生2",3));
+    void testInsert() {
+        Assertions.assertTrue(SqlRunner.db().insert("INSERT INTO h2student ( name, age ) VALUES ( {0}, {1} )", "测试学生", 2));
+        Assertions.assertTrue(SqlRunner.db(H2Student.class).insert("INSERT INTO h2student ( name, age ) VALUES ( {0}, {1} )", "测试学生2", 3));
     }
 
     @Test
     @Order(2)
-    void testTransactional(){
+    void testTransactional() {
         try {
             studentService.testSqlRunnerTransactional();
-        } catch (RuntimeException e){
+        } catch (RuntimeException e) {
             List<H2Student> list = studentService.list(new QueryWrapper<H2Student>().like("name", "sqlRunnerTx"));
             Assertions.assertTrue(CollectionUtils.isEmpty(list));
         }
     }
 
     @Test
+    @Order(4)
     void testSelectPage() {
         IPage page1 = SqlRunner.db().selectPage(new Page(1, 3), "select * from h2student");
         Assertions.assertEquals(page1.getRecords().size(), 3);
@@ -71,4 +73,24 @@ class SqlRunnerTest {
         IPage page3 = SqlRunner.db().selectPage(new Page(1, 3), "select * from h2student where id = {0}", 10086);
         Assertions.assertEquals(page3.getRecords().size(), 0);
     }
+
+    @Test
+    @Order(5)
+    void testInsertByDisorderParameter() {
+        Assertions.assertTrue(SqlRunner.db().insert("INSERT INTO h2student (id, name, age ) VALUES ( {3}, {2}, {1} )", "测试学生", 2, "'六翻了'", 10000));
+        Assertions.assertTrue(SqlRunner.db(H2Student.class).insert("INSERT INTO h2student ( name, age, id ) VALUES ( {0}, {1}, {2} )", "测试学生2", 3, 10001));
+        Assertions.assertEquals(2, SqlRunner.db().selectCount("select count(1) from h2student where (id = 10000 or id = 10001)"));
+    }
+
+    @Test
+    @Order(6)
+    void testSpecialParameters() {
+        var name = "`测`的'的'\\//塞'2";
+        Assertions.assertTrue(SqlRunner.db().insert("INSERT INTO h2student (id, name, age ) VALUES ( {3}, {0}, {1} )", name, 2, "'六翻了'", 10004));
+        Assertions.assertEquals(10004L, SqlRunner.db().selectObj("select id from h2student where name = {0}", name));
+        name = "`测`的'的'\\//塞'2" + "2";
+        Assertions.assertTrue(SqlRunner.db().update("update h2student set name = {0} where id = {1}", name, 10004L));
+        Assertions.assertEquals(10004L, SqlRunner.db().selectObj("select id from h2student where name = {0}", name));
+    }
+
 }