Support for subqueries, Group By and Having clause.

This commit is contained in:
zer0sub 2019-03-30 10:17:32 +01:00
parent ddf816c998
commit 5a7563d4af
12 changed files with 674 additions and 305 deletions

View file

@ -5,26 +5,29 @@ package sqlbuilder
import ( import (
"bytes" "bytes"
"regexp" "regexp"
"strings"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
) )
// XXX: Maybe add UIntColumn // XXX: Maybe add UIntColumn
// Representation of a table for query generation // Representation of a tableName for query generation
type Column interface { type Column interface {
isProjectionInterface isProjectionInterface
isExpressionInterface isExpressionInterface
As(alias string) Column As(alias string) Projection
Name() string Name() string
TableName() string
// Serialization for use in column lists // Serialization for use in column lists
SerializeSqlForColumnList(out *bytes.Buffer) error SerializeSqlForColumnList(out *bytes.Buffer) error
// Serialization for use in an expression (Clause) // Serialization for use in an expression (Clause)
SerializeSql(out *bytes.Buffer) error SerializeSql(out *bytes.Buffer) error
// Internal function for tracking table that a column belongs to // Internal function for tracking tableName that a column belongs to
// for the purpose of serialization // for the purpose of serialization
setTableName(table string) error setTableName(table string) error
@ -73,13 +76,13 @@ const (
type baseColumn struct { type baseColumn struct {
isProjection isProjection
isExpression isExpression
name string name string
nullable NullableColumn nullable NullableColumn
table string tableName string
alias string alias string
} }
func (c *baseColumn) As(alias string) Column { func (c *baseColumn) As(alias string) Projection {
newBaseColumn := *c newBaseColumn := *c
newBaseColumn.alias = alias newBaseColumn.alias = alias
@ -90,8 +93,12 @@ func (c *baseColumn) Name() string {
return c.name return c.name
} }
func (c *baseColumn) TableName() string {
return c.tableName
}
func (c *baseColumn) setTableName(table string) error { func (c *baseColumn) setTableName(table string) error {
c.table = table c.tableName = table
return nil return nil
} }
@ -101,19 +108,27 @@ func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error {
if c.alias != "" { if c.alias != "" {
_, _ = out.WriteString(" AS \"" + c.alias + "\"") _, _ = out.WriteString(" AS \"" + c.alias + "\"")
} else if c.table != "" { } else if c.tableName != "" {
_, _ = out.WriteString(" AS \"" + c.table + "." + c.name + "\"") _, _ = out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
} }
return nil return nil
} }
func (c baseColumn) SerializeSql(out *bytes.Buffer) error { func (c baseColumn) SerializeSql(out *bytes.Buffer) error {
if c.table != "" { if c.tableName != "" {
_, _ = out.WriteString(c.table) _, _ = out.WriteString(c.tableName)
_, _ = out.WriteString(".") _, _ = out.WriteString(".")
} }
containsDot := strings.Contains(c.name, ".")
if containsDot {
out.WriteString("\"")
}
_, _ = out.WriteString(c.name) _, _ = out.WriteString(c.name)
if containsDot {
out.WriteString("\"")
}
return nil return nil
} }
@ -323,11 +338,11 @@ func validIdentifierName(name string) bool {
} }
// //
//// Pseudo Column type returned by table.C(name) //// Pseudo Column type returned by tableName.C(name)
//type deferredLookupColumn struct { //type deferredLookupColumn struct {
// isProjection // isProjection
// isExpression // isExpression
// table *Table // tableName *Table
// colName string // colName string
// //
// cachedColumn NonAliasColumn // cachedColumn NonAliasColumn
@ -348,7 +363,7 @@ func validIdentifierName(name string) bool {
// return c.cachedColumn.SerializeSql(out) // return c.cachedColumn.SerializeSql(out)
// } // }
// //
// col, err := c.table.getColumn(c.colName) // col, err := c.tableName.getColumn(c.colName)
// if err != nil { // if err != nil {
// return err // return err
// } // }
@ -357,7 +372,7 @@ func validIdentifierName(name string) bool {
// return col.SerializeSql(out) // return col.SerializeSql(out)
//} //}
// //
//func (c *deferredLookupColumn) setTableName(table string) error { //func (c *deferredLookupColumn) setTableName(tableName string) error {
// return errors.Newf( // return errors.Newf(
// "Lookup column '%s' should never have setTableName called on it", // "Lookup column '%s' should never have setTableName called on it",
// c.colName) // c.colName)

View file

@ -29,7 +29,7 @@ func (s *ColumnSuite) TestRealColumnName(c *gc.C) {
func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) { func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) {
col := IntColumn("col", Nullable) col := IntColumn("col", Nullable)
// Without table name // Without tableName name
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := col.SerializeSqlForColumnList(buf) err := col.SerializeSqlForColumnList(buf)
@ -38,7 +38,7 @@ func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) {
sql := buf.String() sql := buf.String()
c.Assert(sql, gc.Equals, "col") c.Assert(sql, gc.Equals, "col")
// With table name // With tableName name
err = col.setTableName("foo") err = col.setTableName("foo")
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
@ -54,7 +54,7 @@ func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) {
func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) { func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) {
col := IntColumn("col", Nullable) col := IntColumn("col", Nullable)
// Without table name // Without tableName name
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := col.SerializeSql(buf) err := col.SerializeSql(buf)
@ -63,7 +63,7 @@ func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) {
sql := buf.String() sql := buf.String()
c.Assert(sql, gc.Equals, "col") c.Assert(sql, gc.Equals, "col")
// With table name // With tableName name
err = col.setTableName("foo") err = col.setTableName("foo")
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)

View file

@ -9,7 +9,7 @@
// //
// Known limitations for SELECT queries: // Known limitations for SELECT queries:
// - does not support subqueries (since mysql is bad at it) // - does not support subqueries (since mysql is bad at it)
// - does not currently support join table alias (and hence self join) // - does not currently support join tableName alias (and hence self join)
// - does not support NATURAL joins and join USING // - does not support NATURAL joins and join USING
// //
// Known limitation for INSERT statements: // Known limitation for INSERT statements:
@ -17,9 +17,9 @@
// //
// Known limitation for UPDATE statements: // Known limitation for UPDATE statements:
// - does not support update without a WHERE clause (since it is dangerous) // - does not support update without a WHERE clause (since it is dangerous)
// - does not support multi-table update // - does not support multi-tableName update
// //
// Known limitation for DELETE statements: // Known limitation for DELETE statements:
// - does not support delete without a WHERE clause (since it is dangerous) // - does not support delete without a WHERE clause (since it is dangerous)
// - does not support multi-table delete // - does not support multi-tableName delete
package sqlbuilder package sqlbuilder

View file

@ -262,7 +262,7 @@ func (s *ExprSuite) TestLtExpr(c *gc.C) {
} }
func (s *ExprSuite) TestLteExpr(c *gc.C) { func (s *ExprSuite) TestLteExpr(c *gc.C) {
expr := LteL(table1Col1, "foo\"';drop user table;") expr := LteL(table1Col1, "foo\"';drop user tableName;")
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
@ -273,7 +273,7 @@ func (s *ExprSuite) TestLteExpr(c *gc.C) {
c.Assert( c.Assert(
sql, sql,
gc.Equals, gc.Equals,
"table1.col1<='foo\\\"\\';drop user table;'") "table1.col1<='foo\\\"\\';drop user tableName;'")
} }
func (s *ExprSuite) TestGtExpr(c *gc.C) { func (s *ExprSuite) TestGtExpr(c *gc.C) {

View file

@ -0,0 +1,268 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"github.com/dropbox/godropbox/errors"
"github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution"
"reflect"
)
type SelectStatement interface {
Statement
Expression
Where(expression BoolExpression) SelectStatement
AndWhere(expression BoolExpression) SelectStatement
GroupBy(expressions ...Expression) SelectStatement
HAVING(expressions BoolExpression) SelectStatement
OrderBy(clauses ...OrderByClause) SelectStatement
Limit(limit int64) SelectStatement
Offset(offset int64) SelectStatement
Distinct() SelectStatement
WithSharedLock() SelectStatement
ForUpdate() SelectStatement
Comment(comment string) SelectStatement
Copy() SelectStatement
AsTable(alias string) *SelectStatementTable
}
// NOTE: SelectStatement purposely does not implement the Table interface since
// mysql's subquery performance is horrible.
type selectStatementImpl struct {
isExpression
table ReadableTable
projections []Projection
where BoolExpression
group *listClause
having BoolExpression
order *listClause
comment string
limit, offset int64
withSharedLock bool
forUpdate bool
distinct bool
}
func newSelectStatement(
table ReadableTable,
projections []Projection) SelectStatement {
return &selectStatementImpl{
table: table,
projections: projections,
limit: -1,
offset: -1,
withSharedLock: false,
forUpdate: false,
distinct: false,
}
}
func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer) error {
str, err := s.String()
if err != nil {
return err
}
out.WriteString("( ")
out.WriteString(str)
out.WriteString(")")
return nil
}
func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
return &SelectStatementTable{
statement: s,
alias: alias,
}
}
func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error {
destinationType := reflect.TypeOf(destination)
if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct {
s.Limit(1)
}
query, err := s.String()
if err != nil {
return err
}
return execution.Execute(db, query, destination)
}
func (s *selectStatementImpl) Copy() SelectStatement {
ret := *s
return &ret
}
// Further filter the query, instead of replacing the filter
func (q *selectStatementImpl) AndWhere(
expression BoolExpression) SelectStatement {
if q.where == nil {
return q.Where(expression)
}
q.where = And(q.where, expression)
return q
}
func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement {
q.where = expression
return q
}
func (q *selectStatementImpl) GroupBy(
expressions ...Expression) SelectStatement {
q.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
q.group.clauses[i] = e
}
return q
}
func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement {
q.having = expression
return q
}
func (q *selectStatementImpl) OrderBy(
clauses ...OrderByClause) SelectStatement {
q.order = newOrderByListClause(clauses...)
return q
}
func (q *selectStatementImpl) Limit(limit int64) SelectStatement {
q.limit = limit
return q
}
func (q *selectStatementImpl) Distinct() SelectStatement {
q.distinct = true
return q
}
func (q *selectStatementImpl) WithSharedLock() SelectStatement {
// We don't need to grab a read lock if we're going to grab a write one
if !q.forUpdate {
q.withSharedLock = true
}
return q
}
func (q *selectStatementImpl) ForUpdate() SelectStatement {
// Clear a request for a shared lock if we're asking for a write one
q.withSharedLock = false
q.forUpdate = true
return q
}
func (q *selectStatementImpl) Offset(offset int64) SelectStatement {
q.offset = offset
return q
}
func (q *selectStatementImpl) Comment(comment string) SelectStatement {
q.comment = comment
return q
}
// Return the properly escaped SQL statement, against the specified database
func (q *selectStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("SELECT ")
if err = writeComment(q.comment, buf); err != nil {
return
}
if q.distinct {
_, _ = buf.WriteString("DISTINCT ")
}
if q.projections == nil || len(q.projections) == 0 {
return "", errors.Newf(
"No column selected. Generated sql: %s",
buf.String())
}
for i, col := range q.projections {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column selected. Generated sql: %s",
buf.String())
}
if err = col.SerializeSqlForColumnList(buf); err != nil {
return
}
}
_, _ = buf.WriteString(" FROM ")
if q.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = q.table.SerializeSql(buf); err != nil {
return
}
if q.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = q.where.SerializeSql(buf); err != nil {
return
}
}
if q.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = q.group.SerializeSql(buf); err != nil {
return
}
}
if q.having != nil {
buf.WriteString(" HAVING ")
if err = q.having.SerializeSql(buf); err != nil {
return
}
}
if q.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = q.order.SerializeSql(buf); err != nil {
return
}
}
if q.limit >= 0 {
if q.offset >= 0 {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
}
}
if q.forUpdate {
_, _ = buf.WriteString(" FOR UPDATE")
} else if q.withSharedLock {
_, _ = buf.WriteString(" LOCK IN SHARE MODE")
}
return buf.String(), nil
}

