support string concatenation ||

This commit is contained in:
chrislu
2025-09-04 16:38:08 -07:00
parent e528629944
commit 279abda3b7
2 changed files with 486 additions and 3 deletions

View File

@@ -0,0 +1,245 @@
package engine
import (
"testing"
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
)
func TestArithmeticExpressionParsing(t *testing.T) {
tests := []struct {
name string
expression string
expectNil bool
leftCol string
rightCol string
operator string
}{
{
name: "simple addition",
expression: "id+user_id",
expectNil: false,
leftCol: "id",
rightCol: "user_id",
operator: "+",
},
{
name: "simple subtraction",
expression: "col1-col2",
expectNil: false,
leftCol: "col1",
rightCol: "col2",
operator: "-",
},
{
name: "multiplication with spaces",
expression: "a * b",
expectNil: false,
leftCol: "a",
rightCol: "b",
operator: "*",
},
{
name: "string concatenation",
expression: "first_name||last_name",
expectNil: false,
leftCol: "first_name",
rightCol: "last_name",
operator: "||",
},
{
name: "string concatenation with spaces",
expression: "prefix || suffix",
expectNil: false,
leftCol: "prefix",
rightCol: "suffix",
operator: "||",
},
{
name: "not arithmetic",
expression: "simple_column",
expectNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseArithmeticExpression(tt.expression)
if tt.expectNil {
if result != nil {
t.Errorf("Expected nil for %s, got %v", tt.expression, result)
}
return
}
if result == nil {
t.Errorf("Expected arithmetic expression for %s, got nil", tt.expression)
return
}
if result.Operator != tt.operator {
t.Errorf("Expected operator %s, got %s", tt.operator, result.Operator)
}
// Check left operand
if leftCol, ok := result.Left.(*ColName); ok {
if leftCol.Name.String() != tt.leftCol {
t.Errorf("Expected left column %s, got %s", tt.leftCol, leftCol.Name.String())
}
} else {
t.Errorf("Expected left operand to be ColName, got %T", result.Left)
}
// Check right operand
if rightCol, ok := result.Right.(*ColName); ok {
if rightCol.Name.String() != tt.rightCol {
t.Errorf("Expected right column %s, got %s", tt.rightCol, rightCol.Name.String())
}
} else {
t.Errorf("Expected right operand to be ColName, got %T", result.Right)
}
})
}
}
func TestArithmeticExpressionEvaluation(t *testing.T) {
engine := NewSQLEngine("")
// Create test data
result := HybridScanResult{
Values: map[string]*schema_pb.Value{
"id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 10}},
"user_id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 5}},
"price": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 25.5}},
"qty": {Kind: &schema_pb.Value_Int64Value{Int64Value: 3}},
"first_name": {Kind: &schema_pb.Value_StringValue{StringValue: "John"}},
"last_name": {Kind: &schema_pb.Value_StringValue{StringValue: "Doe"}},
"prefix": {Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}},
"suffix": {Kind: &schema_pb.Value_StringValue{StringValue: "World"}},
},
}
tests := []struct {
name string
expression string
expected interface{}
}{
{
name: "integer addition",
expression: "id+user_id",
expected: int64(15),
},
{
name: "integer subtraction",
expression: "id-user_id",
expected: int64(5),
},
{
name: "mixed types multiplication",
expression: "price*qty",
expected: float64(76.5),
},
{
name: "string concatenation",
expression: "first_name||last_name",
expected: "JohnDoe",
},
{
name: "string concatenation with spaces",
expression: "prefix || suffix",
expected: "HelloWorld",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Parse the arithmetic expression
arithmeticExpr := parseArithmeticExpression(tt.expression)
if arithmeticExpr == nil {
t.Fatalf("Failed to parse arithmetic expression: %s", tt.expression)
}
// Evaluate the expression
value, err := engine.evaluateArithmeticExpression(arithmeticExpr, result)
if err != nil {
t.Fatalf("Failed to evaluate expression: %v", err)
}
if value == nil {
t.Fatalf("Got nil value for expression: %s", tt.expression)
}
// Check the result
switch expected := tt.expected.(type) {
case int64:
if intVal, ok := value.Kind.(*schema_pb.Value_Int64Value); ok {
if intVal.Int64Value != expected {
t.Errorf("Expected %d, got %d", expected, intVal.Int64Value)
}
} else {
t.Errorf("Expected int64 result, got %T", value.Kind)
}
case float64:
if doubleVal, ok := value.Kind.(*schema_pb.Value_DoubleValue); ok {
if doubleVal.DoubleValue != expected {
t.Errorf("Expected %f, got %f", expected, doubleVal.DoubleValue)
}
} else {
t.Errorf("Expected double result, got %T", value.Kind)
}
case string:
if stringVal, ok := value.Kind.(*schema_pb.Value_StringValue); ok {
if stringVal.StringValue != expected {
t.Errorf("Expected %s, got %s", expected, stringVal.StringValue)
}
} else {
t.Errorf("Expected string result, got %T", value.Kind)
}
}
})
}
}
func TestSelectArithmeticExpression(t *testing.T) {
// Test parsing a SELECT with arithmetic and string concatenation expressions
stmt, err := ParseSQL("SELECT id+user_id, user_id*2, first_name||last_name FROM test_table")
if err != nil {
t.Fatalf("Failed to parse SQL: %v", err)
}
selectStmt := stmt.(*SelectStatement)
if len(selectStmt.SelectExprs) != 3 {
t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
}
// Check first expression (id+user_id)
aliasedExpr1 := selectStmt.SelectExprs[0].(*AliasedExpr)
if arithmeticExpr1, ok := aliasedExpr1.Expr.(*ArithmeticExpr); ok {
if arithmeticExpr1.Operator != "+" {
t.Errorf("Expected + operator, got %s", arithmeticExpr1.Operator)
}
} else {
t.Errorf("Expected arithmetic expression, got %T", aliasedExpr1.Expr)
}
// Check second expression (user_id*2)
aliasedExpr2 := selectStmt.SelectExprs[1].(*AliasedExpr)
if arithmeticExpr2, ok := aliasedExpr2.Expr.(*ArithmeticExpr); ok {
if arithmeticExpr2.Operator != "*" {
t.Errorf("Expected * operator, got %s", arithmeticExpr2.Operator)
}
} else {
t.Errorf("Expected arithmetic expression, got %T", aliasedExpr2.Expr)
}
// Check third expression (first_name||last_name)
aliasedExpr3 := selectStmt.SelectExprs[2].(*AliasedExpr)
if arithmeticExpr3, ok := aliasedExpr3.Expr.(*ArithmeticExpr); ok {
if arithmeticExpr3.Operator != "||" {
t.Errorf("Expected || operator, got %s", arithmeticExpr3.Operator)
}
} else {
t.Errorf("Expected string concatenation expression, got %T", aliasedExpr3.Expr)
}
}

