From 1db1206827f3baeb416e5ce6e6626c3c1452aa17 Mon Sep 17 00:00:00 2001 From: chrislu Date: Wed, 3 Sep 2025 17:47:24 -0700 Subject: [PATCH] fix splitting multiple SQLs --- weed/command/sql.go | 127 +++++++++++++++++++++++++++------ weed/command/sql_test.go | 149 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 253 insertions(+), 23 deletions(-) create mode 100644 weed/command/sql_test.go diff --git a/weed/command/sql.go b/weed/command/sql.go index 7e10234c7..928e0c4f8 100644 --- a/weed/command/sql.go +++ b/weed/command/sql.go @@ -18,43 +18,124 @@ import ( ) // splitSQLStatements splits a query string into individual SQL statements -// This is a simple implementation that splits on semicolons outside of quoted strings +// This robust implementation handles SQL comments, quoted strings, and escaped characters func splitSQLStatements(query string) []string { var statements []string var current strings.Builder - inSingleQuote := false - inDoubleQuote := false - + query = strings.TrimSpace(query) if query == "" { return []string{} } - for _, char := range query { - switch char { - case '\'': - if !inDoubleQuote { - inSingleQuote = !inSingleQuote + runes := []rune(query) + i := 0 + + for i < len(runes) { + char := runes[i] + + // Handle single-line comments (-- comment) + if char == '-' && i+1 < len(runes) && runes[i+1] == '-' { + // Include the entire comment in the current statement + for i < len(runes) && runes[i] != '\n' && runes[i] != '\r' { + current.WriteRune(runes[i]) + i++ } - current.WriteRune(char) - case '"': - if !inSingleQuote { - inDoubleQuote = !inDoubleQuote + // Include the newline if present + if i < len(runes) { + current.WriteRune(runes[i]) + i++ } - current.WriteRune(char) - case ';': - if !inSingleQuote && !inDoubleQuote { - stmt := strings.TrimSpace(current.String()) - if stmt != "" { - statements = append(statements, stmt) + continue + } + + // Handle multi-line comments (/* comment */) + if char == '/' && i+1 < len(runes) && runes[i+1] == '*' { + current.WriteRune(char) // Include the /* + i++ + if i < len(runes) { + current.WriteRune(runes[i]) + i++ + } + + // Skip to end of comment or end of input + for i < len(runes) { + current.WriteRune(runes[i]) + if runes[i] == '*' && i+1 < len(runes) && runes[i+1] == '/' { + i++ + current.WriteRune(runes[i]) // Include the */ + i++ + break } - current.Reset() - } else { - current.WriteRune(char) + i++ } - default: + continue + } + + // Handle single-quoted strings + if char == '\'' { + current.WriteRune(char) + i++ + + for i < len(runes) { + char = runes[i] + current.WriteRune(char) + + if char == '\'' { + // Check if it's an escaped quote + if i+1 < len(runes) && runes[i+1] == '\'' { + i++ // Skip the next quote (it's escaped) + if i < len(runes) { + current.WriteRune(runes[i]) + } + } else { + break // End of string + } + } + i++ + } + i++ + continue + } + + // Handle double-quoted identifiers + if char == '"' { + current.WriteRune(char) + i++ + + for i < len(runes) { + char = runes[i] + current.WriteRune(char) + + if char == '"' { + // Check if it's an escaped quote + if i+1 < len(runes) && runes[i+1] == '"' { + i++ // Skip the next quote (it's escaped) + if i < len(runes) { + current.WriteRune(runes[i]) + } + } else { + break // End of identifier + } + } + i++ + } + i++ + continue + } + + // Handle semicolon (statement separator) + if char == ';' { + stmt := strings.TrimSpace(current.String()) + if stmt != "" { + statements = append(statements, stmt) + } + current.Reset() + } else { current.WriteRune(char) } + + i++ } // Add any remaining statement diff --git a/weed/command/sql_test.go b/weed/command/sql_test.go new file mode 100644 index 000000000..0d15e01e0 --- /dev/null +++ b/weed/command/sql_test.go @@ -0,0 +1,149 @@ +package command + +import ( + "reflect" + "testing" +) + +func TestSplitSQLStatements(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "Simple single statement", + input: "SELECT * FROM users", + expected: []string{"SELECT * FROM users"}, + }, + { + name: "Multiple statements", + input: "SELECT * FROM users; SELECT * FROM orders;", + expected: []string{"SELECT * FROM users", "SELECT * FROM orders"}, + }, + { + name: "Semicolon in single quotes", + input: "SELECT 'hello;world' FROM users; SELECT * FROM orders;", + expected: []string{"SELECT 'hello;world' FROM users", "SELECT * FROM orders"}, + }, + { + name: "Semicolon in double quotes", + input: `SELECT "column;name" FROM users; SELECT * FROM orders;`, + expected: []string{`SELECT "column;name" FROM users`, "SELECT * FROM orders"}, + }, + { + name: "Escaped quotes in strings", + input: `SELECT 'don''t split; here' FROM users; SELECT * FROM orders;`, + expected: []string{`SELECT 'don''t split; here' FROM users`, "SELECT * FROM orders"}, + }, + { + name: "Escaped quotes in identifiers", + input: `SELECT "column""name" FROM users; SELECT * FROM orders;`, + expected: []string{`SELECT "column""name" FROM users`, "SELECT * FROM orders"}, + }, + { + name: "Single line comment", + input: "SELECT * FROM users; -- This is a comment\nSELECT * FROM orders;", + expected: []string{"SELECT * FROM users", "-- This is a comment\nSELECT * FROM orders"}, + }, + { + name: "Single line comment with semicolon", + input: "SELECT * FROM users; -- Comment with; semicolon\nSELECT * FROM orders;", + expected: []string{"SELECT * FROM users", "-- Comment with; semicolon\nSELECT * FROM orders"}, + }, + { + name: "Multi-line comment", + input: "SELECT * FROM users; /* Multi-line\ncomment */ SELECT * FROM orders;", + expected: []string{"SELECT * FROM users", "/* Multi-line\ncomment */ SELECT * FROM orders"}, + }, + { + name: "Multi-line comment with semicolon", + input: "SELECT * FROM users; /* Comment with; semicolon */ SELECT * FROM orders;", + expected: []string{"SELECT * FROM users", "/* Comment with; semicolon */ SELECT * FROM orders"}, + }, + { + name: "Complex mixed case", + input: `SELECT 'test;string', "quoted;id" FROM users; -- Comment; here + /* Another; comment */ + INSERT INTO users VALUES ('name''s value', "id""field");`, + expected: []string{ + `SELECT 'test;string', "quoted;id" FROM users`, + `-- Comment; here + /* Another; comment */ + INSERT INTO users VALUES ('name''s value', "id""field")`, + }, + }, + { + name: "Empty statements filtered", + input: "SELECT * FROM users;;; SELECT * FROM orders;", + expected: []string{"SELECT * FROM users", "SELECT * FROM orders"}, + }, + { + name: "Whitespace handling", + input: " SELECT * FROM users ; SELECT * FROM orders ; ", + expected: []string{"SELECT * FROM users", "SELECT * FROM orders"}, + }, + { + name: "Single statement without semicolon", + input: "SELECT * FROM users", + expected: []string{"SELECT * FROM users"}, + }, + { + name: "Empty query", + input: "", + expected: []string{}, + }, + { + name: "Only whitespace", + input: " \n\t ", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := splitSQLStatements(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("splitSQLStatements() = %v, expected %v", result, tt.expected) + } + }) + } +} + +func TestSplitSQLStatements_EdgeCases(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "Nested comments are not supported but handled gracefully", + input: "SELECT * FROM users; /* Outer /* inner */ comment */ SELECT * FROM orders;", + expected: []string{"SELECT * FROM users", "/* Outer /* inner */ comment */ SELECT * FROM orders"}, + }, + { + name: "Unterminated string (malformed SQL)", + input: "SELECT 'unterminated string; SELECT * FROM orders;", + expected: []string{"SELECT 'unterminated string; SELECT * FROM orders;"}, + }, + { + name: "Unterminated comment (malformed SQL)", + input: "SELECT * FROM users; /* unterminated comment", + expected: []string{"SELECT * FROM users", "/* unterminated comment"}, + }, + { + name: "Multiple semicolons in quotes", + input: "SELECT ';;;' FROM users; SELECT ';;;' FROM orders;", + expected: []string{"SELECT ';;;' FROM users", "SELECT ';;;' FROM orders"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := splitSQLStatements(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("splitSQLStatements() = %v, expected %v", result, tt.expected) + } + }) + } +}