You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
141 lines
3.9 KiB
141 lines
3.9 KiB
6 years ago
|
/*
|
||
|
Copyright 2017 Google Inc.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package sqlparser
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
)
|
||
|
|
||
|
// NodeFormatter defines the signature of a custom node formatter
|
||
|
// function that can be given to TrackedBuffer for code generation.
|
||
|
type NodeFormatter func(buf *TrackedBuffer, node SQLNode)
|
||
|
|
||
|
// TrackedBuffer is used to rebuild a query from the ast.
|
||
|
// bindLocations keeps track of locations in the buffer that
|
||
|
// use bind variables for efficient future substitutions.
|
||
|
// nodeFormatter is the formatting function the buffer will
|
||
|
// use to format a node. By default(nil), it's FormatNode.
|
||
|
// But you can supply a different formatting function if you
|
||
|
// want to generate a query that's different from the default.
|
||
|
type TrackedBuffer struct {
|
||
|
*bytes.Buffer
|
||
|
bindLocations []bindLocation
|
||
|
nodeFormatter NodeFormatter
|
||
|
}
|
||
|
|
||
|
// NewTrackedBuffer creates a new TrackedBuffer.
|
||
|
func NewTrackedBuffer(nodeFormatter NodeFormatter) *TrackedBuffer {
|
||
|
return &TrackedBuffer{
|
||
|
Buffer: new(bytes.Buffer),
|
||
|
nodeFormatter: nodeFormatter,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WriteNode function, initiates the writing of a single SQLNode tree by passing
|
||
|
// through to Myprintf with a default format string
|
||
|
func (buf *TrackedBuffer) WriteNode(node SQLNode) *TrackedBuffer {
|
||
|
buf.Myprintf("%v", node)
|
||
|
return buf
|
||
|
}
|
||
|
|
||
|
// Myprintf mimics fmt.Fprintf(buf, ...), but limited to Node(%v),
|
||
|
// Node.Value(%s) and string(%s). It also allows a %a for a value argument, in
|
||
|
// which case it adds tracking info for future substitutions.
|
||
|
//
|
||
|
// The name must be something other than the usual Printf() to avoid "go vet"
|
||
|
// warnings due to our custom format specifiers.
|
||
|
func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
|
||
|
end := len(format)
|
||
|
fieldnum := 0
|
||
|
for i := 0; i < end; {
|
||
|
lasti := i
|
||
|
for i < end && format[i] != '%' {
|
||
|
i++
|
||
|
}
|
||
|
if i > lasti {
|
||
|
buf.WriteString(format[lasti:i])
|
||
|
}
|
||
|
if i >= end {
|
||
|
break
|
||
|
}
|
||
|
i++ // '%'
|
||
|
switch format[i] {
|
||
|
case 'c':
|
||
|
switch v := values[fieldnum].(type) {
|
||
|
case byte:
|
||
|
buf.WriteByte(v)
|
||
|
case rune:
|
||
|
buf.WriteRune(v)
|
||
|
default:
|
||
|
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
|
||
|
}
|
||
|
case 's':
|
||
|
switch v := values[fieldnum].(type) {
|
||
|
case []byte:
|
||
|
buf.Write(v)
|
||
|
case string:
|
||
|
buf.WriteString(v)
|
||
|
default:
|
||
|
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
|
||
|
}
|
||
|
case 'v':
|
||
|
node := values[fieldnum].(SQLNode)
|
||
|
if buf.nodeFormatter == nil {
|
||
|
node.Format(buf)
|
||
|
} else {
|
||
|
buf.nodeFormatter(buf, node)
|
||
|
}
|
||
|
case 'a':
|
||
|
buf.WriteArg(values[fieldnum].(string))
|
||
|
default:
|
||
|
panic("unexpected")
|
||
|
}
|
||
|
fieldnum++
|
||
|
i++
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WriteArg writes a value argument into the buffer along with
|
||
|
// tracking information for future substitutions. arg must contain
|
||
|
// the ":" or "::" prefix.
|
||
|
func (buf *TrackedBuffer) WriteArg(arg string) {
|
||
|
buf.bindLocations = append(buf.bindLocations, bindLocation{
|
||
|
offset: buf.Len(),
|
||
|
length: len(arg),
|
||
|
})
|
||
|
buf.WriteString(arg)
|
||
|
}
|
||
|
|
||
|
// ParsedQuery returns a ParsedQuery that contains bind
|
||
|
// locations for easy substitution.
|
||
|
func (buf *TrackedBuffer) ParsedQuery() *ParsedQuery {
|
||
|
return &ParsedQuery{Query: buf.String(), bindLocations: buf.bindLocations}
|
||
|
}
|
||
|
|
||
|
// HasBindVars returns true if the parsed query uses bind vars.
|
||
|
func (buf *TrackedBuffer) HasBindVars() bool {
|
||
|
return len(buf.bindLocations) != 0
|
||
|
}
|
||
|
|
||
|
// BuildParsedQuery builds a ParsedQuery from the input.
|
||
|
func BuildParsedQuery(in string, vars ...interface{}) *ParsedQuery {
|
||
|
buf := NewTrackedBuffer(nil)
|
||
|
buf.Myprintf(in, vars...)
|
||
|
return buf.ParsedQuery()
|
||
|
}
|