From d197956271a3e567d3f9d0f7ec390bc753e477aa Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 21 Oct 2021 13:35:37 +0200 Subject: [PATCH] Avoid unnecessary double wrapping of SELECT statement when used as single function parameter. --- internal/jet/literal_expression.go | 9 ++++-- internal/jet/serializer.go | 2 ++ internal/jet/statement.go | 9 ++++++ internal/jet/utils.go | 11 +++++-- mysql/insert_statement_test.go | 1 - postgres/dialect_test.go | 16 +++++----- tests/mysql/alltypes_test.go | 8 ++--- tests/mysql/insert_test.go | 2 +- tests/mysql/update_test.go | 13 ++++----- tests/postgres/alltypes_test.go | 8 ++--- tests/postgres/generator_test.go | 15 +++++----- tests/postgres/insert_test.go | 3 +- tests/postgres/main_test.go | 11 ++----- tests/postgres/select_test.go | 11 +++---- tests/postgres/update_test.go | 8 ++--- tests/postgres/with_test.go | 47 +++++++++++++++++------------- 16 files changed, 97 insertions(+), 77 deletions(-) diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index d7cf47a..450b0ab 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -375,9 +375,14 @@ type wrap struct { expressions []Expression } -func (n *wrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("(") - serializeExpressionList(statement, n.expressions, ", ", out) + + if len(n.expressions) == 1 { + options = append(options, NoWrap, Ident) + } + serializeExpressionList(statementType, n.expressions, ", ", out, options...) + out.WriteString(")") } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index b8cf04a..866d60e 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -7,8 +7,10 @@ type SerializeOption int const ( NoWrap SerializeOption = iota SkipNewLine + Ident fallTroughOptions // fall trough options + ShortName ) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index da3650d..1d05045 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -195,10 +195,19 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti out.IncreaseIdent() } + if contains(options, Ident) { + out.IncreaseIdent() + } + for _, clause := range s.Clauses { clause.Serialize(statement, out, FallTrough(options)...) } + if contains(options, Ident) { + out.DecreaseIdent() + out.NewLine() + } + if !contains(options, NoWrap) { out.DecreaseIdent() out.NewLine() diff --git a/internal/jet/utils.go b/internal/jet/utils.go index b2fff48..eab4403 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -21,14 +21,19 @@ func SerializeClauseList(statement StatementType, clauses []Serializer, out *SQL } } -func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SQLBuilder) { +func serializeExpressionList( + statement StatementType, + expressions []Expression, + separator string, + out *SQLBuilder, + options ...SerializeOption) { - for i, value := range expressions { + for i, expression := range expressions { if i > 0 { out.WriteString(separator) } - value.serialize(statement, out) + expression.serialize(statement, out, options...) } } diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index dbabc3f..7b396d0 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -7,7 +7,6 @@ import ( ) func TestInvalidInsert(t *testing.T) { - assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") } diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 9b7b3d1..d98d8f3 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -46,33 +46,33 @@ func TestExists(t *testing.T) { func TestIN(t *testing.T) { assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), - `($1 IN (( + `($1 IN ( SELECT table1.col1 AS "table1.col1" FROM db.table1 -)))`, float64(1.11)) +))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) IN (( + `(ROW($1, table1.col1) IN ( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -)))`, int64(12)) +))`, int64(12)) } func TestNOT_IN(t *testing.T) { assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), - `($1 NOT IN (( + `($1 NOT IN ( SELECT table1.col1 AS "table1.col1" FROM db.table1 -)))`, float64(1.11)) +))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) NOT IN (( + `(ROW($1, table1.col1) NOT IN ( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -)))`, int64(12)) +))`, int64(12)) } func TestReservedWordEscaped(t *testing.T) { diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 5a5012d..2132d7a 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -104,18 +104,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN (( + (all_types.small_int_ptr IN ( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.in_select", + )) AS "result.in_select", (CURRENT_USER()) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN (( + (all_types.small_int_ptr NOT IN ( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.not_in_select" + )) AS "result.not_in_select" FROM test_sample.all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 4f39d6c..55fc706 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -278,7 +278,7 @@ ON DUPLICATE KEY UPDATE id = (id + ?), err := SELECT(Link.AllColumns). FROM(Link). - WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))). + WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))). Query(db, &newLinks) require.NoError(t, err) diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index 281e17b..dc28924 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -121,7 +121,7 @@ func TestUpdateWithModelData(t *testing.T) { stmt := Link. UPDATE(Link.AllColumns). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) expectedSQL := ` UPDATE test_sample.link @@ -131,7 +131,7 @@ SET id = ?, description = ? WHERE link.id = ?; ` - testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) + testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertExec(t, stmt, db) requireLogged(t, stmt) @@ -152,7 +152,7 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { stmt := Link. UPDATE(updateColumnList). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) var expectedSQL = ` UPDATE test_sample.link @@ -161,9 +161,8 @@ SET description = NULL, url = 'http://www.duckduckgo.com' WHERE link.id = 201; ` - //fmt.Println(stmt.DebugSql()) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) testutils.AssertExec(t, stmt, db) requireLogged(t, stmt) @@ -181,7 +180,7 @@ func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { stmt := Link. UPDATE(Link.MutableColumns). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) var expectedSQL = ` UPDATE test_sample.link @@ -192,7 +191,7 @@ WHERE link.id = 201; ` //fmt.Println(stmt.DebugSql()) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertExec(t, stmt, db) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index ea66519..82ac82b 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -241,18 +241,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN ($1, $2)) AS "result.in", - (all_types.small_int_ptr IN (( + (all_types.small_int_ptr IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.in_select", + )) AS "result.in_select", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN (( + (all_types.small_int_ptr NOT IN ( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.not_in_select" + )) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; `, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2)) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index b1b733e..a839987 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "os" "os/exec" + "path/filepath" "reflect" "testing" @@ -368,16 +369,16 @@ func newActorInfoTableImpl(schemaName, tableName, alias string) actorInfoTable { ` func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { - enumDir := testRoot + ".gentestdata/jetdb/test_sample/enum/" - modelDir := testRoot + ".gentestdata/jetdb/test_sample/model/" - tableDir := testRoot + ".gentestdata/jetdb/test_sample/table/" + enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/") + modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/") + tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/") enumFiles, err := ioutil.ReadDir(enumDir) require.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mood.go", "level.go") - testutils.AssertFileContent(t, enumDir+"mood.go", moodEnumContent) - testutils.AssertFileContent(t, enumDir+"level.go", levelEnumContent) + testutils.AssertFileContent(t, enumDir+"/mood.go", moodEnumContent) + testutils.AssertFileContent(t, enumDir+"/level.go", levelEnumContent) modelFiles, err := ioutil.ReadDir(modelDir) require.NoError(t, err) @@ -385,7 +386,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go") - testutils.AssertFileContent(t, modelDir+"all_types.go", allTypesModelContent) + testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) tableFiles, err := ioutil.ReadDir(tableDir) require.NoError(t, err) @@ -393,7 +394,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go") - testutils.AssertFileContent(t, tableDir+"all_types.go", allTypesTableContent) + testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) } var moodEnumContent = ` diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 9c9875c..8a50e02 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -140,8 +140,7 @@ ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; Link.ID.SET(Link.EXCLUDED.ID), Link.URL.SET(String("http://www.postgresqltutorial2.com")), ), - ). - RETURNING(Link.AllColumns) + ).RETURNING(Link.AllColumns) testutils.AssertStatementSql(t, stmt, ` INSERT INTO test_sample.link (id, url, name, description) diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 541747c..4e8aade 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -4,10 +4,9 @@ import ( "context" "database/sql" "fmt" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" "math/rand" "os" - "os/exec" - "strings" "testing" "time" @@ -53,13 +52,7 @@ func TestMain(m *testing.M) { } func setTestRoot() { - cmd := exec.Command("git", "rev-parse", "--show-toplevel") - byteArr, err := cmd.Output() - if err != nil { - panic(err) - } - - testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" + testRoot = repo.GetTestsDirPath() } var loggedSQL string diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 7b7bf00..9635929 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -75,11 +75,12 @@ LIMIT 30; query := SELECT( Payment.AllColumns, Customer.AllColumns, - ). - FROM(Payment. - INNER_JOIN(Customer, Payment.CustomerID.EQ(Customer.CustomerID))). - ORDER_BY(Payment.PaymentID.ASC()). - LIMIT(30) + ).FROM( + Payment. + INNER_JOIN(Customer, Payment.CustomerID.EQ(Customer.CustomerID)), + ).ORDER_BY( + Payment.PaymentID.ASC(), + ).LIMIT(30) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(30)) diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 043bf78..5ec44a1 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -259,14 +259,14 @@ func TestUpdateWithModelData(t *testing.T) { stmt := Link. UPDATE(Link.AllColumns). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) WHERE link.id = 201; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) AssertExec(t, stmt, 1) } @@ -286,14 +286,14 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { stmt := Link. UPDATE(updateColumnList). MODEL(link). - WHERE(Link.ID.EQ(Int(int64(link.ID)))) + WHERE(Link.ID.EQ(Int32(link.ID))) var expectedSQL = ` UPDATE test_sample.link SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') WHERE link.id = 201; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) AssertExec(t, stmt, 1) } diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 8eadf21..8a16fd4 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -22,20 +22,23 @@ func TestWithRegionalSales(t *testing.T) { SELECT( Orders.ShipRegion, SUM(OrderDetails.Quantity).AS(regionalSalesTotalSales.Name()), - ). - FROM(Orders.INNER_JOIN(OrderDetails, OrderDetails.OrderID.EQ(Orders.OrderID))). - GROUP_BY(Orders.ShipRegion), + ).FROM( + Orders.INNER_JOIN(OrderDetails, OrderDetails.OrderID.EQ(Orders.OrderID)), + ).GROUP_BY(Orders.ShipRegion), ), topRegion.AS( - SELECT(regionalSalesShipRegion). - FROM(regionalSales). - WHERE(regionalSalesTotalSales.GT( + SELECT( + regionalSalesShipRegion, + ).FROM( + regionalSales, + ).WHERE( + regionalSalesTotalSales.GT( IntExp( SELECT(SUM(regionalSalesTotalSales)). FROM(regionalSales), ).DIV(Int(50)), ), - ), + ), ), )( SELECT( @@ -43,13 +46,17 @@ func TestWithRegionalSales(t *testing.T) { OrderDetails.ProductID, COUNT(STAR).AS("product_units"), SUM(OrderDetails.Quantity).AS("product_sales"), - ). - FROM(Orders.INNER_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID))). - WHERE(Orders.ShipRegion.IN( - topRegion.SELECT(topRegionShipRegion)), - ). - GROUP_BY(Orders.ShipRegion, OrderDetails.ProductID). - ORDER_BY(SUM(OrderDetails.Quantity).DESC()), + ).FROM( + Orders. + INNER_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)), + ).WHERE( + Orders.ShipRegion.IN(topRegion.SELECT(topRegionShipRegion)), + ).GROUP_BY( + Orders.ShipRegion, + OrderDetails.ProductID, + ).ORDER_BY( + SUM(OrderDetails.Quantity).DESC(), + ), ) //fmt.Println(stmt.DebugSql()) @@ -75,10 +82,10 @@ SELECT orders.ship_region AS "orders.ship_region", SUM(order_details.quantity) AS "product_sales" FROM northwind.orders INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) -WHERE orders.ship_region IN (( +WHERE orders.ship_region IN ( SELECT top_region."orders.ship_region" AS "orders.ship_region" FROM top_region - )) + ) GROUP BY orders.ship_region, order_details.product_id ORDER BY SUM(order_details.quantity) DESC; `) @@ -141,19 +148,19 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { testutils.AssertStatementSql(t, stmt, ` WITH remove_discontinued_orders AS ( DELETE FROM northwind.order_details - WHERE order_details.product_id IN (( + WHERE order_details.product_id IN ( SELECT products.product_id AS "products.product_id" FROM northwind.products WHERE products.discontinued = $1 - )) + ) RETURNING order_details.product_id AS "order_details.product_id" ),update_discontinued_price AS ( UPDATE northwind.products SET unit_price = $2 - WHERE products.product_id IN (( + WHERE products.product_id IN ( SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" FROM remove_discontinued_orders - )) + ) RETURNING products.product_id AS "products.product_id", products.product_name AS "products.product_name", products.supplier_id AS "products.supplier_id",