Add support for UNION statements

This commit is contained in:
zer0sub 2019-05-01 14:42:46 +02:00
parent fef8f0ef83
commit 8a3521a016
8 changed files with 140 additions and 219 deletions

View file

@ -11,6 +11,7 @@ type serializeOption int
const ( const (
SKIP_DEFAULT_ALIASING = iota SKIP_DEFAULT_ALIASING = iota
FOR_PROJECTION FOR_PROJECTION
NO_TABLE_NAME
) )
type Clause interface { type Clause interface {
@ -18,27 +19,27 @@ type Clause interface {
} }
type queryData struct { type queryData struct {
queryBuff bytes.Buffer buff bytes.Buffer
args []interface{} args []interface{}
} }
func (q *queryData) Write(data []byte) { func (q *queryData) Write(data []byte) {
q.queryBuff.Write(data) q.buff.Write(data)
} }
func (q *queryData) WriteString(str string) { func (q *queryData) WriteString(str string) {
q.queryBuff.WriteString(str) q.buff.WriteString(str)
} }
func (q *queryData) WriteByte(b byte) { func (q *queryData) WriteByte(b byte) {
q.queryBuff.WriteByte(b) q.buff.WriteByte(b)
} }
func (q *queryData) InsertArgument(arg interface{}) { func (q *queryData) InsertArgument(arg interface{}) {
q.args = append(q.args, arg) q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args)) argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.queryBuff.WriteString(argPlaceholder) q.buff.WriteString(argPlaceholder)
} }
func argToString(value interface{}) (string, error) { func argToString(value interface{}) (string, error) {

View file

@ -73,7 +73,7 @@ func (c *baseColumn) setTableName(table string) error {
} }
func (c baseColumn) Serialize(out *queryData, options ...serializeOption) error { func (c baseColumn) Serialize(out *queryData, options ...serializeOption) error {
if c.tableName != "" { if c.tableName != "" && !contains(options, NO_TABLE_NAME) {
out.WriteString(c.tableName) out.WriteString(c.tableName)
out.WriteString(".") out.WriteString(".")
} }

View file

@ -67,5 +67,5 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error
} }
} }
return queryData.queryBuff.String() + ";", queryData.args, nil return queryData.buff.String() + ";", queryData.args, nil
} }

View file

@ -226,5 +226,5 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData.WriteByte(';') queryData.WriteByte(';')
return queryData.queryBuff.String(), queryData.args, nil return queryData.buff.String(), queryData.args, nil
} }

View file

@ -21,8 +21,6 @@ type SelectStatement interface {
FOR_UPDATE() SelectStatement FOR_UPDATE() SelectStatement
Copy() SelectStatement
AsTable(alias string) *SelectStatementTable AsTable(alias string) *SelectStatementTable
} }
@ -159,7 +157,7 @@ func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error
return "", nil, err return "", nil, err
} }
return queryData.queryBuff.String(), queryData.args, nil return queryData.buff.String(), queryData.args, nil
} }
func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
@ -169,19 +167,6 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
} }
} }
func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(s, db, destination)
}
func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (s *selectStatementImpl) Copy() SelectStatement {
ret := *s
return &ret
}
func (q *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { func (q *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
q.where = expression q.where = expression
return q return q
@ -224,6 +209,14 @@ func (q *selectStatementImpl) FOR_UPDATE() SelectStatement {
return q return q
} }
func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(s, db, destination)
}
func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func NumExp(statement SelectStatement) NumericExpression { func NumExp(statement SelectStatement) NumericExpression {
return newNumericExpressionWrap(statement) return newNumericExpressionWrap(statement)
} }

View file

@ -463,8 +463,8 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
q := Union(select_queries...) q := Union(select_queries...)
q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15)))
q = q.OrderBy(Desc(table1Col4), Asc(table1Col3)) q = q.ORDER_BY(Desc(table1Col4), Asc(table1Col3))
q = q.Limit(5) q = q.LIMIT(5)
_, err := q.String() _, err := q.String()
@ -502,8 +502,8 @@ func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
q := Union(select_queries...) q := Union(select_queries...)
q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15)))
q = q.OrderBy(Desc(table1Col4), Asc(table1Col3)) q = q.ORDER_BY(Desc(table1Col4), Asc(table1Col3))
q = q.Limit(5) q = q.LIMIT(5)
q = q.GroupBy(table1Col4) q = q.GroupBy(table1Col4)
sql, err := q.String() sql, err := q.String()

