diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index c86ba21..de5e2bc 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -1,8 +1,6 @@ package sqlbuilder import ( - "fmt" - "gotest.tools/assert" "testing" ) @@ -18,6 +16,10 @@ func TestBoolExpressionNOT_EQ(t *testing.T) { func TestBoolExpressionIS_TRUE(t *testing.T) { assertExpressionSerialize(t, table1ColBool.IS_TRUE(), "table1.colBool IS TRUE") + assertExpressionSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE(), + `($1 = table1.colInt) IS TRUE`, int64(2)) + assertExpressionSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE().AND(Int(4).EQ(table2ColInt)), + `(($1 = table1.colInt) IS TRUE AND ($2 = table2.colInt))`, int64(2), int64(4)) } func TestBoolExpressionIS_NOT_TRUE(t *testing.T) { @@ -40,155 +42,32 @@ func TestBoolExpressionIS_NOT_UNKNOWN(t *testing.T) { assertExpressionSerialize(t, table1ColBool.IS_NOT_UNKNOWN(), "table1.colBool IS NOT UNKNOWN") } -func TestBinaryExpression(t *testing.T) { - boolExpression := EQ(Literal(2), Literal(3)) +func TestBinaryBoolExpression(t *testing.T) { + boolExpression := Int(2).EQ(Int(3)) - out := queryData{} - err := boolExpression.serialize(select_statement, &out) - - assert.NilError(t, err) - - assert.Equal(t, out.buff.String(), "($1 = $2)") - assert.Equal(t, len(out.args), 2) - - t.Run("alias", func(t *testing.T) { - alias := boolExpression.AS("alias_eq_expression") - - out := queryData{} - err := alias.serializeForProjection(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`) - }) - - t.Run("and", func(t *testing.T) { - exp := boolExpression.AND(EQ(Literal(4), Literal(5))) - - out := queryData{} - err := exp.serialize(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `(($1 = $2) AND ($3 = $4))`) - }) - - t.Run("or", func(t *testing.T) { - exp := boolExpression.OR(EQ(Literal(4), Literal(5))) - - out := queryData{} - err := exp.serialize(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `(($1 = $2) OR ($3 = $4))`) - }) + assertExpressionSerialize(t, boolExpression, "($1 = $2)", int64(2), int64(3)) + assertProjectionSerialize(t, boolExpression.AS("alias_eq_expression"), + `$1 = $2 AS "alias_eq_expression"`, int64(2), int64(3)) + assertExpressionSerialize(t, boolExpression.AND(Int(4).EQ(Int(5))), + "(($1 = $2) AND ($3 = $4))", int64(2), int64(3), int64(4), int64(5)) + assertExpressionSerialize(t, boolExpression.OR(Int(4).EQ(Int(5))), + "(($1 = $2) OR ($3 = $4))", int64(2), int64(3), int64(4), int64(5)) } - -func TestUnaryExpression(t *testing.T) { - notExpression := NOT(EQ(Literal(2), Literal(1))) - - out := queryData{} - err := notExpression.serialize(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), "NOT ($1 = $2)") - - t.Run("alias", func(t *testing.T) { - alias := notExpression.AS("alias_not_expression") - - out := queryData{} - err := alias.serializeForProjection(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `NOT ($1 = $2) AS "alias_not_expression"`) - }) - - t.Run("and", func(t *testing.T) { - exp := notExpression.AND(EQ(Literal(4), Literal(5))) - - out := queryData{} - err := exp.serialize(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `(NOT ($1 = $2) AND ($3 = $4))`) - }) -} - -func TestUnaryIsTrueExpression(t *testing.T) { - exp := IS_TRUE(EQ(Literal(2), Literal(1))) - - out := queryData{} - err := exp.serialize(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), "($1 = $2) IS TRUE") - - t.Run("and", func(t *testing.T) { - exp := exp.AND(EQ(Literal(4), Literal(5))) - - out := queryData{} - err := exp.serialize(select_statement, &out) - - assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `(($1 = $2) IS TRUE AND ($3 = $4))`) - }) -} - func TestBoolLiteral(t *testing.T) { - literal := Bool(true) - - out := queryData{} - err := literal.serialize(select_statement, &out) - - assert.NilError(t, err) - - assert.Equal(t, out.buff.String(), "$1") + assertExpressionSerialize(t, Bool(true), "$1", true) + assertExpressionSerialize(t, Bool(false), "$1", false) } func TestExists(t *testing.T) { - query := EXISTS( + + assertExpressionSerialize(t, EXISTS( table2. - SELECT(Literal(1)). + SELECT(Int(1)). WHERE(table1Col1.EQ(table2Col3)), - ) - - out := queryData{} - err := query.serialize(select_statement, &out) - - fmt.Println(out.buff.String()) - - assert.NilError(t, err) - - expectedSql := + ), `EXISTS ( SELECT $1 FROM db.table2 WHERE table1.col1 = table2.col3 -)` - assert.Equal(t, out.buff.String(), expectedSql) -} - -func TestIn(t *testing.T) { - query := Literal(1.11).IN(table1.SELECT(table1Col1)) - - out := queryData{} - err := query.serialize(select_statement, &out, NO_WRAP) - - assert.NilError(t, err) - fmt.Println(out.buff.String()) - assert.Equal(t, out.buff.String(), `$1 IN ( - SELECT table1.col1 AS "table1.col1" - FROM db.table1 -)`) - - query2 := ROW(Literal(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)) - - out = queryData{} - err = query2.serialize(select_statement, &out) - - assert.NilError(t, err) - fmt.Println(out.buff.String()) - assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN ( - SELECT table2.col3 AS "table2.col3", - table3.col1 AS "table3.col1" - FROM db.table2 -))`) +)`, int64(1)) } diff --git a/sqlbuilder/delete_statement_test.go b/sqlbuilder/delete_statement_test.go index a81aa9f..ebaff1a 100644 --- a/sqlbuilder/delete_statement_test.go +++ b/sqlbuilder/delete_statement_test.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "fmt" "gotest.tools/assert" "testing" ) @@ -15,7 +14,6 @@ func TestDeleteWithWhere(t *testing.T) { sql, _, err := table1.DELETE().WHERE(table1Col1.EQ(Int(1))).Sql() assert.NilError(t, err) - fmt.Println(sql) expectedSql := ` DELETE FROM db.table1 WHERE table1.col1 = $1; diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index 00b0d00..1f4f3bc 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -14,8 +14,8 @@ type expression interface { IS_NULL() BoolExpression IS_NOT_NULL() BoolExpression - IN(subQuery selectStatement) BoolExpression - NOT_IN(subQuery selectStatement) BoolExpression + IN(expressions ...expression) BoolExpression + NOT_IN(expressions ...expression) BoolExpression AS(alias string) projection @@ -46,12 +46,12 @@ func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression { return newPostifxBoolExpression(e.parent, "IS NOT NULL") } -func (e *expressionInterfaceImpl) IN(subQuery selectStatement) BoolExpression { - return newBinaryBoolExpression(e.parent, subQuery, "IN") +func (e *expressionInterfaceImpl) IN(expressions ...expression) BoolExpression { + return newBinaryBoolExpression(e.parent, WRAP(expressions...), "IN") } -func (e *expressionInterfaceImpl) NOT_IN(subQuery selectStatement) BoolExpression { - return newBinaryBoolExpression(e.parent, subQuery, "NOT IN") +func (e *expressionInterfaceImpl) NOT_IN(expressions ...expression) BoolExpression { + return newBinaryBoolExpression(e.parent, WRAP(expressions...), "NOT IN") } func (e *expressionInterfaceImpl) AS(alias string) projection { diff --git a/sqlbuilder/expression_old_test.go b/sqlbuilder/expression_old_test.go index 4077a2c..f4cdcc3 100644 --- a/sqlbuilder/expression_old_test.go +++ b/sqlbuilder/expression_old_test.go @@ -107,7 +107,7 @@ func (s *ExprSuite) TestOrExpr(c *gc.C) { } func (s *ExprSuite) TestAddExpr(c *gc.C) { - expr := Add(Literal(1), Literal(2), Literal(3)) + expr := Add(literal(1), literal(2), literal(3)) buf := &bytes.Buffer{} @@ -119,7 +119,7 @@ func (s *ExprSuite) TestAddExpr(c *gc.C) { } func (s *ExprSuite) TestSubExpr(c *gc.C) { - expr := Sub(Literal(1), Literal(2), Literal(3)) + expr := Sub(literal(1), literal(2), literal(3)) buf := &bytes.Buffer{} @@ -131,7 +131,7 @@ func (s *ExprSuite) TestSubExpr(c *gc.C) { } func (s *ExprSuite) TestMulExpr(c *gc.C) { - expr := Mul(Literal(1), Literal(2), Literal(3)) + expr := Mul(literal(1), literal(2), literal(3)) buf := &bytes.Buffer{} @@ -143,7 +143,7 @@ func (s *ExprSuite) TestMulExpr(c *gc.C) { } func (s *ExprSuite) TestDivExpr(c *gc.C) { - expr := Div(Literal(1), Literal(2), Literal(3)) + expr := Div(literal(1), literal(2), literal(3)) buf := &bytes.Buffer{} @@ -388,7 +388,7 @@ func (s *ExprSuite) TestColumnValue(c *gc.C) { } func (s *ExprSuite) TestBitwiseOr(c *gc.C) { - clause := BitOr(Literal(1), Literal(2)) + clause := BitOr(literal(1), literal(2)) buf := &bytes.Buffer{} @@ -400,7 +400,7 @@ func (s *ExprSuite) TestBitwiseOr(c *gc.C) { } func (s *ExprSuite) TestBitwiseAnd(c *gc.C) { - clause := BitAnd(Literal(1), Literal(2)) + clause := BitAnd(literal(1), literal(2)) buf := &bytes.Buffer{} @@ -412,7 +412,7 @@ func (s *ExprSuite) TestBitwiseAnd(c *gc.C) { } func (s *ExprSuite) TestBitwiseXor(c *gc.C) { - clause := BitXor(Literal(1), Literal(2)) + clause := BitXor(literal(1), literal(2)) buf := &bytes.Buffer{} @@ -424,7 +424,7 @@ func (s *ExprSuite) TestBitwiseXor(c *gc.C) { } func (s *ExprSuite) TestPlus(c *gc.C) { - clause := Plus(Literal(1), Literal(2)) + clause := Plus(literal(1), literal(2)) buf := &bytes.Buffer{} @@ -436,7 +436,7 @@ func (s *ExprSuite) TestPlus(c *gc.C) { } func (s *ExprSuite) TestMinus(c *gc.C) { - clause := Minus(Literal(1), Literal(2)) + clause := Minus(literal(1), literal(2)) buf := &bytes.Buffer{} diff --git a/sqlbuilder/expression_test.go b/sqlbuilder/expression_test.go index 429410d..59289e5 100644 --- a/sqlbuilder/expression_test.go +++ b/sqlbuilder/expression_test.go @@ -60,3 +60,35 @@ func TestExpressionCAST_TO_TIMESTAMP(t *testing.T) { func TestExpressionCAST_TO_TIMESTAMPZ(t *testing.T) { assertExpressionSerialize(t, table2Col3.CAST_TO_TIMESTAMPZ(), "table2.col3::timestamp with time zone") } + +func TestIN(t *testing.T) { + + assertExpressionSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), + `($1 IN (( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 +)))`, float64(1.11)) + + assertExpressionSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), + `(ROW($1, table1.col1) IN (( + SELECT table2.col3 AS "table2.col3", + table3.col1 AS "table3.col1" + FROM db.table2 +)))`, int64(12)) +} + +func TestNOT_IN(t *testing.T) { + + assertExpressionSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), + `($1 NOT IN (( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 +)))`, float64(1.11)) + + assertExpressionSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), + `(ROW($1, table1.col1) NOT IN (( + SELECT table2.col3 AS "table2.col3", + table3.col1 AS "table3.col1" + FROM db.table2 +)))`, int64(12)) +} diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index b2453ad..ab0f523 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -49,9 +49,6 @@ func (f *funcExpressionImpl) serialize(statement statementType, out *queryData, return nil } -func ROW(expressions ...expression) expression { - return newFunc("ROW", expressions, nil) -} type boolFunc struct { funcExpressionImpl @@ -179,6 +176,10 @@ func newTimestampzFunc(name string, expressions ...expression) *timestampzFunc { return timestampzFunc } +func ROW(expressions ...expression) expression { + return newFunc("ROW", expressions, nil) +} + // ------------------ Mathematical functions ---------------// func ABSf(floatExpression FloatExpression) FloatExpression { @@ -370,8 +371,8 @@ func DECODE(data StringExpression, format StringExpression) StringExpression { return newStringFunc("DECODE", data, format) } -//func FORMAT(formatStr StringExpression, formatArgs ...expression) StringExpression { -// args := []expression{formatStr} +//func FORMAT(formatStr StringExpression, formatArgs ...expressions) StringExpression { +// args := []expressions{formatStr} // args = append(args, formatArgs...) // return newStringFunc("FORMAT", args...) //} @@ -479,7 +480,7 @@ func CURRENT_TIME(precision ...int) TimezExpression { var timezFunc *timezFunc if len(precision) > 0 { - timezFunc = newTimezFunc("CURRENT_TIME", ConstantLiteral(precision[0])) + timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0])) } else { timezFunc = newTimezFunc("CURRENT_TIME") } @@ -493,7 +494,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression { var timestampzFunc *timestampzFunc if len(precision) > 0 { - timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstantLiteral(precision[0])) + timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0])) } else { timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") } @@ -507,7 +508,7 @@ func LOCALTIME(precision ...int) TimeExpression { var timeFunc *timeFunc if len(precision) > 0 { - timeFunc = newTimeFunc("LOCALTIME", ConstantLiteral(precision[0])) + timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0])) } else { timeFunc = newTimeFunc("LOCALTIME") } @@ -521,7 +522,7 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression { var timestampFunc *timestampFunc if len(precision) > 0 { - timestampFunc = newTimestampFunc("LOCALTIMESTAMP", ConstantLiteral(precision[0])) + timestampFunc = newTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0])) } else { timestampFunc = newTimestampFunc("LOCALTIMESTAMP") } diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 19f9104..a008bce 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -59,7 +59,7 @@ func (i *insertStatementImpl) VALUES(values ...interface{}) insertStatement { if clause, ok := value.(clause); ok { literalRow = append(literalRow, clause) } else { - literalRow = append(literalRow, Literal(value)) + literalRow = append(literalRow, literal(value)) } } @@ -97,7 +97,7 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) insertStatement { return i } - rowValues = append(rowValues, Literal(structField.Interface())) + rowValues = append(rowValues, literal(structField.Interface())) } i.rows = append(i.rows, rowValues) diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go index 54f4c8f..8bd2feb 100644 --- a/sqlbuilder/literal_expression.go +++ b/sqlbuilder/literal_expression.go @@ -9,15 +9,15 @@ type literalExpression struct { constant bool } -func Literal(value interface{}) *literalExpression { +func literal(value interface{}) *literalExpression { exp := literalExpression{value: value} exp.expressionInterfaceImpl.parent = &exp return &exp } -func ConstantLiteral(value interface{}) *literalExpression { - exp := Literal(value) +func constLiteral(value interface{}) *literalExpression { + exp := literal(value) exp.constant = true return exp @@ -41,7 +41,7 @@ type integerLiteralExpression struct { func Int(value int64) IntegerExpression { numLiteral := &integerLiteralExpression{} - numLiteral.literalExpression = *Literal(value) + numLiteral.literalExpression = *literal(value) numLiteral.literalExpression.parent = numLiteral numLiteral.integerInterfaceImpl.parent = numLiteral @@ -58,7 +58,7 @@ type boolLiteralExpression struct { func Bool(value bool) BoolExpression { boolLiteralExpression := boolLiteralExpression{} - boolLiteralExpression.literalExpression = *Literal(value) + boolLiteralExpression.literalExpression = *literal(value) boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression return &boolLiteralExpression @@ -72,7 +72,7 @@ type floatLiteral struct { func Float(value float64) FloatExpression { floatLiteral := floatLiteral{} - floatLiteral.literalExpression = *Literal(value) + floatLiteral.literalExpression = *literal(value) floatLiteral.floatInterfaceImpl.parent = &floatLiteral @@ -87,7 +87,7 @@ type stringLiteral struct { func String(value string) StringExpression { stringLiteral := stringLiteral{} - stringLiteral.literalExpression = *Literal(value) + stringLiteral.literalExpression = *literal(value) stringLiteral.stringInterfaceImpl.parent = &stringLiteral @@ -103,7 +103,7 @@ type timeLiteral struct { func Time(hour, minute, second, milliseconds int) TimeExpression { timeLiteral := timeLiteral{} timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d", hour, minute, second, milliseconds) - timeLiteral.literalExpression = *Literal(timeStr) + timeLiteral.literalExpression = *literal(timeStr) timeLiteral.timeInterfaceImpl.parent = &timeLiteral @@ -119,7 +119,7 @@ type timezLiteral struct { func Timez(hour, minute, second, milliseconds, timezone int) TimezExpression { timezLiteral := timezLiteral{} timeStr := fmt.Sprintf("%02d:%02d:%02d.%03d %+03d", hour, minute, second, milliseconds, timezone) - timezLiteral.literalExpression = *Literal(timeStr) + timezLiteral.literalExpression = *literal(timeStr) timezLiteral.timezInterfaceImpl.parent = &timezLiteral @@ -135,7 +135,7 @@ type timestampLiteral struct { func Timestamp(year, month, day, hour, minute, second, milliseconds int) TimestampExpression { timestampLiteral := timestampLiteral{} timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d", year, month, day, hour, minute, second, milliseconds) - timestampLiteral.literalExpression = *Literal(timeStr) + timestampLiteral.literalExpression = *literal(timeStr) timestampLiteral.timestampInterfaceImpl.parent = ×tampLiteral @@ -153,7 +153,7 @@ func Timestampz(year, month, day, hour, minute, second, milliseconds, timezone i timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d.%03d %+04d", year, month, day, hour, minute, second, milliseconds, timezone) - timestampzLiteral.literalExpression = *Literal(timeStr) + timestampzLiteral.literalExpression = *literal(timeStr) timestampzLiteral.timestampzInterfaceImpl.parent = ×tampzLiteral @@ -170,7 +170,7 @@ func Date(year, month, day int) DateExpression { dateLiteral := dateLiteral{} timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day) - dateLiteral.literalExpression = *Literal(timeStr) + dateLiteral.literalExpression = *literal(timeStr) dateLiteral.dateInterfaceImpl.parent = &dateLiteral return dateLiteral.CAST_TO_DATE() @@ -211,3 +211,24 @@ func (n *starLiteral) serialize(statement statementType, out *queryData, options out.writeString("*") return nil } + +//---------------------------------------------------// + +type wrap struct { + expressionInterfaceImpl + expressions []expression +} + +func (n *wrap) serialize(statement statementType, out *queryData, options ...serializeOption) error { + out.writeString("(") + err := serializeExpressionList(statement, n.expressions, ", ", out) + out.writeString(")") + return err +} + +func WRAP(expression ...expression) expression { + wrap := &wrap{expressions: expression} + wrap.expressionInterfaceImpl.parent = wrap + + return wrap +} diff --git a/sqlbuilder/operators.go b/sqlbuilder/operators.go index c224e75..c7ddf6a 100644 --- a/sqlbuilder/operators.go +++ b/sqlbuilder/operators.go @@ -82,20 +82,12 @@ func Or(lhs, rhs expression) BoolExpression { return newBinaryBoolExpression(lhs, rhs, "OR") } -func Like(lhs, rhs expression) BoolExpression { - return newBinaryBoolExpression(lhs, rhs, "LIKE") -} - -func LikeL(lhs expression, val string) BoolExpression { - return Like(lhs, Literal(val)) -} - func Regexp(lhs, rhs expression) BoolExpression { return newBinaryBoolExpression(lhs, rhs, "REGEXP") } func RegexpL(lhs expression, val string) BoolExpression { - return Regexp(lhs, Literal(val)) + return Regexp(lhs, literal(val)) } func EXISTS(subQuery selectStatement) BoolExpression { diff --git a/sqlbuilder/operators_test.go b/sqlbuilder/operators_test.go index 568639b..9b13e73 100644 --- a/sqlbuilder/operators_test.go +++ b/sqlbuilder/operators_test.go @@ -2,6 +2,14 @@ package sqlbuilder import "testing" +func TestOperatorNOT(t *testing.T) { + notExpression := NOT(Int(2).EQ(Int(1))) + + assertExpressionSerialize(t, notExpression, "NOT ($1 = $2)", int64(2), int64(1)) + assertProjectionSerialize(t, notExpression.AS("alias_not_expression"), `NOT ($1 = $2) AS "alias_not_expression"`, int64(2), int64(1)) + assertExpressionSerialize(t, notExpression.AND(Int(4).EQ(Int(5))), `(NOT ($1 = $2) AND ($3 = $4))`, int64(2), int64(1), int64(4), int64(5)) +} + func TestCase1(t *testing.T) { query := CASE(). WHEN(table3Col1.EQ(Int(1))).THEN(table3Col1.ADD(Int(1))). diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index d2ec6ed..8555ba1 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -127,18 +127,12 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { return err } - if s.table == nil { - return errors.Newf("nil tableName.") + if s.table != nil { + if err := out.writeFrom(select_statement, s.table); err != nil { + return err + } } - if err := out.writeFrom(select_statement, s.table); err != nil { - return err - } - - //if err := s.table.serialize(select_statement, out); err != nil { - // return err - //} - if s.where != nil { err := out.writeWhere(select_statement, s.where) diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index 095603d..7574fe0 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -261,13 +261,13 @@ func (s *StmtSuite) TestInsertNilValue(c *gc.C) { } func (s *StmtSuite) TestInsertNilColumn(c *gc.C) { - _, err := table1.INSERT(nil).Add(Literal(1)).String() + _, err := table1.INSERT(nil).Add(literal(1)).String() c.Assert(err, gc.NotNil) } func (s *StmtSuite) TestInsertSingleValue(c *gc.C) { - sql, err := table1.INSERT(table1Col1).Add(Literal(1)).String() + sql, err := table1.INSERT(table1Col1).Add(literal(1)).String() c.Assert(err, gc.IsNil) c.Assert( @@ -279,7 +279,7 @@ func (s *StmtSuite) TestInsertSingleValue(c *gc.C) { func (s *StmtSuite) TestInsertDate(c *gc.C) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) - sql, err := table1.INSERT(table1ColTime).Add(Literal(date)).String() + sql, err := table1.INSERT(table1ColTime).Add(literal(date)).String() c.Assert(err, gc.IsNil) c.Assert( @@ -290,7 +290,7 @@ func (s *StmtSuite) TestInsertDate(c *gc.C) { } func (s *StmtSuite) TestInsertIgnore(c *gc.C) { - stmt := table1.INSERT(table1Col1).Add(Literal(1)).IgnoreDuplicates(true) + stmt := table1.INSERT(table1Col1).Add(literal(1)).IgnoreDuplicates(true) sql, err := stmt.String() c.Assert(err, gc.IsNil) @@ -302,7 +302,7 @@ func (s *StmtSuite) TestInsertIgnore(c *gc.C) { func (s *StmtSuite) TestInsertMultipleValues(c *gc.C) { stmt := table1.INSERT(table1Col1, table1ColFloat, table1Col3) - stmt.Add(Literal(1), Literal(2), Literal(3)) + stmt.Add(literal(1), literal(2), literal(3)) sql, err := stmt.String() c.Assert(err, gc.IsNil) @@ -317,9 +317,9 @@ func (s *StmtSuite) TestInsertMultipleValues(c *gc.C) { func (s *StmtSuite) TestInsertMultipleRows(c *gc.C) { stmt := table1.INSERT(table1Col1, table1ColFloat) - stmt.Add(Literal(1), Literal(2)) - stmt.Add(Literal(11), Literal(22)) - stmt.Add(Literal(111), Literal(222)) + stmt.Add(literal(1), literal(2)) + stmt.Add(literal(11), literal(22)) + stmt.Add(literal(111), literal(222)) sql, err := stmt.String() c.Assert(err, gc.IsNil) @@ -334,8 +334,8 @@ func (s *StmtSuite) TestInsertMultipleRows(c *gc.C) { func (s *StmtSuite) TestOnDuplicateKeyUpdateNilCol(c *gc.C) { stmt := table1.INSERT(table1Col1, table1ColFloat) - stmt.Add(Literal(1), Literal(2)) - stmt.AddOnDuplicateKeyUpdate(nil, Literal(3)) + stmt.Add(literal(1), literal(2)) + stmt.AddOnDuplicateKeyUpdate(nil, literal(3)) _, err := stmt.String() c.Assert(err, gc.NotNil) @@ -343,7 +343,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateNilCol(c *gc.C) { func (s *StmtSuite) TestOnDuplicateKeyUpdateNilExpr(c *gc.C) { stmt := table1.INSERT(table1Col1, table1ColFloat) - stmt.Add(Literal(1), Literal(2)) + stmt.Add(literal(1), literal(2)) stmt.AddOnDuplicateKeyUpdate(table1Col1, nil) _, err := stmt.String() @@ -352,8 +352,8 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateNilExpr(c *gc.C) { func (s *StmtSuite) TestOnDuplicateKeyUpdateSingle(c *gc.C) { stmt := table1.INSERT(table1Col1, table1ColFloat) - stmt.Add(Literal(1), Literal(2)) - stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3)) + stmt.Add(literal(1), literal(2)) + stmt.AddOnDuplicateKeyUpdate(table1Col3, literal(3)) sql, err := stmt.String() c.Assert(err, gc.IsNil) @@ -369,9 +369,9 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateSingle(c *gc.C) { func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) { stmt := table1.INSERT(table1Col1, table1ColFloat) - stmt.Add(Literal(1), Literal(2)) - stmt.AddOnDuplicateKeyUpdate(table1Col3, Literal(3)) - stmt.AddOnDuplicateKeyUpdate(table1ColFloat, Literal(4)) + stmt.Add(literal(1), literal(2)) + stmt.AddOnDuplicateKeyUpdate(table1Col3, literal(3)) + stmt.AddOnDuplicateKeyUpdate(table1ColFloat, literal(4)) sql, err := stmt.String() c.Assert(err, gc.IsNil) diff --git a/sqlbuilder/test_utils.go b/sqlbuilder/test_utils.go index 9fc3d1a..34ef813 100644 --- a/sqlbuilder/test_utils.go +++ b/sqlbuilder/test_utils.go @@ -58,3 +58,13 @@ func assertExpressionSerialize(t *testing.T, expression expression, query string assert.DeepEqual(t, out.buff.String(), query) assert.DeepEqual(t, out.args, args) } + +func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) { + out := queryData{} + err := projection.serializeForProjection(select_statement, &out) + + assert.NilError(t, err) + + assert.DeepEqual(t, out.buff.String(), query) + assert.DeepEqual(t, out.args, args) +} diff --git a/sqlbuilder/time_expression.go b/sqlbuilder/time_expression.go index d4c4e10..bf81e98 100644 --- a/sqlbuilder/time_expression.go +++ b/sqlbuilder/time_expression.go @@ -69,5 +69,5 @@ func newPrefixTimeExpression(operator string, expression expression) TimeExpress } func INTERVAL(interval string) expression { - return newPrefixTimeExpression("INTERVAL", Literal(interval)) + return newPrefixTimeExpression("INTERVAL", literal(interval)) } diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index 9a60814..b54216d 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -35,7 +35,7 @@ func (u *updateStatementImpl) SET(values ...interface{}) updateStatement { if clause, ok := value.(clause); ok { u.updateValues = append(u.updateValues, clause) } else { - u.updateValues = append(u.updateValues, Literal(value)) + u.updateValues = append(u.updateValues, literal(value)) } } diff --git a/sqlbuilder/update_statement_test.go b/sqlbuilder/update_statement_test.go index 1ff95ce..7b4ba79 100644 --- a/sqlbuilder/update_statement_test.go +++ b/sqlbuilder/update_statement_test.go @@ -34,7 +34,7 @@ RETURNING table1.col1 AS "table1.col1"; } //func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) { -// stmt := table1.UPDATE().SET(nil, Literal(1)) +// stmt := table1.UPDATE().SET(nil, literal(1)) // _, err := stmt.String() // c.Assert(err, gc.NotNil) //} @@ -46,13 +46,13 @@ RETURNING table1.col1 AS "table1.col1"; //} // //func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// stmt := table1.UPDATE().SET(table1Col1, literal(1)) // _, err := stmt.String() // c.Assert(err, gc.NotNil) //} // //func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// stmt := table1.UPDATE().SET(table1Col1, literal(1)) // stmt.WHERE(EqString(table1ColFloat, 2)) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) @@ -64,7 +64,7 @@ RETURNING table1.col1 AS "table1.col1"; //} // //func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) { -// stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1)) +// stmt := table1.UPDATE().SET(table1.C("col1"), literal(1)) // stmt.WHERE(EqString(table1ColFloat, 2)) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) @@ -77,8 +77,8 @@ RETURNING table1.col1 AS "table1.col1"; // //func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) { // stmt := table1.UPDATE() -// stmt.SET(table1Col1, Literal(1)) -// stmt.SET(table1ColFloat, Literal(2)) +// stmt.SET(table1Col1, literal(1)) +// stmt.SET(table1ColFloat, literal(2)) // stmt.WHERE(EqString(table1ColFloat, 3)) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) @@ -92,7 +92,7 @@ RETURNING table1.col1 AS "table1.col1"; //} // //func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// stmt := table1.UPDATE().SET(table1Col1, literal(1)) // stmt.WHERE(EqString(table1ColFloat, 2)) // stmt.ORDER_BY(table1ColFloat) // sql, err := stmt.String() @@ -108,7 +108,7 @@ RETURNING table1.col1 AS "table1.col1"; //} // //func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// stmt := table1.UPDATE().SET(table1Col1, literal(1)) // stmt.WHERE(EqString(table1ColFloat, 2)) // stmt.LIMIT(5) // sql, err := stmt.String() diff --git a/tests/scan_test.go b/tests/scan_test.go index 201f3f1..0899cd9 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -269,7 +269,7 @@ func TestScanToNestedStruct(t *testing.T) { query := Inventory. INNER_JOIN(Film, Inventory.FilmID.EQ(Film.FilmID)). INNER_JOIN(Store, Inventory.StoreID.EQ(Store.StoreID)). - SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Literal("").AS("actor.first_name")). + SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, String("").AS("actor.first_name")). WHERE(Inventory.InventoryID.EQ(Int(1))) dest := struct { diff --git a/tests/select_test.go b/tests/select_test.go index da80a48..b1fed8f 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -1086,10 +1086,10 @@ LIMIT 20; ` query := Payment.SELECT( CASE(Payment.StaffID). - WHEN(Int(1)).THEN(Literal("ONE")). - WHEN(Int(2)).THEN(Literal("TWO")). - WHEN(Int(3)).THEN(Literal("THREE")). - ELSE(Literal("OTHER")).AS("staff_id_num"), + WHEN(Int(1)).THEN(String("ONE")). + WHEN(Int(2)).THEN(String("TWO")). + WHEN(Int(3)).THEN(String("THREE")). + ELSE(String("OTHER")).AS("staff_id_num"), ). ORDER_BY(Payment.PaymentID.ASC()). LIMIT(20) diff --git a/tests/types_test.go b/tests/types_test.go index 0687e44..de7bfee 100644 --- a/tests/types_test.go +++ b/tests/types_test.go @@ -29,6 +29,10 @@ func TestExpressionOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Integer.IS_NULL(), AllTypes.Timestamp.IS_NOT_NULL(), + AllTypes.SmallintPtr.IN(Int(11), Int(22), NULL), + AllTypes.SmallintPtr.IN(AllTypes.SELECT(AllTypes.IntegerPtr)), + AllTypes.SmallintPtr.NOT_IN(Int(11), Int(22), NULL), + AllTypes.SmallintPtr.NOT_IN(AllTypes.SELECT(AllTypes.IntegerPtr)), String("TRUE").CAST_TO_BOOL(), String("111").CAST_TO_INTEGER(),