Bool expression refactoring.

This commit is contained in:
zer0sub 2019-03-31 09:17:28 +02:00
parent 5a7563d4af
commit 38007810c1
15 changed files with 818 additions and 583 deletions

View file

@ -0,0 +1,34 @@
package sqlbuilder
import "bytes"
type Alias struct {
Clause
expression Expression
alias string
}
func NewAlias(expression Expression, alias string) *Alias {
if !validIdentifierName(alias) {
panic("Invalid alias")
}
return &Alias{
expression: expression,
alias: alias,
}
}
func (a *Alias) SerializeSql(out *bytes.Buffer) error {
err := a.expression.SerializeSql(out)
if err != nil {
return err
}
out.WriteString(" AS \"" + a.alias + "\"")
return nil
}

View file

@ -8,13 +8,160 @@ import (
"time" "time"
) )
type BoolExpression interface {
Expression
And(expression BoolExpression) BoolExpression
Or(expression BoolExpression) BoolExpression
IsTrue() BoolExpression
IsFalse() BoolExpression
}
type boolInterfaceImpl struct {
parent BoolExpression
}
func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression {
return And(b.parent, expression)
}
func (b *boolInterfaceImpl) Or(expression BoolExpression) BoolExpression {
return Or(b.parent, expression)
}
func (b *boolInterfaceImpl) IsTrue() BoolExpression {
return IsTrue(b.parent)
}
func (b *boolInterfaceImpl) IsFalse() BoolExpression {
return nil
}
//---------------------------------------------------//
type boolLiteralExpression struct {
boolInterfaceImpl
literalExpression
}
func NewBoolLiteralExpression(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{}
sqlValue, err := sqltypes.BuildValue(value)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
boolLiteralExpression.literalExpression = *NewLiteralExpression(sqlValue)
boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression
return &boolLiteralExpression
}
//---------------------------------------------------//
type binaryBoolExpression struct {
boolInterfaceImpl
binaryExpression
}
func NewBinaryBoolExpression(lhs, rhs Expression, operator []byte) BoolExpression {
boolExpression := binaryBoolExpression{}
boolExpression.binaryExpression = *NewBinaryExpression(lhs, rhs, operator, &boolExpression)
boolExpression.boolInterfaceImpl.parent = &boolExpression
return &boolExpression
}
//---------------------------------------------------//
type prefixBoolExpression struct {
boolInterfaceImpl
prefixExpression
}
func NewPrefixBoolExpression(expression Expression, operator []byte) BoolExpression {
boolExpression := prefixBoolExpression{}
boolExpression.prefixExpression = *NewPrefixExpression(expression, operator, &boolExpression)
boolExpression.boolInterfaceImpl.parent = &boolExpression
return &boolExpression
}
//---------------------------------------------------//
type conjunctBoolExpression struct {
boolInterfaceImpl
conjunctExpression
name string
}
func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) BoolExpression {
boolExpression := conjunctBoolExpression{
conjunctExpression: conjunctExpression{
expressions: expressions,
conjunction: operator,
},
}
//boolExpression.expressionInterfaceImpl.parent = &boolExpression
//boolExpression.boolInterfaceImpl.parent = &boolExpression
return &boolExpression
}
//---------------------------------------------------//
type inExpression struct {
expressionInterfaceImpl
boolInterfaceImpl
lhs Expression
rhs *listClause
err error
}
func (c *inExpression) SerializeSql(out *bytes.Buffer) error {
if c.err != nil {
return errors.Wrap(c.err, "Invalid IN expression")
}
if c.lhs == nil {
return errors.Newf(
"lhs of in expression is nil. Generated sql: %s",
out.String())
}
// We'll serialize the lhs even if we don't need it to ensure no error
buf := &bytes.Buffer{}
err := c.lhs.SerializeSql(buf)
if err != nil {
return err
}
if c.rhs == nil {
_, _ = out.WriteString("FALSE")
return nil
}
_, _ = out.WriteString(buf.String())
_, _ = out.WriteString(" IN ")
err = c.rhs.SerializeSql(out)
if err != nil {
return err
}
return nil
}
// Returns a representation of "a=b" // Returns a representation of "a=b"
func Eq(lhs, rhs Expression) BoolExpression { func Eq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression) lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() { if ok && sqltypes.Value(lit.value).IsNull() {
return newBoolExpression(lhs, rhs, []byte(" IS ")) return NewBinaryBoolExpression(lhs, rhs, []byte(" IS "))
} }
return newBoolExpression(lhs, rhs, []byte(" = ")) return NewBinaryBoolExpression(lhs, rhs, []byte(" = "))
} }
// Returns a representation of "a=b", where b is a literal // Returns a representation of "a=b", where b is a literal
@ -26,9 +173,9 @@ func EqL(lhs Expression, val interface{}) BoolExpression {
func Neq(lhs, rhs Expression) BoolExpression { func Neq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression) lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() { if ok && sqltypes.Value(lit.value).IsNull() {
return newBoolExpression(lhs, rhs, []byte(" IS NOT ")) return NewBinaryBoolExpression(lhs, rhs, []byte(" IS NOT "))
} }
return newBoolExpression(lhs, rhs, []byte("!=")) return NewBinaryBoolExpression(lhs, rhs, []byte("!="))
} }
// Returns a representation of "a!=b", where b is a literal // Returns a representation of "a!=b", where b is a literal
@ -38,7 +185,7 @@ func NeqL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a<b" // Returns a representation of "a<b"
func Lt(lhs Expression, rhs Expression) BoolExpression { func Lt(lhs Expression, rhs Expression) BoolExpression {
return newBoolExpression(lhs, rhs, []byte("<")) return NewBinaryBoolExpression(lhs, rhs, []byte("<"))
} }
// Returns a representation of "a<b", where b is a literal // Returns a representation of "a<b", where b is a literal
@ -48,7 +195,7 @@ func LtL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a<=b" // Returns a representation of "a<=b"
func Lte(lhs, rhs Expression) BoolExpression { func Lte(lhs, rhs Expression) BoolExpression {
return newBoolExpression(lhs, rhs, []byte("<=")) return NewBinaryBoolExpression(lhs, rhs, []byte("<="))
} }
// Returns a representation of "a<=b", where b is a literal // Returns a representation of "a<=b", where b is a literal
@ -58,7 +205,7 @@ func LteL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a>b" // Returns a representation of "a>b"
func Gt(lhs, rhs Expression) BoolExpression { func Gt(lhs, rhs Expression) BoolExpression {
return newBoolExpression(lhs, rhs, []byte(">")) return NewBinaryBoolExpression(lhs, rhs, []byte(">"))
} }
// Returns a representation of "a>b", where b is a literal // Returns a representation of "a>b", where b is a literal
@ -68,7 +215,7 @@ func GtL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a>=b" // Returns a representation of "a>=b"
func Gte(lhs, rhs Expression) BoolExpression { func Gte(lhs, rhs Expression) BoolExpression {
return newBoolExpression(lhs, rhs, []byte(">=")) return NewBinaryBoolExpression(lhs, rhs, []byte(">="))
} }
// Returns a representation of "a>=b", where b is a literal // Returns a representation of "a>=b", where b is a literal
@ -78,29 +225,25 @@ func GteL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "not expr" // Returns a representation of "not expr"
func Not(expr BoolExpression) BoolExpression { func Not(expr BoolExpression) BoolExpression {
return &negateExpression{ return NewPrefixBoolExpression(expr, []byte(" NOT "))
nested: expr, }
}
func IsTrue(expr BoolExpression) BoolExpression {
return NewPrefixBoolExpression(expr, []byte(" IS TRUE "))
} }
// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses // Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses
func And(expressions ...BoolExpression) BoolExpression { func And(expressions ...BoolExpression) BoolExpression {
return &conjunctExpression{ return NewConjunctBoolExpression([]byte(" AND "), expressions...)
expressions: expressions,
conjunction: []byte(" AND "),
}
} }
// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses // Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses
func Or(expressions ...BoolExpression) BoolExpression { func Or(expressions ...BoolExpression) BoolExpression {
return &conjunctExpression{ return NewConjunctBoolExpression([]byte(" OR "), expressions...)
expressions: expressions,
conjunction: []byte(" OR "),
}
} }
func Like(lhs, rhs Expression) BoolExpression { func Like(lhs, rhs Expression) BoolExpression {
return newBoolExpression(lhs, rhs, []byte(" LIKE ")) return NewBinaryBoolExpression(lhs, rhs, []byte(" LIKE "))
} }
func LikeL(lhs Expression, val string) BoolExpression { func LikeL(lhs Expression, val string) BoolExpression {
@ -108,7 +251,7 @@ func LikeL(lhs Expression, val string) BoolExpression {
} }
func Regexp(lhs, rhs Expression) BoolExpression { func Regexp(lhs, rhs Expression) BoolExpression {
return newBoolExpression(lhs, rhs, []byte(" REGEXP ")) return NewBinaryBoolExpression(lhs, rhs, []byte(" REGEXP "))
} }
func RegexpL(lhs Expression, val string) BoolExpression { func RegexpL(lhs Expression, val string) BoolExpression {
@ -206,144 +349,3 @@ func In(lhs Expression, valList interface{}) BoolExpression {
} }
return expr return expr
} }
type boolExpressionImpl struct {
isExpression
isBoolExpression
}
func (c *boolExpressionImpl) And(expression BoolExpression) BoolExpression {
return And(c, expression)
}
func (c *boolExpressionImpl) Or(expression BoolExpression) BoolExpression {
return Or(c, expression)
}
func (conj *boolExpressionImpl) SerializeSql(out *bytes.Buffer) (err error) {
return errors.New("Not implemented")
}
// Representation of n-ary conjunctions (AND/OR)
type conjunctExpression struct {
boolExpressionImpl
expressions []BoolExpression
conjunction []byte
}
func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) {
if len(conj.expressions) == 0 {
return errors.Newf(
"Empty conjunction. Generated sql: %s",
out.String())
}
clauses := make([]Clause, len(conj.expressions), len(conj.expressions))
for i, expr := range conj.expressions {
clauses[i] = expr
}
useParentheses := len(clauses) > 1
if useParentheses {
_ = out.WriteByte('(')
}
if err = serializeClauses(clauses, conj.conjunction, out); err != nil {
return
}
if useParentheses {
_ = out.WriteByte(')')
}
return nil
}
// A not expression which negates a expression value
type negateExpression struct {
boolExpressionImpl
nested BoolExpression
}
func (c *negateExpression) SerializeSql(out *bytes.Buffer) (err error) {
_, _ = out.WriteString("NOT (")
if c.nested == nil {
return errors.Newf("nil nested. Generated sql: %s", out.String())
}
if err = c.nested.SerializeSql(out); err != nil {
return
}
_ = out.WriteByte(')')
return nil
}
// A binary expression that evaluates to a boolean value.
type boolBinaryExpression struct {
boolExpressionImpl
binaryExpression binaryExpression
}
func (b *boolBinaryExpression) And(expression BoolExpression) BoolExpression {
return And(b, expression)
}
func newBoolExpression(lhs, rhs Expression, operator []byte) *boolBinaryExpression {
// go does not allow {} syntax for initializing promoted fields ...
expr := new(boolBinaryExpression)
expr.binaryExpression.lhs = lhs
expr.binaryExpression.rhs = rhs
expr.binaryExpression.operator = operator
return expr
}
func (b *boolBinaryExpression) SerializeSql(out *bytes.Buffer) (err error) {
return b.binaryExpression.SerializeSql(out)
}
// in expression representation
type inExpression struct {
boolExpressionImpl
lhs Expression
rhs *listClause
err error
}
func (c *inExpression) SerializeSql(out *bytes.Buffer) error {
if c.err != nil {
return errors.Wrap(c.err, "Invalid IN expression")
}
if c.lhs == nil {
return errors.Newf(
"lhs of in expression is nil. Generated sql: %s",
out.String())
}
// We'll serialize the lhs even if we don't need it to ensure no error
buf := &bytes.Buffer{}
err := c.lhs.SerializeSql(buf)
if err != nil {
return err
}
if c.rhs == nil {
_, _ = out.WriteString("FALSE")
return nil
}
_, _ = out.WriteString(buf.String())
_, _ = out.WriteString(" IN ")
err = c.rhs.SerializeSql(out)
if err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,108 @@
package sqlbuilder
import (
"bytes"
"gotest.tools/assert"
"testing"
)
func TestBinaryExpression(t *testing.T) {
boolExpression := Eq(Literal(2), Literal(3))
out := bytes.Buffer{}
err := boolExpression.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "2 = 3")
t.Run("alias", func(t *testing.T) {
alias := boolExpression.As("alias_eq_expression")
out := bytes.Buffer{}
err := alias.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `2 = 3 AS "alias_eq_expression"`)
})
t.Run("and", func(t *testing.T) {
exp := boolExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`)
})
t.Run("or", func(t *testing.T) {
exp := boolExpression.Or(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`)
})
}
func TestUnaryExpression(t *testing.T) {
notExpression := Not(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{}
err := notExpression.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), " NOT 2 = 1")
t.Run("alias", func(t *testing.T) {
alias := notExpression.As("alias_not_expression")
out := bytes.Buffer{}
err := alias.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), ` NOT 2 = 1 AS "alias_not_expression"`)
})
t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`)
})
}
func TestUnaryIsTrueExpression(t *testing.T) {
notExpression := IsTrue(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{}
err := notExpression.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), " IS TRUE 2 = 1")
t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`)
})
}
func TestBoolLiteral(t *testing.T) {
literal := NewBoolLiteralExpression(true)
out := bytes.Buffer{}
err := literal.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "true")
}

