fix splitting multiple SQLs

This commit is contained in:
chrislu
2025-09-03 17:47:24 -07:00
parent ea758d0d9f
commit 1db1206827
2 changed files with 253 additions and 23 deletions

View File

@@ -18,43 +18,124 @@ import (
)
// splitSQLStatements splits a query string into individual SQL statements
// This is a simple implementation that splits on semicolons outside of quoted strings
// This robust implementation handles SQL comments, quoted strings, and escaped characters
func splitSQLStatements(query string) []string {
var statements []string
var current strings.Builder
inSingleQuote := false
inDoubleQuote := false
query = strings.TrimSpace(query)
if query == "" {
return []string{}
}
for _, char := range query {
switch char {
case '\'':
if !inDoubleQuote {
inSingleQuote = !inSingleQuote
runes := []rune(query)
i := 0
for i < len(runes) {
char := runes[i]
// Handle single-line comments (-- comment)
if char == '-' && i+1 < len(runes) && runes[i+1] == '-' {
// Include the entire comment in the current statement
for i < len(runes) && runes[i] != '\n' && runes[i] != '\r' {
current.WriteRune(runes[i])
i++
}
current.WriteRune(char)
case '"':
if !inSingleQuote {
inDoubleQuote = !inDoubleQuote
// Include the newline if present
if i < len(runes) {
current.WriteRune(runes[i])
i++
}
current.WriteRune(char)
case ';':
if !inSingleQuote && !inDoubleQuote {
stmt := strings.TrimSpace(current.String())
if stmt != "" {
statements = append(statements, stmt)
continue
}
// Handle multi-line comments (/* comment */)
if char == '/' && i+1 < len(runes) && runes[i+1] == '*' {
current.WriteRune(char) // Include the /*
i++
if i < len(runes) {
current.WriteRune(runes[i])
i++
}
// Skip to end of comment or end of input
for i < len(runes) {
current.WriteRune(runes[i])
if runes[i] == '*' && i+1 < len(runes) && runes[i+1] == '/' {
i++
current.WriteRune(runes[i]) // Include the */
i++
break
}
current.Reset()
} else {
current.WriteRune(char)
i++
}
default:
continue
}
// Handle single-quoted strings
if char == '\'' {
current.WriteRune(char)
i++
for i < len(runes) {
char = runes[i]
current.WriteRune(char)
if char == '\'' {
// Check if it's an escaped quote
if i+1 < len(runes) && runes[i+1] == '\'' {
i++ // Skip the next quote (it's escaped)
if i < len(runes) {
current.WriteRune(runes[i])
}
} else {
break // End of string
}
}
i++
}
i++
continue
}
// Handle double-quoted identifiers
if char == '"' {
current.WriteRune(char)
i++
for i < len(runes) {
char = runes[i]
current.WriteRune(char)
if char == '"' {
// Check if it's an escaped quote
if i+1 < len(runes) && runes[i+1] == '"' {
i++ // Skip the next quote (it's escaped)
if i < len(runes) {
current.WriteRune(runes[i])
}
} else {
break // End of identifier
}
}
i++
}
i++
continue
}
// Handle semicolon (statement separator)
if char == ';' {
stmt := strings.TrimSpace(current.String())
if stmt != "" {
statements = append(statements, stmt)
}
current.Reset()
} else {
current.WriteRune(char)
}
i++
}
// Add any remaining statement

149
weed/command/sql_test.go Normal file
View File

@@ -0,0 +1,149 @@
package command
import (
"reflect"
"testing"
)
func TestSplitSQLStatements(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "Simple single statement",
input: "SELECT * FROM users",
expected: []string{"SELECT * FROM users"},
},
{
name: "Multiple statements",
input: "SELECT * FROM users; SELECT * FROM orders;",
expected: []string{"SELECT * FROM users", "SELECT * FROM orders"},
},
{
name: "Semicolon in single quotes",
input: "SELECT 'hello;world' FROM users; SELECT * FROM orders;",
expected: []string{"SELECT 'hello;world' FROM users", "SELECT * FROM orders"},
},
{
name: "Semicolon in double quotes",
input: `SELECT "column;name" FROM users; SELECT * FROM orders;`,
expected: []string{`SELECT "column;name" FROM users`, "SELECT * FROM orders"},
},
{
name: "Escaped quotes in strings",
input: `SELECT 'don''t split; here' FROM users; SELECT * FROM orders;`,
expected: []string{`SELECT 'don''t split; here' FROM users`, "SELECT * FROM orders"},
},
{
name: "Escaped quotes in identifiers",
input: `SELECT "column""name" FROM users; SELECT * FROM orders;`,
expected: []string{`SELECT "column""name" FROM users`, "SELECT * FROM orders"},
},
{
name: "Single line comment",
input: "SELECT * FROM users; -- This is a comment\nSELECT * FROM orders;",
expected: []string{"SELECT * FROM users", "-- This is a comment\nSELECT * FROM orders"},
},
{
name: "Single line comment with semicolon",
input: "SELECT * FROM users; -- Comment with; semicolon\nSELECT * FROM orders;",
expected: []string{"SELECT * FROM users", "-- Comment with; semicolon\nSELECT * FROM orders"},
},
{
name: "Multi-line comment",
input: "SELECT * FROM users; /* Multi-line\ncomment */ SELECT * FROM orders;",
expected: []string{"SELECT * FROM users", "/* Multi-line\ncomment */ SELECT * FROM orders"},
},
{
name: "Multi-line comment with semicolon",
input: "SELECT * FROM users; /* Comment with; semicolon */ SELECT * FROM orders;",
expected: []string{"SELECT * FROM users", "/* Comment with; semicolon */ SELECT * FROM orders"},
},
{
name: "Complex mixed case",
input: `SELECT 'test;string', "quoted;id" FROM users; -- Comment; here
/* Another; comment */
INSERT INTO users VALUES ('name''s value', "id""field");`,
expected: []string{
`SELECT 'test;string', "quoted;id" FROM users`,
`-- Comment; here
/* Another; comment */
INSERT INTO users VALUES ('name''s value', "id""field")`,
},
},
{
name: "Empty statements filtered",
input: "SELECT * FROM users;;; SELECT * FROM orders;",
expected: []string{"SELECT * FROM users", "SELECT * FROM orders"},
},
{
name: "Whitespace handling",
input: " SELECT * FROM users ; SELECT * FROM orders ; ",
expected: []string{"SELECT * FROM users", "SELECT * FROM orders"},
},
{
name: "Single statement without semicolon",
input: "SELECT * FROM users",
expected: []string{"SELECT * FROM users"},
},
{
name: "Empty query",
input: "",
expected: []string{},
},
{
name: "Only whitespace",
input: " \n\t ",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitSQLStatements(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("splitSQLStatements() = %v, expected %v", result, tt.expected)
}
})
}
}
func TestSplitSQLStatements_EdgeCases(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "Nested comments are not supported but handled gracefully",
input: "SELECT * FROM users; /* Outer /* inner */ comment */ SELECT * FROM orders;",
expected: []string{"SELECT * FROM users", "/* Outer /* inner */ comment */ SELECT * FROM orders"},
},
{
name: "Unterminated string (malformed SQL)",
input: "SELECT 'unterminated string; SELECT * FROM orders;",
expected: []string{"SELECT 'unterminated string; SELECT * FROM orders;"},
},
{
name: "Unterminated comment (malformed SQL)",
input: "SELECT * FROM users; /* unterminated comment",
expected: []string{"SELECT * FROM users", "/* unterminated comment"},
},
{
name: "Multiple semicolons in quotes",
input: "SELECT ';;;' FROM users; SELECT ';;;' FROM orders;",
expected: []string{"SELECT ';;;' FROM users", "SELECT ';;;' FROM orders"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitSQLStatements(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("splitSQLStatements() = %v, expected %v", result, tt.expected)
}
})
}
}