Add support for CASE operator.

This commit is contained in:
zer0sub 2019-05-06 12:42:15 +02:00
parent 3367df247c
commit 4f9323ddca
18 changed files with 243 additions and 272 deletions

View file

@ -8,10 +8,10 @@ type BoolExpression interface {
GtEq(rhs Expression) BoolExpression GtEq(rhs Expression) BoolExpression
LtEq(rhs Expression) BoolExpression LtEq(rhs Expression) BoolExpression
And(expression BoolExpression) BoolExpression AND(expression BoolExpression) BoolExpression
Or(expression BoolExpression) BoolExpression OR(expression BoolExpression) BoolExpression
IsTrue() BoolExpression IS_TRUE() BoolExpression
IsFalse() BoolExpression IS_FALSE() BoolExpression
} }
type boolInterfaceImpl struct { type boolInterfaceImpl struct {
@ -34,18 +34,18 @@ func (b *boolInterfaceImpl) LtEq(rhs Expression) BoolExpression {
return LtEq(b.parent, rhs) return LtEq(b.parent, rhs)
} }
func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression { func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression {
return And(b.parent, expression) return And(b.parent, expression)
} }
func (b *boolInterfaceImpl) Or(expression BoolExpression) BoolExpression { func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression {
return Or(b.parent, expression) return Or(b.parent, expression)
} }
func (b *boolInterfaceImpl) IsTrue() BoolExpression { func (b *boolInterfaceImpl) IS_TRUE() BoolExpression {
return IsTrue(b.parent) return IsTrue(b.parent)
} }
func (b *boolInterfaceImpl) IsFalse() BoolExpression { func (b *boolInterfaceImpl) IS_FALSE() BoolExpression {
return nil return nil
} }
@ -106,7 +106,7 @@ func EXISTS(subQuery SelectStatement) BoolExpression {
// Returns a representation of "a=b" // Returns a representation of "a=b"
func Eq(lhs, rhs Expression) BoolExpression { func Eq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " = ") return newBinaryBoolExpression(lhs, rhs, "=")
} }
// Returns a representation of "a=b", where b is a literal // Returns a representation of "a=b", where b is a literal
@ -166,24 +166,24 @@ func GteL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "not expr" // Returns a representation of "not expr"
func Not(expr BoolExpression) BoolExpression { func Not(expr BoolExpression) BoolExpression {
return newPrefixBoolExpression(expr, " NOT") return newPrefixBoolExpression(expr, "NOT")
} }
func IsTrue(expr BoolExpression) BoolExpression { func IsTrue(expr BoolExpression) BoolExpression {
return newPrefixBoolExpression(expr, " IS TRUE") return newPrefixBoolExpression(expr, "IS TRUE")
} }
func And(lhs, rhs Expression) BoolExpression { func And(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " AND ") return newBinaryBoolExpression(lhs, rhs, "AND")
} }
// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses // Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses
func Or(lhs, rhs Expression) BoolExpression { func Or(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " OR ") return newBinaryBoolExpression(lhs, rhs, "OR")
} }
func Like(lhs, rhs Expression) BoolExpression { func Like(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " LIKE ") return newBinaryBoolExpression(lhs, rhs, "LIKE")
} }
func LikeL(lhs Expression, val string) BoolExpression { func LikeL(lhs Expression, val string) BoolExpression {
@ -191,7 +191,7 @@ func LikeL(lhs Expression, val string) BoolExpression {
} }
func Regexp(lhs, rhs Expression) BoolExpression { func Regexp(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, " REGEXP ") return newBinaryBoolExpression(lhs, rhs, "REGEXP")
} }
func RegexpL(lhs Expression, val string) BoolExpression { func RegexpL(lhs Expression, val string) BoolExpression {

View file

@ -27,7 +27,7 @@ func TestBinaryExpression(t *testing.T) {
}) })
t.Run("and", func(t *testing.T) { t.Run("and", func(t *testing.T) {
exp := boolExpression.And(Eq(Literal(4), Literal(5))) exp := boolExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
@ -37,7 +37,7 @@ func TestBinaryExpression(t *testing.T) {
}) })
t.Run("or", func(t *testing.T) { t.Run("or", func(t *testing.T) {
exp := boolExpression.Or(Eq(Literal(4), Literal(5))) exp := boolExpression.OR(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
@ -54,7 +54,7 @@ func TestUnaryExpression(t *testing.T) {
err := notExpression.Serialize(&out) err := notExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), " NOT $1 = $2") assert.Equal(t, out.buff.String(), "NOT $1 = $2")
t.Run("alias", func(t *testing.T) { t.Run("alias", func(t *testing.T) {
alias := notExpression.AS("alias_not_expression") alias := notExpression.AS("alias_not_expression")
@ -63,17 +63,17 @@ func TestUnaryExpression(t *testing.T) {
err := alias.SerializeForProjection(&out) err := alias.SerializeForProjection(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), ` NOT $1 = $2 AS "alias_not_expression"`) assert.Equal(t, out.buff.String(), `NOT $1 = $2 AS "alias_not_expression"`)
}) })
t.Run("and", func(t *testing.T) { t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5))) exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `( NOT $1 = $2 AND $3 = $4)`) assert.Equal(t, out.buff.String(), `(NOT $1 = $2 AND $3 = $4)`)
}) })
} }
@ -84,16 +84,16 @@ func TestUnaryIsTrueExpression(t *testing.T) {
err := notExpression.Serialize(&out) err := notExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), " IS TRUE $1 = $2") assert.Equal(t, out.buff.String(), "IS TRUE $1 = $2")
t.Run("and", func(t *testing.T) { t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5))) exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `( IS TRUE $1 = $2 AND $3 = $4)`) assert.Equal(t, out.buff.String(), `(IS TRUE $1 = $2 AND $3 = $4)`)
}) })
} }