View file

@ -1,196 +1,123 @@
package sqlbuilder package sqlbuilder
// By default, rows selected by a UNION statement are out-of-orderBy import (
// If you have an ORDER BY on an inner SELECT statement, the only thing "database/sql"
// it affects is the LIMIT clause on that inner statement (the ordering will "github.com/dropbox/godropbox/errors"
// still be out-of-orderBy). "github.com/sub0zero/go-sqlbuilder/types"
)
type UnionStatement interface { type UnionStatement interface {
Statement Statement
// Warning! You cannot include tableName names for the next 4 clauses, or ORDER_BY(clauses ...OrderByClause) UnionStatement
// you'll get errors like: LIMIT(limit int64) UnionStatement
// Table 'server_file_journal' from one of the SELECTs cannot be used in OFFSET(offset int64) UnionStatement
// global ORDER clause
Where(expression BoolExpression) UnionStatement
GroupBy(expressions ...Expression) UnionStatement
OrderBy(clauses ...OrderByClause) UnionStatement
Limit(limit int64) UnionStatement
Offset(offset int64) UnionStatement
} }
// func Union(selects ...SelectStatement) UnionStatement {
//func Union(selects ...SelectStatement) UnionStatement { return &unionStatementImpl{
// return &unionStatementImpl{ selects: selects,
// selects: selects, limit: -1,
// limit: -1, offset: -1,
// offset: -1, all: true,
// unique: true, }
// } }
//}
// func UnionAll(selects ...SelectStatement) UnionStatement {
//func UnionAll(selects ...SelectStatement) UnionStatement { return &unionStatementImpl{
// return &unionStatementImpl{ selects: selects,
// selects: selects, limit: -1,
// limit: -1, offset: -1,
// offset: -1, all: false,
// unique: false, }
// } }
//}
// // Similar to selectStatementImpl, but less complete
//// Similar to selectStatementImpl, but less complete type unionStatementImpl struct {
//type unionStatementImpl struct { selects []SelectStatement
// selects []SelectStatement order *listClause
// where BoolExpression limit, offset int64
// group *listClause // True if results of the union should be deduped.
// order *listClause all bool
// limit, offset int64 }
// // True if results of the union should be deduped.
// unique bool func (us *unionStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
//} if len(us.selects) == 0 {
// return errors.Newf("Union statement must have at least one SELECT")
//func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { }
// return Query(s, db, destination)
//} out.WriteString("(")
//
//func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { for i, selectStmt := range us.selects {
// return Execute(u, db) if i > 0 {
//} out.WriteString(" UNION ")
//
//func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { if us.all {
// us.where = expression out.WriteString(" ALL ")
// return us }
//} }
//
//// Further filter the query, instead of replacing the filter err := selectStmt.Serialize(out, options...)
//func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement {
// if us.where == nil { if err != nil {
// return us.Where(expression) return err
// } }
// us.where = And(us.where, expression) }
// return us
//} out.WriteString(")")
//
//func (us *unionStatementImpl) GroupBy( if us.order != nil {
// expressions ...Expression) UnionStatement { out.WriteString(" ORDER BY ")
// if err := us.order.Serialize(out, NO_TABLE_NAME); err != nil {
// us.group = &listClause{ return err
// clauses: make([]Clause, len(expressions), len(expressions)), }
// includeParentheses: false, }
// }
// if us.limit >= 0 {
// for i, e := range expressions { out.WriteString(" LIMIT ")
// us.group.clauses[i] = e out.InsertArgument(us.limit)
// } }
// return us
//} if us.offset >= 0 {
// out.WriteString(" OFFSET ")
//func (us *unionStatementImpl) OrderBy( out.InsertArgument(us.offset)
// clauses ...OrderByClause) UnionStatement { }
//
// us.order = newOrderByListClause(clauses...) return nil
// return us }
//}
// func (us *unionStatementImpl) ORDER_BY(clauses ...OrderByClause) UnionStatement {
//func (us *unionStatementImpl) Limit(limit int64) UnionStatement {
// us.limit = limit us.order = newOrderByListClause(clauses...)
// return us return us
//} }
//
//func (us *unionStatementImpl) Offset(offset int64) UnionStatement { func (us *unionStatementImpl) LIMIT(limit int64) UnionStatement {
// us.offset = offset us.limit = limit
// return us return us
//} }
//
//func (us *unionStatementImpl) String() (sql string, err error) { func (us *unionStatementImpl) OFFSET(offset int64) UnionStatement {
// if len(us.selects) == 0 { us.offset = offset
// return "", errors.Newf("Union statement must have at least one SELECT") return us
// } }
//
// if len(us.selects) == 1 { func (us *unionStatementImpl) Sql() (query string, args []interface{}, err error) {
// return us.selects[0].String() queryData := &queryData{}
// }
// err = us.Serialize(queryData)
// // Union statements in MySQL require that the same number of columns in each subquery
// var projections []Projection if err != nil {
// return
// for _, statement := range us.selects { }
// // do a type assertion to get at the underlying struct
// statementImpl, ok := statement.(*selectStatementImpl) return queryData.buff.String(), queryData.args, nil
// if !ok { }
// return "", errors.Newf(
// "Expected inner select statement to be of type " + func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error {
// "selectStatementImpl") return Query(s, db, destination)
// } }
//
// // check that for limit for statements with orderBy by clauses func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
// if statementImpl.orderBy != nil && statementImpl.limit < 0 { return Execute(u, db)
// return "", errors.Newf( }
// "All inner selects in Union statement must have LIMIT if " +
// "they have ORDER BY")
// }
//
// // check number of projections
// if projections == nil {
// projections = statementImpl.projections
// } else {
// if len(projections) != len(statementImpl.projections) {
// return "", errors.Newf(
// "All inner selects in Union statement must select the " +
// "same number of columns. For sanity, you probably " +
// "want to select the same tableName columns in the same " +
// "orderBy. If you are selecting on multiple tables, " +
// "use Null to pad to the right number of fields.")
// }
// }
// }
//
// buf := new(bytes.Buffer)
// for i, statement := range us.selects {
// if i != 0 {
// if us.unique {
// _, _ = buf.WriteString(" UNION ")
// } else {
// _, _ = buf.WriteString(" UNION ALL ")
// }
// }
// _, _ = buf.WriteString("(")
// selectSql, err := statement.String()
// if err != nil {
// return "", err
// }
// _, _ = buf.WriteString(selectSql)
// _, _ = buf.WriteString(")")
// }
//
// if us.where != nil {
// _, _ = buf.WriteString(" WHERE ")
// if err = us.where.Serialize(buf); err != nil {
// return
// }
// }
//
// if us.group != nil {
// _, _ = buf.WriteString(" GROUP BY ")
// if err = us.group.Serialize(buf); err != nil {
// return
// }
// }
//
// if us.order != nil {
// _, _ = buf.WriteString(" ORDER BY ")
// if err = us.order.Serialize(buf); err != nil {
// return
// }
// }
//
// if us.limit >= 0 {
// if us.offset >= 0 {
// _, _ = buf.WriteString(
// fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit))
// } else {
// _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit))
// }
// }
// return buf.String(), nil
//}

View file

@ -147,5 +147,5 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
} }
return out.queryBuff.String(), out.args, nil return out.buff.String(), out.args, nil
} }