From 5c05214ba1510295792bdd2d26d541ebb4258117 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Wed, 1 May 2019 18:23:19 +0200 Subject: [PATCH] Fix integration tests. --- sqlbuilder/expression.go | 21 +++++-- sqlbuilder/set_statement.go | 4 +- sqlbuilder/set_statement_test.go | 70 +++++++++++++++++++++ tests/main_test.go | 13 ++++ tests/{generator_test.go => select_test.go} | 12 ---- 5 files changed, 102 insertions(+), 18 deletions(-) rename tests/{generator_test.go => select_test.go} (98%) diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index 2fb127d..dd43799 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -60,6 +60,20 @@ func newBinaryExpression(lhs, rhs Expression, operator []byte, parent ...Express return binaryExpression } +func isSimpleOperand(expression Expression) bool { + if _, ok := expression.(*literalExpression); ok { + return true + } + if _, ok := expression.(Column); ok { + return true + } + if _, ok := expression.(FuncExpression); ok { + return true + } + + return false +} + func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) error { if c.lhs == nil { return errors.Newf("nil lhs.") @@ -68,10 +82,9 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) return errors.Newf("nil rhs.") } - _, literalLeft := c.lhs.(*literalExpression) - _, literalRight := c.rhs.(*literalExpression) + wrap := !isSimpleOperand(c.lhs) && !isSimpleOperand(c.rhs) - if !literalLeft && !literalRight { + if wrap { out.WriteString("(") } @@ -85,7 +98,7 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) return err } - if !literalLeft && !literalRight { + if wrap { out.WriteString(")") } diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index 16fcec8..52f9ace 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -81,8 +81,8 @@ func (us *setStatementImpl) OFFSET(offset int64) SetStatement { } func (us *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error { - if len(us.selects) == 0 { - return errors.Newf("UNION statement must have at least one SELECT") + if len(us.selects) < 2 { + return errors.Newf("UNION statement must have at least two SELECT statements.") } out.WriteString("(") diff --git a/sqlbuilder/set_statement_test.go b/sqlbuilder/set_statement_test.go index 3896e39..2136f40 100644 --- a/sqlbuilder/set_statement_test.go +++ b/sqlbuilder/set_statement_test.go @@ -1 +1,71 @@ package sqlbuilder + +import ( + "fmt" + "gotest.tools/assert" + "testing" +) + +func TestUnionNoSelect(t *testing.T) { + query, args, err := UNION().Sql() + + assert.Assert(t, err != nil) + //fmt.Println(err.Error()) + fmt.Print(query, args) +} + +func TestUnionOneSelect(t *testing.T) { + query, args, err := UNION( + table1.SELECT(table1Col1), + ).Sql() + + assert.Assert(t, err != nil) + fmt.Println(err.Error()) + fmt.Println(query) + fmt.Println(args) +} + +func TestUnionTwoSelect(t *testing.T) { + query, args, err := UNION( + table1.SELECT(table1Col1), + table2.SELECT(table2Col3), + ).Sql() + + assert.NilError(t, err) + assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2))`) + assert.Equal(t, len(args), 0) +} + +func TestUnionThreeSelect(t *testing.T) { + query, args, err := UNION( + table1.SELECT(table1Col1), + table2.SELECT(table2Col3), + table3.SELECT(table3Col1), + ).Sql() + + assert.NilError(t, err) + assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2) UNION (SELECT table3.col1 AS "table3.col1" FROM db.table3))`) + assert.Equal(t, len(args), 0) +} + +func TestUnionWithOrderBy(t *testing.T) { + query, args, err := UNION( + table1.SELECT(table1Col1), + table2.SELECT(table2Col3), + ).ORDER_BY(table1Col1.Asc()).Sql() + + assert.NilError(t, err) + assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) ORDER BY table1.col1 ASC`) + assert.Equal(t, len(args), 0) +} + +func TestUnionWithLimit(t *testing.T) { + query, args, err := UNION( + table1.SELECT(table1Col1), + table2.SELECT(table2Col3), + ).LIMIT(10).OFFSET(11).Sql() + + assert.NilError(t, err) + assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) LIMIT $1 OFFSET $2`) + assert.Equal(t, len(args), 2) +} diff --git a/tests/main_test.go b/tests/main_test.go index 95c6d3e..e4b6e62 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" _ "github.com/lib/pq" + "github.com/sub0zero/go-sqlbuilder/generator" + "gotest.tools/assert" "os" "testing" ) @@ -75,3 +77,14 @@ CREATE TABLE IF NOT EXISTS test_sample.link ( fmt.Println(result) } + +func TestGenerateModel(t *testing.T) { + + err := generator.Generate(folderPath, connectString, dbname, schemaName) + + assert.NilError(t, err) + + //err = generator.Generate(folderPath, connectString, dbname, "sport") + // + //assert.NilError(t, err) +} diff --git a/tests/generator_test.go b/tests/select_test.go similarity index 98% rename from tests/generator_test.go rename to tests/select_test.go index ee1be4e..77b0318 100644 --- a/tests/generator_test.go +++ b/tests/select_test.go @@ -2,7 +2,6 @@ package tests import ( "fmt" - "github.com/sub0zero/go-sqlbuilder/generator" "github.com/sub0zero/go-sqlbuilder/sqlbuilder" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" @@ -12,17 +11,6 @@ import ( "time" ) -func TestGenerateModel(t *testing.T) { - - err := generator.Generate(folderPath, connectString, dbname, schemaName) - - assert.NilError(t, err) - - //err = generator.Generate(folderPath, connectString, dbname, "sport") - // - //assert.NilError(t, err) -} - func TestSelect_ScanToStruct(t *testing.T) { actor := model.Actor{} query := Actor.SELECT(Actor.AllColumns).ORDER_BY(Actor.ActorID.Asc())