diff --git a/weed/query/engine/arithmetic_functions.go b/weed/query/engine/arithmetic_functions.go new file mode 100644 index 000000000..fd8ac1684 --- /dev/null +++ b/weed/query/engine/arithmetic_functions.go @@ -0,0 +1,218 @@ +package engine + +import ( + "fmt" + "math" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +// =============================== +// ARITHMETIC OPERATORS +// =============================== + +// ArithmeticOperator represents basic arithmetic operations +type ArithmeticOperator string + +const ( + OpAdd ArithmeticOperator = "+" + OpSub ArithmeticOperator = "-" + OpMul ArithmeticOperator = "*" + OpDiv ArithmeticOperator = "/" + OpMod ArithmeticOperator = "%" +) + +// EvaluateArithmeticExpression evaluates basic arithmetic operations between two values +func (e *SQLEngine) EvaluateArithmeticExpression(left, right *schema_pb.Value, operator ArithmeticOperator) (*schema_pb.Value, error) { + if left == nil || right == nil { + return nil, fmt.Errorf("arithmetic operation requires non-null operands") + } + + // Convert values to numeric types for calculation + leftNum, err := e.valueToFloat64(left) + if err != nil { + return nil, fmt.Errorf("left operand conversion error: %v", err) + } + + rightNum, err := e.valueToFloat64(right) + if err != nil { + return nil, fmt.Errorf("right operand conversion error: %v", err) + } + + var result float64 + var resultErr error + + switch operator { + case OpAdd: + result = leftNum + rightNum + case OpSub: + result = leftNum - rightNum + case OpMul: + result = leftNum * rightNum + case OpDiv: + if rightNum == 0 { + return nil, fmt.Errorf("division by zero") + } + result = leftNum / rightNum + case OpMod: + if rightNum == 0 { + return nil, fmt.Errorf("modulo by zero") + } + result = math.Mod(leftNum, rightNum) + default: + return nil, fmt.Errorf("unsupported arithmetic operator: %s", operator) + } + + if resultErr != nil { + return nil, resultErr + } + + // Convert result back to appropriate schema value type + // If both operands were integers and operation doesn't produce decimal, return integer + if e.isIntegerValue(left) && e.isIntegerValue(right) && + (operator == OpAdd || operator == OpSub || operator == OpMul || operator == OpMod) { + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, + }, nil + } + + // Otherwise return as double/float + return &schema_pb.Value{ + Kind: &schema_pb.Value_DoubleValue{DoubleValue: result}, + }, nil +} + +// Add evaluates addition (left + right) +func (e *SQLEngine) Add(left, right *schema_pb.Value) (*schema_pb.Value, error) { + return e.EvaluateArithmeticExpression(left, right, OpAdd) +} + +// Subtract evaluates subtraction (left - right) +func (e *SQLEngine) Subtract(left, right *schema_pb.Value) (*schema_pb.Value, error) { + return e.EvaluateArithmeticExpression(left, right, OpSub) +} + +// Multiply evaluates multiplication (left * right) +func (e *SQLEngine) Multiply(left, right *schema_pb.Value) (*schema_pb.Value, error) { + return e.EvaluateArithmeticExpression(left, right, OpMul) +} + +// Divide evaluates division (left / right) +func (e *SQLEngine) Divide(left, right *schema_pb.Value) (*schema_pb.Value, error) { + return e.EvaluateArithmeticExpression(left, right, OpDiv) +} + +// Modulo evaluates modulo operation (left % right) +func (e *SQLEngine) Modulo(left, right *schema_pb.Value) (*schema_pb.Value, error) { + return e.EvaluateArithmeticExpression(left, right, OpMod) +} + +// =============================== +// MATHEMATICAL FUNCTIONS +// =============================== + +// Round rounds a numeric value to the nearest integer or specified decimal places +func (e *SQLEngine) Round(value *schema_pb.Value, precision ...*schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("ROUND function requires non-null value") + } + + num, err := e.valueToFloat64(value) + if err != nil { + return nil, fmt.Errorf("ROUND function conversion error: %v", err) + } + + // Default precision is 0 (round to integer) + precisionValue := 0 + if len(precision) > 0 && precision[0] != nil { + precFloat, err := e.valueToFloat64(precision[0]) + if err != nil { + return nil, fmt.Errorf("ROUND precision conversion error: %v", err) + } + precisionValue = int(precFloat) + } + + // Apply rounding + multiplier := math.Pow(10, float64(precisionValue)) + rounded := math.Round(num*multiplier) / multiplier + + // Return as integer if precision is 0 and original was integer, otherwise as double + if precisionValue == 0 && e.isIntegerValue(value) { + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: int64(rounded)}, + }, nil + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_DoubleValue{DoubleValue: rounded}, + }, nil +} + +// Ceil returns the smallest integer greater than or equal to the value +func (e *SQLEngine) Ceil(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("CEIL function requires non-null value") + } + + num, err := e.valueToFloat64(value) + if err != nil { + return nil, fmt.Errorf("CEIL function conversion error: %v", err) + } + + result := math.Ceil(num) + + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, + }, nil +} + +// Floor returns the largest integer less than or equal to the value +func (e *SQLEngine) Floor(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("FLOOR function requires non-null value") + } + + num, err := e.valueToFloat64(value) + if err != nil { + return nil, fmt.Errorf("FLOOR function conversion error: %v", err) + } + + result := math.Floor(num) + + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, + }, nil +} + +// Abs returns the absolute value of a number +func (e *SQLEngine) Abs(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("ABS function requires non-null value") + } + + num, err := e.valueToFloat64(value) + if err != nil { + return nil, fmt.Errorf("ABS function conversion error: %v", err) + } + + result := math.Abs(num) + + // Return same type as input if possible + if e.isIntegerValue(value) { + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, + }, nil + } + + // Check if original was float32 + if _, ok := value.Kind.(*schema_pb.Value_FloatValue); ok { + return &schema_pb.Value{ + Kind: &schema_pb.Value_FloatValue{FloatValue: float32(result)}, + }, nil + } + + // Default to double + return &schema_pb.Value{ + Kind: &schema_pb.Value_DoubleValue{DoubleValue: result}, + }, nil +} diff --git a/weed/query/engine/arithmetic_functions_test.go b/weed/query/engine/arithmetic_functions_test.go new file mode 100644 index 000000000..8c5e11dec --- /dev/null +++ b/weed/query/engine/arithmetic_functions_test.go @@ -0,0 +1,530 @@ +package engine + +import ( + "testing" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +func TestArithmeticOperations(t *testing.T) { + engine := NewTestSQLEngine() + + tests := []struct { + name string + left *schema_pb.Value + right *schema_pb.Value + operator ArithmeticOperator + expected *schema_pb.Value + expectErr bool + }{ + // Addition tests + { + name: "Add two integers", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + operator: OpAdd, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 15}}, + expectErr: false, + }, + { + name: "Add integer and float", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 5.5}}, + operator: OpAdd, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 15.5}}, + expectErr: false, + }, + // Subtraction tests + { + name: "Subtract two integers", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}}, + operator: OpSub, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, + expectErr: false, + }, + // Multiplication tests + { + name: "Multiply two integers", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 6}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, + operator: OpMul, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 42}}, + expectErr: false, + }, + { + name: "Multiply with float", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, + operator: OpMul, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 12.5}}, + expectErr: false, + }, + // Division tests + { + name: "Divide two integers", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 20}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 4}}, + operator: OpDiv, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 5.0}}, + expectErr: false, + }, + { + name: "Division by zero", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, + operator: OpDiv, + expected: nil, + expectErr: true, + }, + // Modulo tests + { + name: "Modulo operation", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 17}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + operator: OpMod, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 2}}, + expectErr: false, + }, + { + name: "Modulo by zero", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, + operator: OpMod, + expected: nil, + expectErr: true, + }, + // String conversion tests + { + name: "Add string number to integer", + left: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "15"}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + operator: OpAdd, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 20.0}}, + expectErr: false, + }, + { + name: "Invalid string conversion", + left: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "not_a_number"}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + operator: OpAdd, + expected: nil, + expectErr: true, + }, + // Boolean conversion tests + { + name: "Add boolean to integer", + left: &schema_pb.Value{Kind: &schema_pb.Value_BoolValue{BoolValue: true}}, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + operator: OpAdd, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 6.0}}, + expectErr: false, + }, + // Null value tests + { + name: "Add with null left operand", + left: nil, + right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + operator: OpAdd, + expected: nil, + expectErr: true, + }, + { + name: "Add with null right operand", + left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + right: nil, + operator: OpAdd, + expected: nil, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := engine.EvaluateArithmeticExpression(tt.left, tt.right, tt.operator) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if !valuesEqual(result, tt.expected) { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestIndividualArithmeticFunctions(t *testing.T) { + engine := NewTestSQLEngine() + + left := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}} + right := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}} + + // Test Add function + result, err := engine.Add(left, right) + if err != nil { + t.Errorf("Add function failed: %v", err) + } + expected := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 13}} + if !valuesEqual(result, expected) { + t.Errorf("Add: Expected %v, got %v", expected, result) + } + + // Test Subtract function + result, err = engine.Subtract(left, right) + if err != nil { + t.Errorf("Subtract function failed: %v", err) + } + expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}} + if !valuesEqual(result, expected) { + t.Errorf("Subtract: Expected %v, got %v", expected, result) + } + + // Test Multiply function + result, err = engine.Multiply(left, right) + if err != nil { + t.Errorf("Multiply function failed: %v", err) + } + expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 30}} + if !valuesEqual(result, expected) { + t.Errorf("Multiply: Expected %v, got %v", expected, result) + } + + // Test Divide function + result, err = engine.Divide(left, right) + if err != nil { + t.Errorf("Divide function failed: %v", err) + } + expected = &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 10.0/3.0}} + if !valuesEqual(result, expected) { + t.Errorf("Divide: Expected %v, got %v", expected, result) + } + + // Test Modulo function + result, err = engine.Modulo(left, right) + if err != nil { + t.Errorf("Modulo function failed: %v", err) + } + expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1}} + if !valuesEqual(result, expected) { + t.Errorf("Modulo: Expected %v, got %v", expected, result) + } +} + +func TestMathematicalFunctions(t *testing.T) { + engine := NewTestSQLEngine() + + t.Run("ROUND function tests", func(t *testing.T) { + tests := []struct { + name string + value *schema_pb.Value + precision *schema_pb.Value + expected *schema_pb.Value + expectErr bool + }{ + { + name: "Round float to integer", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.7}}, + precision: nil, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 4.0}}, + expectErr: false, + }, + { + name: "Round integer stays integer", + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + precision: nil, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expectErr: false, + }, + { + name: "Round with precision 2", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14159}}, + precision: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 2}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, + expectErr: false, + }, + { + name: "Round negative number", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.7}}, + precision: nil, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -4.0}}, + expectErr: false, + }, + { + name: "Round null value", + value: nil, + precision: nil, + expected: nil, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result *schema_pb.Value + var err error + + if tt.precision != nil { + result, err = engine.Round(tt.value, tt.precision) + } else { + result, err = engine.Round(tt.value) + } + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if !valuesEqual(result, tt.expected) { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } + }) + + t.Run("CEIL function tests", func(t *testing.T) { + tests := []struct { + name string + value *schema_pb.Value + expected *schema_pb.Value + expectErr bool + }{ + { + name: "Ceil positive decimal", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.2}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 4}}, + expectErr: false, + }, + { + name: "Ceil negative decimal", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.2}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: -3}}, + expectErr: false, + }, + { + name: "Ceil integer", + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expectErr: false, + }, + { + name: "Ceil null value", + value: nil, + expected: nil, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := engine.Ceil(tt.value) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if !valuesEqual(result, tt.expected) { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } + }) + + t.Run("FLOOR function tests", func(t *testing.T) { + tests := []struct { + name string + value *schema_pb.Value + expected *schema_pb.Value + expectErr bool + }{ + { + name: "Floor positive decimal", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.8}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}}, + expectErr: false, + }, + { + name: "Floor negative decimal", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.2}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: -4}}, + expectErr: false, + }, + { + name: "Floor integer", + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expectErr: false, + }, + { + name: "Floor null value", + value: nil, + expected: nil, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := engine.Floor(tt.value) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if !valuesEqual(result, tt.expected) { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } + }) + + t.Run("ABS function tests", func(t *testing.T) { + tests := []struct { + name string + value *schema_pb.Value + expected *schema_pb.Value + expectErr bool + }{ + { + name: "Abs positive integer", + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expectErr: false, + }, + { + name: "Abs negative integer", + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: -5}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, + expectErr: false, + }, + { + name: "Abs positive double", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, + expectErr: false, + }, + { + name: "Abs negative double", + value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.14}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, + expectErr: false, + }, + { + name: "Abs positive float", + value: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, + expectErr: false, + }, + { + name: "Abs negative float", + value: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: -2.5}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, + expectErr: false, + }, + { + name: "Abs zero", + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, + expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, + expectErr: false, + }, + { + name: "Abs null value", + value: nil, + expected: nil, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := engine.Abs(tt.value) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if !valuesEqual(result, tt.expected) { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } + }) +} + +// Helper function to compare two schema_pb.Value objects +func valuesEqual(v1, v2 *schema_pb.Value) bool { + if v1 == nil && v2 == nil { + return true + } + if v1 == nil || v2 == nil { + return false + } + + switch v1Kind := v1.Kind.(type) { + case *schema_pb.Value_Int32Value: + if v2Kind, ok := v2.Kind.(*schema_pb.Value_Int32Value); ok { + return v1Kind.Int32Value == v2Kind.Int32Value + } + case *schema_pb.Value_Int64Value: + if v2Kind, ok := v2.Kind.(*schema_pb.Value_Int64Value); ok { + return v1Kind.Int64Value == v2Kind.Int64Value + } + case *schema_pb.Value_FloatValue: + if v2Kind, ok := v2.Kind.(*schema_pb.Value_FloatValue); ok { + return v1Kind.FloatValue == v2Kind.FloatValue + } + case *schema_pb.Value_DoubleValue: + if v2Kind, ok := v2.Kind.(*schema_pb.Value_DoubleValue); ok { + return v1Kind.DoubleValue == v2Kind.DoubleValue + } + case *schema_pb.Value_StringValue: + if v2Kind, ok := v2.Kind.(*schema_pb.Value_StringValue); ok { + return v1Kind.StringValue == v2Kind.StringValue + } + case *schema_pb.Value_BoolValue: + if v2Kind, ok := v2.Kind.(*schema_pb.Value_BoolValue); ok { + return v1Kind.BoolValue == v2Kind.BoolValue + } + } + + return false +} diff --git a/weed/query/engine/datetime_functions.go b/weed/query/engine/datetime_functions.go new file mode 100644 index 000000000..2ece58e15 --- /dev/null +++ b/weed/query/engine/datetime_functions.go @@ -0,0 +1,195 @@ +package engine + +import ( + "fmt" + "strings" + "time" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +// =============================== +// DATE/TIME CONSTANTS +// =============================== + +// CurrentDate returns the current date as a string in YYYY-MM-DD format +func (e *SQLEngine) CurrentDate() (*schema_pb.Value, error) { + now := time.Now() + dateStr := now.Format("2006-01-02") + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: dateStr}, + }, nil +} + +// CurrentTimestamp returns the current timestamp +func (e *SQLEngine) CurrentTimestamp() (*schema_pb.Value, error) { + now := time.Now() + + // Return as TimestampValue with microseconds + timestampMicros := now.UnixMicro() + + return &schema_pb.Value{ + Kind: &schema_pb.Value_TimestampValue{ + TimestampValue: &schema_pb.TimestampValue{ + TimestampMicros: timestampMicros, + }, + }, + }, nil +} + +// CurrentTime returns the current time as a string in HH:MM:SS format +func (e *SQLEngine) CurrentTime() (*schema_pb.Value, error) { + now := time.Now() + timeStr := now.Format("15:04:05") + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: timeStr}, + }, nil +} + +// Now is an alias for CurrentTimestamp (common SQL function name) +func (e *SQLEngine) Now() (*schema_pb.Value, error) { + return e.CurrentTimestamp() +} + +// =============================== +// EXTRACT FUNCTION +// =============================== + +// DatePart represents the part of a date/time to extract +type DatePart string + +const ( + PartYear DatePart = "YEAR" + PartMonth DatePart = "MONTH" + PartDay DatePart = "DAY" + PartHour DatePart = "HOUR" + PartMinute DatePart = "MINUTE" + PartSecond DatePart = "SECOND" + PartWeek DatePart = "WEEK" + PartDayOfYear DatePart = "DOY" + PartDayOfWeek DatePart = "DOW" + PartQuarter DatePart = "QUARTER" + PartEpoch DatePart = "EPOCH" +) + +// Extract extracts a specific part from a date/time value +func (e *SQLEngine) Extract(part DatePart, value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("EXTRACT function requires non-null value") + } + + // Convert value to time + t, err := e.valueToTime(value) + if err != nil { + return nil, fmt.Errorf("EXTRACT function time conversion error: %v", err) + } + + var result int64 + + switch strings.ToUpper(string(part)) { + case string(PartYear): + result = int64(t.Year()) + case string(PartMonth): + result = int64(t.Month()) + case string(PartDay): + result = int64(t.Day()) + case string(PartHour): + result = int64(t.Hour()) + case string(PartMinute): + result = int64(t.Minute()) + case string(PartSecond): + result = int64(t.Second()) + case string(PartWeek): + _, week := t.ISOWeek() + result = int64(week) + case string(PartDayOfYear): + result = int64(t.YearDay()) + case string(PartDayOfWeek): + result = int64(t.Weekday()) + case string(PartQuarter): + month := t.Month() + result = int64((month-1)/3 + 1) + case string(PartEpoch): + result = t.Unix() + default: + return nil, fmt.Errorf("unsupported date part: %s", part) + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: result}, + }, nil +} + +// =============================== +// DATE_TRUNC FUNCTION +// =============================== + +// DateTrunc truncates a date/time to the specified precision +func (e *SQLEngine) DateTrunc(precision string, value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("DATE_TRUNC function requires non-null value") + } + + // Convert value to time + t, err := e.valueToTime(value) + if err != nil { + return nil, fmt.Errorf("DATE_TRUNC function time conversion error: %v", err) + } + + var truncated time.Time + + switch strings.ToLower(precision) { + case "microsecond", "microseconds": + // No truncation needed for microsecond precision + truncated = t + case "millisecond", "milliseconds": + truncated = t.Truncate(time.Millisecond) + case "second", "seconds": + truncated = t.Truncate(time.Second) + case "minute", "minutes": + truncated = t.Truncate(time.Minute) + case "hour", "hours": + truncated = t.Truncate(time.Hour) + case "day", "days": + truncated = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) + case "week", "weeks": + // Truncate to beginning of week (Monday) + days := int(t.Weekday()) + if days == 0 { // Sunday = 0, adjust to make Monday = 0 + days = 6 + } else { + days = days - 1 + } + truncated = time.Date(t.Year(), t.Month(), t.Day()-days, 0, 0, 0, 0, t.Location()) + case "month", "months": + truncated = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location()) + case "quarter", "quarters": + month := t.Month() + quarterMonth := ((int(month)-1)/3)*3 + 1 + truncated = time.Date(t.Year(), time.Month(quarterMonth), 1, 0, 0, 0, 0, t.Location()) + case "year", "years": + truncated = time.Date(t.Year(), 1, 1, 0, 0, 0, 0, t.Location()) + case "decade", "decades": + year := (t.Year()/10) * 10 + truncated = time.Date(year, 1, 1, 0, 0, 0, 0, t.Location()) + case "century", "centuries": + year := ((t.Year()-1)/100)*100 + 1 + truncated = time.Date(year, 1, 1, 0, 0, 0, 0, t.Location()) + case "millennium", "millennia": + year := ((t.Year()-1)/1000)*1000 + 1 + truncated = time.Date(year, 1, 1, 0, 0, 0, 0, t.Location()) + default: + return nil, fmt.Errorf("unsupported date truncation precision: %s", precision) + } + + // Return as TimestampValue + return &schema_pb.Value{ + Kind: &schema_pb.Value_TimestampValue{ + TimestampValue: &schema_pb.TimestampValue{ + TimestampMicros: truncated.UnixMicro(), + }, + }, + }, nil +} diff --git a/weed/query/engine/datetime_functions_test.go b/weed/query/engine/datetime_functions_test.go new file mode 100644 index 000000000..5dba04ce3 --- /dev/null +++ b/weed/query/engine/datetime_functions_test.go @@ -0,0 +1,418 @@ +package engine + +import ( + "testing" + "time" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +func TestDateTimeFunctions(t *testing.T) { + engine := NewTestSQLEngine() + + t.Run("CURRENT_DATE function tests", func(t *testing.T) { + result, err := engine.CurrentDate() + if err != nil { + t.Errorf("CurrentDate failed: %v", err) + } + + if result == nil { + t.Errorf("CurrentDate returned nil result") + return + } + + stringVal, ok := result.Kind.(*schema_pb.Value_StringValue) + if !ok { + t.Errorf("CurrentDate should return string value, got %T", result.Kind) + return + } + + // Check format (YYYY-MM-DD) + today := time.Now().Format("2006-01-02") + if stringVal.StringValue != today { + t.Errorf("Expected current date %s, got %s", today, stringVal.StringValue) + } + }) + + t.Run("CURRENT_TIMESTAMP function tests", func(t *testing.T) { + before := time.Now() + result, err := engine.CurrentTimestamp() + after := time.Now() + + if err != nil { + t.Errorf("CurrentTimestamp failed: %v", err) + } + + if result == nil { + t.Errorf("CurrentTimestamp returned nil result") + return + } + + timestampVal, ok := result.Kind.(*schema_pb.Value_TimestampValue) + if !ok { + t.Errorf("CurrentTimestamp should return timestamp value, got %T", result.Kind) + return + } + + timestamp := time.UnixMicro(timestampVal.TimestampValue.TimestampMicros) + + // Check that timestamp is within reasonable range + if timestamp.Before(before) || timestamp.After(after) { + t.Errorf("Timestamp %v should be between %v and %v", timestamp, before, after) + } + }) + + t.Run("NOW function tests", func(t *testing.T) { + result, err := engine.Now() + if err != nil { + t.Errorf("Now failed: %v", err) + } + + if result == nil { + t.Errorf("Now returned nil result") + return + } + + // Should return same type as CurrentTimestamp + _, ok := result.Kind.(*schema_pb.Value_TimestampValue) + if !ok { + t.Errorf("Now should return timestamp value, got %T", result.Kind) + } + }) + + t.Run("CURRENT_TIME function tests", func(t *testing.T) { + result, err := engine.CurrentTime() + if err != nil { + t.Errorf("CurrentTime failed: %v", err) + } + + if result == nil { + t.Errorf("CurrentTime returned nil result") + return + } + + stringVal, ok := result.Kind.(*schema_pb.Value_StringValue) + if !ok { + t.Errorf("CurrentTime should return string value, got %T", result.Kind) + return + } + + // Check format (HH:MM:SS) + if len(stringVal.StringValue) != 8 || stringVal.StringValue[2] != ':' || stringVal.StringValue[5] != ':' { + t.Errorf("CurrentTime should return HH:MM:SS format, got %s", stringVal.StringValue) + } + }) +} + +func TestExtractFunction(t *testing.T) { + engine := NewTestSQLEngine() + + // Create a test timestamp: 2023-06-15 14:30:45 + // Use local time to avoid timezone conversion issues + testTime := time.Date(2023, 6, 15, 14, 30, 45, 0, time.Local) + testTimestamp := &schema_pb.Value{ + Kind: &schema_pb.Value_TimestampValue{ + TimestampValue: &schema_pb.TimestampValue{ + TimestampMicros: testTime.UnixMicro(), + }, + }, + } + + tests := []struct { + name string + part DatePart + value *schema_pb.Value + expected int64 + expectErr bool + }{ + { + name: "Extract YEAR", + part: PartYear, + value: testTimestamp, + expected: 2023, + expectErr: false, + }, + { + name: "Extract MONTH", + part: PartMonth, + value: testTimestamp, + expected: 6, + expectErr: false, + }, + { + name: "Extract DAY", + part: PartDay, + value: testTimestamp, + expected: 15, + expectErr: false, + }, + { + name: "Extract HOUR", + part: PartHour, + value: testTimestamp, + expected: 14, + expectErr: false, + }, + { + name: "Extract MINUTE", + part: PartMinute, + value: testTimestamp, + expected: 30, + expectErr: false, + }, + { + name: "Extract SECOND", + part: PartSecond, + value: testTimestamp, + expected: 45, + expectErr: false, + }, + { + name: "Extract QUARTER from June", + part: PartQuarter, + value: testTimestamp, + expected: 2, // June is in Q2 + expectErr: false, + }, + { + name: "Extract from string date", + part: PartYear, + value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "2023-06-15"}}, + expected: 2023, + expectErr: false, + }, + { + name: "Extract from Unix timestamp", + part: PartYear, + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: testTime.Unix()}}, + expected: 2023, + expectErr: false, + }, + { + name: "Extract from null value", + part: PartYear, + value: nil, + expected: 0, + expectErr: true, + }, + { + name: "Extract invalid part", + part: DatePart("INVALID"), + value: testTimestamp, + expected: 0, + expectErr: true, + }, + { + name: "Extract from invalid string", + part: PartYear, + value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "invalid-date"}}, + expected: 0, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := engine.Extract(tt.part, tt.value) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("Extract returned nil result") + return + } + + intVal, ok := result.Kind.(*schema_pb.Value_Int64Value) + if !ok { + t.Errorf("Extract should return int64 value, got %T", result.Kind) + return + } + + if intVal.Int64Value != tt.expected { + t.Errorf("Expected %d, got %d", tt.expected, intVal.Int64Value) + } + }) + } +} + +func TestDateTruncFunction(t *testing.T) { + engine := NewTestSQLEngine() + + // Create a test timestamp: 2023-06-15 14:30:45.123456 + testTime := time.Date(2023, 6, 15, 14, 30, 45, 123456000, time.Local) // nanoseconds + testTimestamp := &schema_pb.Value{ + Kind: &schema_pb.Value_TimestampValue{ + TimestampValue: &schema_pb.TimestampValue{ + TimestampMicros: testTime.UnixMicro(), + }, + }, + } + + tests := []struct { + name string + precision string + value *schema_pb.Value + expectErr bool + expectedCheck func(result time.Time) bool // Custom check function + }{ + { + name: "Truncate to second", + precision: "second", + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && + result.Hour() == 14 && result.Minute() == 30 && result.Second() == 45 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate to minute", + precision: "minute", + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && + result.Hour() == 14 && result.Minute() == 30 && result.Second() == 0 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate to hour", + precision: "hour", + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && + result.Hour() == 14 && result.Minute() == 0 && result.Second() == 0 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate to day", + precision: "day", + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && + result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate to month", + precision: "month", + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + return result.Year() == 2023 && result.Month() == 6 && result.Day() == 1 && + result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate to quarter", + precision: "quarter", + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + // June (month 6) should truncate to April (month 4) - start of Q2 + return result.Year() == 2023 && result.Month() == 4 && result.Day() == 1 && + result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate to year", + precision: "year", + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + return result.Year() == 2023 && result.Month() == 1 && result.Day() == 1 && + result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate with plural precision", + precision: "minutes", // Test plural form + value: testTimestamp, + expectErr: false, + expectedCheck: func(result time.Time) bool { + return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && + result.Hour() == 14 && result.Minute() == 30 && result.Second() == 0 && + result.Nanosecond() == 0 + }, + }, + { + name: "Truncate from string date", + precision: "day", + value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "2023-06-15 14:30:45"}}, + expectErr: false, + expectedCheck: func(result time.Time) bool { + // The result should be the start of day 2023-06-15 in local timezone + expectedDay := time.Date(2023, 6, 15, 0, 0, 0, 0, result.Location()) + return result.Equal(expectedDay) + }, + }, + { + name: "Truncate null value", + precision: "day", + value: nil, + expectErr: true, + expectedCheck: nil, + }, + { + name: "Invalid precision", + precision: "invalid", + value: testTimestamp, + expectErr: true, + expectedCheck: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := engine.DateTrunc(tt.precision, tt.value) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if result == nil { + t.Errorf("DateTrunc returned nil result") + return + } + + timestampVal, ok := result.Kind.(*schema_pb.Value_TimestampValue) + if !ok { + t.Errorf("DateTrunc should return timestamp value, got %T", result.Kind) + return + } + + resultTime := time.UnixMicro(timestampVal.TimestampValue.TimestampMicros) + + if !tt.expectedCheck(resultTime) { + t.Errorf("DateTrunc result check failed for precision %s, got time: %v", tt.precision, resultTime) + } + }) + } +} diff --git a/weed/query/engine/function_helpers.go b/weed/query/engine/function_helpers.go new file mode 100644 index 000000000..3f1025ed9 --- /dev/null +++ b/weed/query/engine/function_helpers.go @@ -0,0 +1,131 @@ +package engine + +import ( + "fmt" + "strconv" + "time" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +// Helper function to convert schema_pb.Value to float64 +func (e *SQLEngine) valueToFloat64(value *schema_pb.Value) (float64, error) { + switch v := value.Kind.(type) { + case *schema_pb.Value_Int32Value: + return float64(v.Int32Value), nil + case *schema_pb.Value_Int64Value: + return float64(v.Int64Value), nil + case *schema_pb.Value_FloatValue: + return float64(v.FloatValue), nil + case *schema_pb.Value_DoubleValue: + return v.DoubleValue, nil + case *schema_pb.Value_StringValue: + // Try to parse string as number + if f, err := strconv.ParseFloat(v.StringValue, 64); err == nil { + return f, nil + } + return 0, fmt.Errorf("cannot convert string '%s' to number", v.StringValue) + case *schema_pb.Value_BoolValue: + if v.BoolValue { + return 1, nil + } + return 0, nil + default: + return 0, fmt.Errorf("cannot convert value type to number") + } +} + +// Helper function to check if a value is an integer type +func (e *SQLEngine) isIntegerValue(value *schema_pb.Value) bool { + switch value.Kind.(type) { + case *schema_pb.Value_Int32Value, *schema_pb.Value_Int64Value: + return true + default: + return false + } +} + +// Helper function to convert schema_pb.Value to string +func (e *SQLEngine) valueToString(value *schema_pb.Value) (string, error) { + switch v := value.Kind.(type) { + case *schema_pb.Value_StringValue: + return v.StringValue, nil + case *schema_pb.Value_Int32Value: + return strconv.FormatInt(int64(v.Int32Value), 10), nil + case *schema_pb.Value_Int64Value: + return strconv.FormatInt(v.Int64Value, 10), nil + case *schema_pb.Value_FloatValue: + return strconv.FormatFloat(float64(v.FloatValue), 'g', -1, 32), nil + case *schema_pb.Value_DoubleValue: + return strconv.FormatFloat(v.DoubleValue, 'g', -1, 64), nil + case *schema_pb.Value_BoolValue: + if v.BoolValue { + return "true", nil + } + return "false", nil + case *schema_pb.Value_BytesValue: + return string(v.BytesValue), nil + default: + return "", fmt.Errorf("cannot convert value type to string") + } +} + +// Helper function to convert schema_pb.Value to int64 +func (e *SQLEngine) valueToInt64(value *schema_pb.Value) (int64, error) { + switch v := value.Kind.(type) { + case *schema_pb.Value_Int32Value: + return int64(v.Int32Value), nil + case *schema_pb.Value_Int64Value: + return v.Int64Value, nil + case *schema_pb.Value_FloatValue: + return int64(v.FloatValue), nil + case *schema_pb.Value_DoubleValue: + return int64(v.DoubleValue), nil + case *schema_pb.Value_StringValue: + if i, err := strconv.ParseInt(v.StringValue, 10, 64); err == nil { + return i, nil + } + return 0, fmt.Errorf("cannot convert string '%s' to integer", v.StringValue) + default: + return 0, fmt.Errorf("cannot convert value type to integer") + } +} + +// Helper function to convert schema_pb.Value to time.Time +func (e *SQLEngine) valueToTime(value *schema_pb.Value) (time.Time, error) { + switch v := value.Kind.(type) { + case *schema_pb.Value_TimestampValue: + if v.TimestampValue == nil { + return time.Time{}, fmt.Errorf("null timestamp value") + } + return time.UnixMicro(v.TimestampValue.TimestampMicros), nil + case *schema_pb.Value_StringValue: + // Try to parse various date/time string formats + dateFormats := []struct { + format string + useLocal bool + }{ + {"2006-01-02 15:04:05", true}, // Local time assumed for non-timezone formats + {"2006-01-02T15:04:05Z", false}, // UTC format + {"2006-01-02T15:04:05", true}, // Local time assumed + {"2006-01-02", true}, // Local time assumed for date only + {"15:04:05", true}, // Local time assumed for time only + } + + for _, formatSpec := range dateFormats { + if t, err := time.Parse(formatSpec.format, v.StringValue); err == nil { + if formatSpec.useLocal { + // Convert to local timezone if no timezone was specified + return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local), nil + } + return t, nil + } + } + return time.Time{}, fmt.Errorf("unable to parse date/time string: %s", v.StringValue) + case *schema_pb.Value_Int64Value: + // Assume Unix timestamp (seconds) + return time.Unix(v.Int64Value, 0), nil + default: + return time.Time{}, fmt.Errorf("cannot convert value type to date/time") + } +} diff --git a/weed/query/engine/sql_functions.go b/weed/query/engine/sql_functions.go deleted file mode 100644 index 168d09a65..000000000 --- a/weed/query/engine/sql_functions.go +++ /dev/null @@ -1,850 +0,0 @@ -package engine - -import ( - "fmt" - "math" - "strconv" - "strings" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -// ArithmeticOperator represents basic arithmetic operations -type ArithmeticOperator string - -const ( - OpAdd ArithmeticOperator = "+" - OpSub ArithmeticOperator = "-" - OpMul ArithmeticOperator = "*" - OpDiv ArithmeticOperator = "/" - OpMod ArithmeticOperator = "%" -) - -// EvaluateArithmeticExpression evaluates basic arithmetic operations between two values -func (e *SQLEngine) EvaluateArithmeticExpression(left, right *schema_pb.Value, operator ArithmeticOperator) (*schema_pb.Value, error) { - if left == nil || right == nil { - return nil, fmt.Errorf("arithmetic operation requires non-null operands") - } - - // Convert values to numeric types for calculation - leftNum, err := e.valueToFloat64(left) - if err != nil { - return nil, fmt.Errorf("left operand conversion error: %v", err) - } - - rightNum, err := e.valueToFloat64(right) - if err != nil { - return nil, fmt.Errorf("right operand conversion error: %v", err) - } - - var result float64 - var resultErr error - - switch operator { - case OpAdd: - result = leftNum + rightNum - case OpSub: - result = leftNum - rightNum - case OpMul: - result = leftNum * rightNum - case OpDiv: - if rightNum == 0 { - return nil, fmt.Errorf("division by zero") - } - result = leftNum / rightNum - case OpMod: - if rightNum == 0 { - return nil, fmt.Errorf("modulo by zero") - } - result = math.Mod(leftNum, rightNum) - default: - return nil, fmt.Errorf("unsupported arithmetic operator: %s", operator) - } - - if resultErr != nil { - return nil, resultErr - } - - // Convert result back to appropriate schema value type - // If both operands were integers and operation doesn't produce decimal, return integer - if e.isIntegerValue(left) && e.isIntegerValue(right) && - (operator == OpAdd || operator == OpSub || operator == OpMul || operator == OpMod) { - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, - }, nil - } - - // Otherwise return as double/float - return &schema_pb.Value{ - Kind: &schema_pb.Value_DoubleValue{DoubleValue: result}, - }, nil -} - -// Helper function to convert schema_pb.Value to float64 -func (e *SQLEngine) valueToFloat64(value *schema_pb.Value) (float64, error) { - switch v := value.Kind.(type) { - case *schema_pb.Value_Int32Value: - return float64(v.Int32Value), nil - case *schema_pb.Value_Int64Value: - return float64(v.Int64Value), nil - case *schema_pb.Value_FloatValue: - return float64(v.FloatValue), nil - case *schema_pb.Value_DoubleValue: - return v.DoubleValue, nil - case *schema_pb.Value_StringValue: - // Try to parse string as number - if f, err := strconv.ParseFloat(v.StringValue, 64); err == nil { - return f, nil - } - return 0, fmt.Errorf("cannot convert string '%s' to number", v.StringValue) - case *schema_pb.Value_BoolValue: - if v.BoolValue { - return 1, nil - } - return 0, nil - default: - return 0, fmt.Errorf("cannot convert value type to number") - } -} - -// Helper function to check if a value is an integer type -func (e *SQLEngine) isIntegerValue(value *schema_pb.Value) bool { - switch value.Kind.(type) { - case *schema_pb.Value_Int32Value, *schema_pb.Value_Int64Value: - return true - default: - return false - } -} - -// Add evaluates addition (left + right) -func (e *SQLEngine) Add(left, right *schema_pb.Value) (*schema_pb.Value, error) { - return e.EvaluateArithmeticExpression(left, right, OpAdd) -} - -// Subtract evaluates subtraction (left - right) -func (e *SQLEngine) Subtract(left, right *schema_pb.Value) (*schema_pb.Value, error) { - return e.EvaluateArithmeticExpression(left, right, OpSub) -} - -// Multiply evaluates multiplication (left * right) -func (e *SQLEngine) Multiply(left, right *schema_pb.Value) (*schema_pb.Value, error) { - return e.EvaluateArithmeticExpression(left, right, OpMul) -} - -// Divide evaluates division (left / right) -func (e *SQLEngine) Divide(left, right *schema_pb.Value) (*schema_pb.Value, error) { - return e.EvaluateArithmeticExpression(left, right, OpDiv) -} - -// Modulo evaluates modulo operation (left % right) -func (e *SQLEngine) Modulo(left, right *schema_pb.Value) (*schema_pb.Value, error) { - return e.EvaluateArithmeticExpression(left, right, OpMod) -} - -// =============================== -// MATHEMATICAL FUNCTIONS -// =============================== - -// Round rounds a numeric value to the nearest integer or specified decimal places -func (e *SQLEngine) Round(value *schema_pb.Value, precision ...*schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("ROUND function requires non-null value") - } - - num, err := e.valueToFloat64(value) - if err != nil { - return nil, fmt.Errorf("ROUND function conversion error: %v", err) - } - - // Default precision is 0 (round to integer) - precisionValue := 0 - if len(precision) > 0 && precision[0] != nil { - precFloat, err := e.valueToFloat64(precision[0]) - if err != nil { - return nil, fmt.Errorf("ROUND precision conversion error: %v", err) - } - precisionValue = int(precFloat) - } - - // Apply rounding - multiplier := math.Pow(10, float64(precisionValue)) - rounded := math.Round(num*multiplier) / multiplier - - // Return as integer if precision is 0 and original was integer, otherwise as double - if precisionValue == 0 && e.isIntegerValue(value) { - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: int64(rounded)}, - }, nil - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_DoubleValue{DoubleValue: rounded}, - }, nil -} - -// Ceil returns the smallest integer greater than or equal to the value -func (e *SQLEngine) Ceil(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("CEIL function requires non-null value") - } - - num, err := e.valueToFloat64(value) - if err != nil { - return nil, fmt.Errorf("CEIL function conversion error: %v", err) - } - - result := math.Ceil(num) - - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, - }, nil -} - -// Floor returns the largest integer less than or equal to the value -func (e *SQLEngine) Floor(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("FLOOR function requires non-null value") - } - - num, err := e.valueToFloat64(value) - if err != nil { - return nil, fmt.Errorf("FLOOR function conversion error: %v", err) - } - - result := math.Floor(num) - - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, - }, nil -} - -// Abs returns the absolute value of a number -func (e *SQLEngine) Abs(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("ABS function requires non-null value") - } - - num, err := e.valueToFloat64(value) - if err != nil { - return nil, fmt.Errorf("ABS function conversion error: %v", err) - } - - result := math.Abs(num) - - // Return same type as input if possible - if e.isIntegerValue(value) { - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: int64(result)}, - }, nil - } - - // Check if original was float32 - if _, ok := value.Kind.(*schema_pb.Value_FloatValue); ok { - return &schema_pb.Value{ - Kind: &schema_pb.Value_FloatValue{FloatValue: float32(result)}, - }, nil - } - - // Default to double - return &schema_pb.Value{ - Kind: &schema_pb.Value_DoubleValue{DoubleValue: result}, - }, nil -} - -// =============================== -// DATE/TIME CONSTANTS -// =============================== - -// CurrentDate returns the current date as a string in YYYY-MM-DD format -func (e *SQLEngine) CurrentDate() (*schema_pb.Value, error) { - now := time.Now() - dateStr := now.Format("2006-01-02") - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: dateStr}, - }, nil -} - -// CurrentTimestamp returns the current timestamp -func (e *SQLEngine) CurrentTimestamp() (*schema_pb.Value, error) { - now := time.Now() - - // Return as TimestampValue with microseconds - timestampMicros := now.UnixMicro() - - return &schema_pb.Value{ - Kind: &schema_pb.Value_TimestampValue{ - TimestampValue: &schema_pb.TimestampValue{ - TimestampMicros: timestampMicros, - }, - }, - }, nil -} - -// CurrentTime returns the current time as a string in HH:MM:SS format -func (e *SQLEngine) CurrentTime() (*schema_pb.Value, error) { - now := time.Now() - timeStr := now.Format("15:04:05") - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: timeStr}, - }, nil -} - -// Now is an alias for CurrentTimestamp (common SQL function name) -func (e *SQLEngine) Now() (*schema_pb.Value, error) { - return e.CurrentTimestamp() -} - -// =============================== -// EXTRACT FUNCTION -// =============================== - -// DatePart represents the part of a date/time to extract -type DatePart string - -const ( - PartYear DatePart = "YEAR" - PartMonth DatePart = "MONTH" - PartDay DatePart = "DAY" - PartHour DatePart = "HOUR" - PartMinute DatePart = "MINUTE" - PartSecond DatePart = "SECOND" - PartWeek DatePart = "WEEK" - PartDayOfYear DatePart = "DOY" - PartDayOfWeek DatePart = "DOW" - PartQuarter DatePart = "QUARTER" - PartEpoch DatePart = "EPOCH" -) - -// Extract extracts a specific part from a date/time value -func (e *SQLEngine) Extract(part DatePart, value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("EXTRACT function requires non-null value") - } - - // Convert value to time - t, err := e.valueToTime(value) - if err != nil { - return nil, fmt.Errorf("EXTRACT function time conversion error: %v", err) - } - - var result int64 - - switch strings.ToUpper(string(part)) { - case string(PartYear): - result = int64(t.Year()) - case string(PartMonth): - result = int64(t.Month()) - case string(PartDay): - result = int64(t.Day()) - case string(PartHour): - result = int64(t.Hour()) - case string(PartMinute): - result = int64(t.Minute()) - case string(PartSecond): - result = int64(t.Second()) - case string(PartWeek): - _, week := t.ISOWeek() - result = int64(week) - case string(PartDayOfYear): - result = int64(t.YearDay()) - case string(PartDayOfWeek): - result = int64(t.Weekday()) - case string(PartQuarter): - month := t.Month() - result = int64((month-1)/3 + 1) - case string(PartEpoch): - result = t.Unix() - default: - return nil, fmt.Errorf("unsupported date part: %s", part) - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: result}, - }, nil -} - -// Helper function to convert schema_pb.Value to time.Time -func (e *SQLEngine) valueToTime(value *schema_pb.Value) (time.Time, error) { - switch v := value.Kind.(type) { - case *schema_pb.Value_TimestampValue: - if v.TimestampValue == nil { - return time.Time{}, fmt.Errorf("null timestamp value") - } - return time.UnixMicro(v.TimestampValue.TimestampMicros), nil - case *schema_pb.Value_StringValue: - // Try to parse various date/time string formats - dateFormats := []struct { - format string - useLocal bool - }{ - {"2006-01-02 15:04:05", true}, // Local time assumed for non-timezone formats - {"2006-01-02T15:04:05Z", false}, // UTC format - {"2006-01-02T15:04:05", true}, // Local time assumed - {"2006-01-02", true}, // Local time assumed for date only - {"15:04:05", true}, // Local time assumed for time only - } - - for _, formatSpec := range dateFormats { - if t, err := time.Parse(formatSpec.format, v.StringValue); err == nil { - if formatSpec.useLocal { - // Convert to local timezone if no timezone was specified - return time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local), nil - } - return t, nil - } - } - return time.Time{}, fmt.Errorf("unable to parse date/time string: %s", v.StringValue) - case *schema_pb.Value_Int64Value: - // Assume Unix timestamp (seconds) - return time.Unix(v.Int64Value, 0), nil - default: - return time.Time{}, fmt.Errorf("cannot convert value type to date/time") - } -} - -// =============================== -// DATE_TRUNC FUNCTION -// =============================== - -// DateTrunc truncates a date/time to the specified precision -func (e *SQLEngine) DateTrunc(precision string, value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("DATE_TRUNC function requires non-null value") - } - - // Convert value to time - t, err := e.valueToTime(value) - if err != nil { - return nil, fmt.Errorf("DATE_TRUNC function time conversion error: %v", err) - } - - var truncated time.Time - - switch strings.ToLower(precision) { - case "microsecond", "microseconds": - // No truncation needed for microsecond precision - truncated = t - case "millisecond", "milliseconds": - truncated = t.Truncate(time.Millisecond) - case "second", "seconds": - truncated = t.Truncate(time.Second) - case "minute", "minutes": - truncated = t.Truncate(time.Minute) - case "hour", "hours": - truncated = t.Truncate(time.Hour) - case "day", "days": - truncated = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) - case "week", "weeks": - // Truncate to beginning of week (Monday) - days := int(t.Weekday()) - if days == 0 { // Sunday = 0, adjust to make Monday = 0 - days = 6 - } else { - days = days - 1 - } - truncated = time.Date(t.Year(), t.Month(), t.Day()-days, 0, 0, 0, 0, t.Location()) - case "month", "months": - truncated = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location()) - case "quarter", "quarters": - month := t.Month() - quarterMonth := ((int(month)-1)/3)*3 + 1 - truncated = time.Date(t.Year(), time.Month(quarterMonth), 1, 0, 0, 0, 0, t.Location()) - case "year", "years": - truncated = time.Date(t.Year(), 1, 1, 0, 0, 0, 0, t.Location()) - case "decade", "decades": - year := (t.Year()/10) * 10 - truncated = time.Date(year, 1, 1, 0, 0, 0, 0, t.Location()) - case "century", "centuries": - year := ((t.Year()-1)/100)*100 + 1 - truncated = time.Date(year, 1, 1, 0, 0, 0, 0, t.Location()) - case "millennium", "millennia": - year := ((t.Year()-1)/1000)*1000 + 1 - truncated = time.Date(year, 1, 1, 0, 0, 0, 0, t.Location()) - default: - return nil, fmt.Errorf("unsupported date truncation precision: %s", precision) - } - - // Return as TimestampValue - return &schema_pb.Value{ - Kind: &schema_pb.Value_TimestampValue{ - TimestampValue: &schema_pb.TimestampValue{ - TimestampMicros: truncated.UnixMicro(), - }, - }, - }, nil -} - -// =============================== -// STRING FUNCTIONS -// =============================== - -// Length returns the length of a string -func (e *SQLEngine) Length(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("LENGTH function requires non-null value") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("LENGTH function conversion error: %v", err) - } - - length := int64(len(str)) - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: length}, - }, nil -} - -// Upper converts a string to uppercase -func (e *SQLEngine) Upper(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("UPPER function requires non-null value") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("UPPER function conversion error: %v", err) - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: strings.ToUpper(str)}, - }, nil -} - -// Lower converts a string to lowercase -func (e *SQLEngine) Lower(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("LOWER function requires non-null value") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("LOWER function conversion error: %v", err) - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: strings.ToLower(str)}, - }, nil -} - -// Trim removes leading and trailing whitespace from a string -func (e *SQLEngine) Trim(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("TRIM function requires non-null value") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("TRIM function conversion error: %v", err) - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: strings.TrimSpace(str)}, - }, nil -} - -// LTrim removes leading whitespace from a string -func (e *SQLEngine) LTrim(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("LTRIM function requires non-null value") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("LTRIM function conversion error: %v", err) - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: strings.TrimLeft(str, " \t\n\r")}, - }, nil -} - -// RTrim removes trailing whitespace from a string -func (e *SQLEngine) RTrim(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("RTRIM function requires non-null value") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("RTRIM function conversion error: %v", err) - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: strings.TrimRight(str, " \t\n\r")}, - }, nil -} - -// Substring extracts a substring from a string -func (e *SQLEngine) Substring(value *schema_pb.Value, start *schema_pb.Value, length ...*schema_pb.Value) (*schema_pb.Value, error) { - if value == nil || start == nil { - return nil, fmt.Errorf("SUBSTRING function requires non-null value and start position") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("SUBSTRING function value conversion error: %v", err) - } - - startPos, err := e.valueToInt64(start) - if err != nil { - return nil, fmt.Errorf("SUBSTRING function start position conversion error: %v", err) - } - - // Convert to 0-based indexing (SQL uses 1-based) - if startPos < 1 { - startPos = 1 - } - startIdx := int(startPos - 1) - - if startIdx >= len(str) { - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: ""}, - }, nil - } - - var result string - if len(length) > 0 && length[0] != nil { - lengthVal, err := e.valueToInt64(length[0]) - if err != nil { - return nil, fmt.Errorf("SUBSTRING function length conversion error: %v", err) - } - - if lengthVal <= 0 { - result = "" - } else { - endIdx := startIdx + int(lengthVal) - if endIdx > len(str) { - endIdx = len(str) - } - result = str[startIdx:endIdx] - } - } else { - result = str[startIdx:] - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: result}, - }, nil -} - -// Concat concatenates multiple strings -func (e *SQLEngine) Concat(values ...*schema_pb.Value) (*schema_pb.Value, error) { - if len(values) == 0 { - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: ""}, - }, nil - } - - var result strings.Builder - for i, value := range values { - if value == nil { - continue // Skip null values - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("CONCAT function value %d conversion error: %v", i, err) - } - result.WriteString(str) - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: result.String()}, - }, nil -} - -// Replace replaces all occurrences of a substring with another substring -func (e *SQLEngine) Replace(value, oldStr, newStr *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil || oldStr == nil || newStr == nil { - return nil, fmt.Errorf("REPLACE function requires non-null values") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("REPLACE function value conversion error: %v", err) - } - - old, err := e.valueToString(oldStr) - if err != nil { - return nil, fmt.Errorf("REPLACE function old string conversion error: %v", err) - } - - new, err := e.valueToString(newStr) - if err != nil { - return nil, fmt.Errorf("REPLACE function new string conversion error: %v", err) - } - - result := strings.ReplaceAll(str, old, new) - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: result}, - }, nil -} - -// Position returns the position of a substring in a string (1-based, 0 if not found) -func (e *SQLEngine) Position(substring, value *schema_pb.Value) (*schema_pb.Value, error) { - if substring == nil || value == nil { - return nil, fmt.Errorf("POSITION function requires non-null values") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("POSITION function string conversion error: %v", err) - } - - substr, err := e.valueToString(substring) - if err != nil { - return nil, fmt.Errorf("POSITION function substring conversion error: %v", err) - } - - pos := strings.Index(str, substr) - if pos == -1 { - pos = 0 // SQL returns 0 for not found - } else { - pos = pos + 1 // Convert to 1-based indexing - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_Int64Value{Int64Value: int64(pos)}, - }, nil -} - -// Left returns the leftmost characters of a string -func (e *SQLEngine) Left(value *schema_pb.Value, length *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil || length == nil { - return nil, fmt.Errorf("LEFT function requires non-null values") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("LEFT function string conversion error: %v", err) - } - - lengthVal, err := e.valueToInt64(length) - if err != nil { - return nil, fmt.Errorf("LEFT function length conversion error: %v", err) - } - - if lengthVal <= 0 { - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: ""}, - }, nil - } - - if int(lengthVal) >= len(str) { - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: str}, - }, nil - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: str[:lengthVal]}, - }, nil -} - -// Right returns the rightmost characters of a string -func (e *SQLEngine) Right(value *schema_pb.Value, length *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil || length == nil { - return nil, fmt.Errorf("RIGHT function requires non-null values") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("RIGHT function string conversion error: %v", err) - } - - lengthVal, err := e.valueToInt64(length) - if err != nil { - return nil, fmt.Errorf("RIGHT function length conversion error: %v", err) - } - - if lengthVal <= 0 { - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: ""}, - }, nil - } - - if int(lengthVal) >= len(str) { - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: str}, - }, nil - } - - startPos := len(str) - int(lengthVal) - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: str[startPos:]}, - }, nil -} - -// Reverse reverses a string -func (e *SQLEngine) Reverse(value *schema_pb.Value) (*schema_pb.Value, error) { - if value == nil { - return nil, fmt.Errorf("REVERSE function requires non-null value") - } - - str, err := e.valueToString(value) - if err != nil { - return nil, fmt.Errorf("REVERSE function conversion error: %v", err) - } - - // Reverse the string rune by rune to handle Unicode correctly - runes := []rune(str) - for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { - runes[i], runes[j] = runes[j], runes[i] - } - - return &schema_pb.Value{ - Kind: &schema_pb.Value_StringValue{StringValue: string(runes)}, - }, nil -} - -// Helper function to convert schema_pb.Value to string -func (e *SQLEngine) valueToString(value *schema_pb.Value) (string, error) { - switch v := value.Kind.(type) { - case *schema_pb.Value_StringValue: - return v.StringValue, nil - case *schema_pb.Value_Int32Value: - return strconv.FormatInt(int64(v.Int32Value), 10), nil - case *schema_pb.Value_Int64Value: - return strconv.FormatInt(v.Int64Value, 10), nil - case *schema_pb.Value_FloatValue: - return strconv.FormatFloat(float64(v.FloatValue), 'g', -1, 32), nil - case *schema_pb.Value_DoubleValue: - return strconv.FormatFloat(v.DoubleValue, 'g', -1, 64), nil - case *schema_pb.Value_BoolValue: - if v.BoolValue { - return "true", nil - } - return "false", nil - case *schema_pb.Value_BytesValue: - return string(v.BytesValue), nil - default: - return "", fmt.Errorf("cannot convert value type to string") - } -} - -// Helper function to convert schema_pb.Value to int64 -func (e *SQLEngine) valueToInt64(value *schema_pb.Value) (int64, error) { - switch v := value.Kind.(type) { - case *schema_pb.Value_Int32Value: - return int64(v.Int32Value), nil - case *schema_pb.Value_Int64Value: - return v.Int64Value, nil - case *schema_pb.Value_FloatValue: - return int64(v.FloatValue), nil - case *schema_pb.Value_DoubleValue: - return int64(v.DoubleValue), nil - case *schema_pb.Value_StringValue: - if i, err := strconv.ParseInt(v.StringValue, 10, 64); err == nil { - return i, nil - } - return 0, fmt.Errorf("cannot convert string '%s' to integer", v.StringValue) - default: - return 0, fmt.Errorf("cannot convert value type to integer") - } -} diff --git a/weed/query/engine/sql_functions_test.go b/weed/query/engine/sql_functions_test.go deleted file mode 100644 index 30d9b0b31..000000000 --- a/weed/query/engine/sql_functions_test.go +++ /dev/null @@ -1,1205 +0,0 @@ -package engine - -import ( - "testing" - "time" - - "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" -) - -func TestArithmeticOperations(t *testing.T) { - engine := NewTestSQLEngine() - - tests := []struct { - name string - left *schema_pb.Value - right *schema_pb.Value - operator ArithmeticOperator - expected *schema_pb.Value - expectErr bool - }{ - // Addition tests - { - name: "Add two integers", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - operator: OpAdd, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 15}}, - expectErr: false, - }, - { - name: "Add integer and float", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 5.5}}, - operator: OpAdd, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 15.5}}, - expectErr: false, - }, - // Subtraction tests - { - name: "Subtract two integers", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}}, - operator: OpSub, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, - expectErr: false, - }, - // Multiplication tests - { - name: "Multiply two integers", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 6}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, - operator: OpMul, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 42}}, - expectErr: false, - }, - { - name: "Multiply with float", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, - operator: OpMul, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 12.5}}, - expectErr: false, - }, - // Division tests - { - name: "Divide two integers", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 20}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 4}}, - operator: OpDiv, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 5.0}}, - expectErr: false, - }, - { - name: "Division by zero", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, - operator: OpDiv, - expected: nil, - expectErr: true, - }, - // Modulo tests - { - name: "Modulo operation", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 17}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - operator: OpMod, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 2}}, - expectErr: false, - }, - { - name: "Modulo by zero", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, - operator: OpMod, - expected: nil, - expectErr: true, - }, - // String conversion tests - { - name: "Add string number to integer", - left: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "15"}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - operator: OpAdd, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 20.0}}, - expectErr: false, - }, - { - name: "Invalid string conversion", - left: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "not_a_number"}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - operator: OpAdd, - expected: nil, - expectErr: true, - }, - // Boolean conversion tests - { - name: "Add boolean to integer", - left: &schema_pb.Value{Kind: &schema_pb.Value_BoolValue{BoolValue: true}}, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - operator: OpAdd, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 6.0}}, - expectErr: false, - }, - // Null value tests - { - name: "Add with null left operand", - left: nil, - right: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - operator: OpAdd, - expected: nil, - expectErr: true, - }, - { - name: "Add with null right operand", - left: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - right: nil, - operator: OpAdd, - expected: nil, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.EvaluateArithmeticExpression(tt.left, tt.right, tt.operator) - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - if !valuesEqual(result, tt.expected) { - t.Errorf("Expected %v, got %v", tt.expected, result) - } - }) - } -} - -func TestIndividualArithmeticFunctions(t *testing.T) { - engine := NewTestSQLEngine() - - left := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 10}} - right := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}} - - // Test Add function - result, err := engine.Add(left, right) - if err != nil { - t.Errorf("Add function failed: %v", err) - } - expected := &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 13}} - if !valuesEqual(result, expected) { - t.Errorf("Add: Expected %v, got %v", expected, result) - } - - // Test Subtract function - result, err = engine.Subtract(left, right) - if err != nil { - t.Errorf("Subtract function failed: %v", err) - } - expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}} - if !valuesEqual(result, expected) { - t.Errorf("Subtract: Expected %v, got %v", expected, result) - } - - // Test Multiply function - result, err = engine.Multiply(left, right) - if err != nil { - t.Errorf("Multiply function failed: %v", err) - } - expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 30}} - if !valuesEqual(result, expected) { - t.Errorf("Multiply: Expected %v, got %v", expected, result) - } - - // Test Divide function - result, err = engine.Divide(left, right) - if err != nil { - t.Errorf("Divide function failed: %v", err) - } - expected = &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 10.0/3.0}} - if !valuesEqual(result, expected) { - t.Errorf("Divide: Expected %v, got %v", expected, result) - } - - // Test Modulo function - result, err = engine.Modulo(left, right) - if err != nil { - t.Errorf("Modulo function failed: %v", err) - } - expected = &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 1}} - if !valuesEqual(result, expected) { - t.Errorf("Modulo: Expected %v, got %v", expected, result) - } -} - -// Helper function to compare two schema_pb.Value objects -func valuesEqual(v1, v2 *schema_pb.Value) bool { - if v1 == nil && v2 == nil { - return true - } - if v1 == nil || v2 == nil { - return false - } - - switch v1Kind := v1.Kind.(type) { - case *schema_pb.Value_Int32Value: - if v2Kind, ok := v2.Kind.(*schema_pb.Value_Int32Value); ok { - return v1Kind.Int32Value == v2Kind.Int32Value - } - case *schema_pb.Value_Int64Value: - if v2Kind, ok := v2.Kind.(*schema_pb.Value_Int64Value); ok { - return v1Kind.Int64Value == v2Kind.Int64Value - } - case *schema_pb.Value_FloatValue: - if v2Kind, ok := v2.Kind.(*schema_pb.Value_FloatValue); ok { - return v1Kind.FloatValue == v2Kind.FloatValue - } - case *schema_pb.Value_DoubleValue: - if v2Kind, ok := v2.Kind.(*schema_pb.Value_DoubleValue); ok { - return v1Kind.DoubleValue == v2Kind.DoubleValue - } - case *schema_pb.Value_StringValue: - if v2Kind, ok := v2.Kind.(*schema_pb.Value_StringValue); ok { - return v1Kind.StringValue == v2Kind.StringValue - } - case *schema_pb.Value_BoolValue: - if v2Kind, ok := v2.Kind.(*schema_pb.Value_BoolValue); ok { - return v1Kind.BoolValue == v2Kind.BoolValue - } - } - - return false -} - -func TestMathematicalFunctions(t *testing.T) { - engine := NewTestSQLEngine() - - t.Run("ROUND function tests", func(t *testing.T) { - tests := []struct { - name string - value *schema_pb.Value - precision *schema_pb.Value - expected *schema_pb.Value - expectErr bool - }{ - { - name: "Round float to integer", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.7}}, - precision: nil, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 4.0}}, - expectErr: false, - }, - { - name: "Round integer stays integer", - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - precision: nil, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expectErr: false, - }, - { - name: "Round with precision 2", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14159}}, - precision: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 2}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, - expectErr: false, - }, - { - name: "Round negative number", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.7}}, - precision: nil, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -4.0}}, - expectErr: false, - }, - { - name: "Round null value", - value: nil, - precision: nil, - expected: nil, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var result *schema_pb.Value - var err error - - if tt.precision != nil { - result, err = engine.Round(tt.value, tt.precision) - } else { - result, err = engine.Round(tt.value) - } - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - if !valuesEqual(result, tt.expected) { - t.Errorf("Expected %v, got %v", tt.expected, result) - } - }) - } - }) - - t.Run("CEIL function tests", func(t *testing.T) { - tests := []struct { - name string - value *schema_pb.Value - expected *schema_pb.Value - expectErr bool - }{ - { - name: "Ceil positive decimal", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.2}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 4}}, - expectErr: false, - }, - { - name: "Ceil negative decimal", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.2}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: -3}}, - expectErr: false, - }, - { - name: "Ceil integer", - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expectErr: false, - }, - { - name: "Ceil null value", - value: nil, - expected: nil, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.Ceil(tt.value) - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - if !valuesEqual(result, tt.expected) { - t.Errorf("Expected %v, got %v", tt.expected, result) - } - }) - } - }) - - t.Run("FLOOR function tests", func(t *testing.T) { - tests := []struct { - name string - value *schema_pb.Value - expected *schema_pb.Value - expectErr bool - }{ - { - name: "Floor positive decimal", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.8}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 3}}, - expectErr: false, - }, - { - name: "Floor negative decimal", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.2}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: -4}}, - expectErr: false, - }, - { - name: "Floor integer", - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expectErr: false, - }, - { - name: "Floor null value", - value: nil, - expected: nil, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.Floor(tt.value) - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - if !valuesEqual(result, tt.expected) { - t.Errorf("Expected %v, got %v", tt.expected, result) - } - }) - } - }) - - t.Run("ABS function tests", func(t *testing.T) { - tests := []struct { - name string - value *schema_pb.Value - expected *schema_pb.Value - expectErr bool - }{ - { - name: "Abs positive integer", - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expectErr: false, - }, - { - name: "Abs negative integer", - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: -5}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}, - expectErr: false, - }, - { - name: "Abs positive double", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, - expectErr: false, - }, - { - name: "Abs negative double", - value: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: -3.14}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_DoubleValue{DoubleValue: 3.14}}, - expectErr: false, - }, - { - name: "Abs positive float", - value: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, - expectErr: false, - }, - { - name: "Abs negative float", - value: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: -2.5}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_FloatValue{FloatValue: 2.5}}, - expectErr: false, - }, - { - name: "Abs zero", - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, - expected: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 0}}, - expectErr: false, - }, - { - name: "Abs null value", - value: nil, - expected: nil, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.Abs(tt.value) - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - if !valuesEqual(result, tt.expected) { - t.Errorf("Expected %v, got %v", tt.expected, result) - } - }) - } - }) -} - -func TestDateTimeFunctions(t *testing.T) { - engine := NewTestSQLEngine() - - t.Run("CURRENT_DATE function tests", func(t *testing.T) { - result, err := engine.CurrentDate() - if err != nil { - t.Errorf("CurrentDate failed: %v", err) - } - - if result == nil { - t.Errorf("CurrentDate returned nil result") - return - } - - stringVal, ok := result.Kind.(*schema_pb.Value_StringValue) - if !ok { - t.Errorf("CurrentDate should return string value, got %T", result.Kind) - return - } - - // Check format (YYYY-MM-DD) - today := time.Now().Format("2006-01-02") - if stringVal.StringValue != today { - t.Errorf("Expected current date %s, got %s", today, stringVal.StringValue) - } - }) - - t.Run("CURRENT_TIMESTAMP function tests", func(t *testing.T) { - before := time.Now() - result, err := engine.CurrentTimestamp() - after := time.Now() - - if err != nil { - t.Errorf("CurrentTimestamp failed: %v", err) - } - - if result == nil { - t.Errorf("CurrentTimestamp returned nil result") - return - } - - timestampVal, ok := result.Kind.(*schema_pb.Value_TimestampValue) - if !ok { - t.Errorf("CurrentTimestamp should return timestamp value, got %T", result.Kind) - return - } - - timestamp := time.UnixMicro(timestampVal.TimestampValue.TimestampMicros) - - // Check that timestamp is within reasonable range - if timestamp.Before(before) || timestamp.After(after) { - t.Errorf("Timestamp %v should be between %v and %v", timestamp, before, after) - } - }) - - t.Run("NOW function tests", func(t *testing.T) { - result, err := engine.Now() - if err != nil { - t.Errorf("Now failed: %v", err) - } - - if result == nil { - t.Errorf("Now returned nil result") - return - } - - // Should return same type as CurrentTimestamp - _, ok := result.Kind.(*schema_pb.Value_TimestampValue) - if !ok { - t.Errorf("Now should return timestamp value, got %T", result.Kind) - } - }) - - t.Run("CURRENT_TIME function tests", func(t *testing.T) { - result, err := engine.CurrentTime() - if err != nil { - t.Errorf("CurrentTime failed: %v", err) - } - - if result == nil { - t.Errorf("CurrentTime returned nil result") - return - } - - stringVal, ok := result.Kind.(*schema_pb.Value_StringValue) - if !ok { - t.Errorf("CurrentTime should return string value, got %T", result.Kind) - return - } - - // Check format (HH:MM:SS) - if len(stringVal.StringValue) != 8 || stringVal.StringValue[2] != ':' || stringVal.StringValue[5] != ':' { - t.Errorf("CurrentTime should return HH:MM:SS format, got %s", stringVal.StringValue) - } - }) -} - -func TestExtractFunction(t *testing.T) { - engine := NewTestSQLEngine() - - // Create a test timestamp: 2023-06-15 14:30:45 - // Use local time to avoid timezone conversion issues - testTime := time.Date(2023, 6, 15, 14, 30, 45, 0, time.Local) - testTimestamp := &schema_pb.Value{ - Kind: &schema_pb.Value_TimestampValue{ - TimestampValue: &schema_pb.TimestampValue{ - TimestampMicros: testTime.UnixMicro(), - }, - }, - } - - tests := []struct { - name string - part DatePart - value *schema_pb.Value - expected int64 - expectErr bool - }{ - { - name: "Extract YEAR", - part: PartYear, - value: testTimestamp, - expected: 2023, - expectErr: false, - }, - { - name: "Extract MONTH", - part: PartMonth, - value: testTimestamp, - expected: 6, - expectErr: false, - }, - { - name: "Extract DAY", - part: PartDay, - value: testTimestamp, - expected: 15, - expectErr: false, - }, - { - name: "Extract HOUR", - part: PartHour, - value: testTimestamp, - expected: 14, - expectErr: false, - }, - { - name: "Extract MINUTE", - part: PartMinute, - value: testTimestamp, - expected: 30, - expectErr: false, - }, - { - name: "Extract SECOND", - part: PartSecond, - value: testTimestamp, - expected: 45, - expectErr: false, - }, - { - name: "Extract QUARTER from June", - part: PartQuarter, - value: testTimestamp, - expected: 2, // June is in Q2 - expectErr: false, - }, - { - name: "Extract from string date", - part: PartYear, - value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "2023-06-15"}}, - expected: 2023, - expectErr: false, - }, - { - name: "Extract from Unix timestamp", - part: PartYear, - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: testTime.Unix()}}, - expected: 2023, - expectErr: false, - }, - { - name: "Extract from null value", - part: PartYear, - value: nil, - expected: 0, - expectErr: true, - }, - { - name: "Extract invalid part", - part: DatePart("INVALID"), - value: testTimestamp, - expected: 0, - expectErr: true, - }, - { - name: "Extract from invalid string", - part: PartYear, - value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "invalid-date"}}, - expected: 0, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.Extract(tt.part, tt.value) - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("Extract returned nil result") - return - } - - intVal, ok := result.Kind.(*schema_pb.Value_Int64Value) - if !ok { - t.Errorf("Extract should return int64 value, got %T", result.Kind) - return - } - - if intVal.Int64Value != tt.expected { - t.Errorf("Expected %d, got %d", tt.expected, intVal.Int64Value) - } - }) - } -} - -func TestDateTruncFunction(t *testing.T) { - engine := NewTestSQLEngine() - - // Create a test timestamp: 2023-06-15 14:30:45.123456 - testTime := time.Date(2023, 6, 15, 14, 30, 45, 123456000, time.Local) // nanoseconds - testTimestamp := &schema_pb.Value{ - Kind: &schema_pb.Value_TimestampValue{ - TimestampValue: &schema_pb.TimestampValue{ - TimestampMicros: testTime.UnixMicro(), - }, - }, - } - - tests := []struct { - name string - precision string - value *schema_pb.Value - expectErr bool - expectedCheck func(result time.Time) bool // Custom check function - }{ - { - name: "Truncate to second", - precision: "second", - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && - result.Hour() == 14 && result.Minute() == 30 && result.Second() == 45 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate to minute", - precision: "minute", - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && - result.Hour() == 14 && result.Minute() == 30 && result.Second() == 0 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate to hour", - precision: "hour", - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && - result.Hour() == 14 && result.Minute() == 0 && result.Second() == 0 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate to day", - precision: "day", - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && - result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate to month", - precision: "month", - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - return result.Year() == 2023 && result.Month() == 6 && result.Day() == 1 && - result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate to quarter", - precision: "quarter", - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - // June (month 6) should truncate to April (month 4) - start of Q2 - return result.Year() == 2023 && result.Month() == 4 && result.Day() == 1 && - result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate to year", - precision: "year", - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - return result.Year() == 2023 && result.Month() == 1 && result.Day() == 1 && - result.Hour() == 0 && result.Minute() == 0 && result.Second() == 0 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate with plural precision", - precision: "minutes", // Test plural form - value: testTimestamp, - expectErr: false, - expectedCheck: func(result time.Time) bool { - return result.Year() == 2023 && result.Month() == 6 && result.Day() == 15 && - result.Hour() == 14 && result.Minute() == 30 && result.Second() == 0 && - result.Nanosecond() == 0 - }, - }, - { - name: "Truncate from string date", - precision: "day", - value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "2023-06-15 14:30:45"}}, - expectErr: false, - expectedCheck: func(result time.Time) bool { - // The result should be the start of day 2023-06-15 in local timezone - expectedDay := time.Date(2023, 6, 15, 0, 0, 0, 0, result.Location()) - return result.Equal(expectedDay) - }, - }, - { - name: "Truncate null value", - precision: "day", - value: nil, - expectErr: true, - expectedCheck: nil, - }, - { - name: "Invalid precision", - precision: "invalid", - value: testTimestamp, - expectErr: true, - expectedCheck: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.DateTrunc(tt.precision, tt.value) - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - if result == nil { - t.Errorf("DateTrunc returned nil result") - return - } - - timestampVal, ok := result.Kind.(*schema_pb.Value_TimestampValue) - if !ok { - t.Errorf("DateTrunc should return timestamp value, got %T", result.Kind) - return - } - - resultTime := time.UnixMicro(timestampVal.TimestampValue.TimestampMicros) - - if !tt.expectedCheck(resultTime) { - t.Errorf("DateTrunc result check failed for precision %s, got time: %v", tt.precision, resultTime) - } - }) - } -} - -func TestStringFunctions(t *testing.T) { - engine := NewTestSQLEngine() - - t.Run("LENGTH function tests", func(t *testing.T) { - tests := []struct { - name string - value *schema_pb.Value - expected int64 - expectErr bool - }{ - { - name: "Length of string", - value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}, - expected: 11, - expectErr: false, - }, - { - name: "Length of empty string", - value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: ""}}, - expected: 0, - expectErr: false, - }, - { - name: "Length of number", - value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 12345}}, - expected: 5, - expectErr: false, - }, - { - name: "Length of null value", - value: nil, - expected: 0, - expectErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := engine.Length(tt.value) - - if tt.expectErr { - if err == nil { - t.Errorf("Expected error but got none") - } - return - } - - if err != nil { - t.Errorf("Unexpected error: %v", err) - return - } - - intVal, ok := result.Kind.(*schema_pb.Value_Int64Value) - if !ok { - t.Errorf("LENGTH should return int64 value, got %T", result.Kind) - return - } - - if intVal.Int64Value != tt.expected { - t.Errorf("Expected %d, got %d", tt.expected, intVal.Int64Value) - } - }) - } - }) - - t.Run("UPPER/LOWER function tests", func(t *testing.T) { - // Test UPPER - result, err := engine.Upper(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}) - if err != nil { - t.Errorf("UPPER failed: %v", err) - } - stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "HELLO WORLD" { - t.Errorf("Expected 'HELLO WORLD', got '%s'", stringVal.StringValue) - } - - // Test LOWER - result, err = engine.Lower(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}) - if err != nil { - t.Errorf("LOWER failed: %v", err) - } - stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "hello world" { - t.Errorf("Expected 'hello world', got '%s'", stringVal.StringValue) - } - }) - - t.Run("TRIM function tests", func(t *testing.T) { - tests := []struct { - name string - function func(*schema_pb.Value) (*schema_pb.Value, error) - input string - expected string - }{ - {"TRIM whitespace", engine.Trim, " Hello World ", "Hello World"}, - {"LTRIM whitespace", engine.LTrim, " Hello World ", "Hello World "}, - {"RTRIM whitespace", engine.RTrim, " Hello World ", " Hello World"}, - {"TRIM with tabs and newlines", engine.Trim, "\t\nHello\t\n", "Hello"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := tt.function(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: tt.input}}) - if err != nil { - t.Errorf("Function failed: %v", err) - return - } - - stringVal, ok := result.Kind.(*schema_pb.Value_StringValue) - if !ok { - t.Errorf("Function should return string value, got %T", result.Kind) - return - } - - if stringVal.StringValue != tt.expected { - t.Errorf("Expected '%s', got '%s'", tt.expected, stringVal.StringValue) - } - }) - } - }) - - t.Run("SUBSTRING function tests", func(t *testing.T) { - testStr := &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}} - - // Test substring with start and length - result, err := engine.Substring(testStr, - &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, - &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}) - if err != nil { - t.Errorf("SUBSTRING failed: %v", err) - } - stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "World" { - t.Errorf("Expected 'World', got '%s'", stringVal.StringValue) - } - - // Test substring with just start position - result, err = engine.Substring(testStr, - &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}) - if err != nil { - t.Errorf("SUBSTRING failed: %v", err) - } - stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "World" { - t.Errorf("Expected 'World', got '%s'", stringVal.StringValue) - } - }) - - t.Run("CONCAT function tests", func(t *testing.T) { - result, err := engine.Concat( - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}}, - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: " "}}, - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, - ) - if err != nil { - t.Errorf("CONCAT failed: %v", err) - } - stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "Hello World" { - t.Errorf("Expected 'Hello World', got '%s'", stringVal.StringValue) - } - - // Test with mixed types - result, err = engine.Concat( - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Number: "}}, - &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 42}}, - ) - if err != nil { - t.Errorf("CONCAT failed: %v", err) - } - stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "Number: 42" { - t.Errorf("Expected 'Number: 42', got '%s'", stringVal.StringValue) - } - }) - - t.Run("REPLACE function tests", func(t *testing.T) { - result, err := engine.Replace( - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World World"}}, - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Universe"}}, - ) - if err != nil { - t.Errorf("REPLACE failed: %v", err) - } - stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "Hello Universe Universe" { - t.Errorf("Expected 'Hello Universe Universe', got '%s'", stringVal.StringValue) - } - }) - - t.Run("POSITION function tests", func(t *testing.T) { - result, err := engine.Position( - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}, - ) - if err != nil { - t.Errorf("POSITION failed: %v", err) - } - intVal, _ := result.Kind.(*schema_pb.Value_Int64Value) - if intVal.Int64Value != 7 { - t.Errorf("Expected 7, got %d", intVal.Int64Value) - } - - // Test not found - result, err = engine.Position( - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "NotFound"}}, - &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}, - ) - if err != nil { - t.Errorf("POSITION failed: %v", err) - } - intVal, _ = result.Kind.(*schema_pb.Value_Int64Value) - if intVal.Int64Value != 0 { - t.Errorf("Expected 0 for not found, got %d", intVal.Int64Value) - } - }) - - t.Run("LEFT/RIGHT function tests", func(t *testing.T) { - testStr := &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}} - - // Test LEFT - result, err := engine.Left(testStr, &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}) - if err != nil { - t.Errorf("LEFT failed: %v", err) - } - stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "Hello" { - t.Errorf("Expected 'Hello', got '%s'", stringVal.StringValue) - } - - // Test RIGHT - result, err = engine.Right(testStr, &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}) - if err != nil { - t.Errorf("RIGHT failed: %v", err) - } - stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "World" { - t.Errorf("Expected 'World', got '%s'", stringVal.StringValue) - } - }) - - t.Run("REVERSE function tests", func(t *testing.T) { - result, err := engine.Reverse(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}}) - if err != nil { - t.Errorf("REVERSE failed: %v", err) - } - stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "olleH" { - t.Errorf("Expected 'olleH', got '%s'", stringVal.StringValue) - } - - // Test with Unicode - result, err = engine.Reverse(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "🙂👍"}}) - if err != nil { - t.Errorf("REVERSE failed: %v", err) - } - stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) - if stringVal.StringValue != "👍🙂" { - t.Errorf("Expected '👍🙂', got '%s'", stringVal.StringValue) - } - }) -} diff --git a/weed/query/engine/string_functions.go b/weed/query/engine/string_functions.go new file mode 100644 index 000000000..26acd8f4e --- /dev/null +++ b/weed/query/engine/string_functions.go @@ -0,0 +1,333 @@ +package engine + +import ( + "fmt" + "strings" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +// =============================== +// STRING FUNCTIONS +// =============================== + +// Length returns the length of a string +func (e *SQLEngine) Length(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("LENGTH function requires non-null value") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("LENGTH function conversion error: %v", err) + } + + length := int64(len(str)) + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: length}, + }, nil +} + +// Upper converts a string to uppercase +func (e *SQLEngine) Upper(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("UPPER function requires non-null value") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("UPPER function conversion error: %v", err) + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: strings.ToUpper(str)}, + }, nil +} + +// Lower converts a string to lowercase +func (e *SQLEngine) Lower(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("LOWER function requires non-null value") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("LOWER function conversion error: %v", err) + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: strings.ToLower(str)}, + }, nil +} + +// Trim removes leading and trailing whitespace from a string +func (e *SQLEngine) Trim(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("TRIM function requires non-null value") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("TRIM function conversion error: %v", err) + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: strings.TrimSpace(str)}, + }, nil +} + +// LTrim removes leading whitespace from a string +func (e *SQLEngine) LTrim(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("LTRIM function requires non-null value") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("LTRIM function conversion error: %v", err) + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: strings.TrimLeft(str, " \t\n\r")}, + }, nil +} + +// RTrim removes trailing whitespace from a string +func (e *SQLEngine) RTrim(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("RTRIM function requires non-null value") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("RTRIM function conversion error: %v", err) + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: strings.TrimRight(str, " \t\n\r")}, + }, nil +} + +// Substring extracts a substring from a string +func (e *SQLEngine) Substring(value *schema_pb.Value, start *schema_pb.Value, length ...*schema_pb.Value) (*schema_pb.Value, error) { + if value == nil || start == nil { + return nil, fmt.Errorf("SUBSTRING function requires non-null value and start position") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("SUBSTRING function value conversion error: %v", err) + } + + startPos, err := e.valueToInt64(start) + if err != nil { + return nil, fmt.Errorf("SUBSTRING function start position conversion error: %v", err) + } + + // Convert to 0-based indexing (SQL uses 1-based) + if startPos < 1 { + startPos = 1 + } + startIdx := int(startPos - 1) + + if startIdx >= len(str) { + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: ""}, + }, nil + } + + var result string + if len(length) > 0 && length[0] != nil { + lengthVal, err := e.valueToInt64(length[0]) + if err != nil { + return nil, fmt.Errorf("SUBSTRING function length conversion error: %v", err) + } + + if lengthVal <= 0 { + result = "" + } else { + endIdx := startIdx + int(lengthVal) + if endIdx > len(str) { + endIdx = len(str) + } + result = str[startIdx:endIdx] + } + } else { + result = str[startIdx:] + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: result}, + }, nil +} + +// Concat concatenates multiple strings +func (e *SQLEngine) Concat(values ...*schema_pb.Value) (*schema_pb.Value, error) { + if len(values) == 0 { + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: ""}, + }, nil + } + + var result strings.Builder + for i, value := range values { + if value == nil { + continue // Skip null values + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("CONCAT function value %d conversion error: %v", i, err) + } + result.WriteString(str) + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: result.String()}, + }, nil +} + +// Replace replaces all occurrences of a substring with another substring +func (e *SQLEngine) Replace(value, oldStr, newStr *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil || oldStr == nil || newStr == nil { + return nil, fmt.Errorf("REPLACE function requires non-null values") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("REPLACE function value conversion error: %v", err) + } + + old, err := e.valueToString(oldStr) + if err != nil { + return nil, fmt.Errorf("REPLACE function old string conversion error: %v", err) + } + + new, err := e.valueToString(newStr) + if err != nil { + return nil, fmt.Errorf("REPLACE function new string conversion error: %v", err) + } + + result := strings.ReplaceAll(str, old, new) + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: result}, + }, nil +} + +// Position returns the position of a substring in a string (1-based, 0 if not found) +func (e *SQLEngine) Position(substring, value *schema_pb.Value) (*schema_pb.Value, error) { + if substring == nil || value == nil { + return nil, fmt.Errorf("POSITION function requires non-null values") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("POSITION function string conversion error: %v", err) + } + + substr, err := e.valueToString(substring) + if err != nil { + return nil, fmt.Errorf("POSITION function substring conversion error: %v", err) + } + + pos := strings.Index(str, substr) + if pos == -1 { + pos = 0 // SQL returns 0 for not found + } else { + pos = pos + 1 // Convert to 1-based indexing + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_Int64Value{Int64Value: int64(pos)}, + }, nil +} + +// Left returns the leftmost characters of a string +func (e *SQLEngine) Left(value *schema_pb.Value, length *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil || length == nil { + return nil, fmt.Errorf("LEFT function requires non-null values") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("LEFT function string conversion error: %v", err) + } + + lengthVal, err := e.valueToInt64(length) + if err != nil { + return nil, fmt.Errorf("LEFT function length conversion error: %v", err) + } + + if lengthVal <= 0 { + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: ""}, + }, nil + } + + if int(lengthVal) >= len(str) { + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: str}, + }, nil + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: str[:lengthVal]}, + }, nil +} + +// Right returns the rightmost characters of a string +func (e *SQLEngine) Right(value *schema_pb.Value, length *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil || length == nil { + return nil, fmt.Errorf("RIGHT function requires non-null values") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("RIGHT function string conversion error: %v", err) + } + + lengthVal, err := e.valueToInt64(length) + if err != nil { + return nil, fmt.Errorf("RIGHT function length conversion error: %v", err) + } + + if lengthVal <= 0 { + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: ""}, + }, nil + } + + if int(lengthVal) >= len(str) { + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: str}, + }, nil + } + + startPos := len(str) - int(lengthVal) + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: str[startPos:]}, + }, nil +} + +// Reverse reverses a string +func (e *SQLEngine) Reverse(value *schema_pb.Value) (*schema_pb.Value, error) { + if value == nil { + return nil, fmt.Errorf("REVERSE function requires non-null value") + } + + str, err := e.valueToString(value) + if err != nil { + return nil, fmt.Errorf("REVERSE function conversion error: %v", err) + } + + // Reverse the string rune by rune to handle Unicode correctly + runes := []rune(str) + for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { + runes[i], runes[j] = runes[j], runes[i] + } + + return &schema_pb.Value{ + Kind: &schema_pb.Value_StringValue{StringValue: string(runes)}, + }, nil +} diff --git a/weed/query/engine/string_functions_test.go b/weed/query/engine/string_functions_test.go new file mode 100644 index 000000000..91b5e269a --- /dev/null +++ b/weed/query/engine/string_functions_test.go @@ -0,0 +1,271 @@ +package engine + +import ( + "testing" + + "github.com/seaweedfs/seaweedfs/weed/pb/schema_pb" +) + +func TestStringFunctions(t *testing.T) { + engine := NewTestSQLEngine() + + t.Run("LENGTH function tests", func(t *testing.T) { + tests := []struct { + name string + value *schema_pb.Value + expected int64 + expectErr bool + }{ + { + name: "Length of string", + value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}, + expected: 11, + expectErr: false, + }, + { + name: "Length of empty string", + value: &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: ""}}, + expected: 0, + expectErr: false, + }, + { + name: "Length of number", + value: &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 12345}}, + expected: 5, + expectErr: false, + }, + { + name: "Length of null value", + value: nil, + expected: 0, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := engine.Length(tt.value) + + if tt.expectErr { + if err == nil { + t.Errorf("Expected error but got none") + } + return + } + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + intVal, ok := result.Kind.(*schema_pb.Value_Int64Value) + if !ok { + t.Errorf("LENGTH should return int64 value, got %T", result.Kind) + return + } + + if intVal.Int64Value != tt.expected { + t.Errorf("Expected %d, got %d", tt.expected, intVal.Int64Value) + } + }) + } + }) + + t.Run("UPPER/LOWER function tests", func(t *testing.T) { + // Test UPPER + result, err := engine.Upper(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}) + if err != nil { + t.Errorf("UPPER failed: %v", err) + } + stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "HELLO WORLD" { + t.Errorf("Expected 'HELLO WORLD', got '%s'", stringVal.StringValue) + } + + // Test LOWER + result, err = engine.Lower(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}) + if err != nil { + t.Errorf("LOWER failed: %v", err) + } + stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "hello world" { + t.Errorf("Expected 'hello world', got '%s'", stringVal.StringValue) + } + }) + + t.Run("TRIM function tests", func(t *testing.T) { + tests := []struct { + name string + function func(*schema_pb.Value) (*schema_pb.Value, error) + input string + expected string + }{ + {"TRIM whitespace", engine.Trim, " Hello World ", "Hello World"}, + {"LTRIM whitespace", engine.LTrim, " Hello World ", "Hello World "}, + {"RTRIM whitespace", engine.RTrim, " Hello World ", " Hello World"}, + {"TRIM with tabs and newlines", engine.Trim, "\t\nHello\t\n", "Hello"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.function(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: tt.input}}) + if err != nil { + t.Errorf("Function failed: %v", err) + return + } + + stringVal, ok := result.Kind.(*schema_pb.Value_StringValue) + if !ok { + t.Errorf("Function should return string value, got %T", result.Kind) + return + } + + if stringVal.StringValue != tt.expected { + t.Errorf("Expected '%s', got '%s'", tt.expected, stringVal.StringValue) + } + }) + } + }) + + t.Run("SUBSTRING function tests", func(t *testing.T) { + testStr := &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}} + + // Test substring with start and length + result, err := engine.Substring(testStr, + &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}, + &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}) + if err != nil { + t.Errorf("SUBSTRING failed: %v", err) + } + stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "World" { + t.Errorf("Expected 'World', got '%s'", stringVal.StringValue) + } + + // Test substring with just start position + result, err = engine.Substring(testStr, + &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 7}}) + if err != nil { + t.Errorf("SUBSTRING failed: %v", err) + } + stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "World" { + t.Errorf("Expected 'World', got '%s'", stringVal.StringValue) + } + }) + + t.Run("CONCAT function tests", func(t *testing.T) { + result, err := engine.Concat( + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}}, + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: " "}}, + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, + ) + if err != nil { + t.Errorf("CONCAT failed: %v", err) + } + stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "Hello World" { + t.Errorf("Expected 'Hello World', got '%s'", stringVal.StringValue) + } + + // Test with mixed types + result, err = engine.Concat( + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Number: "}}, + &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 42}}, + ) + if err != nil { + t.Errorf("CONCAT failed: %v", err) + } + stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "Number: 42" { + t.Errorf("Expected 'Number: 42', got '%s'", stringVal.StringValue) + } + }) + + t.Run("REPLACE function tests", func(t *testing.T) { + result, err := engine.Replace( + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World World"}}, + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Universe"}}, + ) + if err != nil { + t.Errorf("REPLACE failed: %v", err) + } + stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "Hello Universe Universe" { + t.Errorf("Expected 'Hello Universe Universe', got '%s'", stringVal.StringValue) + } + }) + + t.Run("POSITION function tests", func(t *testing.T) { + result, err := engine.Position( + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "World"}}, + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}, + ) + if err != nil { + t.Errorf("POSITION failed: %v", err) + } + intVal, _ := result.Kind.(*schema_pb.Value_Int64Value) + if intVal.Int64Value != 7 { + t.Errorf("Expected 7, got %d", intVal.Int64Value) + } + + // Test not found + result, err = engine.Position( + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "NotFound"}}, + &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}}, + ) + if err != nil { + t.Errorf("POSITION failed: %v", err) + } + intVal, _ = result.Kind.(*schema_pb.Value_Int64Value) + if intVal.Int64Value != 0 { + t.Errorf("Expected 0 for not found, got %d", intVal.Int64Value) + } + }) + + t.Run("LEFT/RIGHT function tests", func(t *testing.T) { + testStr := &schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello World"}} + + // Test LEFT + result, err := engine.Left(testStr, &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}) + if err != nil { + t.Errorf("LEFT failed: %v", err) + } + stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "Hello" { + t.Errorf("Expected 'Hello', got '%s'", stringVal.StringValue) + } + + // Test RIGHT + result, err = engine.Right(testStr, &schema_pb.Value{Kind: &schema_pb.Value_Int64Value{Int64Value: 5}}) + if err != nil { + t.Errorf("RIGHT failed: %v", err) + } + stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "World" { + t.Errorf("Expected 'World', got '%s'", stringVal.StringValue) + } + }) + + t.Run("REVERSE function tests", func(t *testing.T) { + result, err := engine.Reverse(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "Hello"}}) + if err != nil { + t.Errorf("REVERSE failed: %v", err) + } + stringVal, _ := result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "olleH" { + t.Errorf("Expected 'olleH', got '%s'", stringVal.StringValue) + } + + // Test with Unicode + result, err = engine.Reverse(&schema_pb.Value{Kind: &schema_pb.Value_StringValue{StringValue: "🙂👍"}}) + if err != nil { + t.Errorf("REVERSE failed: %v", err) + } + stringVal, _ = result.Kind.(*schema_pb.Value_StringValue) + if stringVal.StringValue != "👍🙂" { + t.Errorf("Expected '👍🙂', got '%s'", stringVal.StringValue) + } + }) +}