From fa69565dbfceffa7d31f347d4bb8757017442954 Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 28 Mar 2023 13:16:57 +0200 Subject: [PATCH] Add support for postgres GROUPING SET, ROLLUP and CUBE grouping operators Add support for mysql WITH ROLLUP grouping operator Add support for GROUPING operator --- generator/mysql/mysql_generator.go | 2 +- generator/postgres/postgres_generator.go | 2 +- internal/jet/func_expression.go | 79 ++++--- internal/jet/group_by_clause.go | 35 +++ internal/jet/literal_expression.go | 2 +- mysql/functions.go | 14 +- postgres/functions.go | 33 ++- tests/mysql/select_test.go | 95 +++++++- tests/postgres/select_test.go | 284 +++++++++++++++++++++-- 9 files changed, 476 insertions(+), 70 deletions(-) diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index b9280d2..cfcbae7 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -68,7 +68,7 @@ func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error } func openConnection(connectionString string) *sql.DB { - fmt.Println("Connecting to MySQL database") + fmt.Println("Connecting to MySQL database...") db, err := sql.Open("mysql", connectionString) throw.OnError(err) diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index 3298376..54583ea 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -69,7 +69,7 @@ func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (e } func openConnection(dsn string) *sql.DB { - fmt.Printf("Connecting to postgres database") + fmt.Println("Connecting to postgres database...") db, err := sql.Open("postgres", dsn) throw.OnError(err) diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 6900457..4ff3791 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -12,7 +12,7 @@ func OR(expressions ...BoolExpression) BoolExpression { return newBoolExpressionListOperator("OR", expressions...) } -// ROW is construct one table row from list of expressions. +// ROW function is used to create a tuple value that consists of a set of expressions or column values. func ROW(expressions ...Expression) Expression { return NewFunc("ROW", expressions, nil) } @@ -602,16 +602,16 @@ func LEAST(value Expression, values ...Expression) Expression { type funcExpressionImpl struct { ExpressionInterfaceImpl - name string - expressions []Expression - noBrackets bool + name string + parameters parametersSerializer + noBrackets bool } // NewFunc creates new function with name and expressions parameters func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { funcExp := &funcExpressionImpl{ - name: name, - expressions: parameters(expressions), + name: name, + parameters: parametersSerializer(expressions), } if parent != nil { @@ -623,18 +623,43 @@ func NewFunc(name string, expressions []Expression, parent Expression) *funcExpr return funcExp } -func parameters(expressions []Expression) []Expression { - var ret []Expression - - for _, expression := range expressions { - if _, isStatement := expression.(Statement); isStatement { - ret = append(ret, expression) - } else { - ret = append(ret, skipWrap(expression)) - } +func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { + serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.parameters)...) + serializeOverrideFunc(statement, out, FallTrough(options)...) + return } - return ret + addBrackets := !f.noBrackets || len(f.parameters) > 0 + + if addBrackets { + out.WriteString(f.name + "(") + } else { + out.WriteString(f.name) + } + + f.parameters.serialize(statement, out, options...) + + if addBrackets { + out.WriteString(")") + } +} + +type parametersSerializer []Expression + +func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + + for i, expression := range p { + if i > 0 { + out.WriteString(", ") + } + + if _, isStatement := expression.(Statement); isStatement { + expression.serialize(statement, out, options...) + } else { + skipWrap(expression).serialize(statement, out, options...) + } + } } // NewFloatWindowFunc creates new float function with name and expressions @@ -646,28 +671,6 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression { return windowExpr } -func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...) - serializeOverrideFunc(statement, out, FallTrough(options)...) - return - } - - addBrackets := !f.noBrackets || len(f.expressions) > 0 - - if addBrackets { - out.WriteString(f.name + "(") - } else { - out.WriteString(f.name) - } - - serializeExpressionList(statement, f.expressions, ", ", out) - - if addBrackets { - out.WriteString(")") - } -} - type boolFunc struct { funcExpressionImpl boolInterfaceImpl diff --git a/internal/jet/group_by_clause.go b/internal/jet/group_by_clause.go index 65687c6..5f0c4b1 100644 --- a/internal/jet/group_by_clause.go +++ b/internal/jet/group_by_clause.go @@ -4,3 +4,38 @@ package jet type GroupByClause interface { serializeForGroupBy(statement StatementType, out *SQLBuilder) } + +// GROUPING_SETS operator allows grouping of the rows in a table by multiple sets of columns in a single query. +// This can be useful when we want to analyze data by different combinations of columns, without having to write separate +// queries for each combination. +func GROUPING_SETS(expressions ...Expression) GroupByClause { + return Func("GROUPING SETS", expressions...) +} + +// ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list. +// It creates extra rows in the result set that represent the subtotal values for each combination of columns. +func ROLLUP(expressions ...Expression) GroupByClause { + return Func("ROLLUP", expressions...) +} + +// CUBE operator is used with the GROUP BY clause to generate subtotals for all possible combinations of a group of columns. +// It creates extra rows in the result set that represent the subtotal values for each combination of columns. +func CUBE(expressions ...Expression) GroupByClause { + return Func("CUBE", expressions...) +} + +// GROUPING function is used to identify which columns are included in a grouping set or a subtotal row. It takes as input +// the name of a column and returns 1 if the column is not included in the current grouping set, and 0 otherwise. +// It can be also used with multiple parameters to check if a set of columns is included in the current grouping set. The result +// of the GROUPING function would then be an integer bit mask having 1’s for the arguments which have GROUPING(argument) as 1. +func GROUPING(expressions ...Expression) IntegerExpression { + return IntExp(Func("GROUPING", expressions...)) +} + +// WITH_ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list. +// It creates extra rows in the result set that represent the subtotal values for each combination of columns. +func WITH_ROLLUP(expressions ...Expression) GroupByClause { + return newCustomExpression( + parametersSerializer(expressions), Token("WITH ROLLUP"), + ) +} diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index c71cc15..3d89dc7 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -386,7 +386,7 @@ func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options . out.WriteString(")") } -// WRAP wraps list of expressions with brackets '(' and ')' +// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) func WRAP(expression ...Expression) Expression { wrap := &wrap{expressions: expression} wrap.ExpressionInterfaceImpl.Parent = wrap diff --git a/mysql/functions.go b/mysql/functions.go index e2e6776..4eba942 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -11,7 +11,7 @@ var ( OR = jet.OR ) -// ROW is construct one table row from list of expressions. +// ROW function is used to create a tuple value that consists of a set of expressions or column values. var ROW = jet.ROW // ------------------ Mathematical functions ---------------// @@ -281,3 +281,15 @@ var GREATEST = jet.GREATEST // LEAST selects the smallest value from a list of expressions, or null if any of the expressions is null. var LEAST = jet.LEAST + +// ----------------------- Group By operators ----------------------------// + +// WITH_ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list. +// It creates extra rows in the result set that represent the subtotal values for each combination of columns. +var WITH_ROLLUP = jet.WITH_ROLLUP + +// GROUPING function is used to identify which columns are included in a grouping set or a subtotal row. It takes as input +// the name of a column and returns 1 if the column is not included in the current grouping set, and 0 otherwise. +// It can be also used with multiple parameters to check if a set of columns is included in the current grouping set. The result +// of the GROUPING function would then be an integer bit mask having 1’s for the arguments which have GROUPING(argument) as 1. +var GROUPING = jet.GROUPING diff --git a/postgres/functions.go b/postgres/functions.go index 5b07b45..43a5f39 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -13,7 +13,7 @@ var ( OR = jet.OR ) -// ROW is construct one table row from list of expressions. +// ROW function is used to create a tuple value that consists of a set of expressions or column values. var ROW = jet.ROW // ------------------ Mathematical functions ---------------// @@ -390,3 +390,34 @@ func castFloatLiteral(fraction FloatExpression) FloatExpression { } return fraction } + +// ----------------- Group By operators --------------------------// + +// GROUPING_SETS operator allows grouping of the rows in a table by multiple sets of columns(or expressions) in a single query. +// This can be useful when we want to analyze data by different combinations of columns, without having to write separate +// queries for each combination. GROUPING_SETS sets of columns are constructed with WRAP method. +// +// GROUPING_SETS( +// WRAP(Inventory.FilmID, Inventory.StoreID), +// WRAP(), +// ), +var GROUPING_SETS = jet.GROUPING_SETS + +// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) +// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW method behave exactly the same, +// except when used in GROUPING_SETS. For top level GROUPING SETS expression lists WRAP has to be used. +var WRAP = jet.WRAP + +// ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list. +// It creates extra rows in the result set that represent the subtotal values for each combination of columns. +var ROLLUP = jet.ROLLUP + +// CUBE operator is used with the GROUP BY clause to generate subtotals for all possible combinations of a group of columns. +// It creates extra rows in the result set that represent the subtotal values for each combination of columns. +var CUBE = jet.CUBE + +// GROUPING function is used to identify which columns are included in a grouping set or a subtotal row. It takes as input +// the name of a column and returns 1 if the column is not included in the current grouping set, and 0 otherwise. +// It can be also used with multiple parameters to check if a set of columns is included in the current grouping set. The result +// of the GROUPING function would then be an integer bit mask having 1’s for the arguments which have GROUPING(argument) as 1. +var GROUPING = jet.GROUPING diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index f07f810..3c2b949 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -217,16 +217,95 @@ GROUP BY payment.customer_id; } +func TestGroupByWithRollup(t *testing.T) { + skipForMariaDB(t) + + stmt := SELECT( + Inventory.FilmID.AS("film_id"), + Inventory.StoreID.AS("store_id"), + GROUPING(Inventory.FilmID).AS("grouping_film_id"), + GROUPING(Inventory.FilmID, Inventory.StoreID).AS("grouping_film_id_store_id"), + COUNT(STAR).AS("count"), + ).FROM( + Inventory, + ).WHERE( + Inventory.FilmID.IN(Int(2), Int(3)), + ).GROUP_BY( + WITH_ROLLUP(Inventory.FilmID, Inventory.StoreID), + ).ORDER_BY( + Inventory.FilmID, + Inventory.StoreID, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT inventory.film_id AS "film_id", + inventory.store_id AS "store_id", + GROUPING(inventory.film_id) AS "grouping_film_id", + GROUPING(inventory.film_id, inventory.store_id) AS "grouping_film_id_store_id", + COUNT(*) AS "count" +FROM dvds.inventory +WHERE inventory.film_id IN (2, 3) +GROUP BY inventory.film_id, inventory.store_id WITH ROLLUP +ORDER BY inventory.film_id, inventory.store_id; +`) + + var dest []struct { + FilmID int + StoreID int + GroupingFilmID int + GroupingFilmIDStoreID int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + testutils.AssertJSON(t, dest, ` +[ + { + "FilmID": 0, + "StoreID": 0, + "GroupingFilmID": 1, + "GroupingFilmIDStoreID": 3 + }, + { + "FilmID": 2, + "StoreID": 0, + "GroupingFilmID": 0, + "GroupingFilmIDStoreID": 1 + }, + { + "FilmID": 2, + "StoreID": 2, + "GroupingFilmID": 0, + "GroupingFilmIDStoreID": 0 + }, + { + "FilmID": 3, + "StoreID": 0, + "GroupingFilmID": 0, + "GroupingFilmIDStoreID": 1 + }, + { + "FilmID": 3, + "StoreID": 2, + "GroupingFilmID": 0, + "GroupingFilmIDStoreID": 0 + } +] +`) +} + func TestSubQuery(t *testing.T) { - rRatingFilms := Film. - SELECT( - Film.FilmID, - Film.Title, - Film.Rating, - ). - WHERE(Film.Rating.EQ(enum.FilmRating.R)). - AsTable("rFilms") + rRatingFilms := SELECT( + Film.FilmID, + Film.Title, + Film.Rating, + ).FROM( + Film, + ).WHERE( + Film.Rating.EQ(enum.FilmRating.R), + ).AsTable("rFilms") rFilmID := Film.FilmID.From(rRatingFilms) diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 96311a0..9b1164b 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1370,26 +1370,26 @@ GROUP BY customer.customer_id HAVING SUM(payment.amount) > 125.6 ORDER BY customer.customer_id, SUM(payment.amount) ASC; ` - query := Payment. - INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)). - SELECT( - Customer.AllColumns, + query := SELECT( + Customer.AllColumns, - SUMf(Payment.Amount).AS("amount.sum"), - AVG(Payment.Amount).AS("amount.avg"), - MAX(Payment.PaymentDate).AS("amount.max_date"), - MAXf(Payment.Amount).AS("amount.max"), - MIN(Payment.PaymentDate).AS("amount.min_date"), - MINf(Payment.Amount).AS("amount.min"), - COUNT(Payment.Amount).AS("amount.count"), - ). - GROUP_BY(Customer.CustomerID). - HAVING( - SUMf(Payment.Amount).GT(Float(125.6)), - ). - ORDER_BY( - Customer.CustomerID, SUMf(Payment.Amount).ASC(), - ) + SUMf(Payment.Amount).AS("amount.sum"), + AVG(Payment.Amount).AS("amount.avg"), + MAX(Payment.PaymentDate).AS("amount.max_date"), + MAXf(Payment.Amount).AS("amount.max"), + MIN(Payment.PaymentDate).AS("amount.min_date"), + MINf(Payment.Amount).AS("amount.min"), + COUNT(Payment.Amount).AS("amount.count"), + ).FROM( + Payment. + INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)), + ).GROUP_BY( + Customer.CustomerID, + ).HAVING( + SUMf(Payment.Amount).GT(Float(125.6)), + ).ORDER_BY( + Customer.CustomerID, SUMf(Payment.Amount).ASC(), + ) //fmt.Println(query.DebugSql()) @@ -1422,6 +1422,252 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC; testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json") } +func TestGroupByGroupingSets(t *testing.T) { + skipForCockroachDB(t) + + stmt := SELECT( + GROUPING(Inventory.FilmID, Inventory.StoreID).AS("grouping_filmId_store_id"), + Inventory.FilmID.AS("film_id"), + Inventory.StoreID.AS("store_id"), + COUNT(Inventory.InventoryID).AS("count"), + ).FROM( + Inventory, + ).WHERE( + Inventory.FilmID.IN(Int(2), Int(3)), + ).GROUP_BY( + GROUPING_SETS( + WRAP(Inventory.FilmID, Inventory.StoreID), + WRAP(Inventory.FilmID), + WRAP(), + ), + ).ORDER_BY( + Inventory.FilmID, + Inventory.StoreID, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT GROUPING(inventory.film_id, inventory.store_id) AS "grouping_filmId_store_id", + inventory.film_id AS "film_id", + inventory.store_id AS "store_id", + COUNT(inventory.inventory_id) AS "count" +FROM dvds.inventory +WHERE inventory.film_id IN (2, 3) +GROUP BY GROUPING SETS((inventory.film_id, inventory.store_id), (inventory.film_id), ()) +ORDER BY inventory.film_id, inventory.store_id; +`) + + var dest []struct { + GroupingFilmIDStoreID int + FilmID *int + StoreID *int + Count int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + //testutils.PrintJson(dest) + + testutils.AssertJSON(t, dest, ` +[ + { + "GroupingFilmIDStoreID": 0, + "FilmID": 2, + "StoreID": 2, + "Count": 3 + }, + { + "GroupingFilmIDStoreID": 1, + "FilmID": 2, + "StoreID": null, + "Count": 3 + }, + { + "GroupingFilmIDStoreID": 0, + "FilmID": 3, + "StoreID": 2, + "Count": 4 + }, + { + "GroupingFilmIDStoreID": 1, + "FilmID": 3, + "StoreID": null, + "Count": 4 + }, + { + "GroupingFilmIDStoreID": 3, + "FilmID": null, + "StoreID": null, + "Count": 7 + } +] +`) +} + +func TestGroupByCube(t *testing.T) { + skipForCockroachDB(t) + + stmt := SELECT( + Country.Country.AS("country"), + City.City.AS("city"), + COUNT(City.CityID).AS("count"), + ).FROM( + City.INNER_JOIN( + Country, + Country.CountryID.EQ(City.CountryID), + ), + ).WHERE( + Country.Country.EQ(String("Belarus")), + ).GROUP_BY( + CUBE(Country.Country, City.City), + ).ORDER_BY( + Country.Country, + City.City, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT country.country AS "country", + city.city AS "city", + COUNT(city.city_id) AS "count" +FROM dvds.city + INNER JOIN dvds.country ON (country.country_id = city.country_id) +WHERE country.country = 'Belarus'::text +GROUP BY CUBE(country.country, city.city) +ORDER BY country.country, city.city; +`) + + var dest []struct { + Country string + City string + Count int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + testutils.AssertJSON(t, dest, ` +[ + { + "Country": "Belarus", + "City": "Mogiljov", + "Count": 1 + }, + { + "Country": "", + "City": "Mogiljov", + "Count": 1 + }, + { + "Country": "Belarus", + "City": "Molodetno", + "Count": 1 + }, + { + "Country": "", + "City": "Molodetno", + "Count": 1 + }, + { + "Country": "", + "City": "", + "Count": 2 + }, + { + "Country": "Belarus", + "City": "", + "Count": 2 + } +] +`) +} + +func TestGroupByRollup(t *testing.T) { + skipForCockroachDB(t) + + stmt := SELECT( + EXTRACT(YEAR, Rental.RentalDate).AS("year"), + EXTRACT(MONTH, Rental.RentalDate).AS("month"), + EXTRACT(DAY, Rental.RentalDate).AS("day"), + COUNT(Rental.RentalID).AS("count"), + ).FROM( + Rental, + ).WHERE( + Rental.RentalDate.LT(Timestamp(2005, 5, 26, 1, 1, 1)), + ).GROUP_BY( + ROLLUP( + EXTRACT(YEAR, Rental.RentalDate), + EXTRACT(MONTH, Rental.RentalDate), + EXTRACT(DAY, Rental.RentalDate), + ), + ).ORDER_BY( + IntegerColumn("year").ASC(), + EXTRACT(MONTH, Rental.RentalDate).ASC(), + IntegerColumn("day").ASC(), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT EXTRACT(YEAR FROM rental.rental_date) AS "year", + EXTRACT(MONTH FROM rental.rental_date) AS "month", + EXTRACT(DAY FROM rental.rental_date) AS "day", + COUNT(rental.rental_id) AS "count" +FROM dvds.rental +WHERE rental.rental_date < '2005-05-26 01:01:01'::timestamp without time zone +GROUP BY ROLLUP(EXTRACT(YEAR FROM rental.rental_date), EXTRACT(MONTH FROM rental.rental_date), EXTRACT(DAY FROM rental.rental_date)) +ORDER BY year ASC, EXTRACT(MONTH FROM rental.rental_date) ASC, day ASC; +`) + + var dest []struct { + Year *int + Month *int + Day *int + Count int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + testutils.AssertJSON(t, dest, ` +[ + { + "Year": 2005, + "Month": 5, + "Day": 24, + "Count": 8 + }, + { + "Year": 2005, + "Month": 5, + "Day": 25, + "Count": 137 + }, + { + "Year": 2005, + "Month": 5, + "Day": 26, + "Count": 9 + }, + { + "Year": 2005, + "Month": 5, + "Day": null, + "Count": 154 + }, + { + "Year": 2005, + "Month": null, + "Day": null, + "Count": 154 + }, + { + "Year": null, + "Month": null, + "Day": null, + "Count": 154 + } +] +`) +} + func TestAggregateFunctionDistinct(t *testing.T) { stmt := SELECT( Payment.CustomerID,