Add support for DELETE statements.

This commit is contained in:
zer0sub 2019-04-20 19:49:29 +02:00
parent 70d6f84375
commit bc6a2bbcac
14 changed files with 492 additions and 543 deletions

View file

@ -0,0 +1,70 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
)
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
}
func newDeleteStatement(table WritableTable) DeleteStatement {
return &deleteStatementImpl{
table: table,
}
}
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
order *listClause
}
func (u *deleteStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(u, db, destination)
}
func (u *deleteStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.where = expression
return d
}
func (d *deleteStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("DELETE FROM ")
if d.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = d.table.SerializeSql(buf); err != nil {
return
}
if d.where == nil {
return "", errors.Newf("Deleting without a WHERE clause. Generated sql: %s", buf.String())
}
_, _ = buf.WriteString(" WHERE ")
if err = d.where.SerializeSql(buf); err != nil {
return
}
if d.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = d.order.SerializeSql(buf); err != nil {
return
}
}
return buf.String() + ";", nil
}

View file

@ -0,0 +1,18 @@
package sqlbuilder
import (
"gotest.tools/assert"
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
_, err := table1.Delete().String()
assert.Assert(t, err != nil)
}
func TestDeleteWithWhere(t *testing.T) {
sql, err := table1.Delete().WHERE(table1Col1.EqL(1)).String()
assert.NilError(t, err)
assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = 1;")
}

View file

