Add WITH RECURSIVE statement support

This commit is contained in:
go-jet 2021-12-29 19:07:59 +01:00
parent 001d64f1dc
commit 038a32b032
17 changed files with 695 additions and 91 deletions

View file

@ -2,38 +2,41 @@ package jet
// SelectTable is interface for SELECT sub-queries // SelectTable is interface for SELECT sub-queries
type SelectTable interface { type SelectTable interface {
Serializer SerializerHasProjections
Alias() string Alias() string
AllColumns() ProjectionList AllColumns() ProjectionList
} }
type selectTableImpl struct { type selectTableImpl struct {
selectStmt SerializerStatement Statement SerializerHasProjections
alias string alias string
} }
// NewSelectTable func // NewSelectTable func
func NewSelectTable(selectStmt SerializerStatement, alias string) selectTableImpl { func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias} selectTable := selectTableImpl{
Statement: selectStmt,
alias: alias,
}
return selectTable return selectTable
} }
func (s selectTableImpl) projections() ProjectionList {
return s.Statement.projections()
}
func (s selectTableImpl) Alias() string { func (s selectTableImpl) Alias() string {
return s.alias return s.alias
} }
func (s selectTableImpl) AllColumns() ProjectionList { func (s selectTableImpl) AllColumns() ProjectionList {
statementWithProjections, ok := s.selectStmt.(HasProjections) projectionList := s.projections().fromImpl(s)
if !ok {
return ProjectionList{}
}
projectionList := statementWithProjections.projections().fromImpl(s)
return projectionList.(ProjectionList) return projectionList.(ProjectionList)
} }
func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.selectStmt.serialize(statement, out) s.Statement.serialize(statement, out)
out.WriteString("AS") out.WriteString("AS")
out.WriteIdentifier(s.alias) out.WriteIdentifier(s.alias)
@ -52,7 +55,7 @@ func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("LATERAL") out.WriteString("LATERAL")
s.selectStmt.serialize(statement, out) s.Statement.serialize(statement, out)
out.WriteString("AS") out.WriteString("AS")
out.WriteIdentifier(s.alias) out.WriteIdentifier(s.alias)

View file

