diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index e6ebf6c..e9c89d4 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -5,26 +5,29 @@ package sqlbuilder import ( "bytes" "regexp" + "strings" "github.com/dropbox/godropbox/errors" ) // XXX: Maybe add UIntColumn -// Representation of a table for query generation +// Representation of a tableName for query generation type Column interface { isProjectionInterface isExpressionInterface - As(alias string) Column + As(alias string) Projection Name() string + + TableName() string // Serialization for use in column lists SerializeSqlForColumnList(out *bytes.Buffer) error // Serialization for use in an expression (Clause) SerializeSql(out *bytes.Buffer) error - // Internal function for tracking table that a column belongs to + // Internal function for tracking tableName that a column belongs to // for the purpose of serialization setTableName(table string) error @@ -73,13 +76,13 @@ const ( type baseColumn struct { isProjection isExpression - name string - nullable NullableColumn - table string - alias string + name string + nullable NullableColumn + tableName string + alias string } -func (c *baseColumn) As(alias string) Column { +func (c *baseColumn) As(alias string) Projection { newBaseColumn := *c newBaseColumn.alias = alias @@ -90,8 +93,12 @@ func (c *baseColumn) Name() string { return c.name } +func (c *baseColumn) TableName() string { + return c.tableName +} + func (c *baseColumn) setTableName(table string) error { - c.table = table + c.tableName = table return nil } @@ -101,19 +108,27 @@ func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { if c.alias != "" { _, _ = out.WriteString(" AS \"" + c.alias + "\"") - } else if c.table != "" { - _, _ = out.WriteString(" AS \"" + c.table + "." + c.name + "\"") + } else if c.tableName != "" { + _, _ = out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"") } return nil } func (c baseColumn) SerializeSql(out *bytes.Buffer) error { - if c.table != "" { - _, _ = out.WriteString(c.table) + if c.tableName != "" { + _, _ = out.WriteString(c.tableName) _, _ = out.WriteString(".") } + containsDot := strings.Contains(c.name, ".") + + if containsDot { + out.WriteString("\"") + } _, _ = out.WriteString(c.name) + if containsDot { + out.WriteString("\"") + } return nil } @@ -323,11 +338,11 @@ func validIdentifierName(name string) bool { } // -//// Pseudo Column type returned by table.C(name) +//// Pseudo Column type returned by tableName.C(name) //type deferredLookupColumn struct { // isProjection // isExpression -// table *Table +// tableName *Table // colName string // // cachedColumn NonAliasColumn @@ -348,7 +363,7 @@ func validIdentifierName(name string) bool { // return c.cachedColumn.SerializeSql(out) // } // -// col, err := c.table.getColumn(c.colName) +// col, err := c.tableName.getColumn(c.colName) // if err != nil { // return err // } @@ -357,7 +372,7 @@ func validIdentifierName(name string) bool { // return col.SerializeSql(out) //} // -//func (c *deferredLookupColumn) setTableName(table string) error { +//func (c *deferredLookupColumn) setTableName(tableName string) error { // return errors.Newf( // "Lookup column '%s' should never have setTableName called on it", // c.colName) diff --git a/sqlbuilder/column_test.go b/sqlbuilder/column_test.go index 845091f..96c31e1 100644 --- a/sqlbuilder/column_test.go +++ b/sqlbuilder/column_test.go @@ -29,7 +29,7 @@ func (s *ColumnSuite) TestRealColumnName(c *gc.C) { func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) { col := IntColumn("col", Nullable) - // Without table name + // Without tableName name buf := &bytes.Buffer{} err := col.SerializeSqlForColumnList(buf) @@ -38,7 +38,7 @@ func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) { sql := buf.String() c.Assert(sql, gc.Equals, "col") - // With table name + // With tableName name err = col.setTableName("foo") c.Assert(err, gc.IsNil) @@ -54,7 +54,7 @@ func (s *ColumnSuite) TestRealColumnSerializeSqlForColumnList(c *gc.C) { func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) { col := IntColumn("col", Nullable) - // Without table name + // Without tableName name buf := &bytes.Buffer{} err := col.SerializeSql(buf) @@ -63,7 +63,7 @@ func (s *ColumnSuite) TestRealColumnSerializeSql(c *gc.C) { sql := buf.String() c.Assert(sql, gc.Equals, "col") - // With table name + // With tableName name err = col.setTableName("foo") c.Assert(err, gc.IsNil) diff --git a/sqlbuilder/doc.go b/sqlbuilder/doc.go index 3f9170a..6d861f4 100644 --- a/sqlbuilder/doc.go +++ b/sqlbuilder/doc.go @@ -9,7 +9,7 @@ // // Known limitations for SELECT queries: // - does not support subqueries (since mysql is bad at it) -// - does not currently support join table alias (and hence self join) +// - does not currently support join tableName alias (and hence self join) // - does not support NATURAL joins and join USING // // Known limitation for INSERT statements: @@ -17,9 +17,9 @@ // // Known limitation for UPDATE statements: // - does not support update without a WHERE clause (since it is dangerous) -// - does not support multi-table update +// - does not support multi-tableName update // // Known limitation for DELETE statements: // - does not support delete without a WHERE clause (since it is dangerous) -// - does not support multi-table delete +// - does not support multi-tableName delete package sqlbuilder diff --git a/sqlbuilder/expression_test.go b/sqlbuilder/expression_test.go index d916a57..e79400b 100644 --- a/sqlbuilder/expression_test.go +++ b/sqlbuilder/expression_test.go @@ -262,7 +262,7 @@ func (s *ExprSuite) TestLtExpr(c *gc.C) { } func (s *ExprSuite) TestLteExpr(c *gc.C) { - expr := LteL(table1Col1, "foo\"';drop user table;") + expr := LteL(table1Col1, "foo\"';drop user tableName;") buf := &bytes.Buffer{} @@ -273,7 +273,7 @@ func (s *ExprSuite) TestLteExpr(c *gc.C) { c.Assert( sql, gc.Equals, - "table1.col1<='foo\\\"\\';drop user table;'") + "table1.col1<='foo\\\"\\';drop user tableName;'") } func (s *ExprSuite) TestGtExpr(c *gc.C) { diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go new file mode 100644 index 0000000..c7186b8 --- /dev/null +++ b/sqlbuilder/select_statement.go @@ -0,0 +1,268 @@ +package sqlbuilder + +import ( + "bytes" + "database/sql" + "fmt" + "github.com/dropbox/godropbox/errors" + "github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution" + "reflect" +) + +type SelectStatement interface { + Statement + Expression + + Where(expression BoolExpression) SelectStatement + AndWhere(expression BoolExpression) SelectStatement + GroupBy(expressions ...Expression) SelectStatement + HAVING(expressions BoolExpression) SelectStatement + + OrderBy(clauses ...OrderByClause) SelectStatement + Limit(limit int64) SelectStatement + Offset(offset int64) SelectStatement + Distinct() SelectStatement + WithSharedLock() SelectStatement + ForUpdate() SelectStatement + Comment(comment string) SelectStatement + Copy() SelectStatement + + AsTable(alias string) *SelectStatementTable +} + +// NOTE: SelectStatement purposely does not implement the Table interface since +// mysql's subquery performance is horrible. +type selectStatementImpl struct { + isExpression + table ReadableTable + projections []Projection + where BoolExpression + group *listClause + having BoolExpression + order *listClause + comment string + limit, offset int64 + withSharedLock bool + forUpdate bool + distinct bool +} + +func newSelectStatement( + table ReadableTable, + projections []Projection) SelectStatement { + + return &selectStatementImpl{ + table: table, + projections: projections, + limit: -1, + offset: -1, + withSharedLock: false, + forUpdate: false, + distinct: false, + } +} + +func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer) error { + str, err := s.String() + + if err != nil { + return err + } + + out.WriteString("( ") + out.WriteString(str) + out.WriteString(")") + + return nil +} + +func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { + return &SelectStatementTable{ + statement: s, + alias: alias, + } +} + +func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error { + destinationType := reflect.TypeOf(destination) + + if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct { + s.Limit(1) + } + + query, err := s.String() + + if err != nil { + return err + } + + return execution.Execute(db, query, destination) +} + +func (s *selectStatementImpl) Copy() SelectStatement { + ret := *s + return &ret +} + +// Further filter the query, instead of replacing the filter +func (q *selectStatementImpl) AndWhere( + expression BoolExpression) SelectStatement { + + if q.where == nil { + return q.Where(expression) + } + q.where = And(q.where, expression) + return q +} + +func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { + q.where = expression + return q +} + +func (q *selectStatementImpl) GroupBy( + expressions ...Expression) SelectStatement { + + q.group = &listClause{ + clauses: make([]Clause, len(expressions), len(expressions)), + includeParentheses: false, + } + + for i, e := range expressions { + q.group.clauses[i] = e + } + return q +} + +func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement { + q.having = expression + return q +} + +func (q *selectStatementImpl) OrderBy( + clauses ...OrderByClause) SelectStatement { + + q.order = newOrderByListClause(clauses...) + return q +} + +func (q *selectStatementImpl) Limit(limit int64) SelectStatement { + q.limit = limit + return q +} + +func (q *selectStatementImpl) Distinct() SelectStatement { + q.distinct = true + return q +} + +func (q *selectStatementImpl) WithSharedLock() SelectStatement { + // We don't need to grab a read lock if we're going to grab a write one + if !q.forUpdate { + q.withSharedLock = true + } + return q +} + +func (q *selectStatementImpl) ForUpdate() SelectStatement { + // Clear a request for a shared lock if we're asking for a write one + q.withSharedLock = false + q.forUpdate = true + return q +} + +func (q *selectStatementImpl) Offset(offset int64) SelectStatement { + q.offset = offset + return q +} + +func (q *selectStatementImpl) Comment(comment string) SelectStatement { + q.comment = comment + return q +} + +// Return the properly escaped SQL statement, against the specified database +func (q *selectStatementImpl) String() (sql string, err error) { + buf := new(bytes.Buffer) + _, _ = buf.WriteString("SELECT ") + + if err = writeComment(q.comment, buf); err != nil { + return + } + + if q.distinct { + _, _ = buf.WriteString("DISTINCT ") + } + + if q.projections == nil || len(q.projections) == 0 { + return "", errors.Newf( + "No column selected. Generated sql: %s", + buf.String()) + } + + for i, col := range q.projections { + if i > 0 { + _ = buf.WriteByte(',') + } + if col == nil { + return "", errors.Newf( + "nil column selected. Generated sql: %s", + buf.String()) + } + if err = col.SerializeSqlForColumnList(buf); err != nil { + return + } + } + + _, _ = buf.WriteString(" FROM ") + if q.table == nil { + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) + } + if err = q.table.SerializeSql(buf); err != nil { + return + } + + if q.where != nil { + _, _ = buf.WriteString(" WHERE ") + if err = q.where.SerializeSql(buf); err != nil { + return + } + } + + if q.group != nil { + _, _ = buf.WriteString(" GROUP BY ") + if err = q.group.SerializeSql(buf); err != nil { + return + } + } + + if q.having != nil { + buf.WriteString(" HAVING ") + if err = q.having.SerializeSql(buf); err != nil { + return + } + } + + if q.order != nil { + _, _ = buf.WriteString(" ORDER BY ") + if err = q.order.SerializeSql(buf); err != nil { + return + } + } + + if q.limit >= 0 { + if q.offset >= 0 { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit)) + } else { + _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit)) + } + } + + if q.forUpdate { + _, _ = buf.WriteString(" FOR UPDATE") + } else if q.withSharedLock { + _, _ = buf.WriteString(" LOCK IN SHARE MODE") + } + + return buf.String(), nil +} diff --git a/sqlbuilder/select_statement_table.go b/sqlbuilder/select_statement_table.go new file mode 100644 index 0000000..8dcf5d4 --- /dev/null +++ b/sqlbuilder/select_statement_table.go @@ -0,0 +1,75 @@ +package sqlbuilder + +import "bytes" + +type SelectStatementTable struct { + statement SelectStatement + columns []NonAliasColumn + alias string +} + +func (s *SelectStatementTable) Columns() []NonAliasColumn { + return s.columns +} + +func (s *SelectStatementTable) Column(name string) NonAliasColumn { + return &baseColumn{ + name: name, + tableName: s.alias, + } +} + +func (s *SelectStatementTable) ColumnFrom(column NonAliasColumn) NonAliasColumn { + return &baseColumn{ + name: column.TableName() + "." + column.Name(), + tableName: s.alias, + } +} + +func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error { + out.WriteString("( ") + statementStr, err := s.statement.String() + + if err != nil { + return err + } + + out.WriteString(statementStr) + + out.WriteString(" ) AS ") + out.WriteString(s.alias) + + return nil +} + +// Generates a select query on the current tableName. +func (s *SelectStatementTable) Select(projections ...Projection) SelectStatement { + return newSelectStatement(s, projections) +} + +// Creates a inner join tableName expression using onCondition. +func (s *SelectStatementTable) InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { + return InnerJoinOn(s, table, onCondition) +} + +func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable { + return InnerJoinOn(s, table, col1.Eq(col2)) +} + +// Creates a left join tableName expression using onCondition. +func (s *SelectStatementTable) LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { + return LeftJoinOn(s, table, onCondition) +} + +// Creates a right join tableName expression using onCondition. +func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { + return RightJoinOn(s, table, onCondition) +} + +func (s *SelectStatementTable) FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable { + return FullJoin(s, table, col1.Eq(col2)) +} + +func (s *SelectStatementTable) CrossJoin(table ReadableTable) ReadableTable { + return CrossJoin(s, table) +} diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 471c0bc..1c244a3 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -4,8 +4,6 @@ import ( "bytes" "database/sql" "fmt" - "github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution" - "reflect" "regexp" "github.com/dropbox/godropbox/errors" @@ -17,22 +15,6 @@ type Statement interface { Execute(db *sql.DB, destination interface{}) error } -type SelectStatement interface { - Statement - - Where(expression BoolExpression) SelectStatement - AndWhere(expression BoolExpression) SelectStatement - GroupBy(expressions ...Expression) SelectStatement - OrderBy(clauses ...OrderByClause) SelectStatement - Limit(limit int64) SelectStatement - Offset(offset int64) SelectStatement - Distinct() SelectStatement - WithSharedLock() SelectStatement - ForUpdate() SelectStatement - Comment(comment string) SelectStatement - Copy() SelectStatement -} - type InsertStatement interface { Statement @@ -50,7 +32,7 @@ type InsertStatement interface { type UnionStatement interface { Statement - // Warning! You cannot include table names for the next 4 clauses, or + // Warning! You cannot include tableName names for the next 4 clauses, or // you'll get errors like: // Table 'server_file_journal' from one of the SELECTs cannot be used in // global ORDER clause @@ -91,9 +73,9 @@ type LockStatement interface { AddWriteLock(table *Table) LockStatement } -// UnlockStatement can be used to release table locks taken using LockStatement. +// UnlockStatement can be used to release tableName locks taken using LockStatement. // NOTE: You can not selectively release a lock and continue to hold lock on -// another table. UnlockStatement releases all the lock held in the current +// another tableName. UnlockStatement releases all the lock held in the current // session. type UnlockStatement interface { Statement @@ -222,7 +204,7 @@ func (us *unionStatementImpl) String() (sql string, err error) { return "", errors.Newf( "All inner selects in Union statement must select the " + "same number of columns. For sanity, you probably " + - "want to select the same table columns in the same " + + "want to select the same tableName columns in the same " + "order. If you are selecting on multiple tables, " + "use Null to pad to the right number of fields.") } @@ -279,212 +261,6 @@ func (us *unionStatementImpl) String() (sql string, err error) { return buf.String(), nil } -// -// SELECT Statement ============================================================ -// - -func newSelectStatement( - table ReadableTable, - projections []Projection) SelectStatement { - - return &selectStatementImpl{ - table: table, - projections: projections, - limit: -1, - offset: -1, - withSharedLock: false, - forUpdate: false, - distinct: false, - } -} - -// NOTE: SelectStatement purposely does not implement the Table interface since -// mysql's subquery performance is horrible. -type selectStatementImpl struct { - table ReadableTable - projections []Projection - where BoolExpression - group *listClause - order *listClause - comment string - limit, offset int64 - withSharedLock bool - forUpdate bool - distinct bool -} - -func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error { - destinationType := reflect.TypeOf(destination) - - if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct { - s.Limit(1) - } - - query, err := s.String() - - if err != nil { - return err - } - - return execution.Execute(db, query, destination) -} - -func (s *selectStatementImpl) Copy() SelectStatement { - ret := *s - return &ret -} - -// Further filter the query, instead of replacing the filter -func (q *selectStatementImpl) AndWhere( - expression BoolExpression) SelectStatement { - - if q.where == nil { - return q.Where(expression) - } - q.where = And(q.where, expression) - return q -} - -func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { - q.where = expression - return q -} - -func (q *selectStatementImpl) GroupBy( - expressions ...Expression) SelectStatement { - - q.group = &listClause{ - clauses: make([]Clause, len(expressions), len(expressions)), - includeParentheses: false, - } - - for i, e := range expressions { - q.group.clauses[i] = e - } - return q -} - -func (q *selectStatementImpl) OrderBy( - clauses ...OrderByClause) SelectStatement { - - q.order = newOrderByListClause(clauses...) - return q -} - -func (q *selectStatementImpl) Limit(limit int64) SelectStatement { - q.limit = limit - return q -} - -func (q *selectStatementImpl) Distinct() SelectStatement { - q.distinct = true - return q -} - -func (q *selectStatementImpl) WithSharedLock() SelectStatement { - // We don't need to grab a read lock if we're going to grab a write one - if !q.forUpdate { - q.withSharedLock = true - } - return q -} - -func (q *selectStatementImpl) ForUpdate() SelectStatement { - // Clear a request for a shared lock if we're asking for a write one - q.withSharedLock = false - q.forUpdate = true - return q -} - -func (q *selectStatementImpl) Offset(offset int64) SelectStatement { - q.offset = offset - return q -} - -func (q *selectStatementImpl) Comment(comment string) SelectStatement { - q.comment = comment - return q -} - -// Return the properly escaped SQL statement, against the specified database -func (q *selectStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("SELECT ") - - if err = writeComment(q.comment, buf); err != nil { - return - } - - if q.distinct { - _, _ = buf.WriteString("DISTINCT ") - } - - if q.projections == nil || len(q.projections) == 0 { - return "", errors.Newf( - "No column selected. Generated sql: %s", - buf.String()) - } - - for i, col := range q.projections { - if i > 0 { - _ = buf.WriteByte(',') - } - if col == nil { - return "", errors.Newf( - "nil column selected. Generated sql: %s", - buf.String()) - } - if err = col.SerializeSqlForColumnList(buf); err != nil { - return - } - } - - _, _ = buf.WriteString(" FROM ") - if q.table == nil { - return "", errors.Newf("nil table. Generated sql: %s", buf.String()) - } - if err = q.table.SerializeSql(buf); err != nil { - return - } - - if q.where != nil { - _, _ = buf.WriteString(" WHERE ") - if err = q.where.SerializeSql(buf); err != nil { - return - } - } - - if q.group != nil { - _, _ = buf.WriteString(" GROUP BY ") - if err = q.group.SerializeSql(buf); err != nil { - return - } - } - - if q.order != nil { - _, _ = buf.WriteString(" ORDER BY ") - if err = q.order.SerializeSql(buf); err != nil { - return - } - } - - if q.limit >= 0 { - if q.offset >= 0 { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit)) - } else { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit)) - } - } - - if q.forUpdate { - _, _ = buf.WriteString(" FOR UPDATE") - } else if q.withSharedLock { - _, _ = buf.WriteString(" LOCK IN SHARE MODE") - } - - return buf.String(), nil -} - // // INSERT Statement ============================================================ // @@ -560,7 +336,7 @@ func (s *insertStatementImpl) String() (sql string, err error) { } if s.table == nil { - return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) } if err = s.table.SerializeSql(buf); err != nil { @@ -728,7 +504,7 @@ func (u *updateStatementImpl) String() (sql string, err error) { } if u.table == nil { - return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) } if err = u.table.SerializeSql(buf); err != nil { @@ -863,7 +639,7 @@ func (d *deleteStatementImpl) String() (sql string, err error) { } if d.table == nil { - return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) } if err = d.table.SerializeSql(buf); err != nil { @@ -919,13 +695,13 @@ func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error { return nil } -// AddReadLock takes read lock on the table. +// AddReadLock takes read lock on the tableName. func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { s.locks = append(s.locks, tableLock{t: t, w: false}) return s } -// AddWriteLock takes write lock on the table. +// AddWriteLock takes write lock on the tableName. func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement { s.locks = append(s.locks, tableLock{t: t, w: true}) return s @@ -941,7 +717,7 @@ func (s *lockStatementImpl) String() (sql string, err error) { for idx, lock := range s.locks { if lock.t == nil { - return "", errors.Newf("nil table. Generated sql: %s", buf.String()) + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) } if err = lock.t.SerializeSql(buf); err != nil { @@ -962,7 +738,7 @@ func (s *lockStatementImpl) String() (sql string, err error) { return buf.String(), nil } -// NewUnlockStatement returns SQL statement that can be used to release table locks +// NewUnlockStatement returns SQL statement that can be used to release tableName locks // grabbed by the current session. func NewUnlockStatement() UnlockStatement { return &unlockStatementImpl{} diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index ba17324..8ebeda4 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -609,7 +609,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { gc.Equals, "All inner selects in Union statement must select the "+ "same number of columns. For sanity, you probably "+ - "want to select the same table columns in the same "+ + "want to select the same tableName columns in the same "+ "order. If you are selecting on multiple tables, "+ "use Null to pad to the right number of fields.") } diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 1e40f35..ddfdbe6 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -8,29 +8,31 @@ import ( "github.com/dropbox/godropbox/errors" ) -// The sql table read interface. NOTE: NATURAL JOINs, and join "USING" clause +// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause // are not supported. type ReadableTable interface { - // Returns the list of columns that are in the current table expression. + // Returns the list of columns that are in the current tableName expression. Columns() []NonAliasColumn - // Generates the sql string for the current table expression. Note: the + Column(name string) NonAliasColumn + + // Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. - // The database is the name of the database the table is on + // The database is the name of the database the tableName is on SerializeSql(out *bytes.Buffer) error - // Generates a select query on the current table. + // Generates a select query on the current tableName. Select(projections ...Projection) SelectStatement - // Creates a inner join table expression using onCondition. + // Creates a inner join tableName expression using onCondition. InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable - // Creates a left join table expression using onCondition. + // Creates a left join tableName expression using onCondition. LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable - // Creates a right join table expression using onCondition. + // Creates a right join tableName expression using onCondition. RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable @@ -38,14 +40,14 @@ type ReadableTable interface { CrossJoin(table ReadableTable) ReadableTable } -// The sql table write interface. +// The sql tableName write interface. type WritableTable interface { - // Returns the list of columns that are in the table. + // Returns the list of columns that are in the tableName. Columns() []NonAliasColumn - // Generates the sql string for the current table expression. Note: the + // Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. - // The database is the name of the database the table is on + // The database is the name of the database the tableName is on SerializeSql(out *bytes.Buffer) error Insert(columns ...NonAliasColumn) InsertStatement @@ -53,11 +55,11 @@ type WritableTable interface { Delete() DeleteStatement } -// Defines a physical table in the database that is both readable and writable. +// Defines a physical tableName in the database that is both readable and writable. // This function will panic if name is not valid func NewTable(schemaName, name string, columns ...NonAliasColumn) *Table { if !validIdentifierName(name) { - panic("Invalid table name") + panic("Invalid tableName name") } t := &Table{ @@ -91,28 +93,28 @@ type Table struct { forcedIndex string } -// Returns the specified column, or errors if it doesn't exist in the table +// Returns the specified column, or errors if it doesn't exist in the tableName func (t *Table) getColumn(name string) (NonAliasColumn, error) { if c, ok := t.columnLookup[name]; ok { return c, nil } - return nil, errors.Newf("No such column '%s' in table '%s'", name, t.name) + return nil, errors.Newf("No such column '%s' in tableName '%s'", name, t.name) } -// Returns a pseudo column representation of the column name. Error checking -// is deferred to SerializeSql. -//func (t *Table) C(name string) NonAliasColumn { -// return &deferredLookupColumn{ -// table: t, -// colName: name, -// } -//} +func (t *Table) Column(name string) NonAliasColumn { + return &baseColumn{ + name: name, + nullable: NotNullable, + tableName: t.name, + } +} -// Returns all columns for a table as a slice of projections +// Returns all columns for a tableName as a slice of projections func (t *Table) Projections() []Projection { result := make([]Projection, 0) for _, col := range t.columns { + col.Asc() result = append(result, col) } @@ -130,7 +132,7 @@ func (t *Table) SetAlias(alias string) { } } -// Returns the table's name in the database +// Returns the tableName's name in the database func (t *Table) Name() string { return t.name } @@ -139,19 +141,19 @@ func (t *Table) SchemaName() string { return t.schemaName } -// Returns a list of the table's columns +// Returns a list of the tableName's columns func (t *Table) Columns() []NonAliasColumn { return t.columns } -// Returns a copy of this table, but with the specified index forced. +// Returns a copy of this tableName, but with the specified index forced. func (t *Table) ForceIndex(index string) *Table { newTable := *t newTable.forcedIndex = index return &newTable } -// Generates the sql string for the current table expression. Note: the +// Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. func (t *Table) SerializeSql(out *bytes.Buffer) error { if !validIdentifierName(t.schemaName) { @@ -179,12 +181,12 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error { return nil } -// Generates a select query on the current table. +// Generates a select query on the current tableName. func (t *Table) Select(projections ...Projection) SelectStatement { return newSelectStatement(t, projections) } -// Creates a inner join table expression using onCondition. +// Creates a inner join tableName expression using onCondition. func (t *Table) InnerJoinOn( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -200,7 +202,7 @@ func (t *Table) InnerJoinUsing( return InnerJoinOn(t, table, col1.Eq(col2)) } -// Creates a left join table expression using onCondition. +// Creates a left join tableName expression using onCondition. func (t *Table) LeftJoinOn( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -208,7 +210,7 @@ func (t *Table) LeftJoinOn( return LeftJoinOn(t, table, onCondition) } -// Creates a right join table expression using onCondition. +// Creates a right join tableName expression using onCondition. func (t *Table) RightJoinOn( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -315,6 +317,10 @@ func (t *joinTable) Columns() []NonAliasColumn { return columns } +func (t *joinTable) Column(name string) NonAliasColumn { + panic("Not implemented") +} + func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { if t.lhs == nil { diff --git a/sqlbuilder/table_test.go b/sqlbuilder/table_test.go index da55cad..7a5a7af 100644 --- a/sqlbuilder/table_test.go +++ b/sqlbuilder/table_test.go @@ -52,7 +52,7 @@ func (s *TableSuite) TestValidForcedIndex(c *gc.C) { sql := buf.String() c.Assert(sql, gc.Equals, "db.table1 FORCE INDEX (foo)") - // Ensure the original table is unchanged + // Ensure the original tableName is unchanged buf = &bytes.Buffer{} err = table1.SerializeSql(buf) c.Assert(err, gc.IsNil) diff --git a/sqlbuilder/types.go b/sqlbuilder/types.go index 271f130..3aa0b4b 100644 --- a/sqlbuilder/types.go +++ b/sqlbuilder/types.go @@ -32,6 +32,8 @@ type BoolExpression interface { type Projection interface { Clause isProjectionInterface + + As(alias string) Projection SerializeSqlForColumnList(out *bytes.Buffer) error } @@ -51,6 +53,10 @@ func (cl ColumnList) SerializeSql(out *bytes.Buffer) error { func (cl ColumnList) isProjectionType() { } +func (cl ColumnList) As(name string) Projection { + panic("Unallowed operation ") +} + func (cl ColumnList) SerializeSqlForColumnList(out *bytes.Buffer) error { for i, column := range cl { column.SerializeSqlForColumnList(out) diff --git a/tests/generator_test.go b/tests/generator_test.go index 9a4450e..bd47281 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -3,8 +3,8 @@ package tests import ( "database/sql" "fmt" - "github.com/davecgh/go-spew/spew" "github.com/sub0Zero/go-sqlbuilder/generator" + "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" "gotest.tools/assert" @@ -79,6 +79,7 @@ func TestSelect_ScanToSlice(t *testing.T) { queryStr, err := query.String() assert.NilError(t, err) + fmt.Println(queryStr) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`) err = query.Execute(db, &customers) @@ -119,7 +120,7 @@ func TestJoinQueryStruct(t *testing.T) { func TestJoinQuerySlice(t *testing.T) { type FilmsPerLanguage struct { Language *model.Language - Films *[]model.Film + Film *[]model.Film } filmsPerLanguage := []FilmsPerLanguage{} @@ -143,8 +144,10 @@ func TestJoinQuerySlice(t *testing.T) { //fmt.Println("--------------- result --------------- ") //spew.Dump(filmsPerLanguage) + //spew.Dump(filmsPerLanguage) + assert.Equal(t, len(filmsPerLanguage), 1) - assert.Equal(t, len(*filmsPerLanguage[0].Films), limit) + assert.Equal(t, len(*filmsPerLanguage[0].Film), limit) //spew.Dump(filmsPerLanguage) @@ -153,13 +156,13 @@ func TestJoinQuerySlice(t *testing.T) { assert.NilError(t, err) assert.Equal(t, len(filmsPerLanguage), 1) - assert.Equal(t, len(*filmsPerLanguage[0].Films), limit) + assert.Equal(t, len(*filmsPerLanguage[0].Film), limit) } func TestJoinQuerySliceWithPtrs(t *testing.T) { type FilmsPerLanguage struct { Language model.Language - Films *[]*model.Film + Film *[]*model.Film } limit := int64(3) @@ -175,7 +178,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { assert.NilError(t, err) assert.Equal(t, len(filmsPerLanguageWithPtrs), 1) - assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Films), int(limit)) + assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit)) } func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { @@ -323,7 +326,7 @@ func TestSelectSelfJoin(t *testing.T) { assert.NilError(t, err) - spew.Dump(theSameLengthFilms[0]) + //spew.Dump(theSameLengthFilms[0]) assert.Equal(t, len(theSameLengthFilms), 6972) } @@ -343,7 +346,7 @@ func TestSelectAliasColumn(t *testing.T) { Select(f1.Title.As("thesame_length_films.title1"), f2.Title.As("thesame_length_films.title2"), f1.Length.As("thesame_length_films.length")). - OrderBy(f1.Length.Asc()). + OrderBy(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()). Limit(1000) queryStr, err := query.String() @@ -361,7 +364,227 @@ func TestSelectAliasColumn(t *testing.T) { //spew.Dump(films) assert.Equal(t, len(films), 1000) - assert.DeepEqual(t, films[0], thesameLengthFilms{"Ridgemont Submarine", "Iron Moon", 46}) + assert.DeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) +} + +type Manager staff + +type staff struct { + StaffID int32 `sql:"unique"` + FirstName string + LastName string + //Address *model.Address + //Email *string + //StoreID int16 + //Active bool + //Username string + //Password *string + //LastUpdate time.Time + *Manager //`sqlbuilder:"manager"` +} + +func TestSelectSelfReferenceType(t *testing.T) { + + manager := Staff.As("manager") + + query := Staff. + InnerJoinUsing(Address, Staff.AddressID, Address.AddressID). + InnerJoinUsing(manager, Staff.StaffID, manager.StaffID). + Select(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName) + + queryStr, err := query.String() + assert.NilError(t, err) + fmt.Println(queryStr) + + staffs := []staff{} + + err = query.Execute(db, &staffs) + + assert.NilError(t, err) + + //spew.Dump(staffs) +} + +func TestSubQuery(t *testing.T) { + + //selectStmtTable := Actor.Select(Actor.FirstName, Actor.LastName).AsTable("table_expression") + // + //query := selectStmtTable.Select( + // selectStmtTable.ColumnFrom(Actor.FirstName).As("nesto"), + // selectStmtTable.Column("actor.last_name").As("nesto2"), + // ) + // + //queryStr, err := query.String() + // + //assert.NilError(t, err) + // + //fmt.Println(queryStr) + + //avrgCustomer := Customer.Select(Customer.LastName).Limit(1).AsExpression() + // + //Customer. + // InnerJoinUsing(selectStmtTable, Customer.LastName, selectStmtTable.Column("first_name")). + // Select(Customer.AllColumns, selectStmtTable.Column("first_name")). + // Where(Actor.LastName.Neq(avrgCustomer)) + + rFilmsOnly := Film.Select(Film.FilmID, Film.Title, Film.Rating). + Where(Film.Rating.Eq(sqlbuilder.Literal("R"))). + AsTable("films") + + query := Actor.InnerJoinUsing(FilmActor, Actor.ActorID, FilmActor.FilmID). + InnerJoinUsing(rFilmsOnly, FilmActor.FilmID, rFilmsOnly.ColumnFrom(Film.FilmID)). + Select( + Actor.AllColumns, + FilmActor.AllColumns, + rFilmsOnly.ColumnFrom(Film.Title).As("film.title"), + rFilmsOnly.ColumnFrom(Film.Rating).As("film.rating"), + ) + + queryStr, err := query.String() + + assert.NilError(t, err) + + fmt.Println(queryStr) + +} + +func TestSelectFunctions(t *testing.T) { + query := Film.Select(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate")) + + str, err := query.String() + + assert.NilError(t, err) + + assert.Equal(t, str, `SELECT MAX(film.rental_rate) AS "max_film_rate" FROM dvds.film`) + + fmt.Println(str) +} + +func TestSelectQueryScalar(t *testing.T) { + + maxFilmRentalRate := Film.Select(sqlbuilder.MAX(Film.RentalRate)) + + query := Film.Select(Film.AllColumns). + Where(Film.RentalRate.Eq(maxFilmRentalRate)). + OrderBy(Film.FilmID) + + queryStr, err := query.String() + + assert.NilError(t, err) + + fmt.Println(queryStr) + + maxRentalRateFilms := []model.Film{} + err = query.Execute(db, &maxRentalRateFilms) + + assert.NilError(t, err) + + assert.Equal(t, len(maxRentalRateFilms), 336) + + assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{ + FilmID: 2, + Title: "Ace Goldfinger", + Description: stringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: int32Ptr(2006), + Language: nil, + RentalRate: 4.99, + Length: int16Ptr(48), + ReplacementCost: 12.99, + Rating: stringPtr("G"), + RentalDuration: 3, + LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951 +0000"), + SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"), + Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", + }) + + //spew.Dump(maxRentalRateFilms[0]) +} + +func TestSelectGroupByHaving(t *testing.T) { + customersPaymentQuery := Payment. + Select( + Payment.CustomerID.As("customer_payment_sum.customer_id"), + sqlbuilder.SUM(Payment.Amount).As("customer_payment_sum.amount_sum"), + ). + GroupBy(Payment.CustomerID). + OrderBy(sqlbuilder.SUM(Payment.Amount)). + HAVING(sqlbuilder.Gt(sqlbuilder.SUM(Payment.Amount), sqlbuilder.Literal(100))) + + queryStr, err := customersPaymentQuery.String() + + assert.NilError(t, err) + fmt.Println(queryStr) + + assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>100 ORDER BY SUM(payment.amount)`) + + type CustomerPaymentSum struct { + CustomerID int16 + AmountSum float64 + } + + customerPaymentSum := []CustomerPaymentSum{} + + err = customersPaymentQuery.Execute(db, &customerPaymentSum) + + assert.NilError(t, err) + + assert.Equal(t, len(customerPaymentSum), 296) + assert.DeepEqual(t, customerPaymentSum[0], CustomerPaymentSum{ + CustomerID: 135, + AmountSum: 100.72, + }) +} + +func TestSelectGroupBy2(t *testing.T) { + type CustomerWithAmounts struct { + Customer *model.Customer + AmountSum float64 + } + customersWithAmounts := []CustomerWithAmounts{} + + customersPaymentSubQuery := Payment. + Select( + Payment.CustomerID, + sqlbuilder.SUM(Payment.Amount).As("amount_sum"), + ). + GroupBy(Payment.CustomerID) + + customersPaymentTable := customersPaymentSubQuery.AsTable("customer_payment_sum") + amountSumColumn := customersPaymentTable.Column("amount_sum") + + query := Customer. + InnerJoinUsing(customersPaymentTable, Customer.CustomerID, customersPaymentTable.ColumnFrom(Payment.CustomerID)). + Select(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")). + OrderBy(amountSumColumn) + + queryStr, err := query.String() + assert.NilError(t, err) + fmt.Println(queryStr) + + err = query.Execute(db, &customersWithAmounts) + assert.NilError(t, err) + //spew.Dump(customersWithAmounts) + + assert.Equal(t, len(customersWithAmounts), 599) + + assert.DeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{ + CustomerID: 318, + StoreID: 1, + FirstName: "Brian", + LastName: "Wyman", + Email: stringPtr("brian.wyman@sakilacustomer.org"), + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + Active: int32Ptr(1), + }) + + assert.Equal(t, customersWithAmounts[0].AmountSum, 27.93) + +} + +func int16Ptr(i int16) *int16 { + return &i } func int32Ptr(i int32) *int32 {