From 279abda3b725e3b80f20be008bde7b7c40793501 Mon Sep 17 00:00:00 2001 From: chrislu Date: Thu, 4 Sep 2025 16:38:08 -0700 Subject: [PATCH] support string concatenation || --- weed/query/engine/arithmetic_test.go | 245 +++++++++++++++++++++++++++ weed/query/engine/engine.go | 244 +++++++++++++++++++++++++- 2 files changed, 486 insertions(+), 3 deletions(-) create mode 100644 weed/query/engine/arithmetic_test.go diff --git a/weed/query/engine/arithmetic_test.go b/weed/query/engine/arithmetic_test.go new file mode 100644 index 000000000..3903d1b44 --- /dev/null +++ b/weed/query/engine/arithmetic_test.go @@ -0,0 +1,245 @@ +package engine + +import ( + "testing" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +func TestArithmeticExpressionParsing(t *testing.T) { + tests := []struct { + name string + expression string + expectNil bool + leftCol string + rightCol string + operator string + }{ + { + name: "simple addition", + expression: "id+user_id", + expectNil: false, + leftCol: "id", + rightCol: "user_id", + operator: "+", + }, + { + name: "simple subtraction", + expression: "col1-col2", + expectNil: false, + leftCol: "col1", + rightCol: "col2", + operator: "-", + }, + { + name: "multiplication with spaces", + expression: "a * b", + expectNil: false, + leftCol: "a", + rightCol: "b", + operator: "*", + }, + { + name: "string concatenation", + expression: "first_name||last_name", + expectNil: false, + leftCol: "first_name", + rightCol: "last_name", + operator: "||", + }, + { + name: "string concatenation with spaces", + expression: "prefix || suffix", + expectNil: false, + leftCol: "prefix", + rightCol: "suffix", + operator: "||", + }, + { + name: "not arithmetic", + expression: "simple_column", + expectNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseArithmeticExpression(tt.expression) + + if tt.expectNil { + if result != nil { + t.Errorf("Expected nil for %s, got %v", tt.expression, result) + } + return + } + + if result == nil { + t.Errorf("Expected arithmetic expression for %s, got nil", tt.expression) + return + } + + if result.Operator != tt.operator { + t.Errorf("Expected operator %s, got %s", tt.operator, result.Operator) + } + + // Check left operand + if leftCol, ok := result.Left.(*ColName); ok { + if leftCol.Name.String() != tt.leftCol { + t.Errorf("Expected left column %s, got %s", tt.leftCol, leftCol.Name.String()) + } + } else { + t.Errorf("Expected left operand to be ColName, got %T", result.Left) + } + + // Check right operand + if rightCol, ok := result.Right.(*ColName); ok { + if rightCol.Name.String() != tt.rightCol { + t.Errorf("Expected right column %s, got %s", tt.rightCol, rightCol.Name.String()) + } + } else { + t.Errorf("Expected right operand to be ColName, got %T", result.Right) + } + }) + } +} + +func TestArithmeticExpressionEvaluation(t *testing.T) { + engine := NewSQLEngine("") + + // Create test data + result := HybridScanResult{ + Values: map[string]*schema_pb.Value{ + "id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, + "user_id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + "price": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 25.5}}, + "qty": {Kind: &schema_pb.Value_Int64Value{Int64Value: 3}}, + "first_name": {Kind: &schema_pb.Value_StringValue{StringValue: "John"}}, + "last_name": {Kind: &schema_pb.Value_StringValue{StringValue: "Doe"}}, + "prefix": {Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}}, + "suffix": {Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, + }, + } + + tests := []struct { + name string + expression string + expected interface{} + }{ + { + name: "integer addition", + expression: "id+user_id", + expected: int64(15), + }, + { + name: "integer subtraction", + expression: "id-user_id", + expected: int64(5), + }, + { + name: "mixed types multiplication", + expression: "price*qty", + expected: float64(76.5), + }, + { + name: "string concatenation", + expression: "first_name||last_name", + expected: "JohnDoe", + }, + { + name: "string concatenation with spaces", + expression: "prefix || suffix", + expected: "HelloWorld", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the arithmetic expression + arithmeticExpr := parseArithmeticExpression(tt.expression) + if arithmeticExpr == nil { + t.Fatalf("Failed to parse arithmetic expression: %s", tt.expression) + } + + // Evaluate the expression + value, err := engine.evaluateArithmeticExpression(arithmeticExpr, result) + if err != nil { + t.Fatalf("Failed to evaluate expression: %v", err) + } + + if value == nil { + t.Fatalf("Got nil value for expression: %s", tt.expression) + } + + // Check the result + switch expected := tt.expected.(type) { + case int64: + if intVal, ok := value.Kind.(*schema_pb.Value_Int64Value); ok { + if intVal.Int64Value != expected { + t.Errorf("Expected %d, got %d", expected, intVal.Int64Value) + } + } else { + t.Errorf("Expected int64 result, got %T", value.Kind) + } + case float64: + if doubleVal, ok := value.Kind.(*schema_pb.Value_DoubleValue); ok { + if doubleVal.DoubleValue != expected { + t.Errorf("Expected %f, got %f", expected, doubleVal.DoubleValue) + } + } else { + t.Errorf("Expected double result, got %T", value.Kind) + } + case string: + if stringVal, ok := value.Kind.(*schema_pb.Value_StringValue); ok { + if stringVal.StringValue != expected { + t.Errorf("Expected %s, got %s", expected, stringVal.StringValue) + } + } else { + t.Errorf("Expected string result, got %T", value.Kind) + } + } + }) + } +} + +func TestSelectArithmeticExpression(t *testing.T) { + // Test parsing a SELECT with arithmetic and string concatenation expressions + stmt, err := ParseSQL("SELECT id+user_id, user_id*2, first_name||last_name FROM test_table") + if err != nil { + t.Fatalf("Failed to parse SQL: %v", err) + } + + selectStmt := stmt.(*SelectStatement) + if len(selectStmt.SelectExprs) != 3 { + t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs)) + } + + // Check first expression (id+user_id) + aliasedExpr1 := selectStmt.SelectExprs[0].(*AliasedExpr) + if arithmeticExpr1, ok := aliasedExpr1.Expr.(*ArithmeticExpr); ok { + if arithmeticExpr1.Operator != "+" { + t.Errorf("Expected + operator, got %s", arithmeticExpr1.Operator) + } + } else { + t.Errorf("Expected arithmetic expression, got %T", aliasedExpr1.Expr) + } + + // Check second expression (user_id*2) + aliasedExpr2 := selectStmt.SelectExprs[1].(*AliasedExpr) + if arithmeticExpr2, ok := aliasedExpr2.Expr.(*ArithmeticExpr); ok { + if arithmeticExpr2.Operator != "*" { + t.Errorf("Expected * operator, got %s", arithmeticExpr2.Operator) + } + } else { + t.Errorf("Expected arithmetic expression, got %T", aliasedExpr2.Expr) + } + + // Check third expression (first_name||last_name) + aliasedExpr3 := selectStmt.SelectExprs[2].(*AliasedExpr) + if arithmeticExpr3, ok := aliasedExpr3.Expr.(*ArithmeticExpr); ok { + if arithmeticExpr3.Operator != "||" { + t.Errorf("Expected || operator, got %s", arithmeticExpr3.Operator) + } + } else { + t.Errorf("Expected string concatenation expression, got %T", aliasedExpr3.Expr) + } +} diff --git a/weed/query/engine/engine.go b/weed/query/engine/engine.go index 985bcd4fe..472306652 100644 --- a/weed/query/engine/engine.go +++ b/weed/query/engine/engine.go @@ -164,6 +164,15 @@ type ColName struct { func (c *ColName) isExprNode() {} +// ArithmeticExpr represents arithmetic operations like id+user_id and string concatenation like name||suffix +type ArithmeticExpr struct { + Left ExprNode + Right ExprNode + Operator string // +, -, *, /, %, || +} + +func (a *ArithmeticExpr) isExprNode() {} + type ComparisonExpr struct { Left ExprNode Right ExprNode @@ -312,7 +321,7 @@ func parseSelectStatement(sql string) (*SelectStatement, error) { if part == "*" { s.SelectExprs = append(s.SelectExprs, &StarExpr{}) } else { - // Handle column names and functions + // Handle column names, functions, and arithmetic expressions expr := &AliasedExpr{} if strings.Contains(strings.ToUpper(part), "(") && strings.Contains(part, ")") { // Function expression @@ -326,6 +335,9 @@ func parseSelectStatement(sql string) (*SelectStatement, error) { Exprs: funcArgs, } expr.Expr = funcExpr + } else if arithmeticExpr := parseArithmeticExpression(part); arithmeticExpr != nil { + // Arithmetic expression (id+user_id, col1-col2, etc.) + expr.Expr = arithmeticExpr } else { // Column name colExpr := &ColName{Name: stringValue(part)} @@ -438,6 +450,64 @@ func extractFunctionName(expr string) string { return strings.TrimSpace(expr[:parenIdx]) } +// parseArithmeticExpression parses arithmetic expressions like id+user_id, col1*col2, etc. +func parseArithmeticExpression(expr string) *ArithmeticExpr { + // Remove spaces for easier parsing + expr = strings.ReplaceAll(expr, " ", "") + + // Check for arithmetic and string operators (order matters for precedence) + // String concatenation (||) has lower precedence than arithmetic operators + operators := []string{"||", "+", "-", "*", "/", "%"} + + for _, op := range operators { + // Find the operator position (skip operators inside parentheses) + opPos := -1 + parenLevel := 0 + for i, char := range expr { + if char == '(' { + parenLevel++ + } else if char == ')' { + parenLevel-- + } else if parenLevel == 0 && strings.HasPrefix(expr[i:], op) { + opPos = i + break + } + } + + if opPos > 0 && opPos < len(expr)-len(op) { + leftExpr := strings.TrimSpace(expr[:opPos]) + rightExpr := strings.TrimSpace(expr[opPos+len(op):]) + + if leftExpr != "" && rightExpr != "" { + // Create left and right expressions (recursively handle complex expressions) + var left, right ExprNode + + // Parse left side + if leftArithmetic := parseArithmeticExpression(leftExpr); leftArithmetic != nil { + left = leftArithmetic + } else { + left = &ColName{Name: stringValue(leftExpr)} + } + + // Parse right side + if rightArithmetic := parseArithmeticExpression(rightExpr); rightArithmetic != nil { + right = rightArithmetic + } else { + right = &ColName{Name: stringValue(rightExpr)} + } + + return &ArithmeticExpr{ + Left: left, + Right: right, + Operator: op, + } + } + } + } + + return nil +} + // extractFunctionArguments extracts the arguments from a function call expression func extractFunctionArguments(expr string) ([]SelectExpr, error) { // Find the parentheses @@ -1383,6 +1453,9 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *SelectStat switch col := expr.Expr.(type) { case *ColName: columns = append(columns, col.Name.String()) + case *ArithmeticExpr: + // Handle arithmetic expressions like id+user_id and string concatenation like name||suffix + columns = append(columns, e.getArithmeticExpressionAlias(col)) case *FuncExpr: // Handle aggregation functions aggSpec, err := e.parseAggregationFunction(col, expr) @@ -1482,9 +1555,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *SelectStat // Convert to SQL result format if selectAll { columns = nil // Let converter determine all columns + return hybridScanner.ConvertToSQLResult(results, columns), nil } - return hybridScanner.ConvertToSQLResult(results, columns), nil + // Handle custom column expressions (including arithmetic) + return e.ConvertToSQLResultWithExpressions(hybridScanner, results, stmt.SelectExprs), nil } // executeSelectStatementWithBrokerStats handles SELECT queries with broker buffer statistics capture @@ -1573,6 +1648,9 @@ func (e *SQLEngine) executeSelectStatementWithBrokerStats(ctx context.Context, s switch col := expr.Expr.(type) { case *ColName: columns = append(columns, col.Name.String()) + case *ArithmeticExpr: + // Handle arithmetic expressions like id+user_id and string concatenation like name||suffix + columns = append(columns, e.getArithmeticExpressionAlias(col)) case *FuncExpr: // Handle aggregation functions aggSpec, err := e.parseAggregationFunction(col, expr) @@ -1705,9 +1783,11 @@ func (e *SQLEngine) executeSelectStatementWithBrokerStats(ctx context.Context, s // Convert to SQL result format if selectAll { columns = nil // Let converter determine all columns + return hybridScanner.ConvertToSQLResult(results, columns), nil } - return hybridScanner.ConvertToSQLResult(results, columns), nil + // Handle custom column expressions (including arithmetic) + return e.ConvertToSQLResultWithExpressions(hybridScanner, results, stmt.SelectExprs), nil } // extractTimeFilters extracts time range filters from WHERE clause for optimization @@ -3347,3 +3427,161 @@ func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tabl // Note: This is a discovery operation, not query execution, so it's okay to always log return nil } + +// getArithmeticExpressionAlias generates a display alias for arithmetic expressions +func (e *SQLEngine) getArithmeticExpressionAlias(expr *ArithmeticExpr) string { + leftAlias := e.getExpressionAlias(expr.Left) + rightAlias := e.getExpressionAlias(expr.Right) + return leftAlias + expr.Operator + rightAlias +} + +// getExpressionAlias generates an alias for any expression node +func (e *SQLEngine) getExpressionAlias(expr ExprNode) string { + switch exprType := expr.(type) { + case *ColName: + return exprType.Name.String() + case *ArithmeticExpr: + return e.getArithmeticExpressionAlias(exprType) + default: + return "expr" + } +} + +// evaluateArithmeticExpression evaluates an arithmetic expression for a given record +func (e *SQLEngine) evaluateArithmeticExpression(expr *ArithmeticExpr, result HybridScanResult) (*schema_pb.Value, error) { + // Get left operand value + leftValue, err := e.evaluateExpressionValue(expr.Left, result) + if err != nil { + return nil, fmt.Errorf("error evaluating left operand: %v", err) + } + + // Get right operand value + rightValue, err := e.evaluateExpressionValue(expr.Right, result) + if err != nil { + return nil, fmt.Errorf("error evaluating right operand: %v", err) + } + + // Handle string concatenation operator + if expr.Operator == "||" { + return e.Concat(leftValue, rightValue) + } + + // Perform arithmetic operation + var op ArithmeticOperator + switch expr.Operator { + case "+": + op = OpAdd + case "-": + op = OpSub + case "*": + op = OpMul + case "/": + op = OpDiv + case "%": + op = OpMod + default: + return nil, fmt.Errorf("unsupported arithmetic operator: %s", expr.Operator) + } + + return e.EvaluateArithmeticExpression(leftValue, rightValue, op) +} + +// evaluateExpressionValue evaluates any expression to get its value from a record +func (e *SQLEngine) evaluateExpressionValue(expr ExprNode, result HybridScanResult) (*schema_pb.Value, error) { + switch exprType := expr.(type) { + case *ColName: + columnName := exprType.Name.String() + value := e.findColumnValue(result, columnName) + if value == nil { + return nil, nil + } + return value, nil + case *ArithmeticExpr: + return e.evaluateArithmeticExpression(exprType, result) + default: + return nil, fmt.Errorf("unsupported expression type: %T", expr) + } +} + +// ConvertToSQLResultWithExpressions converts HybridScanResults to SQL query results with expression evaluation +func (e *SQLEngine) ConvertToSQLResultWithExpressions(hms *HybridMessageScanner, results []HybridScanResult, selectExprs []SelectExpr) *QueryResult { + if len(results) == 0 { + columns := make([]string, 0, len(selectExprs)) + for _, selectExpr := range selectExprs { + switch expr := selectExpr.(type) { + case *AliasedExpr: + switch col := expr.Expr.(type) { + case *ColName: + columns = append(columns, col.Name.String()) + case *ArithmeticExpr: + columns = append(columns, e.getArithmeticExpressionAlias(col)) + default: + columns = append(columns, "expr") + } + } + } + + return &QueryResult{ + Columns: columns, + Rows: [][]sqltypes.Value{}, + Database: hms.topic.Namespace, + Table: hms.topic.Name, + } + } + + // Build columns from SELECT expressions + columns := make([]string, 0, len(selectExprs)) + for _, selectExpr := range selectExprs { + switch expr := selectExpr.(type) { + case *AliasedExpr: + switch col := expr.Expr.(type) { + case *ColName: + columns = append(columns, col.Name.String()) + case *ArithmeticExpr: + columns = append(columns, e.getArithmeticExpressionAlias(col)) + default: + columns = append(columns, "expr") + } + } + } + + // Convert to SQL rows with expression evaluation + rows := make([][]sqltypes.Value, len(results)) + for i, result := range results { + row := make([]sqltypes.Value, len(selectExprs)) + for j, selectExpr := range selectExprs { + switch expr := selectExpr.(type) { + case *AliasedExpr: + switch col := expr.Expr.(type) { + case *ColName: + // Handle regular column + columnName := col.Name.String() + if value := e.findColumnValue(result, columnName); value != nil { + row[j] = convertSchemaValueToSQL(value) + } else { + row[j] = sqltypes.NULL + } + case *ArithmeticExpr: + // Handle arithmetic expression + if value, err := e.evaluateArithmeticExpression(col, result); err == nil && value != nil { + row[j] = convertSchemaValueToSQL(value) + } else { + row[j] = sqltypes.NULL + } + default: + row[j] = sqltypes.NULL + } + default: + row[j] = sqltypes.NULL + } + } + rows[i] = row + } + + return &QueryResult{ + Columns: columns, + Rows: rows, + Database: hms.topic.Namespace, + Table: hms.topic.Name, + } +}