Extend set operator to accept sets as parameters.

This commit is contained in:
zer0sub 2019-05-13 11:48:58 +02:00
parent 240ddd65e6
commit 1fd63b8783
4 changed files with 57 additions and 8 deletions

View file

@ -9,6 +9,7 @@ import (
type selectStatement interface { type selectStatement interface {
Statement Statement
expression expression
hasRows()
DISTINCT() selectStatement DISTINCT() selectStatement
FROM(table readableTable) selectStatement FROM(table readableTable) selectStatement
@ -33,6 +34,7 @@ func SELECT(projection ...projection) selectStatement {
// mysql's subquery performance is horrible. // mysql's subquery performance is horrible.
type selectStatementImpl struct { type selectStatementImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
isRowsType
table readableTable table readableTable
distinct bool distinct bool

View file

@ -15,6 +15,7 @@ const (
type setStatement interface { type setStatement interface {
Statement Statement
expression expression
hasRows()
ORDER_BY(clauses ...orderByClause) setStatement ORDER_BY(clauses ...orderByClause) setStatement
LIMIT(limit int64) setStatement LIMIT(limit int64) setStatement
@ -23,43 +24,44 @@ type setStatement interface {
AsTable(alias string) expressionTable AsTable(alias string) expressionTable
} }
func UNION(selects ...selectStatement) setStatement { func UNION(selects ...rowsType) setStatement {
return newSetStatementImpl(union, false, selects...) return newSetStatementImpl(union, false, selects...)
} }
func UNION_ALL(selects ...selectStatement) setStatement { func UNION_ALL(selects ...rowsType) setStatement {
return newSetStatementImpl(union, true, selects...) return newSetStatementImpl(union, true, selects...)
} }
func INTERSECT(selects ...selectStatement) setStatement { func INTERSECT(selects ...rowsType) setStatement {
return newSetStatementImpl(intersect, false, selects...) return newSetStatementImpl(intersect, false, selects...)
} }
func INTERSECT_ALL(selects ...selectStatement) setStatement { func INTERSECT_ALL(selects ...rowsType) setStatement {
return newSetStatementImpl(intersect, true, selects...) return newSetStatementImpl(intersect, true, selects...)
} }
func EXCEPT(selects ...selectStatement) setStatement { func EXCEPT(selects ...rowsType) setStatement {
return newSetStatementImpl(except, false, selects...) return newSetStatementImpl(except, false, selects...)
} }
func EXCEPT_ALL(selects ...selectStatement) setStatement { func EXCEPT_ALL(selects ...rowsType) setStatement {
return newSetStatementImpl(except, true, selects...) return newSetStatementImpl(except, true, selects...)
} }
// Similar to selectStatementImpl, but less complete // Similar to selectStatementImpl, but less complete
type setStatementImpl struct { type setStatementImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
isRowsType
operator string operator string
selects []selectStatement selects []rowsType
orderBy []orderByClause orderBy []orderByClause
limit, offset int64 limit, offset int64
// True if results of the union should be deduped. // True if results of the union should be deduped.
all bool all bool
} }
func newSetStatementImpl(operator string, all bool, selects ...selectStatement) setStatement { func newSetStatementImpl(operator string, all bool, selects ...rowsType) setStatement {
setStatement := &setStatementImpl{ setStatement := &setStatementImpl{
operator: operator, operator: operator,
selects: selects, selects: selects,

View file

@ -126,3 +126,39 @@ OFFSET $2;
`) `)
assert.Equal(t, len(args), 2) assert.Equal(t, len(args), 2)
} }
func TestUnionInUnion(t *testing.T) {
expectedSql := `
(
(
SELECT table2.col3 AS "table2.col3",
table2.col3 AS "table2.col3"
FROM db.table2
)
UNION
(
(
SELECT table1.col1 AS "table1.col1"
FROM db.table1
)
UNION ALL
(
SELECT table2.col3 AS "table2.col3"
FROM db.table2
)
)
);
`
query := UNION(
SELECT(table2Col3, table2Col3).FROM(table2),
UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)),
)
queryStr, args, err := query.Sql()
fmt.Println(queryStr)
assert.NilError(t, err)
assert.Equal(t, len(args), 0)
assert.Equal(t, queryStr, expectedSql)
}

View file

@ -1,5 +1,14 @@
package sqlbuilder package sqlbuilder
type rowsType interface {
clause
hasRows()
}
type isRowsType struct{}
func (i *isRowsType) hasRows() {}
// A clause that can be used in orderBy by // A clause that can be used in orderBy by
// A clause that is selectable. // A clause that is selectable.