column name can be on left or right in where conditions

This commit is contained in:
chrislu
2025-09-02 17:05:31 -07:00
parent 900bd94456
commit ed7102df6e

View File

@@ -1210,18 +1210,56 @@ func (e *SQLEngine) buildPredicate(expr sqlparser.Expr) (func(*schema_pb.RecordV
} }
// buildComparisonPredicate creates a predicate for comparison operations (=, <, >, etc.) // buildComparisonPredicate creates a predicate for comparison operations (=, <, >, etc.)
// Handles column names on both left and right sides of the comparison
func (e *SQLEngine) buildComparisonPredicate(expr *sqlparser.ComparisonExpr) (func(*schema_pb.RecordValue) bool, error) { func (e *SQLEngine) buildComparisonPredicate(expr *sqlparser.ComparisonExpr) (func(*schema_pb.RecordValue) bool, error) {
// Extract column name (left side) var columnName string
colName, ok := expr.Left.(*sqlparser.ColName) var compareValue interface{}
if !ok { var operator string
return nil, fmt.Errorf("unsupported comparison left side: %T", expr.Left)
// Check if column is on the left side (normal case: column > value)
if colName, ok := expr.Left.(*sqlparser.ColName); ok {
columnName = colName.Name.String()
operator = expr.Operator
// Extract comparison value from right side
val, err := e.extractComparisonValue(expr.Right)
if err != nil {
return nil, fmt.Errorf("failed to extract right-side value: %v", err)
}
compareValue = val
} else if colName, ok := expr.Right.(*sqlparser.ColName); ok {
// Column is on the right side (reversed case: value < column)
columnName = colName.Name.String()
// Reverse the operator when column is on right side
operator = e.reverseOperator(expr.Operator)
// Extract comparison value from left side
val, err := e.extractComparisonValue(expr.Left)
if err != nil {
return nil, fmt.Errorf("failed to extract left-side value: %v", err)
}
compareValue = val
} else {
return nil, fmt.Errorf("no column name found in comparison expression, left: %T, right: %T", expr.Left, expr.Right)
} }
columnName := colName.Name.String() // Create predicate based on operator
return func(record *schema_pb.RecordValue) bool {
fieldValue, exists := record.Fields[columnName]
if !exists {
return false
}
// Extract comparison value (right side) return e.evaluateComparison(fieldValue, operator, compareValue)
var compareValue interface{} }, nil
switch val := expr.Right.(type) { }
// extractComparisonValue extracts the comparison value from a SQL expression
func (e *SQLEngine) extractComparisonValue(expr sqlparser.Expr) (interface{}, error) {
switch val := expr.(type) {
case *sqlparser.SQLVal: case *sqlparser.SQLVal:
switch val.Type { switch val.Type {
case sqlparser.IntVal: case sqlparser.IntVal:
@@ -1229,9 +1267,9 @@ func (e *SQLEngine) buildComparisonPredicate(expr *sqlparser.ComparisonExpr) (fu
if err != nil { if err != nil {
return nil, err return nil, err
} }
compareValue = intVal return intVal, nil
case sqlparser.StrVal: case sqlparser.StrVal:
compareValue = string(val.Val) return string(val.Val), nil
default: default:
return nil, fmt.Errorf("unsupported SQL value type: %v", val.Type) return nil, fmt.Errorf("unsupported SQL value type: %v", val.Type)
} }
@@ -1253,22 +1291,10 @@ func (e *SQLEngine) buildComparisonPredicate(expr *sqlparser.ComparisonExpr) (fu
} }
} }
} }
compareValue = inValues return inValues, nil
default: default:
return nil, fmt.Errorf("unsupported comparison right side: %T", expr.Right) return nil, fmt.Errorf("unsupported comparison value type: %T", expr)
} }
// Create predicate based on operator
operator := expr.Operator
return func(record *schema_pb.RecordValue) bool {
fieldValue, exists := record.Fields[columnName]
if !exists {
return false
}
return e.evaluateComparison(fieldValue, operator, compareValue)
}, nil
} }
// evaluateComparison performs the actual comparison // evaluateComparison performs the actual comparison