diff --git a/hutool-db/src/main/java/cn/hutool/v7/db/sql/NamedSql.java b/hutool-db/src/main/java/cn/hutool/v7/db/sql/NamedSql.java index 5ec772b208..a6222878c4 100644 --- a/hutool-db/src/main/java/cn/hutool/v7/db/sql/NamedSql.java +++ b/hutool-db/src/main/java/cn/hutool/v7/db/sql/NamedSql.java @@ -18,7 +18,6 @@ package cn.hutool.v7.db.sql; import cn.hutool.v7.core.array.ArrayUtil; import cn.hutool.v7.core.map.MapUtil; -import cn.hutool.v7.core.text.StrUtil; import java.util.Collection; import java.util.Map; @@ -144,7 +143,7 @@ public class NamedSql extends BoundSql { if (paramMap.containsKey(nameStr)) { // 有变量对应值(值可以为null),替换占位符为?,变量值放入相应index位置 Object paramValue = paramMap.get(nameStr); - if ((paramValue instanceof Collection || ArrayUtil.isArray(paramValue)) && StrUtil.containsIgnoreCase(sqlBuilder, "in")) { + if ((paramValue instanceof Collection || ArrayUtil.isArray(paramValue)) && SqlUtil.isInClause(sqlBuilder)) { if (paramValue instanceof Collection) { // 转为数组 paramValue = ((Collection) paramValue).toArray(); diff --git a/hutool-db/src/main/java/cn/hutool/v7/db/sql/SqlUtil.java b/hutool-db/src/main/java/cn/hutool/v7/db/sql/SqlUtil.java index 405f862682..04ee76ab06 100644 --- a/hutool-db/src/main/java/cn/hutool/v7/db/sql/SqlUtil.java +++ b/hutool-db/src/main/java/cn/hutool/v7/db/sql/SqlUtil.java @@ -46,6 +46,10 @@ public class SqlUtil { * 创建SQL中的order by语句的正则 */ private static final Pattern PATTERN_ORDER_BY = PatternPool.get("(.*)\\s+order\\s+by\\s+[^\\s]+", Pattern.CASE_INSENSITIVE); + /** + * SQL中的in语句部分的正则 + */ + private static final Pattern PATTERN_IN_CLAUSE = PatternPool.get("\\s+in\\s+[(]", Pattern.CASE_INSENSITIVE); /** * 构件相等条件的where语句
@@ -283,4 +287,15 @@ public class SqlUtil { // 去除order by 子句 return ReUtil.getGroup1(PATTERN_ORDER_BY, selectSql); } + + /** + * 判断当前上下文是否在 IN 子句中 + * 通过检查变量前的SQL文本,判断是否符合 IN 子句的模式 + * + * @param sql 当前已构建的SQL + * @return 是否在 IN 子句中 + */ + public static boolean isInClause(final CharSequence sql) { + return ReUtil.contains(PATTERN_IN_CLAUSE, sql); + } } diff --git a/hutool-db/src/test/java/cn/hutool/v7/db/NamedSqlTest.java b/hutool-db/src/test/java/cn/hutool/v7/db/NamedSqlTest.java index 580198d8ba..5ffbe434ba 100644 --- a/hutool-db/src/test/java/cn/hutool/v7/db/NamedSqlTest.java +++ b/hutool-db/src/test/java/cn/hutool/v7/db/NamedSqlTest.java @@ -25,6 +25,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + public class NamedSqlTest { @Test @@ -39,9 +42,9 @@ public class NamedSqlTest { final NamedSql namedSql = new NamedSql(sql, paramMap); //未指定参数原样输出 - Assertions.assertEquals("select * from table where id=@id and name = ? and nickName = ?", namedSql.getSql()); - Assertions.assertEquals("张三", namedSql.getParamArray()[0]); - Assertions.assertEquals("小豆豆", namedSql.getParamArray()[1]); + assertEquals("select * from table where id=@id and name = ? and nickName = ?", namedSql.getSql()); + assertEquals("张三", namedSql.getParamArray()[0]); + assertEquals("小豆豆", namedSql.getParamArray()[1]); } @Test @@ -56,11 +59,11 @@ public class NamedSqlTest { .build(); final NamedSql namedSql = new NamedSql(sql, paramMap); - Assertions.assertEquals("select * from table where id=? and name = ? and nickName = ?", namedSql.getSql()); + assertEquals("select * from table where id=? and name = ? and nickName = ?", namedSql.getSql()); //指定了null参数的依旧替换,参数值为null Assertions.assertNull(namedSql.getParamArray()[0]); - Assertions.assertEquals("张三", namedSql.getParamArray()[1]); - Assertions.assertEquals("小豆豆", namedSql.getParamArray()[2]); + assertEquals("张三", namedSql.getParamArray()[1]); + assertEquals("小豆豆", namedSql.getParamArray()[2]); } @Test @@ -73,7 +76,7 @@ public class NamedSqlTest { .build(); final NamedSql namedSql = new NamedSql(sql, paramMap); - Assertions.assertEquals(sql, namedSql.getSql()); + assertEquals(sql, namedSql.getSql()); } @Test @@ -86,7 +89,7 @@ public class NamedSqlTest { .build(); final NamedSql namedSql = new NamedSql(sql, paramMap); - Assertions.assertEquals(sql, namedSql.getSql()); + assertEquals(sql, namedSql.getSql()); } @Test @@ -95,10 +98,10 @@ public class NamedSqlTest { final HashMap paramMap = MapUtil.of("ids", new int[]{1, 2, 3}); final NamedSql namedSql = new NamedSql(sql, paramMap); - Assertions.assertEquals("select * from user where id in (?,?,?)", namedSql.getSql()); - Assertions.assertEquals(1, namedSql.getParamArray()[0]); - Assertions.assertEquals(2, namedSql.getParamArray()[1]); - Assertions.assertEquals(3, namedSql.getParamArray()[2]); + assertEquals("select * from user where id in (?,?,?)", namedSql.getSql()); + assertEquals(1, namedSql.getParamArray()[0]); + assertEquals(2, namedSql.getParamArray()[1]); + assertEquals(3, namedSql.getParamArray()[2]); } @Test @@ -109,10 +112,34 @@ public class NamedSqlTest { final String sql = "select * from user where name = @name1 and age = @age1"; List query = Db.of().query(sql, paramMap); - Assertions.assertEquals(1, query.size()); + assertEquals(1, query.size()); // 采用传统方式查询是否能识别Map类型参数 query = Db.of().query(sql, new Object[]{paramMap}); - Assertions.assertEquals(1, query.size()); + assertEquals(1, query.size()); + } + + @Test + public void parseInTest2() { + // 测试表名包含"in"但不是IN子句的情况 + final String sql = "select * from information where info_data = :info"; + final HashMap paramMap = MapUtil.of("info", new int[]{10, 20}); + + final NamedSql namedSql = new NamedSql(sql, paramMap); + // sql语句不包含IN子句,不会展开数组 + assertEquals("select * from information where info_data = ?", namedSql.getSql()); + assertArrayEquals(new int[]{10, 20}, (int[]) namedSql.getParamArray()[0]); + } + + @Test + public void parseInTest3() { + // 测试字符串中包含"in"关键字但不是IN子句的情况 + final String sql = "select * from user where comment = 'include in text' and id = :id"; + final HashMap paramMap = MapUtil.of("id", new int[]{5, 6}); + + final NamedSql namedSql = new NamedSql(sql, paramMap); + // sql语句不包含IN子句,不会展开数组 + assertEquals("select * from user where comment = 'include in text' and id = ?", namedSql.getSql()); + assertArrayEquals(new int[]{5, 6}, (int[]) namedSql.getParamArray()[0]); } }