Refactoring to support parameterized queries.
This commit is contained in:
parent
bc6a2bbcac
commit
fef8f0ef83
33 changed files with 1112 additions and 1206 deletions
|
|
@ -2,124 +2,88 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dropbox/godropbox/database/sqltypes"
|
||||
"github.com/dropbox/godropbox/errors"
|
||||
)
|
||||
|
||||
type orderByClause struct {
|
||||
isOrderByClause
|
||||
expression Expression
|
||||
ascent bool
|
||||
}
|
||||
|
||||
func (o *orderByClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
|
||||
if o.expression == nil {
|
||||
return errors.Newf(
|
||||
"nil order by clause. Generated sql: %s",
|
||||
out.String())
|
||||
}
|
||||
|
||||
if err := o.expression.SerializeSql(out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if o.ascent {
|
||||
_, _ = out.WriteString(" ASC")
|
||||
} else {
|
||||
_, _ = out.WriteString(" DESC")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Asc(expression Expression) OrderByClause {
|
||||
return &orderByClause{expression: expression, ascent: true}
|
||||
}
|
||||
|
||||
func Desc(expression Expression) OrderByClause {
|
||||
return &orderByClause{expression: expression, ascent: false}
|
||||
}
|
||||
|
||||
func serializeClauses(
|
||||
clauses []Clause,
|
||||
separator []byte,
|
||||
out *bytes.Buffer) (err error) {
|
||||
|
||||
if clauses == nil || len(clauses) == 0 {
|
||||
return errors.Newf("Empty clauses. Generated sql: %s", out.String())
|
||||
}
|
||||
|
||||
if clauses[0] == nil {
|
||||
return errors.Newf("nil clause. Generated sql: %s", out.String())
|
||||
}
|
||||
if err = clauses[0].SerializeSql(out); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, c := range clauses[1:] {
|
||||
_, _ = out.Write(separator)
|
||||
|
||||
if c == nil {
|
||||
return errors.Newf("nil clause. Generated sql: %s", out.String())
|
||||
}
|
||||
if err = c.SerializeSql(out); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Representation of n-ary arithmetic (+ - * /)
|
||||
type arithmeticExpression struct {
|
||||
expressionInterfaceImpl
|
||||
expressions []Expression
|
||||
operator []byte
|
||||
}
|
||||
|
||||
func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
|
||||
if len(arith.expressions) == 0 {
|
||||
return errors.Newf(
|
||||
"Empty arithmetic expression. Generated sql: %s",
|
||||
out.String())
|
||||
}
|
||||
|
||||
clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
|
||||
for i, expr := range arith.expressions {
|
||||
clauses[i] = expr
|
||||
}
|
||||
|
||||
useParentheses := len(clauses) > 1
|
||||
if useParentheses {
|
||||
_ = out.WriteByte('(')
|
||||
}
|
||||
|
||||
if err = serializeClauses(clauses, arith.operator, out); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if useParentheses {
|
||||
_ = out.WriteByte(')')
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
//func serializeClauses(
|
||||
// clauses []Clause,
|
||||
// separator []byte,
|
||||
// out *bytes.Buffer) (err error) {
|
||||
//
|
||||
// if clauses == nil || len(clauses) == 0 {
|
||||
// return errors.Newf("Empty clauses.")
|
||||
// }
|
||||
//
|
||||
// if clauses[0] == nil {
|
||||
// return errors.Newf("nil clause.")
|
||||
// }
|
||||
// if err = clauses[0].Serialize(out); err != nil {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// for _, c := range clauses[1:] {
|
||||
// _, _ = out.Write(separator)
|
||||
//
|
||||
// if c == nil {
|
||||
// return errors.Newf("nil clause.")
|
||||
// }
|
||||
// if err = c.Serialize(out); err != nil {
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//// Representation of n-ary arithmetic (+ - * /)
|
||||
//type arithmeticExpression struct {
|
||||
// expressionInterfaceImpl
|
||||
// expressions []Expression
|
||||
// operator []byte
|
||||
//}
|
||||
//
|
||||
//func (arith *arithmeticExpression) Serialize(out *queryData, options ...serializeOption) error {
|
||||
// if len(arith.expressions) == 0 {
|
||||
// return errors.Newf(
|
||||
// "Empty arithmetic expression.")
|
||||
// }
|
||||
//
|
||||
// clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
|
||||
// for i, expr := range arith.expressions {
|
||||
// clauses[i] = expr
|
||||
// }
|
||||
//
|
||||
// useParentheses := len(clauses) > 1
|
||||
// if useParentheses {
|
||||
// _ = out.WriteByte('(')
|
||||
// }
|
||||
//
|
||||
// if err = serializeClauses(clauses, arith.operator, out); err != nil {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// if useParentheses {
|
||||
// _ = out.WriteByte(')')
|
||||
// }
|
||||
//
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
|
||||
type tupleExpression struct {
|
||||
expressionInterfaceImpl
|
||||
elements listClause
|
||||
}
|
||||
|
||||
func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
|
||||
if len(tuple.elements.clauses) < 1 {
|
||||
func (tuple *tupleExpression) Serialize(out *queryData, options ...serializeOption) error {
|
||||
if len(tuple.elements.clauses) == 0 {
|
||||
return errors.Newf("Tuples must include at least one element")
|
||||
}
|
||||
return tuple.elements.SerializeSql(out)
|
||||
return tuple.elements.Serialize(out)
|
||||
}
|
||||
|
||||
func Tuple(exprs ...Expression) Expression {
|
||||
|
|
@ -141,61 +105,62 @@ type listClause struct {
|
|||
includeParentheses bool
|
||||
}
|
||||
|
||||
func (list *listClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
|
||||
func (list *listClause) Serialize(out *queryData, options ...serializeOption) error {
|
||||
if list.includeParentheses {
|
||||
_ = out.WriteByte('(')
|
||||
out.WriteByte('(')
|
||||
}
|
||||
|
||||
if err := serializeClauses(list.clauses, []byte(","), out); err != nil {
|
||||
if err := serializeClauseList(list.clauses, out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if list.includeParentheses {
|
||||
_ = out.WriteByte(')')
|
||||
out.WriteByte(')')
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type funcExpression struct {
|
||||
expressionInterfaceImpl
|
||||
funcName string
|
||||
args *listClause
|
||||
}
|
||||
|
||||
func (c *funcExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
|
||||
if !validIdentifierName(c.funcName) {
|
||||
return errors.Newf(
|
||||
"Invalid function name: %s. Generated sql: %s",
|
||||
c.funcName,
|
||||
out.String())
|
||||
}
|
||||
_, _ = out.WriteString(c.funcName)
|
||||
if c.args == nil {
|
||||
_, _ = out.WriteString("()")
|
||||
} else {
|
||||
return c.args.SerializeSql(out)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
|
||||
func SqlFunc(funcName string, expressions ...Expression) Expression {
|
||||
f := &funcExpression{
|
||||
funcName: funcName,
|
||||
}
|
||||
if len(expressions) > 0 {
|
||||
args := make([]Clause, len(expressions), len(expressions))
|
||||
for i, expr := range expressions {
|
||||
args[i] = expr
|
||||
}
|
||||
|
||||
f.args = &listClause{
|
||||
clauses: args,
|
||||
includeParentheses: true,
|
||||
}
|
||||
}
|
||||
return f
|
||||
}
|
||||
//
|
||||
//type funcExpression struct {
|
||||
// expressionInterfaceImpl
|
||||
// funcName string
|
||||
// args *listClause
|
||||
//}
|
||||
//
|
||||
//func (c *funcExpression) Serialize(out *queryData, options ...serializeOption) error {
|
||||
// if !validIdentifierName(c.funcName) {
|
||||
// return errors.Newf(
|
||||
// "Invalid function name: %s.",
|
||||
// c.funcName,
|
||||
// out.String())
|
||||
// }
|
||||
// _, _ = out.WriteString(c.funcName)
|
||||
// if c.args == nil {
|
||||
// _, _ = out.WriteString("()")
|
||||
// } else {
|
||||
// return c.args.Serialize(out)
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
|
||||
//func SqlFunc(funcName string, expressions ...Expression) Expression {
|
||||
// f := &funcExpression{
|
||||
// funcName: funcName,
|
||||
// }
|
||||
// if len(expressions) > 0 {
|
||||
// args := make([]Clause, len(expressions), len(expressions))
|
||||
// for i, expr := range expressions {
|
||||
// args[i] = expr
|
||||
// }
|
||||
//
|
||||
// f.args = &listClause{
|
||||
// clauses: args,
|
||||
// includeParentheses: true,
|
||||
// }
|
||||
// }
|
||||
// return f
|
||||
//}
|
||||
|
||||
type intervalExpression struct {
|
||||
expressionInterfaceImpl
|
||||
|
|
@ -205,23 +170,24 @@ type intervalExpression struct {
|
|||
|
||||
var intervalSep = ":"
|
||||
|
||||
func (c *intervalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
|
||||
func (c *intervalExpression) Serialize(out *queryData, options ...serializeOption) error {
|
||||
hours := c.duration / time.Hour
|
||||
minutes := (c.duration % time.Hour) / time.Minute
|
||||
sec := (c.duration % time.Minute) / time.Second
|
||||
msec := (c.duration % time.Second) / time.Microsecond
|
||||
_, _ = out.WriteString("INTERVAL '")
|
||||
out.WriteString("INTERVAL '")
|
||||
if c.negative {
|
||||
_, _ = out.WriteString("-")
|
||||
out.WriteString("-")
|
||||
}
|
||||
_, _ = out.WriteString(strconv.FormatInt(int64(hours), 10))
|
||||
_, _ = out.WriteString(intervalSep)
|
||||
_, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10))
|
||||
_, _ = out.WriteString(intervalSep)
|
||||
_, _ = out.WriteString(strconv.FormatInt(int64(sec), 10))
|
||||
_, _ = out.WriteString(intervalSep)
|
||||
_, _ = out.WriteString(strconv.FormatInt(int64(msec), 10))
|
||||
_, _ = out.WriteString("' HOUR_MICROSECOND")
|
||||
out.WriteString(strconv.FormatInt(int64(hours), 10))
|
||||
out.WriteString(intervalSep)
|
||||
out.WriteString(strconv.FormatInt(int64(minutes), 10))
|
||||
out.WriteString(intervalSep)
|
||||
out.WriteString(strconv.FormatInt(int64(sec), 10))
|
||||
out.WriteString(intervalSep)
|
||||
out.WriteString(strconv.FormatInt(int64(msec), 10))
|
||||
out.WriteString("' HOUR_MICROSECOND")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -246,45 +212,45 @@ func EscapeForLike(s string) string {
|
|||
}
|
||||
|
||||
// Returns an escaped literal string
|
||||
func Literal(v interface{}) Expression {
|
||||
value, err := sqltypes.BuildValue(v)
|
||||
if err != nil {
|
||||
panic(errors.Wrap(err, "Invalid literal value"))
|
||||
}
|
||||
return NewLiteralExpression(value)
|
||||
}
|
||||
|
||||
// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
|
||||
func Add(expressions ...Expression) Expression {
|
||||
return &arithmeticExpression{
|
||||
expressions: expressions,
|
||||
operator: []byte(" + "),
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
|
||||
func Sub(expressions ...Expression) Expression {
|
||||
return &arithmeticExpression{
|
||||
expressions: expressions,
|
||||
operator: []byte(" - "),
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
|
||||
func Mul(expressions ...Expression) Expression {
|
||||
return &arithmeticExpression{
|
||||
expressions: expressions,
|
||||
operator: []byte(" * "),
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
|
||||
func Div(expressions ...Expression) Expression {
|
||||
return &arithmeticExpression{
|
||||
expressions: expressions,
|
||||
operator: []byte(" / "),
|
||||
}
|
||||
}
|
||||
//func Literal(v interface{}) Expression {
|
||||
// value, err := sqltypes.BuildValue(v)
|
||||
// if err != nil {
|
||||
// panic(errors.Wrap(err, "Invalid literal value"))
|
||||
// }
|
||||
// return NewLiteralExpression(value)
|
||||
//}
|
||||
//
|
||||
//// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
|
||||
//func Add(expressions ...Expression) Expression {
|
||||
// return &arithmeticExpression{
|
||||
// expressions: expressions,
|
||||
// operator: []byte(" + "),
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
|
||||
//func Sub(expressions ...Expression) Expression {
|
||||
// return &arithmeticExpression{
|
||||
// expressions: expressions,
|
||||
// operator: []byte(" - "),
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
|
||||
//func Mul(expressions ...Expression) Expression {
|
||||
// return &arithmeticExpression{
|
||||
// expressions: expressions,
|
||||
// operator: []byte(" * "),
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
|
||||
//func Div(expressions ...Expression) Expression {
|
||||
// return &arithmeticExpression{
|
||||
// expressions: expressions,
|
||||
// operator: []byte(" / "),
|
||||
// }
|
||||
//}
|
||||
|
||||
//TODO: Uncomment
|
||||
//
|
||||
|
|
@ -336,14 +302,15 @@ type ifExpression struct {
|
|||
falseExpression Expression
|
||||
}
|
||||
|
||||
func (exp *ifExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
|
||||
_, _ = out.WriteString("IF(")
|
||||
_ = exp.conditional.SerializeSql(out)
|
||||
_, _ = out.WriteString(",")
|
||||
_ = exp.trueExpression.SerializeSql(out)
|
||||
_, _ = out.WriteString(",")
|
||||
_ = exp.falseExpression.SerializeSql(out)
|
||||
_, _ = out.WriteString(")")
|
||||
func (exp *ifExpression) Serialize(out *queryData, options ...serializeOption) error {
|
||||
out.WriteString("IF(")
|
||||
_ = exp.conditional.Serialize(out)
|
||||
out.WriteString(",")
|
||||
_ = exp.trueExpression.Serialize(out)
|
||||
out.WriteString(",")
|
||||
_ = exp.falseExpression.Serialize(out)
|
||||
out.WriteString(")")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -371,7 +338,7 @@ func If(conditional BoolExpression,
|
|||
// }
|
||||
//}
|
||||
//
|
||||
//func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error {
|
||||
//func (cv *columnValueExpression) Serialize(out *bytes.Buffer) error {
|
||||
// _, _ = out.WriteString("VALUES(")
|
||||
// _ = cv.column.SerializeSqlForColumnList(out)
|
||||
// _ = out.WriteByte(')')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue