diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 6b39af4..6b20e2c 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -9,6 +9,7 @@ import ( type selectStatement interface { Statement expression + hasRows() DISTINCT() selectStatement FROM(table readableTable) selectStatement @@ -33,6 +34,7 @@ func SELECT(projection ...projection) selectStatement { // mysql's subquery performance is horrible. type selectStatementImpl struct { expressionInterfaceImpl + isRowsType table readableTable distinct bool diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index a18ccee..d58e4f3 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -15,6 +15,7 @@ const ( type setStatement interface { Statement expression + hasRows() ORDER_BY(clauses ...orderByClause) setStatement LIMIT(limit int64) setStatement @@ -23,43 +24,44 @@ type setStatement interface { AsTable(alias string) expressionTable } -func UNION(selects ...selectStatement) setStatement { +func UNION(selects ...rowsType) setStatement { return newSetStatementImpl(union, false, selects...) } -func UNION_ALL(selects ...selectStatement) setStatement { +func UNION_ALL(selects ...rowsType) setStatement { return newSetStatementImpl(union, true, selects...) } -func INTERSECT(selects ...selectStatement) setStatement { +func INTERSECT(selects ...rowsType) setStatement { return newSetStatementImpl(intersect, false, selects...) } -func INTERSECT_ALL(selects ...selectStatement) setStatement { +func INTERSECT_ALL(selects ...rowsType) setStatement { return newSetStatementImpl(intersect, true, selects...) } -func EXCEPT(selects ...selectStatement) setStatement { +func EXCEPT(selects ...rowsType) setStatement { return newSetStatementImpl(except, false, selects...) } -func EXCEPT_ALL(selects ...selectStatement) setStatement { +func EXCEPT_ALL(selects ...rowsType) setStatement { return newSetStatementImpl(except, true, selects...) } // Similar to selectStatementImpl, but less complete type setStatementImpl struct { expressionInterfaceImpl + isRowsType operator string - selects []selectStatement + selects []rowsType orderBy []orderByClause limit, offset int64 // True if results of the union should be deduped. all bool } -func newSetStatementImpl(operator string, all bool, selects ...selectStatement) setStatement { +func newSetStatementImpl(operator string, all bool, selects ...rowsType) setStatement { setStatement := &setStatementImpl{ operator: operator, selects: selects, diff --git a/sqlbuilder/set_statement_test.go b/sqlbuilder/set_statement_test.go index 88caebe..8fb38b8 100644 --- a/sqlbuilder/set_statement_test.go +++ b/sqlbuilder/set_statement_test.go @@ -126,3 +126,39 @@ OFFSET $2; `) assert.Equal(t, len(args), 2) } + +func TestUnionInUnion(t *testing.T) { + expectedSql := ` +( + ( + SELECT table2.col3 AS "table2.col3", + table2.col3 AS "table2.col3" + FROM db.table2 + ) + UNION + + ( + ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 + ) + UNION ALL + ( + SELECT table2.col3 AS "table2.col3" + FROM db.table2 + ) + ) +); +` + query := UNION( + SELECT(table2Col3, table2Col3).FROM(table2), + UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)), + ) + + queryStr, args, err := query.Sql() + + fmt.Println(queryStr) + assert.NilError(t, err) + assert.Equal(t, len(args), 0) + assert.Equal(t, queryStr, expectedSql) +} diff --git a/sqlbuilder/types.go b/sqlbuilder/types.go index cbe7cce..5d2f136 100644 --- a/sqlbuilder/types.go +++ b/sqlbuilder/types.go @@ -1,5 +1,14 @@ package sqlbuilder +type rowsType interface { + clause + hasRows() +} + +type isRowsType struct{} + +func (i *isRowsType) hasRows() {} + // A clause that can be used in orderBy by // A clause that is selectable.