From 14e18634566e8726cc3b8759125487dea28d4dac Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 12 Apr 2020 18:53:57 +0200 Subject: [PATCH] [postgres] Add support for ON CONFLICT clause --- README.md | 2 +- .../internal/metadata/table_meta_data.go | 9 +- generator/internal/template/templates.go | 27 ++- internal/jet/cast.go | 4 +- internal/jet/clause.go | 120 ++++++++----- internal/jet/column.go | 19 ++- internal/jet/expression.go | 18 +- internal/jet/func_expression.go | 2 +- internal/jet/interval.go | 2 +- internal/jet/keyword.go | 14 +- internal/jet/literal_expression.go | 8 + internal/jet/operators.go | 8 +- internal/jet/serializer.go | 41 ++++- internal/jet/sql_builder.go | 6 +- internal/jet/sql_builder_test.go | 6 + internal/jet/statement.go | 10 +- internal/jet/table.go | 10 +- internal/jet/utils.go | 21 ++- internal/jet/window_expression.go | 10 +- internal/jet/window_func.go | 14 +- internal/testutils/test_utils.go | 38 +++-- mysql/insert_statement_test.go | 38 ++--- mysql/table.go | 4 +- mysql/table_test.go | 4 +- mysql/utils_test.go | 8 +- postgres/clause.go | 87 ++++++++++ postgres/clause_test.go | 34 ++++ postgres/clauses.go | 20 --- postgres/conflict_action.go | 36 ++++ postgres/insert_statement.go | 17 +- postgres/insert_statement_test.go | 97 ++++++++--- postgres/table.go | 4 +- postgres/table_test.go | 4 +- postgres/update_statement.go | 2 +- postgres/utils_test.go | 13 +- tests/mysql/generator_test.go | 58 +++++-- tests/mysql/insert_test.go | 38 ++--- tests/postgres/generator_test.go | 81 ++++++--- tests/postgres/insert_test.go | 157 +++++++++++++++--- tests/postgres/main_test.go | 3 + tests/postgres/sample_test.go | 7 +- tests/postgres/util_test.go | 3 +- 42 files changed, 827 insertions(+), 277 deletions(-) create mode 100644 postgres/clause.go create mode 100644 postgres/clause_test.go delete mode 100644 postgres/clauses.go create mode 100644 postgres/conflict_action.go diff --git a/README.md b/README.md index 721219a..932fff5 100644 --- a/README.md +++ b/README.md @@ -568,5 +568,5 @@ To run the tests, additional dependencies are required: ## License -Copyright 2019 Goran Bjelanovic +Copyright 2019-2020 Goran Bjelanovic Licensed under the Apache License, Version 2.0. diff --git a/generator/internal/metadata/table_meta_data.go b/generator/internal/metadata/table_meta_data.go index cb738fa..bab1088 100644 --- a/generator/internal/metadata/table_meta_data.go +++ b/generator/internal/metadata/table_meta_data.go @@ -3,6 +3,7 @@ package metadata import ( "database/sql" "github.com/go-jet/jet/internal/utils" + "strings" ) // TableMetaData metadata struct @@ -67,15 +68,19 @@ func (t TableMetaData) GoStructName() string { return utils.ToGoIdentifier(t.name) + "Table" } +// GoStructImplName returns go struct impl name for sql builder +func (t TableMetaData) GoStructImplName() string { + name := utils.ToGoIdentifier(t.name) + "Table" + return string(strings.ToLower(name)[0]) + name[1:] +} + // GetTableMetaData returns table info metadata func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) { - tableInfo.SchemaName = schemaName tableInfo.name = tableName tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName) tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName) - return } diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index 40e4773..d8e302a 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -26,7 +26,7 @@ import ( var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() -type {{.GoStructName}} struct { +type {{.GoStructImplName}} struct { {{dialect.PackageName}}.Table //Columns @@ -38,32 +38,45 @@ type {{.GoStructName}} struct { MutableColumns {{dialect.PackageName}}.ColumnList } +type {{.GoStructName}} struct { + {{.GoStructImplName}} + + EXCLUDED {{.GoStructImplName}} +} + // creates new {{.GoStructName}} with assigned alias func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} { aliasTable := new{{.GoStructName}}() - aliasTable.Table.AS(alias) - return aliasTable } func new{{.GoStructName}}() *{{.GoStructName}} { + return &{{.GoStructName}}{ + {{.GoStructImplName}}: new{{.GoStructName}}Impl("{{.SchemaName}}", "{{.Name}}"), + EXCLUDED: new{{.GoStructName}}Impl("", "excluded"), + } +} + +func new{{.GoStructName}}Impl(schemaName, tableName string) {{.GoStructImplName}} { var ( {{- range .Columns}} {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{- end}} + allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } + mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } ) - return &{{.GoStructName}}{ - Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), + return {{.GoStructImplName}}{ + Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, allColumns...), //Columns {{- range .Columns}} {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, {{- end}} - AllColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }, - MutableColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/internal/jet/cast.go b/internal/jet/cast.go index c5fe9a7..e8045e7 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -42,12 +42,12 @@ func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, opt castType := b.cast if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil { - castOverride(expression, String(castType))(statement, out, options...) + castOverride(expression, String(castType))(statement, out, FallTrough(options)...) return } out.WriteString("CAST(") - expression.serialize(statement, out, options...) + expression.serialize(statement, out, FallTrough(options)...) out.WriteString("AS") out.WriteString(castType + ")") } diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 738074b..90d194e 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -6,7 +6,7 @@ import ( // Clause interface type Clause interface { - Serialize(statementType StatementType, out *SQLBuilder) + Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) } // ClauseWithProjections interface @@ -27,7 +27,7 @@ func (s *ClauseSelect) projections() ProjectionList { } // Serialize serializes clause into SQLBuilder -func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder) { +func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("SELECT") @@ -48,7 +48,7 @@ type ClauseFrom struct { } // Serialize serializes clause into SQLBuilder -func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) { +func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if f.Table == nil { return } @@ -56,7 +56,7 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString("FROM") out.IncreaseIdent() - f.Table.serialize(statementType, out) + f.Table.serialize(statementType, out, FallTrough(options)...) out.DecreaseIdent() } @@ -67,18 +67,20 @@ type ClauseWhere struct { } // Serialize serializes clause into SQLBuilder -func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder) { +func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if c.Condition == nil { if c.Mandatory { panic("jet: WHERE clause not set") } return } - out.NewLine() + if !contains(options, SkipNewLine) { + out.NewLine() + } out.WriteString("WHERE") out.IncreaseIdent() - c.Condition.serialize(statementType, out, noWrap) + c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...) out.DecreaseIdent() } @@ -88,7 +90,7 @@ type ClauseGroupBy struct { } // Serialize serializes clause into SQLBuilder -func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder) { +func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(c.List) == 0 { return } @@ -119,7 +121,7 @@ type ClauseHaving struct { } // Serialize serializes clause into SQLBuilder -func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) { +func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if c.Condition == nil { return } @@ -128,7 +130,7 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString("HAVING") out.IncreaseIdent() - c.Condition.serialize(statementType, out, noWrap) + c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...) out.DecreaseIdent() } @@ -139,7 +141,7 @@ type ClauseOrderBy struct { } // Serialize serializes clause into SQLBuilder -func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder) { +func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if o.List == nil { return } @@ -168,7 +170,7 @@ type ClauseLimit struct { } // Serialize serializes clause into SQLBuilder -func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder) { +func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if l.Count >= 0 { out.NewLine() out.WriteString("LIMIT") @@ -182,7 +184,7 @@ type ClauseOffset struct { } // Serialize serializes clause into SQLBuilder -func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder) { +func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if o.Count >= 0 { out.NewLine() out.WriteString("OFFSET") @@ -196,14 +198,14 @@ type ClauseFor struct { } // Serialize serializes clause into SQLBuilder -func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder) { +func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if f.Lock == nil { return } out.NewLine() out.WriteString("FOR") - f.Lock.serialize(statementType, out) + f.Lock.serialize(statementType, out, FallTrough(options)...) } // ClauseSetStmtOperator struct @@ -224,7 +226,7 @@ func (s *ClauseSetStmtOperator) projections() ProjectionList { } // Serialize serializes clause into SQLBuilder -func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder) { +func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(s.Selects) < 2 { panic("jet: UNION Statement must contain at least two SELECT statements") } @@ -244,7 +246,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB panic("jet: select statement of '" + s.Operator + "' is nil") } - selectStmt.serialize(statementType, out) + selectStmt.serialize(statementType, out, FallTrough(options)...) } s.OrderBy.Serialize(statementType, out) @@ -258,7 +260,7 @@ type ClauseUpdate struct { } // Serialize serializes clause into SQLBuilder -func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) { +func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("UPDATE") @@ -266,7 +268,7 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) { panic("jet: table to update is nil") } - u.Table.serialize(statementType, out) + u.Table.serialize(statementType, out, FallTrough(options)...) } // ClauseSet struct @@ -276,7 +278,7 @@ type ClauseSet struct { } // Serialize serializes clause into SQLBuilder -func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) { +func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("SET") @@ -299,7 +301,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString(" = ") - s.Values[i].serialize(UpdateStatementType, out) + s.Values[i].serialize(UpdateStatementType, out, FallTrough(options)...) } out.DecreaseIdent(4) } @@ -320,7 +322,7 @@ func (i *ClauseInsert) GetColumns() []Column { } // Serialize serializes clause into SQLBuilder -func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder) { +func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("INSERT INTO") @@ -346,7 +348,7 @@ type ClauseValuesQuery struct { } // Serialize serializes clause into SQLBuilder -func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder) { +func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(v.Rows) == 0 && v.Query == nil { panic("jet: VALUES or QUERY has to be specified for INSERT statement") } @@ -355,8 +357,8 @@ func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuild panic("jet: VALUES or QUERY has to be specified for INSERT statement") } - v.ClauseValues.Serialize(statementType, out) - v.ClauseQuery.Serialize(statementType, out) + v.ClauseValues.Serialize(statementType, out, FallTrough(options)...) + v.ClauseQuery.Serialize(statementType, out, FallTrough(options)...) } // ClauseValues struct @@ -365,27 +367,29 @@ type ClauseValues struct { } // Serialize serializes clause into SQLBuilder -func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder) { +func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(v.Rows) == 0 { return } + out.NewLine() out.WriteString("VALUES") for rowIndex, row := range v.Rows { if rowIndex > 0 { out.WriteString(",") + out.NewLine() + } else { + out.IncreaseIdent(7) } - out.IncreaseIdent() - out.NewLine() out.WriteString("(") SerializeClauseList(statementType, row, out) out.WriteByte(')') - out.DecreaseIdent() } + out.DecreaseIdent(7) } // ClauseQuery struct @@ -394,12 +398,12 @@ type ClauseQuery struct { } // Serialize serializes clause into SQLBuilder -func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder) { +func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if v.Query == nil { return } - v.Query.serialize(statementType, out) + v.Query.serialize(statementType, out, FallTrough(options)...) } // ClauseDelete struct @@ -408,7 +412,7 @@ type ClauseDelete struct { } // Serialize serializes clause into SQLBuilder -func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) { +func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("DELETE FROM") @@ -416,7 +420,7 @@ func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) { panic("jet: nil table in DELETE clause") } - d.Table.serialize(statementType, out) + d.Table.serialize(statementType, out, FallTrough(options)...) } // ClauseStatementBegin struct @@ -426,7 +430,7 @@ type ClauseStatementBegin struct { } // Serialize serializes clause into SQLBuilder -func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder) { +func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString(d.Name) @@ -435,7 +439,7 @@ func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBu out.WriteString(", ") } - table.serialize(statementType, out) + table.serialize(statementType, out, FallTrough(options)...) } } @@ -447,7 +451,7 @@ type ClauseOptional struct { } // Serialize serializes clause into SQLBuilder -func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder) { +func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if !d.Show { return } @@ -463,7 +467,7 @@ type ClauseIn struct { } // Serialize serializes clause into SQLBuilder -func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) { +func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if i.LockMode == "" { return } @@ -485,7 +489,7 @@ type ClauseWindow struct { } // Serialize serializes clause into SQLBuilder -func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) { +func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(i.Definitions) == 0 { return } @@ -503,6 +507,44 @@ func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString("()") continue } - def.Window.serialize(statementType, out) + def.Window.serialize(statementType, out, FallTrough(options)...) } } + +// SetPair clause +type SetPair struct { + Column ColumnSerializer + Value Serializer +} + +// SetClause clause +type SetClause []SetPair + +// Serialize for SetClause +func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + out.NewLine() + out.WriteString("SET") + out.IncreaseIdent(4) + + for i, pair := range s { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + + pair.Column.serialize(statementType, out, ShortName.WithFallTrough(options)...) + out.WriteString("=") + pair.Value.serialize(statementType, out, FallTrough(options)...) + } + out.DecreaseIdent(4) +} + +// KeywordClause type +type KeywordClause struct { + Keyword +} + +// Serialize for KeywordClause +func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + k.serialize(statementType, out, FallTrough(options)...) +} diff --git a/internal/jet/column.go b/internal/jet/column.go index 85c053e..0fd59be 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -12,6 +12,12 @@ type Column interface { defaultAlias() string } +// ColumnSerializer is interface for all serializable columns +type ColumnSerializer interface { + Serializer + Column +} + // ColumnExpression interface type ColumnExpression interface { Column @@ -101,7 +107,7 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder out.WriteByte('.') out.WriteIdentifier(c.defaultAlias(), true) } else { - if c.tableName != "" { + if c.tableName != "" && !contains(options, ShortName) { out.WriteIdentifier(c.tableName) out.WriteByte('.') } @@ -125,6 +131,17 @@ func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { return newProjectionList } +func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString("(") + for i, column := range cl { + if i > 0 { + out.WriteString(", ") + } + column.serialize(statement, out, FallTrough(options)...) + } + out.WriteString(")") +} + func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) { projections := ColumnListToProjectionList(cl) diff --git a/internal/jet/expression.go b/internal/jet/expression.go index a463b76..d807534 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -72,15 +72,15 @@ func (e *ExpressionInterfaceImpl) DESC() OrderByClause { } func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { - e.Parent.serialize(statement, out, noWrap) + e.Parent.serialize(statement, out, NoWrap) } func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { - e.Parent.serialize(statement, out, noWrap) + e.Parent.serialize(statement, out, NoWrap) } func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { - e.Parent.serialize(statement, out, noWrap) + e.Parent.serialize(statement, out, NoWrap) } // Representation of binary operations (e.g. comparisons, arithmetic) @@ -117,7 +117,7 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu panic("jet: rhs is nil for '" + c.operator + "' operator") } - wrap := !contains(options, noWrap) + wrap := !contains(options, NoWrap) if wrap { out.WriteString("(") @@ -125,11 +125,11 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) - serializeOverrideFunc(statement, out, options...) + serializeOverrideFunc(statement, out, FallTrough(options)...) } else { - c.lhs.serialize(statement, out) + c.lhs.serialize(statement, out, FallTrough(options)...) out.WriteString(c.operator) - c.rhs.serialize(statement, out) + c.rhs.serialize(statement, out, FallTrough(options)...) } if wrap { @@ -163,7 +163,7 @@ func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, o panic("jet: nil prefix expression in prefix operator " + p.operator) } - p.expression.serialize(statement, out) + p.expression.serialize(statement, out, FallTrough(options)...) out.WriteString(")") } @@ -192,7 +192,7 @@ func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder panic("jet: nil prefix expression in postfix operator " + p.operator) } - p.expression.serialize(statement, out) + p.expression.serialize(statement, out, FallTrough(options)...) out.WriteString(p.operator) } diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index f38c9a2..e95bece 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -613,7 +613,7 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression { func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...) - serializeOverrideFunc(statement, out, options...) + serializeOverrideFunc(statement, out, FallTrough(options)...) return } diff --git a/internal/jet/interval.go b/internal/jet/interval.go index fab84e0..5b371e1 100644 --- a/internal/jet/interval.go +++ b/internal/jet/interval.go @@ -33,5 +33,5 @@ type IntervalImpl struct { func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("INTERVAL") - i.interval.serialize(statement, out, options...) + i.interval.serialize(statement, out, FallTrough(options)...) } diff --git a/internal/jet/keyword.go b/internal/jet/keyword.go index c903ad6..a40538e 100644 --- a/internal/jet/keyword.go +++ b/internal/jet/keyword.go @@ -2,18 +2,12 @@ package jet const ( // DEFAULT is jet equivalent of SQL DEFAULT - DEFAULT keywordClause = "DEFAULT" + DEFAULT Keyword = "DEFAULT" ) -var ( - // NULL is jet equivalent of SQL NULL - NULL = newNullLiteral() - // STAR is jet equivalent of SQL * - STAR = newStarLiteral() -) +// Keyword type +type Keyword string -type keywordClause string - -func (k keywordClause) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (k Keyword) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(string(k)) } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 15c8a4a..8cdb3d7 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -278,6 +278,14 @@ func formatNanoseconds(nanoseconds ...time.Duration) string { } //--------------------------------------------------// + +var ( + // NULL is jet equivalent of SQL NULL + NULL = newNullLiteral() + // STAR is jet equivalent of SQL * + STAR = newStarLiteral() +) + type nullLiteral struct { ExpressionInterfaceImpl } diff --git a/internal/jet/operators.go b/internal/jet/operators.go index d17081c..19173a6 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -147,7 +147,7 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o out.WriteString("(CASE") if c.expression != nil { - c.expression.serialize(statement, out) + c.expression.serialize(statement, out, FallTrough(options)...) } if len(c.when) == 0 || len(c.then) == 0 { @@ -160,15 +160,15 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o for i, when := range c.when { out.WriteString("WHEN") - when.serialize(statement, out, noWrap) + when.serialize(statement, out, NoWrap) out.WriteString("THEN") - c.then[i].serialize(statement, out, noWrap) + c.then[i].serialize(statement, out, NoWrap) } if c.els != nil { out.WriteString("ELSE") - c.els.serialize(statement, out, noWrap) + c.els.serialize(statement, out, NoWrap) } out.WriteString("END)") diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index dc661d7..2f014cc 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -5,9 +5,18 @@ type SerializeOption int // Serialize options const ( - noWrap SerializeOption = iota + NoWrap SerializeOption = iota + SkipNewLine + + fallTroughOptions // fall trough options + ShortName ) +// WithFallTrough extends existing serialize options with additional +func (s SerializeOption) WithFallTrough(options []SerializeOption) []SerializeOption { + return append(FallTrough(options), s) +} + // StatementType is type of the SQL statement type StatementType string @@ -42,6 +51,19 @@ func contains(options []SerializeOption, option SerializeOption) bool { return false } +// FallTrough filters fall-trough options from the list +func FallTrough(options []SerializeOption) []SerializeOption { + var ret []SerializeOption + + for _, option := range options { + if option > fallTroughOptions { + ret = append(ret, option) + } + } + + return ret +} + // ListSerializer serializes list of serializers with separator type ListSerializer struct { Serializers []Serializer @@ -53,6 +75,21 @@ func (s ListSerializer) serialize(statement StatementType, out *SQLBuilder, opti if i > 0 { out.WriteString(s.Separator) } - ser.serialize(statement, out) + ser.serialize(statement, out, FallTrough(options)...) + } +} + +// NewSerializerClauseImpl is constructor for Seralizer with list of clauses +func NewSerializerClauseImpl(clauses ...Clause) Serializer { + return &serializerImpl{Clauses: clauses} +} + +type serializerImpl struct { + Clauses []Clause +} + +func (s serializerImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + for _, clause := range s.Clauses { + clause.Serialize(statement, out, FallTrough(options)...) } } diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index ef7f801..642f0cd 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) { // WriteIdentifier adds identifier to output SQL func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { - if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { + if s.shouldQuote(name, alwaysQuote...) { identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) s.WriteString(identQuoteChar + name + identQuoteChar) } else { @@ -106,6 +106,10 @@ func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { } } +func (s *SQLBuilder) shouldQuote(name string, alwaysQuote ...bool) bool { + return s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 +} + // WriteByte writes byte to output SQL func (s *SQLBuilder) WriteByte(b byte) { s.write([]byte{b}) diff --git a/internal/jet/sql_builder_test.go b/internal/jet/sql_builder_test.go index f7b5ade..911df27 100644 --- a/internal/jet/sql_builder_test.go +++ b/internal/jet/sql_builder_test.go @@ -41,3 +41,9 @@ func TestArgToString(t *testing.T) { argToString(map[string]bool{}) }() } + +func TestFallTrough(t *testing.T) { + assert.Equal(t, FallTrough([]SerializeOption{ShortName}), []SerializeOption{ShortName}) + assert.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil)) + assert.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName}) +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 3b0638d..beb52f1 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -58,7 +58,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface queryData := &SQLBuilder{Dialect: s.dialect} - s.parent.serialize(s.statementType, queryData, noWrap) + s.parent.serialize(s.statementType, queryData, NoWrap) query, args = queryData.finalize() return @@ -67,7 +67,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true} - s.parent.serialize(s.statementType, sqlBuilder, noWrap) + s.parent.serialize(s.statementType, sqlBuilder, NoWrap) query, _ = sqlBuilder.finalize() return @@ -157,16 +157,16 @@ func (s *statementImpl) projections() ProjectionList { func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.WriteString("(") out.IncreaseIdent() } for _, clause := range s.Clauses { - clause.Serialize(statement, out) + clause.Serialize(statement, out, FallTrough(options)...) } - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.DecreaseIdent() out.NewLine() out.WriteString(")") diff --git a/internal/jet/table.go b/internal/jet/table.go index 4379000..ca76e3a 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -19,17 +19,15 @@ type Table interface { } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, column ColumnExpression, columns ...ColumnExpression) SerializerTable { - - columnList := append([]ColumnExpression{column}, columns...) +func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable { t := tableImpl{ schemaName: schemaName, name: name, - columnList: columnList, + columnList: columns, } - for _, c := range columnList { + for _, c := range columns { c.setTableName(name) } @@ -156,7 +154,7 @@ func (t *joinTableImpl) serialize(statement StatementType, out *SQLBuilder, opti panic("jet: left hand side of join operation is nil table") } - t.lhs.serialize(statement, out) + t.lhs.serialize(statement, out, FallTrough(options)...) out.NewLine() diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 126ce32..50f4e13 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -63,6 +63,22 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } +// SerializeColumnExpressionNames func +func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType, + out *SQLBuilder, options ...SerializeOption) { + for i, col := range columns { + if i > 0 { + out.WriteString(", ") + } + + if col == nil { + panic("jet: nil column in columns list") + } + + col.serialize(statementType, out, options...) + } +} + // ExpressionListToSerializerList converts list of expressions to list of serializers func ExpressionListToSerializerList(expressions []Expression) []Serializer { var ret []Serializer @@ -85,7 +101,8 @@ func ColumnListToProjectionList(columns []ColumnExpression) []Projection { return ret } -func valueToClause(value interface{}) Serializer { +// ToSerializerValue creates Serializer type from the value +func ToSerializerValue(value interface{}) Serializer { if clause, ok := value.(Serializer); ok { return clause } @@ -148,7 +165,7 @@ func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer { allValues := append([]interface{}{value}, values...) for _, val := range allValues { - row = append(row, valueToClause(val)) + row = append(row, ToSerializerValue(val)) } return row diff --git a/internal/jet/window_expression.go b/internal/jet/window_expression.go index 3e7f1c7..e5f18c1 100644 --- a/internal/jet/window_expression.go +++ b/internal/jet/window_expression.go @@ -17,7 +17,7 @@ func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, o w.expression.serialize(statement, out) if w.window != nil { out.WriteString("OVER") - w.window.serialize(statement, out) + w.window.serialize(statement, out, FallTrough(options)...) } } @@ -49,7 +49,7 @@ func (f *windowExpressionImpl) OVER(window ...Window) Expression { } func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } // ----------------------------------------------------- @@ -80,7 +80,7 @@ func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression { } func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } // ------------------------------------------------ @@ -111,7 +111,7 @@ func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression { } func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } // ------------------------------------------------ @@ -142,5 +142,5 @@ func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression { } func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } diff --git a/internal/jet/window_func.go b/internal/jet/window_func.go index 7f4d1b7..6602495 100644 --- a/internal/jet/window_func.go +++ b/internal/jet/window_func.go @@ -30,7 +30,7 @@ func newWindowImpl(parent Window) *windowImpl { } func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.WriteByte('(') } @@ -40,7 +40,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options serializeExpressionList(statement, w.partitionBy, ", ", out) } w.orderBy.SkipNewLine = true - w.orderBy.Serialize(statement, out) + w.orderBy.Serialize(statement, out, FallTrough(options)...) if w.frameUnits != "" { out.WriteString(w.frameUnits) @@ -55,7 +55,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options } } - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.WriteByte(')') } } @@ -139,7 +139,7 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op if f == nil { return } - f.offset.serialize(statement, out) + f.offset.serialize(statement, out, FallTrough(options)...) if f.preceding { out.WriteString("PRECEDING") @@ -152,12 +152,12 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op // Window function keywords var ( - UNBOUNDED = keywordClause("UNBOUNDED") + UNBOUNDED = Keyword("UNBOUNDED") CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"} ) type frameExtentKeyword struct { - keywordClause + Keyword } func (f frameExtentKeyword) isFrameExtent() {} @@ -180,7 +180,7 @@ func (w windowName) serialize(statement StatementType, out *SQLBuilder, options out.WriteByte('(') out.WriteString(w.name) - w.windowImpl.serialize(statement, out, noWrap) + w.windowImpl.serialize(statement, out, NoWrap.WithFallTrough(options)...) out.WriteByte(')') } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index f0c18e9..942dab7 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -8,6 +8,7 @@ import ( "github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/qrm" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "io/ioutil" "os" "path/filepath" @@ -110,17 +111,17 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st _, args := query.Sql() if len(expectedArgs) > 0 { - AssertDeepEqual(t, args, expectedArgs) + AssertDeepEqual(t, args, expectedArgs, "arguments are not equal") } debuqSql := query.DebugSql() assert.Equal(t, debuqSql, expectedQuery) } -// AssertClauseSerialize checks if clause serialize produces expected query and args -func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { +// AssertSerialize checks if clause serialize produces expected query and args +func AssertSerialize(t *testing.T, dialect jet.Dialect, serializer jet.Serializer, query string, args ...interface{}) { out := jet.SQLBuilder{Dialect: dialect} - jet.Serialize(clause, jet.SelectStatementType, &out) + jet.Serialize(serializer, jet.SelectStatementType, &out) //fmt.Println(out.Buff.String()) @@ -131,8 +132,20 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali } } -// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args -func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { +// AssertClauseSerialize checks if clause serialize produces expected query and args +func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) { + out := jet.SQLBuilder{Dialect: dialect} + clause.Serialize(jet.SelectStatementType, &out) + + require.Equal(t, out.Buff.String(), query) + + if len(args) > 0 { + AssertDeepEqual(t, out.Args, args) + } +} + +// AssertDebugSerialize checks if clause serialize produces expected debug query and args +func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { out := jet.SQLBuilder{Dialect: dialect, Debug: true} jet.Serialize(clause, jet.SelectStatementType, &out) @@ -153,8 +166,8 @@ func AssertPanicErr(t *testing.T, fun func(), errorStr string) { fun() } -// AssertClauseSerializeErr check if clause serialize panics with errString -func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { +// AssertSerializeErr check if clause serialize panics with errString +func AssertSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { defer func() { r := recover() assert.Equal(t, r, errString) @@ -191,9 +204,8 @@ func AssertFileContent(t *testing.T, filePath string, contentBegin string, expec beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) - //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") - - AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) + //AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) + require.Equal(t, string(enumFileData[beginIndex:]), expectedContent) } // AssertFileNamesEqual check if all filesInfos are contained in fileNames @@ -212,6 +224,6 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } // AssertDeepEqual checks if actual and expected objects are deeply equal. -func AssertDeepEqual(t *testing.T, actual, expected interface{}) { - assert.True(t, cmp.Equal(actual, expected)) +func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { + assert.True(t, cmp.Equal(actual, expected), msg) } diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index 6faf313..65c8fba 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -13,15 +13,15 @@ func TestInvalidInsert(t *testing.T) { func TestInsertNilValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` -INSERT INTO db.table1 (col1) VALUES - (?); +INSERT INTO db.table1 (col1) +VALUES (?); `, nil) } func TestInsertSingleValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` -INSERT INTO db.table1 (col1) VALUES - (?); +INSERT INTO db.table1 (col1) +VALUES (?); `, int(1)) } @@ -31,8 +31,8 @@ func TestInsertWithColumnList(t *testing.T) { columnList = append(columnList, table3StrCol) assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` -INSERT INTO db.table3 (col_int, col2) VALUES - (?, ?); +INSERT INTO db.table3 (col_int, col2) +VALUES (?, ?); `, 1, 3) } @@ -40,15 +40,15 @@ func TestInsertDate(t *testing.T) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` -INSERT INTO db.table1 (col_timestamp) VALUES - (?); +INSERT INTO db.table1 (col_timestamp) +VALUES (?); `, date) } func TestInsertMultipleValues(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` -INSERT INTO db.table1 (col1, col_float, col3) VALUES - (?, ?, ?); +INSERT INTO db.table1 (col1, col_float, col3) +VALUES (?, ?, ?); `, 1, 2, 3) } @@ -59,10 +59,10 @@ func TestInsertMultipleRows(t *testing.T) { VALUES(111, 222) assertStatementSql(t, stmt, ` -INSERT INTO db.table1 (col1, col_float) VALUES - (?, ?), - (?, ?), - (?, ?); +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?), + (?, ?); `, 1, 2, 11, 22, 111, 222) } @@ -84,9 +84,9 @@ func TestInsertValuesFromModel(t *testing.T) { MODEL(&toInsert) expectedSQL := ` -INSERT INTO db.table1 (col1, col_float) VALUES - (?, ?), - (?, ?); +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?); ` assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) @@ -127,8 +127,8 @@ func TestInsertDefaultValue(t *testing.T) { VALUES(DEFAULT, "two") var expectedSQL = ` -INSERT INTO db.table1 (col1, col_float) VALUES - (DEFAULT, ?); +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?); ` assertStatementSql(t, stmt, expectedSQL, "two") diff --git a/mysql/table.go b/mysql/table.go index 6d414a2..a4cf042 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -77,9 +77,9 @@ func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectU } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { +func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { t := &tableImpl{ - SerializerTable: jet.NewTable(schemaName, name, column, columns...), + SerializerTable: jet.NewTable(schemaName, name, columns...), } t.readableTableInterfaceImpl.parent = t diff --git a/mysql/table_test.go b/mysql/table_test.go index 3894378..3bc79f9 100644 --- a/mysql/table_test.go +++ b/mysql/table_test.go @@ -5,9 +5,9 @@ import ( ) func TestJoinNilInputs(t *testing.T) { - assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), "jet: right hand side of join operation is nil table") - assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), + assertSerializeErr(t, table2.INNER_JOIN(table1, nil), "jet: join condition is nil") } diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 1cc42f1..709097d 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -59,15 +59,15 @@ var table3 = NewTable( table3StrCol) func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { - testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) + testutils.AssertSerialize(t, Dialect, clause, query, args...) } func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { - testutils.AssertDebugClauseSerialize(t, Dialect, clause, query, args...) + testutils.AssertDebugSerialize(t, Dialect, clause, query, args...) } -func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { - testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) +func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) { + testutils.AssertSerializeErr(t, Dialect, clause, errString) } func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { diff --git a/postgres/clause.go b/postgres/clause.go new file mode 100644 index 0000000..ad61b9a --- /dev/null +++ b/postgres/clause.go @@ -0,0 +1,87 @@ +package postgres + +import ( + "github.com/go-jet/jet/internal/jet" +) + +type clauseReturning struct { + Projections []jet.Projection +} + +func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(r.Projections) == 0 { + return + } + + out.NewLine() + out.WriteString("RETURNING") + out.IncreaseIdent() + out.WriteProjections(statementType, r.Projections) +} + +// ========================================== // + +type onConflict interface { + ON_CONSTRAINT(name string) conflictTarget + WHERE(indexPredicate BoolExpression) conflictTarget + DO_NOTHING() InsertStatement + DO_UPDATE(action conflictAction) InsertStatement +} + +type conflictTarget interface { + DO_NOTHING() InsertStatement + DO_UPDATE(action conflictAction) InsertStatement +} + +type onConflictClause struct { + insertStatement InsertStatement + constraint string + indexExpressions []jet.ColumnExpression + whereClause jet.ClauseWhere + do jet.Serializer +} + +func (o *onConflictClause) ON_CONSTRAINT(name string) conflictTarget { + o.constraint = name + return o +} + +func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget { + o.whereClause.Condition = indexPredicate + return o +} + +func (o *onConflictClause) DO_NOTHING() InsertStatement { + o.do = jet.Keyword("DO NOTHING") + return o.insertStatement +} + +func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement { + o.do = action + return o.insertStatement +} + +func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(o.indexExpressions) == 0 && o.constraint == "" { + return + } + + out.NewLine() + out.WriteString("ON CONFLICT") + if len(o.indexExpressions) > 0 { + out.WriteString("(") + jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + out.WriteString(")") + } + + if o.constraint != "" { + out.WriteString("ON CONSTRAINT") + out.WriteString(o.constraint) + } + + o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName) + + out.IncreaseIdent(7) + jet.Serialize(o.do, statementType, out) + out.DecreaseIdent(7) +} diff --git a/postgres/clause_test.go b/postgres/clause_test.go new file mode 100644 index 0000000..7f64c61 --- /dev/null +++ b/postgres/clause_test.go @@ -0,0 +1,34 @@ +package postgres + +import "testing" + +func TestOnConflict(t *testing.T) { + + assertClauseSerialize(t, &onConflictClause{}, "") + + onConflict := &onConflictClause{} + onConflict.DO_NOTHING() + assertClauseSerialize(t, onConflict, "") + + onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}} + onConflict.DO_NOTHING() + assertClauseSerialize(t, onConflict, ` +ON CONFLICT (col_bool) DO NOTHING`) + + onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}} + onConflict.ON_CONSTRAINT("table_pkey").DO_NOTHING() + assertClauseSerialize(t, onConflict, ` +ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) + + onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}} + onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE( + SET(table1ColBool, Bool(true)). + SET(table1ColInt, Int(1)). + WHERE(table2ColFloat.GT(Float(11.1))), + ) + assertClauseSerialize(t, onConflict, ` +ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE + SET col_bool = $1, + col_int = $2 + WHERE table2.col_float > $3`) +} diff --git a/postgres/clauses.go b/postgres/clauses.go deleted file mode 100644 index a882b64..0000000 --- a/postgres/clauses.go +++ /dev/null @@ -1,20 +0,0 @@ -package postgres - -import ( - "github.com/go-jet/jet/internal/jet" -) - -type clauseReturning struct { - Projections []jet.Projection -} - -func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) { - if len(r.Projections) == 0 { - return - } - - out.NewLine() - out.WriteString("RETURNING") - out.IncreaseIdent() - out.WriteProjections(statementType, r.Projections) -} diff --git a/postgres/conflict_action.go b/postgres/conflict_action.go new file mode 100644 index 0000000..5ff7a19 --- /dev/null +++ b/postgres/conflict_action.go @@ -0,0 +1,36 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +type conflictAction interface { + jet.Serializer + SET(column jet.ColumnSerializer, expression interface{}) conflictAction + WHERE(condition BoolExpression) conflictAction +} + +// SET creates conflict action for ON_CONFLICT clause +func SET(column jet.ColumnSerializer, expression interface{}) conflictAction { + conflictAction := updateConflictActionImpl{} + conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"} + conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where) + conflictAction.SET(column, expression) + return &conflictAction +} + +type updateConflictActionImpl struct { + jet.Serializer + + doUpdate jet.KeywordClause + set jet.SetClause + where jet.ClauseWhere +} + +func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction { + u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)}) + return u +} + +func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction { + u.where.Condition = condition + return u +} diff --git a/postgres/insert_statement.go b/postgres/insert_statement.go index a9dee25..e4c72f5 100644 --- a/postgres/insert_statement.go +++ b/postgres/insert_statement.go @@ -11,18 +11,18 @@ type InsertStatement interface { // Insert row of values, where value for each column is extracted from filed of structure data. // If data is not struct or there is no field for every column selected, this method will panic. MODEL(data interface{}) InsertStatement - MODELS(data interface{}) InsertStatement - QUERY(selectStatement SelectStatement) InsertStatement - RETURNING(projections ...jet.Projection) InsertStatement + ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict + + RETURNING(projections ...Projection) InsertStatement } func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, - &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.Returning) + &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning) newInsert.Insert.Table = table newInsert.Insert.Columns = columns @@ -36,6 +36,7 @@ type insertStatementImpl struct { Insert jet.ClauseInsert ValuesQuery jet.ClauseValuesQuery Returning clauseReturning + OnConflict onConflictClause } func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { @@ -62,3 +63,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState i.ValuesQuery.Query = selectStatement return i } + +func (i *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict { + i.OnConflict = onConflictClause{ + insertStatement: i, + indexExpressions: indexExpressions, + } + return &i.OnConflict +} diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index ccb1404..d80c1a1 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -1,6 +1,7 @@ package postgres import ( + "github.com/go-jet/jet/internal/jet" "github.com/stretchr/testify/assert" "testing" "time" @@ -13,15 +14,15 @@ func TestInvalidInsert(t *testing.T) { func TestInsertNilValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` -INSERT INTO db.table1 (col1) VALUES - ($1); +INSERT INTO db.table1 (col1) +VALUES ($1); `, nil) } func TestInsertSingleValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` -INSERT INTO db.table1 (col1) VALUES - ($1); +INSERT INTO db.table1 (col1) +VALUES ($1); `, int(1)) } @@ -29,8 +30,8 @@ func TestInsertWithColumnList(t *testing.T) { columnList := ColumnList{table3ColInt, table3StrCol} assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` -INSERT INTO db.table3 (col_int, col2) VALUES - ($1, $2); +INSERT INTO db.table3 (col_int, col2) +VALUES ($1, $2); `, 1, 3) } @@ -38,15 +39,15 @@ func TestInsertDate(t *testing.T) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), ` -INSERT INTO db.table1 (col_time) VALUES - ($1); +INSERT INTO db.table1 (col_time) +VALUES ($1); `, date) } func TestInsertMultipleValues(t *testing.T) { - assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` -INSERT INTO db.table1 (col1, col_float, col3) VALUES - ($1, $2, $3); + assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1ColBool).VALUES(1, 2, 3), ` +INSERT INTO db.table1 (col1, col_float, col_bool) +VALUES ($1, $2, $3); `, 1, 2, 3) } @@ -57,10 +58,10 @@ func TestInsertMultipleRows(t *testing.T) { VALUES(111, 222) assertStatementSql(t, stmt, ` -INSERT INTO db.table1 (col1, col_float) VALUES - ($1, $2), - ($3, $4), - ($5, $6); +INSERT INTO db.table1 (col1, col_float) +VALUES ($1, $2), + ($3, $4), + ($5, $6); `, 1, 2, 11, 22, 111, 222) } @@ -82,12 +83,12 @@ func TestInsertValuesFromModel(t *testing.T) { MODEL(&toInsert) expectedSQL := ` -INSERT INTO db.table1 (col1, col_float) VALUES - ($1, $2), - ($3, $4); +INSERT INTO db.table1 (col1, col_float) +VALUES ($1, $2), + ($3, $4); ` - assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) + assertStatementSql(t, stmt, expectedSQL, 1, float64(1.11), 1, float64(1.11)) } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { @@ -139,9 +140,63 @@ func TestInsertDefaultValue(t *testing.T) { VALUES(DEFAULT, "two") var expectedSQL = ` -INSERT INTO db.table1 (col1, col_float) VALUES - (DEFAULT, $1); +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, $1); ` assertStatementSql(t, stmt, expectedSQL, "two") } + +func TestInsert_ON_CONFLICT(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColBool). + VALUES("one", "two"). + VALUES("1", "2"). + VALUES("theta", "beta"). + ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( + SET(table1ColBool, "12"). + SET(table2ColInt, 1). + SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). + WHERE(table1Col1.GT(Int(2))), + ). + RETURNING(table1Col1, table1ColBool) + + assertDebugStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_bool) +VALUES ('one', 'two'), + ('1', '2'), + ('theta', 'beta') +ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE + SET col_bool = '12', + col_int = 1, + (col1, col_bool) = ROW(2, 'two') + WHERE table1.col1 > 2 +RETURNING table1.col1 AS "table1.col1", + table1.col_bool AS "table1.col_bool"; +`) +} + +func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColBool). + VALUES("one", "two"). + VALUES("1", "2"). + ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( + SET(table1ColBool, "12"). + SET(table2ColInt, 1). + SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). + WHERE(table1Col1.GT(Int(2))), + ). + RETURNING(table1Col1, table1ColBool) + + assertDebugStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_bool) +VALUES ('one', 'two'), + ('1', '2') +ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE + SET col_bool = '12', + col_int = 1, + (col1, col_bool) = ROW(2, 'two') + WHERE table1.col1 > 2 +RETURNING table1.col1 AS "table1.col1", + table1.col_bool AS "table1.col_bool"; +`) +} diff --git a/postgres/table.go b/postgres/table.go index dc0b266..bc2f5c2 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -109,10 +109,10 @@ type tableImpl struct { } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { +func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { t := &tableImpl{ - SerializerTable: jet.NewTable(schemaName, name, column, columns...), + SerializerTable: jet.NewTable(schemaName, name, columns...), } t.readableTableInterfaceImpl.parent = t diff --git a/postgres/table_test.go b/postgres/table_test.go index 43aa096..3b6f498 100644 --- a/postgres/table_test.go +++ b/postgres/table_test.go @@ -5,9 +5,9 @@ import ( ) func TestJoinNilInputs(t *testing.T) { - assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), "jet: right hand side of join operation is nil table") - assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), + assertSerializeErr(t, table2.INNER_JOIN(table1, nil), "jet: join condition is nil") } diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 6d54eb5..d96e1e9 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -61,7 +61,7 @@ type clauseSet struct { Values []jet.Serializer } -func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) { +func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { out.NewLine() out.WriteString("SET") diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 38c429c..cbca2bc 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -10,7 +10,6 @@ import ( var table1Col1 = IntegerColumn("col1") var table1ColInt = IntegerColumn("col_int") var table1ColFloat = FloatColumn("col_float") -var table1Col3 = IntegerColumn("col3") var table1ColTime = TimeColumn("col_time") var table1ColTimez = TimezColumn("col_timez") var table1ColTimestamp = TimestampColumn("col_timestamp") @@ -25,7 +24,6 @@ var table1 = NewTable( table1Col1, table1ColInt, table1ColFloat, - table1Col3, table1ColTime, table1ColTimez, table1ColBool, @@ -75,12 +73,16 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { +func assertSerialize(t *testing.T, serializer jet.Serializer, query string, args ...interface{}) { + testutils.AssertSerialize(t, Dialect, serializer, query, args...) +} + +func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) { testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } -func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { - testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) +func assertSerializeErr(t *testing.T, serializer jet.Serializer, errString string) { + testutils.AssertSerializeErr(t, Dialect, serializer, errString) } func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { @@ -88,5 +90,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st } var assertStatementSql = testutils.AssertStatementSql +var assertDebugStatementSql = testutils.AssertDebugStatementSql var assertStatementSqlErr = testutils.AssertStatementSqlErr var assertPanicErr = testutils.AssertPanicErr diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 8ae9347..59df9a1 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -76,7 +76,7 @@ func assertGeneratedFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") - testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") @@ -128,7 +128,7 @@ import ( var Actor = newActorTable() -type ActorTable struct { +type actorTable struct { mysql.Table //Columns @@ -141,25 +141,38 @@ type ActorTable struct { MutableColumns mysql.ColumnList } +type ActorTable struct { + actorTable + + EXCLUDED actorTable +} + // creates new ActorTable with assigned alias func (a *ActorTable) AS(alias string) *ActorTable { aliasTable := newActorTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorTable() *ActorTable { + return &ActorTable{ + actorTable: newActorTableImpl("dvds", "actor"), + EXCLUDED: newActorTableImpl("", "excluded"), + } +} + +func newActorTableImpl(schemaName, tableName string) actorTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") FirstNameColumn = mysql.StringColumn("first_name") LastNameColumn = mysql.StringColumn("last_name") LastUpdateColumn = mysql.TimestampColumn("last_update") + allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn} + mutableColumns = mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} ) - return &ActorTable{ - Table: mysql.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), + return actorTable{ + Table: mysql.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -167,8 +180,8 @@ func newActorTable() *ActorTable { LastName: LastNameColumn, LastUpdate: LastUpdateColumn, - AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, - MutableColumns: mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` @@ -188,7 +201,7 @@ type Actor struct { } ` -var actorInfoSQLBuilerFile = ` +var actorInfoSQLBuilderFile = ` package view import ( @@ -197,7 +210,7 @@ import ( var ActorInfo = newActorInfoTable() -type ActorInfoTable struct { +type actorInfoTable struct { mysql.Table //Columns @@ -210,25 +223,38 @@ type ActorInfoTable struct { MutableColumns mysql.ColumnList } +type ActorInfoTable struct { + actorInfoTable + + EXCLUDED actorInfoTable +} + // creates new ActorInfoTable with assigned alias func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { aliasTable := newActorInfoTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorInfoTable() *ActorInfoTable { + return &ActorInfoTable{ + actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"), + EXCLUDED: newActorInfoTableImpl("", "excluded"), + } +} + +func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") FirstNameColumn = mysql.StringColumn("first_name") LastNameColumn = mysql.StringColumn("last_name") FilmInfoColumn = mysql.StringColumn("film_info") + allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} + mutableColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} ) - return &ActorInfoTable{ - Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + return actorInfoTable{ + Table: mysql.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -236,8 +262,8 @@ func newActorInfoTable() *ActorInfoTable { LastName: LastNameColumn, FilmInfo: FilmInfoColumn, - AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, - MutableColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 4c86012..cb27f61 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -15,10 +15,10 @@ func TestInsertValues(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (id, url, name, description) VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL); +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL); ` insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). @@ -69,8 +69,8 @@ func TestInsertEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) expectedSQL := ` -INSERT INTO test_sample.link VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); ` stmt := Link.INSERT(). @@ -97,8 +97,8 @@ INSERT INTO test_sample.link VALUES func TestInsertModelObject(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.duckduckgo.com', 'Duck Duck go'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); ` linkData := model.Link{ @@ -119,8 +119,8 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertModelObjectEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link VALUES - (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); ` linkData := model.Link{ @@ -141,10 +141,10 @@ INSERT INTO test_sample.link VALUES func TestInsertModelsObject(t *testing.T) { expectedSQL := ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); ` tutorial := model.Link{ @@ -177,11 +177,11 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertUsingMutableColumns(t *testing.T) { var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); ` google := model.Link{ diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 54d7cc4..71f4f8f 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -161,7 +161,7 @@ import ( var Actor = newActorTable() -type ActorTable struct { +type actorTable struct { postgres.Table //Columns @@ -174,25 +174,38 @@ type ActorTable struct { MutableColumns postgres.ColumnList } +type ActorTable struct { + actorTable + + EXCLUDED actorTable +} + // creates new ActorTable with assigned alias func (a *ActorTable) AS(alias string) *ActorTable { aliasTable := newActorTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorTable() *ActorTable { + return &ActorTable{ + actorTable: newActorTableImpl("dvds", "actor"), + EXCLUDED: newActorTableImpl("", "excluded"), + } +} + +func newActorTableImpl(schemaName, tableName string) actorTable { var ( ActorIDColumn = postgres.IntegerColumn("actor_id") FirstNameColumn = postgres.StringColumn("first_name") LastNameColumn = postgres.StringColumn("last_name") LastUpdateColumn = postgres.TimestampColumn("last_update") + allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn} + mutableColumns = postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} ) - return &ActorTable{ - Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), + return actorTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -200,8 +213,8 @@ func newActorTable() *ActorTable { LastName: LastNameColumn, LastUpdate: LastUpdateColumn, - AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, - MutableColumns: postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` @@ -230,7 +243,7 @@ import ( var ActorInfo = newActorInfoTable() -type ActorInfoTable struct { +type actorInfoTable struct { postgres.Table //Columns @@ -243,25 +256,38 @@ type ActorInfoTable struct { MutableColumns postgres.ColumnList } +type ActorInfoTable struct { + actorInfoTable + + EXCLUDED actorInfoTable +} + // creates new ActorInfoTable with assigned alias func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { aliasTable := newActorInfoTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorInfoTable() *ActorInfoTable { + return &ActorInfoTable{ + actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"), + EXCLUDED: newActorInfoTableImpl("", "excluded"), + } +} + +func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable { var ( ActorIDColumn = postgres.IntegerColumn("actor_id") FirstNameColumn = postgres.StringColumn("first_name") LastNameColumn = postgres.StringColumn("last_name") FilmInfoColumn = postgres.StringColumn("film_info") + allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} + mutableColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} ) - return &ActorInfoTable{ - Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + return actorInfoTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -269,8 +295,8 @@ func newActorInfoTable() *ActorInfoTable { LastName: LastNameColumn, FilmInfo: FilmInfoColumn, - AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, - MutableColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` @@ -422,7 +448,7 @@ import ( var AllTypes = newAllTypesTable() -type AllTypesTable struct { +type allTypesTable struct { postgres.Table //Columns @@ -492,16 +518,27 @@ type AllTypesTable struct { MutableColumns postgres.ColumnList } +type AllTypesTable struct { + allTypesTable + + EXCLUDED allTypesTable +} + // creates new AllTypesTable with assigned alias func (a *AllTypesTable) AS(alias string) *AllTypesTable { aliasTable := newAllTypesTable() - aliasTable.Table.AS(alias) - return aliasTable } func newAllTypesTable() *AllTypesTable { + return &AllTypesTable{ + allTypesTable: newAllTypesTableImpl("test_sample", "all_types"), + EXCLUDED: newAllTypesTableImpl("", "excluded"), + } +} + +func newAllTypesTableImpl(schemaName, tableName string) allTypesTable { var ( SmallIntPtrColumn = postgres.IntegerColumn("small_int_ptr") SmallIntColumn = postgres.IntegerColumn("small_int") @@ -564,10 +601,12 @@ func newAllTypesTable() *AllTypesTable { JsonbArrayColumn = postgres.StringColumn("jsonb_array") TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr") TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array") + allColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} + mutableColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} ) - return &AllTypesTable{ - Table: postgres.NewTable("test_sample", "all_types", SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn), + return allTypesTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns SmallIntPtr: SmallIntPtrColumn, @@ -632,8 +671,8 @@ func newAllTypesTable() *AllTypesTable { TextMultiDimArrayPtr: TextMultiDimArrayPtrColumn, TextMultiDimArray: TextMultiDimArrayColumn, - AllColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn}, - MutableColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 38c135e..ecbd3c5 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -2,11 +2,13 @@ package postgres import ( "context" + "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/stretchr/testify/assert" + "math/rand" "testing" "time" ) @@ -15,10 +17,10 @@ func TestInsertValues(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (id, url, name, description) VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL) +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", @@ -77,8 +79,8 @@ func TestInsertEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) expectedSQL := ` -INSERT INTO test_sample.link VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); ` stmt := Link.INSERT(). @@ -90,11 +92,128 @@ INSERT INTO test_sample.link VALUES AssertExec(t, stmt, 1) } +func TestInsertOnConflict(t *testing.T) { + + t.Run("do nothing", func(t *testing.T) { + employee := model.Employee{EmployeeID: rand.Int31()} + + stmt := Employee.INSERT(Employee.AllColumns). + MODEL(employee). + MODEL(employee). + ON_CONFLICT(Employee.EmployeeID).DO_NOTHING() + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) +VALUES ($1, $2, $3, $4, $5), + ($6, $7, $8, $9, $10) +ON CONFLICT (employee_id) DO NOTHING; +`) + AssertExec(t, stmt, 1) + }) + + t.Run("on constraint do nothing", func(t *testing.T) { + employee := model.Employee{EmployeeID: rand.Int31()} + + stmt := Employee.INSERT(Employee.AllColumns). + MODEL(employee). + MODEL(employee). + ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING() + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) +VALUES ($1, $2, $3, $4, $5), + ($6, $7, $8, $9, $10) +ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; +`) + AssertExec(t, stmt, 1) + }) + + t.Run("do update", func(t *testing.T) { + cleanUpLinkTable(t) + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_CONFLICT(Link.ID).DO_UPDATE( + SET(Link.ID, Link.EXCLUDED.ID). + SET(Link.URL, "http://www.postgresqltutorial2.com"), + ). + RETURNING(Link.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES ($1, $2, $3, DEFAULT), + ($4, $5, $6, DEFAULT) +ON CONFLICT (id) DO UPDATE + SET id = excluded.id, + url = $7 +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + AssertExec(t, stmt, 2) + }) + + t.Run("on constraint do update", func(t *testing.T) { + cleanUpLinkTable(t) + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE( + SET(Link.ID, Link.EXCLUDED.ID). + SET(Link.URL, "http://www.postgresqltutorial2.com"), + ). + RETURNING(Link.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES ($1, $2, $3, DEFAULT), + ($4, $5, $6, DEFAULT) +ON CONFLICT ON CONSTRAINT link_pkey DO UPDATE + SET id = excluded.id, + url = $7 +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + AssertExec(t, stmt, 2) + }) + + t.Run("do update complex", func(t *testing.T) { + cleanUpLinkTable(t) + + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE( + SET(Link.ID, SELECT(MAXi(Link.ID).ADD(Int(1))).FROM(Link)). + SET(ColumnList{Link.Name, Link.Description}, jet.ROW(Link.EXCLUDED.Name, String("new description"))). + WHERE(Link.Description.IS_NOT_NULL()), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT) +ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE + SET id = ( + SELECT MAX(link.id) + 1 + FROM test_sample.link + ), + (name, description) = ROW(excluded.name, 'new description') + WHERE link.description IS NOT NULL; +`) + + AssertExec(t, stmt, 1) + }) +} + func TestInsertModelObject(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.duckduckgo.com', 'Duck Duck go'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); ` linkData := model.Link{ @@ -114,8 +233,8 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertModelObjectEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link VALUES - (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); ` linkData := model.Link{ @@ -135,10 +254,10 @@ INSERT INTO test_sample.link VALUES func TestInsertModelsObject(t *testing.T) { expectedSQL := ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); ` tutorial := model.Link{ @@ -170,11 +289,11 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertUsingMutableColumns(t *testing.T) { var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); ` google := model.Link{ diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 49ab868..fd538d6 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -5,16 +5,19 @@ import ( "github.com/go-jet/jet/tests/dbconfig" _ "github.com/lib/pq" "github.com/pkg/profile" + "math/rand" "os" "os/exec" "strings" "testing" + "time" ) var db *sql.DB var testRoot string func TestMain(m *testing.M) { + rand.Seed(time.Now().Unix()) defer profile.Start().Stop() setTestRoot() diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 2cfbc85..b2bed1d 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -7,6 +7,7 @@ import ( . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) @@ -221,6 +222,10 @@ FROM test_sample.person; func TestSelecSelfJoin1(t *testing.T) { + // clean up + _, err := Employee.DELETE().WHERE(Employee.EmployeeID.GT(Int(100))).Exec(db) + require.NoError(t, err) + var expectedSQL = ` SELECT employee.employee_id AS "employee.employee_id", employee.first_name AS "employee.first_name", @@ -256,7 +261,7 @@ ORDER BY employee.employee_id; Manager *Manager } - err := query.Query(db, &dest) + err = query.Query(db, &dest) assert.NoError(t, err) assert.Equal(t, len(dest), 8) diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go index e7a7816..c30a365 100644 --- a/tests/postgres/util_test.go +++ b/tests/postgres/util_test.go @@ -6,13 +6,14 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { res, err := stmt.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) rows, err := res.RowsAffected() assert.NoError(t, err) assert.Equal(t, rows, rowsAffected)