You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
236 lines
5.6 KiB
236 lines
5.6 KiB
// Copyright 2018 Huan Du. All rights reserved. |
|
// Licensed under the MIT license that can be found in the LICENSE file. |
|
|
|
package sqlbuilder |
|
|
|
import ( |
|
"bytes" |
|
"database/sql" |
|
"fmt" |
|
"sort" |
|
"strconv" |
|
"strings" |
|
) |
|
|
|
// Args stores arguments associated with a SQL. |
|
type Args struct { |
|
// The default flavor used by `Args#Compile` |
|
Flavor Flavor |
|
|
|
args []interface{} |
|
namedArgs map[string]int |
|
sqlNamedArgs map[string]int |
|
onlyNamed bool |
|
} |
|
|
|
// Add adds an arg to Args and returns a placeholder. |
|
func (args *Args) Add(arg interface{}) string { |
|
return fmt.Sprintf("$%v", args.add(arg)) |
|
} |
|
|
|
func (args *Args) add(arg interface{}) int { |
|
idx := len(args.args) |
|
|
|
switch a := arg.(type) { |
|
case sql.NamedArg: |
|
if args.sqlNamedArgs == nil { |
|
args.sqlNamedArgs = map[string]int{} |
|
} |
|
|
|
if p, ok := args.sqlNamedArgs[a.Name]; ok { |
|
arg = args.args[p] |
|
break |
|
} |
|
|
|
args.sqlNamedArgs[a.Name] = idx |
|
case namedArgs: |
|
if args.namedArgs == nil { |
|
args.namedArgs = map[string]int{} |
|
} |
|
|
|
if p, ok := args.namedArgs[a.name]; ok { |
|
arg = args.args[p] |
|
break |
|
} |
|
|
|
// Find out the real arg and add it to args. |
|
idx = args.add(a.arg) |
|
args.namedArgs[a.name] = idx |
|
return idx |
|
} |
|
|
|
args.args = append(args.args, arg) |
|
return idx |
|
} |
|
|
|
// Compile compiles builder's format to standard sql and returns associated args. |
|
// |
|
// The format string uses a special syntax to represent arguments. |
|
// |
|
// $? refers successive arguments passed in the call. It works similar as `%v` in `fmt.Sprintf`. |
|
// $0 $1 ... $n refers nth-argument passed in the call. Next $? will use arguments n+1. |
|
// ${name} refers a named argument created by `Named` with `name`. |
|
// $$ is a "$" string. |
|
func (args *Args) Compile(format string, intialValue ...interface{}) (query string, values []interface{}) { |
|
return args.CompileWithFlavor(format, args.Flavor, intialValue...) |
|
} |
|
|
|
// CompileWithFlavor compiles builder's format to standard sql with flavor and returns associated args. |
|
// |
|
// See doc for `Compile` to learn details. |
|
func (args *Args) CompileWithFlavor(format string, flavor Flavor, intialValue ...interface{}) (query string, values []interface{}) { |
|
buf := &bytes.Buffer{} |
|
idx := strings.IndexRune(format, '$') |
|
offset := 0 |
|
values = intialValue |
|
|
|
if flavor == invalidFlavor { |
|
flavor = DefaultFlavor |
|
} |
|
|
|
for idx >= 0 && len(format) > 0 { |
|
if idx > 0 { |
|
buf.WriteString(format[:idx]) |
|
} |
|
|
|
format = format[idx+1:] |
|
|
|
// Should not happen. |
|
if len(format) == 0 { |
|
break |
|
} |
|
|
|
if format[0] == '$' { |
|
buf.WriteRune('$') |
|
format = format[1:] |
|
} else if format[0] == '{' { |
|
format, values = args.compileNamed(buf, flavor, format, values) |
|
} else if !args.onlyNamed && '0' <= format[0] && format[0] <= '9' { |
|
format, values, offset = args.compileDigits(buf, flavor, format, values, offset) |
|
} else if !args.onlyNamed && format[0] == '?' { |
|
format, values, offset = args.compileSuccessive(buf, flavor, format[1:], values, offset) |
|
} |
|
|
|
idx = strings.IndexRune(format, '$') |
|
} |
|
|
|
if len(format) > 0 { |
|
buf.WriteString(format) |
|
} |
|
|
|
query = buf.String() |
|
|
|
if len(args.sqlNamedArgs) > 0 { |
|
// Stabilize the sequence to make it easier to write test cases. |
|
ints := make([]int, 0, len(args.sqlNamedArgs)) |
|
|
|
for _, p := range args.sqlNamedArgs { |
|
ints = append(ints, p) |
|
} |
|
|
|
sort.Ints(ints) |
|
|
|
for _, i := range ints { |
|
values = append(values, args.args[i]) |
|
} |
|
} |
|
|
|
return |
|
} |
|
|
|
func (args *Args) compileNamed(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}) (string, []interface{}) { |
|
i := 1 |
|
|
|
for ; i < len(format) && format[i] != '}'; i++ { |
|
// Nothing. |
|
} |
|
|
|
// Invalid $ format. Ignore it. |
|
if i == len(format) { |
|
return format, values |
|
} |
|
|
|
name := format[1:i] |
|
format = format[i+1:] |
|
|
|
if p, ok := args.namedArgs[name]; ok { |
|
format, values, _ = args.compileSuccessive(buf, flavor, format, values, p) |
|
} |
|
|
|
return format, values |
|
} |
|
|
|
func (args *Args) compileDigits(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) { |
|
i := 1 |
|
|
|
for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ { |
|
// Nothing. |
|
} |
|
|
|
digits := format[:i] |
|
format = format[i:] |
|
|
|
if pointer, err := strconv.Atoi(digits); err == nil { |
|
return args.compileSuccessive(buf, flavor, format, values, pointer) |
|
} |
|
|
|
return format, values, offset |
|
} |
|
|
|
func (args *Args) compileSuccessive(buf *bytes.Buffer, flavor Flavor, format string, values []interface{}, offset int) (string, []interface{}, int) { |
|
if offset >= len(args.args) { |
|
return format, values, offset |
|
} |
|
|
|
arg := args.args[offset] |
|
values = args.compileArg(buf, flavor, values, arg) |
|
|
|
return format, values, offset + 1 |
|
} |
|
|
|
func (args *Args) compileArg(buf *bytes.Buffer, flavor Flavor, values []interface{}, arg interface{}) []interface{} { |
|
switch a := arg.(type) { |
|
case Builder: |
|
var s string |
|
s, values = a.BuildWithFlavor(flavor, values...) |
|
buf.WriteString(s) |
|
case sql.NamedArg: |
|
buf.WriteRune('@') |
|
buf.WriteString(a.Name) |
|
case rawArgs: |
|
buf.WriteString(a.expr) |
|
case listArgs: |
|
if len(a.args) > 0 { |
|
values = args.compileArg(buf, flavor, values, a.args[0]) |
|
} |
|
|
|
for i := 1; i < len(a.args); i++ { |
|
buf.WriteString(", ") |
|
values = args.compileArg(buf, flavor, values, a.args[i]) |
|
} |
|
default: |
|
switch flavor { |
|
case MySQL: |
|
buf.WriteRune('?') |
|
case PostgreSQL: |
|
fmt.Fprintf(buf, "$%v", len(values)+1) |
|
default: |
|
panic(fmt.Errorf("Args.CompileWithFlavor: invalid flavor %v (%v)", flavor, int(flavor))) |
|
} |
|
|
|
values = append(values, arg) |
|
} |
|
|
|
return values |
|
} |
|
|
|
// Copy is |
|
func (args *Args) Copy() *Args { |
|
return &Args{ |
|
Flavor: args.Flavor, |
|
args: args.args, |
|
namedArgs: args.namedArgs, |
|
sqlNamedArgs: args.sqlNamedArgs, |
|
onlyNamed: args.onlyNamed, |
|
} |
|
}
|
|
|