7
sqlbuilder/clause.go Normal file
View file

@ -0,0 +1,7 @@
package sqlbuilder
import "bytes"
type Clause interface {
SerializeSql(out *bytes.Buffer) error
}

View file

@ -14,18 +14,14 @@ import (
// Representation of a tableName for query generation // Representation of a tableName for query generation
type Column interface { type Column interface {
Expression
isProjectionInterface isProjectionInterface
isExpressionInterface
As(alias string) Projection
Name() string Name() string
TableName() string TableName() string
// Serialization for use in column lists // Serialization for use in column lists
SerializeSqlForColumnList(out *bytes.Buffer) error SerializeSqlForColumnList(out *bytes.Buffer) error
// Serialization for use in an expression (Clause)
SerializeSql(out *bytes.Buffer) error
// Internal function for tracking tableName that a column belongs to // Internal function for tracking tableName that a column belongs to
// for the purpose of serialization // for the purpose of serialization
@ -54,7 +50,6 @@ const (
// A column that can be refer to outside of the projection list // A column that can be refer to outside of the projection list
type NonAliasColumn interface { type NonAliasColumn interface {
Column Column
isOrderByClauseInterface
} }
type Collation string type Collation string
@ -74,20 +69,20 @@ const (
// The base type for real materialized columns. // The base type for real materialized columns.
type baseColumn struct { type baseColumn struct {
expressionInterfaceImpl
isProjection isProjection
isExpression
name string name string
nullable NullableColumn nullable NullableColumn
tableName string tableName string
alias string alias string
} }
func (c *baseColumn) As(alias string) Projection { //func (c *baseColumn) As(alias string) Projection {
newBaseColumn := *c // newBaseColumn := *c
newBaseColumn.alias = alias // newBaseColumn.alias = alias
//
return &newBaseColumn // return &newBaseColumn
} //}
func (c *baseColumn) Name() string { func (c *baseColumn) Name() string {
return c.name return c.name
@ -167,7 +162,6 @@ func (c *baseColumn) Desc() OrderByClause {
type bytesColumn struct { type bytesColumn struct {
baseColumn baseColumn
isExpression
} }
// Representation of VARBINARY/BLOB columns // Representation of VARBINARY/BLOB columns
@ -184,7 +178,6 @@ func BytesColumn(name string, nullable NullableColumn) NonAliasColumn {
type stringColumn struct { type stringColumn struct {
baseColumn baseColumn
isExpression
charset Charset charset Charset
collation Collation collation Collation
} }
@ -208,7 +201,6 @@ func StrColumn(
type dateTimeColumn struct { type dateTimeColumn struct {
baseColumn baseColumn
isExpression
} }
// Representation of DateTime columns // Representation of DateTime columns
@ -225,7 +217,6 @@ func DateTimeColumn(name string, nullable NullableColumn) NonAliasColumn {
type IntegerColumn struct { type IntegerColumn struct {
baseColumn baseColumn
isExpression
} }
// Representation of any integer column // Representation of any integer column
@ -242,7 +233,6 @@ func IntColumn(name string, nullable NullableColumn) *IntegerColumn {
type doubleColumn struct { type doubleColumn struct {
baseColumn baseColumn
isExpression
} }
// Representation of any double column // Representation of any double column
@ -259,7 +249,6 @@ func DoubleColumn(name string, nullable NullableColumn) NonAliasColumn {
type booleanColumn struct { type booleanColumn struct {
baseColumn baseColumn
isExpression
// XXX: Maybe allow isBoolExpression (for now, not included because // XXX: Maybe allow isBoolExpression (for now, not included because
// the deferred lookup equivalent can never be isBoolExpression) // the deferred lookup equivalent can never be isBoolExpression)
@ -322,12 +311,12 @@ func (c *aliasColumn) setTableName(table string) error {
} }
// Representation of aliased clauses (expression AS name) // Representation of aliased clauses (expression AS name)
func Alias(name string, c Expression) Column { //func Alias(name string, c Expression) Column {
ac := &aliasColumn{} // ac := &aliasColumn{}
ac.name = name // ac.name = name
ac.expression = c // ac.expression = c
return ac // return ac
} //}
// This is a strict subset of the actual allowed identifiers // This is a strict subset of the actual allowed identifiers
var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$") var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$")

View file

@ -1,3 +1,5 @@
// +build disabled
package sqlbuilder package sqlbuilder
import ( import (

View file

@ -1,3 +1,5 @@
// +build disabled
package sqlbuilder package sqlbuilder
import "fmt" import "fmt"

View file

@ -1,179 +1,56 @@
// Query building functions for expression components
package sqlbuilder package sqlbuilder
import ( import (
"bytes" "bytes"
"strconv"
"strings"
"time"
"github.com/dropbox/godropbox/database/sqltypes" "github.com/dropbox/godropbox/database/sqltypes"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
) )
type orderByClause struct { // An expression
isOrderByClause type Expression interface {
expression Expression Clause
ascent bool
As(alias string) Clause
IsDistinct(expression Expression) BoolExpression
IsNull(expression Expression) BoolExpression
} }
func (o *orderByClause) SerializeSql(out *bytes.Buffer) error { type expressionInterfaceImpl struct {
if o.expression == nil { parent Expression
return errors.Newf( }
"nil order by clause. Generated sql: %s",
out.String())
}
if err := o.expression.SerializeSql(out); err != nil { func (e *expressionInterfaceImpl) As(alias string) Clause {
return err return NewAlias(e.parent, alias)
} }
if o.ascent {
_, _ = out.WriteString(" ASC")
} else {
_, _ = out.WriteString(" DESC")
}
func (e *expressionInterfaceImpl) IsDistinct(expression Expression) BoolExpression {
return nil return nil
} }
func Asc(expression Expression) OrderByClause { func (e *expressionInterfaceImpl) IsNull(expression Expression) BoolExpression {
return &orderByClause{expression: expression, ascent: true}
}
func Desc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: false}
}
// Representation of an escaped literal
type literalExpression struct {
isExpression
value sqltypes.Value
}
func (c literalExpression) SerializeSql(out *bytes.Buffer) error {
sqltypes.Value(c.value).EncodeSql(out)
return nil
}
func serializeClauses(
clauses []Clause,
separator []byte,
out *bytes.Buffer) (err error) {
if clauses == nil || len(clauses) == 0 {
return errors.Newf("Empty clauses. Generated sql: %s", out.String())
}
if clauses[0] == nil {
return errors.Newf("nil clause. Generated sql: %s", out.String())
}
if err = clauses[0].SerializeSql(out); err != nil {
return
}
for _, c := range clauses[1:] {
_, _ = out.Write(separator)
if c == nil {
return errors.Newf("nil clause. Generated sql: %s", out.String())
}
if err = c.SerializeSql(out); err != nil {
return
}
}
return nil
}
// Representation of n-ary arithmetic (+ - * /)
type arithmeticExpression struct {
isExpression
expressions []Expression
operator []byte
}
func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer) (err error) {
if len(arith.expressions) == 0 {
return errors.Newf(
"Empty arithmetic expression. Generated sql: %s",
out.String())
}
clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
for i, expr := range arith.expressions {
clauses[i] = expr
}
useParentheses := len(clauses) > 1
if useParentheses {
_ = out.WriteByte('(')
}
if err = serializeClauses(clauses, arith.operator, out); err != nil {
return
}
if useParentheses {
_ = out.WriteByte(')')
}
return nil
}
type tupleExpression struct {
isExpression
elements listClause
}
func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer) error {
if len(tuple.elements.clauses) < 1 {
return errors.Newf("Tuples must include at least one element")
}
return tuple.elements.SerializeSql(out)
}
func Tuple(exprs ...Expression) Expression {
clauses := make([]Clause, 0, len(exprs))
for _, expr := range exprs {
clauses = append(clauses, expr)
}
return &tupleExpression{
elements: listClause{
clauses: clauses,
includeParentheses: true,
},
}
}
// Representation of a tuple enclosed, comma separated list of clauses
type listClause struct {
clauses []Clause
includeParentheses bool
}
func (list *listClause) SerializeSql(out *bytes.Buffer) error {
if list.includeParentheses {
_ = out.WriteByte('(')
}
if err := serializeClauses(list.clauses, []byte(","), out); err != nil {
return err
}
if list.includeParentheses {
_ = out.WriteByte(')')
}
return nil return nil
} }
// Representation of binary operations (e.g. comparisons, arithmetic) // Representation of binary operations (e.g. comparisons, arithmetic)
type binaryExpression struct { type binaryExpression struct {
isExpression expressionInterfaceImpl
lhs, rhs Expression lhs, rhs Expression
operator []byte operator []byte
} }
func NewBinaryExpression(lhs, rhs Expression, operator []byte, parent ...Expression) *binaryExpression {
binaryExpression := binaryExpression{
lhs: lhs,
rhs: rhs,
operator: operator,
}
if len(parent) > 0 {
binaryExpression.parent = parent[0]
}
return &binaryExpression
}
func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) { func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) {
if c.lhs == nil { if c.lhs == nil {
return errors.Newf("nil lhs. Generated sql: %s", out.String()) return errors.Newf("nil lhs. Generated sql: %s", out.String())
@ -194,220 +71,90 @@ func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) {
return nil return nil
} }
type funcExpression struct { // A not expression which negates a expression value
isExpression type prefixExpression struct {
funcName string expressionInterfaceImpl
args *listClause
expression Expression
operator []byte
} }
func (c *funcExpression) SerializeSql(out *bytes.Buffer) (err error) { func NewPrefixExpression(expression Expression, operator []byte, parent ...Expression) *prefixExpression {
if !validIdentifierName(c.funcName) { prefixExpression := prefixExpression{
expression: expression,
operator: operator,
}
if len(parent) > 0 {
prefixExpression.parent = parent[0]
}
return &prefixExpression
}
func (p *prefixExpression) SerializeSql(out *bytes.Buffer) (err error) {
_, _ = out.Write(p.operator)
if p.expression == nil {
return errors.Newf("nil prefix expression. Generated sql: %s", out.String())
}
if err = p.expression.SerializeSql(out); err != nil {
return
}
return nil
}
// Representation of n-ary conjunctions (AND/OR)
type conjunctExpression struct {
expressionInterfaceImpl
expressions []BoolExpression
conjunction []byte
}
func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) {
if len(conj.expressions) == 0 {
return errors.Newf( return errors.Newf(
"Invalid function name: %s. Generated sql: %s", "Empty conjunction. Generated sql: %s",
c.funcName,
out.String()) out.String())
} }
_, _ = out.WriteString(c.funcName)
if c.args == nil { clauses := make([]Clause, len(conj.expressions), len(conj.expressions))
_, _ = out.WriteString("()") for i, expr := range conj.expressions {
} else { clauses[i] = expr
return c.args.SerializeSql(out)
} }
useParentheses := len(clauses) > 1
if useParentheses {
_ = out.WriteByte('(')
}
if err = serializeClauses(clauses, conj.conjunction, out); err != nil {
return
}
if useParentheses {
_ = out.WriteByte(')')
}
return nil return nil
} }
// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) //--------------------------------------------------------------
func SqlFunc(funcName string, expressions ...Expression) Expression {
f := &funcExpression{
funcName: funcName,
}
if len(expressions) > 0 {
args := make([]Clause, len(expressions), len(expressions))
for i, expr := range expressions {
args[i] = expr
}
f.args = &listClause{ // Representation of an escaped literal
clauses: args, type literalExpression struct {
includeParentheses: true, expressionInterfaceImpl
} value sqltypes.Value
}
return f
} }
type intervalExpression struct { func NewLiteralExpression(value sqltypes.Value) *literalExpression {
isExpression exp := literalExpression{value: value}
duration time.Duration exp.expressionInterfaceImpl.parent = &exp
negative bool
return &exp
} }
var intervalSep = ":" func (c literalExpression) SerializeSql(out *bytes.Buffer) error {
sqltypes.Value(c.value).EncodeSql(out)
func (c *intervalExpression) SerializeSql(out *bytes.Buffer) (err error) {
hours := c.duration / time.Hour
minutes := (c.duration % time.Hour) / time.Minute
sec := (c.duration % time.Minute) / time.Second
msec := (c.duration % time.Second) / time.Microsecond
_, _ = out.WriteString("INTERVAL '")
if c.negative {
_, _ = out.WriteString("-")
}
_, _ = out.WriteString(strconv.FormatInt(int64(hours), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(sec), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(msec), 10))
_, _ = out.WriteString("' HOUR_MICROSECOND")
return nil
}
// Interval returns a representation of duration
// in a form "INTERVAL `hour:min:sec:microsec` HOUR_MICROSECOND"
func Interval(duration time.Duration) Expression {
negative := false
if duration < 0 {
negative = true
duration = -duration
}
return &intervalExpression{
duration: duration,
negative: negative,
}
}
var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%")
func EscapeForLike(s string) string {
return likeEscaper.Replace(s)
}
// Returns an escaped literal string
func Literal(v interface{}) Expression {
value, err := sqltypes.BuildValue(v)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
return &literalExpression{value: value}
}
// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
func Add(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" + "),
}
}
// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
func Sub(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" - "),
}
}
// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
func Mul(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" * "),
}
}
// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
func Div(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" / "),
}
}
func BitOr(lhs, rhs Expression) Expression {
return &binaryExpression{
lhs: lhs,
rhs: rhs,
operator: []byte(" | "),
}
}
func BitAnd(lhs, rhs Expression) Expression {
return &binaryExpression{
lhs: lhs,
rhs: rhs,
operator: []byte(" & "),
}
}
func BitXor(lhs, rhs Expression) Expression {
return &binaryExpression{
lhs: lhs,
rhs: rhs,
operator: []byte(" ^ "),
}
}
func Plus(lhs, rhs Expression) Expression {
return &binaryExpression{
lhs: lhs,
rhs: rhs,
operator: []byte(" + "),
}
}
func Minus(lhs, rhs Expression) Expression {
return &binaryExpression{
lhs: lhs,
rhs: rhs,
operator: []byte(" - "),
}
}
type ifExpression struct {
isExpression
conditional BoolExpression
trueExpression Expression
falseExpression Expression
}
func (exp *ifExpression) SerializeSql(out *bytes.Buffer) error {
_, _ = out.WriteString("IF(")
_ = exp.conditional.SerializeSql(out)
_, _ = out.WriteString(",")
_ = exp.trueExpression.SerializeSql(out)
_, _ = out.WriteString(",")
_ = exp.falseExpression.SerializeSql(out)
_, _ = out.WriteString(")")
return nil
}
// Returns a representation of an if-expression, of the form:
// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE)
func If(conditional BoolExpression,
trueExpression Expression,
falseExpression Expression) Expression {
return &ifExpression{
conditional: conditional,
trueExpression: trueExpression,
falseExpression: falseExpression,
}
}
type columnValueExpression struct {
isExpression
column NonAliasColumn
}
func ColumnValue(col NonAliasColumn) Expression {
return &columnValueExpression{
column: col,
}
}
func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error {
_, _ = out.WriteString("VALUES(")
_ = cv.column.SerializeSqlForColumnList(out)
_ = out.WriteByte(')')
return nil return nil
} }

View file

@ -0,0 +1,379 @@
// Query building functions for expression components
package sqlbuilder
import (
"bytes"
"strconv"
"strings"
"time"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/dropbox/godropbox/errors"
)
type orderByClause struct {
isOrderByClause
expression Expression
ascent bool
}
func (o *orderByClause) SerializeSql(out *bytes.Buffer) error {
if o.expression == nil {
return errors.Newf(
"nil order by clause. Generated sql: %s",
out.String())
}
if err := o.expression.SerializeSql(out); err != nil {
return err
}
if o.ascent {
_, _ = out.WriteString(" ASC")
} else {
_, _ = out.WriteString(" DESC")
}
return nil
}
func Asc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: true}
}
func Desc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: false}
}
func serializeClauses(
clauses []Clause,
separator []byte,
out *bytes.Buffer) (err error) {
if clauses == nil || len(clauses) == 0 {
return errors.Newf("Empty clauses. Generated sql: %s", out.String())
}
if clauses[0] == nil {
return errors.Newf("nil clause. Generated sql: %s", out.String())
}
if err = clauses[0].SerializeSql(out); err != nil {
return
}
for _, c := range clauses[1:] {
_, _ = out.Write(separator)
if c == nil {
return errors.Newf("nil clause. Generated sql: %s", out.String())
}
if err = c.SerializeSql(out); err != nil {
return
}
}
return nil
}
// Representation of n-ary arithmetic (+ - * /)
type arithmeticExpression struct {
expressionInterfaceImpl
expressions []Expression
operator []byte
}
func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer) (err error) {
if len(arith.expressions) == 0 {
return errors.Newf(
"Empty arithmetic expression. Generated sql: %s",
out.String())
}
clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
for i, expr := range arith.expressions {
clauses[i] = expr
}
useParentheses := len(clauses) > 1
if useParentheses {
_ = out.WriteByte('(')
}
if err = serializeClauses(clauses, arith.operator, out); err != nil {
return
}
if useParentheses {
_ = out.WriteByte(')')
}
return nil
}
type tupleExpression struct {
expressionInterfaceImpl
elements listClause
}
func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer) error {
if len(tuple.elements.clauses) < 1 {
return errors.Newf("Tuples must include at least one element")
}
return tuple.elements.SerializeSql(out)
}
func Tuple(exprs ...Expression) Expression {
clauses := make([]Clause, 0, len(exprs))
for _, expr := range exprs {
clauses = append(clauses, expr)
}
return &tupleExpression{
elements: listClause{
clauses: clauses,
includeParentheses: true,
},
}
}
// Representation of a tuple enclosed, comma separated list of clauses
type listClause struct {
clauses []Clause
includeParentheses bool
}
func (list *listClause) SerializeSql(out *bytes.Buffer) error {
if list.includeParentheses {
_ = out.WriteByte('(')
}
if err := serializeClauses(list.clauses, []byte(","), out); err != nil {
return err
}
if list.includeParentheses {
_ = out.WriteByte(')')
}
return nil
}
type funcExpression struct {
expressionInterfaceImpl
funcName string
args *listClause
}
func (c *funcExpression) SerializeSql(out *bytes.Buffer) (err error) {
if !validIdentifierName(c.funcName) {
return errors.Newf(
"Invalid function name: %s. Generated sql: %s",
c.funcName,
out.String())
}
_, _ = out.WriteString(c.funcName)
if c.args == nil {
_, _ = out.WriteString("()")
} else {
return c.args.SerializeSql(out)
}
return nil
}
// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
func SqlFunc(funcName string, expressions ...Expression) Expression {
f := &funcExpression{
funcName: funcName,
}
if len(expressions) > 0 {
args := make([]Clause, len(expressions), len(expressions))
for i, expr := range expressions {
args[i] = expr
}
f.args = &listClause{
clauses: args,
includeParentheses: true,
}
}
return f
}
type intervalExpression struct {
expressionInterfaceImpl
duration time.Duration
negative bool
}
var intervalSep = ":"
func (c *intervalExpression) SerializeSql(out *bytes.Buffer) (err error) {
hours := c.duration / time.Hour
minutes := (c.duration % time.Hour) / time.Minute
sec := (c.duration % time.Minute) / time.Second
msec := (c.duration % time.Second) / time.Microsecond
_, _ = out.WriteString("INTERVAL '")
if c.negative {
_, _ = out.WriteString("-")
}
_, _ = out.WriteString(strconv.FormatInt(int64(hours), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(sec), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(msec), 10))
_, _ = out.WriteString("' HOUR_MICROSECOND")
return nil
}
// Interval returns a representation of duration
// in a form "INTERVAL `hour:min:sec:microsec` HOUR_MICROSECOND"
func Interval(duration time.Duration) Expression {
negative := false
if duration < 0 {
negative = true
duration = -duration
}
return &intervalExpression{
duration: duration,
negative: negative,
}
}
var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%")
func EscapeForLike(s string) string {
return likeEscaper.Replace(s)
}
// Returns an escaped literal string
func Literal(v interface{}) Expression {
value, err := sqltypes.BuildValue(v)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
return NewLiteralExpression(value)
}
// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
func Add(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" + "),
}
}
// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
func Sub(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" - "),
}
}
// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
func Mul(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" * "),
}
}
// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
func Div(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" / "),
}
}
//TODO: Uncomment
//
//func BitOr(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" | "),
// }
//}
//
//func BitAnd(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" & "),
// }
//}
//
//func BitXor(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" ^ "),
// }
//}
//
//func Plus(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" + "),
// }
//}
//
//func Minus(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" - "),
// }
//}
type ifExpression struct {
expressionInterfaceImpl
conditional BoolExpression
trueExpression Expression
falseExpression Expression
}
func (exp *ifExpression) SerializeSql(out *bytes.Buffer) error {
_, _ = out.WriteString("IF(")
_ = exp.conditional.SerializeSql(out)
_, _ = out.WriteString(",")
_ = exp.trueExpression.SerializeSql(out)
_, _ = out.WriteString(",")
_ = exp.falseExpression.SerializeSql(out)
_, _ = out.WriteString(")")
return nil
}
// Returns a representation of an if-expression, of the form:
// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE)
func If(conditional BoolExpression,
trueExpression Expression,
falseExpression Expression) Expression {
return &ifExpression{
conditional: conditional,
trueExpression: trueExpression,
falseExpression: falseExpression,
}
}
//TODO: Uncomment
//type columnValueExpression struct {
// isExpression
// column NonAliasColumn
//}
//
//func ColumnValue(col NonAliasColumn) Expression {
// return &columnValueExpression{
// column: col,
// }
//}
//
//func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error {
// _, _ = out.WriteString("VALUES(")
// _ = cv.column.SerializeSqlForColumnList(out)
// _ = out.WriteByte(')')
// return nil
//}