View file

@ -0,0 +1,75 @@
package sqlbuilder
import "bytes"
type SelectStatementTable struct {
statement SelectStatement
columns []NonAliasColumn
alias string
}
func (s *SelectStatementTable) Columns() []NonAliasColumn {
return s.columns
}
func (s *SelectStatementTable) Column(name string) NonAliasColumn {
return &baseColumn{
name: name,
tableName: s.alias,
}
}
func (s *SelectStatementTable) ColumnFrom(column NonAliasColumn) NonAliasColumn {
return &baseColumn{
name: column.TableName() + "." + column.Name(),
tableName: s.alias,
}
}
func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error {
out.WriteString("( ")
statementStr, err := s.statement.String()
if err != nil {
return err
}
out.WriteString(statementStr)
out.WriteString(" ) AS ")
out.WriteString(s.alias)
return nil
}
// Generates a select query on the current tableName.
func (s *SelectStatementTable) Select(projections ...Projection) SelectStatement {
return newSelectStatement(s, projections)
}
// Creates a inner join tableName expression using onCondition.
func (s *SelectStatementTable) InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable {
return InnerJoinOn(s, table, onCondition)
}
func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable {
return InnerJoinOn(s, table, col1.Eq(col2))
}
// Creates a left join tableName expression using onCondition.
func (s *SelectStatementTable) LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable {
return LeftJoinOn(s, table, onCondition)
}
// Creates a right join tableName expression using onCondition.
func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable {
return RightJoinOn(s, table, onCondition)
}
func (s *SelectStatementTable) FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable {
return FullJoin(s, table, col1.Eq(col2))
}
func (s *SelectStatementTable) CrossJoin(table ReadableTable) ReadableTable {
return CrossJoin(s, table)
}

