Set statement refactor.

This commit is contained in:
go-jet 2019-07-01 19:41:49 +02:00
parent 461911889a
commit ab0f790bc3
5 changed files with 239 additions and 194 deletions

View file

@ -1,6 +0,0 @@
package jet
type rowsType interface {
clause
projections() []projection
}

View file

@ -28,6 +28,13 @@ type SelectStatement interface {
OFFSET(offset int64) SelectStatement
FOR(lock SelectLock) SelectStatement
UNION(rhs SelectStatement) SelectStatement
UNION_ALL(rhs SelectStatement) SelectStatement
INTERSECT(rhs SelectStatement) SelectStatement
INTERSECT_ALL(rhs SelectStatement) SelectStatement
EXCEPT(rhs SelectStatement) SelectStatement
EXCEPT_ALL(rhs SelectStatement) SelectStatement
AsTable(alias string) ExpressionTable
projections() []projection
@ -39,6 +46,7 @@ func SELECT(projection1 projection, projections ...projection) SelectStatement {
type selectStatementImpl struct {
expressionInterfaceImpl
parent SelectStatement
table ReadableTable
distinct bool
@ -46,8 +54,8 @@ type selectStatementImpl struct {
where BoolExpression
groupBy []groupByClause
having BoolExpression
orderBy []OrderByClause
orderBy []OrderByClause
limit, offset int64
lockFor SelectLock
@ -63,13 +71,86 @@ func newSelectStatement(table ReadableTable, projections []projection) SelectSta
}
newSelect.expressionInterfaceImpl.parent = newSelect
newSelect.parent = newSelect
return newSelect
}
func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement {
s.table = table
return s
return s.parent
}
func (s *selectStatementImpl) AsTable(alias string) ExpressionTable {
return newExpressionTable(s.parent, alias, s.parent.projections())
}
func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
s.where = expression
return s.parent
}
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) SelectStatement {
s.groupBy = groupByClauses
return s.parent
}
func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement {
s.having = expression
return s.parent
}
func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
s.orderBy = clauses
return s.parent
}
func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement {
s.offset = offset
return s.parent
}
func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement {
s.limit = limit
return s.parent
}
func (s *selectStatementImpl) DISTINCT() SelectStatement {
s.distinct = true
return s.parent
}
func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement {
s.lockFor = lock
return s.parent
}
func (s *selectStatementImpl) UNION(rhs SelectStatement) SelectStatement {
return UNION(s.parent, rhs)
}
func (s *selectStatementImpl) UNION_ALL(rhs SelectStatement) SelectStatement {
return UNION_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) INTERSECT(rhs SelectStatement) SelectStatement {
return INTERSECT(s.parent, rhs)
}
func (s *selectStatementImpl) INTERSECT_ALL(rhs SelectStatement) SelectStatement {
return INTERSECT_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) EXCEPT(rhs SelectStatement) SelectStatement {
return EXCEPT(s.parent, rhs)
}
func (s *selectStatementImpl) EXCEPT_ALL(rhs SelectStatement) SelectStatement {
return EXCEPT_ALL(s.parent, rhs)
}
func (s *selectStatementImpl) projections() []projection {
return s.projectionList
}
func (s *selectStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
@ -192,56 +273,26 @@ func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error
}
func (s *selectStatementImpl) DebugSql() (query string, err error) {
return debugSql(s)
return debugSql(s.parent)
}
func (s *selectStatementImpl) projections() []projection {
return s.projectionList
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(s.parent, db, destination)
}
func (s *selectStatementImpl) AsTable(alias string) ExpressionTable {
return newExpressionTable(s.parent, alias, s.projectionList)
func (s *selectStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(s.parent, db, context, destination)
}
func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
s.where = expression
return s
func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s.parent, db)
}
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) SelectStatement {
s.groupBy = groupByClauses
return s
func (s *selectStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(s.parent, db, context)
}
func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement {
s.having = expression
return s
}
func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
s.orderBy = clauses
return s
}
func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement {
s.offset = offset
return s
}
func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement {
s.limit = limit
return s
}
func (s *selectStatementImpl) DISTINCT() SelectStatement {
s.distinct = true
return s
}
func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement {
s.lockFor = lock
return s
}
// SelectLock
type SelectLock interface {
clause
@ -288,19 +339,3 @@ func (s *selectLockImpl) serialize(statement statementType, out *queryData, opti
return nil
}
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(s, db, destination)
}
func (s *selectStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(s, db, context, destination)
}
func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s, db)
}
func (s *selectStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(s, db, context)
}

