From 4f9323ddcaa4cb4424983a0f44d0e86f4e75f6fa Mon Sep 17 00:00:00 2001 From: zer0sub Date: Mon, 6 May 2019 12:42:15 +0200 Subject: [PATCH] Add support for CASE operator. --- sqlbuilder/bool_expresion.go | 30 ++--- sqlbuilder/bool_expression_test.go | 18 +-- sqlbuilder/clause.go | 2 - sqlbuilder/execution/execution.go | 9 +- sqlbuilder/expression.go | 6 +- sqlbuilder/expression_old.go | 197 ---------------------------- sqlbuilder/expression_old_test.go | 16 --- sqlbuilder/func_expression.go | 97 ++++++++++++++ sqlbuilder/func_expression_test.go | 33 +++++ sqlbuilder/literal_expression.go | 16 +++ sqlbuilder/numeric_expression.go | 8 +- sqlbuilder/order_by_clause.go | 2 +- sqlbuilder/statement_test.go | 2 +- sqlbuilder/string_expression.go | 16 +-- sqlbuilder/update_statement_test.go | 10 +- tests/sample_test.go | 2 +- tests/select_test.go | 45 ++++++- tests/update_test.go | 6 +- 18 files changed, 243 insertions(+), 272 deletions(-) create mode 100644 sqlbuilder/func_expression_test.go diff --git a/sqlbuilder/bool_expresion.go b/sqlbuilder/bool_expresion.go index 9869f0e..ec58b6e 100644 --- a/sqlbuilder/bool_expresion.go +++ b/sqlbuilder/bool_expresion.go @@ -8,10 +8,10 @@ type BoolExpression interface { GtEq(rhs Expression) BoolExpression LtEq(rhs Expression) BoolExpression - And(expression BoolExpression) BoolExpression - Or(expression BoolExpression) BoolExpression - IsTrue() BoolExpression - IsFalse() BoolExpression + AND(expression BoolExpression) BoolExpression + OR(expression BoolExpression) BoolExpression + IS_TRUE() BoolExpression + IS_FALSE() BoolExpression } type boolInterfaceImpl struct { @@ -34,18 +34,18 @@ func (b *boolInterfaceImpl) LtEq(rhs Expression) BoolExpression { return LtEq(b.parent, rhs) } -func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression { +func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression { return And(b.parent, expression) } -func (b *boolInterfaceImpl) Or(expression BoolExpression) BoolExpression { +func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression { return Or(b.parent, expression) } -func (b *boolInterfaceImpl) IsTrue() BoolExpression { +func (b *boolInterfaceImpl) IS_TRUE() BoolExpression { return IsTrue(b.parent) } -func (b *boolInterfaceImpl) IsFalse() BoolExpression { +func (b *boolInterfaceImpl) IS_FALSE() BoolExpression { return nil } @@ -106,7 +106,7 @@ func EXISTS(subQuery SelectStatement) BoolExpression { // Returns a representation of "a=b" func Eq(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, " = ") + return newBinaryBoolExpression(lhs, rhs, "=") } // Returns a representation of "a=b", where b is a literal @@ -166,24 +166,24 @@ func GteL(lhs Expression, val interface{}) BoolExpression { // Returns a representation of "not expr" func Not(expr BoolExpression) BoolExpression { - return newPrefixBoolExpression(expr, " NOT") + return newPrefixBoolExpression(expr, "NOT") } func IsTrue(expr BoolExpression) BoolExpression { - return newPrefixBoolExpression(expr, " IS TRUE") + return newPrefixBoolExpression(expr, "IS TRUE") } func And(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, " AND ") + return newBinaryBoolExpression(lhs, rhs, "AND") } // Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses func Or(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, " OR ") + return newBinaryBoolExpression(lhs, rhs, "OR") } func Like(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, " LIKE ") + return newBinaryBoolExpression(lhs, rhs, "LIKE") } func LikeL(lhs Expression, val string) BoolExpression { @@ -191,7 +191,7 @@ func LikeL(lhs Expression, val string) BoolExpression { } func Regexp(lhs, rhs Expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, " REGEXP ") + return newBinaryBoolExpression(lhs, rhs, "REGEXP") } func RegexpL(lhs Expression, val string) BoolExpression { diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index 8c60595..4a41478 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -27,7 +27,7 @@ func TestBinaryExpression(t *testing.T) { }) t.Run("and", func(t *testing.T) { - exp := boolExpression.And(Eq(Literal(4), Literal(5))) + exp := boolExpression.AND(Eq(Literal(4), Literal(5))) out := queryData{} err := exp.Serialize(&out) @@ -37,7 +37,7 @@ func TestBinaryExpression(t *testing.T) { }) t.Run("or", func(t *testing.T) { - exp := boolExpression.Or(Eq(Literal(4), Literal(5))) + exp := boolExpression.OR(Eq(Literal(4), Literal(5))) out := queryData{} err := exp.Serialize(&out) @@ -54,7 +54,7 @@ func TestUnaryExpression(t *testing.T) { err := notExpression.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), " NOT $1 = $2") + assert.Equal(t, out.buff.String(), "NOT $1 = $2") t.Run("alias", func(t *testing.T) { alias := notExpression.AS("alias_not_expression") @@ -63,17 +63,17 @@ func TestUnaryExpression(t *testing.T) { err := alias.SerializeForProjection(&out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), ` NOT $1 = $2 AS "alias_not_expression"`) + assert.Equal(t, out.buff.String(), `NOT $1 = $2 AS "alias_not_expression"`) }) t.Run("and", func(t *testing.T) { - exp := notExpression.And(Eq(Literal(4), Literal(5))) + exp := notExpression.AND(Eq(Literal(4), Literal(5))) out := queryData{} err := exp.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `( NOT $1 = $2 AND $3 = $4)`) + assert.Equal(t, out.buff.String(), `(NOT $1 = $2 AND $3 = $4)`) }) } @@ -84,16 +84,16 @@ func TestUnaryIsTrueExpression(t *testing.T) { err := notExpression.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), " IS TRUE $1 = $2") + assert.Equal(t, out.buff.String(), "IS TRUE $1 = $2") t.Run("and", func(t *testing.T) { - exp := notExpression.And(Eq(Literal(4), Literal(5))) + exp := notExpression.AND(Eq(Literal(4), Literal(5))) out := queryData{} err := exp.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `( IS TRUE $1 = $2 AND $3 = $4)`) + assert.Equal(t, out.buff.String(), `(IS TRUE $1 = $2 AND $3 = $4)`) }) } diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 65253dc..4fa45e4 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -10,8 +10,6 @@ type serializeOption int const ( FOR_PROJECTION = iota - UNION_ORDER_BY - NO_TABLE_NAME ) type Clause interface { diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index e9286a8..7cb44b1 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -487,7 +487,14 @@ func initializePtrValue(value reflect.Value) { } func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} { - columnName := tableName + "." + snaker.CamelToSnake(fieldName) + columnName := "" + + if tableName == "" { + columnName = snaker.CamelToSnake(fieldName) + } else { + columnName = tableName + "." + snaker.CamelToSnake(fieldName) + } + //columnName := snaker.CamelToSnake(fieldName) ////fmt.Println(columnName) diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index f1bc548..7d48a39 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -24,11 +24,11 @@ type expressionInterfaceImpl struct { } func (e *expressionInterfaceImpl) IN(subQuery SelectStatement) BoolExpression { - return newBinaryBoolExpression(e.parent, subQuery, " IN ") + return newBinaryBoolExpression(e.parent, subQuery, "IN") } func (e *expressionInterfaceImpl) NOT_IN(subQuery SelectStatement) BoolExpression { - return newBinaryBoolExpression(e.parent, subQuery, " NOT_IN ") + return newBinaryBoolExpression(e.parent, subQuery, "NOT_IN") } func (e *expressionInterfaceImpl) AS(alias string) Projection { @@ -103,7 +103,7 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) return err } - out.WriteString(c.operator) + out.WriteString(" " + c.operator + " ") if err := c.rhs.Serialize(out); err != nil { return err diff --git a/sqlbuilder/expression_old.go b/sqlbuilder/expression_old.go index 16edc5d..2c1499f 100644 --- a/sqlbuilder/expression_old.go +++ b/sqlbuilder/expression_old.go @@ -7,69 +7,6 @@ import ( "time" ) -// Representation of a tuple enclosed, comma separated list of clauses -//type listClause struct { -// clauses []Clause -// includeParentheses bool -//} -// -//func (list *listClause) Serialize(out *queryData, options ...serializeOption) error { -// if list.includeParentheses { -// out.WriteByte('(') -// } -// -// if err := serializeClauseList(list.clauses, out); err != nil { -// return err -// } -// -// if list.includeParentheses { -// out.WriteByte(')') -// } -// return nil -//} - -// -//type funcExpression struct { -// expressionInterfaceImpl -// funcName string -// args *listClause -//} -// -//func (c *funcExpression) Serialize(out *queryData, options ...serializeOption) error { -// if !validIdentifierName(c.funcName) { -// return errors.Newf( -// "Invalid function name: %s.", -// c.funcName, -// out.String()) -// } -// _, _ = out.WriteString(c.funcName) -// if c.args == nil { -// _, _ = out.WriteString("()") -// } else { -// return c.args.Serialize(out) -// } -// return nil -//} -// -//// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) -//func SqlFunc(funcName string, expressions ...Expression) Expression { -// f := &funcExpression{ -// funcName: funcName, -// } -// if len(expressions) > 0 { -// args := make([]Clause, len(expressions), len(expressions)) -// for i, expr := range expressions { -// args[i] = expr -// } -// -// f.args = &listClause{ -// clauses: args, -// includeParentheses: true, -// } -// } -// return f -//} - type intervalExpression struct { expressionInterfaceImpl duration time.Duration @@ -118,137 +55,3 @@ var likeEscaper = strings.NewReplacer("_", "\\_", "%", "\\%") func EscapeForLike(s string) string { return likeEscaper.Replace(s) } - -// Returns an escaped literal string -//func Literal(v interface{}) Expression { -// value, err := sqltypes.BuildValue(v) -// if err != nil { -// panic(errors.Wrap(err, "Invalid literal value")) -// } -// return NewLiteralExpression(value) -//} -// -//// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses -//func Add(expressions ...Expression) Expression { -// return &arithmeticExpression{ -// expressions: expressions, -// operator: []byte(" + "), -// } -//} -// -//// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses -//func Sub(expressions ...Expression) Expression { -// return &arithmeticExpression{ -// expressions: expressions, -// operator: []byte(" - "), -// } -//} -// -//// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses -//func Mul(expressions ...Expression) Expression { -// return &arithmeticExpression{ -// expressions: expressions, -// operator: []byte(" * "), -// } -//} -// -//// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses -//func Div(expressions ...Expression) Expression { -// return &arithmeticExpression{ -// expressions: expressions, -// operator: []byte(" / "), -// } -//} - -//TODO: Uncomment -// -//func BitOr(lhs, rhs Expression) Expression { -// return &binaryExpression{ -// lhs: lhs, -// rhs: rhs, -// operator: []byte(" | "), -// } -//} -// -//func BitAnd(lhs, rhs Expression) Expression { -// return &binaryExpression{ -// lhs: lhs, -// rhs: rhs, -// operator: []byte(" & "), -// } -//} -// -//func BitXor(lhs, rhs Expression) Expression { -// return &binaryExpression{ -// lhs: lhs, -// rhs: rhs, -// operator: []byte(" ^ "), -// } -//} -// -//func Plus(lhs, rhs Expression) Expression { -// return &binaryExpression{ -// lhs: lhs, -// rhs: rhs, -// operator: []byte(" + "), -// } -//} -// -//func Minus(lhs, rhs Expression) Expression { -// return &binaryExpression{ -// lhs: lhs, -// rhs: rhs, -// operator: []byte(" - "), -// } -//} - -type ifExpression struct { - expressionInterfaceImpl - - conditional BoolExpression - trueExpression Expression - falseExpression Expression -} - -func (exp *ifExpression) Serialize(out *queryData, options ...serializeOption) error { - out.WriteString("IF(") - _ = exp.conditional.Serialize(out) - out.WriteString(",") - _ = exp.trueExpression.Serialize(out) - out.WriteString(",") - _ = exp.falseExpression.Serialize(out) - out.WriteString(")") - - return nil -} - -// Returns a representation of an if-expression, of the form: -// IF (BOOLEAN TEST, VALUE-IF-TRUE, VALUE-IF-FALSE) -func If(conditional BoolExpression, - trueExpression Expression, - falseExpression Expression) Expression { - return &ifExpression{ - conditional: conditional, - trueExpression: trueExpression, - falseExpression: falseExpression, - } -} - -//TODO: Uncomment -//type columnValueExpression struct { -// isExpression -// column NonAliasColumn -//} -// -//func ColumnValue(col NonAliasColumn) Expression { -// return &columnValueExpression{ -// column: col, -// } -//} -// -//func (cv *columnValueExpression) Serialize(out *bytes.Buffer) error { -// _, _ = out.WriteString("VALUES(") -// _ = cv.column.SerializeSqlForColumnList(out) -// _ = out.WriteByte(')') -// return nil -//} diff --git a/sqlbuilder/expression_old_test.go b/sqlbuilder/expression_old_test.go index 355a1ff..ff020f1 100644 --- a/sqlbuilder/expression_old_test.go +++ b/sqlbuilder/expression_old_test.go @@ -375,22 +375,6 @@ func (s *ExprSuite) TestDesc(c *gc.C) { c.Assert(sql, gc.Equals, "table1.col1 DESC") } -func (s *ExprSuite) TestIf(c *gc.C) { - test := GtL(table1Col1, 1.1) - clause := If(test, table1Col1, table1Col2) - - buf := &bytes.Buffer{} - - err := clause.Serialize(buf) - c.Assert(err, gc.IsNil) - - sql := buf.String() - c.Assert( - sql, - gc.Equals, - "IF(table1.col1>1.1,table1.col1,table1.col2)") -} - func (s *ExprSuite) TestColumnValue(c *gc.C) { clause := ColumnValue(table1Col1) diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index 63a40da..f8929ba 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -1,5 +1,7 @@ package sqlbuilder +import "errors" + type funcExpressionImpl struct { expressionInterfaceImpl @@ -63,3 +65,98 @@ func MAX(expression NumericExpression) NumericExpression { func SUM(expression NumericExpression) NumericExpression { return NewNumericFunc("SUM", expression) } + +type caseInterface interface { + Expression + + WHEN(condition Expression) caseInterface + THEN(then Expression) caseInterface + ELSE(els Expression) caseInterface +} + +type caseExpression struct { + expressionInterfaceImpl + + expression Expression + when []Expression + then []Expression + els Expression +} + +func CASE(expression ...Expression) caseInterface { + caseExp := &caseExpression{} + + if len(expression) == 1 { + caseExp.expression = expression[0] + } + + caseExp.expressionInterfaceImpl.parent = caseExp + + return caseExp +} + +func (c *caseExpression) WHEN(when Expression) caseInterface { + c.when = append(c.when, when) + return c +} + +func (c *caseExpression) THEN(then Expression) caseInterface { + c.then = append(c.then, then) + return c +} + +func (c *caseExpression) ELSE(els Expression) caseInterface { + c.els = els + + return c +} + +func (c *caseExpression) Serialize(out *queryData, options ...serializeOption) error { + out.WriteString("(CASE") + + if c.expression != nil { + out.WriteString(" ") + err := c.expression.Serialize(out) + + if err != nil { + return err + } + } + + if len(c.when) == 0 || len(c.then) == 0 { + return errors.New("Invalid case statement. There should be at least one when/then expression pair. ") + } + + if len(c.when) != len(c.then) { + return errors.New("When and then expression count mismatch. ") + } + + for i, when := range c.when { + out.WriteString(" WHEN ") + err := when.Serialize(out) + + if err != nil { + return err + } + + out.WriteString(" THEN ") + err = c.then[i].Serialize(out) + + if err != nil { + return err + } + } + + if c.els != nil { + out.WriteString(" ELSE ") + err := c.els.Serialize(out) + + if err != nil { + return err + } + } + + out.WriteString(" END)") + + return nil +} diff --git a/sqlbuilder/func_expression_test.go b/sqlbuilder/func_expression_test.go new file mode 100644 index 0000000..7e4b9f0 --- /dev/null +++ b/sqlbuilder/func_expression_test.go @@ -0,0 +1,33 @@ +package sqlbuilder + +import ( + "gotest.tools/assert" + "testing" +) + +func TestCase1(t *testing.T) { + query := CASE(). + WHEN(table3Col1.EqL(1)).THEN(table3Col1.Add(IntLiteral(1))). + WHEN(table3Col1.EqL(2)).THEN(table3Col1.Add(IntLiteral(2))) + + queryData := &queryData{} + + err := query.Serialize(queryData) + + assert.NilError(t, err) + assert.Equal(t, queryData.buff.String(), `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`) +} + +func TestCase2(t *testing.T) { + query := CASE(table3Col1). + WHEN(IntLiteral(1)).THEN(table3Col1.Add(IntLiteral(1))). + WHEN(IntLiteral(2)).THEN(table3Col1.Add(IntLiteral(2))). + ELSE(IntLiteral(0)) + + queryData := &queryData{} + + err := query.Serialize(queryData) + + assert.NilError(t, err) + assert.Equal(t, queryData.buff.String(), `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`) +} diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go index 6a408db..1790036 100644 --- a/sqlbuilder/literal_expression.go +++ b/sqlbuilder/literal_expression.go @@ -20,3 +20,19 @@ func (l literalExpression) Serialize(out *queryData, options ...serializeOption) return nil } + +type numLiteralExpression struct { + literalExpression + numericInterfaceImpl +} + +func IntLiteral(value int) NumericExpression { + numLiteral := &numLiteralExpression{} + + numLiteral.literalExpression = *Literal(value) + numLiteral.literalExpression.parent = numLiteral + + numLiteral.numericInterfaceImpl.parent = numLiteral + + return numLiteral +} diff --git a/sqlbuilder/numeric_expression.go b/sqlbuilder/numeric_expression.go index 4281627..9b6b451 100644 --- a/sqlbuilder/numeric_expression.go +++ b/sqlbuilder/numeric_expression.go @@ -62,19 +62,19 @@ func (n *numericInterfaceImpl) LtEqL(literal interface{}) BoolExpression { } func (n *numericInterfaceImpl) Add(expression NumericExpression) NumericExpression { - return newBinaryNumericExpression(n.parent, expression, " + ") + return newBinaryNumericExpression(n.parent, expression, "+") } func (n *numericInterfaceImpl) Sub(expression NumericExpression) NumericExpression { - return newBinaryNumericExpression(n.parent, expression, " - ") + return newBinaryNumericExpression(n.parent, expression, "-") } func (n *numericInterfaceImpl) Mul(expression NumericExpression) NumericExpression { - return newBinaryNumericExpression(n.parent, expression, " * ") + return newBinaryNumericExpression(n.parent, expression, "*") } func (n *numericInterfaceImpl) Div(expression NumericExpression) NumericExpression { - return newBinaryNumericExpression(n.parent, expression, " / ") + return newBinaryNumericExpression(n.parent, expression, "/") } //---------------------------------------------------// diff --git a/sqlbuilder/order_by_clause.go b/sqlbuilder/order_by_clause.go index 239897f..61c0dab 100644 --- a/sqlbuilder/order_by_clause.go +++ b/sqlbuilder/order_by_clause.go @@ -33,7 +33,7 @@ func (o *orderByClause) Serialize(out *queryData, options ...serializeOption) er return errors.Newf("nil orderBy by clause.") } - if err := o.expression.Serialize(out, UNION_ORDER_BY); err != nil { + if err := o.expression.Serialize(out); err != nil { return err } diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index 0e84eaa..61d9f87 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -482,7 +482,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { // tests on outer statement: Group By, Order By, LIMIT - // on inner statement: AndWhere, WHERE (with And), Order By, LIMIT + // on inner statement: AndWhere, WHERE (with AND), Order By, LIMIT select_queries := make([]SelectStatement, 0, 3) // We're not trying to write a SQL parser, so we won't warn if you do something silly like diff --git a/sqlbuilder/string_expression.go b/sqlbuilder/string_expression.go index 5ad82bf..6682d93 100644 --- a/sqlbuilder/string_expression.go +++ b/sqlbuilder/string_expression.go @@ -4,9 +4,9 @@ type StringExpression interface { Expression Eq(expression StringExpression) BoolExpression - EqL(value string) BoolExpression + EqString(value string) BoolExpression NotEq(expression StringExpression) BoolExpression - NotEqL(value string) BoolExpression + NotEqString(value string) BoolExpression } type stringInterfaceImpl struct { @@ -14,17 +14,17 @@ type stringInterfaceImpl struct { } func (b *stringInterfaceImpl) Eq(expression StringExpression) BoolExpression { - return newBinaryBoolExpression(b.parent, expression, " = ") + return Eq(b.parent, expression) } -func (b *stringInterfaceImpl) EqL(value string) BoolExpression { - return newBinaryBoolExpression(b.parent, Literal(value), " = ") +func (b *stringInterfaceImpl) EqString(value string) BoolExpression { + return EqL(b.parent, value) } func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression { - return newBinaryBoolExpression(b.parent, expression, " != ") + return NotEq(b.parent, expression) } -func (b *stringInterfaceImpl) NotEqL(value string) BoolExpression { - return newBinaryBoolExpression(b.parent, Literal(value), " != ") +func (b *stringInterfaceImpl) NotEqString(value string) BoolExpression { + return NotEq(b.parent, Literal(value)) } diff --git a/sqlbuilder/update_statement_test.go b/sqlbuilder/update_statement_test.go index cc1be4d..bb6dcd9 100644 --- a/sqlbuilder/update_statement_test.go +++ b/sqlbuilder/update_statement_test.go @@ -42,7 +42,7 @@ func TestUpdate(t *testing.T) { // //func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) { // stmt := table1.UPDATE().SET(table1Col1, Literal(1)) -// stmt.WHERE(EqL(table1Col2, 2)) +// stmt.WHERE(EqString(table1Col2, 2)) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) // @@ -54,7 +54,7 @@ func TestUpdate(t *testing.T) { // //func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) { // stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1)) -// stmt.WHERE(EqL(table1Col2, 2)) +// stmt.WHERE(EqString(table1Col2, 2)) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) // @@ -68,7 +68,7 @@ func TestUpdate(t *testing.T) { // stmt := table1.UPDATE() // stmt.SET(table1Col1, Literal(1)) // stmt.SET(table1Col2, Literal(2)) -// stmt.WHERE(EqL(table1Col2, 3)) +// stmt.WHERE(EqString(table1Col2, 3)) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) // @@ -82,7 +82,7 @@ func TestUpdate(t *testing.T) { // //func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { // stmt := table1.UPDATE().SET(table1Col1, Literal(1)) -// stmt.WHERE(EqL(table1Col2, 2)) +// stmt.WHERE(EqString(table1Col2, 2)) // stmt.ORDER_BY(table1Col2) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) @@ -98,7 +98,7 @@ func TestUpdate(t *testing.T) { // //func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { // stmt := table1.UPDATE().SET(table1Col1, Literal(1)) -// stmt.WHERE(EqL(table1Col2, 2)) +// stmt.WHERE(EqString(table1Col2, 2)) // stmt.LIMIT(5) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) diff --git a/tests/sample_test.go b/tests/sample_test.go index f4ef72d..37f6233 100644 --- a/tests/sample_test.go +++ b/tests/sample_test.go @@ -12,7 +12,7 @@ import ( func TestUUIDType(t *testing.T) { query := table.AllTypes. SELECT(table.AllTypes.AllColumns). - WHERE(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + WHERE(table.AllTypes.UUID.EqString("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) queryStr, args, err := query.Sql() diff --git a/tests/select_test.go b/tests/select_test.go index cb85648..93558cb 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -2,6 +2,7 @@ package tests import ( "fmt" + "github.com/davecgh/go-spew/spew" "github.com/sub0zero/go-sqlbuilder/sqlbuilder" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" @@ -101,7 +102,7 @@ func TestSelectAndUnionInProjection(t *testing.T) { // INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)). // INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). // SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). -// WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2))) +// WHERE(FilmActor.ActorID.GtEq(1).AND(FilmActor.ActorID.LteLiteral(2))) // // queryStr, args, err := query.Sql() // assert.NilError(t, err) @@ -131,7 +132,7 @@ func TestJoinQuerySlice(t *testing.T) { query := Film. INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). SELECT(Language.AllColumns, Film.AllColumns). - WHERE(Film.Rating.EqL(string(model.MpaaRating_NC17))). + WHERE(Film.Rating.EqString(string(model.MpaaRating_NC17))). LIMIT(15) queryStr, args, err := query.Sql() @@ -317,7 +318,7 @@ func TestSelectSelfJoin(t *testing.T) { f2 := Film.AS("f2") query := f1. - INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). + INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))). SELECT(f1.AllColumns, f2.AllColumns). ORDER_BY(f1.FilmID.ASC()) @@ -356,7 +357,7 @@ func TestSelectAliasColumn(t *testing.T) { } query := f1. - INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). + INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))). SELECT(f1.Title.AS("thesame_length_films.title1"), f2.Title.AS("thesame_length_films.title2"), f1.Length.AS("thesame_length_films.length")). @@ -443,7 +444,7 @@ func TestSubQuery(t *testing.T) { // WHERE(Actor.LastName.Neq(avrgCustomer)) rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating). - WHERE(Film.Rating.EqL("R")). + WHERE(Film.Rating.EqString("R")). AsTable("films") query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)). @@ -532,7 +533,7 @@ func TestSelectGroupByHaving(t *testing.T) { assert.NilError(t, err) fmt.Println(queryStr) assert.Equal(t, len(args), 1) - assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id", SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>$1 ORDER BY SUM(payment.amount) ASC`) + assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id", SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount) > $1 ORDER BY SUM(payment.amount) ASC`) type CustomerPaymentSum struct { CustomerID int16 @@ -666,6 +667,38 @@ func TestUnion(t *testing.T) { }) } +func TestSelectWithCase(t *testing.T) { + query := Payment.SELECT( + sqlbuilder.CASE(Payment.StaffID). + WHEN(sqlbuilder.IntLiteral(1)).THEN(sqlbuilder.Literal("ONE")). + WHEN(sqlbuilder.IntLiteral(2)).THEN(sqlbuilder.Literal("TWO")). + WHEN(sqlbuilder.IntLiteral(3)).THEN(sqlbuilder.Literal("THREE")). + ELSE(sqlbuilder.Literal("OTHER")).AS("staff_id_num"), + ). + ORDER_BY(Payment.PaymentID.ASC()). + LIMIT(20) + + queryStr, args, err := query.Sql() + + assert.NilError(t, err) + fmt.Println(queryStr) + fmt.Println(args) + + dest := []struct { + StaffIdNum string + }{} + + err = query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 20) + assert.Equal(t, dest[0].StaffIdNum, "TWO") + assert.Equal(t, dest[1].StaffIdNum, "ONE") + + spew.Dump(dest) + +} + func int16Ptr(i int16) *int16 { return &i } diff --git a/tests/update_test.go b/tests/update_test.go index 77b5657..e5bcf3d 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -21,7 +21,7 @@ func TestUpdateValues(t *testing.T) { query := table.Link. UPDATE(table.Link.Name, table.Link.URL). SET("Bong", "http://bong.com"). - WHERE(table.Link.Name.EqL("Bing")) + WHERE(table.Link.Name.EqString("Bing")) queryStr, args, err := query.Sql() @@ -38,7 +38,7 @@ func TestUpdateValues(t *testing.T) { links := []model.Link{} err = table.Link.SELECT(table.Link.AllColumns). - WHERE(table.Link.Name.EqL("Bong")). + WHERE(table.Link.Name.EqString("Bong")). Query(db, &links) assert.NilError(t, err) @@ -60,7 +60,7 @@ func TestUpdateAndReturning(t *testing.T) { stmt := table.Link. UPDATE(table.Link.Name, table.Link.URL). SET("DuckDuckGo", "http://www.duckduckgo.com"). - WHERE(table.Link.Name.EqL("Ask")). + WHERE(table.Link.Name.EqString("Ask")). RETURNING(table.Link.AllColumns) stmtStr, args, err := stmt.Sql()