/* * Minio Cloud Storage, (C) 2018 Minio, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package s3select import ( "math" "strconv" "strings" "github.com/xwb1989/sqlparser" ) // SelectFuncs contains the relevant values from the parser for S3 Select // Functions type SelectFuncs struct { funcExpr []*sqlparser.FuncExpr index []int } // RunSqlParser allows us to easily bundle all the functions from above and run // them in the appropriate order. func (reader *Input) runSelectParser(selectExpression string, myRow chan *Row) { reqCols, alias, myLimit, whereClause, aggFunctionNames, myFuncs, myErr := reader.ParseSelect(selectExpression) if myErr != nil { rowStruct := &Row{ err: myErr, } myRow <- rowStruct return } reader.processSelectReq(reqCols, alias, whereClause, myLimit, aggFunctionNames, myRow, myFuncs) } // ParseSelect parses the SELECT expression, and effectively tokenizes it into // its separate parts. It returns the requested column names,alias,limit of // records, and the where clause. func (reader *Input) ParseSelect(sqlInput string) ([]string, string, int, interface{}, []string, *SelectFuncs, error) { // return columnNames, alias, limitOfRecords, whereclause,coalStore, nil stmt, err := sqlparser.Parse(sqlInput) var whereClause interface{} var alias string var limit int myFuncs := &SelectFuncs{} // TODO Maybe can parse their errors a bit to return some more of the s3 errors if err != nil { return nil, "", 0, nil, nil, nil, ErrLexerInvalidChar } switch stmt := stmt.(type) { case *sqlparser.Select: // evaluates the where clause functionNames := make([]string, len(stmt.SelectExprs)) columnNames := make([]string, len(stmt.SelectExprs)) if stmt.Where != nil { switch expr := stmt.Where.Expr.(type) { default: whereClause = expr case *sqlparser.ComparisonExpr: whereClause = expr } } if stmt.SelectExprs != nil { for i := 0; i < len(stmt.SelectExprs); i++ { switch expr := stmt.SelectExprs[i].(type) { case *sqlparser.StarExpr: columnNames[0] = "*" case *sqlparser.AliasedExpr: switch smallerexpr := expr.Expr.(type) { case *sqlparser.FuncExpr: if smallerexpr.IsAggregate() { functionNames[i] = smallerexpr.Name.CompliantName() // Will return function name // Case to deal with if we have functions and not an asterix switch tempagg := smallerexpr.Exprs[0].(type) { case *sqlparser.StarExpr: columnNames[0] = "*" if smallerexpr.Name.CompliantName() != "count" { return nil, "", 0, nil, nil, nil, ErrParseUnsupportedCallWithStar } case *sqlparser.AliasedExpr: switch col := tempagg.Expr.(type) { case *sqlparser.BinaryExpr: return nil, "", 0, nil, nil, nil, ErrParseNonUnaryAgregateFunctionCall case *sqlparser.ColName: columnNames[i] = col.Name.CompliantName() } } // Case to deal with if COALESCE was used.. } else if supportedFunc(smallerexpr.Name.CompliantName()) { if myFuncs.funcExpr == nil { myFuncs.funcExpr = make([]*sqlparser.FuncExpr, len(stmt.SelectExprs)) myFuncs.index = make([]int, len(stmt.SelectExprs)) } myFuncs.funcExpr[i] = smallerexpr myFuncs.index[i] = i } else { return nil, "", 0, nil, nil, nil, ErrUnsupportedSQLOperation } case *sqlparser.ColName: columnNames[i] = smallerexpr.Name.CompliantName() } } } } // This code retrieves the alias and makes sure it is set to the correct // value, if not it sets it to the tablename if (stmt.From) != nil { for i := 0; i < len(stmt.From); i++ { switch smallerexpr := stmt.From[i].(type) { case *sqlparser.JoinTableExpr: return nil, "", 0, nil, nil, nil, ErrParseMalformedJoin case *sqlparser.AliasedTableExpr: alias = smallerexpr.As.CompliantName() if alias == "" { alias = sqlparser.GetTableName(smallerexpr.Expr).CompliantName() } } } } if stmt.Limit != nil { switch expr := stmt.Limit.Rowcount.(type) { case *sqlparser.SQLVal: // The Value of how many rows we're going to limit by limit, _ = strconv.Atoi(string(expr.Val[:])) } } if stmt.GroupBy != nil { return nil, "", 0, nil, nil, nil, ErrParseUnsupportedLiteralsGroupBy } if stmt.OrderBy != nil { return nil, "", 0, nil, nil, nil, ErrParseUnsupportedToken } if err := reader.parseErrs(columnNames, whereClause, alias, myFuncs); err != nil { return nil, "", 0, nil, nil, nil, err } return columnNames, alias, limit, whereClause, functionNames, myFuncs, nil } return nil, "", 0, nil, nil, nil, nil } // This is the main function, It goes row by row and for records which validate // the where clause it currently prints the appropriate row given the requested // columns. func (reader *Input) processSelectReq(reqColNames []string, alias string, whereClause interface{}, limitOfRecords int, functionNames []string, myRow chan *Row, myFunc *SelectFuncs) { counter := -1 filtrCount := 0 functionFlag := false // My values is used to store our aggregation values if we need to store them. myAggVals := make([]float64, len(reqColNames)) var columns []string // LowercasecolumnsMap is used in accordance with hasDuplicates so that we can // raise the error "Ambigious" if a case insensitive column is provided and we // have multiple matches. lowercaseColumnsMap := make(map[string]int) hasDuplicates := make(map[string]bool) // ColumnsMap stores our columns and their index. columnsMap := make(map[string]int) if limitOfRecords == 0 { limitOfRecords = math.MaxInt64 } for { record := reader.ReadRecord() reader.stats.BytesProcessed += processSize(record) if record == nil { if functionFlag { rowStruct := &Row{ record: reader.aggFuncToStr(myAggVals) + "\n", } myRow <- rowStruct } close(myRow) return } if counter == -1 && reader.options.HeaderOpt && len(reader.header) > 0 { columns = reader.Header() myErr := checkForDuplicates(columns, columnsMap, hasDuplicates, lowercaseColumnsMap) if myErr != nil { rowStruct := &Row{ err: myErr, } myRow <- rowStruct return } } else if counter == -1 && len(reader.header) > 0 { columns = reader.Header() } // When we have reached our limit, on what the user specified as the number // of rows they wanted, we terminate our interpreter. if filtrCount == limitOfRecords && limitOfRecords != 0 { close(myRow) return } // The call to the where function clause,ensures that the rows we print match our where clause. condition, myErr := matchesMyWhereClause(record, columnsMap, alias, whereClause) if myErr != nil { rowStruct := &Row{ err: myErr, } myRow <- rowStruct return } if condition { // if its an asterix we just print everything in the row if reqColNames[0] == "*" && functionNames[0] == "" { rowStruct := &Row{ record: reader.printAsterix(record) + "\n", } myRow <- rowStruct } else if alias != "" { // This is for dealing with the case of if we have to deal with a // request for a column with an index e.g A_1. if representsInt(reqColNames[0]) { // This checks whether any aggregation function was called as now we // no longer will go through printing each row, and only print at the // end if len(functionNames) > 0 && functionNames[0] != "" { functionFlag = true aggregationFunctions(counter, filtrCount, myAggVals, columnsMap, reqColNames, functionNames, record) } else { // The code below finds the appropriate columns of the row given the // indicies provided in the SQL request and utilizes the map to // retrieve the correct part of the row. myQueryRow, myErr := reader.processColNameIndex(record, reqColNames, columns) if myErr != nil { rowStruct := &Row{ err: myErr, } myRow <- rowStruct return } rowStruct := &Row{ record: myQueryRow + "\n", } myRow <- rowStruct } } else { // This code does aggregation if we were provided column names in the // form of acutal names rather an indices. if len(functionNames) > 0 && functionNames[0] != "" { functionFlag = true aggregationFunctions(counter, filtrCount, myAggVals, columnsMap, reqColNames, functionNames, record) } else { // This code prints the appropriate part of the row given the filter // and select request, if the select request was based on column // names rather than indices. myQueryRow, myErr := reader.processColNameLiteral(record, reqColNames, columns, columnsMap, myFunc) if myErr != nil { rowStruct := &Row{ err: myErr, } myRow <- rowStruct return } rowStruct := &Row{ record: myQueryRow + "\n", } myRow <- rowStruct } } } filtrCount++ } counter++ } } // printAsterix helps to print out the entire row if an asterix is used. func (reader *Input) printAsterix(record []string) string { myRow := record[0] for i := 1; i < len(record); i++ { myRow = myRow + reader.options.OutputFieldDelimiter + record[i] } return myRow } // processColumnNames is a function which allows for cleaning of column names. func (reader *Input) processColumnNames(reqColNames []string, alias string) error { for i := 0; i < len(reqColNames); i++ { // The code below basically cleans the column name of its alias and other // syntax, so that we can extract its pure name. reqColNames[i] = cleanCol(reqColNames[i], alias) } return nil } // processColNameIndex is the function which creates the row for an index based // query. func (reader *Input) processColNameIndex(record []string, reqColNames []string, columns []string) (string, error) { myRow := "" for i := 0; i < len(reqColNames); i++ { // COALESCE AND NULLIF do not support index based access. if reqColNames[0] == "0" { return "", ErrInvalidColumnIndex } // Subtract 1 because AWS Indexing is not 0 based, it starts at 1. mytempindex, err := strconv.Atoi(reqColNames[i]) mytempindex = mytempindex - 1 if mytempindex > len(columns) { return "", ErrInvalidColumnIndex } myRow = writeRow(myRow, record[mytempindex], reader.options.OutputFieldDelimiter, len(reqColNames)) if err != nil { return "", ErrMissingHeaders } } if len(myRow) > 1000000 { return "", ErrOverMaxRecordSize } if strings.Count(myRow, reader.options.OutputFieldDelimiter) != len(reqColNames)-1 { myRow = qualityCheck(myRow, len(reqColNames)-1-strings.Count(myRow, reader.options.OutputFieldDelimiter), reader.options.OutputFieldDelimiter) } return myRow, nil } // processColNameLiteral is the function which creates the row for an name based // query. func (reader *Input) processColNameLiteral(record []string, reqColNames []string, columns []string, columnsMap map[string]int, myFunc *SelectFuncs) (string, error) { myRow := "" for i := 0; i < len(reqColNames); i++ { // this is the case to deal with COALESCE. if reqColNames[i] == "" && isValidFunc(myFunc.index, i) { myVal := evaluateFuncExpr(myFunc.funcExpr[i], "", record, columnsMap) myRow = writeRow(myRow, myVal, reader.options.OutputFieldDelimiter, len(reqColNames)) continue } myTempIndex, notFound := columnsMap[trimQuotes(reqColNames[i])] if !notFound { return "", ErrMissingHeaders } myRow = writeRow(myRow, record[myTempIndex], reader.options.OutputFieldDelimiter, len(reqColNames)) } if len(myRow) > 1000000 { return "", ErrOverMaxRecordSize } if strings.Count(myRow, reader.options.OutputFieldDelimiter) != len(reqColNames)-1 { myRow = qualityCheck(myRow, len(reqColNames)-1-strings.Count(myRow, reader.options.OutputFieldDelimiter), reader.options.OutputFieldDelimiter) } return myRow, nil } // aggregationFunctions is a function which performs the actual aggregation // methods on the given row, it uses an array defined the the main parsing // function to keep track of values. func aggregationFunctions(counter int, filtrCount int, myAggVals []float64, columnsMap map[string]int, storeReqCols []string, storeFunctions []string, record []string) error { for i := 0; i < len(storeFunctions); i++ { if storeFunctions[i] == "" { i++ } else if storeFunctions[i] == "count" { myAggVals[i]++ } else { // If column names are provided as an index it'll use this if statement instead of the else/ var convAggFloat float64 if representsInt(storeReqCols[i]) { myIndex, _ := strconv.Atoi(storeReqCols[i]) convAggFloat, _ = strconv.ParseFloat(record[myIndex], 64) } else { // case that the columns are in the form of named columns rather than indices. convAggFloat, _ = strconv.ParseFloat(record[columnsMap[trimQuotes(storeReqCols[i])]], 64) } // This if statement is for calculating the min. if storeFunctions[i] == "min" { if counter == -1 { myAggVals[i] = math.MaxFloat64 } if convAggFloat < myAggVals[i] { myAggVals[i] = convAggFloat } } else if storeFunctions[i] == "max" { // This if statement is for calculating the max. if counter == -1 { myAggVals[i] = math.SmallestNonzeroFloat64 } if convAggFloat > myAggVals[i] { myAggVals[i] = convAggFloat } } else if storeFunctions[i] == "sum" { // This if statement is for calculating the sum. myAggVals[i] += convAggFloat } else if storeFunctions[i] == "avg" { // This if statement is for calculating the average. if filtrCount == 0 { myAggVals[i] = convAggFloat } else { myAggVals[i] = (convAggFloat + (myAggVals[i] * float64(filtrCount))) / float64((filtrCount + 1)) } } else { return ErrParseNonUnaryAgregateFunctionCall } } } return nil }