View file

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution"
"reflect"
"regexp" "regexp"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
@ -17,22 +15,6 @@ type Statement interface {
Execute(db *sql.DB, destination interface{}) error Execute(db *sql.DB, destination interface{}) error
} }
type SelectStatement interface {
Statement
Where(expression BoolExpression) SelectStatement
AndWhere(expression BoolExpression) SelectStatement
GroupBy(expressions ...Expression) SelectStatement
OrderBy(clauses ...OrderByClause) SelectStatement
Limit(limit int64) SelectStatement
Offset(offset int64) SelectStatement
Distinct() SelectStatement
WithSharedLock() SelectStatement
ForUpdate() SelectStatement
Comment(comment string) SelectStatement
Copy() SelectStatement
}
type InsertStatement interface { type InsertStatement interface {
Statement Statement
@ -50,7 +32,7 @@ type InsertStatement interface {
type UnionStatement interface { type UnionStatement interface {
Statement Statement
// Warning! You cannot include table names for the next 4 clauses, or // Warning! You cannot include tableName names for the next 4 clauses, or
// you'll get errors like: // you'll get errors like:
// Table 'server_file_journal' from one of the SELECTs cannot be used in // Table 'server_file_journal' from one of the SELECTs cannot be used in
// global ORDER clause // global ORDER clause
@ -91,9 +73,9 @@ type LockStatement interface {
AddWriteLock(table *Table) LockStatement AddWriteLock(table *Table) LockStatement
} }
// UnlockStatement can be used to release table locks taken using LockStatement. // UnlockStatement can be used to release tableName locks taken using LockStatement.
// NOTE: You can not selectively release a lock and continue to hold lock on // NOTE: You can not selectively release a lock and continue to hold lock on
// another table. UnlockStatement releases all the lock held in the current // another tableName. UnlockStatement releases all the lock held in the current
// session. // session.
type UnlockStatement interface { type UnlockStatement interface {
Statement Statement
@ -222,7 +204,7 @@ func (us *unionStatementImpl) String() (sql string, err error) {
return "", errors.Newf( return "", errors.Newf(
"All inner selects in Union statement must select the " + "All inner selects in Union statement must select the " +
"same number of columns. For sanity, you probably " + "same number of columns. For sanity, you probably " +
"want to select the same table columns in the same " + "want to select the same tableName columns in the same " +
"order. If you are selecting on multiple tables, " + "order. If you are selecting on multiple tables, " +
"use Null to pad to the right number of fields.") "use Null to pad to the right number of fields.")
} }
@ -279,212 +261,6 @@ func (us *unionStatementImpl) String() (sql string, err error) {
return buf.String(), nil return buf.String(), nil
} }
//
// SELECT Statement ============================================================
//
func newSelectStatement(
table ReadableTable,
projections []Projection) SelectStatement {
return &selectStatementImpl{
table: table,
projections: projections,
limit: -1,
offset: -1,
withSharedLock: false,
forUpdate: false,
distinct: false,
}
}
// NOTE: SelectStatement purposely does not implement the Table interface since
// mysql's subquery performance is horrible.
type selectStatementImpl struct {
table ReadableTable
projections []Projection
where BoolExpression
group *listClause
order *listClause
comment string
limit, offset int64
withSharedLock bool
forUpdate bool
distinct bool
}
func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error {
destinationType := reflect.TypeOf(destination)
if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct {
s.Limit(1)
}
query, err := s.String()
if err != nil {
return err
}
return execution.Execute(db, query, destination)
}
func (s *selectStatementImpl) Copy() SelectStatement {
ret := *s
return &ret
}
// Further filter the query, instead of replacing the filter
func (q *selectStatementImpl) AndWhere(
expression BoolExpression) SelectStatement {
if q.where == nil {
return q.Where(expression)
}
q.where = And(q.where, expression)
return q
}
func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement {
q.where = expression
return q
}
func (q *selectStatementImpl) GroupBy(
expressions ...Expression) SelectStatement {
q.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
q.group.clauses[i] = e
}
return q
}
func (q *selectStatementImpl) OrderBy(
clauses ...OrderByClause) SelectStatement {
q.order = newOrderByListClause(clauses...)
return q
}
func (q *selectStatementImpl) Limit(limit int64) SelectStatement {
q.limit = limit
return q
}
func (q *selectStatementImpl) Distinct() SelectStatement {
q.distinct = true
return q
}
func (q *selectStatementImpl) WithSharedLock() SelectStatement {
// We don't need to grab a read lock if we're going to grab a write one
if !q.forUpdate {
q.withSharedLock = true
}
return q
}
func (q *selectStatementImpl) ForUpdate() SelectStatement {
// Clear a request for a shared lock if we're asking for a write one
q.withSharedLock = false
q.forUpdate = true
return q
}
func (q *selectStatementImpl) Offset(offset int64) SelectStatement {
q.offset = offset
return q
}
func (q *selectStatementImpl) Comment(comment string) SelectStatement {
q.comment = comment
return q
}
// Return the properly escaped SQL statement, against the specified database
func (q *selectStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("SELECT ")
if err = writeComment(q.comment, buf); err != nil {
return
}
if q.distinct {
_, _ = buf.WriteString("DISTINCT ")
}
if q.projections == nil || len(q.projections) == 0 {
return "", errors.Newf(
"No column selected. Generated sql: %s",
buf.String())
}
for i, col := range q.projections {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column selected. Generated sql: %s",
buf.String())
}
if err = col.SerializeSqlForColumnList(buf); err != nil {
return
}
}
_, _ = buf.WriteString(" FROM ")
if q.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String())
}
if err = q.table.SerializeSql(buf); err != nil {
return
}
if q.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = q.where.SerializeSql(buf); err != nil {
return
}
}
if q.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = q.group.SerializeSql(buf); err != nil {
return
}
}
if q.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = q.order.SerializeSql(buf); err != nil {
return
}
}
if q.limit >= 0 {
if q.offset >= 0 {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
}
}
if q.forUpdate {
_, _ = buf.WriteString(" FOR UPDATE")
} else if q.withSharedLock {
_, _ = buf.WriteString(" LOCK IN SHARE MODE")
}
return buf.String(), nil
}
// //
// INSERT Statement ============================================================ // INSERT Statement ============================================================
// //
@ -560,7 +336,7 @@ func (s *insertStatementImpl) String() (sql string, err error) {
} }
if s.table == nil { if s.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String()) return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
} }
if err = s.table.SerializeSql(buf); err != nil { if err = s.table.SerializeSql(buf); err != nil {
@ -728,7 +504,7 @@ func (u *updateStatementImpl) String() (sql string, err error) {
} }
if u.table == nil { if u.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String()) return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
} }
if err = u.table.SerializeSql(buf); err != nil { if err = u.table.SerializeSql(buf); err != nil {
@ -863,7 +639,7 @@ func (d *deleteStatementImpl) String() (sql string, err error) {
} }
if d.table == nil { if d.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String()) return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
} }
if err = d.table.SerializeSql(buf); err != nil { if err = d.table.SerializeSql(buf); err != nil {
@ -919,13 +695,13 @@ func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil return nil
} }
// AddReadLock takes read lock on the table. // AddReadLock takes read lock on the tableName.
func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement {
s.locks = append(s.locks, tableLock{t: t, w: false}) s.locks = append(s.locks, tableLock{t: t, w: false})
return s return s
} }
// AddWriteLock takes write lock on the table. // AddWriteLock takes write lock on the tableName.
func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement { func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement {
s.locks = append(s.locks, tableLock{t: t, w: true}) s.locks = append(s.locks, tableLock{t: t, w: true})
return s return s
@ -941,7 +717,7 @@ func (s *lockStatementImpl) String() (sql string, err error) {
for idx, lock := range s.locks { for idx, lock := range s.locks {
if lock.t == nil { if lock.t == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String()) return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
} }
if err = lock.t.SerializeSql(buf); err != nil { if err = lock.t.SerializeSql(buf); err != nil {
@ -962,7 +738,7 @@ func (s *lockStatementImpl) String() (sql string, err error) {
return buf.String(), nil return buf.String(), nil
} }
// NewUnlockStatement returns SQL statement that can be used to release table locks // NewUnlockStatement returns SQL statement that can be used to release tableName locks
// grabbed by the current session. // grabbed by the current session.
func NewUnlockStatement() UnlockStatement { func NewUnlockStatement() UnlockStatement {
return &unlockStatementImpl{} return &unlockStatementImpl{}

View file

@ -609,7 +609,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
gc.Equals, gc.Equals,
"All inner selects in Union statement must select the "+ "All inner selects in Union statement must select the "+
"same number of columns. For sanity, you probably "+ "same number of columns. For sanity, you probably "+
"want to select the same table columns in the same "+ "want to select the same tableName columns in the same "+
"order. If you are selecting on multiple tables, "+ "order. If you are selecting on multiple tables, "+
"use Null to pad to the right number of fields.") "use Null to pad to the right number of fields.")
} }

View file

@ -8,29 +8,31 @@ import (
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
) )
// The sql table read interface. NOTE: NATURAL JOINs, and join "USING" clause // The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause
// are not supported. // are not supported.
type ReadableTable interface { type ReadableTable interface {
// Returns the list of columns that are in the current table expression. // Returns the list of columns that are in the current tableName expression.
Columns() []NonAliasColumn Columns() []NonAliasColumn
// Generates the sql string for the current table expression. Note: the Column(name string) NonAliasColumn
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
// The database is the name of the database the table is on // The database is the name of the database the tableName is on
SerializeSql(out *bytes.Buffer) error SerializeSql(out *bytes.Buffer) error
// Generates a select query on the current table. // Generates a select query on the current tableName.
Select(projections ...Projection) SelectStatement Select(projections ...Projection) SelectStatement
// Creates a inner join table expression using onCondition. // Creates a inner join tableName expression using onCondition.
InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable
// Creates a left join table expression using onCondition. // Creates a left join tableName expression using onCondition.
LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a right join table expression using onCondition. // Creates a right join tableName expression using onCondition.
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable
@ -38,14 +40,14 @@ type ReadableTable interface {
CrossJoin(table ReadableTable) ReadableTable CrossJoin(table ReadableTable) ReadableTable
} }
// The sql table write interface. // The sql tableName write interface.
type WritableTable interface { type WritableTable interface {
// Returns the list of columns that are in the table. // Returns the list of columns that are in the tableName.
Columns() []NonAliasColumn Columns() []NonAliasColumn
// Generates the sql string for the current table expression. Note: the // Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
// The database is the name of the database the table is on // The database is the name of the database the tableName is on
SerializeSql(out *bytes.Buffer) error SerializeSql(out *bytes.Buffer) error
Insert(columns ...NonAliasColumn) InsertStatement Insert(columns ...NonAliasColumn) InsertStatement
@ -53,11 +55,11 @@ type WritableTable interface {
Delete() DeleteStatement Delete() DeleteStatement
} }
// Defines a physical table in the database that is both readable and writable. // Defines a physical tableName in the database that is both readable and writable.
// This function will panic if name is not valid // This function will panic if name is not valid
func NewTable(schemaName, name string, columns ...NonAliasColumn) *Table { func NewTable(schemaName, name string, columns ...NonAliasColumn) *Table {
if !validIdentifierName(name) { if !validIdentifierName(name) {
panic("Invalid table name") panic("Invalid tableName name")
} }
t := &Table{ t := &Table{
@ -91,28 +93,28 @@ type Table struct {
forcedIndex string forcedIndex string
} }
// Returns the specified column, or errors if it doesn't exist in the table // Returns the specified column, or errors if it doesn't exist in the tableName
func (t *Table) getColumn(name string) (NonAliasColumn, error) { func (t *Table) getColumn(name string) (NonAliasColumn, error) {
if c, ok := t.columnLookup[name]; ok { if c, ok := t.columnLookup[name]; ok {
return c, nil return c, nil
} }
return nil, errors.Newf("No such column '%s' in table '%s'", name, t.name) return nil, errors.Newf("No such column '%s' in tableName '%s'", name, t.name)
} }
// Returns a pseudo column representation of the column name. Error checking func (t *Table) Column(name string) NonAliasColumn {
// is deferred to SerializeSql. return &baseColumn{
//func (t *Table) C(name string) NonAliasColumn { name: name,
// return &deferredLookupColumn{ nullable: NotNullable,
// table: t, tableName: t.name,
// colName: name, }
// } }
//}
// Returns all columns for a table as a slice of projections // Returns all columns for a tableName as a slice of projections
func (t *Table) Projections() []Projection { func (t *Table) Projections() []Projection {
result := make([]Projection, 0) result := make([]Projection, 0)
for _, col := range t.columns { for _, col := range t.columns {
col.Asc()
result = append(result, col) result = append(result, col)
} }
@ -130,7 +132,7 @@ func (t *Table) SetAlias(alias string) {
} }
} }
// Returns the table's name in the database // Returns the tableName's name in the database
func (t *Table) Name() string { func (t *Table) Name() string {
return t.name return t.name
} }
@ -139,19 +141,19 @@ func (t *Table) SchemaName() string {
return t.schemaName return t.schemaName
} }
// Returns a list of the table's columns // Returns a list of the tableName's columns
func (t *Table) Columns() []NonAliasColumn { func (t *Table) Columns() []NonAliasColumn {
return t.columns return t.columns
} }
// Returns a copy of this table, but with the specified index forced. // Returns a copy of this tableName, but with the specified index forced.
func (t *Table) ForceIndex(index string) *Table { func (t *Table) ForceIndex(index string) *Table {
newTable := *t newTable := *t
newTable.forcedIndex = index newTable.forcedIndex = index
return &newTable return &newTable
} }
// Generates the sql string for the current table expression. Note: the // Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
func (t *Table) SerializeSql(out *bytes.Buffer) error { func (t *Table) SerializeSql(out *bytes.Buffer) error {
if !validIdentifierName(t.schemaName) { if !validIdentifierName(t.schemaName) {
@ -179,12 +181,12 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error {
return nil return nil
} }
// Generates a select query on the current table. // Generates a select query on the current tableName.
func (t *Table) Select(projections ...Projection) SelectStatement { func (t *Table) Select(projections ...Projection) SelectStatement {
return newSelectStatement(t, projections) return newSelectStatement(t, projections)
} }
// Creates a inner join table expression using onCondition. // Creates a inner join tableName expression using onCondition.
func (t *Table) InnerJoinOn( func (t *Table) InnerJoinOn(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
@ -200,7 +202,7 @@ func (t *Table) InnerJoinUsing(
return InnerJoinOn(t, table, col1.Eq(col2)) return InnerJoinOn(t, table, col1.Eq(col2))
} }
// Creates a left join table expression using onCondition. // Creates a left join tableName expression using onCondition.
func (t *Table) LeftJoinOn( func (t *Table) LeftJoinOn(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
@ -208,7 +210,7 @@ func (t *Table) LeftJoinOn(
return LeftJoinOn(t, table, onCondition) return LeftJoinOn(t, table, onCondition)
} }
// Creates a right join table expression using onCondition. // Creates a right join tableName expression using onCondition.
func (t *Table) RightJoinOn( func (t *Table) RightJoinOn(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
@ -315,6 +317,10 @@ func (t *joinTable) Columns() []NonAliasColumn {
return columns return columns
} }
func (t *joinTable) Column(name string) NonAliasColumn {
panic("Not implemented")
}
func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
if t.lhs == nil { if t.lhs == nil {

View file

@ -52,7 +52,7 @@ func (s *TableSuite) TestValidForcedIndex(c *gc.C) {
sql := buf.String() sql := buf.String()
c.Assert(sql, gc.Equals, "db.table1 FORCE INDEX (foo)") c.Assert(sql, gc.Equals, "db.table1 FORCE INDEX (foo)")
// Ensure the original table is unchanged // Ensure the original tableName is unchanged
buf = &bytes.Buffer{} buf = &bytes.Buffer{}
err = table1.SerializeSql(buf) err = table1.SerializeSql(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)

View file

@ -32,6 +32,8 @@ type BoolExpression interface {
type Projection interface { type Projection interface {
Clause Clause
isProjectionInterface isProjectionInterface
As(alias string) Projection
SerializeSqlForColumnList(out *bytes.Buffer) error SerializeSqlForColumnList(out *bytes.Buffer) error
} }
@ -51,6 +53,10 @@ func (cl ColumnList) SerializeSql(out *bytes.Buffer) error {
func (cl ColumnList) isProjectionType() { func (cl ColumnList) isProjectionType() {
} }
func (cl ColumnList) As(name string) Projection {
panic("Unallowed operation ")
}
func (cl ColumnList) SerializeSqlForColumnList(out *bytes.Buffer) error { func (cl ColumnList) SerializeSqlForColumnList(out *bytes.Buffer) error {
for i, column := range cl { for i, column := range cl {
column.SerializeSqlForColumnList(out) column.SerializeSqlForColumnList(out)

View file

@ -3,8 +3,8 @@ package tests
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/davecgh/go-spew/spew"
"github.com/sub0Zero/go-sqlbuilder/generator" "github.com/sub0Zero/go-sqlbuilder/generator"
"github.com/sub0Zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" . "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
"gotest.tools/assert" "gotest.tools/assert"
@ -79,6 +79,7 @@ func TestSelect_ScanToSlice(t *testing.T) {
queryStr, err := query.String() queryStr, err := query.String()
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`)
err = query.Execute(db, &customers) err = query.Execute(db, &customers)
@ -119,7 +120,7 @@ func TestJoinQueryStruct(t *testing.T) {
func TestJoinQuerySlice(t *testing.T) { func TestJoinQuerySlice(t *testing.T) {
type FilmsPerLanguage struct { type FilmsPerLanguage struct {
Language *model.Language Language *model.Language
Films *[]model.Film Film *[]model.Film
} }
filmsPerLanguage := []FilmsPerLanguage{} filmsPerLanguage := []FilmsPerLanguage{}
@ -143,8 +144,10 @@ func TestJoinQuerySlice(t *testing.T) {
//fmt.Println("--------------- result --------------- ") //fmt.Println("--------------- result --------------- ")
//spew.Dump(filmsPerLanguage) //spew.Dump(filmsPerLanguage)
//spew.Dump(filmsPerLanguage)
assert.Equal(t, len(filmsPerLanguage), 1) assert.Equal(t, len(filmsPerLanguage), 1)
assert.Equal(t, len(*filmsPerLanguage[0].Films), limit) assert.Equal(t, len(*filmsPerLanguage[0].Film), limit)
//spew.Dump(filmsPerLanguage) //spew.Dump(filmsPerLanguage)
@ -153,13 +156,13 @@ func TestJoinQuerySlice(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(filmsPerLanguage), 1) assert.Equal(t, len(filmsPerLanguage), 1)
assert.Equal(t, len(*filmsPerLanguage[0].Films), limit) assert.Equal(t, len(*filmsPerLanguage[0].Film), limit)
} }
func TestJoinQuerySliceWithPtrs(t *testing.T) { func TestJoinQuerySliceWithPtrs(t *testing.T) {
type FilmsPerLanguage struct { type FilmsPerLanguage struct {
Language model.Language Language model.Language
Films *[]*model.Film Film *[]*model.Film
} }
limit := int64(3) limit := int64(3)
@ -175,7 +178,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(filmsPerLanguageWithPtrs), 1) assert.Equal(t, len(filmsPerLanguageWithPtrs), 1)
assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Films), int(limit)) assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit))
} }
func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { func TestSelect_WithoutUniqueColumnSelected(t *testing.T) {
@ -323,7 +326,7 @@ func TestSelectSelfJoin(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
spew.Dump(theSameLengthFilms[0]) //spew.Dump(theSameLengthFilms[0])
assert.Equal(t, len(theSameLengthFilms), 6972) assert.Equal(t, len(theSameLengthFilms), 6972)
} }
@ -343,7 +346,7 @@ func TestSelectAliasColumn(t *testing.T) {
Select(f1.Title.As("thesame_length_films.title1"), Select(f1.Title.As("thesame_length_films.title1"),
f2.Title.As("thesame_length_films.title2"), f2.Title.As("thesame_length_films.title2"),
f1.Length.As("thesame_length_films.length")). f1.Length.As("thesame_length_films.length")).
OrderBy(f1.Length.Asc()). OrderBy(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()).
Limit(1000) Limit(1000)
queryStr, err := query.String() queryStr, err := query.String()
@ -361,7 +364,227 @@ func TestSelectAliasColumn(t *testing.T) {
//spew.Dump(films) //spew.Dump(films)
assert.Equal(t, len(films), 1000) assert.Equal(t, len(films), 1000)
assert.DeepEqual(t, films[0], thesameLengthFilms{"Ridgemont Submarine", "Iron Moon", 46}) assert.DeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46})
}
type Manager staff
type staff struct {
StaffID int32 `sql:"unique"`
FirstName string
LastName string
//Address *model.Address
//Email *string
//StoreID int16
//Active bool
//Username string
//Password *string
//LastUpdate time.Time
*Manager //`sqlbuilder:"manager"`
}
func TestSelectSelfReferenceType(t *testing.T) {
manager := Staff.As("manager")
query := Staff.
InnerJoinUsing(Address, Staff.AddressID, Address.AddressID).
InnerJoinUsing(manager, Staff.StaffID, manager.StaffID).
Select(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
staffs := []staff{}
err = query.Execute(db, &staffs)
assert.NilError(t, err)
//spew.Dump(staffs)
}
func TestSubQuery(t *testing.T) {
//selectStmtTable := Actor.Select(Actor.FirstName, Actor.LastName).AsTable("table_expression")
//
//query := selectStmtTable.Select(
// selectStmtTable.ColumnFrom(Actor.FirstName).As("nesto"),
// selectStmtTable.Column("actor.last_name").As("nesto2"),
// )
//
//queryStr, err := query.String()
//
//assert.NilError(t, err)
//
//fmt.Println(queryStr)
//avrgCustomer := Customer.Select(Customer.LastName).Limit(1).AsExpression()
//
//Customer.
// InnerJoinUsing(selectStmtTable, Customer.LastName, selectStmtTable.Column("first_name")).
// Select(Customer.AllColumns, selectStmtTable.Column("first_name")).
// Where(Actor.LastName.Neq(avrgCustomer))
rFilmsOnly := Film.Select(Film.FilmID, Film.Title, Film.Rating).
Where(Film.Rating.Eq(sqlbuilder.Literal("R"))).
AsTable("films")
query := Actor.InnerJoinUsing(FilmActor, Actor.ActorID, FilmActor.FilmID).
InnerJoinUsing(rFilmsOnly, FilmActor.FilmID, rFilmsOnly.ColumnFrom(Film.FilmID)).
Select(
Actor.AllColumns,
FilmActor.AllColumns,
rFilmsOnly.ColumnFrom(Film.Title).As("film.title"),
rFilmsOnly.ColumnFrom(Film.Rating).As("film.rating"),
)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
}
func TestSelectFunctions(t *testing.T) {
query := Film.Select(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate"))
str, err := query.String()
assert.NilError(t, err)
assert.Equal(t, str, `SELECT MAX(film.rental_rate) AS "max_film_rate" FROM dvds.film`)
fmt.Println(str)
}
func TestSelectQueryScalar(t *testing.T) {
maxFilmRentalRate := Film.Select(sqlbuilder.MAX(Film.RentalRate))
query := Film.Select(Film.AllColumns).
Where(Film.RentalRate.Eq(maxFilmRentalRate)).
OrderBy(Film.FilmID)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
maxRentalRateFilms := []model.Film{}
err = query.Execute(db, &maxRentalRateFilms)
assert.NilError(t, err)
assert.Equal(t, len(maxRentalRateFilms), 336)
assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{
FilmID: 2,
Title: "Ace Goldfinger",
Description: stringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
ReleaseYear: int32Ptr(2006),
Language: nil,
RentalRate: 4.99,
Length: int16Ptr(48),
ReplacementCost: 12.99,
Rating: stringPtr("G"),
RentalDuration: 3,
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951 +0000"),
SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"),
Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14",
})
//spew.Dump(maxRentalRateFilms[0])
}
func TestSelectGroupByHaving(t *testing.T) {
customersPaymentQuery := Payment.
Select(
Payment.CustomerID.As("customer_payment_sum.customer_id"),
sqlbuilder.SUM(Payment.Amount).As("customer_payment_sum.amount_sum"),
).
GroupBy(Payment.CustomerID).
OrderBy(sqlbuilder.SUM(Payment.Amount)).
HAVING(sqlbuilder.Gt(sqlbuilder.SUM(Payment.Amount), sqlbuilder.Literal(100)))
queryStr, err := customersPaymentQuery.String()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>100 ORDER BY SUM(payment.amount)`)
type CustomerPaymentSum struct {
CustomerID int16
AmountSum float64
}
customerPaymentSum := []CustomerPaymentSum{}
err = customersPaymentQuery.Execute(db, &customerPaymentSum)
assert.NilError(t, err)
assert.Equal(t, len(customerPaymentSum), 296)
assert.DeepEqual(t, customerPaymentSum[0], CustomerPaymentSum{
CustomerID: 135,
AmountSum: 100.72,
})
}
func TestSelectGroupBy2(t *testing.T) {
type CustomerWithAmounts struct {
Customer *model.Customer
AmountSum float64
}
customersWithAmounts := []CustomerWithAmounts{}
customersPaymentSubQuery := Payment.
Select(
Payment.CustomerID,
sqlbuilder.SUM(Payment.Amount).As("amount_sum"),
).
GroupBy(Payment.CustomerID)
customersPaymentTable := customersPaymentSubQuery.AsTable("customer_payment_sum")
amountSumColumn := customersPaymentTable.Column("amount_sum")
query := Customer.
InnerJoinUsing(customersPaymentTable, Customer.CustomerID, customersPaymentTable.ColumnFrom(Payment.CustomerID)).
Select(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")).
OrderBy(amountSumColumn)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
err = query.Execute(db, &customersWithAmounts)
assert.NilError(t, err)
//spew.Dump(customersWithAmounts)
assert.Equal(t, len(customersWithAmounts), 599)
assert.DeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{
CustomerID: 318,
StoreID: 1,
FirstName: "Brian",
LastName: "Wyman",
Email: stringPtr("brian.wyman@sakilacustomer.org"),
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"),
Active: int32Ptr(1),
})
assert.Equal(t, customersWithAmounts[0].AmountSum, 27.93)
}
func int16Ptr(i int16) *int16 {
return &i
} }
func int32Ptr(i int32) *int32 { func int32Ptr(i int32) *int32 {