column name can be on left or right in where conditions

This commit is contained in:
chrislu
2025-09-02 17:05:31 -07:00
parent 900bd94456
commit ed7102df6e

View File

@@ -1210,18 +1210,56 @@ func (e *SQLEngine) buildPredicate(expr sqlparser.Expr) (func(*schema_pb.RecordV
}
// buildComparisonPredicate creates a predicate for comparison operations (=, <, >, etc.)
// Handles column names on both left and right sides of the comparison
func (e *SQLEngine) buildComparisonPredicate(expr *sqlparser.ComparisonExpr) (func(*schema_pb.RecordValue) bool, error) {
// Extract column name (left side)
colName, ok := expr.Left.(*sqlparser.ColName)
if !ok {
return nil, fmt.Errorf("unsupported comparison left side: %T", expr.Left)
var columnName string
var compareValue interface{}
var operator string
// Check if column is on the left side (normal case: column > value)
if colName, ok := expr.Left.(*sqlparser.ColName); ok {
columnName = colName.Name.String()
operator = expr.Operator
// Extract comparison value from right side
val, err := e.extractComparisonValue(expr.Right)
if err != nil {
return nil, fmt.Errorf("failed to extract right-side value: %v", err)
}
compareValue = val
} else if colName, ok := expr.Right.(*sqlparser.ColName); ok {
// Column is on the right side (reversed case: value < column)
columnName = colName.Name.String()
// Reverse the operator when column is on right side
operator = e.reverseOperator(expr.Operator)
// Extract comparison value from left side
val, err := e.extractComparisonValue(expr.Left)
if err != nil {
return nil, fmt.Errorf("failed to extract left-side value: %v", err)
}
compareValue = val
} else {
return nil, fmt.Errorf("no column name found in comparison expression, left: %T, right: %T", expr.Left, expr.Right)
}
columnName := colName.Name.String()
// Create predicate based on operator
return func(record *schema_pb.RecordValue) bool {
fieldValue, exists := record.Fields[columnName]
if !exists {
return false
}
// Extract comparison value (right side)
var compareValue interface{}
switch val := expr.Right.(type) {
return e.evaluateComparison(fieldValue, operator, compareValue)
}, nil
}
// extractComparisonValue extracts the comparison value from a SQL expression
func (e *SQLEngine) extractComparisonValue(expr sqlparser.Expr) (interface{}, error) {
switch val := expr.(type) {
case *sqlparser.SQLVal:
switch val.Type {
case sqlparser.IntVal:
@@ -1229,9 +1267,9 @@ func (e *SQLEngine) buildComparisonPredicate(expr *sqlparser.ComparisonExpr) (fu
if err != nil {
return nil, err
}
compareValue = intVal
return intVal, nil
case sqlparser.StrVal:
compareValue = string(val.Val)
return string(val.Val), nil
default:
return nil, fmt.Errorf("unsupported SQL value type: %v", val.Type)
}
@@ -1253,22 +1291,10 @@ func (e *SQLEngine) buildComparisonPredicate(expr *sqlparser.ComparisonExpr) (fu
}
}
}
compareValue = inValues
return inValues, nil
default:
return nil, fmt.Errorf("unsupported comparison right side: %T", expr.Right)
return nil, fmt.Errorf("unsupported comparison value type: %T", expr)
}
// Create predicate based on operator
operator := expr.Operator
return func(record *schema_pb.RecordValue) bool {
fieldValue, exists := record.Fields[columnName]
if !exists {
return false
}
return e.evaluateComparison(fieldValue, operator, compareValue)
}, nil
}
// evaluateComparison performs the actual comparison