diff --git a/pkg/s3select/sql/statement.go b/pkg/s3select/sql/statement.go index 7a0347dad..bfd06bedb 100644 --- a/pkg/s3select/sql/statement.go +++ b/pkg/s3select/sql/statement.go @@ -122,9 +122,35 @@ func (e *SelectStatement) AggregateResult(output Record) error { return nil } +func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) { + if e.selectAST.Where == nil { + return true, nil + } + value, err := e.selectAST.Where.evalNode(input) + if err != nil { + return false, err + } + + b, ok := value.ToBool() + if !ok { + err = fmt.Errorf("WHERE expression did not return bool") + return false, err + } + + return b, nil +} + // AggregateRow - aggregates the input record. Applies only to // aggregation queries. func (e *SelectStatement) AggregateRow(input Record) error { + ok, err := e.isPassingWhereClause(input) + if err != nil { + return err + } + if !ok { + return nil + } + for _, expr := range e.selectAST.Expression.Expressions { err := expr.aggregateRow(input) if err != nil { @@ -137,22 +163,12 @@ func (e *SelectStatement) AggregateRow(input Record) error { // Eval - evaluates the Select statement for the given record. It // applies only to non-aggregation queries. func (e *SelectStatement) Eval(input, output Record) (Record, error) { - if whereExpr := e.selectAST.Where; whereExpr != nil { - value, err := whereExpr.evalNode(input) - if err != nil { - return nil, err - } - - b, ok := value.ToBool() - if !ok { - err = fmt.Errorf("WHERE expression did not return bool") - return nil, err - } - - if !b { - // Where clause is not satisfied by the row - return nil, nil - } + ok, err := e.isPassingWhereClause(input) + if err != nil { + return nil, err + } + if !ok { + return nil, nil } if e.selectAST.Expression.All {