@ -51,6 +51,12 @@ type HasProjections interface {
projections() ProjectionList projections() ProjectionList
} }
// SerializerHasProjections interface is combination of Serializer and HasProjections interface
type SerializerHasProjections interface {
Serializer
HasProjections
}
// serializerStatementInterfaceImpl struct // serializerStatementInterfaceImpl struct
type serializerStatementInterfaceImpl struct { type serializerStatementInterfaceImpl struct {
dialect Dialect dialect Dialect
@ -200,7 +206,7 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti
} }
for _, clause := range s.Clauses { for _, clause := range s.Clauses {
clause.Serialize(statement, out, FallTrough(options)...) clause.Serialize(s.statementType, out, FallTrough(options)...)
} }
if contains(options, Ident) { if contains(options, Ident) {

View file

@ -68,8 +68,8 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) {
} }
} }
// SerializeColumnExpressionNames func // SerializeColumnExpressions func
func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType, func SerializeColumnExpressions(columns []ColumnExpression, statementType StatementType,
out *SQLBuilder, options ...SerializeOption) { out *SQLBuilder, options ...SerializeOption) {
for i, col := range columns { for i, col := range columns {
if i > 0 { if i > 0 {
@ -84,6 +84,21 @@ func SerializeColumnExpressionNames(columns []ColumnExpression, statementType St
} }
} }
// SerializeColumnExpressionNames func
func SerializeColumnExpressionNames(columns []ColumnExpression, out *SQLBuilder) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
out.WriteIdentifier(col.Name())
}
}
// ExpressionListToSerializerList converts list of expressions to list of serializers // ExpressionListToSerializerList converts list of expressions to list of serializers
func ExpressionListToSerializerList(expressions []Expression) []Serializer { func ExpressionListToSerializerList(expressions []Expression) []Serializer {
var ret []Serializer var ret []Serializer

View file

@ -1,7 +1,9 @@
package jet package jet
import "fmt"
// WITH function creates new with statement from list of common table expressions for specified dialect // WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, recursive bool, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement { func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(statement Statement) Statement {
newWithImpl := &withImpl{ newWithImpl := &withImpl{
recursive: recursive, recursive: recursive,
ctes: cte, ctes: cte,
@ -25,7 +27,7 @@ func WITH(dialect Dialect, recursive bool, cte ...CommonTableExpressionDefinitio
type withImpl struct { type withImpl struct {
serializerStatementInterfaceImpl serializerStatementInterfaceImpl
recursive bool recursive bool
ctes []CommonTableExpressionDefinition ctes []*CommonTableExpression
primaryStatement SerializerStatement primaryStatement SerializerStatement
} }
@ -54,35 +56,55 @@ func (w withImpl) projections() ProjectionList {
// CommonTableExpression contains information about a CTE. // CommonTableExpression contains information about a CTE.
type CommonTableExpression struct { type CommonTableExpression struct {
selectTableImpl selectTableImpl
NotMaterialized bool
Columns []ColumnExpression
} }
// CTE creates new named CommonTableExpression // CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression { func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
return CommonTableExpression{ cte := CommonTableExpression{
selectTableImpl: selectTableImpl{ selectTableImpl: NewSelectTable(nil, name),
selectStmt: nil, Columns: columns,
alias: name,
},
} }
for _, column := range cte.Columns {
column.setSubQuery(cte)
}
return cte
} }
func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if statement == WithStatementType { // serialize CTE definition
out.WriteIdentifier(c.alias)
if len(c.Columns) > 0 {
out.WriteByte('(')
SerializeColumnExpressionNames(c.Columns, out)
out.WriteByte(')')
}
out.WriteString("AS")
if c.NotMaterialized {
out.WriteString("NOT MATERIALIZED")
}
if c.Statement == nil {
panic(fmt.Sprintf("jet: '%s' CTE is not defined", c.alias))
}
c.Statement.serialize(statement, out, FallTrough(options)...)
} else { // serialize CTE in FROM clause
out.WriteIdentifier(c.alias) out.WriteIdentifier(c.alias)
} }
// AS returns sets definition for a CTE
func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition {
c.selectStmt = statement
return CommonTableExpressionDefinition{cte: c}
} }
// CommonTableExpressionDefinition contains implementation details of CTE // AllColumns returns list of all projections in the CTE
type CommonTableExpressionDefinition struct { func (c CommonTableExpression) AllColumns() ProjectionList {
cte *CommonTableExpression if len(c.Columns) > 0 {
return ColumnListToProjectionList(c.Columns)
} }
func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { return c.selectTableImpl.AllColumns()
out.WriteIdentifier(c.cte.alias)
out.WriteString("AS")
c.cte.selectStmt.serialize(statement, out, FallTrough(options)...)
} }

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
} }
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
subQuery := &selectTableImpl{ subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias), SelectTable: jet.NewSelectTable(selectStmt, alias),
} }

View file

@ -2,30 +2,65 @@ package mysql
import "github.com/go-jet/jet/v2/internal/jet" import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE. // CommonTableExpression defines set of interface methods for postgres CTEs
type CommonTableExpression struct { type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
internalCTE() *jet.CommonTableExpression
}
type commonTableExpression struct {
readableTableInterfaceImpl readableTableInterfaceImpl
jet.CommonTableExpression jet.CommonTableExpression
} }
// WITH function creates new WITH statement from list of common table expressions // WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, false, cte...) return jet.WITH(Dialect, false, toInternalCTE(cte)...)
} }
// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions // WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions
func WITH_RECURSIVE(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, true, cte...) return jet.WITH(Dialect, true, toInternalCTE(cte)...)
} }
// CTE creates new named CommonTableExpression // CTE creates new named commonTableExpression
func CTE(name string) CommonTableExpression { func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{ cte := &commonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{}, readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name), CommonTableExpression: jet.CTE(name, columns...),
} }
cte.parent = &cte cte.parent = cte
return cte return cte
} }
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
return &c.CommonTableExpression
}
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
var ret []*jet.CommonTableExpression
for _, cte := range ctes {
ret = append(ret, cte.internalCTE())
}
return ret
}

View file

@ -52,7 +52,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S
out.WriteString("ON CONFLICT") out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 { if len(o.indexExpressions) > 0 {
out.WriteString("(") out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")") out.WriteString(")")
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )
// Bool is boolean literal constructor
func Bool(value bool) BoolExpression { func Bool(value bool) BoolExpression {
return CAST(jet.Bool(value)).AS_BOOL() return CAST(jet.Bool(value)).AS_BOOL()
} }

View file