View File

@@ -164,6 +164,15 @@ type ColName struct {
func (c *ColName) isExprNode() {}
// ArithmeticExpr represents arithmetic operations like id+user_id and string concatenation like name||suffix
type ArithmeticExpr struct {
Left ExprNode
Right ExprNode
Operator string // +, -, *, /, %, ||
}
func (a *ArithmeticExpr) isExprNode() {}
type ComparisonExpr struct {
Left ExprNode
Right ExprNode
@@ -312,7 +321,7 @@ func parseSelectStatement(sql string) (*SelectStatement, error) {
if part == "*" {
s.SelectExprs = append(s.SelectExprs, &StarExpr{})
} else {
// Handle column names and functions
// Handle column names, functions, and arithmetic expressions
expr := &AliasedExpr{}
if strings.Contains(strings.ToUpper(part), "(") && strings.Contains(part, ")") {
// Function expression
@@ -326,6 +335,9 @@ func parseSelectStatement(sql string) (*SelectStatement, error) {
Exprs: funcArgs,
}
expr.Expr = funcExpr
} else if arithmeticExpr := parseArithmeticExpression(part); arithmeticExpr != nil {
// Arithmetic expression (id+user_id, col1-col2, etc.)
expr.Expr = arithmeticExpr
} else {
// Column name
colExpr := &ColName{Name: stringValue(part)}
@@ -438,6 +450,64 @@ func extractFunctionName(expr string) string {
return strings.TrimSpace(expr[:parenIdx])
}
// parseArithmeticExpression parses arithmetic expressions like id+user_id, col1*col2, etc.
func parseArithmeticExpression(expr string) *ArithmeticExpr {
// Remove spaces for easier parsing
expr = strings.ReplaceAll(expr, " ", "")
// Check for arithmetic and string operators (order matters for precedence)
// String concatenation (||) has lower precedence than arithmetic operators
operators := []string{"||", "+", "-", "*", "/", "%"}
for _, op := range operators {
// Find the operator position (skip operators inside parentheses)
opPos := -1
parenLevel := 0
for i, char := range expr {
if char == '(' {
parenLevel++
} else if char == ')' {
parenLevel--
} else if parenLevel == 0 && strings.HasPrefix(expr[i:], op) {
opPos = i
break
}
}
if opPos > 0 && opPos < len(expr)-len(op) {
leftExpr := strings.TrimSpace(expr[:opPos])
rightExpr := strings.TrimSpace(expr[opPos+len(op):])
if leftExpr != "" && rightExpr != "" {
// Create left and right expressions (recursively handle complex expressions)
var left, right ExprNode
// Parse left side
if leftArithmetic := parseArithmeticExpression(leftExpr); leftArithmetic != nil {
left = leftArithmetic
} else {
left = &ColName{Name: stringValue(leftExpr)}
}
// Parse right side
if rightArithmetic := parseArithmeticExpression(rightExpr); rightArithmetic != nil {
right = rightArithmetic
} else {
right = &ColName{Name: stringValue(rightExpr)}
}
return &ArithmeticExpr{
Left: left,
Right: right,
Operator: op,
}
}
}
}
return nil
}
// extractFunctionArguments extracts the arguments from a function call expression
func extractFunctionArguments(expr string) ([]SelectExpr, error) {
// Find the parentheses
@@ -1383,6 +1453,9 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *SelectStat
switch col := expr.Expr.(type) {
case *ColName:
columns = append(columns, col.Name.String())
case *ArithmeticExpr:
// Handle arithmetic expressions like id+user_id and string concatenation like name||suffix
columns = append(columns, e.getArithmeticExpressionAlias(col))
case *FuncExpr:
// Handle aggregation functions
aggSpec, err := e.parseAggregationFunction(col, expr)
@@ -1482,9 +1555,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *SelectStat
// Convert to SQL result format
if selectAll {
columns = nil // Let converter determine all columns
return hybridScanner.ConvertToSQLResult(results, columns), nil
}
return hybridScanner.ConvertToSQLResult(results, columns), nil
// Handle custom column expressions (including arithmetic)
return e.ConvertToSQLResultWithExpressions(hybridScanner, results, stmt.SelectExprs), nil
}
// executeSelectStatementWithBrokerStats handles SELECT queries with broker buffer statistics capture
@@ -1573,6 +1648,9 @@ func (e *SQLEngine) executeSelectStatementWithBrokerStats(ctx context.Context, s
switch col := expr.Expr.(type) {
case *ColName:
columns = append(columns, col.Name.String())
case *ArithmeticExpr:
// Handle arithmetic expressions like id+user_id and string concatenation like name||suffix
columns = append(columns, e.getArithmeticExpressionAlias(col))
case *FuncExpr:
// Handle aggregation functions
aggSpec, err := e.parseAggregationFunction(col, expr)
@@ -1705,9 +1783,11 @@ func (e *SQLEngine) executeSelectStatementWithBrokerStats(ctx context.Context, s
// Convert to SQL result format
if selectAll {
columns = nil // Let converter determine all columns
return hybridScanner.ConvertToSQLResult(results, columns), nil
}
return hybridScanner.ConvertToSQLResult(results, columns), nil
// Handle custom column expressions (including arithmetic)
return e.ConvertToSQLResultWithExpressions(hybridScanner, results, stmt.SelectExprs), nil
}
// extractTimeFilters extracts time range filters from WHERE clause for optimization
@@ -3347,3 +3427,161 @@ func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tabl
// Note: This is a discovery operation, not query execution, so it's okay to always log
return nil
}
// getArithmeticExpressionAlias generates a display alias for arithmetic expressions
func (e *SQLEngine) getArithmeticExpressionAlias(expr *ArithmeticExpr) string {
leftAlias := e.getExpressionAlias(expr.Left)
rightAlias := e.getExpressionAlias(expr.Right)
return leftAlias + expr.Operator + rightAlias
}
// getExpressionAlias generates an alias for any expression node
func (e *SQLEngine) getExpressionAlias(expr ExprNode) string {
switch exprType := expr.(type) {
case *ColName:
return exprType.Name.String()
case *ArithmeticExpr:
return e.getArithmeticExpressionAlias(exprType)
default:
return "expr"
}
}
// evaluateArithmeticExpression evaluates an arithmetic expression for a given record
func (e *SQLEngine) evaluateArithmeticExpression(expr *ArithmeticExpr, result HybridScanResult) (*schema_pb.Value, error) {
// Get left operand value
leftValue, err := e.evaluateExpressionValue(expr.Left, result)
if err != nil {
return nil, fmt.Errorf("error evaluating left operand: %v", err)
}
// Get right operand value
rightValue, err := e.evaluateExpressionValue(expr.Right, result)
if err != nil {
return nil, fmt.Errorf("error evaluating right operand: %v", err)
}
// Handle string concatenation operator
if expr.Operator == "||" {
return e.Concat(leftValue, rightValue)
}
// Perform arithmetic operation
var op ArithmeticOperator
switch expr.Operator {
case "+":
op = OpAdd
case "-":
op = OpSub
case "*":
op = OpMul
case "/":
op = OpDiv
case "%":
op = OpMod
default:
return nil, fmt.Errorf("unsupported arithmetic operator: %s", expr.Operator)
}
return e.EvaluateArithmeticExpression(leftValue, rightValue, op)
}
// evaluateExpressionValue evaluates any expression to get its value from a record
func (e *SQLEngine) evaluateExpressionValue(expr ExprNode, result HybridScanResult) (*schema_pb.Value, error) {
switch exprType := expr.(type) {
case *ColName:
columnName := exprType.Name.String()
value := e.findColumnValue(result, columnName)
if value == nil {
return nil, nil
}
return value, nil
case *ArithmeticExpr:
return e.evaluateArithmeticExpression(exprType, result)
default:
return nil, fmt.Errorf("unsupported expression type: %T", expr)
}
}
// ConvertToSQLResultWithExpressions converts HybridScanResults to SQL query results with expression evaluation
func (e *SQLEngine) ConvertToSQLResultWithExpressions(hms *HybridMessageScanner, results []HybridScanResult, selectExprs []SelectExpr) *QueryResult {
if len(results) == 0 {
columns := make([]string, 0, len(selectExprs))
for _, selectExpr := range selectExprs {
switch expr := selectExpr.(type) {
case *AliasedExpr:
switch col := expr.Expr.(type) {
case *ColName:
columns = append(columns, col.Name.String())
case *ArithmeticExpr:
columns = append(columns, e.getArithmeticExpressionAlias(col))
default:
columns = append(columns, "expr")
}
}
}
return &QueryResult{
Columns: columns,
Rows: [][]sqltypes.Value{},
Database: hms.topic.Namespace,
Table: hms.topic.Name,
}
}
// Build columns from SELECT expressions
columns := make([]string, 0, len(selectExprs))
for _, selectExpr := range selectExprs {
switch expr := selectExpr.(type) {
case *AliasedExpr:
switch col := expr.Expr.(type) {
case *ColName:
columns = append(columns, col.Name.String())
case *ArithmeticExpr:
columns = append(columns, e.getArithmeticExpressionAlias(col))
default:
columns = append(columns, "expr")
}
}
}
// Convert to SQL rows with expression evaluation
rows := make([][]sqltypes.Value, len(results))
for i, result := range results {
row := make([]sqltypes.Value, len(selectExprs))
for j, selectExpr := range selectExprs {
switch expr := selectExpr.(type) {
case *AliasedExpr:
switch col := expr.Expr.(type) {
case *ColName:
// Handle regular column
columnName := col.Name.String()
if value := e.findColumnValue(result, columnName); value != nil {
row[j] = convertSchemaValueToSQL(value)
} else {
row[j] = sqltypes.NULL
}
case *ArithmeticExpr:
// Handle arithmetic expression
if value, err := e.evaluateArithmeticExpression(col, result); err == nil && value != nil {
row[j] = convertSchemaValueToSQL(value)
} else {
row[j] = sqltypes.NULL
}
default:
row[j] = sqltypes.NULL
}
default:
row[j] = sqltypes.NULL
}
}
rows[i] = row
}
return &QueryResult{
Columns: columns,
Rows: rows,
Database: hms.topic.Namespace,
Table: hms.topic.Name,
}
}