mirror of
https://github.com/seaweedfs/seaweedfs.git
synced 2025-09-19 22:19:23 +08:00
support aggregation functions
This commit is contained in:
@@ -168,9 +168,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser.
|
||||
return e.executeSelectWithSampleData(ctx, stmt, database, tableName)
|
||||
}
|
||||
|
||||
// Parse SELECT columns
|
||||
// Parse SELECT columns and detect aggregation functions
|
||||
var columns []string
|
||||
var aggregations []AggregationSpec
|
||||
selectAll := false
|
||||
hasAggregations := false
|
||||
|
||||
for _, selectExpr := range stmt.SelectExprs {
|
||||
switch expr := selectExpr.(type) {
|
||||
@@ -180,6 +182,14 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser.
|
||||
switch col := expr.Expr.(type) {
|
||||
case *sqlparser.ColName:
|
||||
columns = append(columns, col.Name.String())
|
||||
case *sqlparser.FuncExpr:
|
||||
// Handle aggregation functions
|
||||
aggSpec, err := e.parseAggregationFunction(col, expr)
|
||||
if err != nil {
|
||||
return &QueryResult{Error: err}, err
|
||||
}
|
||||
aggregations = append(aggregations, *aggSpec)
|
||||
hasAggregations = true
|
||||
default:
|
||||
err := fmt.Errorf("unsupported SELECT expression: %T", col)
|
||||
return &QueryResult{Error: err}, err
|
||||
@@ -190,6 +200,11 @@ func (e *SQLEngine) executeSelectStatement(ctx context.Context, stmt *sqlparser.
|
||||
}
|
||||
}
|
||||
|
||||
// If we have aggregations, use aggregation query path
|
||||
if hasAggregations {
|
||||
return e.executeAggregationQuery(ctx, hybridScanner, aggregations, stmt)
|
||||
}
|
||||
|
||||
// Parse WHERE clause for predicate pushdown
|
||||
var predicate func(*schema_pb.RecordValue) bool
|
||||
if stmt.Where != nil {
|
||||
@@ -988,6 +1003,338 @@ func (e *SQLEngine) dropTable(ctx context.Context, stmt *sqlparser.DDL) (*QueryR
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// AggregationSpec defines an aggregation function to be computed
|
||||
type AggregationSpec struct {
|
||||
Function string // COUNT, SUM, AVG, MIN, MAX
|
||||
Column string // Column name, or "*" for COUNT(*)
|
||||
Alias string // Optional alias for the result column
|
||||
}
|
||||
|
||||
// AggregationResult holds the computed result of an aggregation
|
||||
type AggregationResult struct {
|
||||
Count int64
|
||||
Sum float64
|
||||
Min interface{}
|
||||
Max interface{}
|
||||
}
|
||||
|
||||
// parseAggregationFunction parses an aggregation function expression
|
||||
func (e *SQLEngine) parseAggregationFunction(funcExpr *sqlparser.FuncExpr, aliasExpr *sqlparser.AliasedExpr) (*AggregationSpec, error) {
|
||||
funcName := strings.ToUpper(funcExpr.Name.String())
|
||||
|
||||
// Get alias name if specified
|
||||
alias := funcName // Default alias is the function name
|
||||
if !aliasExpr.As.IsEmpty() {
|
||||
alias = aliasExpr.As.String()
|
||||
}
|
||||
|
||||
spec := &AggregationSpec{
|
||||
Function: funcName,
|
||||
Alias: alias,
|
||||
}
|
||||
|
||||
// Parse function arguments
|
||||
switch funcName {
|
||||
case "COUNT":
|
||||
if len(funcExpr.Exprs) != 1 {
|
||||
return nil, fmt.Errorf("COUNT function expects exactly 1 argument")
|
||||
}
|
||||
|
||||
switch arg := funcExpr.Exprs[0].(type) {
|
||||
case *sqlparser.StarExpr:
|
||||
spec.Column = "*"
|
||||
case *sqlparser.AliasedExpr:
|
||||
if colName, ok := arg.Expr.(*sqlparser.ColName); ok {
|
||||
spec.Column = colName.Name.String()
|
||||
} else {
|
||||
return nil, fmt.Errorf("COUNT argument must be a column name or *")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported COUNT argument: %T", arg)
|
||||
}
|
||||
|
||||
case "SUM", "AVG", "MIN", "MAX":
|
||||
if len(funcExpr.Exprs) != 1 {
|
||||
return nil, fmt.Errorf("%s function expects exactly 1 argument", funcName)
|
||||
}
|
||||
|
||||
switch arg := funcExpr.Exprs[0].(type) {
|
||||
case *sqlparser.AliasedExpr:
|
||||
if colName, ok := arg.Expr.(*sqlparser.ColName); ok {
|
||||
spec.Column = colName.Name.String()
|
||||
} else {
|
||||
return nil, fmt.Errorf("%s argument must be a column name", funcName)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported %s argument: %T", funcName, arg)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported aggregation function: %s", funcName)
|
||||
}
|
||||
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
// executeAggregationQuery handles SELECT queries with aggregation functions
|
||||
func (e *SQLEngine) executeAggregationQuery(ctx context.Context, hybridScanner *HybridMessageScanner, aggregations []AggregationSpec, stmt *sqlparser.Select) (*QueryResult, error) {
|
||||
// Parse WHERE clause for filtering
|
||||
var predicate func(*schema_pb.RecordValue) bool
|
||||
var err error
|
||||
if stmt.Where != nil {
|
||||
predicate, err = e.buildPredicate(stmt.Where.Expr)
|
||||
if err != nil {
|
||||
return &QueryResult{Error: err}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Extract time filters for optimization
|
||||
startTimeNs, stopTimeNs := int64(0), int64(0)
|
||||
if stmt.Where != nil {
|
||||
startTimeNs, stopTimeNs = e.extractTimeFilters(stmt.Where.Expr)
|
||||
}
|
||||
|
||||
// Build scan options for full table scan (aggregations need all data)
|
||||
hybridScanOptions := HybridScanOptions{
|
||||
StartTimeNs: startTimeNs,
|
||||
StopTimeNs: stopTimeNs,
|
||||
Limit: 0, // No limit for aggregations - need all data
|
||||
Predicate: predicate,
|
||||
}
|
||||
|
||||
// Execute the hybrid scan to get all matching records
|
||||
results, err := hybridScanner.Scan(ctx, hybridScanOptions)
|
||||
if err != nil {
|
||||
return &QueryResult{Error: err}, err
|
||||
}
|
||||
|
||||
// Compute aggregations
|
||||
aggResults := e.computeAggregations(results, aggregations)
|
||||
|
||||
// Build result set
|
||||
columns := make([]string, len(aggregations))
|
||||
row := make([]sqltypes.Value, len(aggregations))
|
||||
|
||||
for i, spec := range aggregations {
|
||||
columns[i] = spec.Alias
|
||||
row[i] = e.formatAggregationResult(spec, aggResults[i])
|
||||
}
|
||||
|
||||
return &QueryResult{
|
||||
Columns: columns,
|
||||
Rows: [][]sqltypes.Value{row},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// computeAggregations computes aggregation functions over the scan results
|
||||
func (e *SQLEngine) computeAggregations(results []HybridScanResult, aggregations []AggregationSpec) []AggregationResult {
|
||||
aggResults := make([]AggregationResult, len(aggregations))
|
||||
|
||||
for i, spec := range aggregations {
|
||||
switch spec.Function {
|
||||
case "COUNT":
|
||||
if spec.Column == "*" {
|
||||
// COUNT(*) counts all rows
|
||||
aggResults[i].Count = int64(len(results))
|
||||
} else {
|
||||
// COUNT(column) counts non-null values
|
||||
count := int64(0)
|
||||
for _, result := range results {
|
||||
if value, exists := result.Values[spec.Column]; exists && value != nil {
|
||||
if !e.isNullValue(value) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
aggResults[i].Count = count
|
||||
}
|
||||
|
||||
case "SUM":
|
||||
sum := float64(0)
|
||||
for _, result := range results {
|
||||
if value, exists := result.Values[spec.Column]; exists && value != nil {
|
||||
if numValue := e.convertToNumber(value); numValue != nil {
|
||||
sum += *numValue
|
||||
}
|
||||
}
|
||||
}
|
||||
aggResults[i].Sum = sum
|
||||
|
||||
case "AVG":
|
||||
sum := float64(0)
|
||||
count := int64(0)
|
||||
for _, result := range results {
|
||||
if value, exists := result.Values[spec.Column]; exists && value != nil {
|
||||
if numValue := e.convertToNumber(value); numValue != nil {
|
||||
sum += *numValue
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
if count > 0 {
|
||||
aggResults[i].Sum = sum / float64(count) // Store average in Sum field
|
||||
aggResults[i].Count = count
|
||||
}
|
||||
|
||||
case "MIN":
|
||||
var min interface{}
|
||||
for _, result := range results {
|
||||
if value, exists := result.Values[spec.Column]; exists && value != nil {
|
||||
if min == nil || e.compareValues(value, min) < 0 {
|
||||
min = e.extractRawValue(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
aggResults[i].Min = min
|
||||
|
||||
case "MAX":
|
||||
var max interface{}
|
||||
for _, result := range results {
|
||||
if value, exists := result.Values[spec.Column]; exists && value != nil {
|
||||
if max == nil || e.compareValues(value, max) > 0 {
|
||||
max = e.extractRawValue(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
aggResults[i].Max = max
|
||||
}
|
||||
}
|
||||
|
||||
return aggResults
|
||||
}
|
||||
|
||||
// Helper functions for aggregation processing
|
||||
|
||||
func (e *SQLEngine) isNullValue(value *schema_pb.Value) bool {
|
||||
return value == nil || value.Kind == nil
|
||||
}
|
||||
|
||||
func (e *SQLEngine) convertToNumber(value *schema_pb.Value) *float64 {
|
||||
switch v := value.Kind.(type) {
|
||||
case *schema_pb.Value_Int32Value:
|
||||
result := float64(v.Int32Value)
|
||||
return &result
|
||||
case *schema_pb.Value_Int64Value:
|
||||
result := float64(v.Int64Value)
|
||||
return &result
|
||||
case *schema_pb.Value_FloatValue:
|
||||
result := float64(v.FloatValue)
|
||||
return &result
|
||||
case *schema_pb.Value_DoubleValue:
|
||||
return &v.DoubleValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *SQLEngine) extractRawValue(value *schema_pb.Value) interface{} {
|
||||
switch v := value.Kind.(type) {
|
||||
case *schema_pb.Value_Int32Value:
|
||||
return v.Int32Value
|
||||
case *schema_pb.Value_Int64Value:
|
||||
return v.Int64Value
|
||||
case *schema_pb.Value_FloatValue:
|
||||
return v.FloatValue
|
||||
case *schema_pb.Value_DoubleValue:
|
||||
return v.DoubleValue
|
||||
case *schema_pb.Value_StringValue:
|
||||
return v.StringValue
|
||||
case *schema_pb.Value_BoolValue:
|
||||
return v.BoolValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *SQLEngine) compareValues(value1 *schema_pb.Value, value2 interface{}) int {
|
||||
raw1 := e.extractRawValue(value1)
|
||||
if raw1 == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
// Simple comparison - in a full implementation this would handle type coercion
|
||||
switch v1 := raw1.(type) {
|
||||
case int32:
|
||||
if v2, ok := value2.(int32); ok {
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
} else if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case int64:
|
||||
if v2, ok := value2.(int64); ok {
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
} else if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case float64:
|
||||
if v2, ok := value2.(float64); ok {
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
} else if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case string:
|
||||
if v2, ok := value2.(string); ok {
|
||||
if v1 < v2 {
|
||||
return -1
|
||||
} else if v1 > v2 {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *SQLEngine) formatAggregationResult(spec AggregationSpec, result AggregationResult) sqltypes.Value {
|
||||
switch spec.Function {
|
||||
case "COUNT":
|
||||
return sqltypes.NewInt64(result.Count)
|
||||
case "SUM":
|
||||
return sqltypes.NewFloat64(result.Sum)
|
||||
case "AVG":
|
||||
return sqltypes.NewFloat64(result.Sum) // Sum contains the average for AVG
|
||||
case "MIN":
|
||||
if result.Min != nil {
|
||||
return e.convertRawValueToSQL(result.Min)
|
||||
}
|
||||
return sqltypes.NULL
|
||||
case "MAX":
|
||||
if result.Max != nil {
|
||||
return e.convertRawValueToSQL(result.Max)
|
||||
}
|
||||
return sqltypes.NULL
|
||||
}
|
||||
return sqltypes.NULL
|
||||
}
|
||||
|
||||
func (e *SQLEngine) convertRawValueToSQL(value interface{}) sqltypes.Value {
|
||||
switch v := value.(type) {
|
||||
case int32:
|
||||
return sqltypes.NewInt32(v)
|
||||
case int64:
|
||||
return sqltypes.NewInt64(v)
|
||||
case float32:
|
||||
return sqltypes.NewFloat32(v)
|
||||
case float64:
|
||||
return sqltypes.NewFloat64(v)
|
||||
case string:
|
||||
return sqltypes.NewVarChar(v)
|
||||
case bool:
|
||||
if v {
|
||||
return sqltypes.NewVarChar("1")
|
||||
}
|
||||
return sqltypes.NewVarChar("0")
|
||||
}
|
||||
return sqltypes.NULL
|
||||
}
|
||||
|
||||
// discoverAndRegisterTopic attempts to discover an existing topic and register it in the SQL catalog
|
||||
func (e *SQLEngine) discoverAndRegisterTopic(ctx context.Context, database, tableName string) error {
|
||||
// First, check if topic exists by trying to get its schema from the broker/filer
|
||||
|
Reference in New Issue
Block a user