Add support for CASE operator.

This commit is contained in:
zer0sub 2019-05-06 12:42:15 +02:00
parent 3367df247c
commit 4f9323ddca
18 changed files with 243 additions and 272 deletions

View file

@ -8,10 +8,10 @@ type BoolExpression interface {
GtEq(rhs Expression) BoolExpression
LtEq(rhs Expression) BoolExpression
And(expression BoolExpression) BoolExpression
Or(expression BoolExpression) BoolExpression
IsTrue() BoolExpression
IsFalse() BoolExpression
AND(expression BoolExpression) BoolExpression
OR(expression BoolExpression) BoolExpression
IS_TRUE() BoolExpression
IS_FALSE() BoolExpression
}
type boolInterfaceImpl struct {
@ -34,18 +34,18 @@ func (b *boolInterfaceImpl) LtEq(rhs Expression) BoolExpression {
return LtEq(b.parent, rhs)
}
func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression {
func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression {
return And(b.parent, expression)
}
func (b *boolInterfaceImpl) Or(expression BoolExpression) BoolExpression {
func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression {
return Or(b.parent, expression)
}
func (b *boolInterfaceImpl) IsTrue() BoolExpression {
func (b *boolInterfaceImpl) IS_TRUE() BoolExpression {
return IsTrue(b.parent)
}
func (b *boolInterfaceImpl) IsFalse() BoolExpression {
func (b *boolInterfaceImpl) IS_FALSE() BoolExpression {
return nil
}
@ -106,7 +106,7 @@ func EXISTS(subQuery SelectStatement) BoolExpression {
// Returns a representation of "a=b"
func Eq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " = ")
return newBinaryBoolExpression(lhs, rhs, "=")
}
// Returns a representation of "a=b", where b is a literal
@ -166,24 +166,24 @@ func GteL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "not expr"
func Not(expr BoolExpression) BoolExpression {
return newPrefixBoolExpression(expr, " NOT")
return newPrefixBoolExpression(expr, "NOT")
}
func IsTrue(expr BoolExpression) BoolExpression {
return newPrefixBoolExpression(expr, " IS TRUE")
return newPrefixBoolExpression(expr, "IS TRUE")
}
func And(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " AND ")
return newBinaryBoolExpression(lhs, rhs, "AND")
}
// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses
func Or(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " OR ")
return newBinaryBoolExpression(lhs, rhs, "OR")
}
func Like(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " LIKE ")
return newBinaryBoolExpression(lhs, rhs, "LIKE")
}
func LikeL(lhs Expression, val string) BoolExpression {
@ -191,7 +191,7 @@ func LikeL(lhs Expression, val string) BoolExpression {
}
func Regexp(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " REGEXP ")
return newBinaryBoolExpression(lhs, rhs, "REGEXP")
}
func RegexpL(lhs Expression, val string) BoolExpression {

View file

@ -27,7 +27,7 @@ func TestBinaryExpression(t *testing.T) {
})
t.Run("and", func(t *testing.T) {
exp := boolExpression.And(Eq(Literal(4), Literal(5)))
exp := boolExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.Serialize(&out)
@ -37,7 +37,7 @@ func TestBinaryExpression(t *testing.T) {
})
t.Run("or", func(t *testing.T) {
exp := boolExpression.Or(Eq(Literal(4), Literal(5)))
exp := boolExpression.OR(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.Serialize(&out)
@ -54,7 +54,7 @@ func TestUnaryExpression(t *testing.T) {
err := notExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), " NOT $1 = $2")
assert.Equal(t, out.buff.String(), "NOT $1 = $2")
t.Run("alias", func(t *testing.T) {
alias := notExpression.AS("alias_not_expression")
@ -63,17 +63,17 @@ func TestUnaryExpression(t *testing.T) {
err := alias.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), ` NOT $1 = $2 AS "alias_not_expression"`)
assert.Equal(t, out.buff.String(), `NOT $1 = $2 AS "alias_not_expression"`)
})
t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `( NOT $1 = $2 AND $3 = $4)`)
assert.Equal(t, out.buff.String(), `(NOT $1 = $2 AND $3 = $4)`)
})
}
@ -84,16 +84,16 @@ func TestUnaryIsTrueExpression(t *testing.T) {
err := notExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), " IS TRUE $1 = $2")
assert.Equal(t, out.buff.String(), "IS TRUE $1 = $2")
t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `( IS TRUE $1 = $2 AND $3 = $4)`)
assert.Equal(t, out.buff.String(), `(IS TRUE $1 = $2 AND $3 = $4)`)
})
}

