From bc6a2bbcacd4676e30b0fe8e0c419988377a5595 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Sat, 20 Apr 2019 19:49:29 +0200 Subject: [PATCH] Add support for DELETE statements. --- sqlbuilder/delete_statement.go | 70 ++++ sqlbuilder/delete_statement_test.go | 18 + sqlbuilder/insert_statement.go | 16 +- sqlbuilder/select_statement.go | 29 +- sqlbuilder/statement.go | 552 +++++++--------------------- sqlbuilder/statement_test.go | 43 --- sqlbuilder/table.go | 5 +- sqlbuilder/union_statement.go | 203 ++++++++++ sqlbuilder/update_statement.go | 22 +- sqlbuilder/utils.go | 29 +- tests/generator_test.go | 36 +- tests/insert_test.go | 4 +- tests/sample_test.go | 6 +- tests/update_test.go | 2 +- 14 files changed, 492 insertions(+), 543 deletions(-) create mode 100644 sqlbuilder/delete_statement.go create mode 100644 sqlbuilder/delete_statement_test.go create mode 100644 sqlbuilder/union_statement.go diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go new file mode 100644 index 0000000..d978996 --- /dev/null +++ b/sqlbuilder/delete_statement.go @@ -0,0 +1,70 @@ +package sqlbuilder + +import ( + "bytes" + "database/sql" + "github.com/dropbox/godropbox/errors" + "github.com/sub0zero/go-sqlbuilder/types" +) + +type DeleteStatement interface { + Statement + + WHERE(expression BoolExpression) DeleteStatement +} + +func newDeleteStatement(table WritableTable) DeleteStatement { + return &deleteStatementImpl{ + table: table, + } +} + +type deleteStatementImpl struct { + table WritableTable + where BoolExpression + order *listClause +} + +func (u *deleteStatementImpl) Query(db types.Db, destination interface{}) error { + return Query(u, db, destination) +} + +func (u *deleteStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(u, db) +} + +func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { + d.where = expression + return d +} + +func (d *deleteStatementImpl) String() (sql string, err error) { + buf := new(bytes.Buffer) + _, _ = buf.WriteString("DELETE FROM ") + + if d.table == nil { + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) + } + + if err = d.table.SerializeSql(buf); err != nil { + return + } + + if d.where == nil { + return "", errors.Newf("Deleting without a WHERE clause. Generated sql: %s", buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = d.where.SerializeSql(buf); err != nil { + return + } + + if d.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = d.order.SerializeSql(buf); err != nil { + return + } + } + + return buf.String() + ";", nil +} diff --git a/sqlbuilder/delete_statement_test.go b/sqlbuilder/delete_statement_test.go new file mode 100644 index 0000000..0ef31ae --- /dev/null +++ b/sqlbuilder/delete_statement_test.go @@ -0,0 +1,18 @@ +package sqlbuilder + +import ( + "gotest.tools/assert" + "testing" +) + +func TestDeleteUnconditionally(t *testing.T) { + _, err := table1.Delete().String() + assert.Assert(t, err != nil) +} + +func TestDeleteWithWhere(t *testing.T) { + sql, err := table1.Delete().WHERE(table1Col1.EqL(1)).String() + assert.NilError(t, err) + + assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = 1;") +} diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 0d9ebcf..a02a6fe 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -21,8 +21,6 @@ type InsertStatement interface { RETURNING(projections ...Projection) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement - - Execute(db types.Db) (sql.Result, error) } func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { @@ -47,16 +45,12 @@ type insertStatementImpl struct { errors []string } -func (i *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) { - query, err := i.String() +func (s *insertStatementImpl) Query(db types.Db, destination interface{}) error { + return Query(s, db, destination) +} - if err != nil { - return - } - - res, err = db.Exec(query) - - return +func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(u, db) } func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 03b7898..33fa57a 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -2,9 +2,9 @@ package sqlbuilder import ( "bytes" + "database/sql" "fmt" "github.com/dropbox/godropbox/errors" - "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" "github.com/sub0zero/go-sqlbuilder/types" ) @@ -13,7 +13,6 @@ type SelectStatement interface { Expression Where(expression BoolExpression) SelectStatement - AndWhere(expression BoolExpression) SelectStatement GroupBy(expressions ...Expression) SelectStatement HAVING(expressions BoolExpression) SelectStatement @@ -27,9 +26,6 @@ type SelectStatement interface { Copy() SelectStatement AsTable(alias string) *SelectStatementTable - - Execute(db types.Db, destination interface{}) error - //ExecuteInTx(tx *sql.Tx, destination interface{}) error } // NOTE: SelectStatement purposely does not implement the Table interface since @@ -86,14 +82,12 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { } } -func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error { - query, err := s.String() +func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error { + return Query(s, db, destination) +} - if err != nil { - return err - } - - return execution.Execute(db, query, destination) +func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(u, db) } func (s *selectStatementImpl) Copy() SelectStatement { @@ -101,17 +95,6 @@ func (s *selectStatementImpl) Copy() SelectStatement { return &ret } -// Further filter the query, instead of replacing the filter -func (q *selectStatementImpl) AndWhere( - expression BoolExpression) SelectStatement { - - if q.where == nil { - return q.Where(expression) - } - q.where = And(q.where, expression) - return q -} - func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { q.where = expression return q diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 952d377..74caa30 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -3,7 +3,7 @@ package sqlbuilder import ( "bytes" "database/sql" - "fmt" + "github.com/sub0zero/go-sqlbuilder/types" "regexp" "github.com/dropbox/godropbox/errors" @@ -12,434 +12,150 @@ import ( type Statement interface { // String returns generated SQL as string. String() (sql string, err error) -} -// By default, rows selected by a UNION statement are out-of-order -// If you have an ORDER BY on an inner SELECT statement, the only thing -// it affects is the LIMIT clause on that inner statement (the ordering will -// still be out-of-order). -type UnionStatement interface { - Statement - - // Warning! You cannot include tableName names for the next 4 clauses, or - // you'll get errors like: - // Table 'server_file_journal' from one of the SELECTs cannot be used in - // global ORDER clause - Where(expression BoolExpression) UnionStatement - AndWhere(expression BoolExpression) UnionStatement - GroupBy(expressions ...Expression) UnionStatement - OrderBy(clauses ...OrderByClause) UnionStatement - - Limit(limit int64) UnionStatement - Offset(offset int64) UnionStatement -} - -type DeleteStatement interface { - Statement - - Where(expression BoolExpression) DeleteStatement - OrderBy(clauses ...OrderByClause) DeleteStatement - Limit(limit int64) DeleteStatement - Comment(comment string) DeleteStatement + Query(db types.Db, destination interface{}) error + Execute(db types.Db) (sql.Result, error) } // LockStatement is used to take Read/Write lock on tables. // See http://dev.mysql.com/doc/refman/5.0/en/lock-tables.html -type LockStatement interface { - Statement - - AddReadLock(table *Table) LockStatement - AddWriteLock(table *Table) LockStatement -} - -// UnlockStatement can be used to release tableName locks taken using LockStatement. -// NOTE: You can not selectively release a lock and continue to hold lock on -// another tableName. UnlockStatement releases all the lock held in the current -// session. -type UnlockStatement interface { - Statement -} - -// SetGtidNextStatement returns a SQL statement that can be used to explicitly set the next GTID. -type GtidNextStatement interface { - Statement -} - +//type LockStatement interface { +// Statement // -// UNION SELECT Statement ====================================================== +// AddReadLock(table *Table) LockStatement +// AddWriteLock(table *Table) LockStatement +//} + +//// UnlockStatement can be used to release tableName locks taken using LockStatement. +//// NOTE: You can not selectively release a lock and continue to hold lock on +//// another tableName. UnlockStatement releases all the lock held in the current +//// session. +//type UnlockStatement interface { +// Statement +//} // - -func Union(selects ...SelectStatement) UnionStatement { - return &unionStatementImpl{ - selects: selects, - limit: -1, - offset: -1, - unique: true, - } -} - -func UnionAll(selects ...SelectStatement) UnionStatement { - return &unionStatementImpl{ - selects: selects, - limit: -1, - offset: -1, - unique: false, - } -} - -// Similar to selectStatementImpl, but less complete -type unionStatementImpl struct { - selects []SelectStatement - where BoolExpression - group *listClause - order *listClause - limit, offset int64 - // True if results of the union should be deduped. - unique bool -} - -func (us *unionStatementImpl) Execute(db *sql.DB, data interface{}) error { - return nil -} - -func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { - us.where = expression - return us -} - -// Further filter the query, instead of replacing the filter -func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { - if us.where == nil { - return us.Where(expression) - } - us.where = And(us.where, expression) - return us -} - -func (us *unionStatementImpl) GroupBy( - expressions ...Expression) UnionStatement { - - us.group = &listClause{ - clauses: make([]Clause, len(expressions), len(expressions)), - includeParentheses: false, - } - - for i, e := range expressions { - us.group.clauses[i] = e - } - return us -} - -func (us *unionStatementImpl) OrderBy( - clauses ...OrderByClause) UnionStatement { - - us.order = newOrderByListClause(clauses...) - return us -} - -func (us *unionStatementImpl) Limit(limit int64) UnionStatement { - us.limit = limit - return us -} - -func (us *unionStatementImpl) Offset(offset int64) UnionStatement { - us.offset = offset - return us -} - -func (us *unionStatementImpl) String() (sql string, err error) { - if len(us.selects) == 0 { - return "", errors.Newf("Union statement must have at least one SELECT") - } - - if len(us.selects) == 1 { - return us.selects[0].String() - } - - // Union statements in MySQL require that the same number of columns in each subquery - var projections []Projection - - for _, statement := range us.selects { - // do a type assertion to get at the underlying struct - statementImpl, ok := statement.(*selectStatementImpl) - if !ok { - return "", errors.Newf( - "Expected inner select statement to be of type " + - "selectStatementImpl") - } - - // check that for limit for statements with order by clauses - if statementImpl.order != nil && statementImpl.limit < 0 { - return "", errors.Newf( - "All inner selects in Union statement must have LIMIT if " + - "they have ORDER BY") - } - - // check number of projections - if projections == nil { - projections = statementImpl.projections - } else { - if len(projections) != len(statementImpl.projections) { - return "", errors.Newf( - "All inner selects in Union statement must select the " + - "same number of columns. For sanity, you probably " + - "want to select the same tableName columns in the same " + - "order. If you are selecting on multiple tables, " + - "use Null to pad to the right number of fields.") - } - } - } - - buf := new(bytes.Buffer) - for i, statement := range us.selects { - if i != 0 { - if us.unique { - _, _ = buf.WriteString(" UNION ") - } else { - _, _ = buf.WriteString(" UNION ALL ") - } - } - _, _ = buf.WriteString("(") - selectSql, err := statement.String() - if err != nil { - return "", err - } - _, _ = buf.WriteString(selectSql) - _, _ = buf.WriteString(")") - } - - if us.where != nil { - _, _ = buf.WriteString(" WHERE ") - if err = us.where.SerializeSql(buf); err != nil { - return - } - } - - if us.group != nil { - _, _ = buf.WriteString(" GROUP BY ") - if err = us.group.SerializeSql(buf); err != nil { - return - } - } - - if us.order != nil { - _, _ = buf.WriteString(" ORDER BY ") - if err = us.order.SerializeSql(buf); err != nil { - return - } - } - - if us.limit >= 0 { - if us.offset >= 0 { - _, _ = buf.WriteString( - fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) - } else { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) - } - } - return buf.String(), nil -} - +//// SetGtidNextStatement returns a SQL statement that can be used to explicitly set the next GTID. +//type GtidNextStatement interface { +// Statement +//} // -// DELETE statement =========================================================== +//// +//// UNION SELECT Statement ====================================================== +//// +//// +//// LOCK statement =========================================================== +//// // - -func newDeleteStatement(table WritableTable) DeleteStatement { - return &deleteStatementImpl{ - table: table, - limit: -1, - } -} - -type deleteStatementImpl struct { - table WritableTable - where BoolExpression - order *listClause - limit int64 - comment string -} - -func (d *deleteStatementImpl) Execute(db *sql.DB, data interface{}) error { - return nil -} - -func (d *deleteStatementImpl) Where(expression BoolExpression) DeleteStatement { - d.where = expression - return d -} - -func (d *deleteStatementImpl) OrderBy( - clauses ...OrderByClause) DeleteStatement { - - d.order = newOrderByListClause(clauses...) - return d -} - -func (d *deleteStatementImpl) Limit(limit int64) DeleteStatement { - d.limit = limit - return d -} - -func (d *deleteStatementImpl) Comment(comment string) DeleteStatement { - d.comment = comment - return d -} - -func (d *deleteStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("DELETE FROM ") - - if err = writeComment(d.comment, buf); err != nil { - return - } - - if d.table == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) - } - - if err = d.table.SerializeSql(buf); err != nil { - return - } - - if d.where == nil { - return "", errors.Newf( - "Deleting without a WHERE clause. Generated sql: %s", - buf.String()) - } - - _, _ = buf.WriteString(" WHERE ") - if err = d.where.SerializeSql(buf); err != nil { - return - } - - if d.order != nil { - _, _ = buf.WriteString(" ORDER BY ") - if err = d.order.SerializeSql(buf); err != nil { - return - } - } - - if d.limit >= 0 { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", d.limit)) - } - - return buf.String(), nil -} - +//// NewLockStatement returns a SQL representing empty set of locks. You need to use +//// AddReadLock/AddWriteLock to add tables that need to be locked. +//// NOTE: You need at least one lock in the set for it to be a valid statement. +//func NewLockStatement() LockStatement { +// return &lockStatementImpl{} +//} // -// LOCK statement =========================================================== +//type lockStatementImpl struct { +// locks []tableLock +//} // - -// NewLockStatement returns a SQL representing empty set of locks. You need to use -// AddReadLock/AddWriteLock to add tables that need to be locked. -// NOTE: You need at least one lock in the set for it to be a valid statement. -func NewLockStatement() LockStatement { - return &lockStatementImpl{} -} - -type lockStatementImpl struct { - locks []tableLock -} - -type tableLock struct { - t *Table - w bool -} - -func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error { - return nil -} - -// AddReadLock takes read lock on the tableName. -func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { - s.locks = append(s.locks, tableLock{t: t, w: false}) - return s -} - -// AddWriteLock takes write lock on the tableName. -func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement { - s.locks = append(s.locks, tableLock{t: t, w: true}) - return s -} - -func (s *lockStatementImpl) String() (sql string, err error) { - if len(s.locks) == 0 { - return "", errors.New("No locks added") - } - - buf := new(bytes.Buffer) - _, _ = buf.WriteString("LOCK TABLES ") - - for idx, lock := range s.locks { - if lock.t == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) - } - - if err = lock.t.SerializeSql(buf); err != nil { - return - } - - if lock.w { - _, _ = buf.WriteString(" WRITE") - } else { - _, _ = buf.WriteString(" READ") - } - - if idx != len(s.locks)-1 { - _, _ = buf.WriteString(", ") - } - } - - return buf.String(), nil -} - -// NewUnlockStatement returns SQL statement that can be used to release tableName locks -// grabbed by the current session. -func NewUnlockStatement() UnlockStatement { - return &unlockStatementImpl{} -} - -type unlockStatementImpl struct { -} - -func (u *unlockStatementImpl) Execute(db *sql.DB, data interface{}) error { - return nil -} - -func (s *unlockStatementImpl) String() (sql string, err error) { - return "UNLOCK TABLES", nil -} - -// SET GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID. -func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement { - return >idNextStatementImpl{ - sid: sid, - gno: gno, - } -} - -type gtidNextStatementImpl struct { - sid []byte - gno uint64 -} - -func (g *gtidNextStatementImpl) Execute(db *sql.DB, data interface{}) error { - return nil -} - -func (s *gtidNextStatementImpl) String() (sql string, err error) { - // This statement sets a session local variable defining what the next transaction ID is. It - // does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we - // don't have to worry about data corruption. - // Because of the string formatting (hex plus an integer), can't morph into another statement. - // See: https://dev.mysql.com/doc/refman/5.7/en/replication-options-gtids.html - const gtidFormatString = "SET GTID_NEXT=\"%x-%x-%x-%x-%x:%d\"" - - buf := new(bytes.Buffer) - _, _ = buf.WriteString(fmt.Sprintf(gtidFormatString, - s.sid[:4], s.sid[4:6], s.sid[6:8], s.sid[8:10], s.sid[10:], s.gno)) - return buf.String(), nil -} +//type tableLock struct { +// t *Table +// w bool +//} +// +//func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error { +// return nil +//} +// +//// AddReadLock takes read lock on the tableName. +//func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { +// s.locks = append(s.locks, tableLock{t: t, w: false}) +// return s +//} +// +//// AddWriteLock takes write lock on the tableName. +//func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement { +// s.locks = append(s.locks, tableLock{t: t, w: true}) +// return s +//} +// +//func (s *lockStatementImpl) String() (sql string, err error) { +// if len(s.locks) == 0 { +// return "", errors.New("No locks added") +// } +// +// buf := new(bytes.Buffer) +// _, _ = buf.WriteString("LOCK TABLES ") +// +// for idx, lock := range s.locks { +// if lock.t == nil { +// return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) +// } +// +// if err = lock.t.SerializeSql(buf); err != nil { +// return +// } +// +// if lock.w { +// _, _ = buf.WriteString(" WRITE") +// } else { +// _, _ = buf.WriteString(" READ") +// } +// +// if idx != len(s.locks)-1 { +// _, _ = buf.WriteString(", ") +// } +// } +// +// return buf.String(), nil +//} +// +//// NewUnlockStatement returns SQL statement that can be used to release tableName locks +//// grabbed by the current session. +//func NewUnlockStatement() UnlockStatement { +// return &unlockStatementImpl{} +//} +// +//type unlockStatementImpl struct { +//} +// +//func (u *unlockStatementImpl) Execute(db *sql.DB, data interface{}) error { +// return nil +//} +// +//func (s *unlockStatementImpl) String() (sql string, err error) { +// return "UNLOCK TABLES", nil +//} +// +//// SET GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID. +//func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement { +// return >idNextStatementImpl{ +// sid: sid, +// gno: gno, +// } +//} +// +//type gtidNextStatementImpl struct { +// sid []byte +// gno uint64 +//} +// +//func (g *gtidNextStatementImpl) Execute(db *sql.DB, data interface{}) error { +// return nil +//} +// +//func (s *gtidNextStatementImpl) String() (sql string, err error) { +// // This statement sets a session local variable defining what the next transaction ID is. It +// // does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we +// // don't have to worry about data corruption. +// // Because of the string formatting (hex plus an integer), can't morph into another statement. +// // See: https://dev.mysql.com/doc/refman/5.7/en/replication-options-gtids.html +// const gtidFormatString = "SET GTID_NEXT=\"%x-%x-%x-%x-%x:%d\"" +// +// buf := new(bytes.Buffer) +// _, _ = buf.WriteString(fmt.Sprintf(gtidFormatString, +// s.sid[:4], s.sid[4:6], s.sid[6:8], s.sid[8:10], s.sid[10:], s.gno)) +// return buf.String(), nil +//} // // Util functions ============================================================= diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index 2afc4d6..86765a2 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -385,49 +385,6 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) { "ON DUPLICATE KEY UPDATE table1.col3=3, table1.col2=4") } -// -// DELETE statement tests ===================================================== -// - -func (s *StmtSuite) TestDeleteUnconditionally(c *gc.C) { - _, err := table1.Delete().String() - c.Assert(err, gc.NotNil) -} - -func (s *StmtSuite) TestDeleteWithWhere(c *gc.C) { - sql, err := table1.Delete().Where(EqL(table1Col1, 1)).String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "DELETE FROM db.table1 WHERE table1.col1=1") -} - -func (s *StmtSuite) TestDeleteWithOrderBy(c *gc.C) { - stmt := table1.Delete().Where(EqL(table1Col1, 1)).OrderBy(table1Col1) - sql, err := stmt.String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "DELETE FROM db.table1 "+ - "WHERE table1.col1=1 "+ - "ORDER BY table1.col1") -} - -func (s *StmtSuite) TestDeleteWithLimit(c *gc.C) { - stmt := table1.Delete().Where(EqL(table1Col1, 1)).Limit(5) - sql, err := stmt.String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "DELETE FROM db.table1 WHERE table1.col1=1 LIMIT 5") -} - // // LOCK/UNLOCK statement tests ================================================ // diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 0602751..5a48202 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -155,10 +155,9 @@ func (t *Table) ForceIndex(index string) *Table { // Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. func (t *Table) SerializeSql(out *bytes.Buffer) error { - if !validIdentifierName(t.schemaName) { - return errors.New("Invalid database name specified") + if t == nil { + return errors.Newf("nil tableName. Generated sql: %s", out.String()) } - _, _ = out.WriteString(t.schemaName) _, _ = out.WriteString(".") _, _ = out.WriteString(t.TableName()) diff --git a/sqlbuilder/union_statement.go b/sqlbuilder/union_statement.go new file mode 100644 index 0000000..462c125 --- /dev/null +++ b/sqlbuilder/union_statement.go @@ -0,0 +1,203 @@ +package sqlbuilder + +import ( + "bytes" + "database/sql" + "fmt" + "github.com/dropbox/godropbox/errors" + "github.com/sub0zero/go-sqlbuilder/types" +) + +// By default, rows selected by a UNION statement are out-of-order +// If you have an ORDER BY on an inner SELECT statement, the only thing +// it affects is the LIMIT clause on that inner statement (the ordering will +// still be out-of-order). +type UnionStatement interface { + Statement + + // Warning! You cannot include tableName names for the next 4 clauses, or + // you'll get errors like: + // Table 'server_file_journal' from one of the SELECTs cannot be used in + // global ORDER clause + Where(expression BoolExpression) UnionStatement + GroupBy(expressions ...Expression) UnionStatement + OrderBy(clauses ...OrderByClause) UnionStatement + + Limit(limit int64) UnionStatement + Offset(offset int64) UnionStatement +} + +func Union(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: true, + } +} + +func UnionAll(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + unique: false, + } +} + +// Similar to selectStatementImpl, but less complete +type unionStatementImpl struct { + selects []SelectStatement + where BoolExpression + group *listClause + order *listClause + limit, offset int64 + // True if results of the union should be deduped. + unique bool +} + +func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { + return Query(s, db, destination) +} + +func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(u, db) +} + +func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { + us.where = expression + return us +} + +// Further filter the query, instead of replacing the filter +func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { + if us.where == nil { + return us.Where(expression) + } + us.where = And(us.where, expression) + return us +} + +func (us *unionStatementImpl) GroupBy( + expressions ...Expression) UnionStatement { + + us.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + us.group.clauses[i] = e + } + return us +} + +func (us *unionStatementImpl) OrderBy( + clauses ...OrderByClause) UnionStatement { + + us.order = newOrderByListClause(clauses...) + return us +} + +func (us *unionStatementImpl) Limit(limit int64) UnionStatement { + us.limit = limit + return us +} + +func (us *unionStatementImpl) Offset(offset int64) UnionStatement { + us.offset = offset + return us +} + +func (us *unionStatementImpl) String() (sql string, err error) { + if len(us.selects) == 0 { + return "", errors.Newf("Union statement must have at least one SELECT") + } + + if len(us.selects) == 1 { + return us.selects[0].String() + } + + // Union statements in MySQL require that the same number of columns in each subquery + var projections []Projection + + for _, statement := range us.selects { + // do a type assertion to get at the underlying struct + statementImpl, ok := statement.(*selectStatementImpl) + if !ok { + return "", errors.Newf( + "Expected inner select statement to be of type " + + "selectStatementImpl") + } + + // check that for limit for statements with order by clauses + if statementImpl.order != nil && statementImpl.limit < 0 { + return "", errors.Newf( + "All inner selects in Union statement must have LIMIT if " + + "they have ORDER BY") + } + + // check number of projections + if projections == nil { + projections = statementImpl.projections + } else { + if len(projections) != len(statementImpl.projections) { + return "", errors.Newf( + "All inner selects in Union statement must select the " + + "same number of columns. For sanity, you probably " + + "want to select the same tableName columns in the same " + + "order. If you are selecting on multiple tables, " + + "use Null to pad to the right number of fields.") + } + } + } + + buf := new(bytes.Buffer) + for i, statement := range us.selects { + if i != 0 { + if us.unique { + _, _ = buf.WriteString(" UNION ") + } else { + _, _ = buf.WriteString(" UNION ALL ") + } + } + _, _ = buf.WriteString("(") + selectSql, err := statement.String() + if err != nil { + return "", err + } + _, _ = buf.WriteString(selectSql) + _, _ = buf.WriteString(")") + } + + if us.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = us.where.SerializeSql(buf); err != nil { + return + } + } + + if us.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = us.group.SerializeSql(buf); err != nil { + return + } + } + + if us.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = us.order.SerializeSql(buf); err != nil { + return + } + } + + if us.limit >= 0 { + if us.offset >= 0 { + _, _ = buf.WriteString( + fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) + } + } + return buf.String(), nil +} diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index 5536b90..a35f731 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -4,7 +4,6 @@ import ( "bytes" "database/sql" "github.com/dropbox/godropbox/errors" - "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" "github.com/sub0zero/go-sqlbuilder/types" ) @@ -14,9 +13,6 @@ type UpdateStatement interface { SET(values ...interface{}) UpdateStatement WHERE(expression BoolExpression) UpdateStatement RETURNING(projections ...Projection) UpdateStatement - - Query(db types.Db, destination interface{}) error - Execute(db types.Db) (sql.Result, error) } func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement { @@ -35,25 +31,11 @@ type updateStatementImpl struct { } func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error { - query, err := u.String() - - if err != nil { - return err - } - - return execution.Execute(db, query, destination) + return Query(u, db, destination) } func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) { - query, err := u.String() - - if err != nil { - return - } - - res, err = db.Exec(query) - - return + return Execute(u, db) } func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement { diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index b6db6cb..105665b 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -1,6 +1,11 @@ package sqlbuilder -import "bytes" +import ( + "bytes" + "database/sql" + "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" + "github.com/sub0zero/go-sqlbuilder/types" +) func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error { for i, value := range expressions { @@ -33,3 +38,25 @@ func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error return nil } + +func Query(statement Statement, db types.Db, destination interface{}) error { + query, err := statement.String() + + if err != nil { + return err + } + + return execution.Execute(db, query, destination) +} + +func Execute(statement Statement, db types.Db) (res sql.Result, err error) { + query, err := statement.String() + + if err != nil { + return + } + + res, err = db.Exec(query) + + return +} diff --git a/tests/generator_test.go b/tests/generator_test.go index c600cef..23103eb 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -33,7 +33,7 @@ func TestSelect_ScanToStruct(t *testing.T) { assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`) - err = query.Execute(db, &actor) + err = query.Query(db, &actor) assert.NilError(t, err) @@ -57,7 +57,7 @@ func TestSelect_ScanToSlice(t *testing.T) { fmt.Println(queryStr) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`) - err = query.Execute(db, &customers) + err = query.Query(db, &customers) assert.NilError(t, err) assert.Equal(t, len(customers), 599) @@ -113,7 +113,7 @@ func TestJoinQuerySlice(t *testing.T) { assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = 'NC-17' LIMIT 15`) //fmt.Println(queryStr) - err = query.Execute(db, &filmsPerLanguage) + err = query.Query(db, &filmsPerLanguage) assert.NilError(t, err) @@ -132,7 +132,7 @@ func TestJoinQuerySlice(t *testing.T) { //spew.Dump(filmsPerLanguage) filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} - err = query.Execute(db, &filmsPerLanguageWithPtrs) + err = query.Query(db, &filmsPerLanguageWithPtrs) assert.NilError(t, err) assert.Equal(t, len(filmsPerLanguage), 1) @@ -152,7 +152,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { Limit(limit) filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} - err := query.Execute(db, &filmsPerLanguageWithPtrs) + err := query.Query(db, &filmsPerLanguageWithPtrs) //spew.Dump(filmsPerLanguageWithPtrs) @@ -166,7 +166,7 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { customers := []model.Customer{} - err := query.Execute(db, &customers) + err := query.Query(db, &customers) assert.NilError(t, err) @@ -180,7 +180,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { err := Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). OrderBy(Customer.FirstName.Asc()). - Execute(db, &customersAsc) + Query(db, &customersAsc) assert.NilError(t, err) @@ -190,7 +190,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { customersDesc := []model.Customer{} err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). OrderBy(Customer.FirstName.Desc()). - Execute(db, &customersDesc) + Query(db, &customersDesc) assert.NilError(t, err) @@ -203,7 +203,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { customersAscDesc := []model.Customer{} err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). OrderBy(Customer.FirstName.Asc(), Customer.LastName.Desc()). - Execute(db, &customersAscDesc) + Query(db, &customersAscDesc) assert.NilError(t, err) @@ -240,7 +240,7 @@ func TestSelectFullJoin(t *testing.T) { Customer *model.Customer }{} - err = query.Execute(db, &allCustomersAndAddress) + err = query.Query(db, &allCustomersAndAddress) assert.NilError(t, err) assert.Equal(t, len(allCustomersAndAddress), 603) @@ -269,7 +269,7 @@ func TestSelectFullCrossJoin(t *testing.T) { customerAddresCrosJoined := []model.Customer{} - err = query.Execute(db, &customerAddresCrosJoined) + err = query.Query(db, &customerAddresCrosJoined) assert.Equal(t, len(customerAddresCrosJoined), 1000) @@ -302,7 +302,7 @@ func TestSelectSelfJoin(t *testing.T) { F2 F2 }{} - err = query.Execute(db, &theSameLengthFilms) + err = query.Query(db, &theSameLengthFilms) assert.NilError(t, err) @@ -337,7 +337,7 @@ func TestSelectAliasColumn(t *testing.T) { films := []thesameLengthFilms{} - err = query.Execute(db, &films) + err = query.Query(db, &films) assert.NilError(t, err) @@ -378,7 +378,7 @@ func TestSelectSelfReferenceType(t *testing.T) { staffs := []staff{} - err = query.Execute(db, &staffs) + err = query.Query(db, &staffs) assert.NilError(t, err) @@ -455,7 +455,7 @@ func TestSelectQueryScalar(t *testing.T) { fmt.Println(queryStr) maxRentalRateFilms := []model.Film{} - err = query.Execute(db, &maxRentalRateFilms) + err = query.Query(db, &maxRentalRateFilms) assert.NilError(t, err) @@ -505,7 +505,7 @@ func TestSelectGroupByHaving(t *testing.T) { customerPaymentSum := []CustomerPaymentSum{} - err = customersPaymentQuery.Execute(db, &customerPaymentSum) + err = customersPaymentQuery.Query(db, &customerPaymentSum) assert.NilError(t, err) @@ -542,7 +542,7 @@ func TestSelectGroupBy2(t *testing.T) { assert.NilError(t, err) fmt.Println(queryStr) - err = query.Execute(db, &customersWithAmounts) + err = query.Query(db, &customersWithAmounts) assert.NilError(t, err) //spew.Dump(customersWithAmounts) @@ -576,7 +576,7 @@ func TestSelectTimeColumns(t *testing.T) { payments := []model.Payment{} - err = query.Execute(db, &payments) + err = query.Query(db, &payments) assert.NilError(t, err) diff --git a/tests/insert_test.go b/tests/insert_test.go index 499db50..2695d97 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -36,7 +36,7 @@ func TestInsertValues(t *testing.T) { link := []model.Link{} - err = table.Link.SELECT(table.Link.AllColumns).Execute(db, &link) + err = table.Link.SELECT(table.Link.AllColumns).Query(db, &link) assert.NilError(t, err) @@ -103,7 +103,7 @@ func TestInsertQuery(t *testing.T) { assert.NilError(t, err) allLinks := []model.Link{} - err = table.Link.SELECT(table.Link.AllColumns).Execute(db, &allLinks) + err = table.Link.SELECT(table.Link.AllColumns).Query(db, &allLinks) assert.NilError(t, err) spew.Dump(allLinks) diff --git a/tests/sample_test.go b/tests/sample_test.go index f303a4b..f8ad90c 100644 --- a/tests/sample_test.go +++ b/tests/sample_test.go @@ -21,7 +21,7 @@ func TestUUIDType(t *testing.T) { //assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid", all_types.json AS "all_types.json", all_types.jsonb AS "all_types.jsonb" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11`) result := model.AllTypes{} - err = query.Execute(db, &result) + err = query.Query(db, &result) spew.Dump(result) } @@ -36,7 +36,7 @@ func TestEnumType(t *testing.T) { result := []model.Person{} - err = query.Execute(db, &result) + err = query.Query(db, &result) assert.NilError(t, err) //spew.Dump(result) @@ -48,7 +48,7 @@ func TestEnumType(t *testing.T) { result2 := []Person{} - err = query.Execute(db, &result2) + err = query.Query(db, &result2) assert.NilError(t, err) diff --git a/tests/update_test.go b/tests/update_test.go index 73e8563..add2fb3 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -39,7 +39,7 @@ func TestUpdateValues(t *testing.T) { err = table.Link.SELECT(table.Link.AllColumns). Where(table.Link.Name.EqL("Bong")). - Execute(db, &links) + Query(db, &links) assert.NilError(t, err)