@ -21,8 +21,6 @@ type InsertStatement interface {
RETURNING(projections ...Projection) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
Execute(db types.Db) (sql.Result, error)
}
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
@ -47,16 +45,12 @@ type insertStatementImpl struct {
errors []string
}
func (i *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
query, err := i.String()
func (s *insertStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(s, db, destination)
}
if err != nil {
return
}
res, err = db.Exec(query)
return
func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {

View file

@ -2,9 +2,9 @@ package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types"
)
@ -13,7 +13,6 @@ type SelectStatement interface {
Expression
Where(expression BoolExpression) SelectStatement
AndWhere(expression BoolExpression) SelectStatement
GroupBy(expressions ...Expression) SelectStatement
HAVING(expressions BoolExpression) SelectStatement
@ -27,9 +26,6 @@ type SelectStatement interface {
Copy() SelectStatement
AsTable(alias string) *SelectStatementTable
Execute(db types.Db, destination interface{}) error
//ExecuteInTx(tx *sql.Tx, destination interface{}) error
}
// NOTE: SelectStatement purposely does not implement the Table interface since
@ -86,14 +82,12 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
}
}
func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error {
query, err := s.String()
func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(s, db, destination)
}
if err != nil {
return err
}
return execution.Execute(db, query, destination)
func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (s *selectStatementImpl) Copy() SelectStatement {
@ -101,17 +95,6 @@ func (s *selectStatementImpl) Copy() SelectStatement {
return &ret
}
// Further filter the query, instead of replacing the filter
func (q *selectStatementImpl) AndWhere(
expression BoolExpression) SelectStatement {
if q.where == nil {
return q.Where(expression)
}
q.where = And(q.where, expression)
return q
}
func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement {
q.where = expression
return q

View file

@ -3,7 +3,7 @@ package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"github.com/sub0zero/go-sqlbuilder/types"
"regexp"
"github.com/dropbox/godropbox/errors"
@ -12,434 +12,150 @@ import (
type Statement interface {
// String returns generated SQL as string.
String() (sql string, err error)
}
// By default, rows selected by a UNION statement are out-of-order
// If you have an ORDER BY on an inner SELECT statement, the only thing
// it affects is the LIMIT clause on that inner statement (the ordering will
// still be out-of-order).
type UnionStatement interface {
Statement
// Warning! You cannot include tableName names for the next 4 clauses, or
// you'll get errors like:
// Table 'server_file_journal' from one of the SELECTs cannot be used in
// global ORDER clause
Where(expression BoolExpression) UnionStatement
AndWhere(expression BoolExpression) UnionStatement
GroupBy(expressions ...Expression) UnionStatement
OrderBy(clauses ...OrderByClause) UnionStatement
Limit(limit int64) UnionStatement
Offset(offset int64) UnionStatement
}
type DeleteStatement interface {
Statement
Where(expression BoolExpression) DeleteStatement
OrderBy(clauses ...OrderByClause) DeleteStatement
Limit(limit int64) DeleteStatement
Comment(comment string) DeleteStatement
Query(db types.Db, destination interface{}) error
Execute(db types.Db) (sql.Result, error)
}
// LockStatement is used to take Read/Write lock on tables.
// See http://dev.mysql.com/doc/refman/5.0/en/lock-tables.html
type LockStatement interface {
Statement
AddReadLock(table *Table) LockStatement
AddWriteLock(table *Table) LockStatement
}
// UnlockStatement can be used to release tableName locks taken using LockStatement.
// NOTE: You can not selectively release a lock and continue to hold lock on
// another tableName. UnlockStatement releases all the lock held in the current
// session.
type UnlockStatement interface {
Statement
}
// SetGtidNextStatement returns a SQL statement that can be used to explicitly set the next GTID.
type GtidNextStatement interface {
Statement
}
//type LockStatement interface {
// Statement
//
// UNION SELECT Statement ======================================================
// AddReadLock(table *Table) LockStatement
// AddWriteLock(table *Table) LockStatement
//}
//// UnlockStatement can be used to release tableName locks taken using LockStatement.
//// NOTE: You can not selectively release a lock and continue to hold lock on
//// another tableName. UnlockStatement releases all the lock held in the current
//// session.
//type UnlockStatement interface {
// Statement
//}
//
func Union(selects ...SelectStatement) UnionStatement {
return &unionStatementImpl{
selects: selects,
limit: -1,
offset: -1,
unique: true,
}
}
func UnionAll(selects ...SelectStatement) UnionStatement {
return &unionStatementImpl{
selects: selects,
limit: -1,
offset: -1,
unique: false,
}
}
// Similar to selectStatementImpl, but less complete
type unionStatementImpl struct {
selects []SelectStatement
where BoolExpression
group *listClause
order *listClause
limit, offset int64
// True if results of the union should be deduped.
unique bool
}
func (us *unionStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement {
us.where = expression
return us
}
// Further filter the query, instead of replacing the filter
func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement {
if us.where == nil {
return us.Where(expression)
}
us.where = And(us.where, expression)
return us
}
func (us *unionStatementImpl) GroupBy(
expressions ...Expression) UnionStatement {
us.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
us.group.clauses[i] = e
}
return us
}
func (us *unionStatementImpl) OrderBy(
clauses ...OrderByClause) UnionStatement {
us.order = newOrderByListClause(clauses...)
return us
}
func (us *unionStatementImpl) Limit(limit int64) UnionStatement {
us.limit = limit
return us
}
func (us *unionStatementImpl) Offset(offset int64) UnionStatement {
us.offset = offset
return us
}
func (us *unionStatementImpl) String() (sql string, err error) {
if len(us.selects) == 0 {
return "", errors.Newf("Union statement must have at least one SELECT")
}
if len(us.selects) == 1 {
return us.selects[0].String()
}
// Union statements in MySQL require that the same number of columns in each subquery
var projections []Projection
for _, statement := range us.selects {
// do a type assertion to get at the underlying struct
statementImpl, ok := statement.(*selectStatementImpl)
if !ok {
return "", errors.Newf(
"Expected inner select statement to be of type " +
"selectStatementImpl")
}
// check that for limit for statements with order by clauses
if statementImpl.order != nil && statementImpl.limit < 0 {
return "", errors.Newf(
"All inner selects in Union statement must have LIMIT if " +
"they have ORDER BY")
}
// check number of projections
if projections == nil {
projections = statementImpl.projections
} else {
if len(projections) != len(statementImpl.projections) {
return "", errors.Newf(
"All inner selects in Union statement must select the " +
"same number of columns. For sanity, you probably " +
"want to select the same tableName columns in the same " +
"order. If you are selecting on multiple tables, " +
"use Null to pad to the right number of fields.")
}
}
}
buf := new(bytes.Buffer)
for i, statement := range us.selects {
if i != 0 {
if us.unique {
_, _ = buf.WriteString(" UNION ")
} else {
_, _ = buf.WriteString(" UNION ALL ")
}
}
_, _ = buf.WriteString("(")
selectSql, err := statement.String()
if err != nil {
return "", err
}
_, _ = buf.WriteString(selectSql)
_, _ = buf.WriteString(")")
}
if us.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = us.where.SerializeSql(buf); err != nil {
return
}
}
if us.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = us.group.SerializeSql(buf); err != nil {
return
}
}
if us.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = us.order.SerializeSql(buf); err != nil {
return
}
}
if us.limit >= 0 {
if us.offset >= 0 {
_, _ = buf.WriteString(
fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit))
}
}
return buf.String(), nil
}
//// SetGtidNextStatement returns a SQL statement that can be used to explicitly set the next GTID.
//type GtidNextStatement interface {
// Statement
//}
//
// DELETE statement ===========================================================
////
//// UNION SELECT Statement ======================================================
////
////
//// LOCK statement ===========================================================
////
//
func newDeleteStatement(table WritableTable) DeleteStatement {
return &deleteStatementImpl{
table: table,
limit: -1,
}
}
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
order *listClause
limit int64
comment string
}
func (d *deleteStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
func (d *deleteStatementImpl) Where(expression BoolExpression) DeleteStatement {
d.where = expression
return d
}
func (d *deleteStatementImpl) OrderBy(
clauses ...OrderByClause) DeleteStatement {
d.order = newOrderByListClause(clauses...)
return d
}
func (d *deleteStatementImpl) Limit(limit int64) DeleteStatement {
d.limit = limit
return d
}
func (d *deleteStatementImpl) Comment(comment string) DeleteStatement {
d.comment = comment
return d
}
func (d *deleteStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("DELETE FROM ")
if err = writeComment(d.comment, buf); err != nil {
return
}
if d.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = d.table.SerializeSql(buf); err != nil {
return
}
if d.where == nil {
return "", errors.Newf(
"Deleting without a WHERE clause. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(" WHERE ")
if err = d.where.SerializeSql(buf); err != nil {
return
}
if d.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = d.order.SerializeSql(buf); err != nil {
return
}
}
if d.limit >= 0 {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", d.limit))
}
return buf.String(), nil
}
//// NewLockStatement returns a SQL representing empty set of locks. You need to use
//// AddReadLock/AddWriteLock to add tables that need to be locked.
//// NOTE: You need at least one lock in the set for it to be a valid statement.
//func NewLockStatement() LockStatement {
// return &lockStatementImpl{}
//}
//
// LOCK statement ===========================================================
//type lockStatementImpl struct {
// locks []tableLock
//}
//
// NewLockStatement returns a SQL representing empty set of locks. You need to use
// AddReadLock/AddWriteLock to add tables that need to be locked.
// NOTE: You need at least one lock in the set for it to be a valid statement.
func NewLockStatement() LockStatement {
return &lockStatementImpl{}
}
type lockStatementImpl struct {
locks []tableLock
}
type tableLock struct {
t *Table
w bool
}
func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
// AddReadLock takes read lock on the tableName.
func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement {
s.locks = append(s.locks, tableLock{t: t, w: false})
return s
}
// AddWriteLock takes write lock on the tableName.
func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement {
s.locks = append(s.locks, tableLock{t: t, w: true})
return s
}
func (s *lockStatementImpl) String() (sql string, err error) {
if len(s.locks) == 0 {
return "", errors.New("No locks added")
}
buf := new(bytes.Buffer)
_, _ = buf.WriteString("LOCK TABLES ")
for idx, lock := range s.locks {
if lock.t == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = lock.t.SerializeSql(buf); err != nil {
return
}
if lock.w {
_, _ = buf.WriteString(" WRITE")
} else {
_, _ = buf.WriteString(" READ")
}
if idx != len(s.locks)-1 {
_, _ = buf.WriteString(", ")
}
}
return buf.String(), nil
}
// NewUnlockStatement returns SQL statement that can be used to release tableName locks
// grabbed by the current session.
func NewUnlockStatement() UnlockStatement {
return &unlockStatementImpl{}
}
type unlockStatementImpl struct {
}
func (u *unlockStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
func (s *unlockStatementImpl) String() (sql string, err error) {
return "UNLOCK TABLES", nil
}
// SET GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID.
func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement {
return &gtidNextStatementImpl{
sid: sid,
gno: gno,
}
}
type gtidNextStatementImpl struct {
sid []byte
gno uint64
}
func (g *gtidNextStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
func (s *gtidNextStatementImpl) String() (sql string, err error) {
// This statement sets a session local variable defining what the next transaction ID is. It
// does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we
// don't have to worry about data corruption.
// Because of the string formatting (hex plus an integer), can't morph into another statement.
// See: https://dev.mysql.com/doc/refman/5.7/en/replication-options-gtids.html
const gtidFormatString = "SET GTID_NEXT=\"%x-%x-%x-%x-%x:%d\""
buf := new(bytes.Buffer)
_, _ = buf.WriteString(fmt.Sprintf(gtidFormatString,
s.sid[:4], s.sid[4:6], s.sid[6:8], s.sid[8:10], s.sid[10:], s.gno))
return buf.String(), nil
}
//type tableLock struct {
// t *Table
// w bool
//}
//
//func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error {
// return nil
//}
//
//// AddReadLock takes read lock on the tableName.
//func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement {
// s.locks = append(s.locks, tableLock{t: t, w: false})
// return s
//}
//
//// AddWriteLock takes write lock on the tableName.
//func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement {
// s.locks = append(s.locks, tableLock{t: t, w: true})
// return s
//}
//
//func (s *lockStatementImpl) String() (sql string, err error) {
// if len(s.locks) == 0 {
// return "", errors.New("No locks added")
// }
//
// buf := new(bytes.Buffer)
// _, _ = buf.WriteString("LOCK TABLES ")
//
// for idx, lock := range s.locks {
// if lock.t == nil {
// return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
// }
//
// if err = lock.t.SerializeSql(buf); err != nil {
// return
// }
//
// if lock.w {
// _, _ = buf.WriteString(" WRITE")
// } else {
// _, _ = buf.WriteString(" READ")
// }
//
// if idx != len(s.locks)-1 {
// _, _ = buf.WriteString(", ")
// }
// }
//
// return buf.String(), nil
//}
//
//// NewUnlockStatement returns SQL statement that can be used to release tableName locks
//// grabbed by the current session.
//func NewUnlockStatement() UnlockStatement {
// return &unlockStatementImpl{}
//}
//
//type unlockStatementImpl struct {
//}
//
//func (u *unlockStatementImpl) Execute(db *sql.DB, data interface{}) error {
// return nil
//}
//
//func (s *unlockStatementImpl) String() (sql string, err error) {
// return "UNLOCK TABLES", nil
//}
//
//// SET GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID.
//func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement {
// return &gtidNextStatementImpl{
// sid: sid,
// gno: gno,
// }
//}
//
//type gtidNextStatementImpl struct {
// sid []byte
// gno uint64
//}
//
//func (g *gtidNextStatementImpl) Execute(db *sql.DB, data interface{}) error {
// return nil
//}
//
//func (s *gtidNextStatementImpl) String() (sql string, err error) {
// // This statement sets a session local variable defining what the next transaction ID is. It
// // does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we
// // don't have to worry about data corruption.
// // Because of the string formatting (hex plus an integer), can't morph into another statement.
// // See: https://dev.mysql.com/doc/refman/5.7/en/replication-options-gtids.html
// const gtidFormatString = "SET GTID_NEXT=\"%x-%x-%x-%x-%x:%d\""
//
// buf := new(bytes.Buffer)
// _, _ = buf.WriteString(fmt.Sprintf(gtidFormatString,
// s.sid[:4], s.sid[4:6], s.sid[6:8], s.sid[8:10], s.sid[10:], s.gno))
// return buf.String(), nil
//}
//
// Util functions =============================================================

View file

@ -385,49 +385,6 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) {
"ON DUPLICATE KEY UPDATE table1.col3=3, table1.col2=4")
}
//
// DELETE statement tests =====================================================
//
func (s *StmtSuite) TestDeleteUnconditionally(c *gc.C) {
_, err := table1.Delete().String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestDeleteWithWhere(c *gc.C) {
sql, err := table1.Delete().Where(EqL(table1Col1, 1)).String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"DELETE FROM db.table1 WHERE table1.col1=1")
}
func (s *StmtSuite) TestDeleteWithOrderBy(c *gc.C) {
stmt := table1.Delete().Where(EqL(table1Col1, 1)).OrderBy(table1Col1)
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"DELETE FROM db.table1 "+
"WHERE table1.col1=1 "+
"ORDER BY table1.col1")
}
func (s *StmtSuite) TestDeleteWithLimit(c *gc.C) {
stmt := table1.Delete().Where(EqL(table1Col1, 1)).Limit(5)
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"DELETE FROM db.table1 WHERE table1.col1=1 LIMIT 5")
}
//
// LOCK/UNLOCK statement tests ================================================
//