View file

@ -10,8 +10,6 @@ type serializeOption int
const (
FOR_PROJECTION = iota
UNION_ORDER_BY
NO_TABLE_NAME
)
type Clause interface {

View file

@ -487,7 +487,14 @@ func initializePtrValue(value reflect.Value) {
}
func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} {
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
columnName := ""
if tableName == "" {
columnName = snaker.CamelToSnake(fieldName)
} else {
columnName = tableName + "." + snaker.CamelToSnake(fieldName)
}
//columnName := snaker.CamelToSnake(fieldName)
////fmt.Println(columnName)

View file

@ -24,11 +24,11 @@ type expressionInterfaceImpl struct {
}
func (e *expressionInterfaceImpl) IN(subQuery SelectStatement) BoolExpression {
return newBinaryBoolExpression(e.parent, subQuery, " IN ")
return newBinaryBoolExpression(e.parent, subQuery, "IN")
}
func (e *expressionInterfaceImpl) NOT_IN(subQuery SelectStatement) BoolExpression {
return newBinaryBoolExpression(e.parent, subQuery, " NOT_IN ")
return newBinaryBoolExpression(e.parent, subQuery, "NOT_IN")
}
func (e *expressionInterfaceImpl) AS(alias string) Projection {
@ -103,7 +103,7 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption)
return err
}
out.WriteString(c.operator)
out.WriteString(" " + c.operator + " ")
if err := c.rhs.Serialize(out); err != nil {
return err

View file

@ -7,69 +7,6 @@ import (
"time"
)
// Representation of a tuple enclosed, comma separated list of clauses
//type listClause struct {
// clauses []Clause
// includeParentheses bool
//}
//
//func (list *listClause) Serialize(out *queryData, options ...serializeOption) error {
// if list.includeParentheses {
// out.WriteByte('(')
// }
//
// if err := serializeClauseList(list.clauses, out); err != nil {
// return err
// }
//
// if list.includeParentheses {
// out.WriteByte(')')
// }
// return nil
//}
//
//type funcExpression struct {
// expressionInterfaceImpl
// funcName string
// args *listClause
//}
//
//func (c *funcExpression) Serialize(out *queryData, options ...serializeOption) error {
// if !validIdentifierName(c.funcName) {
// return errors.Newf(
// "Invalid function name: %s.",
// c.funcName,
// out.String())
// }
// _, _ = out.WriteString(c.funcName)
// if c.args == nil {
// _, _ = out.WriteString("()")
// } else {
// return c.args.Serialize(out)
// }
// return nil
//}
//
//// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
//func SqlFunc(funcName string, expressions ...Expression) Expression {
// f := &funcExpression{
// funcName: funcName,
// }
// if len(expressions) > 0 {
// args := make([]Clause, len(expressions), len(expressions))
// for i, expr := range expressions {
// args[i] = expr
// }
//
// f.args = &listClause{
// clauses: args,
// includeParentheses: true,
// }
// }
// return f
//}
type intervalExpression struct {
expressionInterfaceImpl
duration time.Duration
@ -118,137 +55,3 @@ var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%")
func EscapeForLike(s string) string {
return likeEscaper.Replace(s)
}
// Returns an escaped literal string
//func Literal(v interface{}) Expression {
// value, err := sqltypes.BuildValue(v)
// if err != nil {
// panic(errors.Wrap(err, "Invalid literal value"))
// }
// return NewLiteralExpression(value)
//}
//
//// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
//func Add(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" + "),
// }
//}
//
//// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
//func Sub(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" - "),
// }
//}
//
//// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
//func Mul(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" * "),
// }
//}
//
//// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
//func Div(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" / "),
// }
//}
//TODO: Uncomment
//
//func BitOr(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" | "),
// }
//}
//
//func BitAnd(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" & "),
// }
//}
//
//func BitXor(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" ^ "),
// }
//}
//
//func Plus(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" + "),
// }
//}
//
//func Minus(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" - "),
// }
//}
type ifExpression struct {
expressionInterfaceImpl
conditional BoolExpression
trueExpression Expression
falseExpression Expression
}
func (exp *ifExpression) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("IF(")
_ = exp.conditional.Serialize(out)
out.WriteString(",")
_ = exp.trueExpression.Serialize(out)
out.WriteString(",")
_ = exp.falseExpression.Serialize(out)
out.WriteString(")")
return nil
}
// Returns a representation of an if-expression, of the form:
// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE)
func If(conditional BoolExpression,
trueExpression Expression,
falseExpression Expression) Expression {
return &ifExpression{
conditional: conditional,
trueExpression: trueExpression,
falseExpression: falseExpression,
}
}
//TODO: Uncomment
//type columnValueExpression struct {
// isExpression
// column NonAliasColumn
//}
//
//func ColumnValue(col NonAliasColumn) Expression {
// return &columnValueExpression{
// column: col,
// }
//}
//
//func (cv *columnValueExpression) Serialize(out *bytes.Buffer) error {
// _, _ = out.WriteString("VALUES(")
// _ = cv.column.SerializeSqlForColumnList(out)
// _ = out.WriteByte(')')
// return nil
//}

View file

@ -375,22 +375,6 @@ func (s *ExprSuite) TestDesc(c *gc.C) {
c.Assert(sql, gc.Equals, "table1.col1 DESC")
}
func (s *ExprSuite) TestIf(c *gc.C) {
test := GtL(table1Col1, 1.1)
clause := If(test, table1Col1, table1Col2)
buf := &bytes.Buffer{}
err := clause.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(
sql,
gc.Equals,
"IF(table1.col1>1.1,table1.col1,table1.col2)")
}
func (s *ExprSuite) TestColumnValue(c *gc.C) {
clause := ColumnValue(table1Col1)

View file

@ -1,5 +1,7 @@
package sqlbuilder
import "errors"
type funcExpressionImpl struct {
expressionInterfaceImpl
@ -63,3 +65,98 @@ func MAX(expression NumericExpression) NumericExpression {
func SUM(expression NumericExpression) NumericExpression {
return NewNumericFunc("SUM", expression)
}
type caseInterface interface {
Expression
WHEN(condition Expression) caseInterface
THEN(then Expression) caseInterface
ELSE(els Expression) caseInterface
}
type caseExpression struct {
expressionInterfaceImpl
expression Expression
when []Expression
then []Expression
els Expression
}
func CASE(expression ...Expression) caseInterface {
caseExp := &caseExpression{}
if len(expression) == 1 {
caseExp.expression = expression[0]
}
caseExp.expressionInterfaceImpl.parent = caseExp
return caseExp
}
func (c *caseExpression) WHEN(when Expression) caseInterface {
c.when = append(c.when, when)
return c
}
func (c *caseExpression) THEN(then Expression) caseInterface {
c.then = append(c.then, then)
return c
}
func (c *caseExpression) ELSE(els Expression) caseInterface {
c.els = els
return c
}
func (c *caseExpression) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("(CASE")
if c.expression != nil {
out.WriteString(" ")
err := c.expression.Serialize(out)
if err != nil {
return err
}
}
if len(c.when) == 0 || len(c.then) == 0 {
return errors.New("Invalid case statement. There should be at least one when/then expression pair. ")
}
if len(c.when) != len(c.then) {
return errors.New("When and then expression count mismatch. ")
}
for i, when := range c.when {
out.WriteString(" WHEN ")
err := when.Serialize(out)
if err != nil {
return err
}
out.WriteString(" THEN ")
err = c.then[i].Serialize(out)
if err != nil {
return err
}
}
if c.els != nil {
out.WriteString(" ELSE ")
err := c.els.Serialize(out)
if err != nil {
return err
}
}
out.WriteString(" END)")
return nil
}

View file

@ -0,0 +1,33 @@
package sqlbuilder
import (
"gotest.tools/assert"
"testing"
)
func TestCase1(t *testing.T) {
query := CASE().
WHEN(table3Col1.EqL(1)).THEN(table3Col1.Add(IntLiteral(1))).
WHEN(table3Col1.EqL(2)).THEN(table3Col1.Add(IntLiteral(2)))
queryData := &queryData{}
err := query.Serialize(queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`)
}
func TestCase2(t *testing.T) {
query := CASE(table3Col1).
WHEN(IntLiteral(1)).THEN(table3Col1.Add(IntLiteral(1))).
WHEN(IntLiteral(2)).THEN(table3Col1.Add(IntLiteral(2))).
ELSE(IntLiteral(0))
queryData := &queryData{}
err := query.Serialize(queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`)
}

View file

@ -20,3 +20,19 @@ func (l literalExpression) Serialize(out *queryData, options ...serializeOption)
return nil
}
type numLiteralExpression struct {
literalExpression
numericInterfaceImpl
}
func IntLiteral(value int) NumericExpression {
numLiteral := &numLiteralExpression{}
numLiteral.literalExpression = *Literal(value)
numLiteral.literalExpression.parent = numLiteral
numLiteral.numericInterfaceImpl.parent = numLiteral
return numLiteral
}

View file

@ -62,19 +62,19 @@ func (n *numericInterfaceImpl) LtEqL(literal interface{}) BoolExpression {
}
func (n *numericInterfaceImpl) Add(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " + ")
return newBinaryNumericExpression(n.parent, expression, "+")
}
func (n *numericInterfaceImpl) Sub(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " - ")
return newBinaryNumericExpression(n.parent, expression, "-")
}
func (n *numericInterfaceImpl) Mul(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " * ")
return newBinaryNumericExpression(n.parent, expression, "*")
}
func (n *numericInterfaceImpl) Div(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " / ")
return newBinaryNumericExpression(n.parent, expression, "/")
}
//---------------------------------------------------//