View file

@ -10,8 +10,6 @@ type serializeOption int
const ( const (
FOR_PROJECTION = iota FOR_PROJECTION = iota
UNION_ORDER_BY
NO_TABLE_NAME
) )
type Clause interface { type Clause interface {

View file

@ -487,7 +487,14 @@ func initializePtrValue(value reflect.Value) {
} }
func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} { func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} {
columnName := tableName + "." + snaker.CamelToSnake(fieldName) columnName := ""
if tableName == "" {
columnName = snaker.CamelToSnake(fieldName)
} else {
columnName = tableName + "." + snaker.CamelToSnake(fieldName)
}
//columnName := snaker.CamelToSnake(fieldName) //columnName := snaker.CamelToSnake(fieldName)
////fmt.Println(columnName) ////fmt.Println(columnName)

View file

@ -24,11 +24,11 @@ type expressionInterfaceImpl struct {
} }
func (e *expressionInterfaceImpl) IN(subQuery SelectStatement) BoolExpression { func (e *expressionInterfaceImpl) IN(subQuery SelectStatement) BoolExpression {
return newBinaryBoolExpression(e.parent, subQuery, " IN ") return newBinaryBoolExpression(e.parent, subQuery, "IN")
} }
func (e *expressionInterfaceImpl) NOT_IN(subQuery SelectStatement) BoolExpression { func (e *expressionInterfaceImpl) NOT_IN(subQuery SelectStatement) BoolExpression {
return newBinaryBoolExpression(e.parent, subQuery, " NOT_IN ") return newBinaryBoolExpression(e.parent, subQuery, "NOT_IN")
} }
func (e *expressionInterfaceImpl) AS(alias string) Projection { func (e *expressionInterfaceImpl) AS(alias string) Projection {
@ -103,7 +103,7 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption)
return err return err
} }
out.WriteString(c.operator) out.WriteString(" " + c.operator + " ")
if err := c.rhs.Serialize(out); err != nil { if err := c.rhs.Serialize(out); err != nil {
return err return err

View file

@ -7,69 +7,6 @@ import (
"time" "time"
) )
// Representation of a tuple enclosed, comma separated list of clauses
//type listClause struct {
// clauses []Clause
// includeParentheses bool
//}
//
//func (list *listClause) Serialize(out *queryData, options ...serializeOption) error {
// if list.includeParentheses {
// out.WriteByte('(')
// }
//
// if err := serializeClauseList(list.clauses, out); err != nil {
// return err
// }
//
// if list.includeParentheses {
// out.WriteByte(')')
// }
// return nil
//}
//
//type funcExpression struct {
// expressionInterfaceImpl
// funcName string
// args *listClause
//}
//
//func (c *funcExpression) Serialize(out *queryData, options ...serializeOption) error {
// if !validIdentifierName(c.funcName) {
// return errors.Newf(
// "Invalid function name: %s.",
// c.funcName,
// out.String())
// }
// _, _ = out.WriteString(c.funcName)
// if c.args == nil {
// _, _ = out.WriteString("()")
// } else {
// return c.args.Serialize(out)
// }
// return nil
//}
//
//// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
//func SqlFunc(funcName string, expressions ...Expression) Expression {
// f := &funcExpression{
// funcName: funcName,
// }
// if len(expressions) > 0 {
// args := make([]Clause, len(expressions), len(expressions))
// for i, expr := range expressions {
// args[i] = expr
// }
//
// f.args = &listClause{
// clauses: args,
// includeParentheses: true,
// }
// }
// return f
//}
type intervalExpression struct { type intervalExpression struct {
expressionInterfaceImpl expressionInterfaceImpl
duration time.Duration duration time.Duration
@ -118,137 +55,3 @@ var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%")
func EscapeForLike(s string) string { func EscapeForLike(s string) string {
return likeEscaper.Replace(s) return likeEscaper.Replace(s)
} }
// Returns an escaped literal string
//func Literal(v interface{}) Expression {
// value, err := sqltypes.BuildValue(v)
// if err != nil {
// panic(errors.Wrap(err, "Invalid literal value"))
// }
// return NewLiteralExpression(value)
//}
//
//// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
//func Add(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" + "),
// }
//}
//
//// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
//func Sub(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" - "),
// }
//}
//
//// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
//func Mul(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" * "),
// }
//}
//
//// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
//func Div(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" / "),
// }
//}
//TODO: Uncomment
//
//func BitOr(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" | "),
// }
//}
//
//func BitAnd(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" & "),
// }
//}
//
//func BitXor(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" ^ "),
// }
//}
//
//func Plus(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" + "),
// }
//}
//
//func Minus(lhs, rhs Expression) Expression {
// return &binaryExpression{
// lhs: lhs,
// rhs: rhs,
// operator: []byte(" - "),
// }
//}
type ifExpression struct {
expressionInterfaceImpl
conditional BoolExpression
trueExpression Expression
falseExpression Expression
}
func (exp *ifExpression) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("IF(")
_ = exp.conditional.Serialize(out)
out.WriteString(",")
_ = exp.trueExpression.Serialize(out)
out.WriteString(",")
_ = exp.falseExpression.Serialize(out)
out.WriteString(")")
return nil
}
// Returns a representation of an if-expression, of the form:
// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE)
func If(conditional BoolExpression,
trueExpression Expression,
falseExpression Expression) Expression {
return &ifExpression{
conditional: conditional,
trueExpression: trueExpression,
falseExpression: falseExpression,
}
}
//TODO: Uncomment
//type columnValueExpression struct {
// isExpression
// column NonAliasColumn
//}
//
//func ColumnValue(col NonAliasColumn) Expression {
// return &columnValueExpression{
// column: col,
// }
//}
//
//func (cv *columnValueExpression) Serialize(out *bytes.Buffer) error {
// _, _ = out.WriteString("VALUES(")
// _ = cv.column.SerializeSqlForColumnList(out)
// _ = out.WriteByte(')')
// return nil
//}

View file

@ -375,22 +375,6 @@ func (s *ExprSuite) TestDesc(c *gc.C) {
c.Assert(sql, gc.Equals, "table1.col1 DESC") c.Assert(sql, gc.Equals, "table1.col1 DESC")
} }
func (s *ExprSuite) TestIf(c *gc.C) {
test := GtL(table1Col1, 1.1)
clause := If(test, table1Col1, table1Col2)
buf := &bytes.Buffer{}
err := clause.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
c.Assert(
sql,
gc.Equals,
"IF(table1.col1>1.1,table1.col1,table1.col2)")
}
func (s *ExprSuite) TestColumnValue(c *gc.C) { func (s *ExprSuite) TestColumnValue(c *gc.C) {
clause := ColumnValue(table1Col1) clause := ColumnValue(table1Col1)

View file

@ -1,5 +1,7 @@
package sqlbuilder package sqlbuilder
import "errors"
type funcExpressionImpl struct { type funcExpressionImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
@ -63,3 +65,98 @@ func MAX(expression NumericExpression) NumericExpression {
func SUM(expression NumericExpression) NumericExpression { func SUM(expression NumericExpression) NumericExpression {
return NewNumericFunc("SUM", expression) return NewNumericFunc("SUM", expression)
} }
type caseInterface interface {
Expression
WHEN(condition Expression) caseInterface
THEN(then Expression) caseInterface
ELSE(els Expression) caseInterface
}
type caseExpression struct {
expressionInterfaceImpl
expression Expression
when []Expression
then []Expression
els Expression
}
func CASE(expression ...Expression) caseInterface {
caseExp := &caseExpression{}
if len(expression) == 1 {
caseExp.expression = expression[0]
}
caseExp.expressionInterfaceImpl.parent = caseExp
return caseExp
}
func (c *caseExpression) WHEN(when Expression) caseInterface {
c.when = append(c.when, when)
return c
}
func (c *caseExpression) THEN(then Expression) caseInterface {
c.then = append(c.then, then)
return c
}
func (c *caseExpression) ELSE(els Expression) caseInterface {
c.els = els
return c
}
func (c *caseExpression) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("(CASE")
if c.expression != nil {
out.WriteString(" ")
err := c.expression.Serialize(out)
if err != nil {
return err
}
}
if len(c.when) == 0 || len(c.then) == 0 {
return errors.New("Invalid case statement. There should be at least one when/then expression pair. ")
}
if len(c.when) != len(c.then) {
return errors.New("When and then expression count mismatch. ")
}
for i, when := range c.when {
out.WriteString(" WHEN ")
err := when.Serialize(out)
if err != nil {
return err
}
out.WriteString(" THEN ")
err = c.then[i].Serialize(out)
if err != nil {
return err
}
}
if c.els != nil {
out.WriteString(" ELSE ")
err := c.els.Serialize(out)
if err != nil {
return err
}
}
out.WriteString(" END)")
return nil
}

View file

@ -0,0 +1,33 @@
package sqlbuilder
import (
"gotest.tools/assert"
"testing"
)
func TestCase1(t *testing.T) {
query := CASE().
WHEN(table3Col1.EqL(1)).THEN(table3Col1.Add(IntLiteral(1))).
WHEN(table3Col1.EqL(2)).THEN(table3Col1.Add(IntLiteral(2)))
queryData := &queryData{}
err := query.Serialize(queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`)
}
func TestCase2(t *testing.T) {
query := CASE(table3Col1).
WHEN(IntLiteral(1)).THEN(table3Col1.Add(IntLiteral(1))).
WHEN(IntLiteral(2)).THEN(table3Col1.Add(IntLiteral(2))).
ELSE(IntLiteral(0))
queryData := &queryData{}
err := query.Serialize(queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`)
}

View file

@ -20,3 +20,19 @@ func (l literalExpression) Serialize(out *queryData, options ...serializeOption)
return nil return nil
} }
type numLiteralExpression struct {
literalExpression
numericInterfaceImpl
}
func IntLiteral(value int) NumericExpression {
numLiteral := &numLiteralExpression{}
numLiteral.literalExpression = *Literal(value)
numLiteral.literalExpression.parent = numLiteral
numLiteral.numericInterfaceImpl.parent = numLiteral
return numLiteral
}

View file

@ -62,19 +62,19 @@ func (n *numericInterfaceImpl) LtEqL(literal interface{}) BoolExpression {
} }
func (n *numericInterfaceImpl) Add(expression NumericExpression) NumericExpression { func (n *numericInterfaceImpl) Add(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " + ") return newBinaryNumericExpression(n.parent, expression, "+")
} }
func (n *numericInterfaceImpl) Sub(expression NumericExpression) NumericExpression { func (n *numericInterfaceImpl) Sub(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " - ") return newBinaryNumericExpression(n.parent, expression, "-")
} }
func (n *numericInterfaceImpl) Mul(expression NumericExpression) NumericExpression { func (n *numericInterfaceImpl) Mul(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " * ") return newBinaryNumericExpression(n.parent, expression, "*")
} }
func (n *numericInterfaceImpl) Div(expression NumericExpression) NumericExpression { func (n *numericInterfaceImpl) Div(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, " / ") return newBinaryNumericExpression(n.parent, expression, "/")
} }
//---------------------------------------------------// //---------------------------------------------------//

