You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
344 lines
8.2 KiB
344 lines
8.2 KiB
6 years ago
|
/*
|
||
|
Copyright 2017 Google Inc.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package sqlparser
|
||
|
|
||
|
// analyzer.go contains utility analysis functions.
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"unicode"
|
||
|
|
||
|
"github.com/xwb1989/sqlparser/dependency/sqltypes"
|
||
|
)
|
||
|
|
||
|
// These constants are used to identify the SQL statement type.
|
||
|
const (
|
||
|
StmtSelect = iota
|
||
|
StmtStream
|
||
|
StmtInsert
|
||
|
StmtReplace
|
||
|
StmtUpdate
|
||
|
StmtDelete
|
||
|
StmtDDL
|
||
|
StmtBegin
|
||
|
StmtCommit
|
||
|
StmtRollback
|
||
|
StmtSet
|
||
|
StmtShow
|
||
|
StmtUse
|
||
|
StmtOther
|
||
|
StmtUnknown
|
||
|
StmtComment
|
||
|
)
|
||
|
|
||
|
// Preview analyzes the beginning of the query using a simpler and faster
|
||
|
// textual comparison to identify the statement type.
|
||
|
func Preview(sql string) int {
|
||
|
trimmed := StripLeadingComments(sql)
|
||
|
|
||
|
firstWord := trimmed
|
||
|
if end := strings.IndexFunc(trimmed, unicode.IsSpace); end != -1 {
|
||
|
firstWord = trimmed[:end]
|
||
|
}
|
||
|
firstWord = strings.TrimLeftFunc(firstWord, func(r rune) bool { return !unicode.IsLetter(r) })
|
||
|
// Comparison is done in order of priority.
|
||
|
loweredFirstWord := strings.ToLower(firstWord)
|
||
|
switch loweredFirstWord {
|
||
|
case "select":
|
||
|
return StmtSelect
|
||
|
case "stream":
|
||
|
return StmtStream
|
||
|
case "insert":
|
||
|
return StmtInsert
|
||
|
case "replace":
|
||
|
return StmtReplace
|
||
|
case "update":
|
||
|
return StmtUpdate
|
||
|
case "delete":
|
||
|
return StmtDelete
|
||
|
}
|
||
|
// For the following statements it is not sufficient to rely
|
||
|
// on loweredFirstWord. This is because they are not statements
|
||
|
// in the grammar and we are relying on Preview to parse them.
|
||
|
// For instance, we don't want: "BEGIN JUNK" to be parsed
|
||
|
// as StmtBegin.
|
||
|
trimmedNoComments, _ := SplitMarginComments(trimmed)
|
||
|
switch strings.ToLower(trimmedNoComments) {
|
||
|
case "begin", "start transaction":
|
||
|
return StmtBegin
|
||
|
case "commit":
|
||
|
return StmtCommit
|
||
|
case "rollback":
|
||
|
return StmtRollback
|
||
|
}
|
||
|
switch loweredFirstWord {
|
||
|
case "create", "alter", "rename", "drop", "truncate":
|
||
|
return StmtDDL
|
||
|
case "set":
|
||
|
return StmtSet
|
||
|
case "show":
|
||
|
return StmtShow
|
||
|
case "use":
|
||
|
return StmtUse
|
||
|
case "analyze", "describe", "desc", "explain", "repair", "optimize":
|
||
|
return StmtOther
|
||
|
}
|
||
|
if strings.Index(trimmed, "/*!") == 0 {
|
||
|
return StmtComment
|
||
|
}
|
||
|
return StmtUnknown
|
||
|
}
|
||
|
|
||
|
// StmtType returns the statement type as a string
|
||
|
func StmtType(stmtType int) string {
|
||
|
switch stmtType {
|
||
|
case StmtSelect:
|
||
|
return "SELECT"
|
||
|
case StmtStream:
|
||
|
return "STREAM"
|
||
|
case StmtInsert:
|
||
|
return "INSERT"
|
||
|
case StmtReplace:
|
||
|
return "REPLACE"
|
||
|
case StmtUpdate:
|
||
|
return "UPDATE"
|
||
|
case StmtDelete:
|
||
|
return "DELETE"
|
||
|
case StmtDDL:
|
||
|
return "DDL"
|
||
|
case StmtBegin:
|
||
|
return "BEGIN"
|
||
|
case StmtCommit:
|
||
|
return "COMMIT"
|
||
|
case StmtRollback:
|
||
|
return "ROLLBACK"
|
||
|
case StmtSet:
|
||
|
return "SET"
|
||
|
case StmtShow:
|
||
|
return "SHOW"
|
||
|
case StmtUse:
|
||
|
return "USE"
|
||
|
case StmtOther:
|
||
|
return "OTHER"
|
||
|
default:
|
||
|
return "UNKNOWN"
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// IsDML returns true if the query is an INSERT, UPDATE or DELETE statement.
|
||
|
func IsDML(sql string) bool {
|
||
|
switch Preview(sql) {
|
||
|
case StmtInsert, StmtReplace, StmtUpdate, StmtDelete:
|
||
|
return true
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// GetTableName returns the table name from the SimpleTableExpr
|
||
|
// only if it's a simple expression. Otherwise, it returns "".
|
||
|
func GetTableName(node SimpleTableExpr) TableIdent {
|
||
|
if n, ok := node.(TableName); ok && n.Qualifier.IsEmpty() {
|
||
|
return n.Name
|
||
|
}
|
||
|
// sub-select or '.' expression
|
||
|
return NewTableIdent("")
|
||
|
}
|
||
|
|
||
|
// IsColName returns true if the Expr is a *ColName.
|
||
|
func IsColName(node Expr) bool {
|
||
|
_, ok := node.(*ColName)
|
||
|
return ok
|
||
|
}
|
||
|
|
||
|
// IsValue returns true if the Expr is a string, integral or value arg.
|
||
|
// NULL is not considered to be a value.
|
||
|
func IsValue(node Expr) bool {
|
||
|
switch v := node.(type) {
|
||
|
case *SQLVal:
|
||
|
switch v.Type {
|
||
|
case StrVal, HexVal, IntVal, ValArg:
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// IsNull returns true if the Expr is SQL NULL
|
||
|
func IsNull(node Expr) bool {
|
||
|
switch node.(type) {
|
||
|
case *NullVal:
|
||
|
return true
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// IsSimpleTuple returns true if the Expr is a ValTuple that
|
||
|
// contains simple values or if it's a list arg.
|
||
|
func IsSimpleTuple(node Expr) bool {
|
||
|
switch vals := node.(type) {
|
||
|
case ValTuple:
|
||
|
for _, n := range vals {
|
||
|
if !IsValue(n) {
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
case ListArg:
|
||
|
return true
|
||
|
}
|
||
|
// It's a subquery
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// NewPlanValue builds a sqltypes.PlanValue from an Expr.
|
||
|
func NewPlanValue(node Expr) (sqltypes.PlanValue, error) {
|
||
|
switch node := node.(type) {
|
||
|
case *SQLVal:
|
||
|
switch node.Type {
|
||
|
case ValArg:
|
||
|
return sqltypes.PlanValue{Key: string(node.Val[1:])}, nil
|
||
|
case IntVal:
|
||
|
n, err := sqltypes.NewIntegral(string(node.Val))
|
||
|
if err != nil {
|
||
|
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
|
||
|
}
|
||
|
return sqltypes.PlanValue{Value: n}, nil
|
||
|
case StrVal:
|
||
|
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val)}, nil
|
||
|
case HexVal:
|
||
|
v, err := node.HexDecode()
|
||
|
if err != nil {
|
||
|
return sqltypes.PlanValue{}, fmt.Errorf("%v", err)
|
||
|
}
|
||
|
return sqltypes.PlanValue{Value: sqltypes.MakeTrusted(sqltypes.VarBinary, v)}, nil
|
||
|
}
|
||
|
case ListArg:
|
||
|
return sqltypes.PlanValue{ListKey: string(node[2:])}, nil
|
||
|
case ValTuple:
|
||
|
pv := sqltypes.PlanValue{
|
||
|
Values: make([]sqltypes.PlanValue, 0, len(node)),
|
||
|
}
|
||
|
for _, val := range node {
|
||
|
innerpv, err := NewPlanValue(val)
|
||
|
if err != nil {
|
||
|
return sqltypes.PlanValue{}, err
|
||
|
}
|
||
|
if innerpv.ListKey != "" || innerpv.Values != nil {
|
||
|
return sqltypes.PlanValue{}, errors.New("unsupported: nested lists")
|
||
|
}
|
||
|
pv.Values = append(pv.Values, innerpv)
|
||
|
}
|
||
|
return pv, nil
|
||
|
case *NullVal:
|
||
|
return sqltypes.PlanValue{}, nil
|
||
|
}
|
||
|
return sqltypes.PlanValue{}, fmt.Errorf("expression is too complex '%v'", String(node))
|
||
|
}
|
||
|
|
||
|
// StringIn is a convenience function that returns
|
||
|
// true if str matches any of the values.
|
||
|
func StringIn(str string, values ...string) bool {
|
||
|
for _, val := range values {
|
||
|
if str == val {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// SetKey is the extracted key from one SetExpr
|
||
|
type SetKey struct {
|
||
|
Key string
|
||
|
Scope string
|
||
|
}
|
||
|
|
||
|
// ExtractSetValues returns a map of key-value pairs
|
||
|
// if the query is a SET statement. Values can be bool, int64 or string.
|
||
|
// Since set variable names are case insensitive, all keys are returned
|
||
|
// as lower case.
|
||
|
func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope string, err error) {
|
||
|
stmt, err := Parse(sql)
|
||
|
if err != nil {
|
||
|
return nil, "", err
|
||
|
}
|
||
|
setStmt, ok := stmt.(*Set)
|
||
|
if !ok {
|
||
|
return nil, "", fmt.Errorf("ast did not yield *sqlparser.Set: %T", stmt)
|
||
|
}
|
||
|
result := make(map[SetKey]interface{})
|
||
|
for _, expr := range setStmt.Exprs {
|
||
|
scope := SessionStr
|
||
|
key := expr.Name.Lowered()
|
||
|
switch {
|
||
|
case strings.HasPrefix(key, "@@global."):
|
||
|
scope = GlobalStr
|
||
|
key = strings.TrimPrefix(key, "@@global.")
|
||
|
case strings.HasPrefix(key, "@@session."):
|
||
|
key = strings.TrimPrefix(key, "@@session.")
|
||
|
case strings.HasPrefix(key, "@@"):
|
||
|
key = strings.TrimPrefix(key, "@@")
|
||
|
}
|
||
|
|
||
|
if strings.HasPrefix(expr.Name.Lowered(), "@@") {
|
||
|
if setStmt.Scope != "" && scope != "" {
|
||
|
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
|
||
|
}
|
||
|
_, out := NewStringTokenizer(key).Scan()
|
||
|
key = string(out)
|
||
|
}
|
||
|
|
||
|
setKey := SetKey{
|
||
|
Key: key,
|
||
|
Scope: scope,
|
||
|
}
|
||
|
|
||
|
switch expr := expr.Expr.(type) {
|
||
|
case *SQLVal:
|
||
|
switch expr.Type {
|
||
|
case StrVal:
|
||
|
result[setKey] = strings.ToLower(string(expr.Val))
|
||
|
case IntVal:
|
||
|
num, err := strconv.ParseInt(string(expr.Val), 0, 64)
|
||
|
if err != nil {
|
||
|
return nil, "", err
|
||
|
}
|
||
|
result[setKey] = num
|
||
|
default:
|
||
|
return nil, "", fmt.Errorf("invalid value type: %v", String(expr))
|
||
|
}
|
||
|
case BoolVal:
|
||
|
var val int64
|
||
|
if expr {
|
||
|
val = 1
|
||
|
}
|
||
|
result[setKey] = val
|
||
|
case *ColName:
|
||
|
result[setKey] = expr.Name.String()
|
||
|
case *NullVal:
|
||
|
result[setKey] = nil
|
||
|
case *Default:
|
||
|
result[setKey] = "default"
|
||
|
default:
|
||
|
return nil, "", fmt.Errorf("invalid syntax: %s", String(expr))
|
||
|
}
|
||
|
}
|
||
|
return result, strings.ToLower(setStmt.Scope), nil
|
||
|
}
|