View file

@ -33,7 +33,7 @@ func (o *orderByClause) Serialize(out *queryData, options ...serializeOption) er
return errors.Newf("nil orderBy by clause.")
}
if err := o.expression.Serialize(out, UNION_ORDER_BY); err != nil {
if err := o.expression.Serialize(out); err != nil {
return err
}

View file

@ -482,7 +482,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
// tests on outer statement: Group By, Order By, LIMIT
// on inner statement: AndWhere, WHERE (with And), Order By, LIMIT
// on inner statement: AndWhere, WHERE (with AND), Order By, LIMIT
select_queries := make([]SelectStatement, 0, 3)
// We're not trying to write a SQL parser, so we won't warn if you do something silly like

View file

@ -4,9 +4,9 @@ type StringExpression interface {
Expression
Eq(expression StringExpression) BoolExpression
EqL(value string) BoolExpression
EqString(value string) BoolExpression
NotEq(expression StringExpression) BoolExpression
NotEqL(value string) BoolExpression
NotEqString(value string) BoolExpression
}
type stringInterfaceImpl struct {
@ -14,17 +14,17 @@ type stringInterfaceImpl struct {
}
func (b *stringInterfaceImpl) Eq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, " = ")
return Eq(b.parent, expression)
}
func (b *stringInterfaceImpl) EqL(value string) BoolExpression {
return newBinaryBoolExpression(b.parent, Literal(value), " = ")
func (b *stringInterfaceImpl) EqString(value string) BoolExpression {
return EqL(b.parent, value)
}
func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, " != ")
return NotEq(b.parent, expression)
}
func (b *stringInterfaceImpl) NotEqL(value string) BoolExpression {
return newBinaryBoolExpression(b.parent, Literal(value), " != ")
func (b *stringInterfaceImpl) NotEqString(value string) BoolExpression {
return NotEq(b.parent, Literal(value))
}

View file

@ -42,7 +42,7 @@ func TestUpdate(t *testing.T) {
//
//func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.WHERE(EqString(table1Col2, 2))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
@ -54,7 +54,7 @@ func TestUpdate(t *testing.T) {
//
//func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) {
// stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.WHERE(EqString(table1Col2, 2))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
@ -68,7 +68,7 @@ func TestUpdate(t *testing.T) {
// stmt := table1.UPDATE()
// stmt.SET(table1Col1, Literal(1))
// stmt.SET(table1Col2, Literal(2))
// stmt.WHERE(EqL(table1Col2, 3))
// stmt.WHERE(EqString(table1Col2, 3))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
@ -82,7 +82,7 @@ func TestUpdate(t *testing.T) {
//
//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.WHERE(EqString(table1Col2, 2))
// stmt.ORDER_BY(table1Col2)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
@ -98,7 +98,7 @@ func TestUpdate(t *testing.T) {
//
//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.WHERE(EqString(table1Col2, 2))
// stmt.LIMIT(5)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)

View file

@ -12,7 +12,7 @@ import (
func TestUUIDType(t *testing.T) {
query := table.AllTypes.
SELECT(table.AllTypes.AllColumns).
WHERE(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
WHERE(table.AllTypes.UUID.EqString("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
queryStr, args, err := query.Sql()

View file

@ -2,6 +2,7 @@ package tests
import (
"fmt"
"github.com/davecgh/go-spew/spew"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
@ -101,7 +102,7 @@ func TestSelectAndUnionInProjection(t *testing.T) {
// INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)).
// INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns).
// WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2)))
// WHERE(FilmActor.ActorID.GtEq(1).AND(FilmActor.ActorID.LteLiteral(2)))
//
// queryStr, args, err := query.Sql()
// assert.NilError(t, err)
@ -131,7 +132,7 @@ func TestJoinQuerySlice(t *testing.T) {
query := Film.
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Language.AllColumns, Film.AllColumns).
WHERE(Film.Rating.EqL(string(model.MpaaRating_NC17))).
WHERE(Film.Rating.EqString(string(model.MpaaRating_NC17))).
LIMIT(15)
queryStr, args, err := query.Sql()
@ -317,7 +318,7 @@ func TestSelectSelfJoin(t *testing.T) {
f2 := Film.AS("f2")
query := f1.
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))).
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))).
SELECT(f1.AllColumns, f2.AllColumns).
ORDER_BY(f1.FilmID.ASC())
@ -356,7 +357,7 @@ func TestSelectAliasColumn(t *testing.T) {
}
query := f1.
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))).
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))).
SELECT(f1.Title.AS("thesame_length_films.title1"),
f2.Title.AS("thesame_length_films.title2"),
f1.Length.AS("thesame_length_films.length")).
@ -443,7 +444,7 @@ func TestSubQuery(t *testing.T) {
// WHERE(Actor.LastName.Neq(avrgCustomer))
rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating).
WHERE(Film.Rating.EqL("R")).
WHERE(Film.Rating.EqString("R")).
AsTable("films")
query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)).
@ -532,7 +533,7 @@ func TestSelectGroupByHaving(t *testing.T) {
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, len(args), 1)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id", SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>$1 ORDER BY SUM(payment.amount) ASC`)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id", SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount) > $1 ORDER BY SUM(payment.amount) ASC`)
type CustomerPaymentSum struct {
CustomerID int16
@ -666,6 +667,38 @@ func TestUnion(t *testing.T) {
})
}
func TestSelectWithCase(t *testing.T) {
query := Payment.SELECT(
sqlbuilder.CASE(Payment.StaffID).
WHEN(sqlbuilder.IntLiteral(1)).THEN(sqlbuilder.Literal("ONE")).
WHEN(sqlbuilder.IntLiteral(2)).THEN(sqlbuilder.Literal("TWO")).
WHEN(sqlbuilder.IntLiteral(3)).THEN(sqlbuilder.Literal("THREE")).
ELSE(sqlbuilder.Literal("OTHER")).AS("staff_id_num"),
).
ORDER_BY(Payment.PaymentID.ASC()).
LIMIT(20)
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
fmt.Println(args)
dest := []struct {
StaffIdNum string
}{}
err = query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 20)
assert.Equal(t, dest[0].StaffIdNum, "TWO")
assert.Equal(t, dest[1].StaffIdNum, "ONE")
spew.Dump(dest)
}
func int16Ptr(i int16) *int16 {
return &i
}

View file

@ -21,7 +21,7 @@ func TestUpdateValues(t *testing.T) {
query := table.Link.
UPDATE(table.Link.Name, table.Link.URL).
SET("Bong", "http://bong.com").
WHERE(table.Link.Name.EqL("Bing"))
WHERE(table.Link.Name.EqString("Bing"))
queryStr, args, err := query.Sql()
@ -38,7 +38,7 @@ func TestUpdateValues(t *testing.T) {
links := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns).
WHERE(table.Link.Name.EqL("Bong")).
WHERE(table.Link.Name.EqString("Bong")).
Query(db, &links)
assert.NilError(t, err)
@ -60,7 +60,7 @@ func TestUpdateAndReturning(t *testing.T) {
stmt := table.Link.
UPDATE(table.Link.Name, table.Link.URL).
SET("DuckDuckGo", "http://www.duckduckgo.com").
WHERE(table.Link.Name.EqL("Ask")).
WHERE(table.Link.Name.EqString("Ask")).
RETURNING(table.Link.AllColumns)
stmtStr, args, err := stmt.Sql()