View file

@ -1,3 +1,5 @@
// +build disabled
package sqlbuilder package sqlbuilder
import ( import (

View file

@ -3,7 +3,6 @@ package sqlbuilder
import "bytes" import "bytes"
type FuncExpression struct { type FuncExpression struct {
isExpression
isProjection isProjection
name string name string

View file

@ -33,7 +33,8 @@ type SelectStatement interface {
// NOTE: SelectStatement purposely does not implement the Table interface since // NOTE: SelectStatement purposely does not implement the Table interface since
// mysql's subquery performance is horrible. // mysql's subquery performance is horrible.
type selectStatementImpl struct { type selectStatementImpl struct {
isExpression expressionInterfaceImpl
table ReadableTable table ReadableTable
projections []Projection projections []Projection
where BoolExpression where BoolExpression

View file

@ -1,3 +1,5 @@
// +build disabled
package sqlbuilder package sqlbuilder
import ( import (

View file

@ -1,3 +1,5 @@
// +build disabled
package sqlbuilder package sqlbuilder
import ( import (

View file

@ -4,36 +4,17 @@ import (
"bytes" "bytes"
) )
type Clause interface {
SerializeSql(out *bytes.Buffer) error
}
// A clause that can be used in order by // A clause that can be used in order by
type OrderByClause interface { type OrderByClause interface {
Clause Clause
isOrderByClauseInterface isOrderByClauseInterface
} }
// An expression
type Expression interface {
Clause
isExpressionInterface
}
type BoolExpression interface {
Clause
isBoolExpressionInterface
And(expression BoolExpression) BoolExpression
Or(expression BoolExpression) BoolExpression
}
// A clause that is selectable. // A clause that is selectable.
type Projection interface { type Projection interface {
Clause Clause
isProjectionInterface isProjectionInterface
As(alias string) Projection
SerializeSqlForColumnList(out *bytes.Buffer) error SerializeSqlForColumnList(out *bytes.Buffer) error
} }
@ -82,28 +63,6 @@ type isOrderByClause struct {
func (o *isOrderByClause) isOrderByClauseType() { func (o *isOrderByClause) isOrderByClauseType() {
} }
type isExpressionInterface interface {
isExpressionType()
}
type isExpression struct {
isOrderByClause // can always use expression in order by.
}
func (e *isExpression) isExpressionType() {
}
type isBoolExpressionInterface interface {
isExpressionInterface
isBoolExpressionType()
}
type isBoolExpression struct {
}
func (e *isBoolExpression) isBoolExpressionType() {
}
type isProjectionInterface interface { type isProjectionInterface interface {
isProjectionType() isProjectionType()
} }