Support for subqueries, Group By and Having clause.

This commit is contained in:
zer0sub 2019-03-30 10:17:32 +01:00
parent ddf816c998
commit 5a7563d4af
12 changed files with 674 additions and 305 deletions

View file

@ -4,8 +4,6 @@ import (
"bytes"
"database/sql"
"fmt"
"github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution"
"reflect"
"regexp"
"github.com/dropbox/godropbox/errors"
@ -17,22 +15,6 @@ type Statement interface {
Execute(db *sql.DB, destination interface{}) error
}
type SelectStatement interface {
Statement
Where(expression BoolExpression) SelectStatement
AndWhere(expression BoolExpression) SelectStatement
GroupBy(expressions ...Expression) SelectStatement
OrderBy(clauses ...OrderByClause) SelectStatement
Limit(limit int64) SelectStatement
Offset(offset int64) SelectStatement
Distinct() SelectStatement
WithSharedLock() SelectStatement
ForUpdate() SelectStatement
Comment(comment string) SelectStatement
Copy() SelectStatement
}
type InsertStatement interface {
Statement
@ -50,7 +32,7 @@ type InsertStatement interface {
type UnionStatement interface {
Statement
// Warning! You cannot include table names for the next 4 clauses, or
// Warning! You cannot include tableName names for the next 4 clauses, or
// you'll get errors like:
// Table 'server_file_journal' from one of the SELECTs cannot be used in
// global ORDER clause
@ -91,9 +73,9 @@ type LockStatement interface {
AddWriteLock(table *Table) LockStatement
}
// UnlockStatement can be used to release table locks taken using LockStatement.
// UnlockStatement can be used to release tableName locks taken using LockStatement.
// NOTE: You can not selectively release a lock and continue to hold lock on
// another table. UnlockStatement releases all the lock held in the current
// another tableName. UnlockStatement releases all the lock held in the current
// session.
type UnlockStatement interface {
Statement
@ -222,7 +204,7 @@ func (us *unionStatementImpl) String() (sql string, err error) {
return "", errors.Newf(
"All inner selects in Union statement must select the " +
"same number of columns. For sanity, you probably " +
"want to select the same table columns in the same " +
"want to select the same tableName columns in the same " +
"order. If you are selecting on multiple tables, " +
"use Null to pad to the right number of fields.")
}
@ -279,212 +261,6 @@ func (us *unionStatementImpl) String() (sql string, err error) {
return buf.String(), nil
}
//
// SELECT Statement ============================================================
//
func newSelectStatement(
table ReadableTable,
projections []Projection) SelectStatement {
return &selectStatementImpl{
table: table,
projections: projections,
limit: -1,
offset: -1,
withSharedLock: false,
forUpdate: false,
distinct: false,
}
}
// NOTE: SelectStatement purposely does not implement the Table interface since
// mysql's subquery performance is horrible.
type selectStatementImpl struct {
table ReadableTable
projections []Projection
where BoolExpression
group *listClause
order *listClause
comment string
limit, offset int64
withSharedLock bool
forUpdate bool
distinct bool
}
func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error {
destinationType := reflect.TypeOf(destination)
if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct {
s.Limit(1)
}
query, err := s.String()
if err != nil {
return err
}
return execution.Execute(db, query, destination)
}
func (s *selectStatementImpl) Copy() SelectStatement {
ret := *s
return &ret
}
// Further filter the query, instead of replacing the filter
func (q *selectStatementImpl) AndWhere(
expression BoolExpression) SelectStatement {
if q.where == nil {
return q.Where(expression)
}
q.where = And(q.where, expression)
return q
}
func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement {
q.where = expression
return q
}
func (q *selectStatementImpl) GroupBy(
expressions ...Expression) SelectStatement {
q.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
q.group.clauses[i] = e
}
return q
}
func (q *selectStatementImpl) OrderBy(
clauses ...OrderByClause) SelectStatement {
q.order = newOrderByListClause(clauses...)
return q
}
func (q *selectStatementImpl) Limit(limit int64) SelectStatement {
q.limit = limit
return q
}
func (q *selectStatementImpl) Distinct() SelectStatement {
q.distinct = true
return q
}
func (q *selectStatementImpl) WithSharedLock() SelectStatement {
// We don't need to grab a read lock if we're going to grab a write one
if !q.forUpdate {
q.withSharedLock = true
}
return q
}
func (q *selectStatementImpl) ForUpdate() SelectStatement {
// Clear a request for a shared lock if we're asking for a write one
q.withSharedLock = false
q.forUpdate = true
return q
}
func (q *selectStatementImpl) Offset(offset int64) SelectStatement {
q.offset = offset
return q
}
func (q *selectStatementImpl) Comment(comment string) SelectStatement {
q.comment = comment
return q
}
// Return the properly escaped SQL statement, against the specified database
func (q *selectStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("SELECT ")
if err = writeComment(q.comment, buf); err != nil {
return
}
if q.distinct {
_, _ = buf.WriteString("DISTINCT ")
}
if q.projections == nil || len(q.projections) == 0 {
return "", errors.Newf(
"No column selected. Generated sql: %s",
buf.String())
}
for i, col := range q.projections {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column selected. Generated sql: %s",
buf.String())
}
if err = col.SerializeSqlForColumnList(buf); err != nil {
return
}
}
_, _ = buf.WriteString(" FROM ")
if q.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String())
}
if err = q.table.SerializeSql(buf); err != nil {
return
}
if q.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = q.where.SerializeSql(buf); err != nil {
return
}
}
if q.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = q.group.SerializeSql(buf); err != nil {
return
}
}
if q.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = q.order.SerializeSql(buf); err != nil {
return
}
}
if q.limit >= 0 {
if q.offset >= 0 {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
}
}
if q.forUpdate {
_, _ = buf.WriteString(" FOR UPDATE")
} else if q.withSharedLock {
_, _ = buf.WriteString(" LOCK IN SHARE MODE")
}
return buf.String(), nil
}
//
// INSERT Statement ============================================================
//
@ -560,7 +336,7 @@ func (s *insertStatementImpl) String() (sql string, err error) {
}
if s.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String())
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = s.table.SerializeSql(buf); err != nil {
@ -728,7 +504,7 @@ func (u *updateStatementImpl) String() (sql string, err error) {
}
if u.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String())
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = u.table.SerializeSql(buf); err != nil {
@ -863,7 +639,7 @@ func (d *deleteStatementImpl) String() (sql string, err error) {
}
if d.table == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String())
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = d.table.SerializeSql(buf); err != nil {
@ -919,13 +695,13 @@ func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
// AddReadLock takes read lock on the table.
// AddReadLock takes read lock on the tableName.
func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement {
s.locks = append(s.locks, tableLock{t: t, w: false})
return s
}
// AddWriteLock takes write lock on the table.
// AddWriteLock takes write lock on the tableName.
func (s *lockStatementImpl) AddWriteLock(t *Table) LockStatement {
s.locks = append(s.locks, tableLock{t: t, w: true})
return s
@ -941,7 +717,7 @@ func (s *lockStatementImpl) String() (sql string, err error) {
for idx, lock := range s.locks {
if lock.t == nil {
return "", errors.Newf("nil table. Generated sql: %s", buf.String())
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = lock.t.SerializeSql(buf); err != nil {
@ -962,7 +738,7 @@ func (s *lockStatementImpl) String() (sql string, err error) {
return buf.String(), nil
}
// NewUnlockStatement returns SQL statement that can be used to release table locks
// NewUnlockStatement returns SQL statement that can be used to release tableName locks
// grabbed by the current session.
func NewUnlockStatement() UnlockStatement {
return &unlockStatementImpl{}