From fb8607da29c7fb4cfe6005921ffe0e309323fb9b Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 24 May 2020 17:55:28 +0200 Subject: [PATCH] Add support for WITH statements and Common Table Expressions. --- internal/jet/clause.go | 20 +-- internal/jet/column.go | 2 +- internal/jet/func_expression.go | 5 + internal/jet/select_table.go | 32 ++--- internal/jet/serializer.go | 1 + internal/jet/sql_builder.go | 7 + internal/jet/sql_builder_test.go | 10 ++ internal/jet/statement.go | 9 +- internal/jet/with_statement.go | 78 +++++++++++ tests/mysql/with_test.go | 61 +++++++++ tests/postgres/select_test.go | 4 +- tests/postgres/with_test.go | 214 +++++++++++++++++++++++++++++++ tests/testdata | 2 +- 13 files changed, 406 insertions(+), 39 deletions(-) create mode 100644 internal/jet/with_statement.go create mode 100644 tests/mysql/with_test.go create mode 100644 tests/postgres/with_test.go diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 6091986..7b7e27b 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -13,17 +13,18 @@ type Clause interface { type ClauseWithProjections interface { Clause - projections() ProjectionList + Projections() ProjectionList } // ClauseSelect struct type ClauseSelect struct { - Distinct bool - Projections []Projection + Distinct bool + ProjectionList []Projection } -func (s *ClauseSelect) projections() ProjectionList { - return s.Projections +// Projections returns list of projections for select clause +func (s *ClauseSelect) Projections() ProjectionList { + return s.ProjectionList } // Serialize serializes clause into SQLBuilder @@ -35,11 +36,11 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o out.WriteString("DISTINCT") } - if len(s.Projections) == 0 { + if len(s.ProjectionList) == 0 { panic("jet: SELECT clause has to have at least one projection") } - out.WriteProjections(statementType, s.Projections) + out.WriteProjections(statementType, s.ProjectionList) } // ClauseFrom struct @@ -212,13 +213,14 @@ func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, opti type ClauseSetStmtOperator struct { Operator string All bool - Selects []StatementWithProjections + Selects []SerializerStatement OrderBy ClauseOrderBy Limit ClauseLimit Offset ClauseOffset } -func (s *ClauseSetStmtOperator) projections() ProjectionList { +// Projections returns set of projections for ClauseSetStmtOperator +func (s *ClauseSetStmtOperator) Projections() ProjectionList { if len(s.Selects) > 0 { return s.Selects[0].projections() } diff --git a/internal/jet/column.go b/internal/jet/column.go index 3e4c300..2b1b930 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -105,7 +105,7 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder if c.subQuery != nil { out.WriteIdentifier(c.subQuery.Alias()) out.WriteByte('.') - out.WriteIdentifier(c.defaultAlias(), true) + out.WriteIdentifier(c.defaultAlias()) } else { if c.tableName != "" && !contains(options, ShortName) { out.WriteIdentifier(c.tableName) diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index e95bece..d4eaa57 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -145,6 +145,11 @@ func MINi(integerExpression IntegerExpression) integerWindowExpression { return newIntegerWindowFunc("MIN", integerExpression) } +// SUM is aggregate function. Returns sum of all expressions +func SUM(expression Expression) Expression { + return newWindowFunc("SUM", expression) +} + // SUMf is aggregate function. Returns sum of expression across all float expressions func SUMf(floatExpression FloatExpression) floatWindowExpression { return NewFloatWindowFunc("SUM", floatExpression) diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index 9acd8a3..52689d4 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -8,35 +8,31 @@ type SelectTable interface { } type selectTableImpl struct { - selectStmt StatementWithProjections + selectStmt SerializerStatement alias string - - projections ProjectionList } // NewSelectTable func -func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTable { - selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias} - - projectionList := selectStmt.projections().fromImpl(&selectTable) - selectTable.projections = projectionList.(ProjectionList) - - return &selectTable +func NewSelectTable(selectStmt SerializerStatement, alias string) SelectTable { + selectTable := &selectTableImpl{selectStmt: selectStmt, alias: alias} + return selectTable } -func (s *selectTableImpl) Alias() string { +func (s selectTableImpl) Alias() string { return s.alias } -func (s *selectTableImpl) AllColumns() ProjectionList { - return s.projections -} - -func (s *selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if s == nil { - panic("jet: expression table is nil. ") +func (s selectTableImpl) AllColumns() ProjectionList { + statementWithProjections, ok := s.selectStmt.(HasProjections) + if !ok { + return ProjectionList{} } + projectionList := statementWithProjections.projections().fromImpl(s) + return projectionList.(ProjectionList) +} + +func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { s.selectStmt.serialize(statement, out) out.WriteString("AS") diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 2f014cc..b8cf04a 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -29,6 +29,7 @@ const ( SetStatementType StatementType = "SET" LockStatementType StatementType = "LOCK" UnLockStatementType StatementType = "UNLOCK" + WithStatementType StatementType = "WITH" ) // Serializer interface diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 59b776f..546759b 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -201,6 +201,13 @@ func integerTypesToString(value interface{}) string { } func shouldQuoteIdentifier(identifier string) bool { + _, err := strconv.ParseInt(identifier, 10, 64) + + if err == nil { // if it is a number we should quote it + return true + } + + // check if contains non ascii characters for _, c := range identifier { if unicode.IsNumber(c) || c == '_' { continue diff --git a/internal/jet/sql_builder_test.go b/internal/jet/sql_builder_test.go index 2aad3aa..3356e6e 100644 --- a/internal/jet/sql_builder_test.go +++ b/internal/jet/sql_builder_test.go @@ -47,3 +47,13 @@ func TestFallTrough(t *testing.T) { require.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil)) require.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName}) } + +func TestShouldQuote(t *testing.T) { + require.Equal(t, shouldQuoteIdentifier("123"), true) + require.Equal(t, shouldQuoteIdentifier("123.235"), true) + require.Equal(t, shouldQuoteIdentifier("abc123"), false) + require.Equal(t, shouldQuoteIdentifier("abc.123"), true) + require.Equal(t, shouldQuoteIdentifier("abc_123"), false) + require.Equal(t, shouldQuoteIdentifier("Abc_123"), true) + require.Equal(t, shouldQuoteIdentifier("DŽƜĐǶ"), true) +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 37b2077..23ae76c 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -32,13 +32,7 @@ type Statement interface { type SerializerStatement interface { Serializer Statement -} - -// StatementWithProjections interface -type StatementWithProjections interface { - Statement HasProjections - Serializer } // HasProjections interface @@ -163,7 +157,7 @@ type statementImpl struct { func (s *statementImpl) projections() ProjectionList { for _, clause := range s.Clauses { if selectClause, ok := clause.(ClauseWithProjections); ok { - return selectClause.projections() + return selectClause.Projections() } } @@ -171,7 +165,6 @@ func (s *statementImpl) projections() ProjectionList { } func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, NoWrap) { out.WriteString("(") out.IncreaseIdent() diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go new file mode 100644 index 0000000..6131b35 --- /dev/null +++ b/internal/jet/with_statement.go @@ -0,0 +1,78 @@ +package jet + +// WITH function creates new with statement from list of common table expressions for specified dialect +func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement SerializerStatement) Statement { + newWithImpl := &withImpl{ + ctes: cte, + serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ + dialect: dialect, + statementType: WithStatementType, + }, + } + newWithImpl.parent = newWithImpl + + return func(primaryStatement SerializerStatement) Statement { + newWithImpl.primaryStatement = primaryStatement + return newWithImpl + } +} + +type withImpl struct { + serializerStatementInterfaceImpl + ctes []CommonTableExpressionDefinition + primaryStatement SerializerStatement +} + +func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.NewLine() + out.WriteString("WITH") + + for i, cte := range w.ctes { + if i > 0 { + out.WriteString(",") + } + + cte.serialize(statement, out, FallTrough(options)...) + } + w.primaryStatement.serialize(statement, out, NoWrap.WithFallTrough(options)...) +} + +func (w withImpl) projections() ProjectionList { + return ProjectionList{} +} + +// CommonTableExpression contains information about a CTE. +type CommonTableExpression struct { + selectTableImpl +} + +// CTE creates new named CommonTableExpression +func CTE(name string) CommonTableExpression { + return CommonTableExpression{ + selectTableImpl: selectTableImpl{ + selectStmt: nil, + alias: name, + }, + } +} + +func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteIdentifier(c.alias) +} + +// AS returns sets definition for a CTE +func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition { + c.selectStmt = statement + return CommonTableExpressionDefinition{cte: c} +} + +// CommonTableExpressionDefinition contains implementation details of CTE +type CommonTableExpressionDefinition struct { + cte *CommonTableExpression +} + +func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteIdentifier(c.cte.alias) + out.WriteString("AS") + c.cte.selectStmt.serialize(statement, out, FallTrough(options)...) +} diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go new file mode 100644 index 0000000..7e3a8dd --- /dev/null +++ b/tests/mysql/with_test.go @@ -0,0 +1,61 @@ +package mysql + +import ( + "github.com/go-jet/jet/internal/testutils" + . "github.com/go-jet/jet/mysql" + . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" + "github.com/stretchr/testify/require" + "strings" + "testing" +) + +func TestWITH_SELECT(t *testing.T) { + salesRep := CTE("sales_rep") + salesRepStaffID := Staff.StaffID.From(salesRep) + salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep) + customerSalesRep := CTE("customer_sales_rep") + + stmt := WITH( + salesRep.AS( + SELECT( + Staff.StaffID, + Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()), + ).FROM(Staff), + ), + customerSalesRep.AS( + SELECT( + Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"), + salesRepFullName, + ).FROM( + salesRep. + INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)). + INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)), + ), + ), + )( + SELECT(customerSalesRep.AllColumns()). + FROM(customerSalesRep), + ) + + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, strings.Replace(` +WITH sales_rep AS ( + SELECT staff.staff_id AS "staff.staff_id", + (CONCAT(staff.first_name, staff.last_name)) AS "sales_rep_full_name" + FROM dvds.staff +),customer_sales_rep AS ( + SELECT (CONCAT(customer.first_name, customer.last_name)) AS "customer_name", + sales_rep.sales_rep_full_name AS "sales_rep_full_name" + FROM sales_rep + INNER JOIN dvds.store ON (store.manager_staff_id = sales_rep.''staff.staff_id'') + INNER JOIN dvds.customer ON (customer.store_id = store.store_id) +) +SELECT customer_sales_rep.customer_name AS "customer_name", + customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name" +FROM customer_sales_rep; +`, "''", "`", -1)) + + _, err := stmt.Exec(db) + require.NoError(t, err) +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 35f9803..b8e759d 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1088,7 +1088,7 @@ SELECT customer.customer_id AS "customer.customer_id", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active", - customer_payment_sum."amount_sum" AS "CustomerWithAmounts.AmountSum" + customer_payment_sum.amount_sum AS "CustomerWithAmounts.AmountSum" FROM dvds.customer INNER JOIN ( SELECT payment.customer_id AS "payment.customer_id", @@ -1096,7 +1096,7 @@ FROM dvds.customer FROM dvds.payment GROUP BY payment.customer_id ) AS customer_payment_sum ON (customer.customer_id = customer_payment_sum."payment.customer_id") -ORDER BY customer_payment_sum."amount_sum" ASC; +ORDER BY customer_payment_sum.amount_sum ASC; ` customersPayments := Payment. diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go new file mode 100644 index 0000000..3a21c63 --- /dev/null +++ b/tests/postgres/with_test.go @@ -0,0 +1,214 @@ +package postgres + +import ( + "github.com/go-jet/jet/internal/testutils" + . "github.com/go-jet/jet/postgres" + "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/model" + . "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/table" + "github.com/stretchr/testify/require" + "testing" +) + +func TestWithRegionalSales(t *testing.T) { + regionalSales := CTE("regional_sales") + topRegion := CTE("top_region") + + regionalSalesTotalSales := IntegerColumn("total_sales").From(regionalSales) + regionalSalesShipRegion := Orders.ShipRegion.From(regionalSales) + topRegionShipRegion := regionalSalesShipRegion.From(topRegion) + + stmt := WITH( + regionalSales.AS( + SELECT( + Orders.ShipRegion, + SUM(OrderDetails.Quantity).AS(regionalSalesTotalSales.Name()), + ). + FROM(Orders.INNER_JOIN(OrderDetails, OrderDetails.OrderID.EQ(Orders.OrderID))). + GROUP_BY(Orders.ShipRegion), + ), + topRegion.AS( + SELECT(regionalSalesShipRegion). + FROM(regionalSales). + WHERE(regionalSalesTotalSales.GT( + IntExp( + SELECT(SUM(regionalSalesTotalSales)). + FROM(regionalSales), + ).DIV(Int(50)), + ), + ), + ), + )( + SELECT( + Orders.ShipRegion, + OrderDetails.ProductID, + COUNT(STAR).AS("product_units"), + SUM(OrderDetails.Quantity).AS("product_sales"), + ). + FROM(Orders.INNER_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID))). + WHERE(Orders.ShipRegion.IN( + topRegion.SELECT(topRegionShipRegion)), + ). + GROUP_BY(Orders.ShipRegion, OrderDetails.ProductID). + ORDER_BY(SUM(OrderDetails.Quantity).DESC()), + ) + + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +WITH regional_sales AS ( + SELECT orders.ship_region AS "orders.ship_region", + SUM(order_details.quantity) AS "total_sales" + FROM northwind.orders + INNER JOIN northwind.order_details ON (order_details.order_id = orders.order_id) + GROUP BY orders.ship_region +),top_region AS ( + SELECT regional_sales."orders.ship_region" AS "orders.ship_region" + FROM regional_sales + WHERE regional_sales.total_sales > (( + SELECT SUM(regional_sales.total_sales) + FROM regional_sales + ) / 50) +) +SELECT orders.ship_region AS "orders.ship_region", + order_details.product_id AS "order_details.product_id", + COUNT(*) AS "product_units", + SUM(order_details.quantity) AS "product_sales" +FROM northwind.orders + INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) +WHERE orders.ship_region IN (( + SELECT top_region."orders.ship_region" AS "orders.ship_region" + FROM top_region + )) +GROUP BY orders.ship_region, order_details.product_id +ORDER BY SUM(order_details.quantity) DESC; +`) + + _, err := stmt.Exec(db) + require.NoError(t, err) +} + +func TestWithStatementDeleteAndInsert(t *testing.T) { + removeDiscontinuedOrders := CTE("remove_discontinued_orders") + updateDiscontinuedPrice := CTE("update_discontinued_price") + logDiscontinuedProducts := CTE("log_discontinued") + + discontinuedProductID := OrderDetails.ProductID.From(removeDiscontinuedOrders) + + stmt := WITH( + removeDiscontinuedOrders.AS( + OrderDetails.DELETE(). + WHERE(OrderDetails.ProductID.IN( + SELECT(Products.ProductID). + FROM(Products). + WHERE(Products.Discontinued.EQ(Int(1)))), + ).RETURNING(OrderDetails.ProductID), + ), + updateDiscontinuedPrice.AS( + Products.UPDATE(). + SET( + Products.UnitPrice.SET(Float(0.0)), + ). + WHERE(Products.ProductID.IN(removeDiscontinuedOrders.SELECT(discontinuedProductID))). + RETURNING(Products.AllColumns), + ), + logDiscontinuedProducts.AS( + ProductLogs.INSERT(ProductLogs.AllColumns). + QUERY(SELECT(updateDiscontinuedPrice.AllColumns()).FROM(updateDiscontinuedPrice)). + RETURNING( + ProductLogs.ProductID, + ProductLogs.ProductName, + ProductLogs.SupplierID, + ProductLogs.CategoryID, + ProductLogs.QuantityPerUnit, + ProductLogs.UnitPrice, + ProductLogs.UnitsInStock, + ProductLogs.UnitsOnOrder, + ProductLogs.ReorderLevel, + ProductLogs.Discontinued, + ), + ), + )( + SELECT(logDiscontinuedProducts.AllColumns()). + FROM(logDiscontinuedProducts), + ) + + require.Equal(t, len(removeDiscontinuedOrders.AllColumns()), 1) + require.Equal(t, len(updateDiscontinuedPrice.AllColumns()[0].(ProjectionList)), 10) + require.Equal(t, len(logDiscontinuedProducts.AllColumns()), 10) + + //fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +WITH remove_discontinued_orders AS ( + DELETE FROM northwind.order_details + WHERE order_details.product_id IN (( + SELECT products.product_id AS "products.product_id" + FROM northwind.products + WHERE products.discontinued = $1 + )) + RETURNING order_details.product_id AS "order_details.product_id" +),update_discontinued_price AS ( + UPDATE northwind.products + SET unit_price = $2 + WHERE products.product_id IN (( + SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" + FROM remove_discontinued_orders + )) + RETURNING products.product_id AS "products.product_id", + products.product_name AS "products.product_name", + products.supplier_id AS "products.supplier_id", + products.category_id AS "products.category_id", + products.quantity_per_unit AS "products.quantity_per_unit", + products.unit_price AS "products.unit_price", + products.units_in_stock AS "products.units_in_stock", + products.units_on_order AS "products.units_on_order", + products.reorder_level AS "products.reorder_level", + products.discontinued AS "products.discontinued" +),log_discontinued AS ( + INSERT INTO northwind.product_logs (product_id, product_name, supplier_id, category_id, quantity_per_unit, unit_price, units_in_stock, units_on_order, reorder_level, discontinued) ( + SELECT update_discontinued_price."products.product_id" AS "products.product_id", + update_discontinued_price."products.product_name" AS "products.product_name", + update_discontinued_price."products.supplier_id" AS "products.supplier_id", + update_discontinued_price."products.category_id" AS "products.category_id", + update_discontinued_price."products.quantity_per_unit" AS "products.quantity_per_unit", + update_discontinued_price."products.unit_price" AS "products.unit_price", + update_discontinued_price."products.units_in_stock" AS "products.units_in_stock", + update_discontinued_price."products.units_on_order" AS "products.units_on_order", + update_discontinued_price."products.reorder_level" AS "products.reorder_level", + update_discontinued_price."products.discontinued" AS "products.discontinued" + FROM update_discontinued_price + ) + RETURNING product_logs.product_id AS "product_logs.product_id", + product_logs.product_name AS "product_logs.product_name", + product_logs.supplier_id AS "product_logs.supplier_id", + product_logs.category_id AS "product_logs.category_id", + product_logs.quantity_per_unit AS "product_logs.quantity_per_unit", + product_logs.unit_price AS "product_logs.unit_price", + product_logs.units_in_stock AS "product_logs.units_in_stock", + product_logs.units_on_order AS "product_logs.units_on_order", + product_logs.reorder_level AS "product_logs.reorder_level", + product_logs.discontinued AS "product_logs.discontinued" +) +SELECT log_discontinued."product_logs.product_id" AS "product_logs.product_id", + log_discontinued."product_logs.product_name" AS "product_logs.product_name", + log_discontinued."product_logs.supplier_id" AS "product_logs.supplier_id", + log_discontinued."product_logs.category_id" AS "product_logs.category_id", + log_discontinued."product_logs.quantity_per_unit" AS "product_logs.quantity_per_unit", + log_discontinued."product_logs.unit_price" AS "product_logs.unit_price", + log_discontinued."product_logs.units_in_stock" AS "product_logs.units_in_stock", + log_discontinued."product_logs.units_on_order" AS "product_logs.units_on_order", + log_discontinued."product_logs.reorder_level" AS "product_logs.reorder_level", + log_discontinued."product_logs.discontinued" AS "product_logs.discontinued" +FROM log_discontinued; +`, int64(1), 0.0) + + var resp []model.ProductLogs + + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + err = stmt.Query(tx, &resp) + require.NoError(t, err) + +} diff --git a/tests/testdata b/tests/testdata index 1745be3..ed53a50 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1745be34a649c0f37d0d31d7c0352a1248ace2dc +Subproject commit ed53a505eb738d1be457877eee251f9ba0418df1