From 9b826fff6e92d9478275631593293db27e3cc683 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Wed, 1 May 2019 16:34:03 +0200 Subject: [PATCH] Add support for INTERSECT and EXCEPT statements. --- sqlbuilder/set_statement.go | 146 ++++++++++++++++++++++++++++++++++ sqlbuilder/statement_test.go | 12 +-- sqlbuilder/union_statement.go | 123 ---------------------------- 3 files changed, 152 insertions(+), 129 deletions(-) create mode 100644 sqlbuilder/set_statement.go delete mode 100644 sqlbuilder/union_statement.go diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go new file mode 100644 index 0000000..16fcec8 --- /dev/null +++ b/sqlbuilder/set_statement.go @@ -0,0 +1,146 @@ +package sqlbuilder + +import ( + "database/sql" + "github.com/dropbox/godropbox/errors" + "github.com/sub0zero/go-sqlbuilder/types" +) + +const ( + union = "UNION" + intersect = "INTERSECT" + except = "EXCEPT" +) + +type SetStatement interface { + Statement + + ORDER_BY(clauses ...OrderByClause) SetStatement + LIMIT(limit int64) SetStatement + OFFSET(offset int64) SetStatement +} + +func UNION(selects ...SelectStatement) SetStatement { + return newSetStatementImpl(union, false, selects...) +} + +func UNION_ALL(selects ...SelectStatement) SetStatement { + return newSetStatementImpl(union, true, selects...) +} + +func INTERSECT(selects ...SelectStatement) SetStatement { + return newSetStatementImpl(intersect, false, selects...) +} + +func INTERSECT_ALL(selects ...SelectStatement) SetStatement { + return newSetStatementImpl(intersect, true, selects...) +} + +func EXCEPT(selects ...SelectStatement) SetStatement { + return newSetStatementImpl(except, false, selects...) +} + +func EXCEPT_ALL(selects ...SelectStatement) SetStatement { + return newSetStatementImpl(except, true, selects...) +} + +// Similar to selectStatementImpl, but less complete +type setStatementImpl struct { + operator string + selects []SelectStatement + order *listClause + limit, offset int64 + // True if results of the union should be deduped. + all bool +} + +func newSetStatementImpl(operator string, all bool, selects ...SelectStatement) *setStatementImpl { + return &setStatementImpl{ + operator: operator, + selects: selects, + limit: -1, + offset: -1, + all: all, + } +} + +func (us *setStatementImpl) ORDER_BY(clauses ...OrderByClause) SetStatement { + + us.order = newOrderByListClause(clauses...) + return us +} + +func (us *setStatementImpl) LIMIT(limit int64) SetStatement { + us.limit = limit + return us +} + +func (us *setStatementImpl) OFFSET(offset int64) SetStatement { + us.offset = offset + return us +} + +func (us *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error { + if len(us.selects) == 0 { + return errors.Newf("UNION statement must have at least one SELECT") + } + + out.WriteString("(") + + for i, selectStmt := range us.selects { + if i > 0 { + out.WriteString(" " + us.operator + " ") + + if us.all { + out.WriteString(" ALL ") + } + } + + err := selectStmt.Serialize(out, options...) + + if err != nil { + return err + } + } + + out.WriteString(")") + + if us.order != nil { + out.WriteString(" ORDER BY ") + if err := us.order.Serialize(out, NO_TABLE_NAME); err != nil { + return err + } + } + + if us.limit >= 0 { + out.WriteString(" LIMIT ") + out.InsertArgument(us.limit) + } + + if us.offset >= 0 { + out.WriteString(" OFFSET ") + out.InsertArgument(us.offset) + } + + return nil +} + +func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) { + queryData := &queryData{} + + err = us.Serialize(queryData) + + if err != nil { + return + } + + return queryData.buff.String(), queryData.args, nil +} + +func (s *setStatementImpl) Query(db types.Db, destination interface{}) error { + return Query(s, db, destination) +} + +func (u *setStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(u, db) +} diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index cc8239c..0e84eaa 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -414,7 +414,7 @@ func (s *StmtSuite) TestUnionSelectStatement(c *gc.C) { table1.Select(table1Col1).Where(LtL(table1Col1, 23)), ) - q := Union(select_queries...) + q := UNION(select_queries...) sql, err := q.String() @@ -436,7 +436,7 @@ func (s *StmtSuite) TestUnionLimitWithoutOrderBy(c *gc.C) { table1.Select(table1Col1).Where(LtL(table1Col1, 23)), ) - q := Union(select_queries...) + q := UNION(select_queries...) _, err := q.String() @@ -444,7 +444,7 @@ func (s *StmtSuite) TestUnionLimitWithoutOrderBy(c *gc.C) { c.Assert( errors.GetMessage(err), gc.Equals, - "All inner selects in Union statement must have LIMIT if they have ORDER BY") + "All inner selects in UNION statement must have LIMIT if they have ORDER BY") } func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { @@ -461,7 +461,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { table1.Select(table1Col1).Where(LtL(table1Col1, 23)).OrderBy(table1Col4).Limit(20), ) - q := Union(select_queries...) + q := UNION(select_queries...) q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) q = q.ORDER_BY(Desc(table1Col4), Asc(table1Col3)) q = q.LIMIT(5) @@ -472,7 +472,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { c.Assert( errors.GetMessage(err), gc.Equals, - "All inner selects in Union statement must select the "+ + "All inner selects in UNION statement must select the "+ "same number of columns. For sanity, you probably "+ "want to select the same tableName columns in the same "+ "orderBy. If you are selecting on multiple tables, "+ @@ -499,7 +499,7 @@ func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { ).Where(LtL(table1Col1, 23)).OrderBy(table1Col4).Limit(20), ) - q := Union(select_queries...) + q := UNION(select_queries...) q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) q = q.ORDER_BY(Desc(table1Col4), Asc(table1Col3)) diff --git a/sqlbuilder/union_statement.go b/sqlbuilder/union_statement.go deleted file mode 100644 index 1555ca6..0000000 --- a/sqlbuilder/union_statement.go +++ /dev/null @@ -1,123 +0,0 @@ -package sqlbuilder - -import ( - "database/sql" - "github.com/dropbox/godropbox/errors" - "github.com/sub0zero/go-sqlbuilder/types" -) - -type UnionStatement interface { - Statement - - ORDER_BY(clauses ...OrderByClause) UnionStatement - LIMIT(limit int64) UnionStatement - OFFSET(offset int64) UnionStatement -} - -func Union(selects ...SelectStatement) UnionStatement { - return &unionStatementImpl{ - selects: selects, - limit: -1, - offset: -1, - all: true, - } -} - -func UnionAll(selects ...SelectStatement) UnionStatement { - return &unionStatementImpl{ - selects: selects, - limit: -1, - offset: -1, - all: false, - } -} - -// Similar to selectStatementImpl, but less complete -type unionStatementImpl struct { - selects []SelectStatement - order *listClause - limit, offset int64 - // True if results of the union should be deduped. - all bool -} - -func (us *unionStatementImpl) Serialize(out *queryData, options ...serializeOption) error { - if len(us.selects) == 0 { - return errors.Newf("Union statement must have at least one SELECT") - } - - out.WriteString("(") - - for i, selectStmt := range us.selects { - if i > 0 { - out.WriteString(" UNION ") - - if us.all { - out.WriteString(" ALL ") - } - } - - err := selectStmt.Serialize(out, options...) - - if err != nil { - return err - } - } - - out.WriteString(")") - - if us.order != nil { - out.WriteString(" ORDER BY ") - if err := us.order.Serialize(out, NO_TABLE_NAME); err != nil { - return err - } - } - - if us.limit >= 0 { - out.WriteString(" LIMIT ") - out.InsertArgument(us.limit) - } - - if us.offset >= 0 { - out.WriteString(" OFFSET ") - out.InsertArgument(us.offset) - } - - return nil -} - -func (us *unionStatementImpl) ORDER_BY(clauses ...OrderByClause) UnionStatement { - - us.order = newOrderByListClause(clauses...) - return us -} - -func (us *unionStatementImpl) LIMIT(limit int64) UnionStatement { - us.limit = limit - return us -} - -func (us *unionStatementImpl) OFFSET(offset int64) UnionStatement { - us.offset = offset - return us -} - -func (us *unionStatementImpl) Sql() (query string, args []interface{}, err error) { - queryData := &queryData{} - - err = us.Serialize(queryData) - - if err != nil { - return - } - - return queryData.buff.String(), queryData.args, nil -} - -func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { - return Query(s, db, destination) -} - -func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { - return Execute(u, db) -}