Add support for WITH statements and Common Table Expressions.

This commit is contained in:
go-jet 2020-05-24 17:55:28 +02:00
parent 0d3ec872d6
commit fb8607da29
13 changed files with 406 additions and 39 deletions

View file

@ -13,17 +13,18 @@ type Clause interface {
type ClauseWithProjections interface {
Clause
projections() ProjectionList
Projections() ProjectionList
}
// ClauseSelect struct
type ClauseSelect struct {
Distinct bool
Projections []Projection
Distinct bool
ProjectionList []Projection
}
func (s *ClauseSelect) projections() ProjectionList {
return s.Projections
// Projections returns list of projections for select clause
func (s *ClauseSelect) Projections() ProjectionList {
return s.ProjectionList
}
// Serialize serializes clause into SQLBuilder
@ -35,11 +36,11 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
out.WriteString("DISTINCT")
}
if len(s.Projections) == 0 {
if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.WriteProjections(statementType, s.Projections)
out.WriteProjections(statementType, s.ProjectionList)
}
// ClauseFrom struct
@ -212,13 +213,14 @@ func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, opti
type ClauseSetStmtOperator struct {
Operator string
All bool
Selects []StatementWithProjections
Selects []SerializerStatement
OrderBy ClauseOrderBy
Limit ClauseLimit
Offset ClauseOffset
}
func (s *ClauseSetStmtOperator) projections() ProjectionList {
// Projections returns set of projections for ClauseSetStmtOperator
func (s *ClauseSetStmtOperator) Projections() ProjectionList {
if len(s.Selects) > 0 {
return s.Selects[0].projections()
}

View file

@ -105,7 +105,7 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias())
out.WriteByte('.')
out.WriteIdentifier(c.defaultAlias(), true)
out.WriteIdentifier(c.defaultAlias())
} else {
if c.tableName != "" && !contains(options, ShortName) {
out.WriteIdentifier(c.tableName)

View file

@ -145,6 +145,11 @@ func MINi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MIN", integerExpression)
}
// SUM is aggregate function. Returns sum of all expressions
func SUM(expression Expression) Expression {
return newWindowFunc("SUM", expression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("SUM", floatExpression)

View file

@ -8,35 +8,31 @@ type SelectTable interface {
}
type selectTableImpl struct {
selectStmt StatementWithProjections
selectStmt SerializerStatement
alias string
projections ProjectionList
}
// NewSelectTable func
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTable {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias}
projectionList := selectStmt.projections().fromImpl(&selectTable)
selectTable.projections = projectionList.(ProjectionList)
return &selectTable
func NewSelectTable(selectStmt SerializerStatement, alias string) SelectTable {
selectTable := &selectTableImpl{selectStmt: selectStmt, alias: alias}
return selectTable
}
func (s *selectTableImpl) Alias() string {
func (s selectTableImpl) Alias() string {
return s.alias
}
func (s *selectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if s == nil {
panic("jet: expression table is nil. ")
func (s selectTableImpl) AllColumns() ProjectionList {
statementWithProjections, ok := s.selectStmt.(HasProjections)
if !ok {
return ProjectionList{}
}
projectionList := statementWithProjections.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.selectStmt.serialize(statement, out)
out.WriteString("AS")

View file

@ -29,6 +29,7 @@ const (
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
WithStatementType StatementType = "WITH"
)
// Serializer interface

View file

@ -201,6 +201,13 @@ func integerTypesToString(value interface{}) string {
}
func shouldQuoteIdentifier(identifier string) bool {
_, err := strconv.ParseInt(identifier, 10, 64)
if err == nil { // if it is a number we should quote it
return true
}
// check if contains non ascii characters
for _, c := range identifier {
if unicode.IsNumber(c) || c == '_' {
continue

View file

@ -47,3 +47,13 @@ func TestFallTrough(t *testing.T) {
require.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil))
require.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName})
}
func TestShouldQuote(t *testing.T) {
require.Equal(t, shouldQuoteIdentifier("123"), true)
require.Equal(t, shouldQuoteIdentifier("123.235"), true)
require.Equal(t, shouldQuoteIdentifier("abc123"), false)
require.Equal(t, shouldQuoteIdentifier("abc.123"), true)
require.Equal(t, shouldQuoteIdentifier("abc_123"), false)
require.Equal(t, shouldQuoteIdentifier("Abc_123"), true)
require.Equal(t, shouldQuoteIdentifier("DŽƜĐǶ"), true)
}

View file

@ -32,13 +32,7 @@ type Statement interface {
type SerializerStatement interface {
Serializer
Statement
}
// StatementWithProjections interface
type StatementWithProjections interface {
Statement
HasProjections
Serializer
}
// HasProjections interface
@ -163,7 +157,7 @@ type statementImpl struct {
func (s *statementImpl) projections() ProjectionList {
for _, clause := range s.Clauses {
if selectClause, ok := clause.(ClauseWithProjections); ok {
return selectClause.projections()
return selectClause.Projections()
}
}
@ -171,7 +165,6 @@ func (s *statementImpl) projections() ProjectionList {
}
func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
out.IncreaseIdent()

View file

@ -0,0 +1,78 @@
package jet
// WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement SerializerStatement) Statement {
newWithImpl := &withImpl{
ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
statementType: WithStatementType,
},
}
newWithImpl.parent = newWithImpl
return func(primaryStatement SerializerStatement) Statement {
newWithImpl.primaryStatement = primaryStatement
return newWithImpl
}
}
type withImpl struct {
serializerStatementInterfaceImpl
ctes []CommonTableExpressionDefinition
primaryStatement SerializerStatement
}
func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("WITH")
for i, cte := range w.ctes {
if i > 0 {
out.WriteString(",")
}
cte.serialize(statement, out, FallTrough(options)...)
}
w.primaryStatement.serialize(statement, out, NoWrap.WithFallTrough(options)...)
}
func (w withImpl) projections() ProjectionList {
return ProjectionList{}
}
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
selectTableImpl
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
return CommonTableExpression{
selectTableImpl: selectTableImpl{
selectStmt: nil,
alias: name,
},
}
}
func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.alias)
}
// AS returns sets definition for a CTE
func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition {
c.selectStmt = statement
return CommonTableExpressionDefinition{cte: c}
}
// CommonTableExpressionDefinition contains implementation details of CTE
type CommonTableExpressionDefinition struct {
cte *CommonTableExpression
}
func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.cte.alias)
out.WriteString("AS")
c.cte.selectStmt.serialize(statement, out, FallTrough(options)...)
}

61
tests/mysql/with_test.go Normal file
View file

@ -0,0 +1,61 @@
package mysql
import (
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/require"
"strings"
"testing"
)
func TestWITH_SELECT(t *testing.T) {
salesRep := CTE("sales_rep")
salesRepStaffID := Staff.StaffID.From(salesRep)
salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep)
customerSalesRep := CTE("customer_sales_rep")
stmt := WITH(
salesRep.AS(
SELECT(
Staff.StaffID,
Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()),
).FROM(Staff),
),
customerSalesRep.AS(
SELECT(
Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"),
salesRepFullName,
).FROM(
salesRep.
INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)).
INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)),
),
),
)(
SELECT(customerSalesRep.AllColumns()).
FROM(customerSalesRep),
)
//fmt.Println(stmt.DebugSql())
testutils.AssertStatementSql(t, stmt, strings.Replace(`
WITH sales_rep AS (
SELECT staff.staff_id AS "staff.staff_id",
(CONCAT(staff.first_name, staff.last_name)) AS "sales_rep_full_name"
FROM dvds.staff
),customer_sales_rep AS (
SELECT (CONCAT(customer.first_name, customer.last_name)) AS "customer_name",
sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM sales_rep
INNER JOIN dvds.store ON (store.manager_staff_id = sales_rep.''staff.staff_id'')
INNER JOIN dvds.customer ON (customer.store_id = store.store_id)
)
SELECT customer_sales_rep.customer_name AS "customer_name",
customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM customer_sales_rep;
`, "''", "`", -1))
_, err := stmt.Exec(db)
require.NoError(t, err)
}