@ -2,7 +2,7 @@ package postgres
import "github.com/go-jet/jet/v2/internal/jet" import "github.com/go-jet/jet/v2/internal/jet"
// SelectTable is interface for MySQL sub-queries // SelectTable is interface for postgres sub-queries
type SelectTable interface { type SelectTable interface {
readableTable readableTable
jet.SelectTable jet.SelectTable
@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
} }
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
subQuery := &selectTableImpl{ subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias), SelectTable: jet.NewSelectTable(selectStmt, alias),
} }

View file

@ -2,30 +2,73 @@ package postgres
import "github.com/go-jet/jet/v2/internal/jet" import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE. // CommonTableExpression defines set of interface methods for postgres CTEs
type CommonTableExpression struct { type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
internalCTE() *jet.CommonTableExpression
}
type commonTableExpression struct {
readableTableInterfaceImpl readableTableInterfaceImpl
jet.CommonTableExpression jet.CommonTableExpression
} }
// WITH function creates new WITH statement from list of common table expressions // WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, false, cte...) return jet.WITH(Dialect, false, toInternalCTE(cte)...)
} }
// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions // WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions
func WITH_RECURSIVE(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, true, cte...) return jet.WITH(Dialect, true, toInternalCTE(cte)...)
} }
// CTE creates new named CommonTableExpression // CTE creates new named commonTableExpression
func CTE(name string) CommonTableExpression { func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{ cte := &commonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{}, readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name), CommonTableExpression: jet.CTE(name, columns...),
} }
cte.parent = &cte cte.parent = cte
return cte return cte
} }
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
// AS_NOT_MATERIALIZED is used to define not materialized CTE query
func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.NotMaterialized = true
c.CommonTableExpression.Statement = statement
return c
}
func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
return &c.CommonTableExpression
}
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
var ret []*jet.CommonTableExpression
for _, cte := range ctes {
ret = append(ret, cte.internalCTE())
}
return ret
}

View file

@ -45,7 +45,7 @@ func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.S
out.WriteString("ON CONFLICT") out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 { if len(o.indexExpressions) > 0 {
out.WriteString("(") out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) jet.SerializeColumnExpressions(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")") out.WriteString(")")
} }

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
} }
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
subQuery := &selectTableImpl{ subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias), SelectTable: jet.NewSelectTable(selectStmt, alias),
} }

View file

@ -2,30 +2,73 @@ package sqlite
import "github.com/go-jet/jet/v2/internal/jet" import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE. // CommonTableExpression defines set of interface methods for postgres CTEs
type CommonTableExpression struct { type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
internalCTE() *jet.CommonTableExpression
}
type commonTableExpression struct {
readableTableInterfaceImpl readableTableInterfaceImpl
jet.CommonTableExpression jet.CommonTableExpression
} }
// WITH function creates new WITH statement from list of common table expressions // WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { func WITH(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, false, cte...) return jet.WITH(Dialect, false, toInternalCTE(cte)...)
} }
// WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions // WITH_RECURSIVE function creates new WITH RECURSIVE statement from list of common table expressions
func WITH_RECURSIVE(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { func WITH_RECURSIVE(cte ...CommonTableExpression) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, true, cte...) return jet.WITH(Dialect, true, toInternalCTE(cte)...)
} }
// CTE creates new named CommonTableExpression // CTE creates new named commonTableExpression
func CTE(name string) CommonTableExpression { func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{ cte := &commonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{}, readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name), CommonTableExpression: jet.CTE(name, columns...),
} }
cte.parent = &cte cte.parent = cte
return cte return cte
} }
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
// AS_NOT_MATERIALIZED is used to define not materialized CTE query
func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression {
c.CommonTableExpression.NotMaterialized = true
c.CommonTableExpression.Statement = statement
return c
}
func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
return &c.CommonTableExpression
}
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
var ret []*jet.CommonTableExpression
for _, cte := range ctes {
ret = append(ret, cte.internalCTE())
}
return ret
}

View file

