You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

3450 lines
74 KiB

/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlparser
import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"strings"
"github.com/xwb1989/sqlparser/dependency/querypb"
"github.com/xwb1989/sqlparser/dependency/sqltypes"
)
// Instructions for creating new types: If a type
// needs to satisfy an interface, declare that function
// along with that interface. This will help users
// identify the list of types to which they can assert
// those interfaces.
// If the member of a type has a string with a predefined
// list of values, declare those values as const following
// the type.
// For interfaces that define dummy functions to consolidate
// a set of types, define the function as iTypeName.
// This will help avoid name collisions.
// Parse parses the SQL in full and returns a Statement, which
// is the AST representation of the query. If a DDL statement
// is partially parsed but still contains a syntax error, the
// error is ignored and the DDL is returned anyway.
func Parse(sql string) (Statement, error) {
tokenizer := NewStringTokenizer(sql)
if yyParse(tokenizer) != 0 {
if tokenizer.partialDDL != nil {
log.Printf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError)
tokenizer.ParseTree = tokenizer.partialDDL
return tokenizer.ParseTree, nil
}
return nil, tokenizer.LastError
}
return tokenizer.ParseTree, nil
}
// ParseStrictDDL is the same as Parse except it errors on
// partially parsed DDL statements.
func ParseStrictDDL(sql string) (Statement, error) {
tokenizer := NewStringTokenizer(sql)
if yyParse(tokenizer) != 0 {
return nil, tokenizer.LastError
}
return tokenizer.ParseTree, nil
}
// ParseNext parses a single SQL statement from the tokenizer
// returning a Statement which is the AST representation of the query.
// The tokenizer will always read up to the end of the statement, allowing for
// the next call to ParseNext to parse any subsequent SQL statements. When
// there are no more statements to parse, a error of io.EOF is returned.
func ParseNext(tokenizer *Tokenizer) (Statement, error) {
if tokenizer.lastChar == ';' {
tokenizer.next()
tokenizer.skipBlank()
}
if tokenizer.lastChar == eofChar {
return nil, io.EOF
}
tokenizer.reset()
tokenizer.multi = true
if yyParse(tokenizer) != 0 {
if tokenizer.partialDDL != nil {
tokenizer.ParseTree = tokenizer.partialDDL
return tokenizer.ParseTree, nil
}
return nil, tokenizer.LastError
}
return tokenizer.ParseTree, nil
}
// SplitStatement returns the first sql statement up to either a ; or EOF
// and the remainder from the given buffer
func SplitStatement(blob string) (string, string, error) {
tokenizer := NewStringTokenizer(blob)
tkn := 0
for {
tkn, _ = tokenizer.Scan()
if tkn == 0 || tkn == ';' || tkn == eofChar {
break
}
}
if tokenizer.LastError != nil {
return "", "", tokenizer.LastError
}
if tkn == ';' {
return blob[:tokenizer.Position-2], blob[tokenizer.Position-1:], nil
}
return blob, "", nil
}
// SplitStatementToPieces split raw sql statement that may have multi sql pieces to sql pieces
// returns the sql pieces blob contains; or error if sql cannot be parsed
func SplitStatementToPieces(blob string) (pieces []string, err error) {
pieces = make([]string, 0, 16)
tokenizer := NewStringTokenizer(blob)
tkn := 0
var stmt string
stmtBegin := 0
for {
tkn, _ = tokenizer.Scan()
if tkn == ';' {
stmt = blob[stmtBegin : tokenizer.Position-2]
pieces = append(pieces, stmt)
stmtBegin = tokenizer.Position - 1
} else if tkn == 0 || tkn == eofChar {
blobTail := tokenizer.Position - 2
if stmtBegin < blobTail {
stmt = blob[stmtBegin : blobTail+1]
pieces = append(pieces, stmt)
}
break
}
}
err = tokenizer.LastError
return
}
// SQLNode defines the interface for all nodes
// generated by the parser.
type SQLNode interface {
Format(buf *TrackedBuffer)
// walkSubtree calls visit on all underlying nodes
// of the subtree, but not the current one. Walking
// must be interrupted if visit returns an error.
walkSubtree(visit Visit) error
}
// Visit defines the signature of a function that
// can be used to visit all nodes of a parse tree.
type Visit func(node SQLNode) (kontinue bool, err error)
// Walk calls visit on every node.
// If visit returns true, the underlying nodes
// are also visited. If it returns an error, walking
// is interrupted, and the error is returned.
func Walk(visit Visit, nodes ...SQLNode) error {
for _, node := range nodes {
if node == nil {
continue
}
kontinue, err := visit(node)
if err != nil {
return err
}
if kontinue {
err = node.walkSubtree(visit)
if err != nil {
return err
}
}
}
return nil
}
// String returns a string representation of an SQLNode.
func String(node SQLNode) string {
if node == nil {
return "<nil>"
}
buf := NewTrackedBuffer(nil)
buf.Myprintf("%v", node)
return buf.String()
}
// Append appends the SQLNode to the buffer.
func Append(buf *bytes.Buffer, node SQLNode) {
tbuf := &TrackedBuffer{
Buffer: buf,
}
node.Format(tbuf)
}
// Statement represents a statement.
type Statement interface {
iStatement()
SQLNode
}
func (*Union) iStatement() {}
func (*Select) iStatement() {}
func (*Stream) iStatement() {}
func (*Insert) iStatement() {}
func (*Update) iStatement() {}
func (*Delete) iStatement() {}
func (*Set) iStatement() {}
func (*DBDDL) iStatement() {}
func (*DDL) iStatement() {}
func (*Show) iStatement() {}
func (*Use) iStatement() {}
func (*Begin) iStatement() {}
func (*Commit) iStatement() {}
func (*Rollback) iStatement() {}
func (*OtherRead) iStatement() {}
func (*OtherAdmin) iStatement() {}
// ParenSelect can actually not be a top level statement,
// but we have to allow it because it's a requirement
// of SelectStatement.
func (*ParenSelect) iStatement() {}
// SelectStatement any SELECT statement.
type SelectStatement interface {
iSelectStatement()
iStatement()
iInsertRows()
AddOrder(*Order)
SetLimit(*Limit)
SQLNode
}
func (*Select) iSelectStatement() {}
func (*Union) iSelectStatement() {}
func (*ParenSelect) iSelectStatement() {}
// Select represents a SELECT statement.
type Select struct {
Cache string
Comments Comments
Distinct string
Hints string
SelectExprs SelectExprs
From TableExprs
Where *Where
GroupBy GroupBy
Having *Where
OrderBy OrderBy
Limit *Limit
Lock string
}
// Select.Distinct
const (
DistinctStr = "distinct "
StraightJoinHint = "straight_join "
)
// Select.Lock
const (
ForUpdateStr = " for update"
ShareModeStr = " lock in share mode"
)
// Select.Cache
const (
SQLCacheStr = "sql_cache "
SQLNoCacheStr = "sql_no_cache "
)
// AddOrder adds an order by element
func (node *Select) AddOrder(order *Order) {
node.OrderBy = append(node.OrderBy, order)
}
// SetLimit sets the limit clause
func (node *Select) SetLimit(limit *Limit) {
node.Limit = limit
}
// Format formats the node.
func (node *Select) Format(buf *TrackedBuffer) {
buf.Myprintf("select %v%s%s%s%v from %v%v%v%v%v%v%s",
node.Comments, node.Cache, node.Distinct, node.Hints, node.SelectExprs,
node.From, node.Where,
node.GroupBy, node.Having, node.OrderBy,
node.Limit, node.Lock)
}
func (node *Select) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Comments,
node.SelectExprs,
node.From,
node.Where,
node.GroupBy,
node.Having,
node.OrderBy,
node.Limit,
)
}
// AddWhere adds the boolean expression to the
// WHERE clause as an AND condition. If the expression
// is an OR clause, it parenthesizes it. Currently,
// the OR operator is the only one that's lower precedence
// than AND.
func (node *Select) AddWhere(expr Expr) {
if _, ok := expr.(*OrExpr); ok {
expr = &ParenExpr{Expr: expr}
}
if node.Where == nil {
node.Where = &Where{
Type: WhereStr,
Expr: expr,
}
return
}
node.Where.Expr = &AndExpr{
Left: node.Where.Expr,
Right: expr,
}
return
}
// AddHaving adds the boolean expression to the
// HAVING clause as an AND condition. If the expression
// is an OR clause, it parenthesizes it. Currently,
// the OR operator is the only one that's lower precedence
// than AND.
func (node *Select) AddHaving(expr Expr) {
if _, ok := expr.(*OrExpr); ok {
expr = &ParenExpr{Expr: expr}
}
if node.Having == nil {
node.Having = &Where{
Type: HavingStr,
Expr: expr,
}
return
}
node.Having.Expr = &AndExpr{
Left: node.Having.Expr,
Right: expr,
}
return
}
// ParenSelect is a parenthesized SELECT statement.
type ParenSelect struct {
Select SelectStatement
}
// AddOrder adds an order by element
func (node *ParenSelect) AddOrder(order *Order) {
panic("unreachable")
}
// SetLimit sets the limit clause
func (node *ParenSelect) SetLimit(limit *Limit) {
panic("unreachable")
}
// Format formats the node.
func (node *ParenSelect) Format(buf *TrackedBuffer) {
buf.Myprintf("(%v)", node.Select)
}
func (node *ParenSelect) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Select,
)
}
// Union represents a UNION statement.
type Union struct {
Type string
Left, Right SelectStatement
OrderBy OrderBy
Limit *Limit
Lock string
}
// Union.Type
const (
UnionStr = "union"
UnionAllStr = "union all"
UnionDistinctStr = "union distinct"
)
// AddOrder adds an order by element
func (node *Union) AddOrder(order *Order) {
node.OrderBy = append(node.OrderBy, order)
}
// SetLimit sets the limit clause
func (node *Union) SetLimit(limit *Limit) {
node.Limit = limit
}
// Format formats the node.
func (node *Union) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %s %v%v%v%s", node.Left, node.Type, node.Right,
node.OrderBy, node.Limit, node.Lock)
}
func (node *Union) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Left,
node.Right,
)
}
// Stream represents a SELECT statement.
type Stream struct {
Comments Comments
SelectExpr SelectExpr
Table TableName
}
// Format formats the node.
func (node *Stream) Format(buf *TrackedBuffer) {
buf.Myprintf("stream %v%v from %v",
node.Comments, node.SelectExpr, node.Table)
}
func (node *Stream) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Comments,
node.SelectExpr,
node.Table,
)
}
// Insert represents an INSERT or REPLACE statement.
// Per the MySQL docs, http://dev.mysql.com/doc/refman/5.7/en/replace.html
// Replace is the counterpart to `INSERT IGNORE`, and works exactly like a
// normal INSERT except if the row exists. In that case it first deletes
// the row and re-inserts with new values. For that reason we keep it as an Insert struct.
// Replaces are currently disallowed in sharded schemas because
// of the implications the deletion part may have on vindexes.
// If you add fields here, consider adding them to calls to validateSubquerySamePlan.
type Insert struct {
Action string
Comments Comments
Ignore string
Table TableName
Partitions Partitions
Columns Columns
Rows InsertRows
OnDup OnDup
}
// DDL strings.
const (
InsertStr = "insert"
ReplaceStr = "replace"
)
// Format formats the node.
func (node *Insert) Format(buf *TrackedBuffer) {
buf.Myprintf("%s %v%sinto %v%v%v %v%v",
node.Action,
node.Comments, node.Ignore,
node.Table, node.Partitions, node.Columns, node.Rows, node.OnDup)
}
func (node *Insert) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Comments,
node.Table,
node.Columns,
node.Rows,
node.OnDup,
)
}
// InsertRows represents the rows for an INSERT statement.
type InsertRows interface {
iInsertRows()
SQLNode
}
func (*Select) iInsertRows() {}
func (*Union) iInsertRows() {}
func (Values) iInsertRows() {}
func (*ParenSelect) iInsertRows() {}
// Update represents an UPDATE statement.
// If you add fields here, consider adding them to calls to validateSubquerySamePlan.
type Update struct {
Comments Comments
TableExprs TableExprs
Exprs UpdateExprs
Where *Where
OrderBy OrderBy
Limit *Limit
}
// Format formats the node.
func (node *Update) Format(buf *TrackedBuffer) {
buf.Myprintf("update %v%v set %v%v%v%v",
node.Comments, node.TableExprs,
node.Exprs, node.Where, node.OrderBy, node.Limit)
}
func (node *Update) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Comments,
node.TableExprs,
node.Exprs,
node.Where,
node.OrderBy,
node.Limit,
)
}
// Delete represents a DELETE statement.
// If you add fields here, consider adding them to calls to validateSubquerySamePlan.
type Delete struct {
Comments Comments
Targets TableNames
TableExprs TableExprs
Partitions Partitions
Where *Where
OrderBy OrderBy
Limit *Limit
}
// Format formats the node.
func (node *Delete) Format(buf *TrackedBuffer) {
buf.Myprintf("delete %v", node.Comments)
if node.Targets != nil {
buf.Myprintf("%v ", node.Targets)
}
buf.Myprintf("from %v%v%v%v%v", node.TableExprs, node.Partitions, node.Where, node.OrderBy, node.Limit)
}
func (node *Delete) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Comments,
node.Targets,
node.TableExprs,
node.Where,
node.OrderBy,
node.Limit,
)
}
// Set represents a SET statement.
type Set struct {
Comments Comments
Exprs SetExprs
Scope string
}
// Set.Scope or Show.Scope
const (
SessionStr = "session"
GlobalStr = "global"
)
// Format formats the node.
func (node *Set) Format(buf *TrackedBuffer) {
if node.Scope == "" {
buf.Myprintf("set %v%v", node.Comments, node.Exprs)
} else {
buf.Myprintf("set %v%s %v", node.Comments, node.Scope, node.Exprs)
}
}
func (node *Set) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Comments,
node.Exprs,
)
}
// DBDDL represents a CREATE, DROP database statement.
type DBDDL struct {
Action string
DBName string
IfExists bool
Collate string
Charset string
}
// Format formats the node.
func (node *DBDDL) Format(buf *TrackedBuffer) {
switch node.Action {
case CreateStr:
buf.WriteString(fmt.Sprintf("%s database %s", node.Action, node.DBName))
case DropStr:
exists := ""
if node.IfExists {
exists = " if exists"
}
buf.WriteString(fmt.Sprintf("%s database%s %v", node.Action, exists, node.DBName))
}
}
// walkSubtree walks the nodes of the subtree.
func (node *DBDDL) walkSubtree(visit Visit) error {
return nil
}
// DDL represents a CREATE, ALTER, DROP, RENAME or TRUNCATE statement.
// Table is set for AlterStr, DropStr, RenameStr, TruncateStr
// NewName is set for AlterStr, CreateStr, RenameStr.
// VindexSpec is set for CreateVindexStr, DropVindexStr, AddColVindexStr, DropColVindexStr
// VindexCols is set for AddColVindexStr
type DDL struct {
Action string
Table TableName
NewName TableName
IfExists bool
TableSpec *TableSpec
PartitionSpec *PartitionSpec
VindexSpec *VindexSpec
VindexCols []ColIdent
}
// DDL strings.
const (
CreateStr = "create"
AlterStr = "alter"
DropStr = "drop"
RenameStr = "rename"
TruncateStr = "truncate"
CreateVindexStr = "create vindex"
AddColVindexStr = "add vindex"
DropColVindexStr = "drop vindex"
// Vindex DDL param to specify the owner of a vindex
VindexOwnerStr = "owner"
)
// Format formats the node.
func (node *DDL) Format(buf *TrackedBuffer) {
switch node.Action {
case CreateStr:
if node.TableSpec == nil {
buf.Myprintf("%s table %v", node.Action, node.NewName)
} else {
buf.Myprintf("%s table %v %v", node.Action, node.NewName, node.TableSpec)
}
case DropStr:
exists := ""
if node.IfExists {
exists = " if exists"
}
buf.Myprintf("%s table%s %v", node.Action, exists, node.Table)
case RenameStr:
buf.Myprintf("%s table %v to %v", node.Action, node.Table, node.NewName)
case AlterStr:
if node.PartitionSpec != nil {
buf.Myprintf("%s table %v %v", node.Action, node.Table, node.PartitionSpec)
} else {
buf.Myprintf("%s table %v", node.Action, node.Table)
}
case CreateVindexStr:
buf.Myprintf("%s %v %v", node.Action, node.VindexSpec.Name, node.VindexSpec)
case AddColVindexStr:
buf.Myprintf("alter table %v %s %v (", node.Table, node.Action, node.VindexSpec.Name)
for i, col := range node.VindexCols {
if i != 0 {
buf.Myprintf(", %v", col)
} else {
buf.Myprintf("%v", col)
}
}
buf.Myprintf(")")
if node.VindexSpec.Type.String() != "" {
buf.Myprintf(" %v", node.VindexSpec)
}
case DropColVindexStr:
buf.Myprintf("alter table %v %s %v", node.Table, node.Action, node.VindexSpec.Name)
default:
buf.Myprintf("%s table %v", node.Action, node.Table)
}
}
func (node *DDL) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Table,
node.NewName,
)
}
// Partition strings
const (
ReorganizeStr = "reorganize partition"
)
// PartitionSpec describe partition actions (for alter and create)
type PartitionSpec struct {
Action string
Name ColIdent
Definitions []*PartitionDefinition
}
// Format formats the node.
func (node *PartitionSpec) Format(buf *TrackedBuffer) {
switch node.Action {
case ReorganizeStr:
buf.Myprintf("%s %v into (", node.Action, node.Name)
var prefix string
for _, pd := range node.Definitions {
buf.Myprintf("%s%v", prefix, pd)
prefix = ", "
}
buf.Myprintf(")")
default:
panic("unimplemented")
}
}
func (node *PartitionSpec) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
if err := Walk(visit, node.Name); err != nil {
return err
}
for _, def := range node.Definitions {
if err := Walk(visit, def); err != nil {
return err
}
}
return nil
}
// PartitionDefinition describes a very minimal partition definition
type PartitionDefinition struct {
Name ColIdent
Limit Expr
Maxvalue bool
}
// Format formats the node
func (node *PartitionDefinition) Format(buf *TrackedBuffer) {
if !node.Maxvalue {
buf.Myprintf("partition %v values less than (%v)", node.Name, node.Limit)
} else {
buf.Myprintf("partition %v values less than (maxvalue)", node.Name)
}
}
func (node *PartitionDefinition) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Name,
node.Limit,
)
}
// TableSpec describes the structure of a table from a CREATE TABLE statement
type TableSpec struct {
Columns []*ColumnDefinition
Indexes []*IndexDefinition
Options string
}
// Format formats the node.
func (ts *TableSpec) Format(buf *TrackedBuffer) {
buf.Myprintf("(\n")
for i, col := range ts.Columns {
if i == 0 {
buf.Myprintf("\t%v", col)
} else {
buf.Myprintf(",\n\t%v", col)
}
}
for _, idx := range ts.Indexes {
buf.Myprintf(",\n\t%v", idx)
}
buf.Myprintf("\n)%s", strings.Replace(ts.Options, ", ", ",\n ", -1))
}
// AddColumn appends the given column to the list in the spec
func (ts *TableSpec) AddColumn(cd *ColumnDefinition) {
ts.Columns = append(ts.Columns, cd)
}
// AddIndex appends the given index to the list in the spec
func (ts *TableSpec) AddIndex(id *IndexDefinition) {
ts.Indexes = append(ts.Indexes, id)
}
func (ts *TableSpec) walkSubtree(visit Visit) error {
if ts == nil {
return nil
}
for _, n := range ts.Columns {
if err := Walk(visit, n); err != nil {
return err
}
}
for _, n := range ts.Indexes {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// ColumnDefinition describes a column in a CREATE TABLE statement
type ColumnDefinition struct {
Name ColIdent
Type ColumnType
}
// Format formats the node.
func (col *ColumnDefinition) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %v", col.Name, &col.Type)
}
func (col *ColumnDefinition) walkSubtree(visit Visit) error {
if col == nil {
return nil
}
return Walk(
visit,
col.Name,
&col.Type,
)
}
// ColumnType represents a sql type in a CREATE TABLE statement
// All optional fields are nil if not specified
type ColumnType struct {
// The base type string
Type string
// Generic field options.
NotNull BoolVal
Autoincrement BoolVal
Default *SQLVal
OnUpdate *SQLVal
Comment *SQLVal
// Numeric field options
Length *SQLVal
Unsigned BoolVal
Zerofill BoolVal
Scale *SQLVal
// Text field options
Charset string
Collate string
// Enum values
EnumValues []string
// Key specification
KeyOpt ColumnKeyOption
}
// Format returns a canonical string representation of the type and all relevant options
func (ct *ColumnType) Format(buf *TrackedBuffer) {
buf.Myprintf("%s", ct.Type)
if ct.Length != nil && ct.Scale != nil {
buf.Myprintf("(%v,%v)", ct.Length, ct.Scale)
} else if ct.Length != nil {
buf.Myprintf("(%v)", ct.Length)
}
if ct.EnumValues != nil {
buf.Myprintf("(%s)", strings.Join(ct.EnumValues, ", "))
}
opts := make([]string, 0, 16)
if ct.Unsigned {
opts = append(opts, keywordStrings[UNSIGNED])
}
if ct.Zerofill {
opts = append(opts, keywordStrings[ZEROFILL])
}
if ct.Charset != "" {
opts = append(opts, keywordStrings[CHARACTER], keywordStrings[SET], ct.Charset)
}
if ct.Collate != "" {
opts = append(opts, keywordStrings[COLLATE], ct.Collate)
}
if ct.NotNull {
opts = append(opts, keywordStrings[NOT], keywordStrings[NULL])
}
if ct.Default != nil {
opts = append(opts, keywordStrings[DEFAULT], String(ct.Default))
}
if ct.OnUpdate != nil {
opts = append(opts, keywordStrings[ON], keywordStrings[UPDATE], String(ct.OnUpdate))
}
if ct.Autoincrement {
opts = append(opts, keywordStrings[AUTO_INCREMENT])
}
if ct.Comment != nil {
opts = append(opts, keywordStrings[COMMENT_KEYWORD], String(ct.Comment))
}
if ct.KeyOpt == colKeyPrimary {
opts = append(opts, keywordStrings[PRIMARY], keywordStrings[KEY])
}
if ct.KeyOpt == colKeyUnique {
opts = append(opts, keywordStrings[UNIQUE])
}
if ct.KeyOpt == colKeyUniqueKey {
opts = append(opts, keywordStrings[UNIQUE], keywordStrings[KEY])
}
if ct.KeyOpt == colKeySpatialKey {
opts = append(opts, keywordStrings[SPATIAL], keywordStrings[KEY])
}
if ct.KeyOpt == colKey {
opts = append(opts, keywordStrings[KEY])
}
if len(opts) != 0 {
buf.Myprintf(" %s", strings.Join(opts, " "))
}
}
// DescribeType returns the abbreviated type information as required for
// describe table
func (ct *ColumnType) DescribeType() string {
buf := NewTrackedBuffer(nil)
buf.Myprintf("%s", ct.Type)
if ct.Length != nil && ct.Scale != nil {
buf.Myprintf("(%v,%v)", ct.Length, ct.Scale)
} else if ct.Length != nil {
buf.Myprintf("(%v)", ct.Length)
}
opts := make([]string, 0, 16)
if ct.Unsigned {
opts = append(opts, keywordStrings[UNSIGNED])
}
if ct.Zerofill {
opts = append(opts, keywordStrings[ZEROFILL])
}
if len(opts) != 0 {
buf.Myprintf(" %s", strings.Join(opts, " "))
}
return buf.String()
}
// SQLType returns the sqltypes type code for the given column
func (ct *ColumnType) SQLType() querypb.Type {
switch ct.Type {
case keywordStrings[TINYINT]:
if ct.Unsigned {
return sqltypes.Uint8
}
return sqltypes.Int8
case keywordStrings[SMALLINT]:
if ct.Unsigned {
return sqltypes.Uint16
}
return sqltypes.Int16
case keywordStrings[MEDIUMINT]:
if ct.Unsigned {
return sqltypes.Uint24
}
return sqltypes.Int24
case keywordStrings[INT]:
fallthrough
case keywordStrings[INTEGER]:
if ct.Unsigned {
return sqltypes.Uint32
}
return sqltypes.Int32
case keywordStrings[BIGINT]:
if ct.Unsigned {
return sqltypes.Uint64
}
return sqltypes.Int64
case keywordStrings[TEXT]:
return sqltypes.Text
case keywordStrings[TINYTEXT]:
return sqltypes.Text
case keywordStrings[MEDIUMTEXT]:
return sqltypes.Text
case keywordStrings[LONGTEXT]:
return sqltypes.Text
case keywordStrings[BLOB]:
return sqltypes.Blob
case keywordStrings[TINYBLOB]:
return sqltypes.Blob
case keywordStrings[MEDIUMBLOB]:
return sqltypes.Blob
case keywordStrings[LONGBLOB]:
return sqltypes.Blob
case keywordStrings[CHAR]:
return sqltypes.Char
case keywordStrings[VARCHAR]:
return sqltypes.VarChar
case keywordStrings[BINARY]:
return sqltypes.Binary
case keywordStrings[VARBINARY]:
return sqltypes.VarBinary
case keywordStrings[DATE]:
return sqltypes.Date
case keywordStrings[TIME]:
return sqltypes.Time
case keywordStrings[DATETIME]:
return sqltypes.Datetime
case keywordStrings[TIMESTAMP]:
return sqltypes.Timestamp
case keywordStrings[YEAR]:
return sqltypes.Year
case keywordStrings[FLOAT_TYPE]:
return sqltypes.Float32
case keywordStrings[DOUBLE]:
return sqltypes.Float64
case keywordStrings[DECIMAL]:
return sqltypes.Decimal
case keywordStrings[BIT]:
return sqltypes.Bit
case keywordStrings[ENUM]:
return sqltypes.Enum
case keywordStrings[SET]:
return sqltypes.Set
case keywordStrings[JSON]:
return sqltypes.TypeJSON
case keywordStrings[GEOMETRY]:
return sqltypes.Geometry
case keywordStrings[POINT]:
return sqltypes.Geometry
case keywordStrings[LINESTRING]:
return sqltypes.Geometry
case keywordStrings[POLYGON]:
return sqltypes.Geometry
case keywordStrings[GEOMETRYCOLLECTION]:
return sqltypes.Geometry
case keywordStrings[MULTIPOINT]:
return sqltypes.Geometry
case keywordStrings[MULTILINESTRING]:
return sqltypes.Geometry
case keywordStrings[MULTIPOLYGON]:
return sqltypes.Geometry
}
panic("unimplemented type " + ct.Type)
}
func (ct *ColumnType) walkSubtree(visit Visit) error {
return nil
}
// IndexDefinition describes an index in a CREATE TABLE statement
type IndexDefinition struct {
Info *IndexInfo
Columns []*IndexColumn
Options []*IndexOption
}
// Format formats the node.
func (idx *IndexDefinition) Format(buf *TrackedBuffer) {
buf.Myprintf("%v (", idx.Info)
for i, col := range idx.Columns {
if i != 0 {
buf.Myprintf(", %v", col.Column)
} else {
buf.Myprintf("%v", col.Column)
}
if col.Length != nil {
buf.Myprintf("(%v)", col.Length)
}
}
buf.Myprintf(")")
for _, opt := range idx.Options {
buf.Myprintf(" %s", opt.Name)
if opt.Using != "" {
buf.Myprintf(" %s", opt.Using)
} else {
buf.Myprintf(" %v", opt.Value)
}
}
}
func (idx *IndexDefinition) walkSubtree(visit Visit) error {
if idx == nil {
return nil
}
for _, n := range idx.Columns {
if err := Walk(visit, n.Column); err != nil {
return err
}
}
return nil
}
// IndexInfo describes the name and type of an index in a CREATE TABLE statement
type IndexInfo struct {
Type string
Name ColIdent
Primary bool
Spatial bool
Unique bool
}
// Format formats the node.
func (ii *IndexInfo) Format(buf *TrackedBuffer) {
if ii.Primary {
buf.Myprintf("%s", ii.Type)
} else {
buf.Myprintf("%s %v", ii.Type, ii.Name)
}
}
func (ii *IndexInfo) walkSubtree(visit Visit) error {
return Walk(visit, ii.Name)
}
// IndexColumn describes a column in an index definition with optional length
type IndexColumn struct {
Column ColIdent
Length *SQLVal
}
// LengthScaleOption is used for types that have an optional length
// and scale
type LengthScaleOption struct {
Length *SQLVal
Scale *SQLVal
}
// IndexOption is used for trailing options for indexes: COMMENT, KEY_BLOCK_SIZE, USING
type IndexOption struct {
Name string
Value *SQLVal
Using string
}
// ColumnKeyOption indicates whether or not the given column is defined as an
// index element and contains the type of the option
type ColumnKeyOption int
const (
colKeyNone ColumnKeyOption = iota
colKeyPrimary
colKeySpatialKey
colKeyUnique
colKeyUniqueKey
colKey
)
// VindexSpec defines a vindex for a CREATE VINDEX or DROP VINDEX statement
type VindexSpec struct {
Name ColIdent
Type ColIdent
Params []VindexParam
}
// ParseParams parses the vindex parameter list, pulling out the special-case
// "owner" parameter
func (node *VindexSpec) ParseParams() (string, map[string]string) {
var owner string
params := map[string]string{}
for _, p := range node.Params {
if p.Key.Lowered() == VindexOwnerStr {
owner = p.Val
} else {
params[p.Key.String()] = p.Val
}
}
return owner, params
}
// Format formats the node. The "CREATE VINDEX" preamble was formatted in
// the containing DDL node Format, so this just prints the type, any
// parameters, and optionally the owner
func (node *VindexSpec) Format(buf *TrackedBuffer) {
buf.Myprintf("using %v", node.Type)
numParams := len(node.Params)
if numParams != 0 {
buf.Myprintf(" with ")
for i, p := range node.Params {
if i != 0 {
buf.Myprintf(", ")
}
buf.Myprintf("%v", p)
}
}
}
func (node *VindexSpec) walkSubtree(visit Visit) error {
err := Walk(visit,
node.Name,
)
if err != nil {
return err
}
for _, p := range node.Params {
err := Walk(visit, p)
if err != nil {
return err
}
}
return nil
}
// VindexParam defines a key/value parameter for a CREATE VINDEX statement
type VindexParam struct {
Key ColIdent
Val string
}
// Format formats the node.
func (node VindexParam) Format(buf *TrackedBuffer) {
buf.Myprintf("%s=%s", node.Key.String(), node.Val)
}
func (node VindexParam) walkSubtree(visit Visit) error {
return Walk(visit,
node.Key,
)
}
// Show represents a show statement.
type Show struct {
Type string
OnTable TableName
ShowTablesOpt *ShowTablesOpt
Scope string
}
// Format formats the node.
func (node *Show) Format(buf *TrackedBuffer) {
if node.Type == "tables" && node.ShowTablesOpt != nil {
opt := node.ShowTablesOpt
if opt.DbName != "" {
if opt.Filter != nil {
buf.Myprintf("show %s%stables from %s %v", opt.Extended, opt.Full, opt.DbName, opt.Filter)
} else {
buf.Myprintf("show %s%stables from %s", opt.Extended, opt.Full, opt.DbName)
}
} else {
if opt.Filter != nil {
buf.Myprintf("show %s%stables %v", opt.Extended, opt.Full, opt.Filter)
} else {
buf.Myprintf("show %s%stables", opt.Extended, opt.Full)
}
}
return
}
if node.Scope == "" {
buf.Myprintf("show %s", node.Type)
} else {
buf.Myprintf("show %s %s", node.Scope, node.Type)
}
if node.HasOnTable() {
buf.Myprintf(" on %v", node.OnTable)
}
}
// HasOnTable returns true if the show statement has an "on" clause
func (node *Show) HasOnTable() bool {
return node.OnTable.Name.v != ""
}
func (node *Show) walkSubtree(visit Visit) error {
return nil
}
// ShowTablesOpt is show tables option
type ShowTablesOpt struct {
Extended string
Full string
DbName string
Filter *ShowFilter
}
// ShowFilter is show tables filter
type ShowFilter struct {
Like string
Filter Expr
}
// Format formats the node.
func (node *ShowFilter) Format(buf *TrackedBuffer) {
if node.Like != "" {
buf.Myprintf("like '%s'", node.Like)
} else {
buf.Myprintf("where %v", node.Filter)
}
}
func (node *ShowFilter) walkSubtree(visit Visit) error {
return nil
}
// Use represents a use statement.
type Use struct {
DBName TableIdent
}
// Format formats the node.
func (node *Use) Format(buf *TrackedBuffer) {
if node.DBName.v != "" {
buf.Myprintf("use %v", node.DBName)
} else {
buf.Myprintf("use")
}
}
func (node *Use) walkSubtree(visit Visit) error {
return Walk(visit, node.DBName)
}
// Begin represents a Begin statement.
type Begin struct{}
// Format formats the node.
func (node *Begin) Format(buf *TrackedBuffer) {
buf.WriteString("begin")
}
func (node *Begin) walkSubtree(visit Visit) error {
return nil
}
// Commit represents a Commit statement.
type Commit struct{}
// Format formats the node.
func (node *Commit) Format(buf *TrackedBuffer) {
buf.WriteString("commit")
}
func (node *Commit) walkSubtree(visit Visit) error {
return nil
}
// Rollback represents a Rollback statement.
type Rollback struct{}
// Format formats the node.
func (node *Rollback) Format(buf *TrackedBuffer) {
buf.WriteString("rollback")
}
func (node *Rollback) walkSubtree(visit Visit) error {
return nil
}
// OtherRead represents a DESCRIBE, or EXPLAIN statement.
// It should be used only as an indicator. It does not contain
// the full AST for the statement.
type OtherRead struct{}
// Format formats the node.
func (node *OtherRead) Format(buf *TrackedBuffer) {
buf.WriteString("otherread")
}
func (node *OtherRead) walkSubtree(visit Visit) error {
return nil
}
// OtherAdmin represents a misc statement that relies on ADMIN privileges,
// such as REPAIR, OPTIMIZE, or TRUNCATE statement.
// It should be used only as an indicator. It does not contain
// the full AST for the statement.
type OtherAdmin struct{}
// Format formats the node.
func (node *OtherAdmin) Format(buf *TrackedBuffer) {
buf.WriteString("otheradmin")
}
func (node *OtherAdmin) walkSubtree(visit Visit) error {
return nil
}
// Comments represents a list of comments.
type Comments [][]byte
// Format formats the node.
func (node Comments) Format(buf *TrackedBuffer) {
for _, c := range node {
buf.Myprintf("%s ", c)
}
}
func (node Comments) walkSubtree(visit Visit) error {
return nil
}
// SelectExprs represents SELECT expressions.
type SelectExprs []SelectExpr
// Format formats the node.
func (node SelectExprs) Format(buf *TrackedBuffer) {
var prefix string
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node SelectExprs) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// SelectExpr represents a SELECT expression.
type SelectExpr interface {
iSelectExpr()
SQLNode
}
func (*StarExpr) iSelectExpr() {}
func (*AliasedExpr) iSelectExpr() {}
func (Nextval) iSelectExpr() {}
// StarExpr defines a '*' or 'table.*' expression.
type StarExpr struct {
TableName TableName
}
// Format formats the node.
func (node *StarExpr) Format(buf *TrackedBuffer) {
if !node.TableName.IsEmpty() {
buf.Myprintf("%v.", node.TableName)
}
buf.Myprintf("*")
}
func (node *StarExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.TableName,
)
}
// AliasedExpr defines an aliased SELECT expression.
type AliasedExpr struct {
Expr Expr
As ColIdent
}
// Format formats the node.
func (node *AliasedExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v", node.Expr)
if !node.As.IsEmpty() {
buf.Myprintf(" as %v", node.As)
}
}
func (node *AliasedExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
node.As,
)
}
// Nextval defines the NEXT VALUE expression.
type Nextval struct {
Expr Expr
}
// Format formats the node.
func (node Nextval) Format(buf *TrackedBuffer) {
buf.Myprintf("next %v values", node.Expr)
}
func (node Nextval) walkSubtree(visit Visit) error {
return Walk(visit, node.Expr)
}
// Columns represents an insert column list.
type Columns []ColIdent
// Format formats the node.
func (node Columns) Format(buf *TrackedBuffer) {
if node == nil {
return
}
prefix := "("
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
buf.WriteString(")")
}
func (node Columns) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// FindColumn finds a column in the column list, returning
// the index if it exists or -1 otherwise
func (node Columns) FindColumn(col ColIdent) int {
for i, colName := range node {
if colName.Equal(col) {
return i
}
}
return -1
}
// Partitions is a type alias for Columns so we can handle printing efficiently
type Partitions Columns
// Format formats the node
func (node Partitions) Format(buf *TrackedBuffer) {
if node == nil {
return
}
prefix := " partition ("
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
buf.WriteString(")")
}
func (node Partitions) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// TableExprs represents a list of table expressions.
type TableExprs []TableExpr
// Format formats the node.
func (node TableExprs) Format(buf *TrackedBuffer) {
var prefix string
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node TableExprs) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// TableExpr represents a table expression.
type TableExpr interface {
iTableExpr()
SQLNode
}
func (*AliasedTableExpr) iTableExpr() {}
func (*ParenTableExpr) iTableExpr() {}
func (*JoinTableExpr) iTableExpr() {}
// AliasedTableExpr represents a table expression
// coupled with an optional alias or index hint.
// If As is empty, no alias was used.
type AliasedTableExpr struct {
Expr SimpleTableExpr
Partitions Partitions
As TableIdent
Hints *IndexHints
}
// Format formats the node.
func (node *AliasedTableExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v%v", node.Expr, node.Partitions)
if !node.As.IsEmpty() {
buf.Myprintf(" as %v", node.As)
}
if node.Hints != nil {
// Hint node provides the space padding.
buf.Myprintf("%v", node.Hints)
}
}
func (node *AliasedTableExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
node.As,
node.Hints,
)
}
// RemoveHints returns a new AliasedTableExpr with the hints removed.
func (node *AliasedTableExpr) RemoveHints() *AliasedTableExpr {
noHints := *node
noHints.Hints = nil
return &noHints
}
// SimpleTableExpr represents a simple table expression.
type SimpleTableExpr interface {
iSimpleTableExpr()
SQLNode
}
func (TableName) iSimpleTableExpr() {}
func (*Subquery) iSimpleTableExpr() {}
// TableNames is a list of TableName.
type TableNames []TableName
// Format formats the node.
func (node TableNames) Format(buf *TrackedBuffer) {
var prefix string
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node TableNames) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// TableName represents a table name.
// Qualifier, if specified, represents a database or keyspace.
// TableName is a value struct whose fields are case sensitive.
// This means two TableName vars can be compared for equality
// and a TableName can also be used as key in a map.
type TableName struct {
Name, Qualifier TableIdent
}
// Format formats the node.
func (node TableName) Format(buf *TrackedBuffer) {
if node.IsEmpty() {
return
}
if !node.Qualifier.IsEmpty() {
buf.Myprintf("%v.", node.Qualifier)
}
buf.Myprintf("%v", node.Name)
}
func (node TableName) walkSubtree(visit Visit) error {
return Walk(
visit,
node.Name,
node.Qualifier,
)
}
// IsEmpty returns true if TableName is nil or empty.
func (node TableName) IsEmpty() bool {
// If Name is empty, Qualifer is also empty.
return node.Name.IsEmpty()
}
// ToViewName returns a TableName acceptable for use as a VIEW. VIEW names are
// always lowercase, so ToViewName lowercasese the name. Databases are case-sensitive
// so Qualifier is left untouched.
func (node TableName) ToViewName() TableName {
return TableName{
Qualifier: node.Qualifier,
Name: NewTableIdent(strings.ToLower(node.Name.v)),
}
}
// ParenTableExpr represents a parenthesized list of TableExpr.
type ParenTableExpr struct {
Exprs TableExprs
}
// Format formats the node.
func (node *ParenTableExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("(%v)", node.Exprs)
}
func (node *ParenTableExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Exprs,
)
}
// JoinCondition represents the join conditions (either a ON or USING clause)
// of a JoinTableExpr.
type JoinCondition struct {
On Expr
Using Columns
}
// Format formats the node.
func (node JoinCondition) Format(buf *TrackedBuffer) {
if node.On != nil {
buf.Myprintf(" on %v", node.On)
}
if node.Using != nil {
buf.Myprintf(" using %v", node.Using)
}
}
func (node JoinCondition) walkSubtree(visit Visit) error {
return Walk(
visit,
node.On,
node.Using,
)
}
// JoinTableExpr represents a TableExpr that's a JOIN operation.
type JoinTableExpr struct {
LeftExpr TableExpr
Join string
RightExpr TableExpr
Condition JoinCondition
}
// JoinTableExpr.Join
const (
JoinStr = "join"
StraightJoinStr = "straight_join"
LeftJoinStr = "left join"
RightJoinStr = "right join"
NaturalJoinStr = "natural join"
NaturalLeftJoinStr = "natural left join"
NaturalRightJoinStr = "natural right join"
)
// Format formats the node.
func (node *JoinTableExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %s %v%v", node.LeftExpr, node.Join, node.RightExpr, node.Condition)
}
func (node *JoinTableExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.LeftExpr,
node.RightExpr,
node.Condition,
)
}
// IndexHints represents a list of index hints.
type IndexHints struct {
Type string
Indexes []ColIdent
}
// Index hints.
const (
UseStr = "use "
IgnoreStr = "ignore "
ForceStr = "force "
)
// Format formats the node.
func (node *IndexHints) Format(buf *TrackedBuffer) {
buf.Myprintf(" %sindex ", node.Type)
prefix := "("
for _, n := range node.Indexes {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
buf.Myprintf(")")
}
func (node *IndexHints) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
for _, n := range node.Indexes {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// Where represents a WHERE or HAVING clause.
type Where struct {
Type string
Expr Expr
}
// Where.Type
const (
WhereStr = "where"
HavingStr = "having"
)
// NewWhere creates a WHERE or HAVING clause out
// of a Expr. If the expression is nil, it returns nil.
func NewWhere(typ string, expr Expr) *Where {
if expr == nil {
return nil
}
return &Where{Type: typ, Expr: expr}
}
// Format formats the node.
func (node *Where) Format(buf *TrackedBuffer) {
if node == nil || node.Expr == nil {
return
}
buf.Myprintf(" %s %v", node.Type, node.Expr)
}
func (node *Where) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
// Expr represents an expression.
type Expr interface {
iExpr()
// replace replaces any subexpression that matches
// from with to. The implementation can use the
// replaceExprs convenience function.
replace(from, to Expr) bool
SQLNode
}
func (*AndExpr) iExpr() {}
func (*OrExpr) iExpr() {}
func (*NotExpr) iExpr() {}
func (*ParenExpr) iExpr() {}
func (*ComparisonExpr) iExpr() {}
func (*RangeCond) iExpr() {}
func (*IsExpr) iExpr() {}
func (*ExistsExpr) iExpr() {}
func (*SQLVal) iExpr() {}
func (*NullVal) iExpr() {}
func (BoolVal) iExpr() {}
func (*ColName) iExpr() {}
func (ValTuple) iExpr() {}
func (*Subquery) iExpr() {}
func (ListArg) iExpr() {}
func (*BinaryExpr) iExpr() {}
func (*UnaryExpr) iExpr() {}
func (*IntervalExpr) iExpr() {}
func (*CollateExpr) iExpr() {}
func (*FuncExpr) iExpr() {}
func (*CaseExpr) iExpr() {}
func (*ValuesFuncExpr) iExpr() {}
func (*ConvertExpr) iExpr() {}
func (*SubstrExpr) iExpr() {}
func (*ConvertUsingExpr) iExpr() {}
func (*MatchExpr) iExpr() {}
func (*GroupConcatExpr) iExpr() {}
func (*Default) iExpr() {}
// ReplaceExpr finds the from expression from root
// and replaces it with to. If from matches root,
// then to is returned.
func ReplaceExpr(root, from, to Expr) Expr {
if root == from {
return to
}
root.replace(from, to)
return root
}
// replaceExprs is a convenience function used by implementors
// of the replace method.
func replaceExprs(from, to Expr, exprs ...*Expr) bool {
for _, expr := range exprs {
if *expr == nil {
continue
}
if *expr == from {
*expr = to
return true
}
if (*expr).replace(from, to) {
return true
}
}
return false
}
// Exprs represents a list of value expressions.
// It's not a valid expression because it's not parenthesized.
type Exprs []Expr
// Format formats the node.
func (node Exprs) Format(buf *TrackedBuffer) {
var prefix string
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node Exprs) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// AndExpr represents an AND expression.
type AndExpr struct {
Left, Right Expr
}
// Format formats the node.
func (node *AndExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v and %v", node.Left, node.Right)
}
func (node *AndExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Left,
node.Right,
)
}
func (node *AndExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Left, &node.Right)
}
// OrExpr represents an OR expression.
type OrExpr struct {
Left, Right Expr
}
// Format formats the node.
func (node *OrExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v or %v", node.Left, node.Right)
}
func (node *OrExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Left,
node.Right,
)
}
func (node *OrExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Left, &node.Right)
}
// NotExpr represents a NOT expression.
type NotExpr struct {
Expr Expr
}
// Format formats the node.
func (node *NotExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("not %v", node.Expr)
}
func (node *NotExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
func (node *NotExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// ParenExpr represents a parenthesized boolean expression.
type ParenExpr struct {
Expr Expr
}
// Format formats the node.
func (node *ParenExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("(%v)", node.Expr)
}
func (node *ParenExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
func (node *ParenExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// ComparisonExpr represents a two-value comparison expression.
type ComparisonExpr struct {
Operator string
Left, Right Expr
Escape Expr
}
// ComparisonExpr.Operator
const (
EqualStr = "="
LessThanStr = "<"
GreaterThanStr = ">"
LessEqualStr = "<="
GreaterEqualStr = ">="
NotEqualStr = "!="
NullSafeEqualStr = "<=>"
InStr = "in"
NotInStr = "not in"
LikeStr = "like"
NotLikeStr = "not like"
RegexpStr = "regexp"
NotRegexpStr = "not regexp"
JSONExtractOp = "->"
JSONUnquoteExtractOp = "->>"
)
// Format formats the node.
func (node *ComparisonExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %s %v", node.Left, node.Operator, node.Right)
if node.Escape != nil {
buf.Myprintf(" escape %v", node.Escape)
}
}
func (node *ComparisonExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Left,
node.Right,
node.Escape,
)
}
func (node *ComparisonExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Left, &node.Right, &node.Escape)
}
// RangeCond represents a BETWEEN or a NOT BETWEEN expression.
type RangeCond struct {
Operator string
Left Expr
From, To Expr
}
// RangeCond.Operator
const (
BetweenStr = "between"
NotBetweenStr = "not between"
)
// Format formats the node.
func (node *RangeCond) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %s %v and %v", node.Left, node.Operator, node.From, node.To)
}
func (node *RangeCond) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Left,
node.From,
node.To,
)
}
func (node *RangeCond) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Left, &node.From, &node.To)
}
// IsExpr represents an IS ... or an IS NOT ... expression.
type IsExpr struct {
Operator string
Expr Expr
}
// IsExpr.Operator
const (
IsNullStr = "is null"
IsNotNullStr = "is not null"
IsTrueStr = "is true"
IsNotTrueStr = "is not true"
IsFalseStr = "is false"
IsNotFalseStr = "is not false"
)
// Format formats the node.
func (node *IsExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %s", node.Expr, node.Operator)
}
func (node *IsExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
func (node *IsExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// ExistsExpr represents an EXISTS expression.
type ExistsExpr struct {
Subquery *Subquery
}
// Format formats the node.
func (node *ExistsExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("exists %v", node.Subquery)
}
func (node *ExistsExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Subquery,
)
}
func (node *ExistsExpr) replace(from, to Expr) bool {
return false
}
// ExprFromValue converts the given Value into an Expr or returns an error.
func ExprFromValue(value sqltypes.Value) (Expr, error) {
// The type checks here follow the rules defined in sqltypes/types.go.
switch {
case value.Type() == sqltypes.Null:
return &NullVal{}, nil
case value.IsIntegral():
return NewIntVal(value.ToBytes()), nil
case value.IsFloat() || value.Type() == sqltypes.Decimal:
return NewFloatVal(value.ToBytes()), nil
case value.IsQuoted():
return NewStrVal(value.ToBytes()), nil
default:
// We cannot support sqltypes.Expression, or any other invalid type.
return nil, fmt.Errorf("cannot convert value %v to AST", value)
}
}
// ValType specifies the type for SQLVal.
type ValType int
// These are the possible Valtype values.
// HexNum represents a 0x... value. It cannot
// be treated as a simple value because it can
// be interpreted differently depending on the
// context.
const (
StrVal = ValType(iota)
IntVal
FloatVal
HexNum
HexVal
ValArg
BitVal
)
// SQLVal represents a single value.
type SQLVal struct {
Type ValType
Val []byte
}
// NewStrVal builds a new StrVal.
func NewStrVal(in []byte) *SQLVal {
return &SQLVal{Type: StrVal, Val: in}
}
// NewIntVal builds a new IntVal.
func NewIntVal(in []byte) *SQLVal {
return &SQLVal{Type: IntVal, Val: in}
}
// NewFloatVal builds a new FloatVal.
func NewFloatVal(in []byte) *SQLVal {
return &SQLVal{Type: FloatVal, Val: in}
}
// NewHexNum builds a new HexNum.
func NewHexNum(in []byte) *SQLVal {
return &SQLVal{Type: HexNum, Val: in}
}
// NewHexVal builds a new HexVal.
func NewHexVal(in []byte) *SQLVal {
return &SQLVal{Type: HexVal, Val: in}
}
// NewBitVal builds a new BitVal containing a bit literal.
func NewBitVal(in []byte) *SQLVal {
return &SQLVal{Type: BitVal, Val: in}
}
// NewValArg builds a new ValArg.
func NewValArg(in []byte) *SQLVal {
return &SQLVal{Type: ValArg, Val: in}
}
// Format formats the node.
func (node *SQLVal) Format(buf *TrackedBuffer) {
switch node.Type {
case StrVal:
sqltypes.MakeTrusted(sqltypes.VarBinary, node.Val).EncodeSQL(buf)
case IntVal, FloatVal, HexNum:
buf.Myprintf("%s", []byte(node.Val))
case HexVal:
buf.Myprintf("X'%s'", []byte(node.Val))
case BitVal:
buf.Myprintf("B'%s'", []byte(node.Val))
case ValArg:
buf.WriteArg(string(node.Val))
default:
panic("unexpected")
}
}
func (node *SQLVal) walkSubtree(visit Visit) error {
return nil
}
func (node *SQLVal) replace(from, to Expr) bool {
return false
}
// HexDecode decodes the hexval into bytes.
func (node *SQLVal) HexDecode() ([]byte, error) {
dst := make([]byte, hex.DecodedLen(len([]byte(node.Val))))
_, err := hex.Decode(dst, []byte(node.Val))
if err != nil {
return nil, err
}
return dst, err
}
// NullVal represents a NULL value.
type NullVal struct{}
// Format formats the node.
func (node *NullVal) Format(buf *TrackedBuffer) {
buf.Myprintf("null")
}
func (node *NullVal) walkSubtree(visit Visit) error {
return nil
}
func (node *NullVal) replace(from, to Expr) bool {
return false
}
// BoolVal is true or false.
type BoolVal bool
// Format formats the node.
func (node BoolVal) Format(buf *TrackedBuffer) {
if node {
buf.Myprintf("true")
} else {
buf.Myprintf("false")
}
}
func (node BoolVal) walkSubtree(visit Visit) error {
return nil
}
func (node BoolVal) replace(from, to Expr) bool {
return false
}
// ColName represents a column name.
type ColName struct {
// Metadata is not populated by the parser.
// It's a placeholder for analyzers to store
// additional data, typically info about which
// table or column this node references.
Metadata interface{}
Name ColIdent
Qualifier TableName
}
// Format formats the node.
func (node *ColName) Format(buf *TrackedBuffer) {
if !node.Qualifier.IsEmpty() {
buf.Myprintf("%v.", node.Qualifier)
}
buf.Myprintf("%v", node.Name)
}
func (node *ColName) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Name,
node.Qualifier,
)
}
func (node *ColName) replace(from, to Expr) bool {
return false
}
// Equal returns true if the column names match.
func (node *ColName) Equal(c *ColName) bool {
// Failsafe: ColName should not be empty.
if node == nil || c == nil {
return false
}
return node.Name.Equal(c.Name) && node.Qualifier == c.Qualifier
}
// ColTuple represents a list of column values.
// It can be ValTuple, Subquery, ListArg.
type ColTuple interface {
iColTuple()
Expr
}
func (ValTuple) iColTuple() {}
func (*Subquery) iColTuple() {}
func (ListArg) iColTuple() {}
// ValTuple represents a tuple of actual values.
type ValTuple Exprs
// Format formats the node.
func (node ValTuple) Format(buf *TrackedBuffer) {
buf.Myprintf("(%v)", Exprs(node))
}
func (node ValTuple) walkSubtree(visit Visit) error {
return Walk(visit, Exprs(node))
}
func (node ValTuple) replace(from, to Expr) bool {
for i := range node {
if replaceExprs(from, to, &node[i]) {
return true
}
}
return false
}
// Subquery represents a subquery.
type Subquery struct {
Select SelectStatement
}
// Format formats the node.
func (node *Subquery) Format(buf *TrackedBuffer) {
buf.Myprintf("(%v)", node.Select)
}
func (node *Subquery) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Select,
)
}
func (node *Subquery) replace(from, to Expr) bool {
return false
}
// ListArg represents a named list argument.
type ListArg []byte
// Format formats the node.
func (node ListArg) Format(buf *TrackedBuffer) {
buf.WriteArg(string(node))
}
func (node ListArg) walkSubtree(visit Visit) error {
return nil
}
func (node ListArg) replace(from, to Expr) bool {
return false
}
// BinaryExpr represents a binary value expression.
type BinaryExpr struct {
Operator string
Left, Right Expr
}
// BinaryExpr.Operator
const (
BitAndStr = "&"
BitOrStr = "|"
BitXorStr = "^"
PlusStr = "+"
MinusStr = "-"
MultStr = "*"
DivStr = "/"
IntDivStr = "div"
ModStr = "%"
ShiftLeftStr = "<<"
ShiftRightStr = ">>"
)
// Format formats the node.
func (node *BinaryExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v %s %v", node.Left, node.Operator, node.Right)
}
func (node *BinaryExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Left,
node.Right,
)
}
func (node *BinaryExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Left, &node.Right)
}
// UnaryExpr represents a unary value expression.
type UnaryExpr struct {
Operator string
Expr Expr
}
// UnaryExpr.Operator
const (
UPlusStr = "+"
UMinusStr = "-"
TildaStr = "~"
BangStr = "!"
BinaryStr = "binary "
UBinaryStr = "_binary "
)
// Format formats the node.
func (node *UnaryExpr) Format(buf *TrackedBuffer) {
if _, unary := node.Expr.(*UnaryExpr); unary {
buf.Myprintf("%s %v", node.Operator, node.Expr)
return
}
buf.Myprintf("%s%v", node.Operator, node.Expr)
}
func (node *UnaryExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
func (node *UnaryExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// IntervalExpr represents a date-time INTERVAL expression.
type IntervalExpr struct {
Expr Expr
Unit string
}
// Format formats the node.
func (node *IntervalExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("interval %v %s", node.Expr, node.Unit)
}
func (node *IntervalExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
func (node *IntervalExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// CollateExpr represents dynamic collate operator.
type CollateExpr struct {
Expr Expr
Charset string
}
// Format formats the node.
func (node *CollateExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v collate %s", node.Expr, node.Charset)
}
func (node *CollateExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
func (node *CollateExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// FuncExpr represents a function call.
type FuncExpr struct {
Qualifier TableIdent
Name ColIdent
Distinct bool
Exprs SelectExprs
}
// Format formats the node.
func (node *FuncExpr) Format(buf *TrackedBuffer) {
var distinct string
if node.Distinct {
distinct = "distinct "
}
if !node.Qualifier.IsEmpty() {
buf.Myprintf("%v.", node.Qualifier)
}
// Function names should not be back-quoted even
// if they match a reserved word. So, print the
// name as is.
buf.Myprintf("%s(%s%v)", node.Name.String(), distinct, node.Exprs)
}
func (node *FuncExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Qualifier,
node.Name,
node.Exprs,
)
}
func (node *FuncExpr) replace(from, to Expr) bool {
for _, sel := range node.Exprs {
aliased, ok := sel.(*AliasedExpr)
if !ok {
continue
}
if replaceExprs(from, to, &aliased.Expr) {
return true
}
}
return false
}
// Aggregates is a map of all aggregate functions.
var Aggregates = map[string]bool{
"avg": true,
"bit_and": true,
"bit_or": true,
"bit_xor": true,
"count": true,
"group_concat": true,
"max": true,
"min": true,
"std": true,
"stddev_pop": true,
"stddev_samp": true,
"stddev": true,
"sum": true,
"var_pop": true,
"var_samp": true,
"variance": true,
}
// IsAggregate returns true if the function is an aggregate.
func (node *FuncExpr) IsAggregate() bool {
return Aggregates[node.Name.Lowered()]
}
// GroupConcatExpr represents a call to GROUP_CONCAT
type GroupConcatExpr struct {
Distinct string
Exprs SelectExprs
OrderBy OrderBy
Separator string
}
// Format formats the node
func (node *GroupConcatExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("group_concat(%s%v%v%s)", node.Distinct, node.Exprs, node.OrderBy, node.Separator)
}
func (node *GroupConcatExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Exprs,
node.OrderBy,
)
}
func (node *GroupConcatExpr) replace(from, to Expr) bool {
for _, sel := range node.Exprs {
aliased, ok := sel.(*AliasedExpr)
if !ok {
continue
}
if replaceExprs(from, to, &aliased.Expr) {
return true
}
}
for _, order := range node.OrderBy {
if replaceExprs(from, to, &order.Expr) {
return true
}
}
return false
}
// ValuesFuncExpr represents a function call.
type ValuesFuncExpr struct {
Name *ColName
}
// Format formats the node.
func (node *ValuesFuncExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("values(%v)", node.Name)
}
func (node *ValuesFuncExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Name,
)
}
func (node *ValuesFuncExpr) replace(from, to Expr) bool {
return false
}
// SubstrExpr represents a call to SubstrExpr(column, value_expression) or SubstrExpr(column, value_expression,value_expression)
// also supported syntax SubstrExpr(column from value_expression for value_expression)
type SubstrExpr struct {
Name *ColName
From Expr
To Expr
}
// Format formats the node.
func (node *SubstrExpr) Format(buf *TrackedBuffer) {
if node.To == nil {
buf.Myprintf("substr(%v, %v)", node.Name, node.From)
} else {
buf.Myprintf("substr(%v, %v, %v)", node.Name, node.From, node.To)
}
}
func (node *SubstrExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.From, &node.To)
}
func (node *SubstrExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Name,
node.From,
node.To,
)
}
// ConvertExpr represents a call to CONVERT(expr, type)
// or it's equivalent CAST(expr AS type). Both are rewritten to the former.
type ConvertExpr struct {
Expr Expr
Type *ConvertType
}
// Format formats the node.
func (node *ConvertExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("convert(%v, %v)", node.Expr, node.Type)
}
func (node *ConvertExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
node.Type,
)
}
func (node *ConvertExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// ConvertUsingExpr represents a call to CONVERT(expr USING charset).
type ConvertUsingExpr struct {
Expr Expr
Type string
}
// Format formats the node.
func (node *ConvertUsingExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("convert(%v using %s)", node.Expr, node.Type)
}
func (node *ConvertUsingExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
func (node *ConvertUsingExpr) replace(from, to Expr) bool {
return replaceExprs(from, to, &node.Expr)
}
// ConvertType represents the type in call to CONVERT(expr, type)
type ConvertType struct {
Type string
Length *SQLVal
Scale *SQLVal
Operator string
Charset string
}
// this string is "character set" and this comment is required
const (
CharacterSetStr = " character set"
)
// Format formats the node.
func (node *ConvertType) Format(buf *TrackedBuffer) {
buf.Myprintf("%s", node.Type)
if node.Length != nil {
buf.Myprintf("(%v", node.Length)
if node.Scale != nil {
buf.Myprintf(", %v", node.Scale)
}
buf.Myprintf(")")
}
if node.Charset != "" {
buf.Myprintf("%s %s", node.Operator, node.Charset)
}
}
func (node *ConvertType) walkSubtree(visit Visit) error {
return nil
}
// MatchExpr represents a call to the MATCH function
type MatchExpr struct {
Columns SelectExprs
Expr Expr
Option string
}
// MatchExpr.Option
const (
BooleanModeStr = " in boolean mode"
NaturalLanguageModeStr = " in natural language mode"
NaturalLanguageModeWithQueryExpansionStr = " in natural language mode with query expansion"
QueryExpansionStr = " with query expansion"
)
// Format formats the node
func (node *MatchExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("match(%v) against (%v%s)", node.Columns, node.Expr, node.Option)
}
func (node *MatchExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Columns,
node.Expr,
)
}
func (node *MatchExpr) replace(from, to Expr) bool {
for _, sel := range node.Columns {
aliased, ok := sel.(*AliasedExpr)
if !ok {
continue
}
if replaceExprs(from, to, &aliased.Expr) {
return true
}
}
return replaceExprs(from, to, &node.Expr)
}
// CaseExpr represents a CASE expression.
type CaseExpr struct {
Expr Expr
Whens []*When
Else Expr
}
// Format formats the node.
func (node *CaseExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("case ")
if node.Expr != nil {
buf.Myprintf("%v ", node.Expr)
}
for _, when := range node.Whens {
buf.Myprintf("%v ", when)
}
if node.Else != nil {
buf.Myprintf("else %v ", node.Else)
}
buf.Myprintf("end")
}
func (node *CaseExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
if err := Walk(visit, node.Expr); err != nil {
return err
}
for _, n := range node.Whens {
if err := Walk(visit, n); err != nil {
return err
}
}
return Walk(visit, node.Else)
}
func (node *CaseExpr) replace(from, to Expr) bool {
for _, when := range node.Whens {
if replaceExprs(from, to, &when.Cond, &when.Val) {
return true
}
}
return replaceExprs(from, to, &node.Expr, &node.Else)
}
// Default represents a DEFAULT expression.
type Default struct {
ColName string
}
// Format formats the node.
func (node *Default) Format(buf *TrackedBuffer) {
buf.Myprintf("default")
if node.ColName != "" {
buf.Myprintf("(%s)", node.ColName)
}
}
func (node *Default) walkSubtree(visit Visit) error {
return nil
}
func (node *Default) replace(from, to Expr) bool {
return false
}
// When represents a WHEN sub-expression.
type When struct {
Cond Expr
Val Expr
}
// Format formats the node.
func (node *When) Format(buf *TrackedBuffer) {
buf.Myprintf("when %v then %v", node.Cond, node.Val)
}
func (node *When) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Cond,
node.Val,
)
}
// GroupBy represents a GROUP BY clause.
type GroupBy []Expr
// Format formats the node.
func (node GroupBy) Format(buf *TrackedBuffer) {
prefix := " group by "
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node GroupBy) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// OrderBy represents an ORDER By clause.
type OrderBy []*Order
// Format formats the node.
func (node OrderBy) Format(buf *TrackedBuffer) {
prefix := " order by "
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node OrderBy) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// Order represents an ordering expression.
type Order struct {
Expr Expr
Direction string
}
// Order.Direction
const (
AscScr = "asc"
DescScr = "desc"
)
// Format formats the node.
func (node *Order) Format(buf *TrackedBuffer) {
if node, ok := node.Expr.(*NullVal); ok {
buf.Myprintf("%v", node)
return
}
if node, ok := node.Expr.(*FuncExpr); ok {
if node.Name.Lowered() == "rand" {
buf.Myprintf("%v", node)
return
}
}
buf.Myprintf("%v %s", node.Expr, node.Direction)
}
func (node *Order) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Expr,
)
}
// Limit represents a LIMIT clause.
type Limit struct {
Offset, Rowcount Expr
}
// Format formats the node.
func (node *Limit) Format(buf *TrackedBuffer) {
if node == nil {
return
}
buf.Myprintf(" limit ")
if node.Offset != nil {
buf.Myprintf("%v, ", node.Offset)
}
buf.Myprintf("%v", node.Rowcount)
}
func (node *Limit) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Offset,
node.Rowcount,
)
}
// Values represents a VALUES clause.
type Values []ValTuple
// Format formats the node.
func (node Values) Format(buf *TrackedBuffer) {
prefix := "values "
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node Values) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// UpdateExprs represents a list of update expressions.
type UpdateExprs []*UpdateExpr
// Format formats the node.
func (node UpdateExprs) Format(buf *TrackedBuffer) {
var prefix string
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node UpdateExprs) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// UpdateExpr represents an update expression.
type UpdateExpr struct {
Name *ColName
Expr Expr
}
// Format formats the node.
func (node *UpdateExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v = %v", node.Name, node.Expr)
}
func (node *UpdateExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Name,
node.Expr,
)
}
// SetExprs represents a list of set expressions.
type SetExprs []*SetExpr
// Format formats the node.
func (node SetExprs) Format(buf *TrackedBuffer) {
var prefix string
for _, n := range node {
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
}
func (node SetExprs) walkSubtree(visit Visit) error {
for _, n := range node {
if err := Walk(visit, n); err != nil {
return err
}
}
return nil
}
// SetExpr represents a set expression.
type SetExpr struct {
Name ColIdent
Expr Expr
}
// Format formats the node.
func (node *SetExpr) Format(buf *TrackedBuffer) {
// We don't have to backtick set variable names.
if node.Name.EqualString("charset") || node.Name.EqualString("names") {
buf.Myprintf("%s %v", node.Name.String(), node.Expr)
} else {
buf.Myprintf("%s = %v", node.Name.String(), node.Expr)
}
}
func (node *SetExpr) walkSubtree(visit Visit) error {
if node == nil {
return nil
}
return Walk(
visit,
node.Name,
node.Expr,
)
}
// OnDup represents an ON DUPLICATE KEY clause.
type OnDup UpdateExprs
// Format formats the node.
func (node OnDup) Format(buf *TrackedBuffer) {
if node == nil {
return
}
buf.Myprintf(" on duplicate key update %v", UpdateExprs(node))
}
func (node OnDup) walkSubtree(visit Visit) error {
return Walk(visit, UpdateExprs(node))
}
// ColIdent is a case insensitive SQL identifier. It will be escaped with
// backquotes if necessary.
type ColIdent struct {
// This artifact prevents this struct from being compared
// with itself. It consumes no space as long as it's not the
// last field in the struct.
_ [0]struct{ _ []byte }
val, lowered string
}
// NewColIdent makes a new ColIdent.
func NewColIdent(str string) ColIdent {
return ColIdent{
val: str,
}
}
// Format formats the node.
func (node ColIdent) Format(buf *TrackedBuffer) {
formatID(buf, node.val, node.Lowered())
}
func (node ColIdent) walkSubtree(visit Visit) error {
return nil
}
// IsEmpty returns true if the name is empty.
func (node ColIdent) IsEmpty() bool {
return node.val == ""
}
// String returns the unescaped column name. It must
// not be used for SQL generation. Use sqlparser.String
// instead. The Stringer conformance is for usage
// in templates.
func (node ColIdent) String() string {
return node.val
}
// CompliantName returns a compliant id name
// that can be used for a bind var.
func (node ColIdent) CompliantName() string {
return compliantName(node.val)
}
// Lowered returns a lower-cased column name.
// This function should generally be used only for optimizing
// comparisons.
func (node ColIdent) Lowered() string {
if node.val == "" {
return ""
}
if node.lowered == "" {
node.lowered = strings.ToLower(node.val)
}
return node.lowered
}
// Equal performs a case-insensitive compare.
func (node ColIdent) Equal(in ColIdent) bool {
return node.Lowered() == in.Lowered()
}
// EqualString performs a case-insensitive compare with str.
func (node ColIdent) EqualString(str string) bool {
return node.Lowered() == strings.ToLower(str)
}
// MarshalJSON marshals into JSON.
func (node ColIdent) MarshalJSON() ([]byte, error) {
return json.Marshal(node.val)
}
// UnmarshalJSON unmarshals from JSON.
func (node *ColIdent) UnmarshalJSON(b []byte) error {
var result string
err := json.Unmarshal(b, &result)
if err != nil {
return err
}
node.val = result
return nil
}
// TableIdent is a case sensitive SQL identifier. It will be escaped with
// backquotes if necessary.
type TableIdent struct {
v string
}
// NewTableIdent creates a new TableIdent.
func NewTableIdent(str string) TableIdent {
return TableIdent{v: str}
}
// Format formats the node.
func (node TableIdent) Format(buf *TrackedBuffer) {
formatID(buf, node.v, strings.ToLower(node.v))
}
func (node TableIdent) walkSubtree(visit Visit) error {
return nil
}
// IsEmpty returns true if TabIdent is empty.
func (node TableIdent) IsEmpty() bool {
return node.v == ""
}
// String returns the unescaped table name. It must
// not be used for SQL generation. Use sqlparser.String
// instead. The Stringer conformance is for usage
// in templates.
func (node TableIdent) String() string {
return node.v
}
// CompliantName returns a compliant id name
// that can be used for a bind var.
func (node TableIdent) CompliantName() string {
return compliantName(node.v)
}
// MarshalJSON marshals into JSON.
func (node TableIdent) MarshalJSON() ([]byte, error) {
return json.Marshal(node.v)
}
// UnmarshalJSON unmarshals from JSON.
func (node *TableIdent) UnmarshalJSON(b []byte) error {
var result string
err := json.Unmarshal(b, &result)
if err != nil {
return err
}
node.v = result
return nil
}
// Backtick produces a backticked literal given an input string.
func Backtick(in string) string {
var buf bytes.Buffer
buf.WriteByte('`')
for _, c := range in {
buf.WriteRune(c)
if c == '`' {
buf.WriteByte('`')
}
}
buf.WriteByte('`')
return buf.String()
}
func formatID(buf *TrackedBuffer, original, lowered string) {
isDbSystemVariable := false
if len(original) > 1 && original[:2] == "@@" {
isDbSystemVariable = true
}
for i, c := range original {
if !isLetter(uint16(c)) && (!isDbSystemVariable || !isCarat(uint16(c))) {
if i == 0 || !isDigit(uint16(c)) {
goto mustEscape
}
}
}
if _, ok := keywords[lowered]; ok {
goto mustEscape
}
buf.Myprintf("%s", original)
return
mustEscape:
buf.WriteByte('`')
for _, c := range original {
buf.WriteRune(c)
if c == '`' {
buf.WriteByte('`')
}
}
buf.WriteByte('`')
}
func compliantName(in string) string {
var buf bytes.Buffer
for i, c := range in {
if !isLetter(uint16(c)) {
if i == 0 || !isDigit(uint16(c)) {
buf.WriteByte('_')
continue
}
}
buf.WriteRune(c)
}
return buf.String()
}