From a49c682672c553d082ff62dd22defa0938a8a84a Mon Sep 17 00:00:00 2001 From: sub0Zero Date: Fri, 15 Mar 2019 22:02:59 +0100 Subject: [PATCH] Extend BoolExpression with logical operators. --- sqlbuilder/bool_expresion.go | 349 +++++++++++++++++++++++++++++++++++ sqlbuilder/expression.go | 319 -------------------------------- sqlbuilder/statement.go | 2 +- sqlbuilder/types.go | 3 + tests/generator_test.go | 3 +- 5 files changed, 354 insertions(+), 322 deletions(-) create mode 100644 sqlbuilder/bool_expresion.go diff --git a/sqlbuilder/bool_expresion.go b/sqlbuilder/bool_expresion.go new file mode 100644 index 0000000..76f8453 --- /dev/null +++ b/sqlbuilder/bool_expresion.go @@ -0,0 +1,349 @@ +package sqlbuilder + +import ( + "bytes" + "github.com/dropbox/godropbox/database/sqltypes" + "github.com/dropbox/godropbox/errors" + "reflect" + "time" +) + +// Returns a representation of "a=b" +func Eq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS ")) + } + return newBoolExpression(lhs, rhs, []byte(" = ")) +} + +// Returns a representation of "a=b", where b is a literal +func EqL(lhs Expression, val interface{}) BoolExpression { + return Eq(lhs, Literal(val)) +} + +// Returns a representation of "a!=b" +func Neq(lhs, rhs Expression) BoolExpression { + lit, ok := rhs.(*literalExpression) + if ok && sqltypes.Value(lit.value).IsNull() { + return newBoolExpression(lhs, rhs, []byte(" IS NOT ")) + } + return newBoolExpression(lhs, rhs, []byte("!=")) +} + +// Returns a representation of "a!=b", where b is a literal +func NeqL(lhs Expression, val interface{}) BoolExpression { + return Neq(lhs, Literal(val)) +} + +// Returns a representation of "ab" +func Gt(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">")) +} + +// Returns a representation of "a>b", where b is a literal +func GtL(lhs Expression, val interface{}) BoolExpression { + return Gt(lhs, Literal(val)) +} + +// Returns a representation of "a>=b" +func Gte(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(">=")) +} + +// Returns a representation of "a>=b", where b is a literal +func GteL(lhs Expression, val interface{}) BoolExpression { + return Gte(lhs, Literal(val)) +} + +// Returns a representation of "not expr" +func Not(expr BoolExpression) BoolExpression { + return &negateExpression{ + nested: expr, + } +} + +// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses +func And(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" AND "), + } +} + +// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses +func Or(expressions ...BoolExpression) BoolExpression { + return &conjunctExpression{ + expressions: expressions, + conjunction: []byte(" OR "), + } +} + +func Like(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" LIKE ")) +} + +func LikeL(lhs Expression, val string) BoolExpression { + return Like(lhs, Literal(val)) +} + +func Regexp(lhs, rhs Expression) BoolExpression { + return newBoolExpression(lhs, rhs, []byte(" REGEXP ")) +} + +func RegexpL(lhs Expression, val string) BoolExpression { + return Regexp(lhs, Literal(val)) +} + +// Returns a representation of "a IN (b[0], ..., b[n-1])", where b is a list +// of literals valList must be a slice type +func In(lhs Expression, valList interface{}) BoolExpression { + var clauses []Clause + switch val := valList.(type) { + // This atrocious body of copy-paste code is due to the fact that if you + // try to merge the cases, you can't treat val as a list + case []int: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []int64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint32: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []uint64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []float64: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []string: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case [][]byte: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []time.Time: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Numeric: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Fractional: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.String: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + case []sqltypes.Value: + clauses = make([]Clause, 0, len(val)) + for _, v := range val { + clauses = append(clauses, Literal(v)) + } + default: + return &inExpression{ + err: errors.Newf( + "Unknown value list type in IN clause: %s", + reflect.TypeOf(valList)), + } + } + + expr := &inExpression{lhs: lhs} + if len(clauses) > 0 { + expr.rhs = &listClause{clauses: clauses, includeParentheses: true} + } + return expr +} + +type boolExpressionImpl struct { + isExpression + isBoolExpression +} + +func (c *boolExpressionImpl) And(expression BoolExpression) BoolExpression { + return And(c, expression) +} + +func (c *boolExpressionImpl) Or(expression BoolExpression) BoolExpression { + return Or(c, expression) +} + +func (conj *boolExpressionImpl) SerializeSql(out *bytes.Buffer) (err error) { + return errors.New("Not implemented") +} + +// Representation of n-ary conjunctions (AND/OR) +type conjunctExpression struct { + boolExpressionImpl + expressions []BoolExpression + conjunction []byte +} + +func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) { + if len(conj.expressions) == 0 { + return errors.Newf( + "Empty conjunction. Generated sql: %s", + out.String()) + } + + clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) + for i, expr := range conj.expressions { + clauses[i] = expr + } + + useParentheses := len(clauses) > 1 + if useParentheses { + _ = out.WriteByte('(') + } + + if err = serializeClauses(clauses, conj.conjunction, out); err != nil { + return + } + + if useParentheses { + _ = out.WriteByte(')') + } + + return nil +} + +// A not expression which negates a expression value +type negateExpression struct { + boolExpressionImpl + + nested BoolExpression +} + +func (c *negateExpression) SerializeSql(out *bytes.Buffer) (err error) { + _, _ = out.WriteString("NOT (") + + if c.nested == nil { + return errors.Newf("nil nested. Generated sql: %s", out.String()) + } + if err = c.nested.SerializeSql(out); err != nil { + return + } + + _ = out.WriteByte(')') + return nil +} + +// A binary expression that evaluates to a boolean value. +type boolBinaryExpression struct { + boolExpressionImpl + binaryExpression binaryExpression +} + +func (b *boolBinaryExpression) And(expression BoolExpression) BoolExpression { + return And(b, expression) +} + +func newBoolExpression(lhs, rhs Expression, operator []byte) *boolBinaryExpression { + // go does not allow {} syntax for initializing promoted fields ... + expr := new(boolBinaryExpression) + expr.binaryExpression.lhs = lhs + expr.binaryExpression.rhs = rhs + expr.binaryExpression.operator = operator + return expr +} + +func (b *boolBinaryExpression) SerializeSql(out *bytes.Buffer) (err error) { + return b.binaryExpression.SerializeSql(out) +} + +// in expression representation +type inExpression struct { + boolExpressionImpl + + lhs Expression + rhs *listClause + + err error +} + +func (c *inExpression) SerializeSql(out *bytes.Buffer) error { + if c.err != nil { + return errors.Wrap(c.err, "Invalid IN expression") + } + + if c.lhs == nil { + return errors.Newf( + "lhs of in expression is nil. Generated sql: %s", + out.String()) + } + + // We'll serialize the lhs even if we don't need it to ensure no error + buf := &bytes.Buffer{} + + err := c.lhs.SerializeSql(buf) + if err != nil { + return err + } + + if c.rhs == nil { + _, _ = out.WriteString("FALSE") + return nil + } + + _, _ = out.WriteString(buf.String()) + _, _ = out.WriteString(" IN ") + + err = c.rhs.SerializeSql(out) + if err != nil { + return err + } + + return nil +} diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index b42084b..3e426fb 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -3,7 +3,6 @@ package sqlbuilder import ( "bytes" - "reflect" "strconv" "strings" "time" @@ -87,42 +86,6 @@ func serializeClauses( return nil } -// Representation of n-ary conjunctions (AND/OR) -type conjunctExpression struct { - isExpression - isBoolExpression - expressions []BoolExpression - conjunction []byte -} - -func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) { - if len(conj.expressions) == 0 { - return errors.Newf( - "Empty conjunction. Generated sql: %s", - out.String()) - } - - clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) - for i, expr := range conj.expressions { - clauses[i] = expr - } - - useParentheses := len(clauses) > 1 - if useParentheses { - _ = out.WriteByte('(') - } - - if err = serializeClauses(clauses, conj.conjunction, out); err != nil { - return - } - - if useParentheses { - _ = out.WriteByte(')') - } - - return nil -} - // Representation of n-ary arithmetic (+ - * /) type arithmeticExpression struct { isExpression @@ -204,35 +167,6 @@ func (list *listClause) SerializeSql(out *bytes.Buffer) error { return nil } -// A not expression which negates a expression value -type negateExpression struct { - isExpression - isBoolExpression - - nested BoolExpression -} - -func (c *negateExpression) SerializeSql(out *bytes.Buffer) (err error) { - _, _ = out.WriteString("NOT (") - - if c.nested == nil { - return errors.Newf("nil nested. Generated sql: %s", out.String()) - } - if err = c.nested.SerializeSql(out); err != nil { - return - } - - _ = out.WriteByte(')') - return nil -} - -// Returns a representation of "not expr" -func Not(expr BoolExpression) BoolExpression { - return &negateExpression{ - nested: expr, - } -} - // Representation of binary operations (e.g. comparisons, arithmetic) type binaryExpression struct { isExpression @@ -260,21 +194,6 @@ func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) { return nil } -// A binary expression that evaluates to a boolean value. -type boolExpression struct { - isBoolExpression - binaryExpression -} - -func newBoolExpression(lhs, rhs Expression, operator []byte) *boolExpression { - // go does not allow {} syntax for initializing promoted fields ... - expr := new(boolExpression) - expr.lhs = lhs - expr.rhs = rhs - expr.operator = operator - return expr -} - type funcExpression struct { isExpression funcName string @@ -373,38 +292,6 @@ func Literal(v interface{}) Expression { return &literalExpression{value: value} } -// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses -func And(expressions ...BoolExpression) BoolExpression { - return &conjunctExpression{ - expressions: expressions, - conjunction: []byte(" AND "), - } -} - -// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses -func Or(expressions ...BoolExpression) BoolExpression { - return &conjunctExpression{ - expressions: expressions, - conjunction: []byte(" OR "), - } -} - -func Like(lhs, rhs Expression) BoolExpression { - return newBoolExpression(lhs, rhs, []byte(" LIKE ")) -} - -func LikeL(lhs Expression, val string) BoolExpression { - return Like(lhs, Literal(val)) -} - -func Regexp(lhs, rhs Expression) BoolExpression { - return newBoolExpression(lhs, rhs, []byte(" REGEXP ")) -} - -func RegexpL(lhs Expression, val string) BoolExpression { - return Regexp(lhs, Literal(val)) -} - // Returns a representation of "c[0] + ... + c[n-1]" for c in clauses func Add(expressions ...Expression) Expression { return &arithmeticExpression{ @@ -437,74 +324,6 @@ func Div(expressions ...Expression) Expression { } } -// Returns a representation of "a=b" -func Eq(lhs, rhs Expression) BoolExpression { - lit, ok := rhs.(*literalExpression) - if ok && sqltypes.Value(lit.value).IsNull() { - return newBoolExpression(lhs, rhs, []byte(" IS ")) - } - return newBoolExpression(lhs, rhs, []byte(" = ")) -} - -// Returns a representation of "a=b", where b is a literal -func EqL(lhs Expression, val interface{}) BoolExpression { - return Eq(lhs, Literal(val)) -} - -// Returns a representation of "a!=b" -func Neq(lhs, rhs Expression) BoolExpression { - lit, ok := rhs.(*literalExpression) - if ok && sqltypes.Value(lit.value).IsNull() { - return newBoolExpression(lhs, rhs, []byte(" IS NOT ")) - } - return newBoolExpression(lhs, rhs, []byte("!=")) -} - -// Returns a representation of "a!=b", where b is a literal -func NeqL(lhs Expression, val interface{}) BoolExpression { - return Neq(lhs, Literal(val)) -} - -// Returns a representation of "ab" -func Gt(lhs, rhs Expression) BoolExpression { - return newBoolExpression(lhs, rhs, []byte(">")) -} - -// Returns a representation of "a>b", where b is a literal -func GtL(lhs Expression, val interface{}) BoolExpression { - return Gt(lhs, Literal(val)) -} - -// Returns a representation of "a>=b" -func Gte(lhs, rhs Expression) BoolExpression { - return newBoolExpression(lhs, rhs, []byte(">=")) -} - -// Returns a representation of "a>=b", where b is a literal -func GteL(lhs Expression, val interface{}) BoolExpression { - return Gte(lhs, Literal(val)) -} - func BitOr(lhs, rhs Expression) Expression { return &binaryExpression{ lhs: lhs, @@ -545,144 +364,6 @@ func Minus(lhs, rhs Expression) Expression { } } -// in expression representation -type inExpression struct { - isExpression - isBoolExpression - - lhs Expression - rhs *listClause - - err error -} - -func (c *inExpression) SerializeSql(out *bytes.Buffer) error { - if c.err != nil { - return errors.Wrap(c.err, "Invalid IN expression") - } - - if c.lhs == nil { - return errors.Newf( - "lhs of in expression is nil. Generated sql: %s", - out.String()) - } - - // We'll serialize the lhs even if we don't need it to ensure no error - buf := &bytes.Buffer{} - - err := c.lhs.SerializeSql(buf) - if err != nil { - return err - } - - if c.rhs == nil { - _, _ = out.WriteString("FALSE") - return nil - } - - _, _ = out.WriteString(buf.String()) - _, _ = out.WriteString(" IN ") - - err = c.rhs.SerializeSql(out) - if err != nil { - return err - } - - return nil -} - -// Returns a representation of "a IN (b[0], ..., b[n-1])", where b is a list -// of literals valList must be a slice type -func In(lhs Expression, valList interface{}) BoolExpression { - var clauses []Clause - switch val := valList.(type) { - // This atrocious body of copy-paste code is due to the fact that if you - // try to merge the cases, you can't treat val as a list - case []int: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []int32: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []int64: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []uint: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []uint32: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []uint64: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []float64: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []string: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case [][]byte: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []time.Time: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []sqltypes.Numeric: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []sqltypes.Fractional: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []sqltypes.String: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - case []sqltypes.Value: - clauses = make([]Clause, 0, len(val)) - for _, v := range val { - clauses = append(clauses, Literal(v)) - } - default: - return &inExpression{ - err: errors.Newf( - "Unknown value list type in IN clause: %s", - reflect.TypeOf(valList)), - } - } - - expr := &inExpression{lhs: lhs} - if len(clauses) > 0 { - expr.rhs = &listClause{clauses: clauses, includeParentheses: true} - } - return expr -} - type ifExpression struct { isExpression conditional BoolExpression diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 5a51a4e..471c0bc 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -25,10 +25,10 @@ type SelectStatement interface { GroupBy(expressions ...Expression) SelectStatement OrderBy(clauses ...OrderByClause) SelectStatement Limit(limit int64) SelectStatement + Offset(offset int64) SelectStatement Distinct() SelectStatement WithSharedLock() SelectStatement ForUpdate() SelectStatement - Offset(offset int64) SelectStatement Comment(comment string) SelectStatement Copy() SelectStatement } diff --git a/sqlbuilder/types.go b/sqlbuilder/types.go index 82a6452..271f130 100644 --- a/sqlbuilder/types.go +++ b/sqlbuilder/types.go @@ -23,6 +23,9 @@ type Expression interface { type BoolExpression interface { Clause isBoolExpressionInterface + + And(expression BoolExpression) BoolExpression + Or(expression BoolExpression) BoolExpression } // A clause that is selectable. diff --git a/tests/generator_test.go b/tests/generator_test.go index a0491de..ee97f84 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" "github.com/sub0Zero/go-sqlbuilder/generator" - "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" "gotest.tools/assert" @@ -138,7 +137,7 @@ func TestJoinQueryStruct(t *testing.T) { InnerJoinUsing(Film, FilmActor.FilmID, Film.FilmID). InnerJoinUsing(Language, Film.LanguageID, Language.LanguageID). Select(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). - Where(sqlbuilder.And(FilmActor.ActorID.GteLiteral(1), FilmActor.ActorID.LteLiteral(2))) + Where(FilmActor.ActorID.GteLiteral(1).And(FilmActor.ActorID.LteLiteral(2))) queryStr, err := query.String() assert.NilError(t, err)