Add support for WITH statements and Common Table Expressions.

This commit is contained in:
go-jet 2020-05-24 17:55:28 +02:00
parent 0d3ec872d6
commit fb8607da29
13 changed files with 406 additions and 39 deletions

View file

@ -13,17 +13,18 @@ type Clause interface {
type ClauseWithProjections interface {
Clause
projections() ProjectionList
Projections() ProjectionList
}
// ClauseSelect struct
type ClauseSelect struct {
Distinct bool
Projections []Projection
Distinct bool
ProjectionList []Projection
}
func (s *ClauseSelect) projections() ProjectionList {
return s.Projections
// Projections returns list of projections for select clause
func (s *ClauseSelect) Projections() ProjectionList {
return s.ProjectionList
}
// Serialize serializes clause into SQLBuilder
@ -35,11 +36,11 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
out.WriteString("DISTINCT")
}
if len(s.Projections) == 0 {
if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.WriteProjections(statementType, s.Projections)
out.WriteProjections(statementType, s.ProjectionList)
}
// ClauseFrom struct
@ -212,13 +213,14 @@ func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, opti
type ClauseSetStmtOperator struct {
Operator string
All bool
Selects []StatementWithProjections
Selects []SerializerStatement
OrderBy ClauseOrderBy
Limit ClauseLimit
Offset ClauseOffset
}
func (s *ClauseSetStmtOperator) projections() ProjectionList {
// Projections returns set of projections for ClauseSetStmtOperator
func (s *ClauseSetStmtOperator) Projections() ProjectionList {
if len(s.Selects) > 0 {
return s.Selects[0].projections()
}

View file

@ -105,7 +105,7 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias())
out.WriteByte('.')
out.WriteIdentifier(c.defaultAlias(), true)
out.WriteIdentifier(c.defaultAlias())
} else {
if c.tableName != "" && !contains(options, ShortName) {
out.WriteIdentifier(c.tableName)

View file

@ -145,6 +145,11 @@ func MINi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MIN", integerExpression)
}
// SUM is aggregate function. Returns sum of all expressions
func SUM(expression Expression) Expression {
return newWindowFunc("SUM", expression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("SUM", floatExpression)

View file

@ -8,35 +8,31 @@ type SelectTable interface {
}
type selectTableImpl struct {
selectStmt StatementWithProjections
selectStmt SerializerStatement
alias string
projections ProjectionList
}
// NewSelectTable func
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTable {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias}
projectionList := selectStmt.projections().fromImpl(&selectTable)
selectTable.projections = projectionList.(ProjectionList)
return &selectTable
func NewSelectTable(selectStmt SerializerStatement, alias string) SelectTable {
selectTable := &selectTableImpl{selectStmt: selectStmt, alias: alias}
return selectTable
}
func (s *selectTableImpl) Alias() string {
func (s selectTableImpl) Alias() string {
return s.alias
}
func (s *selectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if s == nil {
panic("jet: expression table is nil. ")
func (s selectTableImpl) AllColumns() ProjectionList {
statementWithProjections, ok := s.selectStmt.(HasProjections)
if !ok {
return ProjectionList{}
}
projectionList := statementWithProjections.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.selectStmt.serialize(statement, out)
out.WriteString("AS")

View file

@ -29,6 +29,7 @@ const (
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
WithStatementType StatementType = "WITH"
)
// Serializer interface

View file

@ -201,6 +201,13 @@ func integerTypesToString(value interface{}) string {
}
func shouldQuoteIdentifier(identifier string) bool {
_, err := strconv.ParseInt(identifier, 10, 64)
if err == nil { // if it is a number we should quote it
return true
}
// check if contains non ascii characters
for _, c := range identifier {
if unicode.IsNumber(c) || c == '_' {
continue

View file

@ -47,3 +47,13 @@ func TestFallTrough(t *testing.T) {
require.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil))
require.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName})
}
func TestShouldQuote(t *testing.T) {
require.Equal(t, shouldQuoteIdentifier("123"), true)
require.Equal(t, shouldQuoteIdentifier("123.235"), true)
require.Equal(t, shouldQuoteIdentifier("abc123"), false)
require.Equal(t, shouldQuoteIdentifier("abc.123"), true)
require.Equal(t, shouldQuoteIdentifier("abc_123"), false)
require.Equal(t, shouldQuoteIdentifier("Abc_123"), true)
require.Equal(t, shouldQuoteIdentifier("DŽƜĐǶ"), true)
}

View file

@ -32,13 +32,7 @@ type Statement interface {
type SerializerStatement interface {
Serializer
Statement
}
// StatementWithProjections interface
type StatementWithProjections interface {
Statement
HasProjections
Serializer
}
// HasProjections interface
@ -163,7 +157,7 @@ type statementImpl struct {
func (s *statementImpl) projections() ProjectionList {
for _, clause := range s.Clauses {
if selectClause, ok := clause.(ClauseWithProjections); ok {
return selectClause.projections()
return selectClause.Projections()
}
}
@ -171,7 +165,6 @@ func (s *statementImpl) projections() ProjectionList {
}
func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
out.IncreaseIdent()

View file

@ -0,0 +1,78 @@
package jet
// WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement SerializerStatement) Statement {
newWithImpl := &withImpl{
ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
statementType: WithStatementType,
},
}
newWithImpl.parent = newWithImpl
return func(primaryStatement SerializerStatement) Statement {
newWithImpl.primaryStatement = primaryStatement
return newWithImpl
}
}
type withImpl struct {
serializerStatementInterfaceImpl
ctes []CommonTableExpressionDefinition
primaryStatement SerializerStatement
}
func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("WITH")
for i, cte := range w.ctes {
if i > 0 {
out.WriteString(",")
}
cte.serialize(statement, out, FallTrough(options)...)
}
w.primaryStatement.serialize(statement, out, NoWrap.WithFallTrough(options)...)
}
func (w withImpl) projections() ProjectionList {
return ProjectionList{}
}
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
selectTableImpl
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
return CommonTableExpression{
selectTableImpl: selectTableImpl{
selectStmt: nil,
alias: name,
},
}
}
func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.alias)
}
// AS returns sets definition for a CTE
func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition {
c.selectStmt = statement
return CommonTableExpressionDefinition{cte: c}
}
// CommonTableExpressionDefinition contains implementation details of CTE
type CommonTableExpressionDefinition struct {
cte *CommonTableExpression
}
func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.cte.alias)
out.WriteString("AS")
c.cte.selectStmt.serialize(statement, out, FallTrough(options)...)
}