mirror of
https://github.com/seaweedfs/seaweedfs.git
synced 2025-09-19 05:59:23 +08:00
support string concatenation ||
This commit is contained in:
245
weed/query/engine/arithmetic_test.go
Normal file
245
weed/query/engine/arithmetic_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/seaweedfs/seaweedfs/weed/pb/schema_pb"
|
||||
)
|
||||
|
||||
func TestArithmeticExpressionParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
expectNil bool
|
||||
leftCol string
|
||||
rightCol string
|
||||
operator string
|
||||
}{
|
||||
{
|
||||
name: "simple addition",
|
||||
expression: "id+user_id",
|
||||
expectNil: false,
|
||||
leftCol: "id",
|
||||
rightCol: "user_id",
|
||||
operator: "+",
|
||||
},
|
||||
{
|
||||
name: "simple subtraction",
|
||||
expression: "col1-col2",
|
||||
expectNil: false,
|
||||
leftCol: "col1",
|
||||
rightCol: "col2",
|
||||
operator: "-",
|
||||
},
|
||||
{
|
||||
name: "multiplication with spaces",
|
||||
expression: "a * b",
|
||||
expectNil: false,
|
||||
leftCol: "a",
|
||||
rightCol: "b",
|
||||
operator: "*",
|
||||
},
|
||||
{
|
||||
name: "string concatenation",
|
||||
expression: "first_name||last_name",
|
||||
expectNil: false,
|
||||
leftCol: "first_name",
|
||||
rightCol: "last_name",
|
||||
operator: "||",
|
||||
},
|
||||
{
|
||||
name: "string concatenation with spaces",
|
||||
expression: "prefix || suffix",
|
||||
expectNil: false,
|
||||
leftCol: "prefix",
|
||||
rightCol: "suffix",
|
||||
operator: "||",
|
||||
},
|
||||
{
|
||||
name: "not arithmetic",
|
||||
expression: "simple_column",
|
||||
expectNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseArithmeticExpression(tt.expression)
|
||||
|
||||
if tt.expectNil {
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil for %s, got %v", tt.expression, result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Errorf("Expected arithmetic expression for %s, got nil", tt.expression)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Operator != tt.operator {
|
||||
t.Errorf("Expected operator %s, got %s", tt.operator, result.Operator)
|
||||
}
|
||||
|
||||
// Check left operand
|
||||
if leftCol, ok := result.Left.(*ColName); ok {
|
||||
if leftCol.Name.String() != tt.leftCol {
|
||||
t.Errorf("Expected left column %s, got %s", tt.leftCol, leftCol.Name.String())
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected left operand to be ColName, got %T", result.Left)
|
||||
}
|
||||
|
||||
// Check right operand
|
||||
if rightCol, ok := result.Right.(*ColName); ok {
|
||||
if rightCol.Name.String() != tt.rightCol {
|
||||
t.Errorf("Expected right column %s, got %s", tt.rightCol, rightCol.Name.String())
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected right operand to be ColName, got %T", result.Right)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArithmeticExpressionEvaluation(t *testing.T) {
|
||||
engine := NewSQLEngine("")
|
||||
|
||||
// Create test data
|
||||
result := HybridScanResult{
|
||||
Values: map[string]*schema_pb.Value{
|
||||
"id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 10}},
|
||||
"user_id": {Kind: &schema_pb.Value_Int64Value{Int64Value: 5}},
|
||||
"price": {Kind: &schema_pb.Value_DoubleValue{DoubleValue: 25.5}},
|
||||
"qty": {Kind: &schema_pb.Value_Int64Value{Int64Value: 3}},
|
||||
"first_name": {Kind: &schema_pb.Value_StringValue{StringValue: "John"}},
|
||||
"last_name": {Kind: &schema_pb.Value_StringValue{StringValue: "Doe"}},
|
||||
"prefix": {Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}},
|
||||
"suffix": {Kind: &schema_pb.Value_StringValue{StringValue: "World"}},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expression string
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "integer addition",
|
||||
expression: "id+user_id",
|
||||
expected: int64(15),
|
||||
},
|
||||
{
|
||||
name: "integer subtraction",
|
||||
expression: "id-user_id",
|
||||
expected: int64(5),
|
||||
},
|
||||
{
|
||||
name: "mixed types multiplication",
|
||||
expression: "price*qty",
|
||||
expected: float64(76.5),
|
||||
},
|
||||
{
|
||||
name: "string concatenation",
|
||||
expression: "first_name||last_name",
|
||||
expected: "JohnDoe",
|
||||
},
|
||||
{
|
||||
name: "string concatenation with spaces",
|
||||
expression: "prefix || suffix",
|
||||
expected: "HelloWorld",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Parse the arithmetic expression
|
||||
arithmeticExpr := parseArithmeticExpression(tt.expression)
|
||||
if arithmeticExpr == nil {
|
||||
t.Fatalf("Failed to parse arithmetic expression: %s", tt.expression)
|
||||
}
|
||||
|
||||
// Evaluate the expression
|
||||
value, err := engine.evaluateArithmeticExpression(arithmeticExpr, result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to evaluate expression: %v", err)
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
t.Fatalf("Got nil value for expression: %s", tt.expression)
|
||||
}
|
||||
|
||||
// Check the result
|
||||
switch expected := tt.expected.(type) {
|
||||
case int64:
|
||||
if intVal, ok := value.Kind.(*schema_pb.Value_Int64Value); ok {
|
||||
if intVal.Int64Value != expected {
|
||||
t.Errorf("Expected %d, got %d", expected, intVal.Int64Value)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected int64 result, got %T", value.Kind)
|
||||
}
|
||||
case float64:
|
||||
if doubleVal, ok := value.Kind.(*schema_pb.Value_DoubleValue); ok {
|
||||
if doubleVal.DoubleValue != expected {
|
||||
t.Errorf("Expected %f, got %f", expected, doubleVal.DoubleValue)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected double result, got %T", value.Kind)
|
||||
}
|
||||
case string:
|
||||
if stringVal, ok := value.Kind.(*schema_pb.Value_StringValue); ok {
|
||||
if stringVal.StringValue != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, stringVal.StringValue)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected string result, got %T", value.Kind)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectArithmeticExpression(t *testing.T) {
|
||||
// Test parsing a SELECT with arithmetic and string concatenation expressions
|
||||
stmt, err := ParseSQL("SELECT id+user_id, user_id*2, first_name||last_name FROM test_table")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse SQL: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*SelectStatement)
|
||||
if len(selectStmt.SelectExprs) != 3 {
|
||||
t.Fatalf("Expected 3 select expressions, got %d", len(selectStmt.SelectExprs))
|
||||
}
|
||||
|
||||
// Check first expression (id+user_id)
|
||||
aliasedExpr1 := selectStmt.SelectExprs[0].(*AliasedExpr)
|
||||
if arithmeticExpr1, ok := aliasedExpr1.Expr.(*ArithmeticExpr); ok {
|
||||
if arithmeticExpr1.Operator != "+" {
|
||||
t.Errorf("Expected + operator, got %s", arithmeticExpr1.Operator)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected arithmetic expression, got %T", aliasedExpr1.Expr)
|
||||
}
|
||||
|
||||
// Check second expression (user_id*2)
|
||||
aliasedExpr2 := selectStmt.SelectExprs[1].(*AliasedExpr)
|
||||
if arithmeticExpr2, ok := aliasedExpr2.Expr.(*ArithmeticExpr); ok {
|
||||
if arithmeticExpr2.Operator != "*" {
|
||||
t.Errorf("Expected * operator, got %s", arithmeticExpr2.Operator)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected arithmetic expression, got %T", aliasedExpr2.Expr)
|
||||
}
|
||||
|
||||
// Check third expression (first_name||last_name)
|
||||
aliasedExpr3 := selectStmt.SelectExprs[2].(*AliasedExpr)
|
||||
if arithmeticExpr3, ok := aliasedExpr3.Expr.(*ArithmeticExpr); ok {
|
||||
if arithmeticExpr3.Operator != "||" {
|
||||
t.Errorf("Expected || operator, got %s", arithmeticExpr3.Operator)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Expected string concatenation expression, got %T", aliasedExpr3.Expr)
|
||||
}
|
||||
}
|
@@ -164,6 +164,15 @@ type ColName struct {
|
||||
|
||||
func (c *ColName) isExprNode() {}
|
||||
|
||||
// ArithmeticExpr represents arithmetic operations like id+user_id and string concatenation like name||suffix
|
||||
type ArithmeticExpr struct {
|
||||
Left ExprNode
|
||||
Right ExprNode
|
||||
Operator string // +, -, *, /, %, ||
|
||||
}
|
||||
|
||||
func (a *ArithmeticExpr) isExprNode() {}
|
||||
|
||||
type ComparisonExpr struct {
|
||||
Left ExprNode
|
||||
Right ExprNode
|
||||
@@ -312,7 +321,7 @@ func parseSelectStatement(sql string) (*SelectStatement, error) {
|
||||
if part == "*" {
|
||||
s.SelectExprs = append(s.SelectExprs, &StarExpr{})
|
||||
} else {
|
||||
// Handle column names and functions
|
||||
// Handle column names, functions, and arithmetic expressions
|
||||
expr := &AliasedExpr{}
|
||||
if strings.Contains(strings.ToUpper(part), "(") && strings.Contains(part, ")") {
|
||||
// Function expression
|
||||
@@ -326,6 +335,9 @@ func parseSelectStatement(sql string) (*SelectStatement, error) {
|
||||
Exprs: funcArgs,
|
||||
}
|
||||
expr.Expr = funcExpr
|
||||
} else if arithmeticExpr := parseArithmeticExpression(part); arithmeticExpr != nil {
|
||||
// Arithmetic expression (id+user_id, col1-col2, etc.)
|
||||
expr.Expr = arithmeticExpr
|
||||
} else {
|
||||
// Column name
|
||||
colExpr := &ColName{Name: stringValue(part)}
|
||||
@@ -438,6 +450,64 @@ func extractFunctionName(expr string) string {
|
||||
return strings.TrimSpace(expr[:parenIdx])
|
||||
}
|
||||
|
||||
// parseArithmeticExpression parses arithmetic expressions like id+user_id, col1*col2, etc.
|
||||
func parseArithmeticExpression(expr string) *ArithmeticExpr {
|
||||
// Remove spaces for easier parsing
|
||||
expr = strings.ReplaceAll(expr, " ", "")
|
||||
|
||||
// Check for arithmetic and string operators (order matters for precedence)
|
||||
// String concatenation (||) has lower precedence than arithmetic operators
|
||||
operators := []string{"||", "+", "-", "*", "/", "%"}
|
||||
|
||||
for _, op := range operators {
|
||||
// Find the operator position (skip operators inside parentheses)
|
||||
opPos := -1
|
||||
parenLevel := 0
|
||||
for i, char := range expr {
|
||||
if char == '(' {
|
||||
parenLevel++
|
||||
} else if char == ')' {
|
||||
parenLevel--
|
||||
} else if parenLevel == 0 && strings.HasPrefix(expr[i:], op) {
|
||||
opPos = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if opPos > 0 && opPos < len(expr)-len(op) {
|
||||
leftExpr := strings.TrimSpace(expr[:opPos])
|
||||
rightExpr := strings.TrimSpace(expr[opPos+len(op):])
|
||||
|
||||
if leftExpr != "" && rightExpr != "" {
|
||||
// Create left and right expressions (recursively handle complex expressions)
|
||||
var left, right ExprNode
|
||||
|
||||
// Parse left side
|
||||
if leftArithmetic := parseArithmeticExpression(leftExpr); leftArithmetic != nil {
|
||||
left = leftArithmetic
|
||||
} else {
|
||||
left = &ColName{Name: stringValue(leftExpr)}
|
||||
}
|
||||
|
||||
// Parse right side
|
||||
if rightArithmetic := parseArithmeticExpression(rightExpr); rightArithmetic != nil {
|
||||
right = rightArithmetic
|
||||
} else {
|
||||
right = &ColName{Name: stringValue(rightExpr)}
|
||||
}
|
||||
|
||||
return &ArithmeticExpr{
|
||||
Left: left,
|
||||
Right: right,
|
||||
Operator: op,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractFunctionArguments extracts the arguments from a function call expression
|
||||
func extractFunctionArguments(expr string) ([]SelectExpr, error) {
|
||||
// Find the parentheses
|
||||
@@ -1383,6 +1453,9 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *SelectStat
|
||||
switch col := expr.Expr.(type) {
|
||||
case *ColName:
|
||||
columns = append(columns, col.Name.String())
|
||||
case *ArithmeticExpr:
|
||||
// Handle arithmetic expressions like id+user_id and string concatenation like name||suffix
|
||||
columns = append(columns, e.getArithmeticExpressionAlias(col))
|
||||
case *FuncExpr:
|
||||
// Handle aggregation functions
|
||||
aggSpec, err := e.parseAggregationFunction(col, expr)
|
||||
@@ -1482,9 +1555,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *SelectStat
|
||||
// Convert to SQL result format
|
||||
if selectAll {
|
||||
columns = nil // Let converter determine all columns
|
||||
return hybridScanner.ConvertToSQLResult(results, columns), nil
|
||||
}
|
||||
|
||||
return hybridScanner.ConvertToSQLResult(results, columns), nil
|
||||
// Handle custom column expressions (including arithmetic)
|
||||
return e.ConvertToSQLResultWithExpressions(hybridScanner, results, stmt.SelectExprs), nil
|
||||
}
|
||||
|
||||
// executeSelectStatementWithBrokerStats handles SELECT queries with broker buffer statistics capture
|
||||
@@ -1573,6 +1648,9 @@ func (e *SQLEngine) executeSelectStatementWithBrokerStats(ctx context.Context, s
|
||||
switch col := expr.Expr.(type) {
|
||||
case *ColName:
|
||||
columns = append(columns, col.Name.String())
|
||||
case *ArithmeticExpr:
|
||||
// Handle arithmetic expressions like id+user_id and string concatenation like name||suffix
|
||||
columns = append(columns, e.getArithmeticExpressionAlias(col))
|
||||
case *FuncExpr:
|
||||
// Handle aggregation functions
|
||||
aggSpec, err := e.parseAggregationFunction(col, expr)
|
||||
@@ -1705,9 +1783,11 @@ func (e *SQLEngine) executeSelectStatementWithBrokerStats(ctx context.Context, s
|
||||
// Convert to SQL result format
|
||||
if selectAll {
|
||||
columns = nil // Let converter determine all columns
|
||||
return hybridScanner.ConvertToSQLResult(results, columns), nil
|
||||
}
|
||||
|
||||
return hybridScanner.ConvertToSQLResult(results, columns), nil
|
||||
// Handle custom column expressions (including arithmetic)
|
||||
return e.ConvertToSQLResultWithExpressions(hybridScanner, results, stmt.SelectExprs), nil
|
||||
}
|
||||
|
||||
// extractTimeFilters extracts time range filters from WHERE clause for optimization
|
||||
@@ -3347,3 +3427,161 @@ func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tabl
|
||||
// Note: This is a discovery operation, not query execution, so it's okay to always log
|
||||
return nil
|
||||
}
|
||||
|
||||
// getArithmeticExpressionAlias generates a display alias for arithmetic expressions
|
||||
func (e *SQLEngine) getArithmeticExpressionAlias(expr *ArithmeticExpr) string {
|
||||
leftAlias := e.getExpressionAlias(expr.Left)
|
||||
rightAlias := e.getExpressionAlias(expr.Right)
|
||||
return leftAlias + expr.Operator + rightAlias
|
||||
}
|
||||
|
||||
// getExpressionAlias generates an alias for any expression node
|
||||
func (e *SQLEngine) getExpressionAlias(expr ExprNode) string {
|
||||
switch exprType := expr.(type) {
|
||||
case *ColName:
|
||||
return exprType.Name.String()
|
||||
case *ArithmeticExpr:
|
||||
return e.getArithmeticExpressionAlias(exprType)
|
||||
default:
|
||||
return "expr"
|
||||
}
|
||||
}
|
||||
|
||||
// evaluateArithmeticExpression evaluates an arithmetic expression for a given record
|
||||
func (e *SQLEngine) evaluateArithmeticExpression(expr *ArithmeticExpr, result HybridScanResult) (*schema_pb.Value, error) {
|
||||
// Get left operand value
|
||||
leftValue, err := e.evaluateExpressionValue(expr.Left, result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error evaluating left operand: %v", err)
|
||||
}
|
||||
|
||||
// Get right operand value
|
||||
rightValue, err := e.evaluateExpressionValue(expr.Right, result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error evaluating right operand: %v", err)
|
||||
}
|
||||
|
||||
// Handle string concatenation operator
|
||||
if expr.Operator == "||" {
|
||||
return e.Concat(leftValue, rightValue)
|
||||
}
|
||||
|
||||
// Perform arithmetic operation
|
||||
var op ArithmeticOperator
|
||||
switch expr.Operator {
|
||||
case "+":
|
||||
op = OpAdd
|
||||
case "-":
|
||||
op = OpSub
|
||||
case "*":
|
||||
op = OpMul
|
||||
case "/":
|
||||
op = OpDiv
|
||||
case "%":
|
||||
op = OpMod
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported arithmetic operator: %s", expr.Operator)
|
||||
}
|
||||
|
||||
return e.EvaluateArithmeticExpression(leftValue, rightValue, op)
|
||||
}
|
||||
|
||||
// evaluateExpressionValue evaluates any expression to get its value from a record
|
||||
func (e *SQLEngine) evaluateExpressionValue(expr ExprNode, result HybridScanResult) (*schema_pb.Value, error) {
|
||||
switch exprType := expr.(type) {
|
||||
case *ColName:
|
||||
columnName := exprType.Name.String()
|
||||
value := e.findColumnValue(result, columnName)
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return value, nil
|
||||
case *ArithmeticExpr:
|
||||
return e.evaluateArithmeticExpression(exprType, result)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported expression type: %T", expr)
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToSQLResultWithExpressions converts HybridScanResults to SQL query results with expression evaluation
|
||||
func (e *SQLEngine) ConvertToSQLResultWithExpressions(hms *HybridMessageScanner, results []HybridScanResult, selectExprs []SelectExpr) *QueryResult {
|
||||
if len(results) == 0 {
|
||||
columns := make([]string, 0, len(selectExprs))
|
||||
for _, selectExpr := range selectExprs {
|
||||
switch expr := selectExpr.(type) {
|
||||
case *AliasedExpr:
|
||||
switch col := expr.Expr.(type) {
|
||||
case *ColName:
|
||||
columns = append(columns, col.Name.String())
|
||||
case *ArithmeticExpr:
|
||||
columns = append(columns, e.getArithmeticExpressionAlias(col))
|
||||
default:
|
||||
columns = append(columns, "expr")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &QueryResult{
|
||||
Columns: columns,
|
||||
Rows: [][]sqltypes.Value{},
|
||||
Database: hms.topic.Namespace,
|
||||
Table: hms.topic.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Build columns from SELECT expressions
|
||||
columns := make([]string, 0, len(selectExprs))
|
||||
for _, selectExpr := range selectExprs {
|
||||
switch expr := selectExpr.(type) {
|
||||
case *AliasedExpr:
|
||||
switch col := expr.Expr.(type) {
|
||||
case *ColName:
|
||||
columns = append(columns, col.Name.String())
|
||||
case *ArithmeticExpr:
|
||||
columns = append(columns, e.getArithmeticExpressionAlias(col))
|
||||
default:
|
||||
columns = append(columns, "expr")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to SQL rows with expression evaluation
|
||||
rows := make([][]sqltypes.Value, len(results))
|
||||
for i, result := range results {
|
||||
row := make([]sqltypes.Value, len(selectExprs))
|
||||
for j, selectExpr := range selectExprs {
|
||||
switch expr := selectExpr.(type) {
|
||||
case *AliasedExpr:
|
||||
switch col := expr.Expr.(type) {
|
||||
case *ColName:
|
||||
// Handle regular column
|
||||
columnName := col.Name.String()
|
||||
if value := e.findColumnValue(result, columnName); value != nil {
|
||||
row[j] = convertSchemaValueToSQL(value)
|
||||
} else {
|
||||
row[j] = sqltypes.NULL
|
||||
}
|
||||
case *ArithmeticExpr:
|
||||
// Handle arithmetic expression
|
||||
if value, err := e.evaluateArithmeticExpression(col, result); err == nil && value != nil {
|
||||
row[j] = convertSchemaValueToSQL(value)
|
||||
} else {
|
||||
row[j] = sqltypes.NULL
|
||||
}
|
||||
default:
|
||||
row[j] = sqltypes.NULL
|
||||
}
|
||||
default:
|
||||
row[j] = sqltypes.NULL
|
||||
}
|
||||
}
|
||||
rows[i] = row
|
||||
}
|
||||
|
||||
return &QueryResult{
|
||||
Columns: columns,
|
||||
Rows: rows,
|
||||
Database: hms.topic.Namespace,
|
||||
Table: hms.topic.Name,
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user