diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index dad2f57..224c419 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -11,6 +11,7 @@ type serializeOption int const ( SKIP_DEFAULT_ALIASING = iota FOR_PROJECTION + NO_TABLE_NAME ) type Clause interface { @@ -18,27 +19,27 @@ type Clause interface { } type queryData struct { - queryBuff bytes.Buffer - args []interface{} + buff bytes.Buffer + args []interface{} } func (q *queryData) Write(data []byte) { - q.queryBuff.Write(data) + q.buff.Write(data) } func (q *queryData) WriteString(str string) { - q.queryBuff.WriteString(str) + q.buff.WriteString(str) } func (q *queryData) WriteByte(b byte) { - q.queryBuff.WriteByte(b) + q.buff.WriteByte(b) } func (q *queryData) InsertArgument(arg interface{}) { q.args = append(q.args, arg) argPlaceholder := "$" + strconv.Itoa(len(q.args)) - q.queryBuff.WriteString(argPlaceholder) + q.buff.WriteString(argPlaceholder) } func argToString(value interface{}) (string, error) { diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 5416673..c0a75ea 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -73,7 +73,7 @@ func (c *baseColumn) setTableName(table string) error { } func (c baseColumn) Serialize(out *queryData, options ...serializeOption) error { - if c.tableName != "" { + if c.tableName != "" && !contains(options, NO_TABLE_NAME) { out.WriteString(c.tableName) out.WriteString(".") } diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go index 8bfccce..4617f8f 100644 --- a/sqlbuilder/delete_statement.go +++ b/sqlbuilder/delete_statement.go @@ -67,5 +67,5 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error } } - return queryData.queryBuff.String() + ";", queryData.args, nil + return queryData.buff.String() + ";", queryData.args, nil } diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 6607b3b..48ed991 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -226,5 +226,5 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) queryData.WriteByte(';') - return queryData.queryBuff.String(), queryData.args, nil + return queryData.buff.String(), queryData.args, nil } diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 9fc96a8..14a3fbc 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -21,8 +21,6 @@ type SelectStatement interface { FOR_UPDATE() SelectStatement - Copy() SelectStatement - AsTable(alias string) *SelectStatementTable } @@ -159,7 +157,7 @@ func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error return "", nil, err } - return queryData.queryBuff.String(), queryData.args, nil + return queryData.buff.String(), queryData.args, nil } func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { @@ -169,19 +167,6 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { } } -func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error { - return Query(s, db, destination) -} - -func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) { - return Execute(u, db) -} - -func (s *selectStatementImpl) Copy() SelectStatement { - ret := *s - return &ret -} - func (q *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { q.where = expression return q @@ -224,6 +209,14 @@ func (q *selectStatementImpl) FOR_UPDATE() SelectStatement { return q } +func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error { + return Query(s, db, destination) +} + +func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(u, db) +} + func NumExp(statement SelectStatement) NumericExpression { return newNumericExpressionWrap(statement) } diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index ba0d431..cc8239c 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -463,8 +463,8 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { q := Union(select_queries...) q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) - q = q.OrderBy(Desc(table1Col4), Asc(table1Col3)) - q = q.Limit(5) + q = q.ORDER_BY(Desc(table1Col4), Asc(table1Col3)) + q = q.LIMIT(5) _, err := q.String() @@ -502,8 +502,8 @@ func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { q := Union(select_queries...) q = q.Where(And(LtL(table1Col1, 1000), GtL(table1Col1, 15))) - q = q.OrderBy(Desc(table1Col4), Asc(table1Col3)) - q = q.Limit(5) + q = q.ORDER_BY(Desc(table1Col4), Asc(table1Col3)) + q = q.LIMIT(5) q = q.GroupBy(table1Col4) sql, err := q.String() diff --git a/sqlbuilder/union_statement.go b/sqlbuilder/union_statement.go index 41c26f0..1555ca6 100644 --- a/sqlbuilder/union_statement.go +++ b/sqlbuilder/union_statement.go @@ -1,196 +1,123 @@ package sqlbuilder -// By default, rows selected by a UNION statement are out-of-orderBy -// If you have an ORDER BY on an inner SELECT statement, the only thing -// it affects is the LIMIT clause on that inner statement (the ordering will -// still be out-of-orderBy). +import ( + "database/sql" + "github.com/dropbox/godropbox/errors" + "github.com/sub0zero/go-sqlbuilder/types" +) + type UnionStatement interface { Statement - // Warning! You cannot include tableName names for the next 4 clauses, or - // you'll get errors like: - // Table 'server_file_journal' from one of the SELECTs cannot be used in - // global ORDER clause - Where(expression BoolExpression) UnionStatement - GroupBy(expressions ...Expression) UnionStatement - OrderBy(clauses ...OrderByClause) UnionStatement - - Limit(limit int64) UnionStatement - Offset(offset int64) UnionStatement + ORDER_BY(clauses ...OrderByClause) UnionStatement + LIMIT(limit int64) UnionStatement + OFFSET(offset int64) UnionStatement } -// -//func Union(selects ...SelectStatement) UnionStatement { -// return &unionStatementImpl{ -// selects: selects, -// limit: -1, -// offset: -1, -// unique: true, -// } -//} -// -//func UnionAll(selects ...SelectStatement) UnionStatement { -// return &unionStatementImpl{ -// selects: selects, -// limit: -1, -// offset: -1, -// unique: false, -// } -//} -// -//// Similar to selectStatementImpl, but less complete -//type unionStatementImpl struct { -// selects []SelectStatement -// where BoolExpression -// group *listClause -// order *listClause -// limit, offset int64 -// // True if results of the union should be deduped. -// unique bool -//} -// -//func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { -// return Query(s, db, destination) -//} -// -//func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { -// return Execute(u, db) -//} -// -//func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { -// us.where = expression -// return us -//} -// -//// Further filter the query, instead of replacing the filter -//func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { -// if us.where == nil { -// return us.Where(expression) -// } -// us.where = And(us.where, expression) -// return us -//} -// -//func (us *unionStatementImpl) GroupBy( -// expressions ...Expression) UnionStatement { -// -// us.group = &listClause{ -// clauses: make([]Clause, len(expressions), len(expressions)), -// includeParentheses: false, -// } -// -// for i, e := range expressions { -// us.group.clauses[i] = e -// } -// return us -//} -// -//func (us *unionStatementImpl) OrderBy( -// clauses ...OrderByClause) UnionStatement { -// -// us.order = newOrderByListClause(clauses...) -// return us -//} -// -//func (us *unionStatementImpl) Limit(limit int64) UnionStatement { -// us.limit = limit -// return us -//} -// -//func (us *unionStatementImpl) Offset(offset int64) UnionStatement { -// us.offset = offset -// return us -//} -// -//func (us *unionStatementImpl) String() (sql string, err error) { -// if len(us.selects) == 0 { -// return "", errors.Newf("Union statement must have at least one SELECT") -// } -// -// if len(us.selects) == 1 { -// return us.selects[0].String() -// } -// -// // Union statements in MySQL require that the same number of columns in each subquery -// var projections []Projection -// -// for _, statement := range us.selects { -// // do a type assertion to get at the underlying struct -// statementImpl, ok := statement.(*selectStatementImpl) -// if !ok { -// return "", errors.Newf( -// "Expected inner select statement to be of type " + -// "selectStatementImpl") -// } -// -// // check that for limit for statements with orderBy by clauses -// if statementImpl.orderBy != nil && statementImpl.limit < 0 { -// return "", errors.Newf( -// "All inner selects in Union statement must have LIMIT if " + -// "they have ORDER BY") -// } -// -// // check number of projections -// if projections == nil { -// projections = statementImpl.projections -// } else { -// if len(projections) != len(statementImpl.projections) { -// return "", errors.Newf( -// "All inner selects in Union statement must select the " + -// "same number of columns. For sanity, you probably " + -// "want to select the same tableName columns in the same " + -// "orderBy. If you are selecting on multiple tables, " + -// "use Null to pad to the right number of fields.") -// } -// } -// } -// -// buf := new(bytes.Buffer) -// for i, statement := range us.selects { -// if i != 0 { -// if us.unique { -// _, _ = buf.WriteString(" UNION ") -// } else { -// _, _ = buf.WriteString(" UNION ALL ") -// } -// } -// _, _ = buf.WriteString("(") -// selectSql, err := statement.String() -// if err != nil { -// return "", err -// } -// _, _ = buf.WriteString(selectSql) -// _, _ = buf.WriteString(")") -// } -// -// if us.where != nil { -// _, _ = buf.WriteString(" WHERE ") -// if err = us.where.Serialize(buf); err != nil { -// return -// } -// } -// -// if us.group != nil { -// _, _ = buf.WriteString(" GROUP BY ") -// if err = us.group.Serialize(buf); err != nil { -// return -// } -// } -// -// if us.order != nil { -// _, _ = buf.WriteString(" ORDER BY ") -// if err = us.order.Serialize(buf); err != nil { -// return -// } -// } -// -// if us.limit >= 0 { -// if us.offset >= 0 { -// _, _ = buf.WriteString( -// fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) -// } else { -// _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) -// } -// } -// return buf.String(), nil -//} +func Union(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + all: true, + } +} + +func UnionAll(selects ...SelectStatement) UnionStatement { + return &unionStatementImpl{ + selects: selects, + limit: -1, + offset: -1, + all: false, + } +} + +// Similar to selectStatementImpl, but less complete +type unionStatementImpl struct { + selects []SelectStatement + order *listClause + limit, offset int64 + // True if results of the union should be deduped. + all bool +} + +func (us *unionStatementImpl) Serialize(out *queryData, options ...serializeOption) error { + if len(us.selects) == 0 { + return errors.Newf("Union statement must have at least one SELECT") + } + + out.WriteString("(") + + for i, selectStmt := range us.selects { + if i > 0 { + out.WriteString(" UNION ") + + if us.all { + out.WriteString(" ALL ") + } + } + + err := selectStmt.Serialize(out, options...) + + if err != nil { + return err + } + } + + out.WriteString(")") + + if us.order != nil { + out.WriteString(" ORDER BY ") + if err := us.order.Serialize(out, NO_TABLE_NAME); err != nil { + return err + } + } + + if us.limit >= 0 { + out.WriteString(" LIMIT ") + out.InsertArgument(us.limit) + } + + if us.offset >= 0 { + out.WriteString(" OFFSET ") + out.InsertArgument(us.offset) + } + + return nil +} + +func (us *unionStatementImpl) ORDER_BY(clauses ...OrderByClause) UnionStatement { + + us.order = newOrderByListClause(clauses...) + return us +} + +func (us *unionStatementImpl) LIMIT(limit int64) UnionStatement { + us.limit = limit + return us +} + +func (us *unionStatementImpl) OFFSET(offset int64) UnionStatement { + us.offset = offset + return us +} + +func (us *unionStatementImpl) Sql() (query string, args []interface{}, err error) { + queryData := &queryData{} + + err = us.Serialize(queryData) + + if err != nil { + return + } + + return queryData.buff.String(), queryData.args, nil +} + +func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { + return Query(s, db, destination) +} + +func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(u, db) +} diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index 41f6687..662f9f6 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -147,5 +147,5 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) } } - return out.queryBuff.String(), out.args, nil + return out.buff.String(), out.args, nil }