From 8519ccbdd0c78bcda289011857d9f5df6d107194 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 11 Aug 2019 09:52:02 +0200 Subject: [PATCH] Postgres refactor. --- internal/jet/alias.go | 2 +- internal/jet/bool_expression.go | 12 +- internal/jet/cast.go | 4 +- internal/jet/clause.go | 896 +++++++++++++----- internal/jet/column.go | 4 +- internal/jet/column_test.go | 2 +- internal/jet/delete_statement.go | 2 +- internal/jet/dialects.go | 10 +- internal/jet/enum_value.go | 4 +- internal/jet/expression.go | 62 +- internal/jet/float_expression.go | 4 +- internal/jet/func_expression.go | 6 +- internal/jet/group_by_clause.go | 2 +- internal/jet/insert_statement.go | 12 +- internal/jet/insert_statement_test.go | 6 +- internal/jet/integer_expression.go | 12 +- internal/jet/literal_expression.go | 22 +- internal/jet/operators.go | 4 +- internal/jet/order_by_clause.go | 4 +- internal/jet/select_statement.go | 23 +- internal/jet/select_table.go | 138 ++- internal/jet/serializer.go | 36 + .../{clause_test.go => serializer_test.go} | 0 internal/jet/set_statement.go | 2 +- internal/jet/sql_builder.go | 235 +++++ internal/jet/statement.go | 125 ++- internal/jet/string_expression.go | 4 +- internal/jet/table.go | 74 +- internal/jet/testutils.go | 4 +- internal/jet/time_expression.go | 6 +- internal/jet/timestampz_expression.go | 4 +- internal/jet/timez_expression.go | 6 +- internal/jet/update_statement.go | 10 +- internal/jet/utils.go | 24 +- internal/jet/visitor.go | 2 +- mysql/utils_test.go | 4 +- postgres/delete_statement.go | 41 + postgres/delete_statement_test.go | 25 + postgres/dialect.go | 36 +- postgres/functions.go | 12 +- postgres/insert_statement.go | 73 ++ postgres/insert_statement_test.go | 148 +++ postgres/lock_statement.go | 54 +- postgres/lock_statement_test.go | 32 + postgres/select_statement.go | 144 +++ postgres/select_statement_test.go | 137 +++ postgres/set_statement.go | 147 +++ postgres/set_statement_test.go | 81 ++ postgres/statements.go | 42 - postgres/statements_test.go | 30 - postgres/table.go | 141 ++- postgres/update_statement.go | 55 ++ postgres/update_statement_test.go | 62 ++ postgres/utils_test.go | 4 +- tests/postgres/alltypes_test.go | 5 +- tests/postgres/lock_test.go | 2 +- tests/postgres/select_test.go | 11 +- 57 files changed, 2451 insertions(+), 598 deletions(-) create mode 100644 internal/jet/serializer.go rename internal/jet/{clause_test.go => serializer_test.go} (100%) create mode 100644 internal/jet/sql_builder.go create mode 100644 postgres/delete_statement.go create mode 100644 postgres/delete_statement_test.go create mode 100644 postgres/insert_statement.go create mode 100644 postgres/insert_statement_test.go create mode 100644 postgres/lock_statement_test.go create mode 100644 postgres/select_statement.go create mode 100644 postgres/select_statement_test.go create mode 100644 postgres/set_statement.go create mode 100644 postgres/set_statement_test.go delete mode 100644 postgres/statements.go delete mode 100644 postgres/statements_test.go create mode 100644 postgres/update_statement.go create mode 100644 postgres/update_statement_test.go diff --git a/internal/jet/alias.go b/internal/jet/alias.go index 4ec7948..108cffd 100644 --- a/internal/jet/alias.go +++ b/internal/jet/alias.go @@ -14,7 +14,7 @@ func newAlias(expression Expression, aliasName string) Projection { func (a *alias) fromImpl(subQuery SelectTable) Projection { column := newColumn(a.alias, "", nil) - column.parent = &column + column.Parent = &column column.subQuery = subQuery return &column diff --git a/internal/jet/bool_expression.go b/internal/jet/bool_expression.go index efbae3c..f4ca0a6 100644 --- a/internal/jet/bool_expression.go +++ b/internal/jet/bool_expression.go @@ -86,7 +86,7 @@ func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression { //---------------------------------------------------// type binaryBoolExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl boolInterfaceImpl binaryOpExpression @@ -96,7 +96,7 @@ func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression binaryBoolExpression := binaryBoolExpression{} binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - binaryBoolExpression.expressionInterfaceImpl.parent = &binaryBoolExpression + binaryBoolExpression.ExpressionInterfaceImpl.Parent = &binaryBoolExpression binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression return &binaryBoolExpression @@ -104,7 +104,7 @@ func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression //---------------------------------------------------// type prefixBoolExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl boolInterfaceImpl prefixOpExpression @@ -114,7 +114,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio exp := prefixBoolExpression{} exp.prefixOpExpression = newPrefixExpression(expression, operator) - exp.expressionInterfaceImpl.parent = &exp + exp.ExpressionInterfaceImpl.Parent = &exp exp.boolInterfaceImpl.parent = &exp return &exp @@ -122,7 +122,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio //---------------------------------------------------// type postfixBoolOpExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl boolInterfaceImpl postfixOpExpression @@ -132,7 +132,7 @@ func newPostifxBoolExpression(expression Expression, operator string) BoolExpres exp := postfixBoolOpExpression{} exp.postfixOpExpression = newPostfixOpExpression(expression, operator) - exp.expressionInterfaceImpl.parent = &exp + exp.ExpressionInterfaceImpl.Parent = &exp exp.boolInterfaceImpl.parent = &exp return &exp diff --git a/internal/jet/cast.go b/internal/jet/cast.go index 9e7f277..886257e 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -32,7 +32,7 @@ func (b *CastImpl) AS(castType string) Expression { cast: string(castType), } - castExp.expressionInterfaceImpl.parent = castExp + castExp.ExpressionInterfaceImpl.Parent = castExp return castExp } @@ -61,7 +61,7 @@ func (b *CastImpl) AS_TIME() TimeExpression { } type castExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expression Expression cast string diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 6e3388a..b6c4e37 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -1,270 +1,702 @@ package jet import ( - "bytes" + "errors" "github.com/go-jet/jet/internal/utils" - "github.com/google/uuid" - "strconv" - "strings" - "time" -) - -type SerializeOption int - -const ( - noWrap SerializeOption = iota ) type Clause interface { - serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error + Serialize(statementType StatementType, out *SqlBuilder) error } -func Serialize(exp Clause, statementType StatementType, out *SqlBuilder, options ...SerializeOption) error { - return exp.serialize(statementType, out, options...) +type ClauseWithProjections interface { + Clause + + projections() []Projection } -func contains(options []SerializeOption, option SerializeOption) bool { - for _, opt := range options { - if opt == option { - return true +type ClauseSelect struct { + Distinct bool + Projections []Projection +} + +func (s *ClauseSelect) projections() []Projection { + return s.Projections +} + +func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) error { + out.newLine() + out.WriteString("SELECT") + + if s.Distinct { + out.WriteString("DISTINCT") + } + + if len(s.Projections) == 0 { + return errors.New("jet: no column selected for Projection") + } + + return out.writeProjections(statementType, s.Projections) +} + +type ClauseFrom struct { + Table Serializer +} + +func (f *ClauseFrom) Serialize(statementType StatementType, out *SqlBuilder) error { + if f.Table == nil { + return nil + } + return out.writeFrom(statementType, f.Table) +} + +type ClauseWhere struct { + Condition BoolExpression + Mandatory bool +} + +func (c *ClauseWhere) Serialize(statementType StatementType, out *SqlBuilder) error { + if c.Condition == nil { + if c.Mandatory { + return errors.New("jet: WHERE clause not set") } + return nil } - - return false + return out.writeWhere(statementType, c.Condition) } -type SqlBuilder struct { - Dialect Dialect - Buff bytes.Buffer - Args []interface{} - - lastChar byte - ident int +type ClauseGroupBy struct { + List []GroupByClause } -func (s *SqlBuilder) DebugSQL() string { - return queryStringToDebugString(s.Buff.String(), s.Args, s.Dialect) -} - -type StatementType string - -const ( - SelectStatementType StatementType = "SELECT" - InsertStatementType StatementType = "INSERT" - UpdateStatementType StatementType = "UPDATE" - DeleteStatementType StatementType = "DELETE" - SetStatementType StatementType = "SET" - LockStatementType StatementType = "LOCK" -) - -const defaultIdent = 5 - -func (q *SqlBuilder) increaseIdent() { - q.ident += defaultIdent -} - -func (q *SqlBuilder) decreaseIdent() { - if q.ident < defaultIdent { - q.ident = 0 - } - - q.ident -= defaultIdent -} - -func (q *SqlBuilder) writeProjections(statement StatementType, projections []Projection) error { - q.increaseIdent() - err := SerializeProjectionList(statement, projections, q) - q.decreaseIdent() - return err -} - -func (q *SqlBuilder) writeFrom(statement StatementType, table ReadableTable) error { - q.newLine() - q.WriteString("FROM") - - q.increaseIdent() - err := table.serialize(statement, q) - q.decreaseIdent() - - return err -} - -func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error { - q.newLine() - q.WriteString("WHERE") - - q.increaseIdent() - err := where.serialize(statement, q, noWrap) - q.decreaseIdent() - - return err -} - -func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []groupByClause) error { - q.newLine() - q.WriteString("GROUP BY") - - q.increaseIdent() - err := serializeGroupByClauseList(statement, groupBy, q) - q.decreaseIdent() - - return err -} - -func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []orderByClause) error { - q.newLine() - q.WriteString("ORDER BY") - - q.increaseIdent() - err := serializeOrderByClauseList(statement, orderBy, q) - q.decreaseIdent() - - return err -} - -func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) error { - q.newLine() - q.WriteString("HAVING") - - q.increaseIdent() - err := having.serialize(statement, q, noWrap) - q.decreaseIdent() - - return err -} - -func (q *SqlBuilder) writeReturning(statement StatementType, returning []Projection) error { - if len(returning) == 0 { +func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) error { + if len(c.List) == 0 { return nil } - if !q.Dialect.SupportsReturning() { - panic("jet: " + q.Dialect.Name() + " dialect does not support RETURNING.") + out.newLine() + out.WriteString("GROUP BY") + + out.increaseIdent() + err := serializeGroupByClauseList(statementType, c.List, out) + out.decreaseIdent() + + return err +} + +type ClauseHaving struct { + Condition BoolExpression +} + +func (c *ClauseHaving) Serialize(statementType StatementType, out *SqlBuilder) error { + if c.Condition == nil { + return nil } - q.newLine() - q.WriteString("RETURNING") - q.increaseIdent() - - return q.writeProjections(statement, returning) + return out.writeHaving(statementType, c.Condition) } -func (q *SqlBuilder) newLine() { - q.write([]byte{'\n'}) - q.write(bytes.Repeat([]byte{' '}, q.ident)) +type ClauseOrderBy struct { + List []OrderByClause } -func (q *SqlBuilder) write(data []byte) { - if len(data) == 0 { +func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SqlBuilder) error { + if o.List == nil { + return nil + } + + return out.writeOrderBy(statementType, o.List) +} + +type ClauseLimit struct { + Count int64 +} + +func (l *ClauseLimit) Serialize(statementType StatementType, out *SqlBuilder) error { + if l.Count >= 0 { + out.newLine() + out.WriteString("LIMIT") + out.insertParametrizedArgument(l.Count) + } + + return nil +} + +type ClauseOffset struct { + Count int64 +} + +func (o *ClauseOffset) Serialize(statementType StatementType, out *SqlBuilder) error { + if o.Count >= 0 { + out.newLine() + out.WriteString("OFFSET") + out.insertParametrizedArgument(o.Count) + } + + return nil +} + +type ClauseFor struct { + Lock SelectLock +} + +func (f *ClauseFor) Serialize(statementType StatementType, out *SqlBuilder) error { + if f.Lock == nil { + return nil + } + + out.newLine() + out.WriteString("FOR") + return f.Lock.serialize(statementType, out) +} + +type ClauseSetStmtOperator struct { + Operator string + All bool + Selects []StatementWithProjections + OrderBy ClauseOrderBy + Limit ClauseLimit + Offset ClauseOffset +} + +func (s *ClauseSetStmtOperator) projections() []Projection { + if len(s.Selects) > 0 { + return s.Selects[0].projections() + } + return nil +} + +func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlBuilder) error { + if len(s.Selects) < 2 { + return errors.New("jet: UNION Statement must have at least two SELECT statements") + } + + wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0 + + //if wrap { + // out.WriteString("(") + // out.increaseIdent() + //} + + if wrap { + out.newLine() + out.WriteString("(") + out.increaseIdent() + } + + for i, selectStmt := range s.Selects { + out.newLine() + if i > 0 { + out.WriteString(s.Operator) + + if s.All { + out.WriteString("ALL") + } + out.newLine() + } + + if selectStmt == nil { + return errors.New("jet: select statement is nil") + } + + err := selectStmt.serialize(statementType, out) + + if err != nil { + return err + } + } + + if wrap { + out.decreaseIdent() + out.newLine() + out.WriteString(")") + } + + if err := s.OrderBy.Serialize(statementType, out); err != nil { + return err + } + + if err := s.Limit.Serialize(statementType, out); err != nil { + return err + } + + if err := s.Offset.Serialize(statementType, out); err != nil { + return err + } + + //if wrap { + // out.decreaseIdent() + // out.newLine() + // out.WriteString(")") + //} + + return nil +} + +type ClauseUpdate struct { + Table SerializerTable +} + +func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) error { + out.newLine() + out.WriteString("UPDATE") + + if utils.IsNil(u.Table) { + return errors.New("jet: table to update is nil") + } + + if err := u.Table.serialize(statementType, out); err != nil { + return err + } + + return nil +} + +type ClauseSet struct { + Columns []IColumn + Values []Serializer +} + +func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) error { + out.newLine() + out.WriteString("SET") + + if len(s.Columns) == 0 { + return errors.New("jet: no columns selected") + } + + if len(s.Columns) > 1 { + out.WriteString("(") + } + + err := SerializeColumnNames(s.Columns, out) + + if err != nil { + return err + } + + if len(s.Columns) > 1 { + out.WriteString(")") + } + + out.WriteString("=") + + if len(s.Values) > 1 { + out.WriteString("(") + } + + err = SerializeClauseList(statementType, s.Values, out) + + if err != nil { + return err + } + + if len(s.Values) > 1 { + out.WriteString(")") + } + + return nil +} + +type ClauseReturning struct { + Projections []Projection +} + +func (r *ClauseReturning) Serialize(statementType StatementType, out *SqlBuilder) error { + return out.WriteReturning(statementType, r.Projections) +} + +type ClauseInsert struct { + Table SerializerTable + Columns []IColumn +} + +func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error { + out.newLine() + out.WriteString("INSERT INTO") + + if utils.IsNil(i.Table) { + return errors.New("jet: table is nil") + } + + err := i.Table.serialize(statementType, out) + + if err != nil { + return err + } + + if len(i.Columns) > 0 { + out.WriteString("(") + + err = SerializeColumnNames(i.Columns, out) + + if err != nil { + return err + } + + out.WriteString(")") + } + + return nil +} + +type ClauseValues struct { + Rows [][]Serializer +} + +func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) error { + if len(v.Rows) == 0 { + return nil + } + + out.WriteString("VALUES") + + for rowIndex, row := range v.Rows { + if rowIndex > 0 { + out.WriteString(",") + } + + out.increaseIdent() + out.newLine() + out.WriteString("(") + + err := SerializeClauseList(statementType, row, out) + + if err != nil { + return err + } + + out.writeByte(')') + out.decreaseIdent() + } + return nil +} + +type ClauseQuery struct { + Query SerializerStatement +} + +func (v *ClauseQuery) Serialize(statementType StatementType, out *SqlBuilder) error { + if v.Query == nil { + return nil + } + + return v.Query.serialize(statementType, out) +} + +type ClauseDelete struct { + Table SerializerTable +} + +func (d *ClauseDelete) Serialize(statementType StatementType, out *SqlBuilder) error { + out.newLine() + out.WriteString("DELETE FROM") + + if d.Table == nil { + return errors.New("jet: nil tableName") + } + + if err := d.Table.serialize(statementType, out); err != nil { + return err + } + + return nil +} + +type ClauseStatementBegin struct { + Name string + Tables []SerializerTable +} + +func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBuilder) error { + out.newLine() + out.WriteString(d.Name) + + for i, table := range d.Tables { + if i > 0 { + out.WriteString(", ") + } + + err := table.serialize(statementType, out) + + if err != nil { + return err + } + } + + return nil +} + +type ClauseString struct { + Name string + Data string +} + +func (d *ClauseString) Serialize(statementType StatementType, out *SqlBuilder) error { + out.newLine() + out.WriteString(d.Name) + out.WriteString(d.Data) + return nil +} + +type ClauseOptional struct { + Name string + Show bool +} + +func (d *ClauseOptional) Serialize(statementType StatementType, out *SqlBuilder) error { + if !d.Show { + return nil + } + //out.newLine() + out.WriteString(d.Name) + return nil +} + +type ClauseIn struct { + LockMode string +} + +func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) error { + if i.LockMode == "" { + return nil + } + + out.WriteString("IN") + out.WriteString(string(i.LockMode)) + out.WriteString("MODE") + + return nil +} + +// NewTable creates new table with schema Name, table Name and list of columns +func NewTable2(Dialect Dialect, schemaName, name string, columns ...Column) TableImpl2 { + + t := TableImpl2{ + Dialect: Dialect, + schemaName: schemaName, + name: name, + columnList: columns, + } + + for _, c := range columns { + c.SetTableName(name) + } + + return t +} + +type TableImpl2 struct { + Dialect Dialect + schemaName string + name string + alias string + columnList []Column +} + +func (t *TableImpl2) AS(alias string) { + t.alias = alias + + for _, c := range t.columnList { + c.SetTableName(alias) + } +} + +func (t *TableImpl2) SchemaName() string { + return t.schemaName +} + +func (t *TableImpl2) TableName() string { + return t.name +} + +func (t *TableImpl2) Columns() []IColumn { + ret := []IColumn{} + + for _, col := range t.columnList { + ret = append(ret, col) + } + + return ret +} + +func (t *TableImpl2) dialect() Dialect { + return t.Dialect +} + +func (t *TableImpl2) accept(visitor visitor) { + visitor.visit(t) +} + +func (t *TableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { + if t == nil { + return errors.New("jet: tableImpl is nil. ") + } + + out.writeIdentifier(t.schemaName) + out.WriteString(".") + out.writeIdentifier(t.name) + + if len(t.alias) > 0 { + out.WriteString("AS") + out.writeIdentifier(t.alias) + } + + return nil +} + +// Join expressions are pseudo readable tables. +type JoinTableImpl struct { + lhs Serializer + rhs Serializer + joinType JoinType + onCondition BoolExpression +} + +func NewJoinTableImpl(lhs Serializer, rhs Serializer, joinType JoinType, onCondition BoolExpression) JoinTableImpl { + + joinTable := JoinTableImpl{ + lhs: lhs, + rhs: rhs, + joinType: joinType, + onCondition: onCondition, + } + + return joinTable +} + +func (t *JoinTableImpl) SchemaName() string { + return "" +} + +func (t *JoinTableImpl) TableName() string { + return "" +} + +func (t *JoinTableImpl) columns() []IColumn { + //return append(t.lhs.columns(), t.rhs.columns()...) + panic("Unimplemented") +} + +func (t *JoinTableImpl) accept(visitor visitor) { + //t.lhs.accept(visitor) + //t.rhs.accept(visitor) + //TODO: uncoment +} + +func (t *JoinTableImpl) dialect() Dialect { + return detectDialect(t) +} + +func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) { + if t == nil { + return errors.New("jet: Join table is nil. ") + } + + if utils.IsNil(t.lhs) { + return errors.New("jet: left hand side of join operation is nil table") + } + + if err = t.lhs.serialize(statement, out); err != nil { return } - if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.Buff.Len() > 0 { - q.Buff.WriteByte(' ') + out.newLine() + + switch t.joinType { + case InnerJoin: + out.WriteString("INNER JOIN") + case LeftJoin: + out.WriteString("LEFT JOIN") + case RightJoin: + out.WriteString("RIGHT JOIN") + case FullJoin: + out.WriteString("FULL JOIN") + case CrossJoin: + out.WriteString("CROSS JOIN") } - q.Buff.Write(data) - q.lastChar = data[len(data)-1] -} - -func isPreSeparator(b byte) bool { - return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' -} - -func isPostSeparator(b byte) bool { - return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' -} - -func (q *SqlBuilder) writeAlias(str string) { - aliasQuoteChar := string(q.Dialect.AliasQuoteChar()) - q.WriteString(aliasQuoteChar + str + aliasQuoteChar) -} - -func (q *SqlBuilder) WriteString(str string) { - q.write([]byte(str)) -} - -func (q *SqlBuilder) writeIdentifier(name string, alwaysQuote ...bool) { - quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -") - - if quoteWrap || len(alwaysQuote) > 0 { - identQuoteChar := string(q.Dialect.IdentifierQuoteChar()) - q.WriteString(identQuoteChar + name + identQuoteChar) - } else { - q.WriteString(name) - } -} - -func (q *SqlBuilder) writeByte(b byte) { - q.write([]byte{b}) -} - -func (q *SqlBuilder) finalize() (string, []interface{}) { - return q.Buff.String() + ";\n", q.Args -} - -func (q *SqlBuilder) insertConstantArgument(arg interface{}) { - q.WriteString(argToString(arg)) -} - -func (q *SqlBuilder) insertParametrizedArgument(arg interface{}) { - q.Args = append(q.Args, arg) - argPlaceholder := q.Dialect.ArgumentPlaceholder()(len(q.Args)) - - q.WriteString(argPlaceholder) -} - -func argToString(value interface{}) string { - if utils.IsNil(value) { - return "NULL" + if utils.IsNil(t.rhs) { + return errors.New("jet: right hand side of join operation is nil table") } - switch bindVal := value.(type) { - case bool: - if bindVal { - return "TRUE" + if err = t.rhs.serialize(statement, out); err != nil { + return + } + + if t.onCondition == nil && t.joinType != CrossJoin { + return errors.New("jet: join condition is nil") + } + + if t.onCondition != nil { + out.WriteString("ON") + if err = t.onCondition.serialize(statement, out); err != nil { + return } - return "FALSE" - case int8: - return strconv.FormatInt(int64(bindVal), 10) - case int: - return strconv.FormatInt(int64(bindVal), 10) - case int16: - return strconv.FormatInt(int64(bindVal), 10) - case int32: - return strconv.FormatInt(int64(bindVal), 10) - case int64: - return strconv.FormatInt(int64(bindVal), 10) - - case uint8: - return strconv.FormatUint(uint64(bindVal), 10) - case uint: - return strconv.FormatUint(uint64(bindVal), 10) - case uint16: - return strconv.FormatUint(uint64(bindVal), 10) - case uint32: - return strconv.FormatUint(uint64(bindVal), 10) - case uint64: - return strconv.FormatUint(uint64(bindVal), 10) - - case float32: - return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) - case float64: - return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) - - case string: - return stringQuote(bindVal) - case []byte: - return stringQuote(string(bindVal)) - case uuid.UUID: - return stringQuote(bindVal.String()) - case time.Time: - return stringQuote(string(utils.FormatTimestamp(bindVal))) - default: - return "[Unsupported type]" } + + return nil } -func stringQuote(value string) string { - return `'` + strings.Replace(value, "'", "''", -1) + `'` +// SelectTable is interface for SELECT sub-queries +type SelectTable interface { + Alias() string + AllColumns() ProjectionList +} + +type SelectTableImpl2 struct { + selectStmt StatementWithProjections + alias string + + projections []Projection +} + +func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 { + selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias} + + for _, projection := range selectStmt.projections() { + newProjection := projection.fromImpl(&selectTable) + + selectTable.projections = append(selectTable.projections, newProjection) + } + + return selectTable +} + +func (s *SelectTableImpl2) Alias() string { + return s.alias +} + +func (s *SelectTableImpl2) columns() []IColumn { + return nil +} + +func (s *SelectTableImpl2) accept(visitor visitor) { + visitor.visit(s) + s.selectStmt.accept(visitor) +} + +func (s *SelectTableImpl2) dialect() Dialect { + return detectDialect(s.selectStmt) +} + +func (s *SelectTableImpl2) AllColumns() ProjectionList { + return s.projections +} + +func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { + if s == nil { + return errors.New("jet: Expression table is nil. ") + } + + err := s.selectStmt.serialize(statement, out) + + if err != nil { + return err + } + + out.WriteString("AS") + out.writeIdentifier(s.alias) + + return nil } diff --git a/internal/jet/column.go b/internal/jet/column.go index d634132..0783d5a 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -19,7 +19,7 @@ type Column interface { // The base type for real materialized columns. type columnImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl noOpVisitorImpl name string @@ -34,7 +34,7 @@ func newColumn(name string, tableName string, parent Column) columnImpl { tableName: tableName, } - bc.expressionInterfaceImpl.parent = parent + bc.ExpressionInterfaceImpl.Parent = parent return bc } diff --git a/internal/jet/column_test.go b/internal/jet/column_test.go index c6ded56..2c766d9 100644 --- a/internal/jet/column_test.go +++ b/internal/jet/column_test.go @@ -4,7 +4,7 @@ import "testing" func TestColumn(t *testing.T) { column := newColumn("col", "", nil) - column.expressionInterfaceImpl.parent = &column + column.ExpressionInterfaceImpl.Parent = &column assertClauseSerialize(t, column, "col") column.SetTableName("table1") diff --git a/internal/jet/delete_statement.go b/internal/jet/delete_statement.go index 1efb6ed..c2108a9 100644 --- a/internal/jet/delete_statement.go +++ b/internal/jet/delete_statement.go @@ -67,7 +67,7 @@ func (d *deleteStatementImpl) serializeImpl(out *SqlBuilder) error { return err } - if err := out.writeReturning(DeleteStatementType, d.returning); err != nil { + if err := out.WriteReturning(DeleteStatementType, d.returning); err != nil { return err } diff --git a/internal/jet/dialects.go b/internal/jet/dialects.go index fcdbd7a..56a2c4b 100644 --- a/internal/jet/dialects.go +++ b/internal/jet/dialects.go @@ -21,7 +21,7 @@ type Dialect interface { AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc - SetClause() func(columns []IColumn, values []Clause, out *SqlBuilder) (err error) + SetClause() func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) SupportsReturning() bool } @@ -29,7 +29,7 @@ type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...Ser type SerializeOverride func(expressions ...Expression) SerializeFunc type QueryPlaceholderFunc func(ord int) string -type UpdateAssigmentFunc func(columns []IColumn, values []Clause, out *SqlBuilder) (err error) +type UpdateAssigmentFunc func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) type DialectParams struct { Name string @@ -38,7 +38,7 @@ type DialectParams struct { AliasQuoteChar byte IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc - SetClause func(columns []IColumn, values []Clause, out *SqlBuilder) (err error) + SetClause func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) SupportsReturning bool } @@ -92,7 +92,7 @@ func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { return d.argumentPlaceholder } -func (d *dialectImpl) SetClause() func(columns []IColumn, values []Clause, out *SqlBuilder) (err error) { +func (d *dialectImpl) SetClause() func(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) { if d.setClause != nil { return d.setClause } @@ -103,7 +103,7 @@ func (d *dialectImpl) SupportsReturning() bool { return d.supportsReturning } -func setClause(columns []IColumn, values []Clause, out *SqlBuilder) (err error) { +func setClause(columns []IColumn, values []Serializer, out *SqlBuilder) (err error) { if len(columns) != len(values) { return errors.New("jet: mismatch in numers of columns and values") diff --git a/internal/jet/enum_value.go b/internal/jet/enum_value.go index ba598b1..7480b6a 100644 --- a/internal/jet/enum_value.go +++ b/internal/jet/enum_value.go @@ -1,7 +1,7 @@ package jet type enumValue struct { - expressionInterfaceImpl + ExpressionInterfaceImpl stringInterfaceImpl noOpVisitorImpl @@ -12,7 +12,7 @@ type enumValue struct { func NewEnumValue(name string) StringExpression { enumValue := &enumValue{name: name} - enumValue.expressionInterfaceImpl.parent = enumValue + enumValue.ExpressionInterfaceImpl.Parent = enumValue enumValue.stringInterfaceImpl.parent = enumValue return enumValue diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 791d3d9..84c0936 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -9,14 +9,14 @@ import ( type Expression interface { acceptsVisitor - expression + IExpression } -type expression interface { - Clause +type IExpression interface { + Serializer Projection - groupByClause - orderByClause + GroupByClause + OrderByClause // Test expression whether it is a NULL value. IS_NULL() BoolExpression @@ -32,57 +32,57 @@ type expression interface { AS(alias string) Projection // Expression will be used to sort query result in ascending order - ASC() orderByClause + ASC() OrderByClause // Expression will be used to sort query result in ascending order - DESC() orderByClause + DESC() OrderByClause } -type expressionInterfaceImpl struct { - parent Expression +type ExpressionInterfaceImpl struct { + Parent Expression } -func (e *expressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { - return e.parent +func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { + return e.Parent } -func (e *expressionInterfaceImpl) IS_NULL() BoolExpression { - return newPostifxBoolExpression(e.parent, "IS NULL") +func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression { + return newPostifxBoolExpression(e.Parent, "IS NULL") } -func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression { - return newPostifxBoolExpression(e.parent, "IS NOT NULL") +func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { + return newPostifxBoolExpression(e.Parent, "IS NOT NULL") } -func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperator(e.parent, WRAP(expressions...), "IN") +func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { + return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN") } -func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperator(e.parent, WRAP(expressions...), "NOT IN") +func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { + return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN") } -func (e *expressionInterfaceImpl) AS(alias string) Projection { - return newAlias(e.parent, alias) +func (e *ExpressionInterfaceImpl) AS(alias string) Projection { + return newAlias(e.Parent, alias) } -func (e *expressionInterfaceImpl) ASC() orderByClause { - return newOrderByClause(e.parent, true) +func (e *ExpressionInterfaceImpl) ASC() OrderByClause { + return newOrderByClause(e.Parent, true) } -func (e *expressionInterfaceImpl) DESC() orderByClause { - return newOrderByClause(e.parent, false) +func (e *ExpressionInterfaceImpl) DESC() OrderByClause { + return newOrderByClause(e.Parent, false) } -func (e *expressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SqlBuilder) error { - return e.parent.serialize(statement, out, noWrap) +func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SqlBuilder) error { + return e.Parent.serialize(statement, out, noWrap) } -func (e *expressionInterfaceImpl) serializeForProjection(statement StatementType, out *SqlBuilder) error { - return e.parent.serialize(statement, out, noWrap) +func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SqlBuilder) error { + return e.Parent.serialize(statement, out, noWrap) } -func (e *expressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) error { - return e.parent.serialize(statement, out, noWrap) +func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) error { + return e.Parent.serialize(statement, out, noWrap) } // Representation of binary operations (e.g. comparisons, arithmetic) diff --git a/internal/jet/float_expression.go b/internal/jet/float_expression.go index f61cdc0..bc69a62 100644 --- a/internal/jet/float_expression.go +++ b/internal/jet/float_expression.go @@ -86,7 +86,7 @@ func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression { //---------------------------------------------------// type binaryFloatExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl floatInterfaceImpl binaryOpExpression @@ -97,7 +97,7 @@ func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpress floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - floatExpression.expressionInterfaceImpl.parent = &floatExpression + floatExpression.ExpressionInterfaceImpl.Parent = &floatExpression floatExpression.floatInterfaceImpl.parent = &floatExpression return &floatExpression diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 1a94dbf..789b3ca 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -478,7 +478,7 @@ func LEAST(value Expression, values ...Expression) Expression { //--------------------------------------------------------------------// type funcExpressionImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl name string expressions []Expression @@ -492,9 +492,9 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr } if parent != nil { - funcExp.expressionInterfaceImpl.parent = parent + funcExp.ExpressionInterfaceImpl.Parent = parent } else { - funcExp.expressionInterfaceImpl.parent = funcExp + funcExp.ExpressionInterfaceImpl.Parent = funcExp } return funcExp diff --git a/internal/jet/group_by_clause.go b/internal/jet/group_by_clause.go index a8602b8..7a629f6 100644 --- a/internal/jet/group_by_clause.go +++ b/internal/jet/group_by_clause.go @@ -1,5 +1,5 @@ package jet -type groupByClause interface { +type GroupByClause interface { serializeForGroupBy(statement StatementType, out *SqlBuilder) error } diff --git a/internal/jet/insert_statement.go b/internal/jet/insert_statement.go index 8dae633..0ff6db8 100644 --- a/internal/jet/insert_statement.go +++ b/internal/jet/insert_statement.go @@ -35,23 +35,23 @@ func newInsertStatement(t WritableTable, columns []IColumn) InsertStatement { type insertStatementImpl struct { table WritableTable columns []IColumn - rows [][]Clause + rows [][]Serializer query SelectStatement returning []Projection } func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { - i.rows = append(i.rows, unwindRowFromValues(value, values)) + i.rows = append(i.rows, UnwindRowFromValues(value, values)) return i } func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { - i.rows = append(i.rows, unwindRowFromModel(i.getColumns(), data)) + i.rows = append(i.rows, UnwindRowFromModel(i.getColumns(), data)) return i } func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { - i.rows = append(i.rows, unwindRowsFromModels(i.getColumns(), data)...) + i.rows = append(i.rows, UnwindRowsFromModels(i.getColumns(), data)...) return i } @@ -113,6 +113,8 @@ func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []inte out.WriteString(")") } + //TODO: + if len(i.rows) == 0 && i.query == nil { return "", nil, errors.New("jet: no row values or query specified") } @@ -152,7 +154,7 @@ func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []inte } } - if err = out.writeReturning(InsertStatementType, i.returning); err != nil { + if err = out.WriteReturning(InsertStatementType, i.returning); err != nil { return } diff --git a/internal/jet/insert_statement_test.go b/internal/jet/insert_statement_test.go index 414d323..95679f5 100644 --- a/internal/jet/insert_statement_test.go +++ b/internal/jet/insert_statement_test.go @@ -82,9 +82,9 @@ func TestInsertValuesFromModel(t *testing.T) { MODEL(&toInsert) expectedSQL := ` -INSERT INTO db.table1 (col1, col_float) VALUES - ($1, $2), - ($3, $4); +INSERT INTO db.table1 (col1, col_float) +VALUES ($1, $2), + ($3, $4); ` assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) diff --git a/internal/jet/integer_expression.go b/internal/jet/integer_expression.go index 7018ece..3394c16 100644 --- a/internal/jet/integer_expression.go +++ b/internal/jet/integer_expression.go @@ -131,7 +131,7 @@ func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) //---------------------------------------------------// type binaryIntegerExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl integerInterfaceImpl binaryOpExpression @@ -140,7 +140,7 @@ type binaryIntegerExpression struct { func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { integerExpression := binaryIntegerExpression{} - integerExpression.expressionInterfaceImpl.parent = &integerExpression + integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression integerExpression.integerInterfaceImpl.parent = &integerExpression integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) @@ -150,7 +150,7 @@ func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) Int //---------------------------------------------------// type prefixIntegerOpExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl integerInterfaceImpl prefixOpExpression @@ -160,7 +160,7 @@ func newPrefixIntegerOperator(expression IntegerExpression, operator string) Int integerExpression := prefixIntegerOpExpression{} integerExpression.prefixOpExpression = newPrefixExpression(expression, operator) - integerExpression.expressionInterfaceImpl.parent = &integerExpression + integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression integerExpression.integerInterfaceImpl.parent = &integerExpression return &integerExpression @@ -168,7 +168,7 @@ func newPrefixIntegerOperator(expression IntegerExpression, operator string) Int //---------------------------------------------------// type prefixFloatOpExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl floatInterfaceImpl prefixOpExpression @@ -178,7 +178,7 @@ func newPrefixFloatOperator(expression FloatExpression, operator string) FloatEx floatOpExpression := prefixFloatOpExpression{} floatOpExpression.prefixOpExpression = newPrefixExpression(expression, operator) - floatOpExpression.expressionInterfaceImpl.parent = &floatOpExpression + floatOpExpression.ExpressionInterfaceImpl.Parent = &floatOpExpression floatOpExpression.floatInterfaceImpl.parent = &floatOpExpression return &floatOpExpression diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index fba2644..412e017 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -15,7 +15,7 @@ type LiteralExpression interface { } type literalExpressionImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl noOpVisitorImpl value interface{} @@ -29,7 +29,7 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl exp.constant = optionalConstant[0] } - exp.expressionInterfaceImpl.parent = &exp + exp.ExpressionInterfaceImpl.Parent = &exp return &exp } @@ -73,7 +73,7 @@ func Int(value int64, constant ...bool) IntegerExpression { numLiteral.constant = true } - numLiteral.literalExpressionImpl.parent = numLiteral + numLiteral.literalExpressionImpl.Parent = numLiteral numLiteral.integerInterfaceImpl.parent = numLiteral return numLiteral @@ -204,14 +204,14 @@ func DateT(t time.Time) DateExpression { //--------------------------------------------------// type nullLiteral struct { - expressionInterfaceImpl + ExpressionInterfaceImpl noOpVisitorImpl } func newNullLiteral() Expression { nullExpression := &nullLiteral{} - nullExpression.expressionInterfaceImpl.parent = nullExpression + nullExpression.ExpressionInterfaceImpl.Parent = nullExpression return nullExpression } @@ -223,14 +223,14 @@ func (n *nullLiteral) serialize(statement StatementType, out *SqlBuilder, option //--------------------------------------------------// type starLiteral struct { - expressionInterfaceImpl + ExpressionInterfaceImpl noOpVisitorImpl } func newStarLiteral() Expression { starExpression := &starLiteral{} - starExpression.expressionInterfaceImpl.parent = starExpression + starExpression.ExpressionInterfaceImpl.Parent = starExpression return starExpression } @@ -243,7 +243,7 @@ func (n *starLiteral) serialize(statement StatementType, out *SqlBuilder, option //---------------------------------------------------// type wrap struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expressions []Expression } @@ -263,7 +263,7 @@ func (n *wrap) serialize(statement StatementType, out *SqlBuilder, options ...Se // WRAP wraps list of expressions with brackets '(' and ')' func WRAP(expression ...Expression) Expression { wrap := &wrap{expressions: expression} - wrap.expressionInterfaceImpl.parent = wrap + wrap.ExpressionInterfaceImpl.Parent = wrap return wrap } @@ -271,7 +271,7 @@ func WRAP(expression ...Expression) Expression { //---------------------------------------------------// type rawExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl noOpVisitorImpl raw string @@ -286,7 +286,7 @@ func (n *rawExpression) serialize(statement StatementType, out *SqlBuilder, opti // For example: Raw("current_database()") func Raw(raw string) Expression { rawExp := &rawExpression{raw: raw} - rawExp.expressionInterfaceImpl.parent = rawExp + rawExp.ExpressionInterfaceImpl.Parent = rawExp return rawExp } diff --git a/internal/jet/operators.go b/internal/jet/operators.go index 0b43431..ff0b301 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -71,7 +71,7 @@ type CaseOperator interface { } type caseOperatorImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expression Expression when []Expression @@ -87,7 +87,7 @@ func CASE(expression ...Expression) CaseOperator { caseExp.expression = expression[0] } - caseExp.expressionInterfaceImpl.parent = caseExp + caseExp.ExpressionInterfaceImpl.Parent = caseExp return caseExp } diff --git a/internal/jet/order_by_clause.go b/internal/jet/order_by_clause.go index 32fdb2e..782f366 100644 --- a/internal/jet/order_by_clause.go +++ b/internal/jet/order_by_clause.go @@ -3,7 +3,7 @@ package jet import "errors" // OrderByClause -type orderByClause interface { +type OrderByClause interface { serializeForOrderBy(statement StatementType, out *SqlBuilder) error } @@ -30,6 +30,6 @@ func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *Sq return nil } -func newOrderByClause(expression Expression, ascent bool) orderByClause { +func newOrderByClause(expression Expression, ascent bool) OrderByClause { return &orderByClauseImpl{expression: expression, ascent: ascent} } diff --git a/internal/jet/select_statement.go b/internal/jet/select_statement.go index 6923034..6162ebc 100644 --- a/internal/jet/select_statement.go +++ b/internal/jet/select_statement.go @@ -10,14 +10,14 @@ import ( // SelectStatement is interface for SQL SELECT statements type SelectStatement interface { Statement - expression + IExpression DISTINCT() SelectStatement FROM(table ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement - GROUP_BY(groupByClauses ...groupByClause) SelectStatement + GROUP_BY(groupByClauses ...GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement - ORDER_BY(orderByClauses ...orderByClause) SelectStatement + ORDER_BY(orderByClauses ...OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement FOR(lock SelectLock) SelectStatement @@ -40,17 +40,17 @@ func SELECT(projection1 Projection, projections ...Projection) SelectStatement { } type selectStatementImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl parent SelectStatement table ReadableTable distinct bool projectionList []Projection where BoolExpression - groupBy []groupByClause + groupBy []GroupByClause having BoolExpression - orderBy []orderByClause + orderBy []OrderByClause limit, offset int64 lockFor SelectLock @@ -65,7 +65,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta distinct: false, } - newSelect.expressionInterfaceImpl.parent = newSelect + newSelect.ExpressionInterfaceImpl.Parent = newSelect newSelect.parent = newSelect return newSelect @@ -77,7 +77,8 @@ func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s.parent, alias) + //return newSelectTable(s.parent, alias) + panic("UNimplemented.") } func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { @@ -85,7 +86,7 @@ func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { return s.parent } -func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) SelectStatement { +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement { s.groupBy = groupByClauses return s.parent } @@ -95,7 +96,7 @@ func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement return s.parent } -func (s *selectStatementImpl) ORDER_BY(clauses ...orderByClause) SelectStatement { +func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { s.orderBy = clauses return s.parent } @@ -308,7 +309,7 @@ func (s *selectStatementImpl) ExecContext(context context.Context, db execution. // SelectLock is interface for SELECT statement locks type SelectLock interface { - Clause + Serializer NOWAIT() SelectLock SKIP_LOCKED() SelectLock diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index 2d4113b..f4db54e 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -1,72 +1,70 @@ package jet -import "errors" - -// SelectTable is interface for SELECT sub-queries -type SelectTable interface { - ReadableTable - - Alias() string - - AllColumns() ProjectionList -} - -type selectTableImpl struct { - readableTableInterfaceImpl - selectStmt SelectStatement - alias string - - projections []Projection -} - -func newSelectTable(selectStmt SelectStatement, alias string) SelectTable { - expTable := &selectTableImpl{selectStmt: selectStmt, alias: alias} - - expTable.readableTableInterfaceImpl.parent = expTable - - for _, projection := range selectStmt.projections() { - newProjection := projection.fromImpl(expTable) - - expTable.projections = append(expTable.projections, newProjection) - } - - return expTable -} - -func (s *selectTableImpl) Alias() string { - return s.alias -} - -func (s *selectTableImpl) columns() []IColumn { - return nil -} - -func (s *selectTableImpl) accept(visitor visitor) { - visitor.visit(s) - s.selectStmt.accept(visitor) -} - -func (s *selectTableImpl) dialect() Dialect { - return detectDialect(s.selectStmt) -} - -func (s *selectTableImpl) AllColumns() ProjectionList { - return s.projections -} - -func (s *selectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if s == nil { - return errors.New("jet: Expression table is nil. ") - } - - err := s.selectStmt.serialize(statement, out) - - if err != nil { - return err - } - - out.WriteString("AS") - out.writeIdentifier(s.alias) - - return nil -} +//// SelectTable is interface for SELECT sub-queries +//type SelectTable interface { +// ReadableTable +// +// Alias() string +// +// AllColumns() ProjectionList +//} +// +//type selectTableImpl struct { +// readableTableInterfaceImpl +// selectStmt SelectStatement +// alias string +// +// projections []Projection +//} +// +//func newSelectTable(selectStmt SelectStatement, alias string) SelectTable { +// expTable := &selectTableImpl{selectStmt: selectStmt, alias: alias} +// +// expTable.readableTableInterfaceImpl.parent = expTable +// +// for _, projection := range selectStmt.projections() { +// newProjection := projection.fromImpl(expTable) +// +// expTable.projections = append(expTable.projections, newProjection) +// } +// +// return expTable +//} +// +//func (s *selectTableImpl) Alias() string { +// return s.alias +//} +// +//func (s *selectTableImpl) columns() []IColumn { +// return nil +//} +// +//func (s *selectTableImpl) accept(visitor visitor) { +// visitor.visit(s) +// s.selectStmt.accept(visitor) +//} +// +//func (s *selectTableImpl) dialect() Dialect { +// return detectDialect(s.selectStmt) +//} +// +//func (s *selectTableImpl) AllColumns() ProjectionList { +// return s.projections +//} +// +//func (s *selectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +// if s == nil { +// return errors.New("jet: Expression table is nil. ") +// } +// +// err := s.selectStmt.serialize(statement, out) +// +// if err != nil { +// return err +// } +// +// out.WriteString("AS") +// out.writeIdentifier(s.alias) +// +// return nil +//} diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go new file mode 100644 index 0000000..3145078 --- /dev/null +++ b/internal/jet/serializer.go @@ -0,0 +1,36 @@ +package jet + +type SerializeOption int + +const ( + noWrap SerializeOption = iota +) + +type StatementType string + +const ( + SelectStatementType StatementType = "SELECT" + InsertStatementType StatementType = "INSERT" + UpdateStatementType StatementType = "UPDATE" + DeleteStatementType StatementType = "DELETE" + SetStatementType StatementType = "SET" + LockStatementType StatementType = "LOCK" +) + +type Serializer interface { + serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error +} + +func Serialize(exp Serializer, statementType StatementType, out *SqlBuilder, options ...SerializeOption) error { + return exp.serialize(statementType, out, options...) +} + +func contains(options []SerializeOption, option SerializeOption) bool { + for _, opt := range options { + if opt == option { + return true + } + } + + return false +} diff --git a/internal/jet/clause_test.go b/internal/jet/serializer_test.go similarity index 100% rename from internal/jet/clause_test.go rename to internal/jet/serializer_test.go diff --git a/internal/jet/set_statement.go b/internal/jet/set_statement.go index 5bdaaae..9d2b83a 100644 --- a/internal/jet/set_statement.go +++ b/internal/jet/set_statement.go @@ -66,7 +66,7 @@ func newSetStatementImpl(operator string, all bool, selects []SelectStatement) S selects: selects, } - setStatement.selectStatementImpl.expressionInterfaceImpl.parent = setStatement + setStatement.selectStatementImpl.ExpressionInterfaceImpl.Parent = setStatement setStatement.selectStatementImpl.parent = setStatement setStatement.limit = -1 setStatement.offset = -1 diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go new file mode 100644 index 0000000..4bb4b5c --- /dev/null +++ b/internal/jet/sql_builder.go @@ -0,0 +1,235 @@ +package jet + +import ( + "bytes" + "github.com/go-jet/jet/internal/utils" + "github.com/google/uuid" + "strconv" + "strings" + "time" +) + +type SqlBuilder struct { + Dialect Dialect + Buff bytes.Buffer + Args []interface{} + + lastChar byte + ident int +} + +func (s *SqlBuilder) DebugSQL() string { + return queryStringToDebugString(s.Buff.String(), s.Args, s.Dialect) +} + +const defaultIdent = 5 + +func (q *SqlBuilder) increaseIdent() { + q.ident += defaultIdent +} + +func (q *SqlBuilder) decreaseIdent() { + if q.ident < defaultIdent { + q.ident = 0 + } + + q.ident -= defaultIdent +} + +func (q *SqlBuilder) writeProjections(statement StatementType, projections []Projection) error { + q.increaseIdent() + err := SerializeProjectionList(statement, projections, q) + q.decreaseIdent() + return err +} + +func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error { + q.newLine() + q.WriteString("FROM") + + q.increaseIdent() + err := table.serialize(statement, q) + q.decreaseIdent() + + return err +} + +func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error { + q.newLine() + q.WriteString("WHERE") + + q.increaseIdent() + err := where.serialize(statement, q, noWrap) + q.decreaseIdent() + + return err +} + +func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClause) error { + q.newLine() + q.WriteString("GROUP BY") + + q.increaseIdent() + err := serializeGroupByClauseList(statement, groupBy, q) + q.decreaseIdent() + + return err +} + +func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClause) error { + q.newLine() + q.WriteString("ORDER BY") + + q.increaseIdent() + err := serializeOrderByClauseList(statement, orderBy, q) + q.decreaseIdent() + + return err +} + +func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) error { + q.newLine() + q.WriteString("HAVING") + + q.increaseIdent() + err := having.serialize(statement, q, noWrap) + q.decreaseIdent() + + return err +} + +func (q *SqlBuilder) WriteReturning(statement StatementType, returning []Projection) error { + if len(returning) == 0 { + return nil + } + + if !q.Dialect.SupportsReturning() { + panic("jet: " + q.Dialect.Name() + " dialect does not support RETURNING.") + } + + q.newLine() + q.WriteString("RETURNING") + q.increaseIdent() + + return q.writeProjections(statement, returning) +} + +func (q *SqlBuilder) newLine() { + q.write([]byte{'\n'}) + q.write(bytes.Repeat([]byte{' '}, q.ident)) +} + +func (q *SqlBuilder) write(data []byte) { + if len(data) == 0 { + return + } + + if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.Buff.Len() > 0 { + q.Buff.WriteByte(' ') + } + + q.Buff.Write(data) + q.lastChar = data[len(data)-1] +} + +func isPreSeparator(b byte) bool { + return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' || b == ':' +} + +func isPostSeparator(b byte) bool { + return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' +} + +func (q *SqlBuilder) writeAlias(str string) { + aliasQuoteChar := string(q.Dialect.AliasQuoteChar()) + q.WriteString(aliasQuoteChar + str + aliasQuoteChar) +} + +func (q *SqlBuilder) WriteString(str string) { + q.write([]byte(str)) +} + +func (q *SqlBuilder) writeIdentifier(name string, alwaysQuote ...bool) { + quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -") + + if quoteWrap || len(alwaysQuote) > 0 { + identQuoteChar := string(q.Dialect.IdentifierQuoteChar()) + q.WriteString(identQuoteChar + name + identQuoteChar) + } else { + q.WriteString(name) + } +} + +func (q *SqlBuilder) writeByte(b byte) { + q.write([]byte{b}) +} + +func (q *SqlBuilder) finalize() (string, []interface{}) { + return q.Buff.String() + ";\n", q.Args +} + +func (q *SqlBuilder) insertConstantArgument(arg interface{}) { + q.WriteString(argToString(arg)) +} + +func (q *SqlBuilder) insertParametrizedArgument(arg interface{}) { + q.Args = append(q.Args, arg) + argPlaceholder := q.Dialect.ArgumentPlaceholder()(len(q.Args)) + + q.WriteString(argPlaceholder) +} + +func argToString(value interface{}) string { + if utils.IsNil(value) { + return "NULL" + } + + switch bindVal := value.(type) { + case bool: + if bindVal { + return "TRUE" + } + return "FALSE" + case int8: + return strconv.FormatInt(int64(bindVal), 10) + case int: + return strconv.FormatInt(int64(bindVal), 10) + case int16: + return strconv.FormatInt(int64(bindVal), 10) + case int32: + return strconv.FormatInt(int64(bindVal), 10) + case int64: + return strconv.FormatInt(int64(bindVal), 10) + + case uint8: + return strconv.FormatUint(uint64(bindVal), 10) + case uint: + return strconv.FormatUint(uint64(bindVal), 10) + case uint16: + return strconv.FormatUint(uint64(bindVal), 10) + case uint32: + return strconv.FormatUint(uint64(bindVal), 10) + case uint64: + return strconv.FormatUint(uint64(bindVal), 10) + + case float32: + return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) + case float64: + return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) + + case string: + return stringQuote(bindVal) + case []byte: + return stringQuote(string(bindVal)) + case uuid.UUID: + return stringQuote(bindVal.String()) + case time.Time: + return stringQuote(string(utils.FormatTimestamp(bindVal))) + default: + return "[Unsupported type]" + } +} + +func stringQuote(value string) string { + return `'` + strings.Replace(value, "'", "''", -1) + `'` +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 97f4cf9..3e77932 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -3,6 +3,7 @@ package jet import ( "context" "database/sql" + "errors" "github.com/go-jet/jet/execution" "strings" ) @@ -31,9 +32,66 @@ type Statement interface { ExecContext(context context.Context, db execution.DB) (sql.Result, error) } +type SerializerStatement interface { + Serializer + Statement +} + +type StatementWithProjections interface { + Statement + HasProjections + Serializer +} + +type HasProjections interface { + projections() []Projection +} + +type SerializerStatementInterfaceImpl struct { + noOpVisitorImpl + Parent SerializerStatement + Dialect Dialect + StatementType StatementType +} + +func (s *SerializerStatementInterfaceImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { + + queryData := &SqlBuilder{Dialect: s.Dialect} + + err = s.Parent.serialize(s.StatementType, queryData, noWrap) + + if err != nil { + return "", nil, err + } + + query, args = queryData.finalize() + + return +} + +func (s *SerializerStatementInterfaceImpl) DebugSql(dialect ...Dialect) (query string, err error) { + return debugSql(s.Parent, s.Dialect) +} + +func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error { + return query(s.Parent, db, destination) +} + +func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { + return queryContext(context, s.Parent, db, destination) +} + +func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) { + return exec(s.Parent, db) +} + +func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { + return execContext(context, s.Parent, db) +} + func debugSql(statement Statement, overrideDialect ...Dialect) (string, error) { dialect := detectDialect(statement, overrideDialect...) - sqlQuery, args, err := statement.Sql() + sqlQuery, args, err := statement.Sql(dialect) if err != nil { return "", err @@ -100,3 +158,68 @@ func execContext(context context.Context, statement Statement, db execution.DB) return db.ExecContext(context, query, args...) } + +type ExpressionStatementImpl struct { + ExpressionInterfaceImpl + StatementImpl +} + +func (s *ExpressionStatementImpl) serializeForProjection(statement StatementType, out *SqlBuilder) error { + return s.serialize(statement, out) +} + +func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) StatementImpl { + return StatementImpl{ + SerializerStatementInterfaceImpl: SerializerStatementInterfaceImpl{ + Parent: parent, + Dialect: Dialect, + StatementType: statementType, + }, + Clauses: clauses, + } +} + +type StatementImpl struct { + SerializerStatementInterfaceImpl + acceptsVisitor + + Clauses []Clause +} + +func (s *StatementImpl) projections() []Projection { + for _, clause := range s.Clauses { + if selectClause, ok := clause.(ClauseWithProjections); ok { + return selectClause.projections() + } + } + + return nil +} + +func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { + if s == nil { + return errors.New("jet: Select expression is nil. ") + } + + if !contains(options, noWrap) { + out.WriteString("(") + + out.increaseIdent() + } + + for _, clause := range s.Clauses { + err := clause.Serialize(statement, out) + + if err != nil { + return err + } + } + + if !contains(options, noWrap) { + out.decreaseIdent() + out.newLine() + out.WriteString(")") + } + + return nil +} diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index 74ffecd..0de3063 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -80,7 +80,7 @@ func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, matchType .. //---------------------------------------------------// type binaryStringExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl stringInterfaceImpl binaryOpExpression @@ -90,7 +90,7 @@ func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpre boolExpression := binaryStringExpression{} boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - boolExpression.expressionInterfaceImpl.parent = &boolExpression + boolExpression.ExpressionInterfaceImpl.Parent = &boolExpression boolExpression.stringInterfaceImpl.parent = &boolExpression return &boolExpression diff --git a/internal/jet/table.go b/internal/jet/table.go index 7d5b46c..923a3ad 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -5,7 +5,19 @@ import ( "github.com/go-jet/jet/internal/utils" ) -type table interface { +type SerializerTable interface { + Serializer + Columns() []IColumn + //SchemaName() string + //TableName() string + //AS(alias string) +} + +type TableInterface interface { + Columns() []IColumn +} + +type TableBase interface { dialect() Dialect columns() []IColumn } @@ -40,26 +52,26 @@ type writableTable interface { // ReadableTable interface type ReadableTable interface { - table + TableBase readableTable - Clause + Serializer acceptsVisitor } // WritableTable interface type WritableTable interface { - table + TableBase writableTable - Clause + Serializer acceptsVisitor } // Table interface type Table interface { - table + TableBase readableTable writableTable - Clause + Serializer acceptsVisitor SchemaName() string @@ -78,25 +90,25 @@ func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections // Creates a inner join tableName Expression using onCondition. func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, innerJoin, onCondition) + return newJoinTable(r.parent, table, InnerJoin, onCondition) } // Creates a left join tableName Expression using onCondition. func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, leftJoin, onCondition) + return newJoinTable(r.parent, table, LeftJoin, onCondition) } // Creates a right join tableName Expression using onCondition. func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, rightJoin, onCondition) + return newJoinTable(r.parent, table, RightJoin, onCondition) } func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { - return newJoinTable(r.parent, table, fullJoin, onCondition) + return newJoinTable(r.parent, table, FullJoin, onCondition) } func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable { - return newJoinTable(r.parent, table, crossJoin, nil) + return newJoinTable(r.parent, table, CrossJoin, nil) } type writableTableInterfaceImpl struct { @@ -104,11 +116,11 @@ type writableTableInterfaceImpl struct { } func (w *writableTableInterfaceImpl) INSERT(columns ...IColumn) InsertStatement { - return newInsertStatement(w.parent, unwidColumnList(columns)) + return newInsertStatement(w.parent, UnwidColumnList(columns)) } func (w *writableTableInterfaceImpl) UPDATE(column IColumn, columns ...IColumn) UpdateStatement { - return newUpdateStatement(w.parent, unwindColumns(column, columns...)) + return newUpdateStatement(w.parent, UnwindColumns(column, columns...)) } func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { @@ -200,14 +212,14 @@ func (t *tableImpl) serialize(statement StatementType, out *SqlBuilder, options return nil } -type joinType int +type JoinType int const ( - innerJoin joinType = iota - leftJoin - rightJoin - fullJoin - crossJoin + InnerJoin JoinType = iota + LeftJoin + RightJoin + FullJoin + CrossJoin ) // Join expressions are pseudo readable tables. @@ -216,15 +228,15 @@ type joinTable struct { lhs ReadableTable rhs ReadableTable - joinType joinType + joinType JoinType onCondition BoolExpression } func newJoinTable( lhs ReadableTable, rhs ReadableTable, - joinType joinType, - onCondition BoolExpression) ReadableTable { + joinType JoinType, + onCondition BoolExpression) *joinTable { joinTable := &joinTable{ lhs: lhs, @@ -275,15 +287,15 @@ func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options out.newLine() switch t.joinType { - case innerJoin: + case InnerJoin: out.WriteString("INNER JOIN") - case leftJoin: + case LeftJoin: out.WriteString("LEFT JOIN") - case rightJoin: + case RightJoin: out.WriteString("RIGHT JOIN") - case fullJoin: + case FullJoin: out.WriteString("FULL JOIN") - case crossJoin: + case CrossJoin: out.WriteString("CROSS JOIN") } @@ -295,7 +307,7 @@ func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options return } - if t.onCondition == nil && t.joinType != crossJoin { + if t.onCondition == nil && t.joinType != CrossJoin { return errors.New("jet: join condition is nil") } @@ -309,7 +321,7 @@ func (t *joinTable) serialize(statement StatementType, out *SqlBuilder, options return nil } -func unwindColumns(column1 IColumn, columns ...IColumn) []IColumn { +func UnwindColumns(column1 IColumn, columns ...IColumn) []IColumn { columnList := []IColumn{} if val, ok := column1.(IColumnList); ok { @@ -325,7 +337,7 @@ func unwindColumns(column1 IColumn, columns ...IColumn) []IColumn { return columnList } -func unwidColumnList(columns []IColumn) []IColumn { +func UnwidColumnList(columns []IColumn) []IColumn { ret := []IColumn{} for _, col := range columns { diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index e73e117..97f4a80 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -72,7 +72,7 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertClauseSerialize(t *testing.T, clause Clause, query string, args ...interface{}) { +func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { out := SqlBuilder{Dialect: ANSII} err := clause.serialize(SelectStatementType, &out) @@ -82,7 +82,7 @@ func assertClauseSerialize(t *testing.T, clause Clause, query string, args ...in assert.DeepEqual(t, out.Args, args) } -func assertClauseSerializeErr(t *testing.T, clause Clause, errString string) { +func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { out := SqlBuilder{Dialect: ANSII} err := clause.serialize(SelectStatementType, &out) diff --git a/internal/jet/time_expression.go b/internal/jet/time_expression.go index 779d37f..03ce827 100644 --- a/internal/jet/time_expression.go +++ b/internal/jet/time_expression.go @@ -53,7 +53,7 @@ func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression { //---------------------------------------------------// type prefixTimeExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl timeInterfaceImpl prefixOpExpression @@ -63,8 +63,8 @@ type prefixTimeExpression struct { // timeExpr := prefixTimeExpression{} // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // -// timeExpr.expressionInterfaceImpl.parent = &timeExpr -// timeExpr.timeInterfaceImpl.parent = &timeExpr +// timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr +// timeExpr.timeInterfaceImpl.Parent = &timeExpr // // return &timeExpr //} diff --git a/internal/jet/timestampz_expression.go b/internal/jet/timestampz_expression.go index dacfc66..1d83523 100644 --- a/internal/jet/timestampz_expression.go +++ b/internal/jet/timestampz_expression.go @@ -54,7 +54,7 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression //---------------------------------------------------// type prefixTimestampzOperator struct { - expressionInterfaceImpl + ExpressionInterfaceImpl timestampzInterfaceImpl prefixOpExpression @@ -64,7 +64,7 @@ func NewPrefixTimestampOperator(operator string, expression Expression) Timestam timeExpr := prefixTimestampzOperator{} timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) - timeExpr.expressionInterfaceImpl.parent = &timeExpr + timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr timeExpr.timestampzInterfaceImpl.parent = &timeExpr return &timeExpr diff --git a/internal/jet/timez_expression.go b/internal/jet/timez_expression.go index d860703..529c92e 100644 --- a/internal/jet/timez_expression.go +++ b/internal/jet/timez_expression.go @@ -61,7 +61,7 @@ func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression { //---------------------------------------------------// type prefixTimezExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl timezInterfaceImpl prefixOpExpression @@ -71,8 +71,8 @@ type prefixTimezExpression struct { // timeExpr := prefixTimezExpression{} // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // -// timeExpr.expressionInterfaceImpl.parent = &timeExpr -// timeExpr.timezInterfaceImpl.parent = &timeExpr +// timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr +// timeExpr.timezInterfaceImpl.Parent = &timeExpr // // return &timeExpr //} diff --git a/internal/jet/update_statement.go b/internal/jet/update_statement.go index b4e20b4..5a7b235 100644 --- a/internal/jet/update_statement.go +++ b/internal/jet/update_statement.go @@ -23,26 +23,26 @@ func newUpdateStatement(table WritableTable, columns []IColumn) UpdateStatement return &updateStatementImpl{ table: table, columns: columns, - values: make([]Clause, 0, len(columns)), + values: make([]Serializer, 0, len(columns)), } } type updateStatementImpl struct { table WritableTable columns []IColumn - values []Clause + values []Serializer where BoolExpression returning []Projection } func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { - u.values = unwindRowFromValues(value, values) + u.values = UnwindRowFromValues(value, values) return u } func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { - u.values = unwindRowFromModel(u.columns, data) + u.values = UnwindRowFromModel(u.columns, data) return u } @@ -101,7 +101,7 @@ func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []inte return } - if err = out.writeReturning(UpdateStatementType, u.returning); err != nil { + if err = out.WriteReturning(UpdateStatementType, u.returning); err != nil { return } diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 3afdeaa..2945377 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -7,7 +7,7 @@ import ( "strings" ) -func serializeOrderByClauseList(statement StatementType, orderByClauses []orderByClause, out *SqlBuilder) error { +func serializeOrderByClauseList(statement StatementType, orderByClauses []OrderByClause, out *SqlBuilder) error { for i, value := range orderByClauses { if i > 0 { @@ -24,7 +24,7 @@ func serializeOrderByClauseList(statement StatementType, orderByClauses []orderB return nil } -func serializeGroupByClauseList(statement StatementType, clauses []groupByClause, out *SqlBuilder) (err error) { +func serializeGroupByClauseList(statement StatementType, clauses []GroupByClause, out *SqlBuilder) (err error) { for i, c := range clauses { if i > 0 { @@ -43,7 +43,7 @@ func serializeGroupByClauseList(statement StatementType, clauses []groupByClause return nil } -func SerializeClauseList(statement StatementType, clauses []Clause, out *SqlBuilder) (err error) { +func SerializeClauseList(statement StatementType, clauses []Serializer, out *SqlBuilder) (err error) { for i, c := range clauses { if i > 0 { @@ -124,18 +124,18 @@ func ColumnListToProjectionList(columns []Column) []Projection { return ret } -func valueToClause(value interface{}) Clause { - if clause, ok := value.(Clause); ok { +func valueToClause(value interface{}) Serializer { + if clause, ok := value.(Serializer); ok { return clause } return literal(value) } -func unwindRowFromModel(columns []IColumn, data interface{}) []Clause { +func UnwindRowFromModel(columns []IColumn, data interface{}) []Serializer { structValue := reflect.Indirect(reflect.ValueOf(data)) - row := []Clause{} + row := []Serializer{} mustBe(structValue, reflect.Struct) @@ -163,23 +163,23 @@ func unwindRowFromModel(columns []IColumn, data interface{}) []Clause { return row } -func unwindRowsFromModels(columns []IColumn, data interface{}) [][]Clause { +func UnwindRowsFromModels(columns []IColumn, data interface{}) [][]Serializer { sliceValue := reflect.Indirect(reflect.ValueOf(data)) mustBe(sliceValue, reflect.Slice) - rows := [][]Clause{} + rows := [][]Serializer{} for i := 0; i < sliceValue.Len(); i++ { structValue := sliceValue.Index(i) - rows = append(rows, unwindRowFromModel(columns, structValue.Interface())) + rows = append(rows, UnwindRowFromModel(columns, structValue.Interface())) } return rows } -func unwindRowFromValues(value interface{}, values []interface{}) []Clause { - row := []Clause{} +func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer { + row := []Serializer{} allValues := append([]interface{}{value}, values...) diff --git a/internal/jet/visitor.go b/internal/jet/visitor.go index c513569..ff6172b 100644 --- a/internal/jet/visitor.go +++ b/internal/jet/visitor.go @@ -45,7 +45,7 @@ func (f *DialectFinder) mustGetDialect() Dialect { func (f *DialectFinder) visit(element acceptsVisitor) { - if table, ok := element.(table); ok { + if table, ok := element.(TableBase); ok { dialect := table.dialect() f.dialects[dialect.Name()] = dialect } diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 3ac9749..0a91e62 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -58,7 +58,7 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) { +func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { out := jet.SqlBuilder{Dialect: Dialect} err := jet.Serialize(clause, jet.SelectStatementType, &out) @@ -68,7 +68,7 @@ func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args . assert.DeepEqual(t, out.Args, args) } -func assertClauseSerializeErr(t *testing.T, clause jet.Clause, errString string) { +func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { out := jet.SqlBuilder{Dialect: Dialect} err := jet.Serialize(clause, jet.SelectStatementType, &out) diff --git a/postgres/delete_statement.go b/postgres/delete_statement.go new file mode 100644 index 0000000..0fcfcfd --- /dev/null +++ b/postgres/delete_statement.go @@ -0,0 +1,41 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +type DeleteStatement interface { + jet.Statement + + WHERE(expression BoolExpression) DeleteStatement + + RETURNING(projections ...jet.Projection) DeleteStatement +} + +type deleteStatementImpl struct { + jet.StatementImpl + + Delete jet.ClauseStatementBegin + Where jet.ClauseWhere + Returning jet.ClauseReturning +} + +func newDeleteStatement(table WritableTable) DeleteStatement { + newDelete := &deleteStatementImpl{} + newDelete.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, &newDelete.Delete, + &newDelete.Where, &newDelete.Returning) + + newDelete.Delete.Name = "DELETE FROM" + newDelete.Delete.Tables = append(newDelete.Delete.Tables, table) + newDelete.Where.Mandatory = true + + return newDelete +} + +func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { + d.Where.Condition = expression + return d +} + +func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement { + d.Returning.Projections = projections + return d +} diff --git a/postgres/delete_statement_test.go b/postgres/delete_statement_test.go new file mode 100644 index 0000000..ebb0cc4 --- /dev/null +++ b/postgres/delete_statement_test.go @@ -0,0 +1,25 @@ +package postgres + +import ( + "testing" +) + +func TestDeleteUnconditionally(t *testing.T) { + assertStatementErr(t, table1.DELETE(), `jet: WHERE clause not set`) + assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) +} + +func TestDeleteWithWhere(t *testing.T) { + assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` +DELETE FROM db.table1 +WHERE table1.col1 = $1; +`, int64(1)) +} + +func TestDeleteWithWhereAndReturning(t *testing.T) { + assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), ` +DELETE FROM db.table1 +WHERE table1.col1 = $1 +RETURNING table1.col1 AS "table1.col1"; +`, int64(1)) +} diff --git a/postgres/dialect.go b/postgres/dialect.go index 0fbfba1..81ddc76 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -24,7 +24,7 @@ func NewDialect() jet.Dialect { ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, - SetClause: postgresSetClause, + //SetClause: postgresSetClause, SupportsReturning: true, } @@ -59,40 +59,6 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { } } -func postgresSetClause(columns []jet.IColumn, values []jet.Clause, out *jet.SqlBuilder) (err error) { - if len(columns) > 1 { - out.WriteString("(") - } - - err = jet.SerializeColumnNames(columns, out) - - if err != nil { - return - } - - if len(columns) > 1 { - out.WriteString(")") - } - - out.WriteString("=") - - if len(values) > 1 { - out.WriteString("(") - } - - err = jet.SerializeClauseList(jet.UpdateStatementType, values, out) - - if err != nil { - return - } - - if len(values) > 1 { - out.WriteString(")") - } - - return -} - func postgres_REGEXP_LIKE_function(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { if len(expressions) < 2 { diff --git a/postgres/functions.go b/postgres/functions.go index 2865f06..fbded66 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -50,11 +50,11 @@ var RTRIM = jet.RTRIM var CHR = jet.CHR var CONCAT = func(expressions ...Expression) StringExpression { - return jet.CONCAT(explicitCasts(expressions...)...) + return jet.CONCAT(explicitLiteralCasts(expressions...)...) } func CONCAT_WS(expressions ...Expression) StringExpression { - return jet.CONCAT_WS(explicitCasts(expressions...)...) + return jet.CONCAT_WS(explicitLiteralCasts(expressions...)...) } var CONVERT = jet.CONVERT @@ -64,7 +64,7 @@ var ENCODE = jet.ENCODE var DECODE = jet.DECODE func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { - return jet.FORMAT(formatStr, explicitCasts(formatArgs...)...) + return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...) } var INITCAP = jet.INITCAP @@ -107,17 +107,17 @@ var LEAST = jet.LEAST var EXISTS = jet.EXISTS var CASE = jet.CASE -func explicitCasts(expressions ...Expression) []jet.Expression { +func explicitLiteralCasts(expressions ...Expression) []jet.Expression { ret := []jet.Expression{} for _, exp := range expressions { - ret = append(ret, explicitCast(exp)) + ret = append(ret, explicitLiteralCast(exp)) } return ret } -func explicitCast(expresion Expression) jet.Expression { +func explicitLiteralCast(expresion Expression) jet.Expression { if _, ok := expresion.(jet.LiteralExpression); !ok { return expresion } diff --git a/postgres/insert_statement.go b/postgres/insert_statement.go new file mode 100644 index 0000000..c93bc9a --- /dev/null +++ b/postgres/insert_statement.go @@ -0,0 +1,73 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +// InsertStatement is interface for SQL INSERT statements +type InsertStatement interface { + jet.Statement + + // Insert row of values + VALUES(value interface{}, values ...interface{}) InsertStatement + // Insert row of values, where value for each column is extracted from filed of structure data. + // If data is not struct or there is no field for every column selected, this method will panic. + MODEL(data interface{}) InsertStatement + + MODELS(data interface{}) InsertStatement + + QUERY(selectStatement SelectStatement) InsertStatement + + RETURNING(projections ...jet.Projection) InsertStatement +} + +func newInsertStatement(table WritableTable, columns []jet.IColumn) InsertStatement { + newInsert := &insertStatementImpl{} + newInsert.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newInsert, + &newInsert.Insert, &newInsert.Values, &newInsert.Select, &newInsert.Returning) + + newInsert.Insert.Table = table + newInsert.Insert.Columns = columns + + return newInsert +} + +type insertStatementImpl struct { + jet.StatementImpl + + Insert jet.ClauseInsert + Values jet.ClauseValues + Select jet.ClauseQuery + Returning jet.ClauseReturning +} + +func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromValues(value, values)) + return i +} + +func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowFromModel(i.getColumns(), data)) + return i +} + +func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { + i.Values.Rows = append(i.Values.Rows, jet.UnwindRowsFromModels(i.getColumns(), data)...) + return i +} + +func (i *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement { + i.Returning.Projections = projections + return i +} + +func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { + i.Select.Query = selectStatement + return i +} + +func (i *insertStatementImpl) getColumns() []jet.IColumn { + if len(i.Insert.Columns) > 0 { + return i.Insert.Columns + } + + return i.Insert.Table.Columns() +} diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go new file mode 100644 index 0000000..7fd84ba --- /dev/null +++ b/postgres/insert_statement_test.go @@ -0,0 +1,148 @@ +package postgres + +import ( + "gotest.tools/assert" + "testing" + "time" +) + +//TODO: +//func TestInvalidInsert(t *testing.T) { +// assertStatementErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified") +// assertStatementErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") +//} + +func TestInsertNilValue(t *testing.T) { + assertStatement(t, table1.INSERT(table1Col1).VALUES(nil), ` +INSERT INTO db.table1 (col1) VALUES + ($1); +`, nil) +} + +func TestInsertSingleValue(t *testing.T) { + assertStatement(t, table1.INSERT(table1Col1).VALUES(1), ` +INSERT INTO db.table1 (col1) VALUES + ($1); +`, int(1)) +} + +func TestInsertWithColumnList(t *testing.T) { + columnList := ColumnList(table3ColInt, table3StrCol) + + assertStatement(t, table3.INSERT(columnList).VALUES(1, 3), ` +INSERT INTO db.table3 (col_int, col2) VALUES + ($1, $2); +`, 1, 3) +} + +func TestInsertDate(t *testing.T) { + date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) + + assertStatement(t, table1.INSERT(table1ColTime).VALUES(date), ` +INSERT INTO db.table1 (col_time) VALUES + ($1); +`, date) +} + +func TestInsertMultipleValues(t *testing.T) { + assertStatement(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` +INSERT INTO db.table1 (col1, col_float, col3) VALUES + ($1, $2, $3); +`, 1, 2, 3) +} + +func TestInsertMultipleRows(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColFloat). + VALUES(1, 2). + VALUES(11, 22). + VALUES(111, 222) + + assertStatement(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) VALUES + ($1, $2), + ($3, $4), + ($5, $6); +`, 1, 2, 11, 22, 111, 222) +} + +func TestInsertValuesFromModel(t *testing.T) { + type Table1Model struct { + Col1 *int + ColFloat float64 + } + + one := 1 + + toInsert := Table1Model{ + Col1: &one, + ColFloat: 1.11, + } + + stmt := table1.INSERT(table1Col1, table1ColFloat). + MODEL(toInsert). + MODEL(&toInsert) + + expectedSQL := ` +INSERT INTO db.table1 (col1, col_float) VALUES + ($1, $2), + ($3, $4); +` + + assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) +} + +func TestInsertValuesFromModelColumnMismatch(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, r, "missing struct field for column : col1") + }() + type Table1Model struct { + Col1Prim int + Col2 string + } + + newData := Table1Model{ + Col1Prim: 1, + Col2: "one", + } + + table1. + INSERT(table1Col1, table1ColFloat). + MODEL(newData) +} + +func TestInsertFromNonStructModel(t *testing.T) { + + defer func() { + r := recover() + assert.Equal(t, r, "argument mismatch: expected struct, got []int") + }() + + table2.INSERT(table2ColInt).MODEL([]int{}) +} + +func TestInsertQuery(t *testing.T) { + + stmt := table1.INSERT(table1Col1). + QUERY(table1.SELECT(table1Col1)) + + var expectedSQL = ` +INSERT INTO db.table1 (col1) ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 +); +` + assertStatement(t, stmt, expectedSQL) +} + +func TestInsertDefaultValue(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColFloat). + VALUES(DEFAULT, "two") + + var expectedSQL = ` +INSERT INTO db.table1 (col1, col_float) VALUES + (DEFAULT, $1); +` + + assertStatement(t, stmt, expectedSQL, "two") +} diff --git a/postgres/lock_statement.go b/postgres/lock_statement.go index 81f95d8..0e31f14 100644 --- a/postgres/lock_statement.go +++ b/postgres/lock_statement.go @@ -2,20 +2,52 @@ package postgres import "github.com/go-jet/jet/internal/jet" -type TableLockMode jet.TableLockMode +type TableLockMode string // Lock types for LockStatement. const ( - LOCK_ACCESS_SHARE = "ACCESS SHARE" - LOCK_ROW_SHARE = "ROW SHARE" - LOCK_ROW_EXCLUSIVE = "ROW EXCLUSIVE" - LOCK_SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE" - LOCK_SHARE = "SHARE" - LOCK_SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE" - LOCK_EXCLUSIVE = "EXCLUSIVE" - LOCK_ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE" + LOCK_ACCESS_SHARE TableLockMode = "ACCESS SHARE" + LOCK_ROW_SHARE TableLockMode = "ROW SHARE" + LOCK_ROW_EXCLUSIVE TableLockMode = "ROW EXCLUSIVE" + LOCK_SHARE_UPDATE_EXCLUSIVE TableLockMode = "SHARE UPDATE EXCLUSIVE" + LOCK_SHARE TableLockMode = "SHARE" + LOCK_SHARE_ROW_EXCLUSIVE TableLockMode = "SHARE ROW EXCLUSIVE" + LOCK_EXCLUSIVE TableLockMode = "EXCLUSIVE" + LOCK_ACCESS_EXCLUSIVE TableLockMode = "ACCESS EXCLUSIVE" ) -type LockStatement jet.LockStatement +type LockStatement interface { + jet.Statement -var LOCK = jet.LOCK + IN(lockMode TableLockMode) LockStatement + NOWAIT() LockStatement +} + +func LOCK(tables ...jet.SerializerTable) LockStatement { + newLock := &lockStatementImpl{} + newLock.StatementImpl = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newLock, + &newLock.StatementBegin, &newLock.In, &newLock.NoWait) + + newLock.StatementBegin.Name = "LOCK TABLE" + newLock.StatementBegin.Tables = tables + newLock.NoWait.Name = "NOWAIT" + return newLock +} + +type lockStatementImpl struct { + jet.StatementImpl + + StatementBegin jet.ClauseStatementBegin + In jet.ClauseIn + NoWait jet.ClauseOptional +} + +func (l *lockStatementImpl) IN(lockMode TableLockMode) LockStatement { + l.In.LockMode = string(lockMode) + return l +} + +func (l *lockStatementImpl) NOWAIT() LockStatement { + l.NoWait.Show = true + return l +} diff --git a/postgres/lock_statement_test.go b/postgres/lock_statement_test.go new file mode 100644 index 0000000..de0dddb --- /dev/null +++ b/postgres/lock_statement_test.go @@ -0,0 +1,32 @@ +package postgres + +import ( + "testing" +) + +func TestLockTable(t *testing.T) { + assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_SHARE), ` +LOCK TABLE db.table1 IN ACCESS SHARE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_ROW_SHARE), ` +LOCK TABLE db.table1 IN ROW SHARE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_ROW_EXCLUSIVE), ` +LOCK TABLE db.table1 IN ROW EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_SHARE_UPDATE_EXCLUSIVE), ` +LOCK TABLE db.table1 IN SHARE UPDATE EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_SHARE), ` +LOCK TABLE db.table1 IN SHARE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_SHARE_ROW_EXCLUSIVE), ` +LOCK TABLE db.table1 IN SHARE ROW EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_EXCLUSIVE), ` +LOCK TABLE db.table1 IN EXCLUSIVE MODE; +`) + assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_EXCLUSIVE).NOWAIT(), ` +LOCK TABLE db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT; +`) +} diff --git a/postgres/select_statement.go b/postgres/select_statement.go new file mode 100644 index 0000000..ee3ed77 --- /dev/null +++ b/postgres/select_statement.go @@ -0,0 +1,144 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +type SelectLock = jet.SelectLock + +var ( + UPDATE = jet.NewSelectLock("UPDATE") + NO_KEY_UPDATE = jet.NewSelectLock("NO KEY UPDATE") + SHARE = jet.NewSelectLock("SHARE") + KEY_SHARE = jet.NewSelectLock("KEY SHARE") +) + +type SelectStatement interface { + jet.Statement + jet.HasProjections + jet.IExpression + + DISTINCT() SelectStatement + FROM(table ReadableTable) SelectStatement + WHERE(expression BoolExpression) SelectStatement + GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement + HAVING(boolExpression BoolExpression) SelectStatement + ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement + LIMIT(limit int64) SelectStatement + OFFSET(offset int64) SelectStatement + FOR(lock SelectLock) SelectStatement + + UNION(rhs SelectStatement) SetStatement + UNION_ALL(rhs SelectStatement) SetStatement + INTERSECT(rhs SelectStatement) SetStatement + INTERSECT_ALL(rhs SelectStatement) SetStatement + EXCEPT(rhs SelectStatement) SetStatement + EXCEPT_ALL(rhs SelectStatement) SetStatement + + AsTable(alias string) SelectTable +} + +//SELECT creates new SelectStatement with list of projections +func SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement { + return newSelectStatement(nil, append([]jet.Projection{projection}, projections...)) +} + +func newSelectStatement(table ReadableTable, projections []jet.Projection) SelectStatement { + newSelect := &selectStatementImpl{} + newSelect.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, + &newSelect.Limit, &newSelect.Offset, &newSelect.For) + + newSelect.ExpressionStatementImpl.ExpressionInterfaceImpl.Parent = newSelect + + newSelect.Select.Projections = projections + newSelect.From.Table = table + newSelect.Limit.Count = -1 + newSelect.Offset.Count = -1 + + newSelect.setOperatorsImpl.parent = newSelect + + return newSelect +} + +type selectStatementImpl struct { + jet.ExpressionStatementImpl + setOperatorsImpl + + Select jet.ClauseSelect + From jet.ClauseFrom + Where jet.ClauseWhere + GroupBy jet.ClauseGroupBy + Having jet.ClauseHaving + OrderBy jet.ClauseOrderBy + Limit jet.ClauseLimit + Offset jet.ClauseOffset + For jet.ClauseFor +} + +func (s *selectStatementImpl) DISTINCT() SelectStatement { + s.Select.Distinct = true + return s +} + +func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { + s.From.Table = table + return s +} + +func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement { + s.Where.Condition = condition + return s +} + +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement { + s.GroupBy.List = groupByClauses + return s +} + +func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatement { + s.Having.Condition = boolExpression + return s +} + +func (s *selectStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SelectStatement { + s.OrderBy.List = orderByClauses + return s +} + +func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { + s.Limit.Count = limit + return s +} + +func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { + s.Offset.Count = offset + return s +} + +func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { + s.For.Lock = lock + return s +} + +func (s *selectStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} + +type SelectTable interface { + ReadableTable + jet.SelectTable +} + +type selectTableImpl struct { + jet.SelectTableImpl2 + readableTableInterfaceImpl +} + +func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { + subQuery := &selectTableImpl{ + SelectTableImpl2: jet.NewSelectTable(selectStmt, alias), + } + + subQuery.readableTableInterfaceImpl.parent = subQuery + + return subQuery +} diff --git a/postgres/select_statement_test.go b/postgres/select_statement_test.go new file mode 100644 index 0000000..8189644 --- /dev/null +++ b/postgres/select_statement_test.go @@ -0,0 +1,137 @@ +package postgres + +import ( + "github.com/go-jet/jet/internal/testutils" + "testing" +) + +func TestInvalidSelect(t *testing.T) { + assertStatementErr(t, SELECT(nil), "jet: Projection is nil") +} + +func TestSelectColumnList(t *testing.T) { + columnList := ColumnList(table2ColInt, table2ColFloat, table3ColInt) + + assertStatement(t, SELECT(columnList).FROM(table2), ` +SELECT table2.col_int AS "table2.col_int", + table2.col_float AS "table2.col_float", + table3.col_int AS "table3.col_int" +FROM db.table2; +`) +} + +func TestSelectLiterals(t *testing.T) { + assertStatement(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` +SELECT $1, + $2, + $3 +FROM db.table1; +`, int64(1), 2.2, false) +} + +func TestSelectDistinct(t *testing.T) { + assertStatement(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` +SELECT DISTINCT table1.col_bool AS "table1.col_bool" +FROM db.table1; +`) +} + +func TestSelectFrom(t *testing.T) { + assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1; +`) + assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) + assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) +} + +func TestSelectWhere(t *testing.T) { + assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE $1; +`, true) + assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE table1.col_int >= $1; +`, int64(10)) +} + +func TestSelectGroupBy(t *testing.T) { + assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +GROUP BY table2.col_float; +`) +} + +func TestSelectHaving(t *testing.T) { + assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` +SELECT table3.col_int AS "table3.col_int" +FROM db.table3 +HAVING table1.col_bool = $1; +`, true) +} + +func TestSelectOrderBy(t *testing.T) { + assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC; +`) + assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC, table2.col_int ASC; +`) +} + +func TestSelectLimitOffset(t *testing.T) { + assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT $1; +`, int64(10)) + assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT $1 +OFFSET $2; +`, int64(10), int64(2)) +} + +func TestSelectLock(t *testing.T) { + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE; +`) + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR SHARE NOWAIT; +`) + + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(KEY_SHARE().NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR KEY SHARE NOWAIT; +`) + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(NO_KEY_UPDATE().SKIP_LOCKED()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR NO KEY UPDATE SKIP LOCKED; +`) +} diff --git a/postgres/set_statement.go b/postgres/set_statement.go new file mode 100644 index 0000000..277a2bb --- /dev/null +++ b/postgres/set_statement.go @@ -0,0 +1,147 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +// UNION effectively appends the result of sub-queries(select statements) into single query. +// It eliminates duplicate rows from its result. +func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Union, false, toSelectList(lhs, rhs, selects...)) +} + +// UNION_ALL effectively appends the result of sub-queries(select statements) into single query. +// It does not eliminates duplicate rows from its result. +func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Union, true, toSelectList(lhs, rhs, selects...)) +} + +// INTERSECT returns all rows that are in query results. +// It eliminates duplicate rows from its result. +func INTERSECT(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Intersect, false, toSelectList(lhs, rhs, selects...)) +} + +// INTERSECT_ALL returns all rows that are in query results. +// It does not eliminates duplicate rows from its result. +func INTERSECT_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Intersect, true, toSelectList(lhs, rhs, selects...)) +} + +// EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs. +// It eliminates duplicate rows from its result. +func EXCEPT(lhs, rhs jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Except, false, toSelectList(lhs, rhs)) +} + +// EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs. +// It does not eliminates duplicate rows from its result. +func EXCEPT_ALL(lhs, rhs jet.StatementWithProjections) SetStatement { + return newSetStatementImpl(Except, true, toSelectList(lhs, rhs)) +} + +type SetStatement interface { + SetOperators + + ORDER_BY(orderByClauses ...jet.OrderByClause) SetStatement + + LIMIT(limit int64) SetStatement + OFFSET(offset int64) SetStatement + + AsTable(alias string) SelectTable +} + +type SetOperators interface { + jet.Statement + jet.HasProjections + jet.IExpression + + UNION(rhs SelectStatement) SetStatement + UNION_ALL(rhs SelectStatement) SetStatement + INTERSECT(rhs SelectStatement) SetStatement + INTERSECT_ALL(rhs SelectStatement) SetStatement + EXCEPT(rhs SelectStatement) SetStatement + EXCEPT_ALL(rhs SelectStatement) SetStatement +} + +type setOperatorsImpl struct { + parent SetOperators +} + +func (s *setOperatorsImpl) UNION(rhs SelectStatement) SetStatement { + return UNION(s.parent, rhs) +} + +func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) SetStatement { + return UNION_ALL(s.parent, rhs) +} + +func (s *setOperatorsImpl) INTERSECT(rhs SelectStatement) SetStatement { + return INTERSECT(s.parent, rhs) +} + +func (s *setOperatorsImpl) INTERSECT_ALL(rhs SelectStatement) SetStatement { + return INTERSECT_ALL(s.parent, rhs) +} + +func (s *setOperatorsImpl) EXCEPT(rhs SelectStatement) SetStatement { + return EXCEPT(s.parent, rhs) +} + +func (s *setOperatorsImpl) EXCEPT_ALL(rhs SelectStatement) SetStatement { + return EXCEPT_ALL(s.parent, rhs) +} + +type setStatementImpl struct { + jet.ExpressionStatementImpl + + setOperatorsImpl + + setOperator jet.ClauseSetStmtOperator +} + +func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) SetStatement { + newSetStatement := &setStatementImpl{} + newSetStatement.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SetStatementType, newSetStatement, + &newSetStatement.setOperator) + newSetStatement.ExpressionStatementImpl.ExpressionInterfaceImpl.Parent = newSetStatement + + newSetStatement.setOperator.Operator = operator + newSetStatement.setOperator.All = all + newSetStatement.setOperator.Selects = selects + newSetStatement.setOperator.Limit.Count = -1 + newSetStatement.setOperator.Offset.Count = -1 + + newSetStatement.setOperatorsImpl.parent = newSetStatement + + newSetStatement.Clauses = []jet.Clause{&newSetStatement.setOperator} + + return newSetStatement +} + +func (s *setStatementImpl) ORDER_BY(orderByClauses ...jet.OrderByClause) SetStatement { + s.setOperator.OrderBy.List = orderByClauses + return s +} + +func (s *setStatementImpl) LIMIT(limit int64) SetStatement { + s.setOperator.Limit.Count = limit + return s +} + +func (s *setStatementImpl) OFFSET(offset int64) SetStatement { + s.setOperator.Offset.Count = offset + return s +} + +func (s *setStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} + +const ( + Union = "UNION" + Intersect = "INTERSECT" + Except = "EXCEPT" +) + +func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections { + return append([]jet.StatementWithProjections{lhs, rhs}, selects...) +} diff --git a/postgres/set_statement_test.go b/postgres/set_statement_test.go new file mode 100644 index 0000000..0ca3bd8 --- /dev/null +++ b/postgres/set_statement_test.go @@ -0,0 +1,81 @@ +package postgres + +import ( + "testing" +) + +func TestSelectSets(t *testing.T) { + select1 := SELECT(table1ColBool).FROM(table1) + select2 := SELECT(table2ColBool).FROM(table2) + + assertStatement(t, select1.UNION(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +UNION +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) + assertStatement(t, select1.UNION_ALL(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +UNION ALL +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) + + assertStatement(t, select1.INTERSECT(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +INTERSECT +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) + + assertStatement(t, select1.INTERSECT_ALL(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +INTERSECT ALL +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) + assertStatement(t, select1.EXCEPT(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +EXCEPT +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) + + assertStatement(t, select1.EXCEPT_ALL(select2), ` +( + SELECT table1.col_bool AS "table1.col_bool" + FROM db.table1 +) +EXCEPT ALL +( + SELECT table2.col_bool AS "table2.col_bool" + FROM db.table2 +); +`) + +} diff --git a/postgres/statements.go b/postgres/statements.go deleted file mode 100644 index ac74830..0000000 --- a/postgres/statements.go +++ /dev/null @@ -1,42 +0,0 @@ -package postgres - -import "github.com/go-jet/jet/internal/jet" - -type SelectStatement jet.SelectStatement -type SelectTable jet.SelectTable -type SelectLock jet.SelectLock - -var ( - UPDATE = jet.NewSelectLock("UPDATE") - NO_KEY_UPDATE = jet.NewSelectLock("NO KEY UPDATE") - SHARE = jet.NewSelectLock("SHARE") - KEY_SHARE = jet.NewSelectLock("KEY SHARE") -) - -var SELECT = jet.SELECT - -func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return jet.UNION(lhs, rhs, toJetSelects(selects...)...) -} - -func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return jet.UNION_ALL(lhs, rhs, toJetSelects(selects...)...) -} - -func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return jet.INTERSECT(lhs, rhs, toJetSelects(selects...)...) -} - -func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return jet.INTERSECT_ALL(lhs, rhs, toJetSelects(selects...)...) -} - -func toJetSelects(selects ...SelectStatement) []jet.SelectStatement { - ret := []jet.SelectStatement{} - - for _, sel := range selects { - ret = append(ret, sel) - } - - return ret -} diff --git a/postgres/statements_test.go b/postgres/statements_test.go deleted file mode 100644 index d3cc49d..0000000 --- a/postgres/statements_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package postgres - -import ( - "github.com/go-jet/jet/internal/testutils" - "testing" -) - -func TestSelectLock(t *testing.T) { - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` -SELECT table1.col_bool AS "table1.col_bool" -FROM db.table1 -FOR UPDATE; -`) - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` -SELECT table1.col_bool AS "table1.col_bool" -FROM db.table1 -FOR SHARE NOWAIT; -`) - - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(KEY_SHARE().NOWAIT()), ` -SELECT table1.col_bool AS "table1.col_bool" -FROM db.table1 -FOR KEY SHARE NOWAIT; -`) - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(NO_KEY_UPDATE().SKIP_LOCKED()), ` -SELECT table1.col_bool AS "table1.col_bool" -FROM db.table1 -FOR NO KEY UPDATE SKIP LOCKED; -`) -} diff --git a/postgres/table.go b/postgres/table.go index 3f5bab8..f7423dd 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -2,8 +2,145 @@ package postgres import "github.com/go-jet/jet/internal/jet" -type Table jet.Table +type readableTable interface { + // Generates a select query on the current tableName. + SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement + + // Creates a inner join tableName Expression using onCondition. + INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a left join tableName Expression using onCondition. + LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a right join tableName Expression using onCondition. + RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a full join tableName Expression using onCondition. + FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable + + // Creates a cross join tableName Expression using onCondition. + CROSS_JOIN(table ReadableTable) ReadableTable +} + +type writableTable interface { + INSERT(columns ...jet.IColumn) InsertStatement + UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement + DELETE() DeleteStatement + LOCK() LockStatement +} + +// ReadableTable interface +type ReadableTable interface { + //table + readableTable + jet.Serializer + //acceptsVisitor +} + +type WritableTable interface { + jet.TableInterface + writableTable + jet.Serializer +} + +type Table interface { + //table + readableTable + writableTable + jet.Serializer + //acceptsVisitor + + SchemaName() string + TableName() string + AS(alias string) +} + +type readableTableInterfaceImpl struct { + parent ReadableTable +} + +// Generates a select query on the current tableName. +func (r *readableTableInterfaceImpl) SELECT(projection1 jet.Projection, projections ...jet.Projection) SelectStatement { + return newSelectStatement(r.parent, append([]jet.Projection{projection1}, projections...)) +} + +// Creates a inner join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) +} + +// Creates a left join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) +} + +// Creates a right join tableName Expression using onCondition. +func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, jet.RightJoin, onCondition) +} + +func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { + return newJoinTable(r.parent, table, jet.FullJoin, onCondition) +} + +func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable { + return newJoinTable(r.parent, table, jet.CrossJoin, nil) +} + +type writableTableInterfaceImpl struct { + parent WritableTable +} + +func (w *writableTableInterfaceImpl) INSERT(columns ...jet.IColumn) InsertStatement { + return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) +} + +func (w *writableTableInterfaceImpl) UPDATE(column jet.IColumn, columns ...jet.IColumn) UpdateStatement { + return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) +} + +func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { + return newDeleteStatement(w.parent) +} + +func (w *writableTableInterfaceImpl) LOCK() LockStatement { + return LOCK(w.parent) +} + +type table2Impl struct { + readableTableInterfaceImpl + writableTableInterfaceImpl + + jet.TableImpl2 +} func NewTable(schemaName, name string, columns ...jet.Column) Table { - return jet.NewTable(Dialect, schemaName, name, columns...) + + t := &table2Impl{ + TableImpl2: jet.NewTable2(Dialect, schemaName, name, columns...), + } + + for _, c := range columns { + c.SetTableName(name) + } + + t.readableTableInterfaceImpl.parent = t + t.writableTableInterfaceImpl.parent = t + + return t +} + +type joinTable2 struct { + readableTableInterfaceImpl + jet.JoinTableImpl +} + +func newJoinTable(lhs jet.Serializer, rhs jet.Serializer, joinType jet.JoinType, onCondition BoolExpression) ReadableTable { + newJoinTable := &joinTable2{ + JoinTableImpl: jet.NewJoinTableImpl(lhs, rhs, joinType, onCondition), + } + + newJoinTable.readableTableInterfaceImpl.parent = newJoinTable + + return newJoinTable } diff --git a/postgres/update_statement.go b/postgres/update_statement.go new file mode 100644 index 0000000..2ec984c --- /dev/null +++ b/postgres/update_statement.go @@ -0,0 +1,55 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +// UpdateStatement is interface of SQL UPDATE statement +type UpdateStatement interface { + jet.Statement + + SET(value interface{}, values ...interface{}) UpdateStatement + MODEL(data interface{}) UpdateStatement + + WHERE(expression BoolExpression) UpdateStatement + RETURNING(projections ...jet.Projection) UpdateStatement +} + +type updateStatementImpl struct { + jet.StatementImpl + + Update jet.ClauseUpdate + Set jet.ClauseSet + Where jet.ClauseWhere + Returning jet.ClauseReturning +} + +func newUpdateStatement(table WritableTable, columns []jet.IColumn) UpdateStatement { + update := &updateStatementImpl{} + update.StatementImpl = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, + &update.Set, &update.Where, &update.Returning) + + update.Update.Table = table + update.Set.Columns = columns + update.Where.Mandatory = true + + return update +} + +func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { + u.Set.Values = jet.UnwindRowFromValues(value, values) + return u +} + +func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { + u.Set.Values = jet.UnwindRowFromModel(u.Set.Columns, data) + return u +} + +func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { + u.Where.Condition = expression + return u +} + +func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement { + u.Returning.Projections = projections + return u +} diff --git a/postgres/update_statement_test.go b/postgres/update_statement_test.go new file mode 100644 index 0000000..1986806 --- /dev/null +++ b/postgres/update_statement_test.go @@ -0,0 +1,62 @@ +package postgres + +import ( + "fmt" + "testing" +) + +func TestUpdateWithOneValue(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = $1 +WHERE table1.col_int >= $2; +` + stmt := table1.UPDATE(table1ColInt). + SET(1). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatement(t, stmt, expectedSQL, 1, int64(33)) +} + +func TestUpdateWithValues(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET (col_int, col_float) = ($1, $2) +WHERE table1.col_int >= $3; +` + stmt := table1.UPDATE(table1ColInt, table1ColFloat). + SET(1, 22.2). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33)) +} + +func TestUpdateOneColumnWithSelect(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_float = ( + SELECT table1.col_float AS "table1.col_float" + FROM db.table1 +) +WHERE table1.col1 = $1 +RETURNING table1.col1 AS "table1.col1"; +` + stmt := table1. + UPDATE(table1ColFloat). + SET( + table1.SELECT(table1ColFloat), + ). + WHERE(table1Col1.EQ(Int(2))). + RETURNING(table1Col1) + + assertStatement(t, stmt, expectedSQL, int64(2)) +} + +func TestInvalidInputs(t *testing.T) { + assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") + assertStatementErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") +} diff --git a/postgres/utils_test.go b/postgres/utils_test.go index d9b6a33..78d66e0 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -70,7 +70,7 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) { +func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { out := jet.SqlBuilder{Dialect: Dialect} err := jet.Serialize(clause, jet.SelectStatementType, &out) @@ -80,7 +80,7 @@ func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args . assert.DeepEqual(t, out.Args, args) } -func assertClauseSerializeErr(t *testing.T, clause jet.Clause, errString string) { +func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { out := jet.SqlBuilder{Dialect: Dialect} err := jet.Serialize(clause, jet.SelectStatementType, &out) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 7573c35..5c9111f 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -678,8 +678,7 @@ func TestSubQueryColumnReference(t *testing.T) { ). AsTable("subQuery") - unionexpectedSQL := ` - ( + unionexpectedSQL := ` ( ( SELECT all_types.boolean AS "all_types.boolean", all_types.integer AS "all_types.integer", @@ -775,8 +774,6 @@ FROM` ). FROM(subQuery) - //fmt.Println(stmt2.DebugSql()) - testutils.AssertDebugStatementSql(t, stmt2, expectedSQL+expected.sql+";\n", expected.args...) dest2 := []model.AllTypes{} diff --git a/tests/postgres/lock_test.go b/tests/postgres/lock_test.go index c62a169..e5ace94 100644 --- a/tests/postgres/lock_test.go +++ b/tests/postgres/lock_test.go @@ -15,7 +15,7 @@ func TestLockTable(t *testing.T) { expectedSQL := ` LOCK TABLE dvds.address IN` - var testData = []string{ + var testData = []TableLockMode{ LOCK_ACCESS_SHARE, LOCK_ROW_SHARE, LOCK_ROW_EXCLUSIVE, diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index cfd522d..833a7b5 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1,6 +1,8 @@ package postgres import ( + "fmt" + "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum" @@ -157,7 +159,12 @@ LIMIT 12; ). LIMIT(12) + fmt.Println(query.DebugSql()) + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) } func TestJoinQueryStruct(t *testing.T) { @@ -1240,6 +1247,8 @@ OFFSET 20; LIMIT(10). OFFSET(20) + fmt.Println(query.DebugSql()) + testutils.AssertDebugStatementSql(t, query, expectedQuery, float64(100), float64(200), int64(10), int64(20)) dest := []model.Payment{} @@ -1267,7 +1276,7 @@ func TestAllSetOperators(t *testing.T) { select1 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17600)).AND(Payment.PaymentID.LT(Int(17610)))) select2 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17620)).AND(Payment.PaymentID.LT(Int(17630)))) - type setOperator func(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement + type setOperator func(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) SetStatement operators := []setOperator{ UNION, UNION_ALL,