Postgres refactor.

This commit is contained in:
go-jet 2019-08-11 09:52:02 +02:00
parent d00167cbba
commit 8519ccbdd0
57 changed files with 2451 additions and 598 deletions

View file

@ -1,270 +1,702 @@
package jet
import (
"bytes"
"errors"
"github.com/go-jet/jet/internal/utils"
"github.com/google/uuid"
"strconv"
"strings"
"time"
)
type SerializeOption int
const (
noWrap SerializeOption = iota
)
type Clause interface {
serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error
Serialize(statementType StatementType, out *SqlBuilder) error
}
func Serialize(exp Clause, statementType StatementType, out *SqlBuilder, options ...SerializeOption) error {
return exp.serialize(statementType, out, options...)
type ClauseWithProjections interface {
Clause
projections() []Projection
}
func contains(options []SerializeOption, option SerializeOption) bool {
for _, opt := range options {
if opt == option {
return true
type ClauseSelect struct {
Distinct bool
Projections []Projection
}
func (s *ClauseSelect) projections() []Projection {
return s.Projections
}
func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) error {
out.newLine()
out.WriteString("SELECT")
if s.Distinct {
out.WriteString("DISTINCT")
}
if len(s.Projections) == 0 {
return errors.New("jet: no column selected for Projection")
}
return out.writeProjections(statementType, s.Projections)
}
type ClauseFrom struct {
Table Serializer
}
func (f *ClauseFrom) Serialize(statementType StatementType, out *SqlBuilder) error {
if f.Table == nil {
return nil
}
return out.writeFrom(statementType, f.Table)
}
type ClauseWhere struct {
Condition BoolExpression
Mandatory bool
}
func (c *ClauseWhere) Serialize(statementType StatementType, out *SqlBuilder) error {
if c.Condition == nil {
if c.Mandatory {
return errors.New("jet: WHERE clause not set")
}
return nil
}
return false
return out.writeWhere(statementType, c.Condition)
}
type SqlBuilder struct {
Dialect Dialect
Buff bytes.Buffer
Args []interface{}
lastChar byte
ident int
type ClauseGroupBy struct {
List []GroupByClause
}
func (s *SqlBuilder) DebugSQL() string {
return queryStringToDebugString(s.Buff.String(), s.Args, s.Dialect)
}
type StatementType string
const (
SelectStatementType StatementType = "SELECT"
InsertStatementType StatementType = "INSERT"
UpdateStatementType StatementType = "UPDATE"
DeleteStatementType StatementType = "DELETE"
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
)
const defaultIdent = 5
func (q *SqlBuilder) increaseIdent() {
q.ident += defaultIdent
}
func (q *SqlBuilder) decreaseIdent() {
if q.ident < defaultIdent {
q.ident = 0
}
q.ident -= defaultIdent
}
func (q *SqlBuilder) writeProjections(statement StatementType, projections []Projection) error {
q.increaseIdent()
err := SerializeProjectionList(statement, projections, q)
q.decreaseIdent()
return err
}
func (q *SqlBuilder) writeFrom(statement StatementType, table ReadableTable) error {
q.newLine()
q.WriteString("FROM")
q.increaseIdent()
err := table.serialize(statement, q)
q.decreaseIdent()
return err
}
func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error {
q.newLine()
q.WriteString("WHERE")
q.increaseIdent()
err := where.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []groupByClause) error {
q.newLine()
q.WriteString("GROUP BY")
q.increaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q)
q.decreaseIdent()
return err
}
func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []orderByClause) error {
q.newLine()
q.WriteString("ORDER BY")
q.increaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q)
q.decreaseIdent()
return err
}
func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) error {
q.newLine()
q.WriteString("HAVING")
q.increaseIdent()
err := having.serialize(statement, q, noWrap)
q.decreaseIdent()
return err
}
func (q *SqlBuilder) writeReturning(statement StatementType, returning []Projection) error {
if len(returning) == 0 {
func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) error {
if len(c.List) == 0 {
return nil
}
if !q.Dialect.SupportsReturning() {
panic("jet: " + q.Dialect.Name() + " dialect does not support RETURNING.")
out.newLine()
out.WriteString("GROUP BY")
out.increaseIdent()
err := serializeGroupByClauseList(statementType, c.List, out)
out.decreaseIdent()
return err
}
type ClauseHaving struct {
Condition BoolExpression
}
func (c *ClauseHaving) Serialize(statementType StatementType, out *SqlBuilder) error {
if c.Condition == nil {
return nil
}
q.newLine()
q.WriteString("RETURNING")
q.increaseIdent()
return q.writeProjections(statement, returning)
return out.writeHaving(statementType, c.Condition)
}
func (q *SqlBuilder) newLine() {
q.write([]byte{'\n'})
q.write(bytes.Repeat([]byte{' '}, q.ident))
type ClauseOrderBy struct {
List []OrderByClause
}
func (q *SqlBuilder) write(data []byte) {
if len(data) == 0 {
func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SqlBuilder) error {
if o.List == nil {
return nil
}
return out.writeOrderBy(statementType, o.List)
}
type ClauseLimit struct {
Count int64
}
func (l *ClauseLimit) Serialize(statementType StatementType, out *SqlBuilder) error {
if l.Count >= 0 {
out.newLine()
out.WriteString("LIMIT")
out.insertParametrizedArgument(l.Count)
}
return nil
}
type ClauseOffset struct {
Count int64
}
func (o *ClauseOffset) Serialize(statementType StatementType, out *SqlBuilder) error {
if o.Count >= 0 {
out.newLine()
out.WriteString("OFFSET")
out.insertParametrizedArgument(o.Count)
}
return nil
}
type ClauseFor struct {
Lock SelectLock
}
func (f *ClauseFor) Serialize(statementType StatementType, out *SqlBuilder) error {
if f.Lock == nil {
return nil
}
out.newLine()
out.WriteString("FOR")
return f.Lock.serialize(statementType, out)
}
type ClauseSetStmtOperator struct {
Operator string
All bool
Selects []StatementWithProjections
OrderBy ClauseOrderBy
Limit ClauseLimit
Offset ClauseOffset
}
func (s *ClauseSetStmtOperator) projections() []Projection {
if len(s.Selects) > 0 {
return s.Selects[0].projections()
}
return nil
}
func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlBuilder) error {
if len(s.Selects) < 2 {
return errors.New("jet: UNION Statement must have at least two SELECT statements")
}
wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0
//if wrap {
// out.WriteString("(")
// out.increaseIdent()
//}
if wrap {
out.newLine()
out.WriteString("(")
out.increaseIdent()
}
for i, selectStmt := range s.Selects {
out.newLine()
if i > 0 {
out.WriteString(s.Operator)
if s.All {
out.WriteString("ALL")
}
out.newLine()
}
if selectStmt == nil {
return errors.New("jet: select statement is nil")
}
err := selectStmt.serialize(statementType, out)
if err != nil {
return err
}
}
if wrap {
out.decreaseIdent()
out.newLine()
out.WriteString(")")
}
if err := s.OrderBy.Serialize(statementType, out); err != nil {
return err
}
if err := s.Limit.Serialize(statementType, out); err != nil {
return err
}
if err := s.Offset.Serialize(statementType, out); err != nil {
return err
}
//if wrap {
// out.decreaseIdent()
// out.newLine()
// out.WriteString(")")
//}
return nil
}
type ClauseUpdate struct {
Table SerializerTable
}
func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) error {
out.newLine()
out.WriteString("UPDATE")
if utils.IsNil(u.Table) {
return errors.New("jet: table to update is nil")
}
if err := u.Table.serialize(statementType, out); err != nil {
return err
}
return nil
}
type ClauseSet struct {
Columns []IColumn
Values []Serializer
}
func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) error {
out.newLine()
out.WriteString("SET")
if len(s.Columns) == 0 {
return errors.New("jet: no columns selected")
}
if len(s.Columns) > 1 {
out.WriteString("(")
}
err := SerializeColumnNames(s.Columns, out)
if err != nil {
return err
}
if len(s.Columns) > 1 {
out.WriteString(")")
}
out.WriteString("=")
if len(s.Values) > 1 {
out.WriteString("(")
}
err = SerializeClauseList(statementType, s.Values, out)
if err != nil {
return err
}
if len(s.Values) > 1 {
out.WriteString(")")
}
return nil
}
type ClauseReturning struct {
Projections []Projection
}
func (r *ClauseReturning) Serialize(statementType StatementType, out *SqlBuilder) error {
return out.WriteReturning(statementType, r.Projections)
}
type ClauseInsert struct {
Table SerializerTable
Columns []IColumn
}
func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error {
out.newLine()
out.WriteString("INSERT INTO")
if utils.IsNil(i.Table) {
return errors.New("jet: table is nil")
}
err := i.Table.serialize(statementType, out)
if err != nil {
return err
}
if len(i.Columns) > 0 {
out.WriteString("(")
err = SerializeColumnNames(i.Columns, out)
if err != nil {
return err
}
out.WriteString(")")
}
return nil
}
type ClauseValues struct {
Rows [][]Serializer
}
func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) error {
if len(v.Rows) == 0 {
return nil
}
out.WriteString("VALUES")
for rowIndex, row := range v.Rows {
if rowIndex > 0 {
out.WriteString(",")
}
out.increaseIdent()
out.newLine()
out.WriteString("(")
err := SerializeClauseList(statementType, row, out)
if err != nil {
return err
}
out.writeByte(')')
out.decreaseIdent()
}
return nil
}
type ClauseQuery struct {
Query SerializerStatement
}
func (v *ClauseQuery) Serialize(statementType StatementType, out *SqlBuilder) error {
if v.Query == nil {
return nil
}
return v.Query.serialize(statementType, out)
}
type ClauseDelete struct {
Table SerializerTable
}
func (d *ClauseDelete) Serialize(statementType StatementType, out *SqlBuilder) error {
out.newLine()
out.WriteString("DELETE FROM")
if d.Table == nil {
return errors.New("jet: nil tableName")
}
if err := d.Table.serialize(statementType, out); err != nil {
return err
}
return nil
}
type ClauseStatementBegin struct {
Name string
Tables []SerializerTable
}
func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBuilder) error {
out.newLine()
out.WriteString(d.Name)
for i, table := range d.Tables {
if i > 0 {
out.WriteString(", ")
}
err := table.serialize(statementType, out)
if err != nil {
return err
}
}
return nil
}
type ClauseString struct {
Name string
Data string
}
func (d *ClauseString) Serialize(statementType StatementType, out *SqlBuilder) error {
out.newLine()
out.WriteString(d.Name)
out.WriteString(d.Data)
return nil
}
type ClauseOptional struct {
Name string
Show bool
}
func (d *ClauseOptional) Serialize(statementType StatementType, out *SqlBuilder) error {
if !d.Show {
return nil
}
//out.newLine()
out.WriteString(d.Name)
return nil
}
type ClauseIn struct {
LockMode string
}
func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) error {
if i.LockMode == "" {
return nil
}
out.WriteString("IN")
out.WriteString(string(i.LockMode))
out.WriteString("MODE")
return nil
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable2(Dialect Dialect, schemaName, name string, columns ...Column) TableImpl2 {
t := TableImpl2{
Dialect: Dialect,
schemaName: schemaName,
name: name,
columnList: columns,
}
for _, c := range columns {
c.SetTableName(name)
}
return t
}
type TableImpl2 struct {
Dialect Dialect
schemaName string
name string
alias string
columnList []Column
}
func (t *TableImpl2) AS(alias string) {
t.alias = alias
for _, c := range t.columnList {
c.SetTableName(alias)
}
}
func (t *TableImpl2) SchemaName() string {
return t.schemaName
}
func (t *TableImpl2) TableName() string {
return t.name
}
func (t *TableImpl2) Columns() []IColumn {
ret := []IColumn{}
for _, col := range t.columnList {
ret = append(ret, col)
}
return ret
}
func (t *TableImpl2) dialect() Dialect {
return t.Dialect
}
func (t *TableImpl2) accept(visitor visitor) {
visitor.visit(t)
}
func (t *TableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if t == nil {
return errors.New("jet: tableImpl is nil. ")
}
out.writeIdentifier(t.schemaName)
out.WriteString(".")
out.writeIdentifier(t.name)
if len(t.alias) > 0 {
out.WriteString("AS")
out.writeIdentifier(t.alias)
}
return nil
}
// Join expressions are pseudo readable tables.
type JoinTableImpl struct {
lhs Serializer
rhs Serializer
joinType JoinType
onCondition BoolExpression
}
func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl {
joinTable := JoinTableImpl{
lhs: lhs,
rhs: rhs,
joinType: joinType,
onCondition: onCondition,
}
return joinTable
}
func (t *JoinTableImpl) SchemaName() string {
return ""
}
func (t *JoinTableImpl) TableName() string {
return ""
}
func (t *JoinTableImpl) columns() []IColumn {
//return append(t.lhs.columns(), t.rhs.columns()...)
panic("Unimplemented")
}
func (t *JoinTableImpl) accept(visitor visitor) {
//t.lhs.accept(visitor)
//t.rhs.accept(visitor)
//TODO: uncoment
}
func (t *JoinTableImpl) dialect() Dialect {
return detectDialect(t)
}
func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) {
if t == nil {
return errors.New("jet: Join table is nil. ")
}
if utils.IsNil(t.lhs) {
return errors.New("jet: left hand side of join operation is nil table")
}
if err = t.lhs.serialize(statement, out); err != nil {
return
}
if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.Buff.Len() > 0 {
q.Buff.WriteByte(' ')
out.newLine()
switch t.joinType {
case InnerJoin:
out.WriteString("INNER JOIN")
case LeftJoin:
out.WriteString("LEFT JOIN")
case RightJoin:
out.WriteString("RIGHT JOIN")
case FullJoin:
out.WriteString("FULL JOIN")
case CrossJoin:
out.WriteString("CROSS JOIN")
}
q.Buff.Write(data)
q.lastChar = data[len(data)-1]
}
func isPreSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':'
}
func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
}
func (q *SqlBuilder) writeAlias(str string) {
aliasQuoteChar := string(q.Dialect.AliasQuoteChar())
q.WriteString(aliasQuoteChar + str + aliasQuoteChar)
}
func (q *SqlBuilder) WriteString(str string) {
q.write([]byte(str))
}
func (q *SqlBuilder) writeIdentifier(name string, alwaysQuote ...bool) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap || len(alwaysQuote) > 0 {
identQuoteChar := string(q.Dialect.IdentifierQuoteChar())
q.WriteString(identQuoteChar + name + identQuoteChar)
} else {
q.WriteString(name)
}
}
func (q *SqlBuilder) writeByte(b byte) {
q.write([]byte{b})
}
func (q *SqlBuilder) finalize() (string, []interface{}) {
return q.Buff.String() + ";\n", q.Args
}
func (q *SqlBuilder) insertConstantArgument(arg interface{}) {
q.WriteString(argToString(arg))
}
func (q *SqlBuilder) insertParametrizedArgument(arg interface{}) {
q.Args = append(q.Args, arg)
argPlaceholder := q.Dialect.ArgumentPlaceholder()(len(q.Args))
q.WriteString(argPlaceholder)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"
if utils.IsNil(t.rhs) {
return errors.New("jet: right hand side of join operation is nil table")
}
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE"
if err = t.rhs.serialize(statement, out); err != nil {
return
}
if t.onCondition == nil && t.joinType != CrossJoin {
return errors.New("jet: join condition is nil")
}
if t.onCondition != nil {
out.WriteString("ON")
if err = t.onCondition.serialize(statement, out); err != nil {
return
}
return "FALSE"
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case int:
return strconv.FormatInt(int64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case uint64:
return strconv.FormatUint(uint64(bindVal), 10)
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
case string:
return stringQuote(bindVal)
case []byte:
return stringQuote(string(bindVal))
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(string(utils.FormatTimestamp(bindVal)))
default:
return "[Unsupported type]"
}
return nil
}
func stringQuote(value string) string {
return `'` + strings.Replace(value, "'", "''", -1) + `'`
// SelectTable is interface for SELECT sub-queries
type SelectTable interface {
Alias() string
AllColumns() ProjectionList
}
type SelectTableImpl2 struct {
selectStmt StatementWithProjections
alias string
projections []Projection
}
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 {
selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias}
for _, projection := range selectStmt.projections() {
newProjection := projection.fromImpl(&selectTable)
selectTable.projections = append(selectTable.projections, newProjection)
}
return selectTable
}
func (s *SelectTableImpl2) Alias() string {
return s.alias
}
func (s *SelectTableImpl2) columns() []IColumn {
return nil
}
func (s *SelectTableImpl2) accept(visitor visitor) {
visitor.visit(s)
s.selectStmt.accept(visitor)
}
func (s *SelectTableImpl2) dialect() Dialect {
return detectDialect(s.selectStmt)
}
func (s *SelectTableImpl2) AllColumns() ProjectionList {
return s.projections
}
func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error {
if s == nil {
return errors.New("jet: Expression table is nil. ")
}
err := s.selectStmt.serialize(statement, out)
if err != nil {
return err
}
out.WriteString("AS")
out.writeIdentifier(s.alias)
return nil
}