@ -151,9 +151,144 @@ func TestWITH_And_DELETE(t *testing.T) {
// fmt.Println(stmt.DebugSql()) // fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(`
WITH payments_to_delete AS (
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date",
payment.last_update AS "payment.last_update"
FROM dvds.payment
WHERE payment.amount < 0.5
)
DELETE FROM dvds.payment
WHERE payment.payment_id IN (
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete
);
`, "''", "`"))
tx, err := db.Begin() tx, err := db.Begin()
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback() defer tx.Rollback()
testutils.AssertExec(t, stmt, tx, 24) testutils.AssertExec(t, stmt, tx, 24)
} }
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
// CTE columns are listed as part of CTE definition
n1 := IntegerColumn("n1")
fibN1 := IntegerColumn("fibN1")
nextFibN1 := IntegerColumn("nextFibN1")
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
// CTE columns are columns from non-recursive select
fibonacci2 := CTE("fibonacci2")
n2 := IntegerColumn("n2").From(fibonacci2)
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
stmt := WITH_RECURSIVE(
fibonacci1.AS(
SELECT(
Int32(1), Int32(0), Int32(1),
).UNION_ALL(
SELECT(
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
).FROM(
fibonacci1,
).WHERE(
n1.LT(Int(20)),
),
),
),
fibonacci2.AS(
SELECT(
Int32(1).AS(n2.Name()),
Int32(0).AS(fibN2.Name()),
Int32(1).AS(nextFibN2.Name()),
).UNION_ALL(
SELECT(
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
).FROM(
fibonacci2,
).WHERE(
n2.LT(Int(20)),
),
),
),
)(
SELECT(
fibonacci1.AllColumns(),
fibonacci2.AllColumns(),
).FROM(
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
).WHERE(
n1.EQ(Int(20)),
),
)
// fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS (
(
SELECT ?,
?,
?
)
UNION ALL
(
SELECT fibonacci1.n1 + ?,
fibonacci1.''nextFibN1'' AS "nextFibN1",
fibonacci1.''fibN1'' + fibonacci1.''nextFibN1''
FROM fibonacci1
WHERE fibonacci1.n1 < ?
)
),fibonacci2 AS (
(
SELECT ? AS "n2",
? AS "fibN2",
? AS "nextFibN2"
)
UNION ALL
(
SELECT fibonacci2.n2 + ?,
fibonacci2.''nextFibN2'' AS "nextFibN2",
fibonacci2.''fibN2'' + fibonacci2.''nextFibN2''
FROM fibonacci2
WHERE fibonacci2.n2 < ?
)
)
SELECT fibonacci1.n1 AS "n1",
fibonacci1.''fibN1'' AS "fibN1",
fibonacci1.''nextFibN1'' AS "nextFibN1",
fibonacci2.n2 AS "n2",
fibonacci2.''fibN2'' AS "fibN2",
fibonacci2.''nextFibN2'' AS "nextFibN2"
FROM fibonacci1
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
WHERE fibonacci1.n1 = ?;
`, "''", "`"))
var dest struct {
N1 int
FibN1 int
NextFibN1 int
N2 int
FibN2 int
NextFibN2 int
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.N1, 20)
require.Equal(t, dest.FibN1, 4181)
require.Equal(t, dest.NextFibN1, 6765)
require.Equal(t, dest.N2, 20)
require.Equal(t, dest.FibN2, 4181)
require.Equal(t, dest.NextFibN2, 6765)
}

View file

@ -814,10 +814,10 @@ ORDER BY f1.film_id ASC;
type F1 model.Film type F1 model.Film
type F2 model.Film type F2 model.Film
theSameLengthFilms := []struct { var theSameLengthFilms []struct {
F1 F1 F1 F1
F2 F2 F2 F2
}{} }
err := query.Query(db, &theSameLengthFilms) err := query.Query(db, &theSameLengthFilms)
@ -858,7 +858,7 @@ LIMIT 1000;
Title2 string Title2 string
Length int16 Length int16
} }
films := []thesameLengthFilms{} var films []thesameLengthFilms
err := query.Query(db, &films) err := query.Query(db, &films)

View file

