Set statement refactor.

This commit is contained in:
go-jet 2019-07-01 19:41:49 +02:00
parent 461911889a
commit ab0f790bc3
5 changed files with 239 additions and 194 deletions

View file

@ -1,23 +1,35 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
type SetStatement interface {
Statement
Expression
func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
}
ORDER_BY(clauses ...OrderByClause) SetStatement
LIMIT(limit int64) SetStatement
OFFSET(offset int64) SetStatement
func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
}
AsTable(alias string) ExpressionTable
func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...))
}
projections() []projection
func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...))
}
func EXCEPT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(except, false, toSelectList(lhs, rhs, selects...))
}
func EXCEPT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(except, true, toSelectList(lhs, rhs, selects...))
}
func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement {
return append([]SelectStatement{lhs, rhs}, selects...)
}
const (
@ -26,71 +38,30 @@ const (
except = "EXCEPT"
)
func UNION(selects ...rowsType) SetStatement {
return newSetStatementImpl(union, false, selects...)
}
func UNION_ALL(selects ...rowsType) SetStatement {
return newSetStatementImpl(union, true, selects...)
}
func INTERSECT(selects ...rowsType) SetStatement {
return newSetStatementImpl(intersect, false, selects...)
}
func INTERSECT_ALL(selects ...rowsType) SetStatement {
return newSetStatementImpl(intersect, true, selects...)
}
func EXCEPT(selects ...rowsType) SetStatement {
return newSetStatementImpl(except, false, selects...)
}
func EXCEPT_ALL(selects ...rowsType) SetStatement {
return newSetStatementImpl(except, true, selects...)
}
// Similar to selectStatementImpl, but less complete
type setStatementImpl struct {
expressionInterfaceImpl
selectStatementImpl
operator string
selects []rowsType
orderBy []OrderByClause
limit, offset int64
all bool
operator string
all bool
selects []SelectStatement
}
func newSetStatementImpl(operator string, all bool, selects ...rowsType) SetStatement {
func newSetStatementImpl(operator string, all bool, selects []SelectStatement) SelectStatement {
setStatement := &setStatementImpl{
operator: operator,
selects: selects,
limit: -1,
offset: -1,
all: all,
selects: selects,
}
setStatement.expressionInterfaceImpl.parent = setStatement
setStatement.selectStatementImpl.expressionInterfaceImpl.parent = setStatement
setStatement.selectStatementImpl.parent = setStatement
setStatement.limit = -1
setStatement.offset = -1
return setStatement
}
func (s *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement {
s.orderBy = orderBy
return s
}
func (s *setStatementImpl) LIMIT(limit int64) SetStatement {
s.limit = limit
return s
}
func (s *setStatementImpl) OFFSET(offset int64) SetStatement {
s.offset = offset
return s
}
func (s *setStatementImpl) projections() []projection {
if len(s.selects) > 0 {
return s.selects[0].projections()
@ -98,10 +69,6 @@ func (s *setStatementImpl) projections() []projection {
return []projection{}
}
func (s *setStatementImpl) AsTable(alias string) ExpressionTable {
return newExpressionTable(s.parent, alias, s.projections())
}
func (s *setStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
if s == nil {
return errors.New("Set expression is nil. ")
@ -153,6 +120,10 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
out.newLine()
}
if selectStmt == nil {
return errors.New("select statement is nil")
}
err := selectStmt.serialize(set_statement, out)
if err != nil {
@ -198,23 +169,3 @@ func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) {
query, args = queryData.finalize()
return
}
func (s *setStatementImpl) DebugSql() (query string, err error) {
return debugSql(s)
}
func (s *setStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(s, db, destination)
}
func (s *setStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(s, db, context, destination)
}
func (s *setStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s, db)
}
func (s *setStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(s, db, context)
}