View file

@ -1,23 +1,35 @@
package jet
import (
"context"
"database/sql"
"errors"
"github.com/go-jet/jet/execution"
)
type SetStatement interface {
Statement
Expression
func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
}
ORDER_BY(clauses ...OrderByClause) SetStatement
LIMIT(limit int64) SetStatement
OFFSET(offset int64) SetStatement
func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
}
AsTable(alias string) ExpressionTable
func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...))
}
projections() []projection
func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...))
}
func EXCEPT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(except, false, toSelectList(lhs, rhs, selects...))
}
func EXCEPT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement {
return newSetStatementImpl(except, true, toSelectList(lhs, rhs, selects...))
}
func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement {
return append([]SelectStatement{lhs, rhs}, selects...)
}
const (
@ -26,71 +38,30 @@ const (
except = "EXCEPT"
)
func UNION(selects ...rowsType) SetStatement {
return newSetStatementImpl(union, false, selects...)
}
func UNION_ALL(selects ...rowsType) SetStatement {
return newSetStatementImpl(union, true, selects...)
}
func INTERSECT(selects ...rowsType) SetStatement {
return newSetStatementImpl(intersect, false, selects...)
}
func INTERSECT_ALL(selects ...rowsType) SetStatement {
return newSetStatementImpl(intersect, true, selects...)
}
func EXCEPT(selects ...rowsType) SetStatement {
return newSetStatementImpl(except, false, selects...)
}
func EXCEPT_ALL(selects ...rowsType) SetStatement {
return newSetStatementImpl(except, true, selects...)
}
// Similar to selectStatementImpl, but less complete
type setStatementImpl struct {
expressionInterfaceImpl
selectStatementImpl
operator string
selects []rowsType
orderBy []OrderByClause
limit, offset int64
all bool
operator string
all bool
selects []SelectStatement
}
func newSetStatementImpl(operator string, all bool, selects ...rowsType) SetStatement {
func newSetStatementImpl(operator string, all bool, selects []SelectStatement) SelectStatement {
setStatement := &setStatementImpl{
operator: operator,
selects: selects,
limit: -1,
offset: -1,
all: all,
selects: selects,
}
setStatement.expressionInterfaceImpl.parent = setStatement
setStatement.selectStatementImpl.expressionInterfaceImpl.parent = setStatement
setStatement.selectStatementImpl.parent = setStatement
setStatement.limit = -1
setStatement.offset = -1
return setStatement
}
func (s *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement {
s.orderBy = orderBy
return s
}
func (s *setStatementImpl) LIMIT(limit int64) SetStatement {
s.limit = limit
return s
}
func (s *setStatementImpl) OFFSET(offset int64) SetStatement {
s.offset = offset
return s
}
func (s *setStatementImpl) projections() []projection {
if len(s.selects) > 0 {
return s.selects[0].projections()
@ -98,10 +69,6 @@ func (s *setStatementImpl) projections() []projection {
return []projection{}
}
func (s *setStatementImpl) AsTable(alias string) ExpressionTable {
return newExpressionTable(s.parent, alias, s.projections())
}
func (s *setStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
if s == nil {
return errors.New("Set expression is nil. ")
@ -153,6 +120,10 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
out.newLine()
}
if selectStmt == nil {
return errors.New("select statement is nil")
}
err := selectStmt.serialize(set_statement, out)
if err != nil {
@ -198,23 +169,3 @@ func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) {
query, args = queryData.finalize()
return
}
func (s *setStatementImpl) DebugSql() (query string, err error) {
return debugSql(s)
}
func (s *setStatementImpl) Query(db execution.DB, destination interface{}) error {
return query(s, db, destination)
}
func (s *setStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error {
return queryContext(s, db, context, destination)
}
func (s *setStatementImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s, db)
}
func (s *setStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) {
return execContext(s, db, context)
}

View file

@ -1,34 +1,13 @@
package jet
import (
"fmt"
"gotest.tools/assert"
"testing"
)
func TestUnionNoSelect(t *testing.T) {
_, _, err := UNION().Sql()
assert.Assert(t, err != nil)
//fmt.Println(err.Error())
//fmt.Print(query, args)
}
func TestUnionOneSelect(t *testing.T) {
_, _, err := UNION(
table1.SELECT(table1Col1),
).Sql()
assert.Assert(t, err != nil)
}
func TestUnionTwoSelect(t *testing.T) {
query, args, err := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
var expectedSql = `
(
(
SELECT table1.col1 AS "table1.col1"
@ -40,19 +19,71 @@ func TestUnionTwoSelect(t *testing.T) {
FROM db.table2
)
);
`)
assert.Equal(t, len(args), 0)
`
unionStmt1 := table1.
SELECT(table1Col1).
UNION(
table2.SELECT(table2Col3),
)
unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3))
assertStatement(t, unionStmt1, expectedSql)
assertStatement(t, unionStmt2, expectedSql)
}
func TestUnionThreeSelect(t *testing.T) {
query, args, err := UNION(
func TestUnionNilSelect(t *testing.T) {
unionStmt := table1.
SELECT(table1Col1).
UNION(nil)
assertStatementErr(t, unionStmt, "select statement is nil")
}
func TestUnionThreeSelect1(t *testing.T) {
unionStmt1 := table1.SELECT(table1Col1).
UNION(
table2.SELECT(table2Col3),
).
UNION(
table3.SELECT(table3Col1),
)
var expectedSql = `
(
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
UNION
(
SELECT table3.col1 AS "table3.col1"
FROM db.table3
)
);
`
assertStatement(t, unionStmt1, expectedSql)
}
func TestUnionThreeSelect2(t *testing.T) {
unionStmt2 := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
table3.SELECT(table3Col1),
).Sql()
)
assert.NilError(t, err)
assert.Equal(t, query, `
var expectedSql = `
(
(
SELECT table1.col1 AS "table1.col1"
@ -69,18 +100,19 @@ func TestUnionThreeSelect(t *testing.T) {
FROM db.table3
)
);
`)
assert.Equal(t, len(args), 0)
`
assertStatement(t, unionStmt2, expectedSql)
}
func TestUnionWithOrderBy(t *testing.T) {
query, args, err := UNION(
unionStmt := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).ORDER_BY(table1Col1.ASC()).Sql()
).
ORDER_BY(table1Col1.ASC())
assert.NilError(t, err)
assert.Equal(t, query, `
assertStatement(t, unionStmt, `
(
(
SELECT table1.col1 AS "table1.col1"
@ -94,14 +126,15 @@ func TestUnionWithOrderBy(t *testing.T) {
)
ORDER BY "table1.col1" ASC;
`)
assert.Equal(t, len(args), 0)
}
func TestUnionWithLimit(t *testing.T) {
func TestUnionWithLimitAndOffset(t *testing.T) {
query, args, err := UNION(
table1.SELECT(table1Col1),
table2.SELECT(table2Col3),
).LIMIT(10).OFFSET(11).Sql()
).
LIMIT(10).
OFFSET(11).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `
@ -150,11 +183,8 @@ func TestUnionInUnion(t *testing.T) {
UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)),
)
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 0)
assert.Equal(t, queryStr, expectedSql)
fmt.Println(query.Sql())
assertStatement(t, query, expectedSql)
}
func TestUnionALL(t *testing.T) {

View file

@ -963,6 +963,41 @@ OFFSET 20;
})
}
func TestAllSetOperators(t *testing.T) {
select1 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17600)).AND(Payment.PaymentID.LT(Int(17610))))
select2 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17620)).AND(Payment.PaymentID.LT(Int(17630))))
type setOperator func(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement
operators := []setOperator{
UNION,
UNION_ALL,
INTERSECT,
INTERSECT_ALL,
EXCEPT,
EXCEPT_ALL,
}
expectedDestLen := []int{
20,
20,
0,
0,
10,
10,
}
for i, operator := range operators {
query := operator(select1, select2)
dest := []model.Payment{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), expectedDestLen[i])
}
}
func TestSelectWithCase(t *testing.T) {
expectedQuery := `
SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE' WHEN 2 THEN 'TWO' WHEN 3 THEN 'THREE' ELSE 'OTHER' END) AS "staff_id_num"