Add WITH RECURSIVE statement support

This commit is contained in:
go-jet 2021-12-29 19:07:59 +01:00
parent 001d64f1dc
commit 038a32b032
17 changed files with 695 additions and 91 deletions

View file

@ -2,38 +2,41 @@ package jet
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Serializer
SerializerHasProjections
Alias() string
AllColumns() ProjectionList
}
type selectTableImpl struct {
selectStmt SerializerStatement
alias string
Statement SerializerHasProjections
alias string
}
// NewSelectTable func
func NewSelectTable(selectStmt SerializerStatement, alias string) selectTableImpl {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias}
func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
selectTable := selectTableImpl{
Statement: selectStmt,
alias: alias,
}
return selectTable
}
func (s selectTableImpl) projections() ProjectionList {
return s.Statement.projections()
}
func (s selectTableImpl) Alias() string {
return s.alias
}
func (s selectTableImpl) AllColumns() ProjectionList {
statementWithProjections, ok := s.selectStmt.(HasProjections)
if !ok {
return ProjectionList{}
}
projectionList := statementWithProjections.projections().fromImpl(s)
projectionList := s.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.selectStmt.serialize(statement, out)
s.Statement.serialize(statement, out)
out.WriteString("AS")
out.WriteIdentifier(s.alias)
@ -52,7 +55,7 @@ func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("LATERAL")
s.selectStmt.serialize(statement, out)
s.Statement.serialize(statement, out)
out.WriteString("AS")
out.WriteIdentifier(s.alias)

View file

@ -51,6 +51,12 @@ type HasProjections interface {
projections() ProjectionList
}
// SerializerHasProjections interface is combination of Serializer and HasProjections interface
type SerializerHasProjections interface {
Serializer
HasProjections
}
// serializerStatementInterfaceImpl struct
type serializerStatementInterfaceImpl struct {
dialect Dialect
@ -200,7 +206,7 @@ func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, opti
}
for _, clause := range s.Clauses {
clause.Serialize(statement, out, FallTrough(options)...)
clause.Serialize(s.statementType, out, FallTrough(options)...)
}
if contains(options, Ident) {

View file

@ -68,8 +68,8 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) {
}
}
// SerializeColumnExpressionNames func
func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType,
// SerializeColumnExpressions func
func SerializeColumnExpressions(columns []ColumnExpression, statementType StatementType,
out *SQLBuilder, options ...SerializeOption) {
for i, col := range columns {
if i > 0 {
@ -84,6 +84,21 @@ func SerializeColumnExpressionNames(columns []ColumnExpression, statementType St
}
}
// SerializeColumnExpressionNames func
func SerializeColumnExpressionNames(columns []ColumnExpression, out *SQLBuilder) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
out.WriteIdentifier(col.Name())
}
}
// ExpressionListToSerializerList converts list of expressions to list of serializers
func ExpressionListToSerializerList(expressions []Expression) []Serializer {
var ret []Serializer

View file

@ -1,7 +1,9 @@
package jet
import "fmt"
// WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, recursive bool, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement {
func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(statement Statement) Statement {
newWithImpl := &withImpl{
recursive: recursive,
ctes: cte,
@ -25,7 +27,7 @@ func WITH(dialect Dialect, recursive bool, cte ...CommonTableExpressionDefinitio
type withImpl struct {
serializerStatementInterfaceImpl
recursive bool
ctes []CommonTableExpressionDefinition
ctes []*CommonTableExpression
primaryStatement SerializerStatement
}
@ -54,35 +56,55 @@ func (w withImpl) projections() ProjectionList {
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
selectTableImpl
NotMaterialized bool
Columns []ColumnExpression
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
return CommonTableExpression{
selectTableImpl: selectTableImpl{
selectStmt: nil,
alias: name,
},
func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{
selectTableImpl: NewSelectTable(nil, name),
Columns: columns,
}
for _, column := range cte.Columns {
column.setSubQuery(cte)
}
return cte
}
func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.alias)
if statement == WithStatementType { // serialize CTE definition
out.WriteIdentifier(c.alias)
if len(c.Columns) > 0 {
out.WriteByte('(')
SerializeColumnExpressionNames(c.Columns, out)
out.WriteByte(')')
}
out.WriteString("AS")
if c.NotMaterialized {
out.WriteString("NOT MATERIALIZED")
}
if c.Statement == nil {
panic(fmt.Sprintf("jet: '%s' CTE is not defined", c.alias))
}
c.Statement.serialize(statement, out, FallTrough(options)...)
} else { // serialize CTE in FROM clause
out.WriteIdentifier(c.alias)
}
}
// AS returns sets definition for a CTE
func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition {
c.selectStmt = statement
return CommonTableExpressionDefinition{cte: c}
}
// AllColumns returns list of all projections in the CTE
func (c CommonTableExpression) AllColumns() ProjectionList {
if len(c.Columns) > 0 {
return ColumnListToProjectionList(c.Columns)
}
// CommonTableExpressionDefinition contains implementation details of CTE
type CommonTableExpressionDefinition struct {
cte *CommonTableExpression
}
func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.cte.alias)
out.WriteString("AS")
c.cte.selectStmt.serialize(statement, out, FallTrough(options)...)
return c.selectTableImpl.AllColumns()
}