View file

@ -33,7 +33,7 @@ func (o *orderByClause) Serialize(out *queryData, options ...serializeOption) er
return errors.Newf("nil orderBy by clause.") return errors.Newf("nil orderBy by clause.")
} }
if err := o.expression.Serialize(out, UNION_ORDER_BY); err != nil { if err := o.expression.Serialize(out); err != nil {
return err return err
} }

View file

@ -482,7 +482,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
// tests on outer statement: Group By, Order By, LIMIT // tests on outer statement: Group By, Order By, LIMIT
// on inner statement: AndWhere, WHERE (with And), Order By, LIMIT // on inner statement: AndWhere, WHERE (with AND), Order By, LIMIT
select_queries := make([]SelectStatement, 0, 3) select_queries := make([]SelectStatement, 0, 3)
// We're not trying to write a SQL parser, so we won't warn if you do something silly like // We're not trying to write a SQL parser, so we won't warn if you do something silly like

View file

@ -4,9 +4,9 @@ type StringExpression interface {
Expression Expression
Eq(expression StringExpression) BoolExpression Eq(expression StringExpression) BoolExpression
EqL(value string) BoolExpression EqString(value string) BoolExpression
NotEq(expression StringExpression) BoolExpression NotEq(expression StringExpression) BoolExpression
NotEqL(value string) BoolExpression NotEqString(value string) BoolExpression
} }
type stringInterfaceImpl struct { type stringInterfaceImpl struct {
@ -14,17 +14,17 @@ type stringInterfaceImpl struct {
} }
func (b *stringInterfaceImpl) Eq(expression StringExpression) BoolExpression { func (b *stringInterfaceImpl) Eq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, " = ") return Eq(b.parent, expression)
} }
func (b *stringInterfaceImpl) EqL(value string) BoolExpression { func (b *stringInterfaceImpl) EqString(value string) BoolExpression {
return newBinaryBoolExpression(b.parent, Literal(value), " = ") return EqL(b.parent, value)
} }
func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression { func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, " != ") return NotEq(b.parent, expression)
} }
func (b *stringInterfaceImpl) NotEqL(value string) BoolExpression { func (b *stringInterfaceImpl) NotEqString(value string) BoolExpression {
return newBinaryBoolExpression(b.parent, Literal(value), " != ") return NotEq(b.parent, Literal(value))
} }