View file

@ -155,10 +155,9 @@ func (t *Table) ForceIndex(index string) *Table {
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement.
func (t *Table) SerializeSql(out *bytes.Buffer) error {
if !validIdentifierName(t.schemaName) {
return errors.New("Invalid database name specified")
if t == nil {
return errors.Newf("nil tableName. Generated sql: %s", out.String())
}
_, _ = out.WriteString(t.schemaName)
_, _ = out.WriteString(".")
_, _ = out.WriteString(t.TableName())

View file

@ -0,0 +1,203 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
)
// By default, rows selected by a UNION statement are out-of-order
// If you have an ORDER BY on an inner SELECT statement, the only thing
// it affects is the LIMIT clause on that inner statement (the ordering will
// still be out-of-order).
type UnionStatement interface {
Statement
// Warning! You cannot include tableName names for the next 4 clauses, or
// you'll get errors like:
// Table 'server_file_journal' from one of the SELECTs cannot be used in
// global ORDER clause
Where(expression BoolExpression) UnionStatement
GroupBy(expressions ...Expression) UnionStatement
OrderBy(clauses ...OrderByClause) UnionStatement
Limit(limit int64) UnionStatement
Offset(offset int64) UnionStatement
}
func Union(selects ...SelectStatement) UnionStatement {
return &unionStatementImpl{
selects: selects,
limit: -1,
offset: -1,
unique: true,
}
}
func UnionAll(selects ...SelectStatement) UnionStatement {
return &unionStatementImpl{
selects: selects,
limit: -1,
offset: -1,
unique: false,
}
}
// Similar to selectStatementImpl, but less complete
type unionStatementImpl struct {
selects []SelectStatement
where BoolExpression
group *listClause
order *listClause
limit, offset int64
// True if results of the union should be deduped.
unique bool
}
func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(s, db, destination)
}
func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement {
us.where = expression
return us
}
// Further filter the query, instead of replacing the filter
func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement {
if us.where == nil {
return us.Where(expression)
}
us.where = And(us.where, expression)
return us
}
func (us *unionStatementImpl) GroupBy(
expressions ...Expression) UnionStatement {
us.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
us.group.clauses[i] = e
}
return us
}
func (us *unionStatementImpl) OrderBy(
clauses ...OrderByClause) UnionStatement {
us.order = newOrderByListClause(clauses...)
return us
}
func (us *unionStatementImpl) Limit(limit int64) UnionStatement {
us.limit = limit
return us
}
func (us *unionStatementImpl) Offset(offset int64) UnionStatement {
us.offset = offset
return us
}
func (us *unionStatementImpl) String() (sql string, err error) {
if len(us.selects) == 0 {
return "", errors.Newf("Union statement must have at least one SELECT")
}
if len(us.selects) == 1 {
return us.selects[0].String()
}
// Union statements in MySQL require that the same number of columns in each subquery
var projections []Projection
for _, statement := range us.selects {
// do a type assertion to get at the underlying struct
statementImpl, ok := statement.(*selectStatementImpl)
if !ok {
return "", errors.Newf(
"Expected inner select statement to be of type " +
"selectStatementImpl")
}
// check that for limit for statements with order by clauses
if statementImpl.order != nil && statementImpl.limit < 0 {
return "", errors.Newf(
"All inner selects in Union statement must have LIMIT if " +
"they have ORDER BY")
}
// check number of projections
if projections == nil {
projections = statementImpl.projections
} else {
if len(projections) != len(statementImpl.projections) {
return "", errors.Newf(
"All inner selects in Union statement must select the " +
"same number of columns. For sanity, you probably " +
"want to select the same tableName columns in the same " +
"order. If you are selecting on multiple tables, " +
"use Null to pad to the right number of fields.")
}
}
}
buf := new(bytes.Buffer)
for i, statement := range us.selects {
if i != 0 {
if us.unique {
_, _ = buf.WriteString(" UNION ")
} else {
_, _ = buf.WriteString(" UNION ALL ")
}
}
_, _ = buf.WriteString("(")
selectSql, err := statement.String()
if err != nil {
return "", err
}
_, _ = buf.WriteString(selectSql)
_, _ = buf.WriteString(")")
}
if us.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = us.where.SerializeSql(buf); err != nil {
return
}
}
if us.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = us.group.SerializeSql(buf); err != nil {
return
}
}
if us.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = us.order.SerializeSql(buf); err != nil {
return
}
}
if us.limit >= 0 {
if us.offset >= 0 {
_, _ = buf.WriteString(
fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit))
}
}
return buf.String(), nil
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types"
)
@ -14,9 +13,6 @@ type UpdateStatement interface {
SET(values ...interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
Query(db types.Db, destination interface{}) error
Execute(db types.Db) (sql.Result, error)
}
func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement {
@ -35,25 +31,11 @@ type updateStatementImpl struct {
}
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
query, err := u.String()
if err != nil {
return err
}
return execution.Execute(db, query, destination)
return Query(u, db, destination)
}
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
query, err := u.String()
if err != nil {
return
}
res, err = db.Exec(query)
return
return Execute(u, db)
}
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {

View file

@ -1,6 +1,11 @@
package sqlbuilder
import "bytes"
import (
"bytes"
"database/sql"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types"
)
func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error {
for i, value := range expressions {
@ -33,3 +38,25 @@ func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error
return nil
}
func Query(statement Statement, db types.Db, destination interface{}) error {
query, err := statement.String()
if err != nil {
return err
}
return execution.Execute(db, query, destination)
}
func Execute(statement Statement, db types.Db) (res sql.Result, err error) {
query, err := statement.String()
if err != nil {
return
}
res, err = db.Exec(query)
return
}