@ -218,7 +218,122 @@ FROM log_discontinued;
err = stmt.Query(tx, &resp) err = stmt.Query(tx, &resp)
require.NoError(t, err) require.NoError(t, err)
}
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
// CTE columns are listed as part of CTE definition
n1 := IntegerColumn("n1")
fibN1 := IntegerColumn("fibN1")
nextFibN1 := IntegerColumn("nextFibN1")
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
// CTE columns are columns from non-recursive select
fibonacci2 := CTE("fibonacci2")
n2 := IntegerColumn("n2").From(fibonacci2)
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
stmt := WITH_RECURSIVE(
fibonacci1.AS(
SELECT(
Int32(1), Int32(0), Int32(1),
).UNION_ALL(
SELECT(
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
).FROM(
fibonacci1,
).WHERE(
n1.LT(Int(20)),
),
),
),
fibonacci2.AS(
SELECT(
Int32(1).AS(n2.Name()),
Int32(0).AS(fibN2.Name()),
Int32(1).AS(nextFibN2.Name()),
).UNION_ALL(
SELECT(
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
).FROM(
fibonacci2,
).WHERE(
n2.LT(Int(20)),
),
),
),
)(
SELECT(
fibonacci1.AllColumns(),
fibonacci2.AllColumns(),
).FROM(
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
).WHERE(
n1.EQ(Int(20)),
),
)
//fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, `
WITH RECURSIVE fibonacci1 (n1, "fibN1", "nextFibN1") AS (
(
SELECT $1::integer,
$2::integer,
$3::integer
)
UNION ALL
(
SELECT fibonacci1.n1 + $4,
fibonacci1."nextFibN1" AS "nextFibN1",
fibonacci1."fibN1" + fibonacci1."nextFibN1"
FROM fibonacci1
WHERE fibonacci1.n1 < $5
)
),fibonacci2 AS (
(
SELECT $6::integer AS "n2",
$7::integer AS "fibN2",
$8::integer AS "nextFibN2"
)
UNION ALL
(
SELECT fibonacci2.n2 + $9,
fibonacci2."nextFibN2" AS "nextFibN2",
fibonacci2."fibN2" + fibonacci2."nextFibN2"
FROM fibonacci2
WHERE fibonacci2.n2 < $10
)
)
SELECT fibonacci1.n1 AS "n1",
fibonacci1."fibN1" AS "fibN1",
fibonacci1."nextFibN1" AS "nextFibN1",
fibonacci2.n2 AS "n2",
fibonacci2."fibN2" AS "fibN2",
fibonacci2."nextFibN2" AS "nextFibN2"
FROM fibonacci1
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
WHERE fibonacci1.n1 = $11;
`)
var dest struct {
N1 int
FibN1 int
NextFibN1 int
N2 int
FibN2 int
NextFibN2 int
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.N1, 20)
require.Equal(t, dest.FibN1, 4181)
require.Equal(t, dest.NextFibN1, 6765)
require.Equal(t, dest.N2, 20)
require.Equal(t, dest.FibN2, 4181)
require.Equal(t, dest.NextFibN2, 6765)
} }
// default column aliases from sub-queries are bubbled up to the main query, // default column aliases from sub-queries are bubbled up to the main query,
@ -298,13 +413,7 @@ FROM cte2;
require.Equal(t, dest[0].CustomColumn2, "custom_column_2") require.Equal(t, dest[0].CustomColumn2, "custom_column_2")
} }
type EmployeeWrap struct { func TestRecursiveWithStatement(t *testing.T) {
model.Employees
Subordinates []*EmployeeWrap
}
func TestWithRecursive(t *testing.T) {
subordinates := CTE("subordinates") subordinates := CTE("subordinates")
@ -333,6 +442,14 @@ func TestWithRecursive(t *testing.T) {
), ),
) )
//fmt.Println(stmt.DebugSql())
type EmployeeWrap struct {
model.Employees
Subordinates []*EmployeeWrap
}
type employeeID = int16 type employeeID = int16
employeeMap := make(map[employeeID]*EmployeeWrap) employeeMap := make(map[employeeID]*EmployeeWrap)
@ -352,7 +469,7 @@ func TestWithRecursive(t *testing.T) {
employeeMap[employeeModel.EmployeeID] = newEmployeeWrap employeeMap[employeeModel.EmployeeID] = newEmployeeWrap
if employeeModel.ReportsTo == nil { // top manager(always first row in the result) if result == nil { // top manager(always first row in the result)
result = newEmployeeWrap result = newEmployeeWrap
continue continue
} }
@ -559,3 +676,73 @@ func TestWithRecursive(t *testing.T) {
} }
`) `)
} }
var suppliersWithFax = CTE("suppliers_fax").AS(
SELECT(
Suppliers.SupplierID,
Suppliers.ContactName,
Suppliers.Country,
).FROM(
Suppliers,
).WHERE(Suppliers.Fax.IS_NOT_NULL()),
)
func SuppliersNotFromUSorAUS(suppliersCTE CommonTableExpression) CommonTableExpression {
return CTE("not_from_us_or_aus").AS(
SELECT(
suppliersCTE.AllColumns(),
).FROM(
suppliersCTE,
).WHERE(
Suppliers.Country.From(suppliersCTE).NOT_IN(String("US"), String("Australia")),
),
)
}
func TestCTEReuse(t *testing.T) {
suppliersFilteredByCountry := SuppliersNotFromUSorAUS(suppliersWithFax)
supplierContactName := Suppliers.ContactName.From(suppliersFilteredByCountry)
stmt := WITH(
suppliersWithFax,
suppliersFilteredByCountry,
)(
SELECT(
suppliersFilteredByCountry.AllColumns(),
).FROM(
suppliersFilteredByCountry,
).WHERE(
supplierContactName.NOT_EQ(String("John")),
),
)
// fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, `
WITH suppliers_fax AS (
SELECT suppliers.supplier_id AS "suppliers.supplier_id",
suppliers.contact_name AS "suppliers.contact_name",
suppliers.country AS "suppliers.country"
FROM northwind.suppliers
WHERE suppliers.fax IS NOT NULL
),not_from_us_or_aus AS (
SELECT suppliers_fax."suppliers.supplier_id" AS "suppliers.supplier_id",
suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name",
suppliers_fax."suppliers.country" AS "suppliers.country"
FROM suppliers_fax
WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia')
)
SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id",
not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name",
not_from_us_or_aus."suppliers.country" AS "suppliers.country"
FROM not_from_us_or_aus
WHERE not_from_us_or_aus."suppliers.contact_name" != 'John';
`)
var dest []model.Suppliers
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Len(t, dest, 11)
}

View file

@ -232,3 +232,117 @@ FROM payment;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
} }
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
// CTE columns are listed as part of CTE definition
n1 := IntegerColumn("n1")
fibN1 := IntegerColumn("fibN1")
nextFibN1 := IntegerColumn("nextFibN1")
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
// CTE columns are columns from non-recursive select
fibonacci2 := CTE("fibonacci2")
n2 := IntegerColumn("n2").From(fibonacci2)
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
stmt := WITH_RECURSIVE(
fibonacci1.AS(
SELECT(
Int32(1), Int32(0), Int32(1),
).UNION_ALL(
SELECT(
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
).FROM(
fibonacci1,
).WHERE(
n1.LT(Int(20)),
),
),
),
fibonacci2.AS(
SELECT(
Int32(1).AS(n2.Name()),
Int32(0).AS(fibN2.Name()),
Int32(1).AS(nextFibN2.Name()),
).UNION_ALL(
SELECT(
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
).FROM(
fibonacci2,
).WHERE(
n2.LT(Int(20)),
),
),
),
)(
SELECT(
fibonacci1.AllColumns(),
fibonacci2.AllColumns(),
).FROM(
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
).WHERE(
n1.EQ(Int(20)),
),
)
//fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS (
SELECT ?,
?,
?
UNION ALL
SELECT fibonacci1.n1 + ?,
fibonacci1.''nextFibN1'' AS "nextFibN1",
fibonacci1.''fibN1'' + fibonacci1.''nextFibN1''
FROM fibonacci1
WHERE fibonacci1.n1 < ?
),fibonacci2 AS (
SELECT ? AS "n2",
? AS "fibN2",
? AS "nextFibN2"
UNION ALL
SELECT fibonacci2.n2 + ?,
fibonacci2.''nextFibN2'' AS "nextFibN2",
fibonacci2.''fibN2'' + fibonacci2.''nextFibN2''
FROM fibonacci2
WHERE fibonacci2.n2 < ?
)
SELECT fibonacci1.n1 AS "n1",
fibonacci1.''fibN1'' AS "fibN1",
fibonacci1.''nextFibN1'' AS "nextFibN1",
fibonacci2.n2 AS "n2",
fibonacci2.''fibN2'' AS "fibN2",
fibonacci2.''nextFibN2'' AS "nextFibN2"
FROM fibonacci1
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
WHERE fibonacci1.n1 = ?;
`, "''", "`"))
var dest struct {
N1 int
FibN1 int
NextFibN1 int
N2 int
FibN2 int
NextFibN2 int
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.N1, 20)
require.Equal(t, dest.FibN1, 4181)
require.Equal(t, dest.NextFibN1, 6765)
require.Equal(t, dest.N2, 20)
require.Equal(t, dest.FibN2, 4181)
require.Equal(t, dest.NextFibN2, 6765)
}