View file

@ -42,7 +42,7 @@ func TestUpdate(t *testing.T) {
// //
//func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) { //func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) // stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2)) // stmt.WHERE(EqString(table1Col2, 2))
// sql, err := stmt.String() // sql, err := stmt.String()
// c.Assert(err, gc.IsNil) // c.Assert(err, gc.IsNil)
// //
@ -54,7 +54,7 @@ func TestUpdate(t *testing.T) {
// //
//func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) { //func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) {
// stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1)) // stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1))
// stmt.WHERE(EqL(table1Col2, 2)) // stmt.WHERE(EqString(table1Col2, 2))
// sql, err := stmt.String() // sql, err := stmt.String()
// c.Assert(err, gc.IsNil) // c.Assert(err, gc.IsNil)
// //
@ -68,7 +68,7 @@ func TestUpdate(t *testing.T) {
// stmt := table1.UPDATE() // stmt := table1.UPDATE()
// stmt.SET(table1Col1, Literal(1)) // stmt.SET(table1Col1, Literal(1))
// stmt.SET(table1Col2, Literal(2)) // stmt.SET(table1Col2, Literal(2))
// stmt.WHERE(EqL(table1Col2, 3)) // stmt.WHERE(EqString(table1Col2, 3))
// sql, err := stmt.String() // sql, err := stmt.String()
// c.Assert(err, gc.IsNil) // c.Assert(err, gc.IsNil)
// //
@ -82,7 +82,7 @@ func TestUpdate(t *testing.T) {
// //
//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { //func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) // stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2)) // stmt.WHERE(EqString(table1Col2, 2))
// stmt.ORDER_BY(table1Col2) // stmt.ORDER_BY(table1Col2)
// sql, err := stmt.String() // sql, err := stmt.String()
// c.Assert(err, gc.IsNil) // c.Assert(err, gc.IsNil)
@ -98,7 +98,7 @@ func TestUpdate(t *testing.T) {
// //
//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { //func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) // stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2)) // stmt.WHERE(EqString(table1Col2, 2))
// stmt.LIMIT(5) // stmt.LIMIT(5)
// sql, err := stmt.String() // sql, err := stmt.String()
// c.Assert(err, gc.IsNil) // c.Assert(err, gc.IsNil)

View file

@ -12,7 +12,7 @@ import (
func TestUUIDType(t *testing.T) { func TestUUIDType(t *testing.T) {
query := table.AllTypes. query := table.AllTypes.
SELECT(table.AllTypes.AllColumns). SELECT(table.AllTypes.AllColumns).
WHERE(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) WHERE(table.AllTypes.UUID.EqString("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
queryStr, args, err := query.Sql() queryStr, args, err := query.Sql()

View file

@ -2,6 +2,7 @@ package tests
import ( import (
"fmt" "fmt"
"github.com/davecgh/go-spew/spew"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder" "github.com/sub0zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
@ -101,7 +102,7 @@ func TestSelectAndUnionInProjection(t *testing.T) {
// INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)). // INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)).
// INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). // INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). // SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns).
// WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2))) // WHERE(FilmActor.ActorID.GtEq(1).AND(FilmActor.ActorID.LteLiteral(2)))
// //
// queryStr, args, err := query.Sql() // queryStr, args, err := query.Sql()
// assert.NilError(t, err) // assert.NilError(t, err)
@ -131,7 +132,7 @@ func TestJoinQuerySlice(t *testing.T) {
query := Film. query := Film.
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Language.AllColumns, Film.AllColumns). SELECT(Language.AllColumns, Film.AllColumns).
WHERE(Film.Rating.EqL(string(model.MpaaRating_NC17))). WHERE(Film.Rating.EqString(string(model.MpaaRating_NC17))).
LIMIT(15) LIMIT(15)
queryStr, args, err := query.Sql() queryStr, args, err := query.Sql()
@ -317,7 +318,7 @@ func TestSelectSelfJoin(t *testing.T) {
f2 := Film.AS("f2") f2 := Film.AS("f2")
query := f1. query := f1.
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))).
SELECT(f1.AllColumns, f2.AllColumns). SELECT(f1.AllColumns, f2.AllColumns).
ORDER_BY(f1.FilmID.ASC()) ORDER_BY(f1.FilmID.ASC())
@ -356,7 +357,7 @@ func TestSelectAliasColumn(t *testing.T) {
} }
query := f1. query := f1.
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))).
SELECT(f1.Title.AS("thesame_length_films.title1"), SELECT(f1.Title.AS("thesame_length_films.title1"),
f2.Title.AS("thesame_length_films.title2"), f2.Title.AS("thesame_length_films.title2"),
f1.Length.AS("thesame_length_films.length")). f1.Length.AS("thesame_length_films.length")).
@ -443,7 +444,7 @@ func TestSubQuery(t *testing.T) {
// WHERE(Actor.LastName.Neq(avrgCustomer)) // WHERE(Actor.LastName.Neq(avrgCustomer))
rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating). rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating).
WHERE(Film.Rating.EqL("R")). WHERE(Film.Rating.EqString("R")).
AsTable("films") AsTable("films")
query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)). query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)).
@ -532,7 +533,7 @@ func TestSelectGroupByHaving(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(queryStr) fmt.Println(queryStr)
assert.Equal(t, len(args), 1) assert.Equal(t, len(args), 1)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id", SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>$1 ORDER BY SUM(payment.amount) ASC`) assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id", SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount) > $1 ORDER BY SUM(payment.amount) ASC`)
type CustomerPaymentSum struct { type CustomerPaymentSum struct {
CustomerID int16 CustomerID int16
@ -666,6 +667,38 @@ func TestUnion(t *testing.T) {
}) })
} }
func TestSelectWithCase(t *testing.T) {
query := Payment.SELECT(
sqlbuilder.CASE(Payment.StaffID).
WHEN(sqlbuilder.IntLiteral(1)).THEN(sqlbuilder.Literal("ONE")).
WHEN(sqlbuilder.IntLiteral(2)).THEN(sqlbuilder.Literal("TWO")).
WHEN(sqlbuilder.IntLiteral(3)).THEN(sqlbuilder.Literal("THREE")).
ELSE(sqlbuilder.Literal("OTHER")).AS("staff_id_num"),
).
ORDER_BY(Payment.PaymentID.ASC()).
LIMIT(20)
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
fmt.Println(args)
dest := []struct {
StaffIdNum string
}{}
err = query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 20)
assert.Equal(t, dest[0].StaffIdNum, "TWO")
assert.Equal(t, dest[1].StaffIdNum, "ONE")
spew.Dump(dest)
}
func int16Ptr(i int16) *int16 { func int16Ptr(i int16) *int16 {
return &i return &i
} }

View file

@ -21,7 +21,7 @@ func TestUpdateValues(t *testing.T) {
query := table.Link. query := table.Link.
UPDATE(table.Link.Name, table.Link.URL). UPDATE(table.Link.Name, table.Link.URL).
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
WHERE(table.Link.Name.EqL("Bing")) WHERE(table.Link.Name.EqString("Bing"))
queryStr, args, err := query.Sql() queryStr, args, err := query.Sql()
@ -38,7 +38,7 @@ func TestUpdateValues(t *testing.T) {
links := []model.Link{} links := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns). err = table.Link.SELECT(table.Link.AllColumns).
WHERE(table.Link.Name.EqL("Bong")). WHERE(table.Link.Name.EqString("Bong")).
Query(db, &links) Query(db, &links)
assert.NilError(t, err) assert.NilError(t, err)
@ -60,7 +60,7 @@ func TestUpdateAndReturning(t *testing.T) {
stmt := table.Link. stmt := table.Link.
UPDATE(table.Link.Name, table.Link.URL). UPDATE(table.Link.Name, table.Link.URL).
SET("DuckDuckGo", "http://www.duckduckgo.com"). SET("DuckDuckGo", "http://www.duckduckgo.com").
WHERE(table.Link.Name.EqL("Ask")). WHERE(table.Link.Name.EqString("Ask")).
RETURNING(table.Link.AllColumns) RETURNING(table.Link.AllColumns)
stmtStr, args, err := stmt.Sql() stmtStr, args, err := stmt.Sql()