You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
550 lines
13 KiB
550 lines
13 KiB
/*
|
|
* Minio Cloud Storage, (C) 2019 Minio, Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package sql
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// FuncName - SQL function name.
|
|
type FuncName string
|
|
|
|
const (
|
|
// Avg - aggregate SQL function AVG().
|
|
Avg FuncName = "AVG"
|
|
|
|
// Count - aggregate SQL function COUNT().
|
|
Count FuncName = "COUNT"
|
|
|
|
// Max - aggregate SQL function MAX().
|
|
Max FuncName = "MAX"
|
|
|
|
// Min - aggregate SQL function MIN().
|
|
Min FuncName = "MIN"
|
|
|
|
// Sum - aggregate SQL function SUM().
|
|
Sum FuncName = "SUM"
|
|
|
|
// Coalesce - conditional SQL function COALESCE().
|
|
Coalesce FuncName = "COALESCE"
|
|
|
|
// NullIf - conditional SQL function NULLIF().
|
|
NullIf FuncName = "NULLIF"
|
|
|
|
// ToTimestamp - conversion SQL function TO_TIMESTAMP().
|
|
ToTimestamp FuncName = "TO_TIMESTAMP"
|
|
|
|
// UTCNow - date SQL function UTCNOW().
|
|
UTCNow FuncName = "UTCNOW"
|
|
|
|
// CharLength - string SQL function CHAR_LENGTH().
|
|
CharLength FuncName = "CHAR_LENGTH"
|
|
|
|
// CharacterLength - string SQL function CHARACTER_LENGTH() same as CHAR_LENGTH().
|
|
CharacterLength FuncName = "CHARACTER_LENGTH"
|
|
|
|
// Lower - string SQL function LOWER().
|
|
Lower FuncName = "LOWER"
|
|
|
|
// Substring - string SQL function SUBSTRING().
|
|
Substring FuncName = "SUBSTRING"
|
|
|
|
// Trim - string SQL function TRIM().
|
|
Trim FuncName = "TRIM"
|
|
|
|
// Upper - string SQL function UPPER().
|
|
Upper FuncName = "UPPER"
|
|
|
|
// DateAdd FuncName = "DATE_ADD"
|
|
// DateDiff FuncName = "DATE_DIFF"
|
|
// Extract FuncName = "EXTRACT"
|
|
// ToString FuncName = "TO_STRING"
|
|
// Cast FuncName = "CAST" // CAST('2007-04-05T14:30Z' AS TIMESTAMP)
|
|
)
|
|
|
|
func isAggregateFuncName(s string) bool {
|
|
switch FuncName(s) {
|
|
case Avg, Count, Max, Min, Sum:
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func callForNumber(f Expr, record Record) (*Value, error) {
|
|
value, err := f.Eval(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !value.Type().isNumber() {
|
|
err := fmt.Errorf("%v evaluated to %v; not to number", f, value.Type())
|
|
return nil, errExternalEvalException(err)
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
func callForInt(f Expr, record Record) (*Value, error) {
|
|
value, err := f.Eval(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if value.Type() != Int {
|
|
err := fmt.Errorf("%v evaluated to %v; not to int", f, value.Type())
|
|
return nil, errExternalEvalException(err)
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
func callForString(f Expr, record Record) (*Value, error) {
|
|
value, err := f.Eval(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if value.Type() != String {
|
|
err := fmt.Errorf("%v evaluated to %v; not to string", f, value.Type())
|
|
return nil, errExternalEvalException(err)
|
|
}
|
|
|
|
return value, nil
|
|
}
|
|
|
|
// funcExpr - SQL function.
|
|
type funcExpr struct {
|
|
args []Expr
|
|
name FuncName
|
|
|
|
sumValue float64
|
|
countValue int64
|
|
maxValue float64
|
|
minValue float64
|
|
}
|
|
|
|
// String - returns string representation of this function.
|
|
func (f *funcExpr) String() string {
|
|
var argStrings []string
|
|
for _, arg := range f.args {
|
|
argStrings = append(argStrings, fmt.Sprintf("%v", arg))
|
|
}
|
|
|
|
return fmt.Sprintf("%v(%v)", f.name, strings.Join(argStrings, ","))
|
|
}
|
|
|
|
func (f *funcExpr) sum(record Record) (*Value, error) {
|
|
value, err := callForNumber(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f.sumValue += value.FloatValue()
|
|
f.countValue++
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *funcExpr) count(record Record) (*Value, error) {
|
|
value, err := f.args[0].Eval(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if value.valueType != Null {
|
|
f.countValue++
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *funcExpr) max(record Record) (*Value, error) {
|
|
value, err := callForNumber(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
v := value.FloatValue()
|
|
if v > f.maxValue {
|
|
f.maxValue = v
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *funcExpr) min(record Record) (*Value, error) {
|
|
value, err := callForNumber(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
v := value.FloatValue()
|
|
if v < f.minValue {
|
|
f.minValue = v
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *funcExpr) charLength(record Record) (*Value, error) {
|
|
value, err := callForString(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewInt(int64(len(value.StringValue()))), nil
|
|
}
|
|
|
|
func (f *funcExpr) trim(record Record) (*Value, error) {
|
|
value, err := callForString(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewString(strings.TrimSpace(value.StringValue())), nil
|
|
}
|
|
|
|
func (f *funcExpr) lower(record Record) (*Value, error) {
|
|
value, err := callForString(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewString(strings.ToLower(value.StringValue())), nil
|
|
}
|
|
|
|
func (f *funcExpr) upper(record Record) (*Value, error) {
|
|
value, err := callForString(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return NewString(strings.ToUpper(value.StringValue())), nil
|
|
}
|
|
|
|
func (f *funcExpr) substring(record Record) (*Value, error) {
|
|
stringValue, err := callForString(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
offsetValue, err := callForInt(f.args[1], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var lengthValue *Value
|
|
if len(f.args) == 3 {
|
|
lengthValue, err = callForInt(f.args[2], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
value := stringValue.StringValue()
|
|
offset := int(offsetValue.FloatValue())
|
|
if offset < 0 || offset > len(value) {
|
|
offset = 0
|
|
}
|
|
length := len(value)
|
|
if lengthValue != nil {
|
|
length = int(lengthValue.FloatValue())
|
|
if length < 0 || length > len(value) {
|
|
length = len(value)
|
|
}
|
|
}
|
|
|
|
return NewString(value[offset:length]), nil
|
|
}
|
|
|
|
func (f *funcExpr) coalesce(record Record) (*Value, error) {
|
|
values := make([]*Value, len(f.args))
|
|
var err error
|
|
for i := range f.args {
|
|
values[i], err = f.args[i].Eval(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
for i := range values {
|
|
if values[i].Type() != Null {
|
|
return values[i], nil
|
|
}
|
|
}
|
|
|
|
return values[0], nil
|
|
}
|
|
|
|
func (f *funcExpr) nullIf(record Record) (*Value, error) {
|
|
value1, err := f.args[0].Eval(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
value2, err := f.args[1].Eval(record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result, err := equal(value1, value2)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if result {
|
|
return NewNull(), nil
|
|
}
|
|
|
|
return value1, nil
|
|
}
|
|
|
|
func (f *funcExpr) toTimeStamp(record Record) (*Value, error) {
|
|
value, err := callForString(f.args[0], record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
t, err := time.Parse(time.RFC3339, value.StringValue())
|
|
if err != nil {
|
|
err := fmt.Errorf("%v: value '%v': %v", f, value, err)
|
|
return nil, errValueParseFailure(err)
|
|
}
|
|
|
|
return NewTime(t), nil
|
|
}
|
|
|
|
func (f *funcExpr) utcNow(record Record) (*Value, error) {
|
|
return NewTime(time.Now().UTC()), nil
|
|
}
|
|
|
|
// Call - evaluates this function for given arg values and returns result as Value.
|
|
func (f *funcExpr) Eval(record Record) (*Value, error) {
|
|
switch f.name {
|
|
case Avg, Sum:
|
|
return f.sum(record)
|
|
case Count:
|
|
return f.count(record)
|
|
case Max:
|
|
return f.max(record)
|
|
case Min:
|
|
return f.min(record)
|
|
case Coalesce:
|
|
return f.coalesce(record)
|
|
case NullIf:
|
|
return f.nullIf(record)
|
|
case ToTimestamp:
|
|
return f.toTimeStamp(record)
|
|
case UTCNow:
|
|
return f.utcNow(record)
|
|
case Substring:
|
|
return f.substring(record)
|
|
case CharLength, CharacterLength:
|
|
return f.charLength(record)
|
|
case Trim:
|
|
return f.trim(record)
|
|
case Lower:
|
|
return f.lower(record)
|
|
case Upper:
|
|
return f.upper(record)
|
|
}
|
|
|
|
panic(fmt.Sprintf("unsupported aggregate function %v", f.name))
|
|
}
|
|
|
|
// AggregateValue - returns aggregated value.
|
|
func (f *funcExpr) AggregateValue() (*Value, error) {
|
|
switch f.name {
|
|
case Avg:
|
|
return NewFloat(f.sumValue / float64(f.countValue)), nil
|
|
case Count:
|
|
return NewInt(f.countValue), nil
|
|
case Max:
|
|
return NewFloat(f.maxValue), nil
|
|
case Min:
|
|
return NewFloat(f.minValue), nil
|
|
case Sum:
|
|
return NewFloat(f.sumValue), nil
|
|
}
|
|
|
|
err := fmt.Errorf("%v is not aggreate function", f)
|
|
return nil, errExternalEvalException(err)
|
|
}
|
|
|
|
// Type - returns Function or aggregateFunction type.
|
|
func (f *funcExpr) Type() Type {
|
|
switch f.name {
|
|
case Avg, Count, Max, Min, Sum:
|
|
return aggregateFunction
|
|
}
|
|
|
|
return function
|
|
}
|
|
|
|
// ReturnType - returns respective primitive type depending on SQL function.
|
|
func (f *funcExpr) ReturnType() Type {
|
|
switch f.name {
|
|
case Avg, Max, Min, Sum:
|
|
return Float
|
|
case Count:
|
|
return Int
|
|
case CharLength, CharacterLength, Trim, Lower, Upper, Substring:
|
|
return String
|
|
case ToTimestamp, UTCNow:
|
|
return Timestamp
|
|
case Coalesce, NullIf:
|
|
return column
|
|
}
|
|
|
|
return function
|
|
}
|
|
|
|
// newFuncExpr - creates new SQL function.
|
|
func newFuncExpr(funcName FuncName, funcs ...Expr) (*funcExpr, error) {
|
|
switch funcName {
|
|
case Avg, Max, Min, Sum:
|
|
if len(funcs) != 1 {
|
|
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
|
return nil, errParseNonUnaryAgregateFunctionCall(err)
|
|
}
|
|
|
|
if !funcs[0].ReturnType().isNumberKind() {
|
|
err := fmt.Errorf("%v(): argument %v evaluate to %v, not number", funcName, funcs[0], funcs[0].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
|
|
return &funcExpr{
|
|
args: funcs,
|
|
name: funcName,
|
|
}, nil
|
|
|
|
case Count:
|
|
if len(funcs) != 1 {
|
|
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
|
return nil, errParseNonUnaryAgregateFunctionCall(err)
|
|
}
|
|
|
|
switch funcs[0].ReturnType() {
|
|
case Null, Bool, Int, Float, String, Timestamp, column, record:
|
|
default:
|
|
err := fmt.Errorf("%v(): argument %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
|
|
return &funcExpr{
|
|
args: funcs,
|
|
name: funcName,
|
|
}, nil
|
|
|
|
case CharLength, CharacterLength, Trim, Lower, Upper, ToTimestamp:
|
|
if len(funcs) != 1 {
|
|
err := fmt.Errorf("%v(): exactly one argument expected; got %v", funcName, len(funcs))
|
|
return nil, errEvaluatorInvalidArguments(err)
|
|
}
|
|
|
|
if !funcs[0].ReturnType().isStringKind() {
|
|
err := fmt.Errorf("%v(): argument %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
|
|
return &funcExpr{
|
|
args: funcs,
|
|
name: funcName,
|
|
}, nil
|
|
|
|
case Coalesce:
|
|
if len(funcs) < 1 {
|
|
err := fmt.Errorf("%v(): one or more argument expected; got %v", funcName, len(funcs))
|
|
return nil, errEvaluatorInvalidArguments(err)
|
|
}
|
|
|
|
for i := range funcs {
|
|
if !funcs[i].ReturnType().isBaseKind() {
|
|
err := fmt.Errorf("%v(): argument-%v %v evaluate to %v is incompatible", funcName, i+1, funcs[i], funcs[i].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
}
|
|
|
|
return &funcExpr{
|
|
args: funcs,
|
|
name: funcName,
|
|
}, nil
|
|
|
|
case NullIf:
|
|
if len(funcs) != 2 {
|
|
err := fmt.Errorf("%v(): exactly two arguments expected; got %v", funcName, len(funcs))
|
|
return nil, errEvaluatorInvalidArguments(err)
|
|
}
|
|
|
|
if !funcs[0].ReturnType().isBaseKind() {
|
|
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v is incompatible", funcName, funcs[0], funcs[0].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
|
|
if !funcs[1].ReturnType().isBaseKind() {
|
|
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v is incompatible", funcName, funcs[1], funcs[1].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
|
|
return &funcExpr{
|
|
args: funcs,
|
|
name: funcName,
|
|
}, nil
|
|
|
|
case UTCNow:
|
|
if len(funcs) != 0 {
|
|
err := fmt.Errorf("%v(): no argument expected; got %v", funcName, len(funcs))
|
|
return nil, errEvaluatorInvalidArguments(err)
|
|
}
|
|
|
|
return &funcExpr{
|
|
args: funcs,
|
|
name: funcName,
|
|
}, nil
|
|
|
|
case Substring:
|
|
if len(funcs) < 2 || len(funcs) > 3 {
|
|
err := fmt.Errorf("%v(): exactly two or three arguments expected; got %v", funcName, len(funcs))
|
|
return nil, errEvaluatorInvalidArguments(err)
|
|
}
|
|
|
|
if !funcs[0].ReturnType().isStringKind() {
|
|
err := fmt.Errorf("%v(): argument-1 %v evaluate to %v, not string", funcName, funcs[0], funcs[0].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
|
|
if !funcs[1].ReturnType().isIntKind() {
|
|
err := fmt.Errorf("%v(): argument-2 %v evaluate to %v, not int", funcName, funcs[1], funcs[1].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
|
|
if len(funcs) > 2 {
|
|
if !funcs[2].ReturnType().isIntKind() {
|
|
err := fmt.Errorf("%v(): argument-3 %v evaluate to %v, not int", funcName, funcs[2], funcs[2].ReturnType())
|
|
return nil, errIncorrectSQLFunctionArgumentType(err)
|
|
}
|
|
}
|
|
|
|
return &funcExpr{
|
|
args: funcs,
|
|
name: funcName,
|
|
}, nil
|
|
}
|
|
|
|
return nil, errUnsupportedFunction(fmt.Errorf("unknown function name %v", funcName))
|
|
}
|
|
|