From 038a32b0324bd6442b2630e93727d4e3537ecd59 Mon Sep 17 00:00:00 2001 From: go-jet Date: Wed, 29 Dec 2021 19:07:59 +0100 Subject: [PATCH] Add WITH RECURSIVE statement support --- internal/jet/select_table.go | 29 ++--- internal/jet/statement.go | 8 +- internal/jet/utils.go | 19 ++- internal/jet/with_statement.go | 68 +++++++---- mysql/select_table.go | 2 +- mysql/with_statement.go | 57 +++++++-- postgres/clause.go | 2 +- postgres/literal.go | 1 + postgres/select_table.go | 4 +- postgres/with_statement.go | 65 +++++++++-- sqlite/on_conflict_clause.go | 2 +- sqlite/select_table.go | 2 +- sqlite/with_statement.go | 65 +++++++++-- tests/mysql/with_test.go | 137 +++++++++++++++++++++- tests/postgres/select_test.go | 6 +- tests/postgres/with_test.go | 205 +++++++++++++++++++++++++++++++-- tests/sqlite/with_test.go | 114 ++++++++++++++++++ 17 files changed, 695 insertions(+), 91 deletions(-) diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index 541992f..c25fba3 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -2,38 +2,41 @@ package jet // SelectTable is interface for SELECT sub-queries type SelectTable interface { - Serializer + SerializerHasProjections Alias() string AllColumns() ProjectionList } type selectTableImpl struct { - selectStmt SerializerStatement - alias string + Statement SerializerHasProjections + alias string } // NewSelectTable func -func NewSelectTable(selectStmt SerializerStatement, alias string) selectTableImpl { - selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias} +func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl { + selectTable := selectTableImpl{ + Statement: selectStmt, + alias: alias, + } + return selectTable } +func (s selectTableImpl) projections() ProjectionList { + return s.Statement.projections() +} + func (s selectTableImpl) Alias() string { return s.alias } func (s selectTableImpl) AllColumns() ProjectionList { - statementWithProjections, ok := s.selectStmt.(HasProjections) - if !ok { - return ProjectionList{} - } - - projectionList := statementWithProjections.projections().fromImpl(s) + projectionList := s.projections().fromImpl(s) return projectionList.(ProjectionList) } func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - s.selectStmt.serialize(statement, out) + s.Statement.serialize(statement, out) out.WriteString("AS") out.WriteIdentifier(s.alias) @@ -52,7 +55,7 @@ func NewLateral(selectStmt SerializerStatement, alias string) SelectTable { func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("LATERAL") - s.selectStmt.serialize(statement, out) + s.Statement.serialize(statement, out) out.WriteString("AS") out.WriteIdentifier(s.alias) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 1d05045..a5ae83b 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -51,6 +51,12 @@ type HasProjections interface { projections() ProjectionList } +// SerializerHasProjections interface is combination of Serializer and HasProjections interface +type SerializerHasProjections interface { + Serializer + HasProjections +} + // serializerStatementInterfaceImpl struct type serializerStatementInterfaceImpl struct { dialect Dialect @@ -200,7 +206,7 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti } for _, clause := range s.Clauses { - clause.Serialize(statement, out, FallTrough(options)...) + clause.Serialize(s.statementType, out, FallTrough(options)...) } if contains(options, Ident) { diff --git a/internal/jet/utils.go b/internal/jet/utils.go index eab4403..113a396 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -68,8 +68,8 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } -// SerializeColumnExpressionNames func -func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType, +// SerializeColumnExpressions func +func SerializeColumnExpressions(columns []ColumnExpression, statementType StatementType, out *SQLBuilder, options ...SerializeOption) { for i, col := range columns { if i > 0 { @@ -84,6 +84,21 @@ func SerializeColumnExpressionNames(columns []ColumnExpression, statementType St } } +// SerializeColumnExpressionNames func +func SerializeColumnExpressionNames(columns []ColumnExpression, out *SQLBuilder) { + for i, col := range columns { + if i > 0 { + out.WriteString(", ") + } + + if col == nil { + panic("jet: nil column in columns list") + } + + out.WriteIdentifier(col.Name()) + } +} + // ExpressionListToSerializerList converts list of expressions to list of serializers func ExpressionListToSerializerList(expressions []Expression) []Serializer { var ret []Serializer diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index becfee1..783fa27 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -1,7 +1,9 @@ package jet +import "fmt" + // WITH function creates new with statement from list of common table expressions for specified dialect -func WITH(dialect Dialect, recursive bool, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement { +func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(statement Statement) Statement { newWithImpl := &withImpl{ recursive: recursive, ctes: cte, @@ -25,7 +27,7 @@ func WITH(dialect Dialect, recursive bool, cte ...CommonTableExpressionDefinitio type withImpl struct { serializerStatementInterfaceImpl recursive bool - ctes []CommonTableExpressionDefinition + ctes []*CommonTableExpression primaryStatement SerializerStatement } @@ -54,35 +56,55 @@ func (w withImpl) projections() ProjectionList { // CommonTableExpression contains information about a CTE. type CommonTableExpression struct { selectTableImpl + + NotMaterialized bool + Columns []ColumnExpression } // CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - return CommonTableExpression{ - selectTableImpl: selectTableImpl{ - selectStmt: nil, - alias: name, - }, +func CTE(name string, columns ...ColumnExpression) CommonTableExpression { + cte := CommonTableExpression{ + selectTableImpl: NewSelectTable(nil, name), + Columns: columns, } + + for _, column := range cte.Columns { + column.setSubQuery(cte) + } + + return cte } func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteIdentifier(c.alias) + if statement == WithStatementType { // serialize CTE definition + out.WriteIdentifier(c.alias) + if len(c.Columns) > 0 { + out.WriteByte('(') + SerializeColumnExpressionNames(c.Columns, out) + out.WriteByte(')') + } + out.WriteString("AS") + + if c.NotMaterialized { + out.WriteString("NOT MATERIALIZED") + } + + if c.Statement == nil { + panic(fmt.Sprintf("jet: '%s' CTE is not defined", c.alias)) + } + + c.Statement.serialize(statement, out, FallTrough(options)...) + + } else { // serialize CTE in FROM clause + out.WriteIdentifier(c.alias) + } } -// AS returns sets definition for a CTE -func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition { - c.selectStmt = statement - return CommonTableExpressionDefinition{cte: c} -} +// AllColumns returns list of all projections in the CTE +func (c CommonTableExpression) AllColumns() ProjectionList { + if len(c.Columns) > 0 { + return ColumnListToProjectionList(c.Columns) + } -// CommonTableExpressionDefinition contains implementation details of CTE -type CommonTableExpressionDefinition struct { - cte *CommonTableExpression -} - -func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteIdentifier(c.cte.alias) - out.WriteString("AS") - c.cte.selectStmt.serialize(statement, out, FallTrough(options)...) + return c.selectTableImpl.AllColumns() } diff --git a/mysql/select_table.go b/mysql/select_table.go index af9de27..ad22193 100644 --- a/mysql/select_table.go +++ b/mysql/select_table.go @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/mysql/with_statement.go b/mysql/with_statement.go index 2487ead..ca608cb 100644 --- a/mysql/with_statement.go +++ b/mysql/with_statement.go @@ -2,30 +2,65 @@ package mysql import "github.com/go-jet/jet/v2/internal/jet" -// CommonTableExpression contains information about a CTE. -type CommonTableExpression struct { +// CommonTableExpression defines set of interface methods for postgres CTEs +type CommonTableExpression interface { + SelectTable + + AS(statement jet.SerializerStatement) CommonTableExpression + // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. + ALIAS(alias string) SelectTable + + internalCTE() *jet.CommonTableExpression +} + +type commonTableExpression struct { readableTableInterfaceImpl jet.CommonTableExpression } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, false, cte...) +func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, false, toInternalCTE(cte)...) } // WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions -func WITH_RECURSIVE(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, true, cte...) +func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, true, toInternalCTE(cte)...) } -// CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - cte := CommonTableExpression{ +// CTE creates new named commonTableExpression +func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { + cte := &commonTableExpression{ readableTableInterfaceImpl: readableTableInterfaceImpl{}, - CommonTableExpression: jet.CTE(name), + CommonTableExpression: jet.CTE(name, columns...), } - cte.parent = &cte + cte.parent = cte return cte } + +// AS is used to define a CTE query +func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.Statement = statement + return c +} + +func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { + return &c.CommonTableExpression +} + +// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. +func (c *commonTableExpression) ALIAS(name string) SelectTable { + return newSelectTable(c, name) +} + +func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { + var ret []*jet.CommonTableExpression + + for _, cte := range ctes { + ret = append(ret, cte.internalCTE()) + } + + return ret +} diff --git a/postgres/clause.go b/postgres/clause.go index 3a23fd0..0953e26 100644 --- a/postgres/clause.go +++ b/postgres/clause.go @@ -52,7 +52,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S out.WriteString("ON CONFLICT") if len(o.indexExpressions) > 0 { out.WriteString("(") - jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName) out.WriteString(")") } diff --git a/postgres/literal.go b/postgres/literal.go index 524b251..e46b874 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -6,6 +6,7 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) +// Bool is boolean literal constructor func Bool(value bool) BoolExpression { return CAST(jet.Bool(value)).AS_BOOL() } diff --git a/postgres/select_table.go b/postgres/select_table.go index e11b7cd..f3d680d 100644 --- a/postgres/select_table.go +++ b/postgres/select_table.go @@ -2,7 +2,7 @@ package postgres import "github.com/go-jet/jet/v2/internal/jet" -// SelectTable is interface for MySQL sub-queries +// SelectTable is interface for postgres sub-queries type SelectTable interface { readableTable jet.SelectTable @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/postgres/with_statement.go b/postgres/with_statement.go index 6ee5035..698d6e3 100644 --- a/postgres/with_statement.go +++ b/postgres/with_statement.go @@ -2,30 +2,73 @@ package postgres import "github.com/go-jet/jet/v2/internal/jet" -// CommonTableExpression contains information about a CTE. -type CommonTableExpression struct { +// CommonTableExpression defines set of interface methods for postgres CTEs +type CommonTableExpression interface { + SelectTable + + AS(statement jet.SerializerStatement) CommonTableExpression + AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression + // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. + ALIAS(alias string) SelectTable + + internalCTE() *jet.CommonTableExpression +} + +type commonTableExpression struct { readableTableInterfaceImpl jet.CommonTableExpression } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, false, cte...) +func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, false, toInternalCTE(cte)...) } // WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions -func WITH_RECURSIVE(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, true, cte...) +func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, true, toInternalCTE(cte)...) } -// CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - cte := CommonTableExpression{ +// CTE creates new named commonTableExpression +func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { + cte := &commonTableExpression{ readableTableInterfaceImpl: readableTableInterfaceImpl{}, - CommonTableExpression: jet.CTE(name), + CommonTableExpression: jet.CTE(name, columns...), } - cte.parent = &cte + cte.parent = cte return cte } + +// AS is used to define a CTE query +func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.Statement = statement + return c +} + +// AS_NOT_MATERIALIZED is used to define not materialized CTE query +func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.NotMaterialized = true + c.CommonTableExpression.Statement = statement + return c +} + +func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { + return &c.CommonTableExpression +} + +// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. +func (c *commonTableExpression) ALIAS(name string) SelectTable { + return newSelectTable(c, name) +} + +func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { + var ret []*jet.CommonTableExpression + + for _, cte := range ctes { + ret = append(ret, cte.internalCTE()) + } + + return ret +} diff --git a/sqlite/on_conflict_clause.go b/sqlite/on_conflict_clause.go index d131b9e..1e2ec8f 100644 --- a/sqlite/on_conflict_clause.go +++ b/sqlite/on_conflict_clause.go @@ -45,7 +45,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S out.WriteString("ON CONFLICT") if len(o.indexExpressions) > 0 { out.WriteString("(") - jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName) out.WriteString(")") } diff --git a/sqlite/select_table.go b/sqlite/select_table.go index 4117e06..9ac7f72 100644 --- a/sqlite/select_table.go +++ b/sqlite/select_table.go @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go index 1e646e8..5375fff 100644 --- a/sqlite/with_statement.go +++ b/sqlite/with_statement.go @@ -2,30 +2,73 @@ package sqlite import "github.com/go-jet/jet/v2/internal/jet" -// CommonTableExpression contains information about a CTE. -type CommonTableExpression struct { +// CommonTableExpression defines set of interface methods for postgres CTEs +type CommonTableExpression interface { + SelectTable + + AS(statement jet.SerializerStatement) CommonTableExpression + AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression + // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. + ALIAS(alias string) SelectTable + + internalCTE() *jet.CommonTableExpression +} + +type commonTableExpression struct { readableTableInterfaceImpl jet.CommonTableExpression } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, false, cte...) +func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, false, toInternalCTE(cte)...) } // WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions -func WITH_RECURSIVE(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { - return jet.WITH(Dialect, true, cte...) +func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, true, toInternalCTE(cte)...) } -// CTE creates new named CommonTableExpression -func CTE(name string) CommonTableExpression { - cte := CommonTableExpression{ +// CTE creates new named commonTableExpression +func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { + cte := &commonTableExpression{ readableTableInterfaceImpl: readableTableInterfaceImpl{}, - CommonTableExpression: jet.CTE(name), + CommonTableExpression: jet.CTE(name, columns...), } - cte.parent = &cte + cte.parent = cte return cte } + +// AS is used to define a CTE query +func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.Statement = statement + return c +} + +// AS_NOT_MATERIALIZED is used to define not materialized CTE query +func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression { + c.CommonTableExpression.NotMaterialized = true + c.CommonTableExpression.Statement = statement + return c +} + +func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { + return &c.CommonTableExpression +} + +// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. +func (c *commonTableExpression) ALIAS(name string) SelectTable { + return newSelectTable(c, name) +} + +func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { + var ret []*jet.CommonTableExpression + + for _, cte := range ctes { + ret = append(ret, cte.internalCTE()) + } + + return ret +} diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index ddc1d12..cc8dfd6 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -149,7 +149,26 @@ func TestWITH_And_DELETE(t *testing.T) { ), ) - //fmt.Println(stmt.DebugSql()) + // fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +WITH payments_to_delete AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM dvds.payment + WHERE payment.amount < 0.5 +) +DELETE FROM dvds.payment +WHERE payment.payment_id IN ( + SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" + FROM payments_to_delete + ); +`, "''", "`")) tx, err := db.Begin() require.NoError(t, err) @@ -157,3 +176,119 @@ func TestWITH_And_DELETE(t *testing.T) { testutils.AssertExec(t, stmt, tx, 24) } + +func TestRecursiveWithStatement_Fibonacci(t *testing.T) { + // CTE columns are listed as part of CTE definition + n1 := IntegerColumn("n1") + fibN1 := IntegerColumn("fibN1") + nextFibN1 := IntegerColumn("nextFibN1") + fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1) + + // CTE columns are columns from non-recursive select + fibonacci2 := CTE("fibonacci2") + n2 := IntegerColumn("n2").From(fibonacci2) + fibN2 := IntegerColumn("fibN2").From(fibonacci2) + nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2) + + stmt := WITH_RECURSIVE( + fibonacci1.AS( + SELECT( + Int32(1), Int32(0), Int32(1), + ).UNION_ALL( + SELECT( + n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1), + ).FROM( + fibonacci1, + ).WHERE( + n1.LT(Int(20)), + ), + ), + ), + fibonacci2.AS( + SELECT( + Int32(1).AS(n2.Name()), + Int32(0).AS(fibN2.Name()), + Int32(1).AS(nextFibN2.Name()), + ).UNION_ALL( + SELECT( + n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2), + ).FROM( + fibonacci2, + ).WHERE( + n2.LT(Int(20)), + ), + ), + ), + )( + SELECT( + fibonacci1.AllColumns(), + fibonacci2.AllColumns(), + ).FROM( + fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)), + ).WHERE( + n1.EQ(Int(20)), + ), + ) + + // fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS ( + ( + SELECT ?, + ?, + ? + ) + UNION ALL + ( + SELECT fibonacci1.n1 + ?, + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci1.''fibN1'' + fibonacci1.''nextFibN1'' + FROM fibonacci1 + WHERE fibonacci1.n1 < ? + ) +),fibonacci2 AS ( + ( + SELECT ? AS "n2", + ? AS "fibN2", + ? AS "nextFibN2" + ) + UNION ALL + ( + SELECT fibonacci2.n2 + ?, + fibonacci2.''nextFibN2'' AS "nextFibN2", + fibonacci2.''fibN2'' + fibonacci2.''nextFibN2'' + FROM fibonacci2 + WHERE fibonacci2.n2 < ? + ) +) +SELECT fibonacci1.n1 AS "n1", + fibonacci1.''fibN1'' AS "fibN1", + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci2.n2 AS "n2", + fibonacci2.''fibN2'' AS "fibN2", + fibonacci2.''nextFibN2'' AS "nextFibN2" +FROM fibonacci1 + INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2) +WHERE fibonacci1.n1 = ?; +`, "''", "`")) + + var dest struct { + N1 int + FibN1 int + NextFibN1 int + + N2 int + FibN2 int + NextFibN2 int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.N1, 20) + require.Equal(t, dest.FibN1, 4181) + require.Equal(t, dest.NextFibN1, 6765) + require.Equal(t, dest.N2, 20) + require.Equal(t, dest.FibN2, 4181) + require.Equal(t, dest.NextFibN2, 6765) +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 90bd8f9..37ca365 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -814,10 +814,10 @@ ORDER BY f1.film_id ASC; type F1 model.Film type F2 model.Film - theSameLengthFilms := []struct { + var theSameLengthFilms []struct { F1 F1 F2 F2 - }{} + } err := query.Query(db, &theSameLengthFilms) @@ -858,7 +858,7 @@ LIMIT 1000; Title2 string Length int16 } - films := []thesameLengthFilms{} + var films []thesameLengthFilms err := query.Query(db, &films) diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 345320b..27d5ce0 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -144,7 +144,7 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { require.Equal(t, len(updateDiscontinuedPrice.AllColumns()[0].(ProjectionList)), 10) require.Equal(t, len(logDiscontinuedProducts.AllColumns()), 10) - //fmt.Println(stmt.Sql()) + // fmt.Println(stmt.Sql()) testutils.AssertStatementSql(t, stmt, ` WITH remove_discontinued_orders AS ( @@ -218,7 +218,122 @@ FROM log_discontinued; err = stmt.Query(tx, &resp) require.NoError(t, err) +} +func TestRecursiveWithStatement_Fibonacci(t *testing.T) { + // CTE columns are listed as part of CTE definition + n1 := IntegerColumn("n1") + fibN1 := IntegerColumn("fibN1") + nextFibN1 := IntegerColumn("nextFibN1") + fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1) + + // CTE columns are columns from non-recursive select + fibonacci2 := CTE("fibonacci2") + n2 := IntegerColumn("n2").From(fibonacci2) + fibN2 := IntegerColumn("fibN2").From(fibonacci2) + nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2) + + stmt := WITH_RECURSIVE( + fibonacci1.AS( + SELECT( + Int32(1), Int32(0), Int32(1), + ).UNION_ALL( + SELECT( + n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1), + ).FROM( + fibonacci1, + ).WHERE( + n1.LT(Int(20)), + ), + ), + ), + fibonacci2.AS( + SELECT( + Int32(1).AS(n2.Name()), + Int32(0).AS(fibN2.Name()), + Int32(1).AS(nextFibN2.Name()), + ).UNION_ALL( + SELECT( + n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2), + ).FROM( + fibonacci2, + ).WHERE( + n2.LT(Int(20)), + ), + ), + ), + )( + SELECT( + fibonacci1.AllColumns(), + fibonacci2.AllColumns(), + ).FROM( + fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)), + ).WHERE( + n1.EQ(Int(20)), + ), + ) + + //fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +WITH RECURSIVE fibonacci1 (n1, "fibN1", "nextFibN1") AS ( + ( + SELECT $1::integer, + $2::integer, + $3::integer + ) + UNION ALL + ( + SELECT fibonacci1.n1 + $4, + fibonacci1."nextFibN1" AS "nextFibN1", + fibonacci1."fibN1" + fibonacci1."nextFibN1" + FROM fibonacci1 + WHERE fibonacci1.n1 < $5 + ) +),fibonacci2 AS ( + ( + SELECT $6::integer AS "n2", + $7::integer AS "fibN2", + $8::integer AS "nextFibN2" + ) + UNION ALL + ( + SELECT fibonacci2.n2 + $9, + fibonacci2."nextFibN2" AS "nextFibN2", + fibonacci2."fibN2" + fibonacci2."nextFibN2" + FROM fibonacci2 + WHERE fibonacci2.n2 < $10 + ) +) +SELECT fibonacci1.n1 AS "n1", + fibonacci1."fibN1" AS "fibN1", + fibonacci1."nextFibN1" AS "nextFibN1", + fibonacci2.n2 AS "n2", + fibonacci2."fibN2" AS "fibN2", + fibonacci2."nextFibN2" AS "nextFibN2" +FROM fibonacci1 + INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2) +WHERE fibonacci1.n1 = $11; +`) + + var dest struct { + N1 int + FibN1 int + NextFibN1 int + + N2 int + FibN2 int + NextFibN2 int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.N1, 20) + require.Equal(t, dest.FibN1, 4181) + require.Equal(t, dest.NextFibN1, 6765) + require.Equal(t, dest.N2, 20) + require.Equal(t, dest.FibN2, 4181) + require.Equal(t, dest.NextFibN2, 6765) } // default column aliases from sub-queries are bubbled up to the main query, @@ -298,13 +413,7 @@ FROM cte2; require.Equal(t, dest[0].CustomColumn2, "custom_column_2") } -type EmployeeWrap struct { - model.Employees - - Subordinates []*EmployeeWrap -} - -func TestWithRecursive(t *testing.T) { +func TestRecursiveWithStatement(t *testing.T) { subordinates := CTE("subordinates") @@ -333,6 +442,14 @@ func TestWithRecursive(t *testing.T) { ), ) + //fmt.Println(stmt.DebugSql()) + + type EmployeeWrap struct { + model.Employees + + Subordinates []*EmployeeWrap + } + type employeeID = int16 employeeMap := make(map[employeeID]*EmployeeWrap) @@ -352,7 +469,7 @@ func TestWithRecursive(t *testing.T) { employeeMap[employeeModel.EmployeeID] = newEmployeeWrap - if employeeModel.ReportsTo == nil { // top manager(always first row in the result) + if result == nil { // top manager(always first row in the result) result = newEmployeeWrap continue } @@ -559,3 +676,73 @@ func TestWithRecursive(t *testing.T) { } `) } + +var suppliersWithFax = CTE("suppliers_fax").AS( + SELECT( + Suppliers.SupplierID, + Suppliers.ContactName, + Suppliers.Country, + ).FROM( + Suppliers, + ).WHERE(Suppliers.Fax.IS_NOT_NULL()), +) + +func SuppliersNotFromUSorAUS(suppliersCTE CommonTableExpression) CommonTableExpression { + return CTE("not_from_us_or_aus").AS( + SELECT( + suppliersCTE.AllColumns(), + ).FROM( + suppliersCTE, + ).WHERE( + Suppliers.Country.From(suppliersCTE).NOT_IN(String("US"), String("Australia")), + ), + ) +} + +func TestCTEReuse(t *testing.T) { + suppliersFilteredByCountry := SuppliersNotFromUSorAUS(suppliersWithFax) + supplierContactName := Suppliers.ContactName.From(suppliersFilteredByCountry) + + stmt := WITH( + suppliersWithFax, + suppliersFilteredByCountry, + )( + SELECT( + suppliersFilteredByCountry.AllColumns(), + ).FROM( + suppliersFilteredByCountry, + ).WHERE( + supplierContactName.NOT_EQ(String("John")), + ), + ) + + // fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +WITH suppliers_fax AS ( + SELECT suppliers.supplier_id AS "suppliers.supplier_id", + suppliers.contact_name AS "suppliers.contact_name", + suppliers.country AS "suppliers.country" + FROM northwind.suppliers + WHERE suppliers.fax IS NOT NULL +),not_from_us_or_aus AS ( + SELECT suppliers_fax."suppliers.supplier_id" AS "suppliers.supplier_id", + suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name", + suppliers_fax."suppliers.country" AS "suppliers.country" + FROM suppliers_fax + WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia') +) +SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id", + not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name", + not_from_us_or_aus."suppliers.country" AS "suppliers.country" +FROM not_from_us_or_aus +WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'; +`) + + var dest []model.Suppliers + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Len(t, dest, 11) +} diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go index f2b623a..92cd331 100644 --- a/tests/sqlite/with_test.go +++ b/tests/sqlite/with_test.go @@ -232,3 +232,117 @@ FROM payment; err := stmt.Query(db, &dest) require.NoError(t, err) } + +func TestRecursiveWithStatement_Fibonacci(t *testing.T) { + // CTE columns are listed as part of CTE definition + n1 := IntegerColumn("n1") + fibN1 := IntegerColumn("fibN1") + nextFibN1 := IntegerColumn("nextFibN1") + fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1) + + // CTE columns are columns from non-recursive select + fibonacci2 := CTE("fibonacci2") + n2 := IntegerColumn("n2").From(fibonacci2) + fibN2 := IntegerColumn("fibN2").From(fibonacci2) + nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2) + + stmt := WITH_RECURSIVE( + fibonacci1.AS( + SELECT( + Int32(1), Int32(0), Int32(1), + ).UNION_ALL( + SELECT( + n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1), + ).FROM( + fibonacci1, + ).WHERE( + n1.LT(Int(20)), + ), + ), + ), + fibonacci2.AS( + SELECT( + Int32(1).AS(n2.Name()), + Int32(0).AS(fibN2.Name()), + Int32(1).AS(nextFibN2.Name()), + ).UNION_ALL( + SELECT( + n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2), + ).FROM( + fibonacci2, + ).WHERE( + n2.LT(Int(20)), + ), + ), + ), + )( + SELECT( + fibonacci1.AllColumns(), + fibonacci2.AllColumns(), + ).FROM( + fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)), + ).WHERE( + n1.EQ(Int(20)), + ), + ) + + //fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS ( + + SELECT ?, + ?, + ? + + UNION ALL + + SELECT fibonacci1.n1 + ?, + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci1.''fibN1'' + fibonacci1.''nextFibN1'' + FROM fibonacci1 + WHERE fibonacci1.n1 < ? +),fibonacci2 AS ( + + SELECT ? AS "n2", + ? AS "fibN2", + ? AS "nextFibN2" + + UNION ALL + + SELECT fibonacci2.n2 + ?, + fibonacci2.''nextFibN2'' AS "nextFibN2", + fibonacci2.''fibN2'' + fibonacci2.''nextFibN2'' + FROM fibonacci2 + WHERE fibonacci2.n2 < ? +) +SELECT fibonacci1.n1 AS "n1", + fibonacci1.''fibN1'' AS "fibN1", + fibonacci1.''nextFibN1'' AS "nextFibN1", + fibonacci2.n2 AS "n2", + fibonacci2.''fibN2'' AS "fibN2", + fibonacci2.''nextFibN2'' AS "nextFibN2" +FROM fibonacci1 + INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2) +WHERE fibonacci1.n1 = ?; +`, "''", "`")) + + var dest struct { + N1 int + FibN1 int + NextFibN1 int + + N2 int + FibN2 int + NextFibN2 int + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.N1, 20) + require.Equal(t, dest.FibN1, 4181) + require.Equal(t, dest.NextFibN1, 6765) + require.Equal(t, dest.N2, 20) + require.Equal(t, dest.FibN2, 4181) + require.Equal(t, dest.NextFibN2, 6765) +}