View file

@ -1088,7 +1088,7 @@ SELECT customer.customer_id AS "customer.customer_id",
customer.create_date AS "customer.create_date",
customer.last_update AS "customer.last_update",
customer.active AS "customer.active",
customer_payment_sum."amount_sum" AS "CustomerWithAmounts.AmountSum"
customer_payment_sum.amount_sum AS "CustomerWithAmounts.AmountSum"
FROM dvds.customer
INNER JOIN (
SELECT payment.customer_id AS "payment.customer_id",
@ -1096,7 +1096,7 @@ FROM dvds.customer
FROM dvds.payment
GROUP BY payment.customer_id
) AS customer_payment_sum ON (customer.customer_id = customer_payment_sum."payment.customer_id")
ORDER BY customer_payment_sum."amount_sum" ASC;
ORDER BY customer_payment_sum.amount_sum ASC;
`
customersPayments := Payment.

214
tests/postgres/with_test.go Normal file
View file

@ -0,0 +1,214 @@
package postgres
import (
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/table"
"github.com/stretchr/testify/require"
"testing"
)
func TestWithRegionalSales(t *testing.T) {
regionalSales := CTE("regional_sales")
topRegion := CTE("top_region")
regionalSalesTotalSales := IntegerColumn("total_sales").From(regionalSales)
regionalSalesShipRegion := Orders.ShipRegion.From(regionalSales)
topRegionShipRegion := regionalSalesShipRegion.From(topRegion)
stmt := WITH(
regionalSales.AS(
SELECT(
Orders.ShipRegion,
SUM(OrderDetails.Quantity).AS(regionalSalesTotalSales.Name()),
).
FROM(Orders.INNER_JOIN(OrderDetails, OrderDetails.OrderID.EQ(Orders.OrderID))).
GROUP_BY(Orders.ShipRegion),
),
topRegion.AS(
SELECT(regionalSalesShipRegion).
FROM(regionalSales).
WHERE(regionalSalesTotalSales.GT(
IntExp(
SELECT(SUM(regionalSalesTotalSales)).
FROM(regionalSales),
).DIV(Int(50)),
),
),
),
)(
SELECT(
Orders.ShipRegion,
OrderDetails.ProductID,
COUNT(STAR).AS("product_units"),
SUM(OrderDetails.Quantity).AS("product_sales"),
).
FROM(Orders.INNER_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID))).
WHERE(Orders.ShipRegion.IN(
topRegion.SELECT(topRegionShipRegion)),
).
GROUP_BY(Orders.ShipRegion, OrderDetails.ProductID).
ORDER_BY(SUM(OrderDetails.Quantity).DESC()),
)
//fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, `
WITH regional_sales AS (
SELECT orders.ship_region AS "orders.ship_region",
SUM(order_details.quantity) AS "total_sales"
FROM northwind.orders
INNER JOIN northwind.order_details ON (order_details.order_id = orders.order_id)
GROUP BY orders.ship_region
),top_region AS (
SELECT regional_sales."orders.ship_region" AS "orders.ship_region"
FROM regional_sales
WHERE regional_sales.total_sales > ((
SELECT SUM(regional_sales.total_sales)
FROM regional_sales
) / 50)
)
SELECT orders.ship_region AS "orders.ship_region",
order_details.product_id AS "order_details.product_id",
COUNT(*) AS "product_units",
SUM(order_details.quantity) AS "product_sales"
FROM northwind.orders
INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id)
WHERE orders.ship_region IN ((
SELECT top_region."orders.ship_region" AS "orders.ship_region"
FROM top_region
))
GROUP BY orders.ship_region, order_details.product_id
ORDER BY SUM(order_details.quantity) DESC;
`)
_, err := stmt.Exec(db)
require.NoError(t, err)
}
func TestWithStatementDeleteAndInsert(t *testing.T) {
removeDiscontinuedOrders := CTE("remove_discontinued_orders")
updateDiscontinuedPrice := CTE("update_discontinued_price")
logDiscontinuedProducts := CTE("log_discontinued")
discontinuedProductID := OrderDetails.ProductID.From(removeDiscontinuedOrders)
stmt := WITH(
removeDiscontinuedOrders.AS(
OrderDetails.DELETE().
WHERE(OrderDetails.ProductID.IN(
SELECT(Products.ProductID).
FROM(Products).
WHERE(Products.Discontinued.EQ(Int(1)))),
).RETURNING(OrderDetails.ProductID),
),
updateDiscontinuedPrice.AS(
Products.UPDATE().
SET(
Products.UnitPrice.SET(Float(0.0)),
).
WHERE(Products.ProductID.IN(removeDiscontinuedOrders.SELECT(discontinuedProductID))).
RETURNING(Products.AllColumns),
),
logDiscontinuedProducts.AS(
ProductLogs.INSERT(ProductLogs.AllColumns).
QUERY(SELECT(updateDiscontinuedPrice.AllColumns()).FROM(updateDiscontinuedPrice)).
RETURNING(
ProductLogs.ProductID,
ProductLogs.ProductName,
ProductLogs.SupplierID,
ProductLogs.CategoryID,
ProductLogs.QuantityPerUnit,
ProductLogs.UnitPrice,
ProductLogs.UnitsInStock,
ProductLogs.UnitsOnOrder,
ProductLogs.ReorderLevel,
ProductLogs.Discontinued,
),
),
)(
SELECT(logDiscontinuedProducts.AllColumns()).
FROM(logDiscontinuedProducts),
)
require.Equal(t, len(removeDiscontinuedOrders.AllColumns()), 1)
require.Equal(t, len(updateDiscontinuedPrice.AllColumns()[0].(ProjectionList)), 10)
require.Equal(t, len(logDiscontinuedProducts.AllColumns()), 10)
//fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, `
WITH remove_discontinued_orders AS (
DELETE FROM northwind.order_details
WHERE order_details.product_id IN ((
SELECT products.product_id AS "products.product_id"
FROM northwind.products
WHERE products.discontinued = $1
))
RETURNING order_details.product_id AS "order_details.product_id"
),update_discontinued_price AS (
UPDATE northwind.products
SET unit_price = $2
WHERE products.product_id IN ((
SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id"
FROM remove_discontinued_orders
))
RETURNING products.product_id AS "products.product_id",
products.product_name AS "products.product_name",
products.supplier_id AS "products.supplier_id",
products.category_id AS "products.category_id",
products.quantity_per_unit AS "products.quantity_per_unit",
products.unit_price AS "products.unit_price",
products.units_in_stock AS "products.units_in_stock",
products.units_on_order AS "products.units_on_order",
products.reorder_level AS "products.reorder_level",
products.discontinued AS "products.discontinued"
),log_discontinued AS (
INSERT INTO northwind.product_logs (product_id, product_name, supplier_id, category_id, quantity_per_unit, unit_price, units_in_stock, units_on_order, reorder_level, discontinued) (
SELECT update_discontinued_price."products.product_id" AS "products.product_id",
update_discontinued_price."products.product_name" AS "products.product_name",
update_discontinued_price."products.supplier_id" AS "products.supplier_id",
update_discontinued_price."products.category_id" AS "products.category_id",
update_discontinued_price."products.quantity_per_unit" AS "products.quantity_per_unit",
update_discontinued_price."products.unit_price" AS "products.unit_price",
update_discontinued_price."products.units_in_stock" AS "products.units_in_stock",
update_discontinued_price."products.units_on_order" AS "products.units_on_order",
update_discontinued_price."products.reorder_level" AS "products.reorder_level",
update_discontinued_price."products.discontinued" AS "products.discontinued"
FROM update_discontinued_price
)
RETURNING product_logs.product_id AS "product_logs.product_id",
product_logs.product_name AS "product_logs.product_name",
product_logs.supplier_id AS "product_logs.supplier_id",
product_logs.category_id AS "product_logs.category_id",
product_logs.quantity_per_unit AS "product_logs.quantity_per_unit",
product_logs.unit_price AS "product_logs.unit_price",
product_logs.units_in_stock AS "product_logs.units_in_stock",
product_logs.units_on_order AS "product_logs.units_on_order",
product_logs.reorder_level AS "product_logs.reorder_level",
product_logs.discontinued AS "product_logs.discontinued"
)
SELECT log_discontinued."product_logs.product_id" AS "product_logs.product_id",
log_discontinued."product_logs.product_name" AS "product_logs.product_name",
log_discontinued."product_logs.supplier_id" AS "product_logs.supplier_id",
log_discontinued."product_logs.category_id" AS "product_logs.category_id",
log_discontinued."product_logs.quantity_per_unit" AS "product_logs.quantity_per_unit",
log_discontinued."product_logs.unit_price" AS "product_logs.unit_price",
log_discontinued."product_logs.units_in_stock" AS "product_logs.units_in_stock",
log_discontinued."product_logs.units_on_order" AS "product_logs.units_on_order",
log_discontinued."product_logs.reorder_level" AS "product_logs.reorder_level",
log_discontinued."product_logs.discontinued" AS "product_logs.discontinued"
FROM log_discontinued;
`, int64(1), 0.0)
var resp []model.ProductLogs
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
err = stmt.Query(tx, &resp)
require.NoError(t, err)
}

@ -1 +1 @@
Subproject commit 1745be34a649c0f37d0d31d7c0352a1248ace2dc
Subproject commit ed53a505eb738d1be457877eee251f9ba0418df1