From e626384d0be3a1ec86b4f3055ce5b42cd1d7cdb8 Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 3 Mar 2020 17:16:42 +0100 Subject: [PATCH 01/23] Remove timestamp from generated files. --- generator/internal/template/templates.go | 1 - 1 file changed, 1 deletion(-) diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index 40e4773..938dab9 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -3,7 +3,6 @@ package template var autoGenWarningTemplate = ` // // Code generated by go-jet DO NOT EDIT. -// Generated at {{now}} // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated From 14e18634566e8726cc3b8759125487dea28d4dac Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 12 Apr 2020 18:53:57 +0200 Subject: [PATCH 02/23] [postgres] Add support for ON CONFLICT clause --- README.md | 2 +- .../internal/metadata/table_meta_data.go | 9 +- generator/internal/template/templates.go | 27 ++- internal/jet/cast.go | 4 +- internal/jet/clause.go | 120 ++++++++----- internal/jet/column.go | 19 ++- internal/jet/expression.go | 18 +- internal/jet/func_expression.go | 2 +- internal/jet/interval.go | 2 +- internal/jet/keyword.go | 14 +- internal/jet/literal_expression.go | 8 + internal/jet/operators.go | 8 +- internal/jet/serializer.go | 41 ++++- internal/jet/sql_builder.go | 6 +- internal/jet/sql_builder_test.go | 6 + internal/jet/statement.go | 10 +- internal/jet/table.go | 10 +- internal/jet/utils.go | 21 ++- internal/jet/window_expression.go | 10 +- internal/jet/window_func.go | 14 +- internal/testutils/test_utils.go | 38 +++-- mysql/insert_statement_test.go | 38 ++--- mysql/table.go | 4 +- mysql/table_test.go | 4 +- mysql/utils_test.go | 8 +- postgres/clause.go | 87 ++++++++++ postgres/clause_test.go | 34 ++++ postgres/clauses.go | 20 --- postgres/conflict_action.go | 36 ++++ postgres/insert_statement.go | 17 +- postgres/insert_statement_test.go | 97 ++++++++--- postgres/table.go | 4 +- postgres/table_test.go | 4 +- postgres/update_statement.go | 2 +- postgres/utils_test.go | 13 +- tests/mysql/generator_test.go | 58 +++++-- tests/mysql/insert_test.go | 38 ++--- tests/postgres/generator_test.go | 81 ++++++--- tests/postgres/insert_test.go | 157 +++++++++++++++--- tests/postgres/main_test.go | 3 + tests/postgres/sample_test.go | 7 +- tests/postgres/util_test.go | 3 +- 42 files changed, 827 insertions(+), 277 deletions(-) create mode 100644 postgres/clause.go create mode 100644 postgres/clause_test.go delete mode 100644 postgres/clauses.go create mode 100644 postgres/conflict_action.go diff --git a/README.md b/README.md index 721219a..932fff5 100644 --- a/README.md +++ b/README.md @@ -568,5 +568,5 @@ To run the tests, additional dependencies are required: ## License -Copyright 2019 Goran Bjelanovic +Copyright 2019-2020 Goran Bjelanovic Licensed under the Apache License, Version 2.0. diff --git a/generator/internal/metadata/table_meta_data.go b/generator/internal/metadata/table_meta_data.go index cb738fa..bab1088 100644 --- a/generator/internal/metadata/table_meta_data.go +++ b/generator/internal/metadata/table_meta_data.go @@ -3,6 +3,7 @@ package metadata import ( "database/sql" "github.com/go-jet/jet/internal/utils" + "strings" ) // TableMetaData metadata struct @@ -67,15 +68,19 @@ func (t TableMetaData) GoStructName() string { return utils.ToGoIdentifier(t.name) + "Table" } +// GoStructImplName returns go struct impl name for sql builder +func (t TableMetaData) GoStructImplName() string { + name := utils.ToGoIdentifier(t.name) + "Table" + return string(strings.ToLower(name)[0]) + name[1:] +} + // GetTableMetaData returns table info metadata func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) { - tableInfo.SchemaName = schemaName tableInfo.name = tableName tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName) tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName) - return } diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index 40e4773..d8e302a 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -26,7 +26,7 @@ import ( var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() -type {{.GoStructName}} struct { +type {{.GoStructImplName}} struct { {{dialect.PackageName}}.Table //Columns @@ -38,32 +38,45 @@ type {{.GoStructName}} struct { MutableColumns {{dialect.PackageName}}.ColumnList } +type {{.GoStructName}} struct { + {{.GoStructImplName}} + + EXCLUDED {{.GoStructImplName}} +} + // creates new {{.GoStructName}} with assigned alias func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} { aliasTable := new{{.GoStructName}}() - aliasTable.Table.AS(alias) - return aliasTable } func new{{.GoStructName}}() *{{.GoStructName}} { + return &{{.GoStructName}}{ + {{.GoStructImplName}}: new{{.GoStructName}}Impl("{{.SchemaName}}", "{{.Name}}"), + EXCLUDED: new{{.GoStructName}}Impl("", "excluded"), + } +} + +func new{{.GoStructName}}Impl(schemaName, tableName string) {{.GoStructImplName}} { var ( {{- range .Columns}} {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{- end}} + allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } + mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } ) - return &{{.GoStructName}}{ - Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), + return {{.GoStructImplName}}{ + Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, allColumns...), //Columns {{- range .Columns}} {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, {{- end}} - AllColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }, - MutableColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/internal/jet/cast.go b/internal/jet/cast.go index c5fe9a7..e8045e7 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -42,12 +42,12 @@ func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, opt castType := b.cast if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil { - castOverride(expression, String(castType))(statement, out, options...) + castOverride(expression, String(castType))(statement, out, FallTrough(options)...) return } out.WriteString("CAST(") - expression.serialize(statement, out, options...) + expression.serialize(statement, out, FallTrough(options)...) out.WriteString("AS") out.WriteString(castType + ")") } diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 738074b..90d194e 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -6,7 +6,7 @@ import ( // Clause interface type Clause interface { - Serialize(statementType StatementType, out *SQLBuilder) + Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) } // ClauseWithProjections interface @@ -27,7 +27,7 @@ func (s *ClauseSelect) projections() ProjectionList { } // Serialize serializes clause into SQLBuilder -func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder) { +func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("SELECT") @@ -48,7 +48,7 @@ type ClauseFrom struct { } // Serialize serializes clause into SQLBuilder -func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) { +func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if f.Table == nil { return } @@ -56,7 +56,7 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString("FROM") out.IncreaseIdent() - f.Table.serialize(statementType, out) + f.Table.serialize(statementType, out, FallTrough(options)...) out.DecreaseIdent() } @@ -67,18 +67,20 @@ type ClauseWhere struct { } // Serialize serializes clause into SQLBuilder -func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder) { +func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if c.Condition == nil { if c.Mandatory { panic("jet: WHERE clause not set") } return } - out.NewLine() + if !contains(options, SkipNewLine) { + out.NewLine() + } out.WriteString("WHERE") out.IncreaseIdent() - c.Condition.serialize(statementType, out, noWrap) + c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...) out.DecreaseIdent() } @@ -88,7 +90,7 @@ type ClauseGroupBy struct { } // Serialize serializes clause into SQLBuilder -func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder) { +func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(c.List) == 0 { return } @@ -119,7 +121,7 @@ type ClauseHaving struct { } // Serialize serializes clause into SQLBuilder -func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) { +func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if c.Condition == nil { return } @@ -128,7 +130,7 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString("HAVING") out.IncreaseIdent() - c.Condition.serialize(statementType, out, noWrap) + c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...) out.DecreaseIdent() } @@ -139,7 +141,7 @@ type ClauseOrderBy struct { } // Serialize serializes clause into SQLBuilder -func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder) { +func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if o.List == nil { return } @@ -168,7 +170,7 @@ type ClauseLimit struct { } // Serialize serializes clause into SQLBuilder -func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder) { +func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if l.Count >= 0 { out.NewLine() out.WriteString("LIMIT") @@ -182,7 +184,7 @@ type ClauseOffset struct { } // Serialize serializes clause into SQLBuilder -func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder) { +func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if o.Count >= 0 { out.NewLine() out.WriteString("OFFSET") @@ -196,14 +198,14 @@ type ClauseFor struct { } // Serialize serializes clause into SQLBuilder -func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder) { +func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if f.Lock == nil { return } out.NewLine() out.WriteString("FOR") - f.Lock.serialize(statementType, out) + f.Lock.serialize(statementType, out, FallTrough(options)...) } // ClauseSetStmtOperator struct @@ -224,7 +226,7 @@ func (s *ClauseSetStmtOperator) projections() ProjectionList { } // Serialize serializes clause into SQLBuilder -func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder) { +func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(s.Selects) < 2 { panic("jet: UNION Statement must contain at least two SELECT statements") } @@ -244,7 +246,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB panic("jet: select statement of '" + s.Operator + "' is nil") } - selectStmt.serialize(statementType, out) + selectStmt.serialize(statementType, out, FallTrough(options)...) } s.OrderBy.Serialize(statementType, out) @@ -258,7 +260,7 @@ type ClauseUpdate struct { } // Serialize serializes clause into SQLBuilder -func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) { +func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("UPDATE") @@ -266,7 +268,7 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) { panic("jet: table to update is nil") } - u.Table.serialize(statementType, out) + u.Table.serialize(statementType, out, FallTrough(options)...) } // ClauseSet struct @@ -276,7 +278,7 @@ type ClauseSet struct { } // Serialize serializes clause into SQLBuilder -func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) { +func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("SET") @@ -299,7 +301,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString(" = ") - s.Values[i].serialize(UpdateStatementType, out) + s.Values[i].serialize(UpdateStatementType, out, FallTrough(options)...) } out.DecreaseIdent(4) } @@ -320,7 +322,7 @@ func (i *ClauseInsert) GetColumns() []Column { } // Serialize serializes clause into SQLBuilder -func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder) { +func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("INSERT INTO") @@ -346,7 +348,7 @@ type ClauseValuesQuery struct { } // Serialize serializes clause into SQLBuilder -func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder) { +func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(v.Rows) == 0 && v.Query == nil { panic("jet: VALUES or QUERY has to be specified for INSERT statement") } @@ -355,8 +357,8 @@ func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuild panic("jet: VALUES or QUERY has to be specified for INSERT statement") } - v.ClauseValues.Serialize(statementType, out) - v.ClauseQuery.Serialize(statementType, out) + v.ClauseValues.Serialize(statementType, out, FallTrough(options)...) + v.ClauseQuery.Serialize(statementType, out, FallTrough(options)...) } // ClauseValues struct @@ -365,27 +367,29 @@ type ClauseValues struct { } // Serialize serializes clause into SQLBuilder -func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder) { +func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(v.Rows) == 0 { return } + out.NewLine() out.WriteString("VALUES") for rowIndex, row := range v.Rows { if rowIndex > 0 { out.WriteString(",") + out.NewLine() + } else { + out.IncreaseIdent(7) } - out.IncreaseIdent() - out.NewLine() out.WriteString("(") SerializeClauseList(statementType, row, out) out.WriteByte(')') - out.DecreaseIdent() } + out.DecreaseIdent(7) } // ClauseQuery struct @@ -394,12 +398,12 @@ type ClauseQuery struct { } // Serialize serializes clause into SQLBuilder -func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder) { +func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if v.Query == nil { return } - v.Query.serialize(statementType, out) + v.Query.serialize(statementType, out, FallTrough(options)...) } // ClauseDelete struct @@ -408,7 +412,7 @@ type ClauseDelete struct { } // Serialize serializes clause into SQLBuilder -func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) { +func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString("DELETE FROM") @@ -416,7 +420,7 @@ func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) { panic("jet: nil table in DELETE clause") } - d.Table.serialize(statementType, out) + d.Table.serialize(statementType, out, FallTrough(options)...) } // ClauseStatementBegin struct @@ -426,7 +430,7 @@ type ClauseStatementBegin struct { } // Serialize serializes clause into SQLBuilder -func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder) { +func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { out.NewLine() out.WriteString(d.Name) @@ -435,7 +439,7 @@ func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBu out.WriteString(", ") } - table.serialize(statementType, out) + table.serialize(statementType, out, FallTrough(options)...) } } @@ -447,7 +451,7 @@ type ClauseOptional struct { } // Serialize serializes clause into SQLBuilder -func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder) { +func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if !d.Show { return } @@ -463,7 +467,7 @@ type ClauseIn struct { } // Serialize serializes clause into SQLBuilder -func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) { +func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if i.LockMode == "" { return } @@ -485,7 +489,7 @@ type ClauseWindow struct { } // Serialize serializes clause into SQLBuilder -func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) { +func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { if len(i.Definitions) == 0 { return } @@ -503,6 +507,44 @@ func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) { out.WriteString("()") continue } - def.Window.serialize(statementType, out) + def.Window.serialize(statementType, out, FallTrough(options)...) } } + +// SetPair clause +type SetPair struct { + Column ColumnSerializer + Value Serializer +} + +// SetClause clause +type SetClause []SetPair + +// Serialize for SetClause +func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + out.NewLine() + out.WriteString("SET") + out.IncreaseIdent(4) + + for i, pair := range s { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + + pair.Column.serialize(statementType, out, ShortName.WithFallTrough(options)...) + out.WriteString("=") + pair.Value.serialize(statementType, out, FallTrough(options)...) + } + out.DecreaseIdent(4) +} + +// KeywordClause type +type KeywordClause struct { + Keyword +} + +// Serialize for KeywordClause +func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + k.serialize(statementType, out, FallTrough(options)...) +} diff --git a/internal/jet/column.go b/internal/jet/column.go index 85c053e..0fd59be 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -12,6 +12,12 @@ type Column interface { defaultAlias() string } +// ColumnSerializer is interface for all serializable columns +type ColumnSerializer interface { + Serializer + Column +} + // ColumnExpression interface type ColumnExpression interface { Column @@ -101,7 +107,7 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder out.WriteByte('.') out.WriteIdentifier(c.defaultAlias(), true) } else { - if c.tableName != "" { + if c.tableName != "" && !contains(options, ShortName) { out.WriteIdentifier(c.tableName) out.WriteByte('.') } @@ -125,6 +131,17 @@ func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { return newProjectionList } +func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString("(") + for i, column := range cl { + if i > 0 { + out.WriteString(", ") + } + column.serialize(statement, out, FallTrough(options)...) + } + out.WriteString(")") +} + func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) { projections := ColumnListToProjectionList(cl) diff --git a/internal/jet/expression.go b/internal/jet/expression.go index a463b76..d807534 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -72,15 +72,15 @@ func (e *ExpressionInterfaceImpl) DESC() OrderByClause { } func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { - e.Parent.serialize(statement, out, noWrap) + e.Parent.serialize(statement, out, NoWrap) } func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { - e.Parent.serialize(statement, out, noWrap) + e.Parent.serialize(statement, out, NoWrap) } func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { - e.Parent.serialize(statement, out, noWrap) + e.Parent.serialize(statement, out, NoWrap) } // Representation of binary operations (e.g. comparisons, arithmetic) @@ -117,7 +117,7 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu panic("jet: rhs is nil for '" + c.operator + "' operator") } - wrap := !contains(options, noWrap) + wrap := !contains(options, NoWrap) if wrap { out.WriteString("(") @@ -125,11 +125,11 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) - serializeOverrideFunc(statement, out, options...) + serializeOverrideFunc(statement, out, FallTrough(options)...) } else { - c.lhs.serialize(statement, out) + c.lhs.serialize(statement, out, FallTrough(options)...) out.WriteString(c.operator) - c.rhs.serialize(statement, out) + c.rhs.serialize(statement, out, FallTrough(options)...) } if wrap { @@ -163,7 +163,7 @@ func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, o panic("jet: nil prefix expression in prefix operator " + p.operator) } - p.expression.serialize(statement, out) + p.expression.serialize(statement, out, FallTrough(options)...) out.WriteString(")") } @@ -192,7 +192,7 @@ func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder panic("jet: nil prefix expression in postfix operator " + p.operator) } - p.expression.serialize(statement, out) + p.expression.serialize(statement, out, FallTrough(options)...) out.WriteString(p.operator) } diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index f38c9a2..e95bece 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -613,7 +613,7 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression { func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...) - serializeOverrideFunc(statement, out, options...) + serializeOverrideFunc(statement, out, FallTrough(options)...) return } diff --git a/internal/jet/interval.go b/internal/jet/interval.go index fab84e0..5b371e1 100644 --- a/internal/jet/interval.go +++ b/internal/jet/interval.go @@ -33,5 +33,5 @@ type IntervalImpl struct { func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("INTERVAL") - i.interval.serialize(statement, out, options...) + i.interval.serialize(statement, out, FallTrough(options)...) } diff --git a/internal/jet/keyword.go b/internal/jet/keyword.go index c903ad6..a40538e 100644 --- a/internal/jet/keyword.go +++ b/internal/jet/keyword.go @@ -2,18 +2,12 @@ package jet const ( // DEFAULT is jet equivalent of SQL DEFAULT - DEFAULT keywordClause = "DEFAULT" + DEFAULT Keyword = "DEFAULT" ) -var ( - // NULL is jet equivalent of SQL NULL - NULL = newNullLiteral() - // STAR is jet equivalent of SQL * - STAR = newStarLiteral() -) +// Keyword type +type Keyword string -type keywordClause string - -func (k keywordClause) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (k Keyword) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(string(k)) } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 15c8a4a..8cdb3d7 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -278,6 +278,14 @@ func formatNanoseconds(nanoseconds ...time.Duration) string { } //--------------------------------------------------// + +var ( + // NULL is jet equivalent of SQL NULL + NULL = newNullLiteral() + // STAR is jet equivalent of SQL * + STAR = newStarLiteral() +) + type nullLiteral struct { ExpressionInterfaceImpl } diff --git a/internal/jet/operators.go b/internal/jet/operators.go index d17081c..19173a6 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -147,7 +147,7 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o out.WriteString("(CASE") if c.expression != nil { - c.expression.serialize(statement, out) + c.expression.serialize(statement, out, FallTrough(options)...) } if len(c.when) == 0 || len(c.then) == 0 { @@ -160,15 +160,15 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o for i, when := range c.when { out.WriteString("WHEN") - when.serialize(statement, out, noWrap) + when.serialize(statement, out, NoWrap) out.WriteString("THEN") - c.then[i].serialize(statement, out, noWrap) + c.then[i].serialize(statement, out, NoWrap) } if c.els != nil { out.WriteString("ELSE") - c.els.serialize(statement, out, noWrap) + c.els.serialize(statement, out, NoWrap) } out.WriteString("END)") diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index dc661d7..2f014cc 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -5,9 +5,18 @@ type SerializeOption int // Serialize options const ( - noWrap SerializeOption = iota + NoWrap SerializeOption = iota + SkipNewLine + + fallTroughOptions // fall trough options + ShortName ) +// WithFallTrough extends existing serialize options with additional +func (s SerializeOption) WithFallTrough(options []SerializeOption) []SerializeOption { + return append(FallTrough(options), s) +} + // StatementType is type of the SQL statement type StatementType string @@ -42,6 +51,19 @@ func contains(options []SerializeOption, option SerializeOption) bool { return false } +// FallTrough filters fall-trough options from the list +func FallTrough(options []SerializeOption) []SerializeOption { + var ret []SerializeOption + + for _, option := range options { + if option > fallTroughOptions { + ret = append(ret, option) + } + } + + return ret +} + // ListSerializer serializes list of serializers with separator type ListSerializer struct { Serializers []Serializer @@ -53,6 +75,21 @@ func (s ListSerializer) serialize(statement StatementType, out *SQLBuilder, opti if i > 0 { out.WriteString(s.Separator) } - ser.serialize(statement, out) + ser.serialize(statement, out, FallTrough(options)...) + } +} + +// NewSerializerClauseImpl is constructor for Seralizer with list of clauses +func NewSerializerClauseImpl(clauses ...Clause) Serializer { + return &serializerImpl{Clauses: clauses} +} + +type serializerImpl struct { + Clauses []Clause +} + +func (s serializerImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + for _, clause := range s.Clauses { + clause.Serialize(statement, out, FallTrough(options)...) } } diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index ef7f801..642f0cd 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) { // WriteIdentifier adds identifier to output SQL func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { - if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { + if s.shouldQuote(name, alwaysQuote...) { identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) s.WriteString(identQuoteChar + name + identQuoteChar) } else { @@ -106,6 +106,10 @@ func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { } } +func (s *SQLBuilder) shouldQuote(name string, alwaysQuote ...bool) bool { + return s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 +} + // WriteByte writes byte to output SQL func (s *SQLBuilder) WriteByte(b byte) { s.write([]byte{b}) diff --git a/internal/jet/sql_builder_test.go b/internal/jet/sql_builder_test.go index f7b5ade..911df27 100644 --- a/internal/jet/sql_builder_test.go +++ b/internal/jet/sql_builder_test.go @@ -41,3 +41,9 @@ func TestArgToString(t *testing.T) { argToString(map[string]bool{}) }() } + +func TestFallTrough(t *testing.T) { + assert.Equal(t, FallTrough([]SerializeOption{ShortName}), []SerializeOption{ShortName}) + assert.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil)) + assert.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName}) +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 3b0638d..beb52f1 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -58,7 +58,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface queryData := &SQLBuilder{Dialect: s.dialect} - s.parent.serialize(s.statementType, queryData, noWrap) + s.parent.serialize(s.statementType, queryData, NoWrap) query, args = queryData.finalize() return @@ -67,7 +67,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true} - s.parent.serialize(s.statementType, sqlBuilder, noWrap) + s.parent.serialize(s.statementType, sqlBuilder, NoWrap) query, _ = sqlBuilder.finalize() return @@ -157,16 +157,16 @@ func (s *statementImpl) projections() ProjectionList { func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.WriteString("(") out.IncreaseIdent() } for _, clause := range s.Clauses { - clause.Serialize(statement, out) + clause.Serialize(statement, out, FallTrough(options)...) } - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.DecreaseIdent() out.NewLine() out.WriteString(")") diff --git a/internal/jet/table.go b/internal/jet/table.go index 4379000..ca76e3a 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -19,17 +19,15 @@ type Table interface { } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, column ColumnExpression, columns ...ColumnExpression) SerializerTable { - - columnList := append([]ColumnExpression{column}, columns...) +func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable { t := tableImpl{ schemaName: schemaName, name: name, - columnList: columnList, + columnList: columns, } - for _, c := range columnList { + for _, c := range columns { c.setTableName(name) } @@ -156,7 +154,7 @@ func (t *joinTableImpl) serialize(statement StatementType, out *SQLBuilder, opti panic("jet: left hand side of join operation is nil table") } - t.lhs.serialize(statement, out) + t.lhs.serialize(statement, out, FallTrough(options)...) out.NewLine() diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 126ce32..50f4e13 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -63,6 +63,22 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } +// SerializeColumnExpressionNames func +func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType, + out *SQLBuilder, options ...SerializeOption) { + for i, col := range columns { + if i > 0 { + out.WriteString(", ") + } + + if col == nil { + panic("jet: nil column in columns list") + } + + col.serialize(statementType, out, options...) + } +} + // ExpressionListToSerializerList converts list of expressions to list of serializers func ExpressionListToSerializerList(expressions []Expression) []Serializer { var ret []Serializer @@ -85,7 +101,8 @@ func ColumnListToProjectionList(columns []ColumnExpression) []Projection { return ret } -func valueToClause(value interface{}) Serializer { +// ToSerializerValue creates Serializer type from the value +func ToSerializerValue(value interface{}) Serializer { if clause, ok := value.(Serializer); ok { return clause } @@ -148,7 +165,7 @@ func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer { allValues := append([]interface{}{value}, values...) for _, val := range allValues { - row = append(row, valueToClause(val)) + row = append(row, ToSerializerValue(val)) } return row diff --git a/internal/jet/window_expression.go b/internal/jet/window_expression.go index 3e7f1c7..e5f18c1 100644 --- a/internal/jet/window_expression.go +++ b/internal/jet/window_expression.go @@ -17,7 +17,7 @@ func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, o w.expression.serialize(statement, out) if w.window != nil { out.WriteString("OVER") - w.window.serialize(statement, out) + w.window.serialize(statement, out, FallTrough(options)...) } } @@ -49,7 +49,7 @@ func (f *windowExpressionImpl) OVER(window ...Window) Expression { } func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } // ----------------------------------------------------- @@ -80,7 +80,7 @@ func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression { } func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } // ------------------------------------------------ @@ -111,7 +111,7 @@ func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression { } func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } // ------------------------------------------------ @@ -142,5 +142,5 @@ func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression { } func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - f.commonWindowImpl.serialize(statement, out) + f.commonWindowImpl.serialize(statement, out, FallTrough(options)...) } diff --git a/internal/jet/window_func.go b/internal/jet/window_func.go index 7f4d1b7..6602495 100644 --- a/internal/jet/window_func.go +++ b/internal/jet/window_func.go @@ -30,7 +30,7 @@ func newWindowImpl(parent Window) *windowImpl { } func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.WriteByte('(') } @@ -40,7 +40,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options serializeExpressionList(statement, w.partitionBy, ", ", out) } w.orderBy.SkipNewLine = true - w.orderBy.Serialize(statement, out) + w.orderBy.Serialize(statement, out, FallTrough(options)...) if w.frameUnits != "" { out.WriteString(w.frameUnits) @@ -55,7 +55,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options } } - if !contains(options, noWrap) { + if !contains(options, NoWrap) { out.WriteByte(')') } } @@ -139,7 +139,7 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op if f == nil { return } - f.offset.serialize(statement, out) + f.offset.serialize(statement, out, FallTrough(options)...) if f.preceding { out.WriteString("PRECEDING") @@ -152,12 +152,12 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op // Window function keywords var ( - UNBOUNDED = keywordClause("UNBOUNDED") + UNBOUNDED = Keyword("UNBOUNDED") CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"} ) type frameExtentKeyword struct { - keywordClause + Keyword } func (f frameExtentKeyword) isFrameExtent() {} @@ -180,7 +180,7 @@ func (w windowName) serialize(statement StatementType, out *SQLBuilder, options out.WriteByte('(') out.WriteString(w.name) - w.windowImpl.serialize(statement, out, noWrap) + w.windowImpl.serialize(statement, out, NoWrap.WithFallTrough(options)...) out.WriteByte(')') } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index f0c18e9..942dab7 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -8,6 +8,7 @@ import ( "github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/qrm" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "io/ioutil" "os" "path/filepath" @@ -110,17 +111,17 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st _, args := query.Sql() if len(expectedArgs) > 0 { - AssertDeepEqual(t, args, expectedArgs) + AssertDeepEqual(t, args, expectedArgs, "arguments are not equal") } debuqSql := query.DebugSql() assert.Equal(t, debuqSql, expectedQuery) } -// AssertClauseSerialize checks if clause serialize produces expected query and args -func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { +// AssertSerialize checks if clause serialize produces expected query and args +func AssertSerialize(t *testing.T, dialect jet.Dialect, serializer jet.Serializer, query string, args ...interface{}) { out := jet.SQLBuilder{Dialect: dialect} - jet.Serialize(clause, jet.SelectStatementType, &out) + jet.Serialize(serializer, jet.SelectStatementType, &out) //fmt.Println(out.Buff.String()) @@ -131,8 +132,20 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali } } -// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args -func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { +// AssertClauseSerialize checks if clause serialize produces expected query and args +func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) { + out := jet.SQLBuilder{Dialect: dialect} + clause.Serialize(jet.SelectStatementType, &out) + + require.Equal(t, out.Buff.String(), query) + + if len(args) > 0 { + AssertDeepEqual(t, out.Args, args) + } +} + +// AssertDebugSerialize checks if clause serialize produces expected debug query and args +func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { out := jet.SQLBuilder{Dialect: dialect, Debug: true} jet.Serialize(clause, jet.SelectStatementType, &out) @@ -153,8 +166,8 @@ func AssertPanicErr(t *testing.T, fun func(), errorStr string) { fun() } -// AssertClauseSerializeErr check if clause serialize panics with errString -func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { +// AssertSerializeErr check if clause serialize panics with errString +func AssertSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { defer func() { r := recover() assert.Equal(t, r, errString) @@ -191,9 +204,8 @@ func AssertFileContent(t *testing.T, filePath string, contentBegin string, expec beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) - //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") - - AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) + //AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) + require.Equal(t, string(enumFileData[beginIndex:]), expectedContent) } // AssertFileNamesEqual check if all filesInfos are contained in fileNames @@ -212,6 +224,6 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } // AssertDeepEqual checks if actual and expected objects are deeply equal. -func AssertDeepEqual(t *testing.T, actual, expected interface{}) { - assert.True(t, cmp.Equal(actual, expected)) +func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { + assert.True(t, cmp.Equal(actual, expected), msg) } diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index 6faf313..65c8fba 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -13,15 +13,15 @@ func TestInvalidInsert(t *testing.T) { func TestInsertNilValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` -INSERT INTO db.table1 (col1) VALUES - (?); +INSERT INTO db.table1 (col1) +VALUES (?); `, nil) } func TestInsertSingleValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` -INSERT INTO db.table1 (col1) VALUES - (?); +INSERT INTO db.table1 (col1) +VALUES (?); `, int(1)) } @@ -31,8 +31,8 @@ func TestInsertWithColumnList(t *testing.T) { columnList = append(columnList, table3StrCol) assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` -INSERT INTO db.table3 (col_int, col2) VALUES - (?, ?); +INSERT INTO db.table3 (col_int, col2) +VALUES (?, ?); `, 1, 3) } @@ -40,15 +40,15 @@ func TestInsertDate(t *testing.T) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` -INSERT INTO db.table1 (col_timestamp) VALUES - (?); +INSERT INTO db.table1 (col_timestamp) +VALUES (?); `, date) } func TestInsertMultipleValues(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` -INSERT INTO db.table1 (col1, col_float, col3) VALUES - (?, ?, ?); +INSERT INTO db.table1 (col1, col_float, col3) +VALUES (?, ?, ?); `, 1, 2, 3) } @@ -59,10 +59,10 @@ func TestInsertMultipleRows(t *testing.T) { VALUES(111, 222) assertStatementSql(t, stmt, ` -INSERT INTO db.table1 (col1, col_float) VALUES - (?, ?), - (?, ?), - (?, ?); +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?), + (?, ?); `, 1, 2, 11, 22, 111, 222) } @@ -84,9 +84,9 @@ func TestInsertValuesFromModel(t *testing.T) { MODEL(&toInsert) expectedSQL := ` -INSERT INTO db.table1 (col1, col_float) VALUES - (?, ?), - (?, ?); +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?); ` assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) @@ -127,8 +127,8 @@ func TestInsertDefaultValue(t *testing.T) { VALUES(DEFAULT, "two") var expectedSQL = ` -INSERT INTO db.table1 (col1, col_float) VALUES - (DEFAULT, ?); +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?); ` assertStatementSql(t, stmt, expectedSQL, "two") diff --git a/mysql/table.go b/mysql/table.go index 6d414a2..a4cf042 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -77,9 +77,9 @@ func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectU } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { +func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { t := &tableImpl{ - SerializerTable: jet.NewTable(schemaName, name, column, columns...), + SerializerTable: jet.NewTable(schemaName, name, columns...), } t.readableTableInterfaceImpl.parent = t diff --git a/mysql/table_test.go b/mysql/table_test.go index 3894378..3bc79f9 100644 --- a/mysql/table_test.go +++ b/mysql/table_test.go @@ -5,9 +5,9 @@ import ( ) func TestJoinNilInputs(t *testing.T) { - assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), "jet: right hand side of join operation is nil table") - assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), + assertSerializeErr(t, table2.INNER_JOIN(table1, nil), "jet: join condition is nil") } diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 1cc42f1..709097d 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -59,15 +59,15 @@ var table3 = NewTable( table3StrCol) func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { - testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) + testutils.AssertSerialize(t, Dialect, clause, query, args...) } func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { - testutils.AssertDebugClauseSerialize(t, Dialect, clause, query, args...) + testutils.AssertDebugSerialize(t, Dialect, clause, query, args...) } -func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { - testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) +func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) { + testutils.AssertSerializeErr(t, Dialect, clause, errString) } func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { diff --git a/postgres/clause.go b/postgres/clause.go new file mode 100644 index 0000000..ad61b9a --- /dev/null +++ b/postgres/clause.go @@ -0,0 +1,87 @@ +package postgres + +import ( + "github.com/go-jet/jet/internal/jet" +) + +type clauseReturning struct { + Projections []jet.Projection +} + +func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(r.Projections) == 0 { + return + } + + out.NewLine() + out.WriteString("RETURNING") + out.IncreaseIdent() + out.WriteProjections(statementType, r.Projections) +} + +// ========================================== // + +type onConflict interface { + ON_CONSTRAINT(name string) conflictTarget + WHERE(indexPredicate BoolExpression) conflictTarget + DO_NOTHING() InsertStatement + DO_UPDATE(action conflictAction) InsertStatement +} + +type conflictTarget interface { + DO_NOTHING() InsertStatement + DO_UPDATE(action conflictAction) InsertStatement +} + +type onConflictClause struct { + insertStatement InsertStatement + constraint string + indexExpressions []jet.ColumnExpression + whereClause jet.ClauseWhere + do jet.Serializer +} + +func (o *onConflictClause) ON_CONSTRAINT(name string) conflictTarget { + o.constraint = name + return o +} + +func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget { + o.whereClause.Condition = indexPredicate + return o +} + +func (o *onConflictClause) DO_NOTHING() InsertStatement { + o.do = jet.Keyword("DO NOTHING") + return o.insertStatement +} + +func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement { + o.do = action + return o.insertStatement +} + +func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(o.indexExpressions) == 0 && o.constraint == "" { + return + } + + out.NewLine() + out.WriteString("ON CONFLICT") + if len(o.indexExpressions) > 0 { + out.WriteString("(") + jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + out.WriteString(")") + } + + if o.constraint != "" { + out.WriteString("ON CONSTRAINT") + out.WriteString(o.constraint) + } + + o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName) + + out.IncreaseIdent(7) + jet.Serialize(o.do, statementType, out) + out.DecreaseIdent(7) +} diff --git a/postgres/clause_test.go b/postgres/clause_test.go new file mode 100644 index 0000000..7f64c61 --- /dev/null +++ b/postgres/clause_test.go @@ -0,0 +1,34 @@ +package postgres + +import "testing" + +func TestOnConflict(t *testing.T) { + + assertClauseSerialize(t, &onConflictClause{}, "") + + onConflict := &onConflictClause{} + onConflict.DO_NOTHING() + assertClauseSerialize(t, onConflict, "") + + onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}} + onConflict.DO_NOTHING() + assertClauseSerialize(t, onConflict, ` +ON CONFLICT (col_bool) DO NOTHING`) + + onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}} + onConflict.ON_CONSTRAINT("table_pkey").DO_NOTHING() + assertClauseSerialize(t, onConflict, ` +ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) + + onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}} + onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE( + SET(table1ColBool, Bool(true)). + SET(table1ColInt, Int(1)). + WHERE(table2ColFloat.GT(Float(11.1))), + ) + assertClauseSerialize(t, onConflict, ` +ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE + SET col_bool = $1, + col_int = $2 + WHERE table2.col_float > $3`) +} diff --git a/postgres/clauses.go b/postgres/clauses.go deleted file mode 100644 index a882b64..0000000 --- a/postgres/clauses.go +++ /dev/null @@ -1,20 +0,0 @@ -package postgres - -import ( - "github.com/go-jet/jet/internal/jet" -) - -type clauseReturning struct { - Projections []jet.Projection -} - -func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) { - if len(r.Projections) == 0 { - return - } - - out.NewLine() - out.WriteString("RETURNING") - out.IncreaseIdent() - out.WriteProjections(statementType, r.Projections) -} diff --git a/postgres/conflict_action.go b/postgres/conflict_action.go new file mode 100644 index 0000000..5ff7a19 --- /dev/null +++ b/postgres/conflict_action.go @@ -0,0 +1,36 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +type conflictAction interface { + jet.Serializer + SET(column jet.ColumnSerializer, expression interface{}) conflictAction + WHERE(condition BoolExpression) conflictAction +} + +// SET creates conflict action for ON_CONFLICT clause +func SET(column jet.ColumnSerializer, expression interface{}) conflictAction { + conflictAction := updateConflictActionImpl{} + conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"} + conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where) + conflictAction.SET(column, expression) + return &conflictAction +} + +type updateConflictActionImpl struct { + jet.Serializer + + doUpdate jet.KeywordClause + set jet.SetClause + where jet.ClauseWhere +} + +func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction { + u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)}) + return u +} + +func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction { + u.where.Condition = condition + return u +} diff --git a/postgres/insert_statement.go b/postgres/insert_statement.go index a9dee25..e4c72f5 100644 --- a/postgres/insert_statement.go +++ b/postgres/insert_statement.go @@ -11,18 +11,18 @@ type InsertStatement interface { // Insert row of values, where value for each column is extracted from filed of structure data. // If data is not struct or there is no field for every column selected, this method will panic. MODEL(data interface{}) InsertStatement - MODELS(data interface{}) InsertStatement - QUERY(selectStatement SelectStatement) InsertStatement - RETURNING(projections ...jet.Projection) InsertStatement + ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict + + RETURNING(projections ...Projection) InsertStatement } func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, - &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.Returning) + &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning) newInsert.Insert.Table = table newInsert.Insert.Columns = columns @@ -36,6 +36,7 @@ type insertStatementImpl struct { Insert jet.ClauseInsert ValuesQuery jet.ClauseValuesQuery Returning clauseReturning + OnConflict onConflictClause } func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { @@ -62,3 +63,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState i.ValuesQuery.Query = selectStatement return i } + +func (i *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict { + i.OnConflict = onConflictClause{ + insertStatement: i, + indexExpressions: indexExpressions, + } + return &i.OnConflict +} diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index ccb1404..d80c1a1 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -1,6 +1,7 @@ package postgres import ( + "github.com/go-jet/jet/internal/jet" "github.com/stretchr/testify/assert" "testing" "time" @@ -13,15 +14,15 @@ func TestInvalidInsert(t *testing.T) { func TestInsertNilValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` -INSERT INTO db.table1 (col1) VALUES - ($1); +INSERT INTO db.table1 (col1) +VALUES ($1); `, nil) } func TestInsertSingleValue(t *testing.T) { assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` -INSERT INTO db.table1 (col1) VALUES - ($1); +INSERT INTO db.table1 (col1) +VALUES ($1); `, int(1)) } @@ -29,8 +30,8 @@ func TestInsertWithColumnList(t *testing.T) { columnList := ColumnList{table3ColInt, table3StrCol} assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` -INSERT INTO db.table3 (col_int, col2) VALUES - ($1, $2); +INSERT INTO db.table3 (col_int, col2) +VALUES ($1, $2); `, 1, 3) } @@ -38,15 +39,15 @@ func TestInsertDate(t *testing.T) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), ` -INSERT INTO db.table1 (col_time) VALUES - ($1); +INSERT INTO db.table1 (col_time) +VALUES ($1); `, date) } func TestInsertMultipleValues(t *testing.T) { - assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` -INSERT INTO db.table1 (col1, col_float, col3) VALUES - ($1, $2, $3); + assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1ColBool).VALUES(1, 2, 3), ` +INSERT INTO db.table1 (col1, col_float, col_bool) +VALUES ($1, $2, $3); `, 1, 2, 3) } @@ -57,10 +58,10 @@ func TestInsertMultipleRows(t *testing.T) { VALUES(111, 222) assertStatementSql(t, stmt, ` -INSERT INTO db.table1 (col1, col_float) VALUES - ($1, $2), - ($3, $4), - ($5, $6); +INSERT INTO db.table1 (col1, col_float) +VALUES ($1, $2), + ($3, $4), + ($5, $6); `, 1, 2, 11, 22, 111, 222) } @@ -82,12 +83,12 @@ func TestInsertValuesFromModel(t *testing.T) { MODEL(&toInsert) expectedSQL := ` -INSERT INTO db.table1 (col1, col_float) VALUES - ($1, $2), - ($3, $4); +INSERT INTO db.table1 (col1, col_float) +VALUES ($1, $2), + ($3, $4); ` - assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) + assertStatementSql(t, stmt, expectedSQL, 1, float64(1.11), 1, float64(1.11)) } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { @@ -139,9 +140,63 @@ func TestInsertDefaultValue(t *testing.T) { VALUES(DEFAULT, "two") var expectedSQL = ` -INSERT INTO db.table1 (col1, col_float) VALUES - (DEFAULT, $1); +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, $1); ` assertStatementSql(t, stmt, expectedSQL, "two") } + +func TestInsert_ON_CONFLICT(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColBool). + VALUES("one", "two"). + VALUES("1", "2"). + VALUES("theta", "beta"). + ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( + SET(table1ColBool, "12"). + SET(table2ColInt, 1). + SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). + WHERE(table1Col1.GT(Int(2))), + ). + RETURNING(table1Col1, table1ColBool) + + assertDebugStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_bool) +VALUES ('one', 'two'), + ('1', '2'), + ('theta', 'beta') +ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE + SET col_bool = '12', + col_int = 1, + (col1, col_bool) = ROW(2, 'two') + WHERE table1.col1 > 2 +RETURNING table1.col1 AS "table1.col1", + table1.col_bool AS "table1.col_bool"; +`) +} + +func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColBool). + VALUES("one", "two"). + VALUES("1", "2"). + ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( + SET(table1ColBool, "12"). + SET(table2ColInt, 1). + SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). + WHERE(table1Col1.GT(Int(2))), + ). + RETURNING(table1Col1, table1ColBool) + + assertDebugStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_bool) +VALUES ('one', 'two'), + ('1', '2') +ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE + SET col_bool = '12', + col_int = 1, + (col1, col_bool) = ROW(2, 'two') + WHERE table1.col1 > 2 +RETURNING table1.col1 AS "table1.col1", + table1.col_bool AS "table1.col_bool"; +`) +} diff --git a/postgres/table.go b/postgres/table.go index dc0b266..bc2f5c2 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -109,10 +109,10 @@ type tableImpl struct { } // NewTable creates new table with schema Name, table Name and list of columns -func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { +func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table { t := &tableImpl{ - SerializerTable: jet.NewTable(schemaName, name, column, columns...), + SerializerTable: jet.NewTable(schemaName, name, columns...), } t.readableTableInterfaceImpl.parent = t diff --git a/postgres/table_test.go b/postgres/table_test.go index 43aa096..3b6f498 100644 --- a/postgres/table_test.go +++ b/postgres/table_test.go @@ -5,9 +5,9 @@ import ( ) func TestJoinNilInputs(t *testing.T) { - assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), "jet: right hand side of join operation is nil table") - assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), + assertSerializeErr(t, table2.INNER_JOIN(table1, nil), "jet: join condition is nil") } diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 6d54eb5..d96e1e9 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -61,7 +61,7 @@ type clauseSet struct { Values []jet.Serializer } -func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) { +func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { out.NewLine() out.WriteString("SET") diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 38c429c..cbca2bc 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -10,7 +10,6 @@ import ( var table1Col1 = IntegerColumn("col1") var table1ColInt = IntegerColumn("col_int") var table1ColFloat = FloatColumn("col_float") -var table1Col3 = IntegerColumn("col3") var table1ColTime = TimeColumn("col_time") var table1ColTimez = TimezColumn("col_timez") var table1ColTimestamp = TimestampColumn("col_timestamp") @@ -25,7 +24,6 @@ var table1 = NewTable( table1Col1, table1ColInt, table1ColFloat, - table1Col3, table1ColTime, table1ColTimez, table1ColBool, @@ -75,12 +73,16 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { +func assertSerialize(t *testing.T, serializer jet.Serializer, query string, args ...interface{}) { + testutils.AssertSerialize(t, Dialect, serializer, query, args...) +} + +func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) { testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } -func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { - testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) +func assertSerializeErr(t *testing.T, serializer jet.Serializer, errString string) { + testutils.AssertSerializeErr(t, Dialect, serializer, errString) } func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { @@ -88,5 +90,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st } var assertStatementSql = testutils.AssertStatementSql +var assertDebugStatementSql = testutils.AssertDebugStatementSql var assertStatementSqlErr = testutils.AssertStatementSqlErr var assertPanicErr = testutils.AssertPanicErr diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 8ae9347..59df9a1 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -76,7 +76,7 @@ func assertGeneratedFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") - testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") @@ -128,7 +128,7 @@ import ( var Actor = newActorTable() -type ActorTable struct { +type actorTable struct { mysql.Table //Columns @@ -141,25 +141,38 @@ type ActorTable struct { MutableColumns mysql.ColumnList } +type ActorTable struct { + actorTable + + EXCLUDED actorTable +} + // creates new ActorTable with assigned alias func (a *ActorTable) AS(alias string) *ActorTable { aliasTable := newActorTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorTable() *ActorTable { + return &ActorTable{ + actorTable: newActorTableImpl("dvds", "actor"), + EXCLUDED: newActorTableImpl("", "excluded"), + } +} + +func newActorTableImpl(schemaName, tableName string) actorTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") FirstNameColumn = mysql.StringColumn("first_name") LastNameColumn = mysql.StringColumn("last_name") LastUpdateColumn = mysql.TimestampColumn("last_update") + allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn} + mutableColumns = mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} ) - return &ActorTable{ - Table: mysql.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), + return actorTable{ + Table: mysql.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -167,8 +180,8 @@ func newActorTable() *ActorTable { LastName: LastNameColumn, LastUpdate: LastUpdateColumn, - AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, - MutableColumns: mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` @@ -188,7 +201,7 @@ type Actor struct { } ` -var actorInfoSQLBuilerFile = ` +var actorInfoSQLBuilderFile = ` package view import ( @@ -197,7 +210,7 @@ import ( var ActorInfo = newActorInfoTable() -type ActorInfoTable struct { +type actorInfoTable struct { mysql.Table //Columns @@ -210,25 +223,38 @@ type ActorInfoTable struct { MutableColumns mysql.ColumnList } +type ActorInfoTable struct { + actorInfoTable + + EXCLUDED actorInfoTable +} + // creates new ActorInfoTable with assigned alias func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { aliasTable := newActorInfoTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorInfoTable() *ActorInfoTable { + return &ActorInfoTable{ + actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"), + EXCLUDED: newActorInfoTableImpl("", "excluded"), + } +} + +func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") FirstNameColumn = mysql.StringColumn("first_name") LastNameColumn = mysql.StringColumn("last_name") FilmInfoColumn = mysql.StringColumn("film_info") + allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} + mutableColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} ) - return &ActorInfoTable{ - Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + return actorInfoTable{ + Table: mysql.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -236,8 +262,8 @@ func newActorInfoTable() *ActorInfoTable { LastName: LastNameColumn, FilmInfo: FilmInfoColumn, - AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, - MutableColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 4c86012..cb27f61 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -15,10 +15,10 @@ func TestInsertValues(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (id, url, name, description) VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL); +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL); ` insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). @@ -69,8 +69,8 @@ func TestInsertEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) expectedSQL := ` -INSERT INTO test_sample.link VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); ` stmt := Link.INSERT(). @@ -97,8 +97,8 @@ INSERT INTO test_sample.link VALUES func TestInsertModelObject(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.duckduckgo.com', 'Duck Duck go'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); ` linkData := model.Link{ @@ -119,8 +119,8 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertModelObjectEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link VALUES - (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); ` linkData := model.Link{ @@ -141,10 +141,10 @@ INSERT INTO test_sample.link VALUES func TestInsertModelsObject(t *testing.T) { expectedSQL := ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); ` tutorial := model.Link{ @@ -177,11 +177,11 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertUsingMutableColumns(t *testing.T) { var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); ` google := model.Link{ diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 54d7cc4..71f4f8f 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -161,7 +161,7 @@ import ( var Actor = newActorTable() -type ActorTable struct { +type actorTable struct { postgres.Table //Columns @@ -174,25 +174,38 @@ type ActorTable struct { MutableColumns postgres.ColumnList } +type ActorTable struct { + actorTable + + EXCLUDED actorTable +} + // creates new ActorTable with assigned alias func (a *ActorTable) AS(alias string) *ActorTable { aliasTable := newActorTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorTable() *ActorTable { + return &ActorTable{ + actorTable: newActorTableImpl("dvds", "actor"), + EXCLUDED: newActorTableImpl("", "excluded"), + } +} + +func newActorTableImpl(schemaName, tableName string) actorTable { var ( ActorIDColumn = postgres.IntegerColumn("actor_id") FirstNameColumn = postgres.StringColumn("first_name") LastNameColumn = postgres.StringColumn("last_name") LastUpdateColumn = postgres.TimestampColumn("last_update") + allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn} + mutableColumns = postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} ) - return &ActorTable{ - Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), + return actorTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -200,8 +213,8 @@ func newActorTable() *ActorTable { LastName: LastNameColumn, LastUpdate: LastUpdateColumn, - AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, - MutableColumns: postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` @@ -230,7 +243,7 @@ import ( var ActorInfo = newActorInfoTable() -type ActorInfoTable struct { +type actorInfoTable struct { postgres.Table //Columns @@ -243,25 +256,38 @@ type ActorInfoTable struct { MutableColumns postgres.ColumnList } +type ActorInfoTable struct { + actorInfoTable + + EXCLUDED actorInfoTable +} + // creates new ActorInfoTable with assigned alias func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { aliasTable := newActorInfoTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorInfoTable() *ActorInfoTable { + return &ActorInfoTable{ + actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"), + EXCLUDED: newActorInfoTableImpl("", "excluded"), + } +} + +func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable { var ( ActorIDColumn = postgres.IntegerColumn("actor_id") FirstNameColumn = postgres.StringColumn("first_name") LastNameColumn = postgres.StringColumn("last_name") FilmInfoColumn = postgres.StringColumn("film_info") + allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} + mutableColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} ) - return &ActorInfoTable{ - Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + return actorInfoTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -269,8 +295,8 @@ func newActorInfoTable() *ActorInfoTable { LastName: LastNameColumn, FilmInfo: FilmInfoColumn, - AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, - MutableColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` @@ -422,7 +448,7 @@ import ( var AllTypes = newAllTypesTable() -type AllTypesTable struct { +type allTypesTable struct { postgres.Table //Columns @@ -492,16 +518,27 @@ type AllTypesTable struct { MutableColumns postgres.ColumnList } +type AllTypesTable struct { + allTypesTable + + EXCLUDED allTypesTable +} + // creates new AllTypesTable with assigned alias func (a *AllTypesTable) AS(alias string) *AllTypesTable { aliasTable := newAllTypesTable() - aliasTable.Table.AS(alias) - return aliasTable } func newAllTypesTable() *AllTypesTable { + return &AllTypesTable{ + allTypesTable: newAllTypesTableImpl("test_sample", "all_types"), + EXCLUDED: newAllTypesTableImpl("", "excluded"), + } +} + +func newAllTypesTableImpl(schemaName, tableName string) allTypesTable { var ( SmallIntPtrColumn = postgres.IntegerColumn("small_int_ptr") SmallIntColumn = postgres.IntegerColumn("small_int") @@ -564,10 +601,12 @@ func newAllTypesTable() *AllTypesTable { JsonbArrayColumn = postgres.StringColumn("jsonb_array") TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr") TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array") + allColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} + mutableColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} ) - return &AllTypesTable{ - Table: postgres.NewTable("test_sample", "all_types", SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn), + return allTypesTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns SmallIntPtr: SmallIntPtrColumn, @@ -632,8 +671,8 @@ func newAllTypesTable() *AllTypesTable { TextMultiDimArrayPtr: TextMultiDimArrayPtrColumn, TextMultiDimArray: TextMultiDimArrayColumn, - AllColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn}, - MutableColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } ` diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 38c135e..ecbd3c5 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -2,11 +2,13 @@ package postgres import ( "context" + "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/stretchr/testify/assert" + "math/rand" "testing" "time" ) @@ -15,10 +17,10 @@ func TestInsertValues(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (id, url, name, description) VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL) +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", @@ -77,8 +79,8 @@ func TestInsertEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) expectedSQL := ` -INSERT INTO test_sample.link VALUES - (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); ` stmt := Link.INSERT(). @@ -90,11 +92,128 @@ INSERT INTO test_sample.link VALUES AssertExec(t, stmt, 1) } +func TestInsertOnConflict(t *testing.T) { + + t.Run("do nothing", func(t *testing.T) { + employee := model.Employee{EmployeeID: rand.Int31()} + + stmt := Employee.INSERT(Employee.AllColumns). + MODEL(employee). + MODEL(employee). + ON_CONFLICT(Employee.EmployeeID).DO_NOTHING() + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) +VALUES ($1, $2, $3, $4, $5), + ($6, $7, $8, $9, $10) +ON CONFLICT (employee_id) DO NOTHING; +`) + AssertExec(t, stmt, 1) + }) + + t.Run("on constraint do nothing", func(t *testing.T) { + employee := model.Employee{EmployeeID: rand.Int31()} + + stmt := Employee.INSERT(Employee.AllColumns). + MODEL(employee). + MODEL(employee). + ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING() + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) +VALUES ($1, $2, $3, $4, $5), + ($6, $7, $8, $9, $10) +ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; +`) + AssertExec(t, stmt, 1) + }) + + t.Run("do update", func(t *testing.T) { + cleanUpLinkTable(t) + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_CONFLICT(Link.ID).DO_UPDATE( + SET(Link.ID, Link.EXCLUDED.ID). + SET(Link.URL, "http://www.postgresqltutorial2.com"), + ). + RETURNING(Link.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES ($1, $2, $3, DEFAULT), + ($4, $5, $6, DEFAULT) +ON CONFLICT (id) DO UPDATE + SET id = excluded.id, + url = $7 +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + AssertExec(t, stmt, 2) + }) + + t.Run("on constraint do update", func(t *testing.T) { + cleanUpLinkTable(t) + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE( + SET(Link.ID, Link.EXCLUDED.ID). + SET(Link.URL, "http://www.postgresqltutorial2.com"), + ). + RETURNING(Link.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES ($1, $2, $3, DEFAULT), + ($4, $5, $6, DEFAULT) +ON CONFLICT ON CONSTRAINT link_pkey DO UPDATE + SET id = excluded.id, + url = $7 +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + AssertExec(t, stmt, 2) + }) + + t.Run("do update complex", func(t *testing.T) { + cleanUpLinkTable(t) + + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE( + SET(Link.ID, SELECT(MAXi(Link.ID).ADD(Int(1))).FROM(Link)). + SET(ColumnList{Link.Name, Link.Description}, jet.ROW(Link.EXCLUDED.Name, String("new description"))). + WHERE(Link.Description.IS_NOT_NULL()), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT) +ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE + SET id = ( + SELECT MAX(link.id) + 1 + FROM test_sample.link + ), + (name, description) = ROW(excluded.name, 'new description') + WHERE link.description IS NOT NULL; +`) + + AssertExec(t, stmt, 1) + }) +} + func TestInsertModelObject(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.duckduckgo.com', 'Duck Duck go'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); ` linkData := model.Link{ @@ -114,8 +233,8 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertModelObjectEmptyColumnList(t *testing.T) { cleanUpLinkTable(t) var expectedSQL = ` -INSERT INTO test_sample.link VALUES - (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); ` linkData := model.Link{ @@ -135,10 +254,10 @@ INSERT INTO test_sample.link VALUES func TestInsertModelsObject(t *testing.T) { expectedSQL := ` -INSERT INTO test_sample.link (url, name) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); ` tutorial := model.Link{ @@ -170,11 +289,11 @@ INSERT INTO test_sample.link (url, name) VALUES func TestInsertUsingMutableColumns(t *testing.T) { var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) VALUES - ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); ` google := model.Link{ diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 49ab868..fd538d6 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -5,16 +5,19 @@ import ( "github.com/go-jet/jet/tests/dbconfig" _ "github.com/lib/pq" "github.com/pkg/profile" + "math/rand" "os" "os/exec" "strings" "testing" + "time" ) var db *sql.DB var testRoot string func TestMain(m *testing.M) { + rand.Seed(time.Now().Unix()) defer profile.Start().Stop() setTestRoot() diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 2cfbc85..b2bed1d 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -7,6 +7,7 @@ import ( . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) @@ -221,6 +222,10 @@ FROM test_sample.person; func TestSelecSelfJoin1(t *testing.T) { + // clean up + _, err := Employee.DELETE().WHERE(Employee.EmployeeID.GT(Int(100))).Exec(db) + require.NoError(t, err) + var expectedSQL = ` SELECT employee.employee_id AS "employee.employee_id", employee.first_name AS "employee.first_name", @@ -256,7 +261,7 @@ ORDER BY employee.employee_id; Manager *Manager } - err := query.Query(db, &dest) + err = query.Query(db, &dest) assert.NoError(t, err) assert.Equal(t, len(dest), 8) diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go index e7a7816..c30a365 100644 --- a/tests/postgres/util_test.go +++ b/tests/postgres/util_test.go @@ -6,13 +6,14 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { res, err := stmt.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) rows, err := res.RowsAffected() assert.NoError(t, err) assert.Equal(t, rows, rowsAffected) From 241ea0d6d6d0779bc20b36ccd3b53604b1c65c8f Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 13 Apr 2020 10:18:14 +0200 Subject: [PATCH 03/23] Gen files idempotence test clean up. --- internal/testutils/test_utils.go | 9 ++-- tests/mysql/generator_test.go | 36 ++++++++++++++-- tests/postgres/generator_test.go | 72 ++++++++++++++++++++++++++++---- 3 files changed, 99 insertions(+), 18 deletions(-) diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 942dab7..8673ff4 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -197,15 +197,12 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter } // AssertFileContent check if file content at filePath contains expectedContent text. -func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { +func AssertFileContent(t *testing.T, filePath string, expectedContent string) { enumFileData, err := ioutil.ReadFile(filePath) - assert.NoError(t, err) + require.NoError(t, err) - beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) - - //AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) - require.Equal(t, string(enumFileData[beginIndex:]), expectedContent) + require.Equal(t, "\n"+string(enumFileData), expectedContent) } // AssertFileNamesEqual check if all filesInfos are contained in fileNames diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 59df9a1..f28c5c1 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -67,7 +67,7 @@ func assertGeneratedFiles(t *testing.T) { "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", "payment.go", "rental.go", "staff.go", "store.go") - testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", actorSQLBuilderFile) // View SQL Builder files viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") @@ -76,14 +76,14 @@ func assertGeneratedFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") - testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", actorInfoSQLBuilderFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") assert.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go") - testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") @@ -96,10 +96,17 @@ func assertGeneratedFiles(t *testing.T) { "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") - testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile) + testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", actorModelFile) } var mpaaRatingEnumFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package enum import "github.com/go-jet/jet/mysql" @@ -120,6 +127,13 @@ var FilmRating = &struct { ` var actorSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package table import ( @@ -187,6 +201,13 @@ func newActorTableImpl(schemaName, tableName string) actorTable { ` var actorModelFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package model import ( @@ -202,6 +223,13 @@ type Actor struct { ` var actorInfoSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package view import ( diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 71f4f8f..a9e771d 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -101,7 +101,7 @@ func assertGeneratedFiles(t *testing.T) { "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "payment.go", "rental.go", "staff.go", "store.go") - testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/table/actor.go", actorSQLBuilderFile) // View SQL Builder files viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") @@ -110,14 +110,14 @@ func assertGeneratedFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") - testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilderFile) + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/view/actor_info.go", actorInfoSQLBuilderFile) // Enums SQL Builder files enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") assert.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go") - testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") @@ -129,10 +129,17 @@ func assertGeneratedFiles(t *testing.T) { "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") - testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", "\npackage model", actorModelFile) + testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", actorModelFile) } var mpaaRatingEnumFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package enum import "github.com/go-jet/jet/postgres" @@ -153,6 +160,13 @@ var MpaaRating = &struct { ` var actorSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package table import ( @@ -220,6 +234,13 @@ func newActorTableImpl(schemaName, tableName string) actorTable { ` var actorModelFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package model import ( @@ -235,6 +256,13 @@ type Actor struct { ` var actorInfoSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package view import ( @@ -310,8 +338,8 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { assert.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mood.go", "level.go") - testutils.AssertFileContent(t, enumDir+"mood.go", "\npackage enum", moodEnumContent) - testutils.AssertFileContent(t, enumDir+"level.go", "\npackage enum", levelEnumContent) + testutils.AssertFileContent(t, enumDir+"mood.go", moodEnumContent) + testutils.AssertFileContent(t, enumDir+"level.go", levelEnumContent) modelFiles, err := ioutil.ReadDir(modelDir) assert.NoError(t, err) @@ -319,7 +347,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go") - testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent) + testutils.AssertFileContent(t, modelDir+"all_types.go", allTypesModelContent) tableFiles, err := ioutil.ReadDir(tableDir) assert.NoError(t, err) @@ -327,10 +355,17 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go") - testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent) + testutils.AssertFileContent(t, tableDir+"all_types.go", allTypesTableContent) } var moodEnumContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package enum import "github.com/go-jet/jet/postgres" @@ -347,6 +382,13 @@ var Mood = &struct { ` var levelEnumContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package enum import "github.com/go-jet/jet/postgres" @@ -367,6 +409,13 @@ var Level = &struct { ` var allTypesModelContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package model import ( @@ -440,6 +489,13 @@ type AllTypes struct { ` var allTypesTableContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + package table import ( From 926b88ed40a6f8ea0de1cd30fd4f0bbbae791c5d Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 2 May 2020 22:15:38 +0200 Subject: [PATCH 04/23] Add reserved words for MySQL. --- internal/jet/sql_builder.go | 7 + internal/jet/utils.go | 2 +- internal/testutils/test_utils.go | 79 +++++++++ mysql/dialect.go | 265 +++++++++++++++++++++++++++++++ tests/mysql/alltypes_test.go | 228 ++++++++++++++++++-------- tests/postgres/alltypes_test.go | 48 +++--- tests/postgres/sample_test.go | 6 +- tests/postgres/scan_test.go | 38 ++--- tests/postgres/select_test.go | 16 +- tests/postgres/update_test.go | 4 +- tests/postgres/util_test.go | 56 +------ tests/testdata | 2 +- 12 files changed, 573 insertions(+), 178 deletions(-) diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 642f0cd..59b776f 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -163,10 +163,17 @@ func argToString(value interface{}) string { case time.Time: return stringQuote(string(pq.FormatTimestamp(bindVal))) default: + if strBindValue, ok := bindVal.(toStringInterface); ok { + return stringQuote(strBindValue.String()) + } panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) } } +type toStringInterface interface { + String() string +} + func integerTypesToString(value interface{}) string { switch bindVal := value.(type) { case int: diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 50f4e13..2f301be 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -59,7 +59,7 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { panic("jet: nil column in columns list") } - out.WriteString(col.Name()) + out.WriteIdentifier(col.Name()) } } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 8673ff4..aa53b1e 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/qrm" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" @@ -14,6 +15,7 @@ import ( "path/filepath" "runtime" "testing" + "time" "github.com/google/go-cmp/cmp" ) @@ -224,3 +226,80 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { assert.True(t, cmp.Equal(actual, expected), msg) } + +// BoolPtr returns address of bool parameter +func BoolPtr(b bool) *bool { + return &b +} + +// Int8Ptr returns address of int8 parameter +func Int8Ptr(i int8) *int8 { + return &i +} + +// UInt8Ptr returns address of uint8 parameter +func UInt8Ptr(i uint8) *uint8 { + return &i +} + +// Int16Ptr returns address of int16 parameter +func Int16Ptr(i int16) *int16 { + return &i +} + +// UInt16Ptr returns address of uint16 parameter +func UInt16Ptr(i uint16) *uint16 { + return &i +} + +// Int32Ptr returns address of int32 parameter +func Int32Ptr(i int32) *int32 { + return &i +} + +// UInt32Ptr returns address of uint32 parameter +func UInt32Ptr(i uint32) *uint32 { + return &i +} + +// Int64Ptr returns address of int64 parameter +func Int64Ptr(i int64) *int64 { + return &i +} + +// UInt64Ptr returns address of uint64 parameter +func UInt64Ptr(i uint64) *uint64 { + return &i +} + +// StringPtr returns address of string parameter +func StringPtr(s string) *string { + return &s +} + +// TimePtr returns address of time.Time parameter +func TimePtr(t time.Time) *time.Time { + return &t +} + +// ByteArrayPtr returns address of []byte parameter +func ByteArrayPtr(arr []byte) *[]byte { + return &arr +} + +// Float32Ptr returns address of float32 parameter +func Float32Ptr(f float32) *float32 { + return &f +} + +// Float64Ptr returns address of float64 parameter +func Float64Ptr(f float64) *float64 { + return &f +} + +// UUIDPtr returns address of uuid.UUID +func UUIDPtr(u string) *uuid.UUID { + newUUID := uuid.MustParse(u) + + return &newUUID +} diff --git a/mysql/dialect.go b/mysql/dialect.go index cfd452a..55862f9 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -26,6 +26,7 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, + ReservedWords: reservedWords, } return jet.NewDialect(mySQLDialectParams) @@ -160,3 +161,267 @@ func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFun jet.Serialize(expressions[1], statement, out, options...) } } + +var reservedWords = []string{ + "ACCESSIBLE", + "ADD", + "ALL", + "ALTER", + "ANALYZE", + "AND", + "AS", + "ASC", + "ASENSITIVE", + "BEFORE", + "BETWEEN", + "BIGINT", + "BINARY", + "BLOB", + "BOTH", + "BY", + "CALL", + "CASCADE", + "CASE", + "CHANGE", + "CHAR", + "CHARACTER", + "CHECK", + "COLLATE", + "COLUMN", + "CONDITION", + "CONSTRAINT", + "CONTINUE", + "CONVERT", + "CREATE", + "CROSS", + "CUBE", + "CUME_DIST", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "CURSOR", + "DATABASE", + "DATABASES", + "DAY_HOUR", + "DAY_MICROSECOND", + "DAY_MINUTE", + "DAY_SECOND", + "DEC", + "DECIMAL", + "DECLARE", + "DEFAULT", + "DELAYED", + "DELETE", + "DENSE_RANK", + "DESC", + "DESCRIBE", + "DETERMINISTIC", + "DISTINCT", + "DISTINCTROW", + "DIV", + "DOUBLE", + "DROP", + "DUAL", + "EACH", + "ELSE", + "ELSEIF", + "EMPTY", + "ENCLOSED", + "ESCAPED", + "EXCEPT", + "EXISTS", + "EXIT", + "EXPLAIN", + "FALSE", + "FETCH", + "FIRST_VALUE", + "FLOAT", + "FLOAT4", + "FLOAT8", + "FOR", + "FORCE", + "FOREIGN", + "FROM", + "FULLTEXT", + "FUNCTION", + "GENERATED", + "GET", + "GRANT", + "GROUP", + "GROUPING", + "GROUPS", + "HAVING", + "HIGH_PRIORITY", + "HOUR_MICROSECOND", + "HOUR_MINUTE", + "HOUR_SECOND", + "IF", + "IGNORE", + "IN", + "INDEX", + "INFILE", + "INNER", + "INOUT", + "INSENSITIVE", + "INSERT", + "INT", + "INT1", + "INT2", + "INT3", + "INT4", + "INT8", + "INTEGER", + "INTERVAL", + "INTO", + "IO_AFTER_GTIDS", + "IO_BEFORE_GTIDS", + "IS", + "ITERATE", + "JOIN", + "JSON_TABLE", + "KEY", + "KEYS", + "KILL", + "LAG", + "LAST_VALUE", + "LATERAL", + "LEAD", + "LEADING", + "LEAVE", + "LEFT", + "LIKE", + "LIMIT", + "LINEAR", + "LINES", + "LOAD", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOCK", + "LONG", + "LONGBLOB", + "LONGTEXT", + "LOOP", + "LOW_PRIORITY", + "MASTER_BIND", + "MASTER_SSL_VERIFY_SERVER_CERT", + "MATCH", + "MAXVALUE", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "MIDDLEINT", + "MINUTE_MICROSECOND", + "MINUTE_SECOND", + "MOD", + "MODIFIES", + "NATURAL", + "NOT", + "NO_WRITE_TO_BINLOG", + "NTH_VALUE", + "NTILE", + "NULL", + "NUMERIC", + "OF", + "ON", + "OPTIMIZE", + "OPTIMIZER_COSTS", + "OPTION", + "OPTIONALLY", + "OR", + "ORDER", + "OUT", + "OUTER", + "OUTFILE", + "OVER", + "PARTITION", + "PERCENT_RANK", + "PRECISION", + "PRIMARY", + "PROCEDURE", + "PURGE", + "RANGE", + "RANK", + "READ", + "READS", + "READ_WRITE", + "REAL", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "RELEASE", + "RENAME", + "REPEAT", + "REPLACE", + "REQUIRE", + "RESIGNAL", + "RESTRICT", + "RETURN", + "REVOKE", + "RIGHT", + "RLIKE", + "ROW", + "ROWS", + "ROW_NUMBER", + "SCHEMA", + "SCHEMAS", + "SECOND_MICROSECOND", + "SELECT", + "SENSITIVE", + "SEPARATOR", + "SET", + "SHOW", + "SIGNAL", + "SMALLINT", + "SPATIAL", + "SPECIFIC", + "SQL", + "SQLEXCEPTION", + "SQLSTATE", + "SQLWARNING", + "SQL_BIG_RESULT", + "SQL_CALC_FOUND_ROWS", + "SQL_SMALL_RESULT", + "SSL", + "STARTING", + "STORED", + "STRAIGHT_JOIN", + "SYSTEM", + "TABLE", + "TERMINATED", + "THEN", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "TO", + "TRAILING", + "TRIGGER", + "TRUE", + "UNDO", + "UNION", + "UNIQUE", + "UNLOCK", + "UNSIGNED", + "UPDATE", + "USAGE", + "USE", + "USING", + "UTC_DATE", + "UTC_TIME", + "UTC_TIMESTAMP", + "VALUES", + "VARBINARY", + "VARCHAR", + "VARCHARACTER", + "VARYING", + "VIRTUAL", + "WHEN", + "WHERE", + "WHILE", + "WINDOW", + "WITH", + "WRITE", + "XOR", + "YEAR_MONTH", + "ZEROFILL", +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 2c6768e..636359d 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,6 +1,9 @@ package mysql import ( + "fmt" + "github.com/stretchr/testify/require" + "strings" "testing" "time" @@ -95,23 +98,23 @@ func TestExpressionOperators(t *testing.T) { //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` -SELECT all_types.integer IS NULL AS "result.is_null", + testutils.AssertStatementSql(t, query, strings.Replace(` +SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", (all_types.small_int_ptr IN (( - SELECT all_types.integer AS "all_types.integer" + SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types ))) AS "result.in_select", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN (( - SELECT all_types.integer AS "all_types.integer" + SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types ))) AS "result.not_in_select", DATABASE() FROM test_sample.all_types LIMIT ?; -`, int64(11), int64(22), int64(11), int64(22), int64(2)) +`, "'", "`", -1), int64(11), int64(22), int64(11), int64(22), int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` @@ -261,45 +264,47 @@ func TestFloatOperators(t *testing.T) { queryStr, _ := query.Sql() - assert.Equal(t, queryStr, ` -SELECT (all_types.numeric = all_types.numeric) AS "eq1", - (all_types.decimal = ?) AS "eq2", - (all_types.real = ?) AS "eq3", - (NOT(all_types.numeric <=> all_types.numeric)) AS "distinct1", - (NOT(all_types.decimal <=> ?)) AS "distinct2", - (NOT(all_types.real <=> ?)) AS "distinct3", - (all_types.numeric <=> all_types.numeric) AS "not_distinct1", - (all_types.decimal <=> ?) AS "not_distinct2", - (all_types.real <=> ?) AS "not_distinct3", - (all_types.numeric < ?) AS "lt1", - (all_types.numeric < ?) AS "lt2", - (all_types.numeric > ?) AS "gt1", - (all_types.numeric > ?) AS "gt2", - TRUNCATE((all_types.decimal + all_types.decimal), ?) AS "add1", - TRUNCATE((all_types.decimal + ?), ?) AS "add2", - TRUNCATE((all_types.decimal - all_types.decimal_ptr), ?) AS "sub1", - TRUNCATE((all_types.decimal - ?), ?) AS "sub2", - TRUNCATE((all_types.decimal * all_types.decimal_ptr), ?) AS "mul1", - TRUNCATE((all_types.decimal * ?), ?) AS "mul2", - TRUNCATE((all_types.decimal / all_types.decimal_ptr), ?) AS "div1", - TRUNCATE((all_types.decimal / ?), ?) AS "div2", - TRUNCATE((all_types.decimal % all_types.decimal_ptr), ?) AS "mod1", - TRUNCATE((all_types.decimal % ?), ?) AS "mod2", - TRUNCATE(POW(all_types.decimal, all_types.decimal_ptr), ?) AS "pow1", - TRUNCATE(POW(all_types.decimal, ?), ?) AS "pow2", - TRUNCATE(ABS(all_types.decimal), ?) AS "abs", - TRUNCATE(POWER(all_types.decimal, ?), ?) AS "power", - TRUNCATE(SQRT(all_types.decimal), ?) AS "sqrt", - TRUNCATE(POWER(all_types.decimal, (? / ?)), ?) AS "cbrt", - CEIL(all_types.real) AS "ceil", - FLOOR(all_types.real) AS "floor", - ROUND(all_types.decimal) AS "round1", - ROUND(all_types.decimal, ?) AS "round2", - SIGN(all_types.real) AS "sign", - TRUNCATE(all_types.decimal, ?) AS "trunc" + //fmt.Println(queryStr) + + assert.Equal(t, queryStr, strings.Replace(` +SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1", + (all_types.'decimal' = ?) AS "eq2", + (all_types.'real' = ?) AS "eq3", + (NOT(all_types.'numeric' <=> all_types.'numeric')) AS "distinct1", + (NOT(all_types.'decimal' <=> ?)) AS "distinct2", + (NOT(all_types.'real' <=> ?)) AS "distinct3", + (all_types.'numeric' <=> all_types.'numeric') AS "not_distinct1", + (all_types.'decimal' <=> ?) AS "not_distinct2", + (all_types.'real' <=> ?) AS "not_distinct3", + (all_types.'numeric' < ?) AS "lt1", + (all_types.'numeric' < ?) AS "lt2", + (all_types.'numeric' > ?) AS "gt1", + (all_types.'numeric' > ?) AS "gt2", + TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1", + TRUNCATE((all_types.'decimal' + ?), ?) AS "add2", + TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1", + TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2", + TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1", + TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2", + TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1", + TRUNCATE((all_types.'decimal' / ?), ?) AS "div2", + TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1", + TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2", + TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1", + TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2", + TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs", + TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power", + TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt", + TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt", + CEIL(all_types.'real') AS "ceil", + FLOOR(all_types.'real') AS "floor", + ROUND(all_types.'decimal') AS "round1", + ROUND(all_types.'decimal', ?) AS "round2", + SIGN(all_types.'real') AS "sign", + TRUNCATE(all_types.'decimal', ?) AS "trunc" FROM test_sample.all_types LIMIT ?; -`) +`, "'", "`", -1)) var dest []struct { common.FloatExpressionTestResult `alias:"."` @@ -568,7 +573,7 @@ func TestTimeExpressions(t *testing.T) { //fmt.Println(query.DebugSql()) - testutils.AssertDebugStatementSql(t, query, ` + testutils.AssertDebugStatementSql(t, query, strings.Replace(` SELECT CAST('20:34:58' AS TIME), all_types.time = all_types.time, all_types.time = CAST('23:06:06' AS TIME), @@ -589,7 +594,7 @@ SELECT CAST('20:34:58' AS TIME), all_types.time >= all_types.time, all_types.time >= CAST('14:26:36' AS TIME), all_types.time + INTERVAL 10 MINUTE, - all_types.time + INTERVAL all_types.integer MINUTE, + all_types.time + INTERVAL all_types.''integer'' MINUTE, all_types.time + INTERVAL 3 HOUR, all_types.time - INTERVAL 20 MINUTE, all_types.time - INTERVAL all_types.small_int MINUTE, @@ -598,7 +603,7 @@ SELECT CAST('20:34:58' AS TIME), CURRENT_TIME, CURRENT_TIME(3) FROM test_sample.all_types; -`, "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", +`, "''", "`", -1), "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", "19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36") dest := []struct{}{} @@ -648,25 +653,25 @@ func TestDateExpressions(t *testing.T) { //fmt.Println(query.DebugSql()) - testutils.AssertStatementSql(t, query, ` -SELECT CAST(? AS DATE), + testutils.AssertDebugStatementSql(t, query, ` +SELECT CAST('2009-11-17' AS DATE), all_types.date = all_types.date, - all_types.date = CAST(? AS DATE), + all_types.date = CAST('2019-06-06' AS DATE), all_types.date_ptr != all_types.date, - all_types.date_ptr != CAST(? AS DATE), + all_types.date_ptr != CAST('2019-01-06' AS DATE), NOT(all_types.date <=> all_types.date), - NOT(all_types.date <=> CAST(? AS DATE)), + NOT(all_types.date <=> CAST('2019-02-06' AS DATE)), all_types.date <=> all_types.date, - all_types.date <=> CAST(? AS DATE), + all_types.date <=> CAST('2019-03-06' AS DATE), all_types.date < all_types.date, - all_types.date < CAST(? AS DATE), + all_types.date < CAST('2019-04-06' AS DATE), all_types.date <= all_types.date, - all_types.date <= CAST(? AS DATE), + all_types.date <= CAST('2019-05-05' AS DATE), all_types.date > all_types.date, - all_types.date > CAST(? AS DATE), + all_types.date > CAST('2019-01-04' AS DATE), all_types.date >= all_types.date, - all_types.date >= CAST(? AS DATE), - all_types.date + INTERVAL ? MINUTE_MICROSECOND, + all_types.date >= CAST('2019-02-03' AS DATE), + all_types.date + INTERVAL '10:20.000100' MINUTE_MICROSECOND, all_types.date + INTERVAL all_types.big_int MINUTE, all_types.date + INTERVAL 15 HOUR, all_types.date - INTERVAL 20 MINUTE, @@ -963,6 +968,91 @@ func TestINTERVAL(t *testing.T) { assert.NoError(t, err) } +func TestAllTypesInsert(t *testing.T) { + tx, err := db.Begin() + require.NoError(t, err) + + stmt := AllTypes.INSERT(AllTypes.AllColumns). + MODEL(toInsert) + + fmt.Println(stmt.DebugSql()) + + testutils.AssertExec(t, stmt, tx, 1) + + var dest model.AllTypes + err = AllTypes.SELECT(AllTypes.AllColumns). + WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))). + Query(tx, &dest) + + require.NoError(t, err) + require.Equal(t, toInsert.TinyInt, dest.TinyInt) + + err = tx.Rollback() + require.NoError(t, err) +} + +var toInsert = model.AllTypes{ + Boolean: false, + BooleanPtr: testutils.BoolPtr(true), + TinyInt: 1, + UTinyInt: 2, + SmallInt: 3, + USmallInt: 4, + MediumInt: 5, + UMediumInt: 6, + Integer: 7, + UInteger: 8, + BigInt: 9, + UBigInt: 1122334455, + TinyIntPtr: testutils.Int8Ptr(11), + UTinyIntPtr: testutils.UInt8Ptr(22), + SmallIntPtr: testutils.Int16Ptr(33), + USmallIntPtr: testutils.UInt16Ptr(44), + MediumIntPtr: testutils.Int32Ptr(55), + UMediumIntPtr: testutils.UInt32Ptr(66), + IntegerPtr: testutils.Int32Ptr(77), + UIntegerPtr: testutils.UInt32Ptr(88), + BigIntPtr: testutils.Int64Ptr(99), + UBigIntPtr: testutils.UInt64Ptr(111), + Decimal: 11.22, + DecimalPtr: testutils.Float64Ptr(33.44), + Numeric: 55.66, + NumericPtr: testutils.Float64Ptr(77.88), + Float: 99.00, + FloatPtr: testutils.Float64Ptr(11.22), + Double: 33.44, + DoublePtr: testutils.Float64Ptr(55.66), + Real: 77.88, + RealPtr: testutils.Float64Ptr(99.00), + Bit: "1", + BitPtr: testutils.StringPtr("0"), + Time: time.Date(0, 0, 0, 10, 11, 12, 100, &time.Location{}), + TimePtr: testutils.TimePtr(time.Date(0, 0, 0, 10, 11, 12, 100, time.UTC)), + Date: time.Now(), + DatePtr: testutils.TimePtr(time.Now()), + DateTime: time.Now(), + DateTimePtr: testutils.TimePtr(time.Now()), + Timestamp: time.Now(), + TimestampPtr: testutils.TimePtr(time.Now()), + Year: 2000, + YearPtr: testutils.Int16Ptr(2001), + Char: "abcd", + CharPtr: testutils.StringPtr("absd"), + VarChar: "abcd", + VarCharPtr: testutils.StringPtr("absd"), + Binary: []byte("1010"), + BinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + VarBinary: []byte("1010"), + VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + Blob: []byte("large file"), + BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + Text: "some text", + TextPtr: testutils.StringPtr("text"), + Enum: model.AllTypesEnum_Value1, + JSON: "{}", + JSONPtr: testutils.StringPtr(`{"a": 1}`), +} + var allTypesJson = ` [ { @@ -1100,24 +1190,22 @@ func TestReservedWord(t *testing.T) { stmt := SELECT(User.AllColumns). FROM(User) - // NOTE: A word that follows a period in a qualified name must be an identifier, so it - // need not be quoted even if it is reserved - testutils.AssertDebugStatementSql(t, stmt, ` -SELECT user.column AS "user.column", - user.use AS "user.use", + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +SELECT user.''column'' AS "user.column", + user.''use'' AS "user.use", user.ceil AS "user.ceil", user.commit AS "user.commit", - user.create AS "user.create", - user.default AS "user.default", - user.desc AS "user.desc", - user.empty AS "user.empty", - user.float AS "user.float", - user.join AS "user.join", - user.like AS "user.like", + user.''create'' AS "user.create", + user.''default'' AS "user.default", + user.''desc'' AS "user.desc", + user.''empty'' AS "user.empty", + user.''float'' AS "user.float", + user.''join'' AS "user.join", + user.''like'' AS "user.like", user.max AS "user.max", - user.rank AS "user.rank" + user.''rank'' AS "user.rank" FROM test_sample.user; -`) +`, "''", "`", -1)) var dest []model.User err := stmt.Query(db, &dest) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 0b30080..f03f38b 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1066,32 +1066,32 @@ LIMIT $6; } var allTypesRow0 = model.AllTypes{ - SmallIntPtr: Int16Ptr(14), + SmallIntPtr: testutils.Int16Ptr(14), SmallInt: 14, - IntegerPtr: Int32Ptr(300), + IntegerPtr: testutils.Int32Ptr(300), Integer: 300, - BigIntPtr: Int64Ptr(50000), + BigIntPtr: testutils.Int64Ptr(50000), BigInt: 5000, - DecimalPtr: Float64Ptr(1.11), + DecimalPtr: testutils.Float64Ptr(1.11), Decimal: 1.11, - NumericPtr: Float64Ptr(2.22), + NumericPtr: testutils.Float64Ptr(2.22), Numeric: 2.22, - RealPtr: Float32Ptr(5.55), + RealPtr: testutils.Float32Ptr(5.55), Real: 5.55, - DoublePrecisionPtr: Float64Ptr(11111111.22), + DoublePrecisionPtr: testutils.Float64Ptr(11111111.22), DoublePrecision: 11111111.22, Smallserial: 1, Serial: 1, Bigserial: 1, //MoneyPtr: nil, //Money: - VarCharPtr: StringPtr("ABBA"), + VarCharPtr: testutils.StringPtr("ABBA"), VarChar: "ABBA", - CharPtr: StringPtr("JOHN "), + CharPtr: testutils.StringPtr("JOHN "), Char: "JOHN ", - TextPtr: StringPtr("Some text"), + TextPtr: testutils.StringPtr("Some text"), Text: "Some text", - ByteaPtr: ByteArrayPtr([]byte("bytea")), + ByteaPtr: testutils.ByteArrayPtr([]byte("bytea")), Bytea: []byte("bytea"), TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), @@ -1103,31 +1103,31 @@ var allTypesRow0 = model.AllTypes{ Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"), TimePtr: testutils.TimeWithoutTimeZone("04:05:06"), Time: *testutils.TimeWithoutTimeZone("04:05:06"), - IntervalPtr: StringPtr("3 days 04:05:06"), + IntervalPtr: testutils.StringPtr("3 days 04:05:06"), Interval: "3 days 04:05:06", - BooleanPtr: BoolPtr(true), + BooleanPtr: testutils.BoolPtr(true), Boolean: false, - PointPtr: StringPtr("(2,3)"), - BitPtr: StringPtr("101"), + PointPtr: testutils.StringPtr("(2,3)"), + BitPtr: testutils.StringPtr("101"), Bit: "101", - BitVaryingPtr: StringPtr("101111"), + BitVaryingPtr: testutils.StringPtr("101111"), BitVarying: "101111", - TsvectorPtr: StringPtr("'supernova':1"), + TsvectorPtr: testutils.StringPtr("'supernova':1"), Tsvector: "'supernova':1", - UUIDPtr: UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), + UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), - XMLPtr: StringPtr("abc"), + XMLPtr: testutils.StringPtr("abc"), XML: "abc", - JSONPtr: StringPtr(`{"a": 1, "b": 3}`), + JSONPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), JSON: `{"a": 1, "b": 3}`, - JsonbPtr: StringPtr(`{"a": 1, "b": 3}`), + JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), Jsonb: `{"a": 1, "b": 3}`, - IntegerArrayPtr: StringPtr("{1,2,3}"), + IntegerArrayPtr: testutils.StringPtr("{1,2,3}"), IntegerArray: "{1,2,3}", - TextArrayPtr: StringPtr("{breakfast,consulting}"), + TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"), TextArray: "{breakfast,consulting}", JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, - TextMultiDimArrayPtr: StringPtr("{{meeting,lunch},{training,presentation}}"), + TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", } diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index b2bed1d..ea35c16 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -28,7 +28,7 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; err := query.Query(db, &result) assert.NoError(t, err) assert.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - testutils.AssertDeepEqual(t, result.UUIDPtr, UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) } func TestUUIDComplex(t *testing.T) { @@ -280,7 +280,7 @@ ORDER BY employee.employee_id; FirstName: "Salley", LastName: "Lester", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1), - ManagerID: Int32Ptr(3), + ManagerID: testutils.Int32Ptr(3), }) } @@ -322,7 +322,7 @@ FROM test_sample."WEIRD NAMES TABLE"; WeirdColumnName5: "Doe", WeirdColumnName6: "Doe", WeirdColumnName7: "Doe", - Weirdcolumnname8: StringPtr("Doe"), + Weirdcolumnname8: testutils.StringPtr("Doe"), WeirdColName9: "Doe", WeirdColuName10: "Doe", WeirdColuName11: "Doe", diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 40e18f8..c5b6d3e 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -505,10 +505,10 @@ func TestScanToSlice(t *testing.T) { assert.NoError(t, err) assert.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) - testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{Int32Ptr(1), Int32Ptr(2), Int32Ptr(3), Int32Ptr(4), - Int32Ptr(5), Int32Ptr(6), Int32Ptr(7), Int32Ptr(8)}) + testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.Int32Ptr(1), testutils.Int32Ptr(2), testutils.Int32Ptr(3), testutils.Int32Ptr(4), + testutils.Int32Ptr(5), testutils.Int32Ptr(6), testutils.Int32Ptr(7), testutils.Int32Ptr(8)}) testutils.AssertDeepEqual(t, dest[1].Film, film2) - testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{Int32Ptr(9), Int32Ptr(10)}) + testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.Int32Ptr(9), testutils.Int32Ptr(10)}) }) t.Run("complex struct 1", func(t *testing.T) { @@ -726,10 +726,10 @@ func TestStructScanAllNull(t *testing.T) { var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", - Address2: StringPtr(""), + Address2: testutils.StringPtr(""), District: "England", CityID: 312, - PostalCode: StringPtr("3433"), + PostalCode: testutils.StringPtr("3433"), Phone: "246810237916", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -737,10 +737,10 @@ var address256 = model.Address{ var addres517 = model.Address{ AddressID: 517, Address: "548 Uruapan Street", - Address2: StringPtr(""), + Address2: testutils.StringPtr(""), District: "Ontario", CityID: 312, - PostalCode: StringPtr("35653"), + PostalCode: testutils.StringPtr("35653"), Phone: "879347453467", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -750,12 +750,12 @@ var customer256 = model.Customer{ StoreID: 2, FirstName: "Mattie", LastName: "Hoffman", - Email: StringPtr("mattie.hoffman@sakilacustomer.org"), + Email: testutils.StringPtr("mattie.hoffman@sakilacustomer.org"), AddressID: 256, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var customer512 = model.Customer{ @@ -763,12 +763,12 @@ var customer512 = model.Customer{ StoreID: 1, FirstName: "Cecil", LastName: "Vines", - Email: StringPtr("cecil.vines@sakilacustomer.org"), + Email: testutils.StringPtr("cecil.vines@sakilacustomer.org"), AddressID: 517, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var countryUk = model.Country{ @@ -801,32 +801,32 @@ var inventory2 = model.Inventory{ var film1 = model.Film{ FilmID: 1, Title: "Academy Dinosaur", - Description: StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), - ReleaseYear: Int32Ptr(2006), + Description: testutils.StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), + ReleaseYear: testutils.Int32Ptr(2006), LanguageID: 1, RentalDuration: 6, RentalRate: 0.99, - Length: Int16Ptr(86), + Length: testutils.Int16Ptr(86), ReplacementCost: 20.99, Rating: &pgRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + SpecialFeatures: testutils.StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", } var film2 = model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: Int32Ptr(2006), + Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: testutils.Int32Ptr(2006), LanguageID: 1, RentalDuration: 3, RentalRate: 4.99, - Length: Int16Ptr(48), + Length: testutils.Int16Ptr(48), ReplacementCost: 12.99, Rating: &gRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: StringPtr(`{Trailers,"Deleted Scenes"}`), + SpecialFeatures: testutils.StringPtr(`{Trailers,"Deleted Scenes"}`), Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`, } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index b641ec0..28afd8c 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -982,16 +982,16 @@ ORDER BY film.film_id ASC; testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: Int32Ptr(2006), + Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: testutils.Int32Ptr(2006), LanguageID: 1, RentalRate: 4.99, - Length: Int16Ptr(48), + Length: testutils.Int16Ptr(48), ReplacementCost: 12.99, Rating: &gRating, RentalDuration: 3, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: StringPtr("{Trailers,\"Deleted Scenes\"}"), + SpecialFeatures: testutils.StringPtr("{Trailers,\"Deleted Scenes\"}"), Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) } @@ -1130,11 +1130,11 @@ ORDER BY customer_payment_sum."amount_sum" ASC; FirstName: "Brian", LastName: "Wyman", AddressID: 323, - Email: StringPtr("brian.wyman@sakilacustomer.org"), + Email: testutils.StringPtr("brian.wyman@sakilacustomer.org"), Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), }) assert.Equal(t, customersWithAmounts[0].AmountSum, 27.93) @@ -1846,8 +1846,8 @@ func TestDynamicCondition(t *testing.T) { Active *bool } - request.CustomerID = Int64Ptr(1) - request.Active = BoolPtr(true) + request.CustomerID = testutils.Int64Ptr(1) + request.Active = testutils.BoolPtr(true) // ... diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 57f579e..ca07332 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -154,7 +154,7 @@ WHERE link.id = 0; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - assertExecErr(t, stmt, "pq: number of columns does not match number of values") + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") } func TestUpdateWithModelData(t *testing.T) { @@ -241,7 +241,7 @@ WHERE link.id = 201; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) - assertExecErr(t, stmt, "pq: number of columns does not match number of values") + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") } func TestUpdateQueryContext(t *testing.T) { diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go index c30a365..b2d5452 100644 --- a/tests/postgres/util_test.go +++ b/tests/postgres/util_test.go @@ -4,7 +4,6 @@ import ( "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" @@ -19,60 +18,17 @@ func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { assert.Equal(t, rows, rowsAffected) } -func assertExecErr(t *testing.T, stmt jet.Statement, errorStr string) { - _, err := stmt.Exec(db) - - assert.Error(t, err, errorStr) -} - -func BoolPtr(b bool) *bool { - return &b -} - -func Int16Ptr(i int16) *int16 { - return &i -} - -func Int32Ptr(i int32) *int32 { - return &i -} - -func Int64Ptr(i int64) *int64 { - return &i -} - -func StringPtr(s string) *string { - return &s -} - -func ByteArrayPtr(arr []byte) *[]byte { - return &arr -} - -func Float32Ptr(f float32) *float32 { - return &f -} -func Float64Ptr(f float64) *float64 { - return &f -} - -func UUIDPtr(u string) *uuid.UUID { - newUUID := uuid.MustParse(u) - - return &newUUID -} - var customer0 = model.Customer{ CustomerID: 1, StoreID: 1, FirstName: "Mary", LastName: "Smith", - Email: StringPtr("mary.smith@sakilacustomer.org"), + Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), AddressID: 5, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var customer1 = model.Customer{ @@ -80,12 +36,12 @@ var customer1 = model.Customer{ StoreID: 1, FirstName: "Patricia", LastName: "Johnson", - Email: StringPtr("patricia.johnson@sakilacustomer.org"), + Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), AddressID: 6, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var lastCustomer = model.Customer{ @@ -93,10 +49,10 @@ var lastCustomer = model.Customer{ StoreID: 2, FirstName: "Austin", LastName: "Cintron", - Email: StringPtr("austin.cintron@sakilacustomer.org"), + Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), AddressID: 605, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } diff --git a/tests/testdata b/tests/testdata index 889e07c..1745be3 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17 +Subproject commit 1745be34a649c0f37d0d31d7c0352a1248ace2dc From 30284af33ec8b1e8e17317af27da6bc8b64b6443 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 2 May 2020 22:26:08 +0200 Subject: [PATCH 05/23] Fix MariaDB build. --- tests/mysql/alltypes_test.go | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 636359d..2f269ac 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1033,24 +1033,24 @@ var toInsert = model.AllTypes{ DateTime: time.Now(), DateTimePtr: testutils.TimePtr(time.Now()), Timestamp: time.Now(), - TimestampPtr: testutils.TimePtr(time.Now()), - Year: 2000, - YearPtr: testutils.Int16Ptr(2001), - Char: "abcd", - CharPtr: testutils.StringPtr("absd"), - VarChar: "abcd", - VarCharPtr: testutils.StringPtr("absd"), - Binary: []byte("1010"), - BinaryPtr: testutils.ByteArrayPtr([]byte("100001")), - VarBinary: []byte("1010"), - VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")), - Blob: []byte("large file"), - BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), - Text: "some text", - TextPtr: testutils.StringPtr("text"), - Enum: model.AllTypesEnum_Value1, - JSON: "{}", - JSONPtr: testutils.StringPtr(`{"a": 1}`), + //TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB + Year: 2000, + YearPtr: testutils.Int16Ptr(2001), + Char: "abcd", + CharPtr: testutils.StringPtr("absd"), + VarChar: "abcd", + VarCharPtr: testutils.StringPtr("absd"), + Binary: []byte("1010"), + BinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + VarBinary: []byte("1010"), + VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + Blob: []byte("large file"), + BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + Text: "some text", + TextPtr: testutils.StringPtr("text"), + Enum: model.AllTypesEnum_Value1, + JSON: "{}", + JSONPtr: testutils.StringPtr(`{"a": 1}`), } var allTypesJson = ` From 980b9b6aac752282fd2310e85fc83c8390efc20c Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 3 May 2020 20:46:21 +0200 Subject: [PATCH 06/23] Add ON DUPLICATE KEY UPDATE support (MySQL). --- internal/jet/clause.go | 9 ++-- internal/jet/column.go | 53 ----------------------- internal/jet/column_assigment.go | 20 +++++++++ internal/jet/column_list.go | 60 ++++++++++++++++++++++++++ internal/jet/column_types.go | 71 +++++++++++++++++++++++++++++++ internal/testutils/test_utils.go | 8 ++-- mysql/insert_statement.go | 61 +++++++++++++++++++------- mysql/insert_statement_test.go | 47 ++++++++++++++++++++ mysql/types.go | 3 ++ mysql/utils_test.go | 6 ++- postgres/clause_test.go | 11 ++--- postgres/conflict_action.go | 10 +---- postgres/insert_statement_test.go | 20 ++++----- postgres/types.go | 3 ++ tests/mysql/alltypes_test.go | 47 ++++++++++++++++++++ tests/mysql/insert_test.go | 42 ++++++++++++++++++ tests/mysql/main_test.go | 3 ++ tests/postgres/insert_test.go | 23 ++++++---- 18 files changed, 388 insertions(+), 109 deletions(-) create mode 100644 internal/jet/column_assigment.go create mode 100644 internal/jet/column_list.go diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 90d194e..349c6dc 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -518,7 +518,7 @@ type SetPair struct { } // SetClause clause -type SetClause []SetPair +type SetClause []ColumnAssigment // Serialize for SetClause func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { @@ -526,16 +526,15 @@ func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, optio out.WriteString("SET") out.IncreaseIdent(4) - for i, pair := range s { + for i, assigment := range s { if i > 0 { out.WriteString(",") out.NewLine() } - pair.Column.serialize(statementType, out, ShortName.WithFallTrough(options)...) - out.WriteString("=") - pair.Value.serialize(statementType, out, FallTrough(options)...) + assigment.serialize(statementType, out, FallTrough(options)...) } + out.DecreaseIdent(4) } diff --git a/internal/jet/column.go b/internal/jet/column.go index 0fd59be..3e4c300 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -115,56 +115,3 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder out.WriteIdentifier(c.name) } } - -//------------------------------------------------------// - -// ColumnList is a helper type to support list of columns as single projection -type ColumnList []ColumnExpression - -func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { - newProjectionList := ProjectionList{} - - for _, column := range cl { - newProjectionList = append(newProjectionList, column.fromImpl(subQuery)) - } - - return newProjectionList -} - -func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("(") - for i, column := range cl { - if i > 0 { - out.WriteString(", ") - } - column.serialize(statement, out, FallTrough(options)...) - } - out.WriteString(")") -} - -func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) { - projections := ColumnListToProjectionList(cl) - - SerializeProjectionList(statement, projections, out) -} - -// dummy column interface implementation - -// Name is placeholder for ColumnList to implement Column interface -func (cl ColumnList) Name() string { return "" } - -// TableName is placeholder for ColumnList to implement Column interface -func (cl ColumnList) TableName() string { return "" } -func (cl ColumnList) setTableName(name string) {} -func (cl ColumnList) setSubQuery(subQuery SelectTable) {} -func (cl ColumnList) defaultAlias() string { return "" } - -// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName -func SetTableName(columnExpression ColumnExpression, tableName string) { - columnExpression.setTableName(tableName) -} - -// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery -func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) { - columnExpression.setSubQuery(subQuery) -} diff --git a/internal/jet/column_assigment.go b/internal/jet/column_assigment.go new file mode 100644 index 0000000..440f3eb --- /dev/null +++ b/internal/jet/column_assigment.go @@ -0,0 +1,20 @@ +package jet + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment interface { + Serializer + isColumnAssigment() +} + +type columnAssigmentImpl struct { + column ColumnSerializer + expression Expression +} + +func (a columnAssigmentImpl) isColumnAssigment() {} + +func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + a.column.serialize(statement, out, ShortName.WithFallTrough(options)...) + out.WriteString("=") + a.expression.serialize(statement, out, FallTrough(options)...) +} diff --git a/internal/jet/column_list.go b/internal/jet/column_list.go new file mode 100644 index 0000000..8483c76 --- /dev/null +++ b/internal/jet/column_list.go @@ -0,0 +1,60 @@ +package jet + +// ColumnList is a helper type to support list of columns as single projection +type ColumnList []ColumnExpression + +// SET creates column assigment for each column in column list. expression should be created by ROW function +func (cl ColumnList) SET(expression Expression) ColumnAssigment { + return columnAssigmentImpl{ + column: cl, + expression: expression, + } +} + +func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { + newProjectionList := ProjectionList{} + + for _, column := range cl { + newProjectionList = append(newProjectionList, column.fromImpl(subQuery)) + } + + return newProjectionList +} + +func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString("(") + for i, column := range cl { + if i > 0 { + out.WriteString(", ") + } + column.serialize(statement, out, FallTrough(options)...) + } + out.WriteString(")") +} + +func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) { + projections := ColumnListToProjectionList(cl) + + SerializeProjectionList(statement, projections, out) +} + +// dummy column interface implementation + +// Name is placeholder for ColumnList to implement Column interface +func (cl ColumnList) Name() string { return "" } + +// TableName is placeholder for ColumnList to implement Column interface +func (cl ColumnList) TableName() string { return "" } +func (cl ColumnList) setTableName(name string) {} +func (cl ColumnList) setSubQuery(subQuery SelectTable) {} +func (cl ColumnList) defaultAlias() string { return "" } + +// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName +func SetTableName(columnExpression ColumnExpression, tableName string) { + columnExpression.setTableName(tableName) +} + +// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery +func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) { + columnExpression.setSubQuery(subQuery) +} diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index 58f6751..a606a4e 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -6,6 +6,7 @@ type ColumnBool interface { Column From(subQuery SelectTable) ColumnBool + SET(boolExp BoolExpression) ColumnAssigment } type boolColumnImpl struct { @@ -21,6 +22,13 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { return newBoolColumn } +func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: boolExp, + } +} + // BoolColumn creates named bool column. func BoolColumn(name string) ColumnBool { boolColumn := &boolColumnImpl{} @@ -38,6 +46,7 @@ type ColumnFloat interface { Column From(subQuery SelectTable) ColumnFloat + SET(floatExp FloatExpression) ColumnAssigment } type floatColumnImpl struct { @@ -53,6 +62,13 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { return newFloatColumn } +func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: floatExp, + } +} + // FloatColumn creates named float column. func FloatColumn(name string) ColumnFloat { floatColumn := &floatColumnImpl{} @@ -70,6 +86,7 @@ type ColumnInteger interface { Column From(subQuery SelectTable) ColumnInteger + SET(intExp IntegerExpression) ColumnAssigment } type integerColumnImpl struct { @@ -86,6 +103,13 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { return newIntColumn } +func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: intExp, + } +} + // IntegerColumn creates named integer column. func IntegerColumn(name string) ColumnInteger { integerColumn := &integerColumnImpl{} @@ -104,6 +128,7 @@ type ColumnString interface { Column From(subQuery SelectTable) ColumnString + SET(stringExp StringExpression) ColumnAssigment } type stringColumnImpl struct { @@ -120,6 +145,13 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { return newStrColumn } +func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: stringExp, + } +} + // StringColumn creates named string column. func StringColumn(name string) ColumnString { stringColumn := &stringColumnImpl{} @@ -137,6 +169,7 @@ type ColumnTime interface { Column From(subQuery SelectTable) ColumnTime + SET(timeExp TimeExpression) ColumnAssigment } type timeColumnImpl struct { @@ -152,6 +185,13 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { return newTimeColumn } +func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timeExp, + } +} + // TimeColumn creates named time column func TimeColumn(name string) ColumnTime { timeColumn := &timeColumnImpl{} @@ -183,6 +223,13 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { return newTimezColumn } +func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timezExp, + } +} + // TimezColumn creates named time with time zone column. func TimezColumn(name string) ColumnTimez { timezColumn := &timezColumnImpl{} @@ -200,6 +247,7 @@ type ColumnTimestamp interface { Column From(subQuery SelectTable) ColumnTimestamp + SET(timestampExp TimestampExpression) ColumnAssigment } type timestampColumnImpl struct { @@ -215,6 +263,13 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { return newTimestampColumn } +func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timestampExp, + } +} + // TimestampColumn creates named timestamp column func TimestampColumn(name string) ColumnTimestamp { timestampColumn := ×tampColumnImpl{} @@ -232,6 +287,7 @@ type ColumnTimestampz interface { Column From(subQuery SelectTable) ColumnTimestampz + SET(timestampzExp TimestampzExpression) ColumnAssigment } type timestampzColumnImpl struct { @@ -247,6 +303,13 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { return newTimestampzColumn } +func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timestampzExp, + } +} + // TimestampzColumn creates named timestamp with time zone column. func TimestampzColumn(name string) ColumnTimestampz { timestampzColumn := ×tampzColumnImpl{} @@ -264,6 +327,7 @@ type ColumnDate interface { Column From(subQuery SelectTable) ColumnDate + SET(dateExp DateExpression) ColumnAssigment } type dateColumnImpl struct { @@ -279,6 +343,13 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { return newDateColumn } +func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: dateExp, + } +} + // DateColumn creates named date column. func DateColumn(name string) ColumnDate { dateColumn := &dateColumnImpl{} diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index aa53b1e..06035dd 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -24,12 +24,12 @@ import ( func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { res, err := stmt.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) rows, err := res.RowsAffected() - assert.NoError(t, err) + require.NoError(t, err) if len(rowsAffected) > 0 { - assert.Equal(t, rows, rowsAffected[0]) + require.Equal(t, rows, rowsAffected[0]) } } @@ -224,7 +224,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st // AssertDeepEqual checks if actual and expected objects are deeply equal. func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { - assert.True(t, cmp.Equal(actual, expected), msg) + require.True(t, cmp.Equal(actual, expected), msg) } // BoolPtr returns address of bool parameter diff --git a/mysql/insert_statement.go b/mysql/insert_statement.go index a21089c..a4ecc94 100644 --- a/mysql/insert_statement.go +++ b/mysql/insert_statement.go @@ -13,13 +13,15 @@ type InsertStatement interface { MODEL(data interface{}) InsertStatement MODELS(data interface{}) InsertStatement + ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement + QUERY(selectStatement SelectStatement) InsertStatement } func newInsertStatement(table Table, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, - &newInsert.Insert, &newInsert.ValuesQuery) + &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnDuplicateKey) newInsert.Insert.Table = table newInsert.Insert.Columns = columns @@ -30,26 +32,55 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement { type insertStatementImpl struct { jet.SerializerStatement - Insert jet.ClauseInsert - ValuesQuery jet.ClauseValuesQuery + Insert jet.ClauseInsert + ValuesQuery jet.ClauseValuesQuery + OnDuplicateKey onDuplicateKeyUpdateClause } -func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { - i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) - return i +func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) + return is } -func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { - i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data)) - return i +func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data)) + return is } -func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { - i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...) - return i +func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...) + return is } -func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { - i.ValuesQuery.Query = selectStatement - return i +func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement { + is.OnDuplicateKey = assigments + return is +} + +func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { + is.ValuesQuery.Query = selectStatement + return is +} + +type onDuplicateKeyUpdateClause []jet.ColumnAssigment + +// Serialize for SetClause +func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(s) == 0 { + return + } + out.NewLine() + out.WriteString("ON DUPLICATE KEY UPDATE") + out.IncreaseIdent(24) + + for i, assigment := range s { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + + jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...) + } + + out.DecreaseIdent(24) } diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index 65c8fba..95814d2 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -133,3 +133,50 @@ VALUES (DEFAULT, ?); assertStatementSql(t, stmt, expectedSQL, "two") } + +func TestInsertOnDuplicateKeyUpdate(t *testing.T) { + stmt := func() InsertStatement { + return table1.INSERT(table1Col1, table1ColFloat). + VALUES(DEFAULT, "two") + } + + t.Run("empty list", func(t *testing.T) { + stmt := stmt().ON_DUPLICATE_KEY_UPDATE() + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?); +`, "two") + }) + + t.Run("one set", func(t *testing.T) { + stmt := stmt().ON_DUPLICATE_KEY_UPDATE(table1ColFloat.SET(Float(11.1))) + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?) +ON DUPLICATE KEY UPDATE col_float = ?; +`, "two", 11.1) + }) + + t.Run("all types set", func(t *testing.T) { + stmt := stmt().ON_DUPLICATE_KEY_UPDATE( + table1ColBool.SET(Bool(true)), + table1ColInt.SET(Int(11)), + table1ColFloat.SET(Float(11.1)), + table1ColString.SET(String("str")), + table1ColTime.SET(Time(11, 23, 11)), + table1ColTimestamp.SET(Timestamp(2020, 1, 22, 3, 4, 5)), + table1ColDate.SET(Date(2020, 12, 1)), + ) + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?) +ON DUPLICATE KEY UPDATE col_bool = ?, + col_int = ?, + col_float = ?, + col_string = ?, + col_time = CAST(? AS TIME), + col_timestamp = TIMESTAMP(?), + col_date = CAST(? AS DATE); +`, "two", true, int64(11), 11.1, "str", "11:23:11", "2020-01-22 03:04:05", "2020-12-01") + }) +} diff --git a/mysql/types.go b/mysql/types.go index 4ef84b4..908fce5 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -10,3 +10,6 @@ type Projection = jet.Projection // ProjectionList can be used to create conditional constructed projection list. type ProjectionList = jet.ProjectionList + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment = jet.ColumnAssigment diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 709097d..584bfee 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -7,12 +7,14 @@ import ( ) var table1Col1 = IntegerColumn("col1") +var table1ColBool = BoolColumn("col_bool") var table1ColInt = IntegerColumn("col_int") var table1ColFloat = FloatColumn("col_float") +var table1ColString = StringColumn("col_string") var table1Col3 = IntegerColumn("col3") var table1ColTimestamp = TimestampColumn("col_timestamp") -var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") +var table1ColTime = TimeColumn("col_time") var table1 = NewTable( "db", @@ -20,10 +22,12 @@ var table1 = NewTable( table1Col1, table1ColInt, table1ColFloat, + table1ColString, table1Col3, table1ColBool, table1ColDate, table1ColTimestamp, + table1ColTime, ) var table2Col3 = IntegerColumn("col3") diff --git a/postgres/clause_test.go b/postgres/clause_test.go index 7f64c61..5602505 100644 --- a/postgres/clause_test.go +++ b/postgres/clause_test.go @@ -21,11 +21,12 @@ ON CONFLICT (col_bool) DO NOTHING`) ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}} - onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE( - SET(table1ColBool, Bool(true)). - SET(table1ColInt, Int(1)). - WHERE(table2ColFloat.GT(Float(11.1))), - ) + onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)). + DO_UPDATE( + SET(table1ColBool.SET(Bool(true)), + table1ColInt.SET(Int(11))). + WHERE(table2ColFloat.GT(Float(11.1))), + ) assertClauseSerialize(t, onConflict, ` ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE SET col_bool = $1, diff --git a/postgres/conflict_action.go b/postgres/conflict_action.go index 5ff7a19..b7e9e2e 100644 --- a/postgres/conflict_action.go +++ b/postgres/conflict_action.go @@ -4,16 +4,15 @@ import "github.com/go-jet/jet/internal/jet" type conflictAction interface { jet.Serializer - SET(column jet.ColumnSerializer, expression interface{}) conflictAction WHERE(condition BoolExpression) conflictAction } // SET creates conflict action for ON_CONFLICT clause -func SET(column jet.ColumnSerializer, expression interface{}) conflictAction { +func SET(assigments ...ColumnAssigment) conflictAction { conflictAction := updateConflictActionImpl{} conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"} conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where) - conflictAction.SET(column, expression) + conflictAction.set = assigments return &conflictAction } @@ -25,11 +24,6 @@ type updateConflictActionImpl struct { where jet.ClauseWhere } -func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction { - u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)}) - return u -} - func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction { u.where.Condition = condition return u diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index d80c1a1..0761cf3 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -153,10 +153,10 @@ func TestInsert_ON_CONFLICT(t *testing.T) { VALUES("1", "2"). VALUES("theta", "beta"). ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( - SET(table1ColBool, "12"). - SET(table2ColInt, 1). - SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). - WHERE(table1Col1.GT(Int(2))), + SET(table1ColBool.SET(Bool(true)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), ). RETURNING(table1Col1, table1ColBool) @@ -166,7 +166,7 @@ VALUES ('one', 'two'), ('1', '2'), ('theta', 'beta') ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE - SET col_bool = '12', + SET col_bool = TRUE, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 @@ -180,10 +180,10 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { VALUES("one", "two"). VALUES("1", "2"). ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( - SET(table1ColBool, "12"). - SET(table2ColInt, 1). - SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). - WHERE(table1Col1.GT(Int(2))), + SET(table1ColBool.SET(Bool(false)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), ). RETURNING(table1Col1, table1ColBool) @@ -192,7 +192,7 @@ INSERT INTO db.table1 (col1, col_bool) VALUES ('one', 'two'), ('1', '2') ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE - SET col_bool = '12', + SET col_bool = FALSE, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 diff --git a/postgres/types.go b/postgres/types.go index 58a8ae9..48de455 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -10,3 +10,6 @@ type Projection = jet.Projection // ProjectionList can be used to create conditional constructed projection list. type ProjectionList = jet.ProjectionList + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment = jet.ColumnAssigment diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 2f269ac..42e8aa3 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -991,6 +991,53 @@ func TestAllTypesInsert(t *testing.T) { require.NoError(t, err) } +func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) { + tx, err := db.Begin() + require.NoError(t, err) + + toInsert := model.AllTypes{ + Boolean: true, + Integer: 124, + Float: 45.67, + Blob: []byte("blob"), + Text: "text", + JSON: "{}", + Time: time.Now(), + Timestamp: time.Now(), + Date: time.Now(), + } + + stmt := AllTypes.INSERT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.Float, + AllTypes.Blob, + AllTypes.Text, + AllTypes.JSON, + AllTypes.Time, + AllTypes.Timestamp, + AllTypes.Date, + ). + MODEL(toInsert). + ON_DUPLICATE_KEY_UPDATE( + AllTypes.Boolean.SET(Bool(false)), + AllTypes.Integer.SET(Int(4)), + AllTypes.Float.SET(Float(0.67)), + AllTypes.Text.SET(String("new text")), + AllTypes.Time.SET(TimeT(time.Now())), + AllTypes.Timestamp.SET(TimestampT(time.Now())), + AllTypes.Date.SET(DateT(time.Now())), + ) + + fmt.Println(stmt.DebugSql()) + + _, err = stmt.Exec(tx) + assert.NoError(t, err) + + err = tx.Rollback() + require.NoError(t, err) +} + var toInsert = model.AllTypes{ Boolean: false, BooleanPtr: testutils.BoolPtr(true), diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index cb27f61..0b1fd2e 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -7,6 +7,8 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "math/rand" "testing" "time" ) @@ -248,6 +250,46 @@ INSERT INTO test_sample.link (url, name) ( assert.Equal(t, len(youtubeLinks), 2) } +func TestInsertOnDuplicateKey(t *testing.T) { + randId := rand.Int31() + + stmt := Link.INSERT(). + VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_DUPLICATE_KEY_UPDATE( + Link.ID.SET(Link.ID.ADD(Int(11))), + Link.Name.SET(String("PostgreSQL Tutorial 2")), + ) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.link +VALUES (?, ?, ?, DEFAULT), + (?, ?, ?, DEFAULT) +ON DUPLICATE KEY UPDATE id = (id + ?), + name = ?; +`, randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + int64(11), "PostgreSQL Tutorial 2") + + testutils.AssertExec(t, stmt, db, 3) + + newLinks := []model.Link{} + + err := SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))). + Query(db, &newLinks) + + require.NoError(t, err) + require.Len(t, newLinks, 1) + require.Equal(t, newLinks[0], model.Link{ + ID: randId + 11, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial 2", + Description: nil, + }) +} + func TestInsertWithQueryContext(t *testing.T) { cleanUpLinkTable(t) diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index eb19a68..c7db884 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "flag" "github.com/go-jet/jet/tests/dbconfig" + "math/rand" + "time" _ "github.com/go-sql-driver/mysql" @@ -28,6 +30,7 @@ func sourceIsMariaDB() bool { } func TestMain(m *testing.M) { + rand.Seed(time.Now().Unix()) defer profile.Start().Stop() var err error diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index ecbd3c5..7367e1d 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -2,7 +2,6 @@ package postgres import ( "context" - "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" @@ -134,8 +133,10 @@ ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). ON_CONFLICT(Link.ID).DO_UPDATE( - SET(Link.ID, Link.EXCLUDED.ID). - SET(Link.URL, "http://www.postgresqltutorial2.com"), + SET( + Link.ID.SET(Link.EXCLUDED.ID), + Link.URL.SET(String("http://www.postgresqltutorial2.com")), + ), ). RETURNING(Link.AllColumns) @@ -161,8 +162,10 @@ RETURNING link.id AS "link.id", VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE( - SET(Link.ID, Link.EXCLUDED.ID). - SET(Link.URL, "http://www.postgresqltutorial2.com"), + SET( + Link.ID.SET(Link.EXCLUDED.ID), + Link.URL.SET(String("http://www.postgresqltutorial2.com")), + ), ). RETURNING(Link.AllColumns) @@ -188,9 +191,13 @@ RETURNING link.id AS "link.id", stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE( - SET(Link.ID, SELECT(MAXi(Link.ID).ADD(Int(1))).FROM(Link)). - SET(ColumnList{Link.Name, Link.Description}, jet.ROW(Link.EXCLUDED.Name, String("new description"))). - WHERE(Link.Description.IS_NOT_NULL()), + SET( + Link.ID.SET( + IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))). + FROM(Link)), + ), + ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String("new description"))), + ).WHERE(Link.Description.IS_NOT_NULL()), ) testutils.AssertDebugStatementSql(t, stmt, ` From a4b47106376c85f83b6acde49640f592895c82b1 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 3 May 2020 21:30:57 +0200 Subject: [PATCH 07/23] Generate different sql builder files for MySQL and PostgreSQL. --- generator/internal/template/generate.go | 9 ++++ generator/internal/template/templates.go | 60 +++++++++++++++++++++++- tests/mysql/generator_test.go | 50 +++++--------------- tests/postgres/generator_test.go | 6 +-- 4 files changed, 82 insertions(+), 43 deletions(-) diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go index 3872f76..46e72f9 100644 --- a/generator/internal/template/generate.go +++ b/generator/internal/template/generate.go @@ -22,6 +22,7 @@ func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect j err := utils.CleanUpGeneratedFiles(destDir) utils.PanicOnError(err) + tableSQLBuilderTemplate := getTableSQLBuilderTemplate(dialect) generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) @@ -33,6 +34,14 @@ func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect j fmt.Println("Done") } +func getTableSQLBuilderTemplate(dialect jet.Dialect) string { + if dialect.Name() == "PostgreSQL" { + return tablePostgreSQLBuilderTemplate + } + + return tableSQLBuilderTemplate +} + func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { if len(metaData) == 0 { return diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index 353e6e1..7f432f1 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -25,6 +25,63 @@ import ( var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() +type {{.GoStructName}} struct { + {{dialect.PackageName}}.Table + + //Columns +{{- range .Columns}} + {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}} +{{- end}} + + AllColumns {{dialect.PackageName}}.ColumnList + MutableColumns {{dialect.PackageName}}.ColumnList +} + +// AS creates new {{.GoStructName}} with assigned alias +func (a *{{.GoStructName}}) AS(alias string) {{.GoStructName}} { + aliasTable := new{{.GoStructName}}() + aliasTable.Table.AS(alias) + return aliasTable +} + +func new{{.GoStructName}}() {{.GoStructName}} { + var ( + {{- range .Columns}} + {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") + {{- end}} + allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } + mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } + ) + + return {{.GoStructName}}{ + Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", allColumns...), + + //Columns +{{- range .Columns}} + {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, +{{- end}} + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +var tablePostgreSQLBuilderTemplate = ` +{{define "column-list" -}} + {{- range $i, $c := . }} + {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column + {{- end}} +{{- end}} + +package {{param "package"}} + +import ( + "github.com/go-jet/jet/{{dialect.PackageName}}" +) + +var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() + type {{.GoStructImplName}} struct { {{dialect.PackageName}}.Table @@ -43,7 +100,7 @@ type {{.GoStructName}} struct { EXCLUDED {{.GoStructImplName}} } -// creates new {{.GoStructName}} with assigned alias +// AS creates new {{.GoStructName}} with assigned alias func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} { aliasTable := new{{.GoStructName}}() aliasTable.Table.AS(alias) @@ -78,7 +135,6 @@ func new{{.GoStructName}}Impl(schemaName, tableName string) {{.GoStructImplName} MutableColumns: mutableColumns, } } - ` var tableModelTemplate = `package model diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index f28c5c1..23fc9ad 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -142,7 +142,7 @@ import ( var Actor = newActorTable() -type actorTable struct { +type ActorTable struct { mysql.Table //Columns @@ -155,27 +155,14 @@ type actorTable struct { MutableColumns mysql.ColumnList } -type ActorTable struct { - actorTable - - EXCLUDED actorTable -} - -// creates new ActorTable with assigned alias -func (a *ActorTable) AS(alias string) *ActorTable { +// AS creates new ActorTable with assigned alias +func (a *ActorTable) AS(alias string) ActorTable { aliasTable := newActorTable() aliasTable.Table.AS(alias) return aliasTable } -func newActorTable() *ActorTable { - return &ActorTable{ - actorTable: newActorTableImpl("dvds", "actor"), - EXCLUDED: newActorTableImpl("", "excluded"), - } -} - -func newActorTableImpl(schemaName, tableName string) actorTable { +func newActorTable() ActorTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") FirstNameColumn = mysql.StringColumn("first_name") @@ -185,8 +172,8 @@ func newActorTableImpl(schemaName, tableName string) actorTable { mutableColumns = mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} ) - return actorTable{ - Table: mysql.NewTable(schemaName, tableName, allColumns...), + return ActorTable{ + Table: mysql.NewTable("dvds", "actor", allColumns...), //Columns ActorID: ActorIDColumn, @@ -238,7 +225,7 @@ import ( var ActorInfo = newActorInfoTable() -type actorInfoTable struct { +type ActorInfoTable struct { mysql.Table //Columns @@ -251,27 +238,14 @@ type actorInfoTable struct { MutableColumns mysql.ColumnList } -type ActorInfoTable struct { - actorInfoTable - - EXCLUDED actorInfoTable -} - -// creates new ActorInfoTable with assigned alias -func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { +// AS creates new ActorInfoTable with assigned alias +func (a *ActorInfoTable) AS(alias string) ActorInfoTable { aliasTable := newActorInfoTable() aliasTable.Table.AS(alias) return aliasTable } -func newActorInfoTable() *ActorInfoTable { - return &ActorInfoTable{ - actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"), - EXCLUDED: newActorInfoTableImpl("", "excluded"), - } -} - -func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable { +func newActorInfoTable() ActorInfoTable { var ( ActorIDColumn = mysql.IntegerColumn("actor_id") FirstNameColumn = mysql.StringColumn("first_name") @@ -281,8 +255,8 @@ func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable { mutableColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} ) - return actorInfoTable{ - Table: mysql.NewTable(schemaName, tableName, allColumns...), + return ActorInfoTable{ + Table: mysql.NewTable("dvds", "actor_info", allColumns...), //Columns ActorID: ActorIDColumn, diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index a9e771d..603fb0f 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -194,7 +194,7 @@ type ActorTable struct { EXCLUDED actorTable } -// creates new ActorTable with assigned alias +// AS creates new ActorTable with assigned alias func (a *ActorTable) AS(alias string) *ActorTable { aliasTable := newActorTable() aliasTable.Table.AS(alias) @@ -290,7 +290,7 @@ type ActorInfoTable struct { EXCLUDED actorInfoTable } -// creates new ActorInfoTable with assigned alias +// AS creates new ActorInfoTable with assigned alias func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { aliasTable := newActorInfoTable() aliasTable.Table.AS(alias) @@ -580,7 +580,7 @@ type AllTypesTable struct { EXCLUDED allTypesTable } -// creates new AllTypesTable with assigned alias +// AS creates new AllTypesTable with assigned alias func (a *AllTypesTable) AS(alias string) *AllTypesTable { aliasTable := newAllTypesTable() aliasTable.Table.AS(alias) From ebcbadef243f086728ce9462d58450d6af90fbb7 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 9 May 2020 10:49:09 +0200 Subject: [PATCH 08/23] Add new typesafe SET operator for UPDATE statement. --- internal/jet/clause.go | 22 ++-- internal/testutils/test_utils.go | 2 +- mysql/table.go | 8 +- mysql/update_statement.go | 22 +++- mysql/update_statement_test.go | 2 +- postgres/conflict_action.go | 2 +- postgres/table.go | 6 +- postgres/update_statement.go | 24 +++- tests/mysql/update_test.go | 90 +++++++++------ tests/postgres/scan_test.go | 30 ++--- tests/postgres/update_test.go | 184 ++++++++++++++++++++++--------- 11 files changed, 269 insertions(+), 123 deletions(-) diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 349c6dc..6091986 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -271,14 +271,17 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder, o u.Table.serialize(statementType, out, FallTrough(options)...) } -// ClauseSet struct -type ClauseSet struct { +// SetClause struct +type SetClause struct { Columns []Column Values []Serializer } // Serialize serializes clause into SQLBuilder -func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { +func (s *SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(s.Values) == 0 { + return + } out.NewLine() out.WriteString("SET") @@ -289,7 +292,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder, opti out.IncreaseIdent(4) for i, column := range s.Columns { if i > 0 { - out.WriteString(", ") + out.WriteString(",") out.NewLine() } @@ -517,11 +520,14 @@ type SetPair struct { Value Serializer } -// SetClause clause -type SetClause []ColumnAssigment +// SetClauseNew clause +type SetClauseNew []ColumnAssigment -// Serialize for SetClause -func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { +// Serialize for SetClauseNew +func (s SetClauseNew) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(s) == 0 { + return + } out.NewLine() out.WriteString("SET") out.IncreaseIdent(4) diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 06035dd..ab36103 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -117,7 +117,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st } debuqSql := query.DebugSql() - assert.Equal(t, debuqSql, expectedQuery) + require.Equal(t, debuqSql, expectedQuery) } // AssertSerialize checks if clause serialize produces expected query and args diff --git a/mysql/table.go b/mysql/table.go index a4cf042..8287159 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -8,7 +8,7 @@ type Table interface { readableTable INSERT(columns ...jet.Column) InsertStatement - UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement + UPDATE(columns ...jet.Column) UpdateStatement DELETE() DeleteStatement LOCK() LockStatement } @@ -35,7 +35,7 @@ type readableTable interface { type joinSelectUpdateTable interface { ReadableTable - UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement + UPDATE(columns ...jet.Column) UpdateStatement } // ReadableTable interface @@ -98,8 +98,8 @@ func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement { return newInsertStatement(t.parent, jet.UnwidColumnList(columns)) } -func (t *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { - return newUpdateStatement(t.parent, jet.UnwindColumns(column, columns...)) +func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement { + return newUpdateStatement(t.parent, jet.UnwidColumnList(columns)) } func (t *tableImpl) DELETE() DeleteStatement { diff --git a/mysql/update_statement.go b/mysql/update_statement.go index ce4498f..ed8d515 100644 --- a/mysql/update_statement.go +++ b/mysql/update_statement.go @@ -16,14 +16,18 @@ type updateStatementImpl struct { jet.SerializerStatement Update jet.ClauseUpdate - Set jet.ClauseSet + Set jet.SetClause + SetNew jet.SetClauseNew Where jet.ClauseWhere } func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { update := &updateStatementImpl{} - update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, - &update.Set, &update.Where) + update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, + &update.Update, + &update.Set, + &update.SetNew, + &update.Where) update.Update.Table = table update.Set.Columns = columns @@ -33,7 +37,17 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { } func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { - u.Set.Values = jet.UnwindRowFromValues(value, values) + columnAssigment, isColumnAssigment := value.(ColumnAssigment) + + if isColumnAssigment { + u.SetNew = []ColumnAssigment{columnAssigment} + for _, value := range values { + u.SetNew = append(u.SetNew, value.(ColumnAssigment)) + } + } else { + u.Set.Values = jet.UnwindRowFromValues(value, values) + } + return u } diff --git a/mysql/update_statement_test.go b/mysql/update_statement_test.go index fc933aa..fe3be01 100644 --- a/mysql/update_statement_test.go +++ b/mysql/update_statement_test.go @@ -23,7 +23,7 @@ WHERE table1.col_int >= ?; func TestUpdateWithValues(t *testing.T) { expectedSQL := ` UPDATE db.table1 -SET col_int = ?, +SET col_int = ?, col_float = ? WHERE table1.col_int >= ?; ` diff --git a/postgres/conflict_action.go b/postgres/conflict_action.go index b7e9e2e..55c9440 100644 --- a/postgres/conflict_action.go +++ b/postgres/conflict_action.go @@ -20,7 +20,7 @@ type updateConflictActionImpl struct { jet.Serializer doUpdate jet.KeywordClause - set jet.SetClause + set jet.SetClauseNew where jet.ClauseWhere } diff --git a/postgres/table.go b/postgres/table.go index bc2f5c2..c82a8f7 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -31,7 +31,7 @@ type readableTable interface { type writableTable interface { INSERT(columns ...jet.Column) InsertStatement - UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement + UPDATE(columns ...jet.Column) UpdateStatement DELETE() DeleteStatement LOCK() LockStatement } @@ -89,8 +89,8 @@ func (w *writableTableInterfaceImpl) INSERT(columns ...jet.Column) InsertStateme return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) } -func (w *writableTableInterfaceImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { - return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) +func (w *writableTableInterfaceImpl) UPDATE(columns ...jet.Column) UpdateStatement { + return newUpdateStatement(w.parent, jet.UnwidColumnList(columns)) } func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { diff --git a/postgres/update_statement.go b/postgres/update_statement.go index d96e1e9..9c56012 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -20,14 +20,19 @@ type updateStatementImpl struct { Update jet.ClauseUpdate Set clauseSet + SetNew jet.SetClauseNew Where jet.ClauseWhere Returning clauseReturning } func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement { update := &updateStatementImpl{} - update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, - &update.Set, &update.Where, &update.Returning) + update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, + &update.Update, + &update.Set, + &update.SetNew, + &update.Where, + &update.Returning) update.Update.Table = table update.Set.Columns = columns @@ -37,7 +42,17 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme } func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { - u.Set.Values = jet.UnwindRowFromValues(value, values) + columnAssigment, isColumnAssigment := value.(ColumnAssigment) + + if isColumnAssigment { + u.SetNew = []ColumnAssigment{columnAssigment} + for _, value := range values { + u.SetNew = append(u.SetNew, value.(ColumnAssigment)) + } + } else { + u.Set.Values = jet.UnwindRowFromValues(value, values) + } + return u } @@ -62,6 +77,9 @@ type clauseSet struct { } func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(s.Values) == 0 { + return + } out.NewLine() out.WriteString("SET") diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index c1e3f19..2114fe2 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -16,22 +16,33 @@ import ( func TestUpdateValues(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET("Bong", "http://bong.com"). - WHERE(Link.Name.EQ(String("Bing"))) - var expectedSQL = ` UPDATE test_sample.link -SET name = 'Bong', +SET name = 'Bong', url = 'http://bong.com' WHERE link.name = 'Bing'; ` + t.Run("old version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) - fmt.Println(query.DebugSql()) - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, query, db) + }) - testutils.AssertExec(t, query, db) + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(String("http://bong.com")), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, stmt, db) + }) links := []model.Link{} @@ -52,21 +63,11 @@ WHERE link.name = 'Bing'; func TestUpdateWithSubQueries(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET( - SELECT(String("Bong")), - SELECT(Link2.URL). - FROM(Link2). - WHERE(Link2.Name.EQ(String("Youtube"))), - ). - WHERE(Link.Name.EQ(String("Bing"))) - expectedSQL := ` UPDATE test_sample.link SET name = ( SELECT ? - ), + ), url = ( SELECT link2.url AS "link2.url" FROM test_sample.link2 @@ -74,10 +75,37 @@ SET name = ( ) WHERE link.name = ?; ` - fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") + t.Run("old version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET( + SELECT(String("Bong")), + SELECT(Link2.URL). + FROM(Link2). + WHERE(Link2.Name.EQ(String("Youtube"))), + ). + WHERE(Link.Name.EQ(String("Bing"))) - testutils.AssertExec(t, query, db) + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") + testutils.AssertExec(t, query, db) + }) + + t.Run("new version", func(t *testing.T) { + query := Link. + UPDATE(). + SET( + Link.Name.SET(StringExp(SELECT(String("Bong")))), + Link.URL.SET(StringExp( + SELECT(Link2.URL). + FROM(Link2). + WHERE(Link2.Name.EQ(String("Youtube"))), + )), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") + testutils.AssertExec(t, query, db) + }) } func TestUpdateWithModelData(t *testing.T) { @@ -96,9 +124,9 @@ func TestUpdateWithModelData(t *testing.T) { expectedSQL := ` UPDATE test_sample.link -SET id = ?, - url = ?, - name = ?, +SET id = ?, + url = ?, + name = ?, description = ? WHERE link.id = ?; ` @@ -127,8 +155,8 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link -SET description = NULL, - name = 'DuckDuckGo', +SET description = NULL, + name = 'DuckDuckGo', url = 'http://www.duckduckgo.com' WHERE link.id = 201; ` @@ -156,22 +184,20 @@ func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link -SET url = 'http://www.duckduckgo.com', - name = 'DuckDuckGo', +SET url = 'http://www.duckduckgo.com', + name = 'DuckDuckGo', description = NULL WHERE link.id = 201; ` fmt.Println(stmt.DebugSql()) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) - testutils.AssertExec(t, stmt, db) } func TestUpdateWithInvalidModelData(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "missing struct field for column : id") }() diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index c5b6d3e..92251c3 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -12,7 +12,7 @@ import ( "testing" ) -var query = Inventory. +var oneInventoryQuery = Inventory. SELECT(Inventory.AllColumns). LIMIT(1). ORDER_BY(Inventory.InventoryID) @@ -20,69 +20,69 @@ var query = Inventory. func TestScanToInvalidDestination(t *testing.T) { t.Run("nil dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, nil, "jet: destination is nil") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, nil, "jet: destination is nil") }) t.Run("struct dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("slice dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, []struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("slice of pointers to pointer dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, []**struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []**struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("map dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, &map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("map dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, []map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("map dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, &[]map[string]string{}, "jet: unsupported slice element type") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &[]map[string]string{}, "jet: unsupported slice element type") }) } func TestScanToValidDestination(t *testing.T) { t.Run("pointer to struct", func(t *testing.T) { dest := []struct{}{} - err := query.Query(db, &dest) + err := oneInventoryQuery.Query(db, &dest) assert.NoError(t, err) }) t.Run("global query function scan", func(t *testing.T) { - queryStr, args := query.Sql() + queryStr, args := oneInventoryQuery.Sql() dest := []struct{}{} err := qrm.Query(nil, db, queryStr, args, &dest) assert.NoError(t, err) }) t.Run("pointer to slice", func(t *testing.T) { - err := query.Query(db, &[]struct{}{}) + err := oneInventoryQuery.Query(db, &[]struct{}{}) assert.NoError(t, err) }) t.Run("pointer to slice of pointer to structs", func(t *testing.T) { - err := query.Query(db, &[]*struct{}{}) + err := oneInventoryQuery.Query(db, &[]*struct{}{}) assert.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { - err := query.Query(db, &[]int32{}) + err := oneInventoryQuery.Query(db, &[]int32{}) assert.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { - err := query.Query(db, &[]*int32{}) + err := oneInventoryQuery.Query(db, &[]*int32{}) assert.NoError(t, err) }) @@ -690,7 +690,7 @@ func TestScanToSlice(t *testing.T) { } } - testutils.AssertQueryPanicErr(t, query, db, &dest, "jet: unsupported slice element type at 'Cities []**struct { *model.City }'") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &dest, "jet: unsupported slice element type at 'Cities []**struct { *model.City }'") }) } diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index ca07332..43e64fb 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -14,50 +15,69 @@ import ( func TestUpdateValues(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET("Bong", "http://bong.com"). - WHERE(Link.Name.EQ(String("Bing"))) + t.Run("deprecated version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, query, ` UPDATE test_sample.link SET (name, url) = ('Bong', 'http://bong.com') WHERE link.name = 'Bing'; -` - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") +`, "Bong", "http://bong.com", "Bing") - AssertExec(t, query, 1) + testutils.AssertExec(t, query, db, 1) - links := []model.Link{} + links := []model.Link{} - err := Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Bong"))). - Query(db, &links) + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.IN(String("Bong"))). + Query(db, &links) - assert.NoError(t, err) - assert.Equal(t, len(links), 1) - testutils.AssertDeepEqual(t, links[0], model.Link{ - ID: 204, - URL: "http://bong.com", - Name: "Bong", + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.Name.SET(String("DuckDuckGo")), + Link.URL.SET(String("www.duckduckgo.com")), + ). + WHERE(Link.Name.EQ(String("Yahoo"))) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.link +SET name = 'DuckDuckGo', + url = 'www.duckduckgo.com' +WHERE link.name = 'Yahoo'; +`) + testutils.AssertExec(t, stmt, db, 1) }) } func TestUpdateWithSubQueries(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET( - SELECT(String("Bong")), - SELECT(Link.URL). - FROM(Link). - WHERE(Link.Name.EQ(String("Bing"))), - ). - WHERE(Link.Name.EQ(String("Bing"))) + t.Run("deprecated version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET( + SELECT(String("Bong")), + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Bing"))), + ). + WHERE(Link.Name.EQ(String("Bing"))) - expectedSQL := ` + expectedSQL := ` UPDATE test_sample.link SET (name, url) = (( SELECT 'Bong' @@ -68,10 +88,34 @@ SET (name, url) = (( )) WHERE link.name = 'Bing'; ` + testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") + AssertExec(t, query, 1) + }) - AssertExec(t, query, 1) + t.Run("new version", func(t *testing.T) { + query := Link.UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(StringExp( + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Bing")))), + ), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, ` +UPDATE test_sample.link +SET name = $1, + url = ( + SELECT link.url AS "link.url" + FROM test_sample.link + WHERE link.name = $2 + ) +WHERE link.name = $3; +`, "Bong", "Bing", "Bing") + }) } func TestUpdateAndReturning(t *testing.T) { @@ -107,15 +151,16 @@ RETURNING link.id AS "link.id", func TestUpdateWithSelect(t *testing.T) { - stmt := Link.UPDATE(Link.AllColumns). - SET( - Link. - SELECT(Link.AllColumns). - WHERE(Link.ID.EQ(Int(0))), - ). - WHERE(Link.ID.EQ(Int(0))) + t.Run("deprecated version", func(t *testing.T) { + stmt := Link.UPDATE(Link.AllColumns). + SET( + Link. + SELECT(Link.AllColumns). + WHERE(Link.ID.EQ(Int(0))), + ). + WHERE(Link.ID.EQ(Int(0))) - expectedSQL := ` + expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = ( SELECT link.id AS "link.id", @@ -127,22 +172,50 @@ SET (id, url, name, description) = ( ) WHERE link.id = 0; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - AssertExec(t, stmt, 1) + AssertExec(t, stmt, 1) + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.MutableColumns.SET( + SELECT(Link.MutableColumns). + FROM(Link). + WHERE(Link.ID.EQ(Int(0))), + ), + ). + WHERE(Link.ID.EQ(Int(0))) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.link +SET (url, name, description) = ( + SELECT link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description" + FROM test_sample.link + WHERE link.id = 0 + ) +WHERE link.id = 0; +`, int64(0), int64(0)) + + AssertExec(t, stmt, 1) + }) } func TestUpdateWithInvalidSelect(t *testing.T) { - stmt := Link.UPDATE(Link.AllColumns). - SET( - Link. - SELECT(Link.ID, Link.Name). - WHERE(Link.ID.EQ(Int(0))), - ). - WHERE(Link.ID.EQ(Int(0))) + t.Run("deprecated version", func(t *testing.T) { + stmt := Link.UPDATE(Link.AllColumns). + SET( + Link. + SELECT(Link.ID, Link.Name). + WHERE(Link.ID.EQ(Int(0))), + ). + WHERE(Link.ID.EQ(Int(0))) - var expectedSQL = ` + var expectedSQL = ` UPDATE test_sample.link SET (id, url, name, description) = ( SELECT link.id AS "link.id", @@ -152,9 +225,18 @@ SET (id, url, name, description) = ( ) WHERE link.id = 0; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET(Link.AllColumns.SET(Link.SELECT(Link.MutableColumns))). + WHERE(Link.ID.EQ(Int(0))) + + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") + }) } func TestUpdateWithModelData(t *testing.T) { From 5d742837f1942af4155041823eb1dd8957e1580c Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 9 May 2020 11:00:22 +0200 Subject: [PATCH 09/23] Use testify/require instead of testify/assert for tests. --- internal/3rdparty/snaker/snaker_test.go | 16 +- internal/jet/clause_test.go | 4 +- internal/jet/sql_builder_test.go | 50 ++--- internal/jet/table_test.go | 20 +- internal/jet/testutils.go | 16 +- internal/jet/utils_test.go | 10 +- internal/testutils/test_utils.go | 27 ++- internal/utils/utils_test.go | 38 ++-- mysql/insert_statement_test.go | 6 +- postgres/insert_statement_test.go | 6 +- qrm/internal/null_types_test.go | 140 +++++++------- qrm/utill_test.go | 36 ++-- tests/mysql/alltypes_test.go | 50 +++-- tests/mysql/cast_test.go | 4 +- tests/mysql/delete_test.go | 6 +- tests/mysql/generator_test.go | 22 +-- tests/mysql/insert_test.go | 35 ++-- tests/mysql/lock_test.go | 8 +- tests/mysql/select_test.go | 64 +++---- tests/mysql/update_test.go | 16 +- tests/postgres/alltypes_test.go | 59 +++--- tests/postgres/chinook_db_test.go | 24 +-- tests/postgres/delete_test.go | 10 +- tests/postgres/generator_test.go | 56 +++--- tests/postgres/insert_test.go | 20 +- tests/postgres/lock_test.go | 12 +- tests/postgres/northwind_test.go | 4 +- tests/postgres/sample_test.go | 27 ++- tests/postgres/scan_test.go | 128 ++++++------- tests/postgres/select_test.go | 232 ++++++++++++------------ tests/postgres/update_test.go | 19 +- tests/postgres/util_test.go | 5 +- 32 files changed, 581 insertions(+), 589 deletions(-) diff --git a/internal/3rdparty/snaker/snaker_test.go b/internal/3rdparty/snaker/snaker_test.go index 83ae867..f828a91 100644 --- a/internal/3rdparty/snaker/snaker_test.go +++ b/internal/3rdparty/snaker/snaker_test.go @@ -1,16 +1,16 @@ package snaker import ( - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func TestSnakeToCamel(t *testing.T) { - assert.Equal(t, SnakeToCamel(""), "") - assert.Equal(t, SnakeToCamel("potato_"), "Potato") - assert.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased") - assert.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID") - assert.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier") - assert.Equal(t, SnakeToCamel("id"), "ID") - assert.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient") + require.Equal(t, SnakeToCamel(""), "") + require.Equal(t, SnakeToCamel("potato_"), "Potato") + require.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased") + require.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID") + require.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier") + require.Equal(t, SnakeToCamel("id"), "ID") + require.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient") } diff --git a/internal/jet/clause_test.go b/internal/jet/clause_test.go index 9a86597..a37bf9e 100644 --- a/internal/jet/clause_test.go +++ b/internal/jet/clause_test.go @@ -1,14 +1,14 @@ package jet import ( - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func TestClauseSelect_Serialize(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "jet: SELECT clause has to have at least one projection") + require.Equal(t, r, "jet: SELECT clause has to have at least one projection") }() selectClause := &ClauseSelect{} diff --git a/internal/jet/sql_builder_test.go b/internal/jet/sql_builder_test.go index 911df27..2aad3aa 100644 --- a/internal/jet/sql_builder_test.go +++ b/internal/jet/sql_builder_test.go @@ -2,40 +2,40 @@ package jet import ( "github.com/google/uuid" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) func TestArgToString(t *testing.T) { - assert.Equal(t, argToString(true), "TRUE") - assert.Equal(t, argToString(false), "FALSE") + require.Equal(t, argToString(true), "TRUE") + require.Equal(t, argToString(false), "FALSE") - assert.Equal(t, argToString(int(-32)), "-32") - assert.Equal(t, argToString(uint(32)), "32") - assert.Equal(t, argToString(int8(-43)), "-43") - assert.Equal(t, argToString(uint8(43)), "43") - assert.Equal(t, argToString(int16(-54)), "-54") - assert.Equal(t, argToString(uint16(54)), "54") - assert.Equal(t, argToString(int32(-65)), "-65") - assert.Equal(t, argToString(uint32(65)), "65") - assert.Equal(t, argToString(int64(-64)), "-64") - assert.Equal(t, argToString(uint64(64)), "64") - assert.Equal(t, argToString(float32(2.0)), "2") - assert.Equal(t, argToString(float64(1.11)), "1.11") + require.Equal(t, argToString(int(-32)), "-32") + require.Equal(t, argToString(uint(32)), "32") + require.Equal(t, argToString(int8(-43)), "-43") + require.Equal(t, argToString(uint8(43)), "43") + require.Equal(t, argToString(int16(-54)), "-54") + require.Equal(t, argToString(uint16(54)), "54") + require.Equal(t, argToString(int32(-65)), "-65") + require.Equal(t, argToString(uint32(65)), "65") + require.Equal(t, argToString(int64(-64)), "-64") + require.Equal(t, argToString(uint64(64)), "64") + require.Equal(t, argToString(float32(2.0)), "2") + require.Equal(t, argToString(float64(1.11)), "1.11") - assert.Equal(t, argToString("john"), "'john'") - assert.Equal(t, argToString("It's text"), "'It''s text'") - assert.Equal(t, argToString([]byte("john")), "'john'") - assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") + require.Equal(t, argToString("john"), "'john'") + require.Equal(t, argToString("It's text"), "'It''s text'") + require.Equal(t, argToString([]byte("john")), "'john'") + require.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") - assert.NoError(t, err) - assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") + require.NoError(t, err) + require.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") func() { defer func() { - assert.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter") + require.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter") }() argToString(map[string]bool{}) @@ -43,7 +43,7 @@ func TestArgToString(t *testing.T) { } func TestFallTrough(t *testing.T) { - assert.Equal(t, FallTrough([]SerializeOption{ShortName}), []SerializeOption{ShortName}) - assert.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil)) - assert.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName}) + require.Equal(t, FallTrough([]SerializeOption{ShortName}), []SerializeOption{ShortName}) + require.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil)) + require.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName}) } diff --git a/internal/jet/table_test.go b/internal/jet/table_test.go index d899a8f..66646b2 100644 --- a/internal/jet/table_test.go +++ b/internal/jet/table_test.go @@ -1,18 +1,18 @@ package jet import ( - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func TestNewTable(t *testing.T) { newTable := NewTable("schema", "table", IntegerColumn("intCol")) - assert.Equal(t, newTable.SchemaName(), "schema") - assert.Equal(t, newTable.TableName(), "table") + require.Equal(t, newTable.SchemaName(), "schema") + require.Equal(t, newTable.TableName(), "table") - assert.Equal(t, len(newTable.columns()), 1) - assert.Equal(t, newTable.columns()[0].Name(), "intCol") + require.Equal(t, len(newTable.columns()), 1) + require.Equal(t, newTable.columns()[0].Name(), "intCol") } func TestNewJoinTable(t *testing.T) { @@ -24,10 +24,10 @@ func TestNewJoinTable(t *testing.T) { assertClauseSerialize(t, joinTable, `schema.table INNER JOIN schema.table2 ON ("intCol1" = "intCol2")`) - assert.Equal(t, joinTable.SchemaName(), "schema") - assert.Equal(t, joinTable.TableName(), "") + require.Equal(t, joinTable.SchemaName(), "schema") + require.Equal(t, joinTable.TableName(), "") - assert.Equal(t, len(joinTable.columns()), 2) - assert.Equal(t, joinTable.columns()[0].Name(), "intCol1") - assert.Equal(t, joinTable.columns()[1].Name(), "intCol2") + require.Equal(t, len(joinTable.columns()), 2) + require.Equal(t, joinTable.columns()[0].Name(), "intCol1") + require.Equal(t, joinTable.columns()[1].Name(), "intCol2") } diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 1d5009e..268ae06 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -1,7 +1,7 @@ package jet import ( - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "strconv" "testing" ) @@ -56,14 +56,14 @@ func assertClauseSerialize(t *testing.T, clause Serializer, query string, args . //fmt.Println(out.Buff.String()) - assert.Equal(t, out.Buff.String(), query) - assert.Equal(t, out.Args, args) + require.Equal(t, out.Buff.String(), query) + require.Equal(t, out.Args, args) } func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { defer func() { r := recover() - assert.Equal(t, r, errString) + require.Equal(t, r, errString) }() out := SQLBuilder{Dialect: defaultDialect} @@ -76,14 +76,14 @@ func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, a //fmt.Println(out.Buff.String()) - assert.Equal(t, out.Buff.String(), query) - assert.Equal(t, out.Args, args) + require.Equal(t, out.Buff.String(), query) + require.Equal(t, out.Args, args) } func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) { out := SQLBuilder{Dialect: defaultDialect} projection.serializeForProjection(SelectStatementType, &out) - assert.Equal(t, out.Buff.String(), query) - assert.Equal(t, out.Args, args) + require.Equal(t, out.Buff.String(), query) + require.Equal(t, out.Args, args) } diff --git a/internal/jet/utils_test.go b/internal/jet/utils_test.go index e13d7ff..b907760 100644 --- a/internal/jet/utils_test.go +++ b/internal/jet/utils_test.go @@ -1,19 +1,19 @@ package jet import ( - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func TestOptionalOrDefaultString(t *testing.T) { - assert.Equal(t, OptionalOrDefaultString("default"), "default") - assert.Equal(t, OptionalOrDefaultString("default", "optional"), "optional") + require.Equal(t, OptionalOrDefaultString("default"), "default") + require.Equal(t, OptionalOrDefaultString("default", "optional"), "optional") } func TestOptionalOrDefaultExpression(t *testing.T) { defaultExpression := table2ColFloat optionalExpression := table1Col1 - assert.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression) - assert.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression) + require.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression) + require.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression) } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index ab36103..3cff7ab 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -8,7 +8,6 @@ import ( "github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/qrm" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" "os" @@ -37,7 +36,7 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) { _, err := stmt.Exec(db) - assert.Error(t, err, errorStr) + require.Error(t, err, errorStr) } func getFullPath(relativePath string) string { @@ -54,9 +53,9 @@ func PrintJson(v interface{}) { // AssertJSON check if data json output is the same as expectedJSON func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { jsonData, err := json.MarshalIndent(data, "", "\t") - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON) + require.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON) } // SaveJSONFile saves v as json at testRelativePath @@ -74,23 +73,23 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) fileJSONData, err := ioutil.ReadFile(filePath) - assert.NoError(t, err) + require.NoError(t, err) if runtime.GOOS == "windows" { fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1) } jsonData, err := json.MarshalIndent(data, "", "\t") - assert.NoError(t, err) + require.NoError(t, err) - assert.True(t, string(fileJSONData) == string(jsonData)) + require.True(t, string(fileJSONData) == string(jsonData)) //AssertDeepEqual(t, string(fileJSONData), string(jsonData)) } // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { queryStr, args := query.Sql() - assert.Equal(t, queryStr, expectedQuery) + require.Equal(t, queryStr, expectedQuery) if len(expectedArgs) == 0 { return @@ -102,7 +101,7 @@ func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) { defer func() { r := recover() - assert.Equal(t, r, errorStr) + require.Equal(t, r, errorStr) }() stmt.Sql() @@ -162,7 +161,7 @@ func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializ func AssertPanicErr(t *testing.T, fun func(), errorStr string) { defer func() { r := recover() - assert.Equal(t, r, errorStr) + require.Equal(t, r, errorStr) }() fun() @@ -172,7 +171,7 @@ func AssertPanicErr(t *testing.T, fun func(), errorStr string) { func AssertSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { defer func() { r := recover() - assert.Equal(t, r, errString) + require.Equal(t, r, errString) }() out := jet.SQLBuilder{Dialect: dialect} @@ -192,7 +191,7 @@ func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest interface{}, errString string) { defer func() { r := recover() - assert.Equal(t, r, errString) + require.Equal(t, r, errString) }() stmt.Query(db, dest) @@ -209,7 +208,7 @@ func AssertFileContent(t *testing.T, filePath string, expectedContent string) { // AssertFileNamesEqual check if all filesInfos are contained in fileNames func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { - assert.Equal(t, len(fileInfos), len(fileNames)) + require.Equal(t, len(fileInfos), len(fileNames)) fileNamesMap := map[string]bool{} @@ -218,7 +217,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } for _, fileName := range fileNames { - assert.True(t, fileNamesMap[fileName], fileName+" does not exist.") + require.True(t, fileNamesMap[fileName], fileName+" does not exist.") } } diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index 787b14b..f2b4f84 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -2,32 +2,32 @@ package utils import ( "fmt" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func TestToGoIdentifier(t *testing.T) { - assert.Equal(t, ToGoIdentifier(""), "") - assert.Equal(t, ToGoIdentifier("uuid"), "UUID") - assert.Equal(t, ToGoIdentifier("col1"), "Col1") - assert.Equal(t, ToGoIdentifier("PG-13"), "Pg13") - assert.Equal(t, ToGoIdentifier("13_pg"), "13Pg") + require.Equal(t, ToGoIdentifier(""), "") + require.Equal(t, ToGoIdentifier("uuid"), "UUID") + require.Equal(t, ToGoIdentifier("col1"), "Col1") + require.Equal(t, ToGoIdentifier("PG-13"), "Pg13") + require.Equal(t, ToGoIdentifier("13_pg"), "13Pg") - assert.Equal(t, ToGoIdentifier("mytable"), "Mytable") - assert.Equal(t, ToGoIdentifier("MYTABLE"), "Mytable") - assert.Equal(t, ToGoIdentifier("MyTaBlE"), "MyTaBlE") - assert.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE") + require.Equal(t, ToGoIdentifier("mytable"), "Mytable") + require.Equal(t, ToGoIdentifier("MYTABLE"), "Mytable") + require.Equal(t, ToGoIdentifier("MyTaBlE"), "MyTaBlE") + require.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE") - assert.Equal(t, ToGoIdentifier("my_table"), "MyTable") - assert.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable") - assert.Equal(t, ToGoIdentifier("My_Table"), "MyTable") - assert.Equal(t, ToGoIdentifier("My Table"), "MyTable") - assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable") + require.Equal(t, ToGoIdentifier("my_table"), "MyTable") + require.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable") + require.Equal(t, ToGoIdentifier("My_Table"), "MyTable") + require.Equal(t, ToGoIdentifier("My Table"), "MyTable") + require.Equal(t, ToGoIdentifier("My-Table"), "MyTable") } func TestToGoEnumValueIdentifier(t *testing.T) { - assert.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue") - assert.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100") + require.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue") + require.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100") } func TestErrorCatchErr(t *testing.T) { @@ -39,7 +39,7 @@ func TestErrorCatchErr(t *testing.T) { panic(fmt.Errorf("newError")) }() - assert.Error(t, err, "newError") + require.Error(t, err, "newError") } func TestErrorCatchNonErr(t *testing.T) { @@ -51,5 +51,5 @@ func TestErrorCatchNonErr(t *testing.T) { panic(11) }() - assert.Error(t, err, "11") + require.Error(t, err, "11") } diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index 95814d2..dbabc3f 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -1,7 +1,7 @@ package mysql import ( - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -95,7 +95,7 @@ VALUES (?, ?), func TestInsertValuesFromModelColumnMismatch(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "missing struct field for column : col1") + require.Equal(t, r, "missing struct field for column : col1") }() type Table1Model struct { Col1Prim int @@ -116,7 +116,7 @@ func TestInsertFromNonStructModel(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "jet: data has to be a struct") + require.Equal(t, r, "jet: data has to be a struct") }() table2.INSERT(table2ColInt).MODEL([]int{}) diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 0761cf3..beb8af2 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -2,7 +2,7 @@ package postgres import ( "github.com/go-jet/jet/internal/jet" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -94,7 +94,7 @@ VALUES ($1, $2), func TestInsertValuesFromModelColumnMismatch(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "missing struct field for column : col1") + require.Equal(t, r, "missing struct field for column : col1") }() type Table1Model struct { Col1Prim int @@ -115,7 +115,7 @@ func TestInsertFromNonStructModel(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "jet: data has to be a struct") + require.Equal(t, r, "jet: data has to be a struct") }() table2.INSERT(table2ColInt).MODEL([]int{}) diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go index f03d3ab..8f4adde 100644 --- a/qrm/internal/null_types_test.go +++ b/qrm/internal/null_types_test.go @@ -2,7 +2,7 @@ package internal import ( "fmt" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -10,138 +10,138 @@ import ( func TestNullByteArray(t *testing.T) { var array NullByteArray - assert.NoError(t, array.Scan(nil)) - assert.Equal(t, array.Valid, false) + require.NoError(t, array.Scan(nil)) + require.Equal(t, array.Valid, false) - assert.NoError(t, array.Scan([]byte("bytea"))) - assert.Equal(t, array.Valid, true) - assert.Equal(t, string(array.ByteArray), string([]byte("bytea"))) + require.NoError(t, array.Scan([]byte("bytea"))) + require.Equal(t, array.Valid, true) + require.Equal(t, string(array.ByteArray), string([]byte("bytea"))) - assert.Error(t, array.Scan(12), "can't scan []byte from 12") + require.Error(t, array.Scan(12), "can't scan []byte from 12") } func TestNullTime(t *testing.T) { var array NullTime - assert.NoError(t, array.Scan(nil)) - assert.Equal(t, array.Valid, false) + require.NoError(t, array.Scan(nil)) + require.Equal(t, array.Valid, false) time := time.Now() - assert.NoError(t, array.Scan(time)) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(time)) + require.Equal(t, array.Valid, true) value, _ := array.Value() - assert.Equal(t, value, time) + require.Equal(t, value, time) - assert.NoError(t, array.Scan([]byte("13:10:11"))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan([]byte("13:10:11"))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") + require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") - assert.NoError(t, array.Scan("13:10:11")) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan("13:10:11")) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") + require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") - assert.Error(t, array.Scan(12), "can't scan time.Time from 12") + require.Error(t, array.Scan(12), "can't scan time.Time from 12") } func TestNullInt8(t *testing.T) { var array NullInt8 - assert.NoError(t, array.Scan(nil)) - assert.Equal(t, array.Valid, false) + require.NoError(t, array.Scan(nil)) + require.Equal(t, array.Valid, false) - assert.NoError(t, array.Scan(int64(11))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int64(11))) + require.Equal(t, array.Valid, true) value, _ := array.Value() - assert.Equal(t, value, int8(11)) + require.Equal(t, value, int8(11)) - assert.Error(t, array.Scan("text"), "can't scan int8 from text") + require.Error(t, array.Scan("text"), "can't scan int8 from text") } func TestNullInt16(t *testing.T) { var array NullInt16 - assert.NoError(t, array.Scan(nil)) - assert.Equal(t, array.Valid, false) + require.NoError(t, array.Scan(nil)) + require.Equal(t, array.Valid, false) - assert.NoError(t, array.Scan(int64(11))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int64(11))) + require.Equal(t, array.Valid, true) value, _ := array.Value() - assert.Equal(t, value, int16(11)) + require.Equal(t, value, int16(11)) - assert.NoError(t, array.Scan(int16(20))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int16(20))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int16(20)) + require.Equal(t, value, int16(20)) - assert.NoError(t, array.Scan(int8(30))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int8(30))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int16(30)) + require.Equal(t, value, int16(30)) - assert.NoError(t, array.Scan(uint8(30))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(uint8(30))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int16(30)) + require.Equal(t, value, int16(30)) - assert.Error(t, array.Scan("text"), "can't scan int16 from text") + require.Error(t, array.Scan("text"), "can't scan int16 from text") } func TestNullInt32(t *testing.T) { var array NullInt32 - assert.NoError(t, array.Scan(nil)) - assert.Equal(t, array.Valid, false) + require.NoError(t, array.Scan(nil)) + require.Equal(t, array.Valid, false) - assert.NoError(t, array.Scan(int64(11))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int64(11))) + require.Equal(t, array.Valid, true) value, _ := array.Value() - assert.Equal(t, value, int32(11)) + require.Equal(t, value, int32(11)) - assert.NoError(t, array.Scan(int32(32))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int32(32))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int32(32)) + require.Equal(t, value, int32(32)) - assert.NoError(t, array.Scan(int16(20))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int16(20))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int32(20)) + require.Equal(t, value, int32(20)) - assert.NoError(t, array.Scan(uint16(16))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(uint16(16))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int32(16)) + require.Equal(t, value, int32(16)) - assert.NoError(t, array.Scan(int8(30))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(int8(30))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int32(30)) + require.Equal(t, value, int32(30)) - assert.NoError(t, array.Scan(uint8(30))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(uint8(30))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, int32(30)) + require.Equal(t, value, int32(30)) - assert.Error(t, array.Scan("text"), "can't scan int32 from text") + require.Error(t, array.Scan("text"), "can't scan int32 from text") } func TestNullFloat32(t *testing.T) { var array NullFloat32 - assert.NoError(t, array.Scan(nil)) - assert.Equal(t, array.Valid, false) + require.NoError(t, array.Scan(nil)) + require.Equal(t, array.Valid, false) - assert.NoError(t, array.Scan(float64(64))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(float64(64))) + require.Equal(t, array.Valid, true) value, _ := array.Value() - assert.Equal(t, value, float32(64)) + require.Equal(t, value, float32(64)) - assert.NoError(t, array.Scan(float32(32))) - assert.Equal(t, array.Valid, true) + require.NoError(t, array.Scan(float32(32))) + require.Equal(t, array.Valid, true) value, _ = array.Value() - assert.Equal(t, value, float32(32)) + require.Equal(t, value, float32(32)) - assert.Error(t, array.Scan(12), "can't scan float32 from 12") + require.Error(t, array.Scan(12), "can't scan float32 from 12") } diff --git a/qrm/utill_test.go b/qrm/utill_test.go index 045e2b9..897bb2c 100644 --- a/qrm/utill_test.go +++ b/qrm/utill_test.go @@ -2,36 +2,36 @@ package qrm import ( "github.com/google/uuid" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "reflect" "testing" "time" ) func TestIsSimpleModelType(t *testing.T) { - assert.True(t, isSimpleModelType(reflect.TypeOf(int8(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(int16(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(int32(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(int64(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(int8(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(int16(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(int32(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(int64(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) + require.True(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) - assert.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) + require.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) + require.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) - assert.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) - assert.True(t, isSimpleModelType(reflect.TypeOf(time.Now()))) - assert.True(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) + require.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) + require.True(t, isSimpleModelType(reflect.TypeOf(time.Now()))) + require.True(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) complexModelType := struct { Field1 string Field2 string }{} - assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false) - assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false) - assert.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false) - assert.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false) + require.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false) + require.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false) + require.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false) + require.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false) } diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 42e8aa3..7d9e17e 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -16,8 +16,6 @@ import ( "github.com/go-jet/jet/tests/testdata/results/common" . "github.com/go-jet/jet/mysql" - - "github.com/stretchr/testify/assert" ) func TestAllTypes(t *testing.T) { @@ -29,9 +27,9 @@ func TestAllTypes(t *testing.T) { LIMIT(2). Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.Equal(t, len(dest), 2) if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert return @@ -48,8 +46,8 @@ func TestAllTypesViewSelect(t *testing.T) { dest := []AllTypesView{} err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert return @@ -77,11 +75,11 @@ func TestUUID(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.True(t, dest.StrUUID != nil) - assert.True(t, dest.UUID.String() != uuid.UUID{}.String()) - assert.True(t, dest.StrUUID.String() != uuid.UUID{}.String()) - assert.Equal(t, dest.StrUUID.String(), dest.BinUUID.String()) + require.NoError(t, err) + require.True(t, dest.StrUUID != nil) + require.True(t, dest.UUID.String() != uuid.UUID{}.String()) + require.True(t, dest.StrUUID.String() != uuid.UUID{}.String()) + require.Equal(t, dest.StrUUID.String(), dest.BinUUID.String()) } func TestExpressionOperators(t *testing.T) { @@ -122,7 +120,7 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) @@ -213,7 +211,7 @@ FROM test_sample.all_types; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") } @@ -266,7 +264,7 @@ func TestFloatOperators(t *testing.T) { //fmt.Println(queryStr) - assert.Equal(t, queryStr, strings.Replace(` + require.Equal(t, queryStr, strings.Replace(` SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1", (all_types.'decimal' = ?) AS "eq2", (all_types.'real' = ?) AS "eq3", @@ -312,7 +310,7 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json") } @@ -449,7 +447,7 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) @@ -521,7 +519,7 @@ func TestStringOperators(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) @@ -609,7 +607,7 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestDateExpressions(t *testing.T) { @@ -684,7 +682,7 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestDateTimeExpressions(t *testing.T) { @@ -761,7 +759,7 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestTimestampExpressions(t *testing.T) { @@ -837,13 +835,13 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestTimeLiterals(t *testing.T) { loc, err := time.LoadLocation("Europe/Berlin") - assert.NoError(t, err) + require.NoError(t, err) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 351387237, loc) @@ -882,7 +880,7 @@ LIMIT ?; } err = query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) @@ -965,7 +963,7 @@ func TestINTERVAL(t *testing.T) { //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) - assert.NoError(t, err) + require.NoError(t, err) } func TestAllTypesInsert(t *testing.T) { @@ -1032,7 +1030,7 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) { fmt.Println(stmt.DebugSql()) _, err = stmt.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Rollback() require.NoError(t, err) @@ -1256,7 +1254,7 @@ FROM test_sample.user; var dest []model.User err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.PrintJson(dest) diff --git a/tests/mysql/cast_test.go b/tests/mysql/cast_test.go index 3ab1914..218665e 100644 --- a/tests/mysql/cast_test.go +++ b/tests/mysql/cast_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -55,7 +55,7 @@ FROM test_sample.all_types; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest, Result{ As1: "test", diff --git a/tests/mysql/delete_test.go b/tests/mysql/delete_test.go index 2e06d88..da91e97 100644 --- a/tests/mysql/delete_test.go +++ b/tests/mysql/delete_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -60,7 +60,7 @@ func TestDeleteQueryContext(t *testing.T) { dest := []model.Link{} err := deleteStmt.QueryContext(ctx, db, &dest) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestDeleteExecContext(t *testing.T) { @@ -77,7 +77,7 @@ func TestDeleteExecContext(t *testing.T) { _, err := deleteStmt.ExecContext(ctx, db) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func initForDeleteTest(t *testing.T) { diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 23fc9ad..b1ca685 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/generator/mysql" "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "io/ioutil" "os" "os/exec" @@ -25,23 +25,23 @@ func TestGenerator(t *testing.T) { DBName: "dvds", }) - assert.NoError(t, err) + require.NoError(t, err) assertGeneratedFiles(t) } err := os.RemoveAll(genTestDirRoot) - assert.NoError(t, err) + require.NoError(t, err) } func TestCmdGenerator(t *testing.T) { goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet") goInstallJet.Stderr = os.Stderr err := goInstallJet.Run() - assert.NoError(t, err) + require.NoError(t, err) err = os.RemoveAll(genTestDir3) - assert.NoError(t, err) + require.NoError(t, err) cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306", "-user=jet", "-password=jet", "-path="+genTestDir3) @@ -50,18 +50,18 @@ func TestCmdGenerator(t *testing.T) { cmd.Stdout = os.Stdout err = cmd.Run() - assert.NoError(t, err) + require.NoError(t, err) assertGeneratedFiles(t) err = os.RemoveAll(genTestDirRoot) - assert.NoError(t, err) + require.NoError(t, err) } func assertGeneratedFiles(t *testing.T) { // Table SQL Builder files tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", @@ -71,7 +71,7 @@ func assertGeneratedFiles(t *testing.T) { // View SQL Builder files viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") @@ -80,14 +80,14 @@ func assertGeneratedFiles(t *testing.T) { // Enums SQL Builder files enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go") testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 0b1fd2e..613a655 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -6,7 +6,6 @@ import ( . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "math/rand" "testing" @@ -34,7 +33,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT 102, "http://www.yahoo.com", "Yahoo", nil) _, err := insertQuery.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) insertedLinks := []model.Link{} @@ -43,8 +42,8 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT ORDER_BY(Link.ID). Query(db, &insertedLinks) - assert.NoError(t, err) - assert.Equal(t, len(insertedLinks), 3) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) @@ -82,7 +81,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") _, err := stmt.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) insertedLinks := []model.Link{} @@ -91,8 +90,8 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT ORDER_BY(Link.ID). Query(db, &insertedLinks) - assert.NoError(t, err) - assert.Equal(t, len(insertedLinks), 1) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 1) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) } @@ -115,7 +114,7 @@ VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") _, err := query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { @@ -138,7 +137,7 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) _, err := query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func TestInsertModelsObject(t *testing.T) { @@ -174,7 +173,7 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), "http://www.yahoo.com", "Yahoo") _, err := query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func TestInsertUsingMutableColumns(t *testing.T) { @@ -209,14 +208,14 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), "http://www.yahoo.com", "Yahoo", nil) _, err := stmt.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func TestInsertQuery(t *testing.T) { _, err := Link.DELETE(). WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))). Exec(db) - assert.NoError(t, err) + require.NoError(t, err) var expectedSQL = ` INSERT INTO test_sample.link (url, name) ( @@ -238,7 +237,7 @@ INSERT INTO test_sample.link (url, name) ( testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) _, err = query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) youtubeLinks := []model.Link{} err = Link. @@ -246,8 +245,8 @@ INSERT INTO test_sample.link (url, name) ( WHERE(Link.Name.EQ(String("Youtube"))). Query(db, &youtubeLinks) - assert.NoError(t, err) - assert.Equal(t, len(youtubeLinks), 2) + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) } func TestInsertOnDuplicateKey(t *testing.T) { @@ -304,7 +303,7 @@ func TestInsertWithQueryContext(t *testing.T) { dest := []model.Link{} err := stmt.QueryContext(ctx, db, &dest) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestInsertWithExecContext(t *testing.T) { @@ -320,10 +319,10 @@ func TestInsertWithExecContext(t *testing.T) { _, err := stmt.ExecContext(ctx, db) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func cleanUpLinkTable(t *testing.T) { _, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/tests/mysql/lock_test.go b/tests/mysql/lock_test.go index e3d5749..8aed571 100644 --- a/tests/mysql/lock_test.go +++ b/tests/mysql/lock_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) @@ -16,7 +16,7 @@ LOCK TABLES dvds.customer READ; `) _, err := query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func TestLockWrite(t *testing.T) { @@ -27,7 +27,7 @@ LOCK TABLES dvds.customer WRITE; `) _, err := query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func TestUnlockTables(t *testing.T) { @@ -38,5 +38,5 @@ UNLOCK TABLES; `) _, err := query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index c47c6d4..e5b748e 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -7,7 +7,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) @@ -30,7 +30,7 @@ WHERE actor.actor_id = ?; actor := model.Actor{} err := query.Query(db, &actor) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, actor, actor2) } @@ -59,9 +59,9 @@ ORDER BY actor.actor_id; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 200) + require.Equal(t, len(dest), 200) testutils.AssertDeepEqual(t, dest[1], actor2) //testutils.PrintJson(dest) @@ -136,11 +136,11 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) - assert.Equal(t, len(dest), 174) + require.Equal(t, len(dest), 174) //testutils.SaveJsonFile(dest, "mysql/testdata/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json") @@ -176,7 +176,7 @@ func TestSubQuery(t *testing.T) { } err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.SaveJsonFile(dest, "mysql/testdata/r_rating_films.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/r_rating_films.json") @@ -229,7 +229,7 @@ LIMIT ?; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestSelectUNION(t *testing.T) { @@ -265,7 +265,7 @@ LIMIT ?; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestSelectUNION_ALL(t *testing.T) { @@ -308,7 +308,7 @@ OFFSET ?; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestJoinQueryStruct(t *testing.T) { @@ -406,10 +406,10 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NoError(t, err) - //assert.Equal(t, len(dest), 1) - //assert.Equal(t, len(dest[0].Films), 10) - //assert.Equal(t, len(dest[0].Films[0].Actors), 10) + require.NoError(t, err) + //require.Equal(t, len(dest), 1) + //require.Equal(t, len(dest[0].Films), 10) + //require.Equal(t, len(dest[0].Films[0].Actors), 10) //testutils.SaveJsonFile(dest, "./mysql/testdata/lang_film_actor_inventory_rental.json") @@ -450,10 +450,10 @@ FOR` tx, _ := db.Begin() _, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } for lockType, lockTypeStr := range getRowLockTestData() { @@ -464,10 +464,10 @@ FOR` tx, _ := db.Begin() _, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } if sourceIsMariaDB() { @@ -482,10 +482,10 @@ FOR` tx, _ := db.Begin() _, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } } @@ -514,7 +514,7 @@ SELECT true, dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestLockInShareMode(t *testing.T) { @@ -535,7 +535,7 @@ LOCK IN SHARE MODE; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestWindowFunction(t *testing.T) { @@ -612,7 +612,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestWindowClause(t *testing.T) { @@ -649,7 +649,7 @@ ORDER BY payment.customer_id; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestSimpleView(t *testing.T) { @@ -670,9 +670,9 @@ func TestSimpleView(t *testing.T) { var dest []ActorInfo err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.Equal(t, len(dest), 10) testutils.AssertJSON(t, dest[1:2], ` [ { @@ -702,11 +702,11 @@ func TestJoinViewWithTable(t *testing.T) { } err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) - assert.Equal(t, len(dest[0].Rentals), 32) - assert.Equal(t, len(dest[1].Rentals), 27) + require.Equal(t, len(dest), 2) + require.Equal(t, len(dest[0].Rentals), 32) + require.Equal(t, len(dest[1].Rentals), 27) } func TestConditionalProjectionList(t *testing.T) { @@ -737,7 +737,7 @@ LIMIT 3; `) var dest []model.Customer err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 3) + require.Equal(t, len(dest), 3) } diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index 2114fe2..a689584 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -8,7 +8,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -51,8 +51,8 @@ WHERE link.name = 'Bing'; WHERE(Link.Name.EQ(String("Bong"))). Query(db, &links) - assert.NoError(t, err) - assert.Equal(t, len(links), 1) + require.NoError(t, err) + require.Equal(t, len(links), 1) testutils.AssertDeepEqual(t, links[0], model.Link{ ID: 204, URL: "http://bong.com", @@ -198,7 +198,7 @@ WHERE link.id = 201; func TestUpdateWithInvalidModelData(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "missing struct field for column : id") + require.Equal(t, r, "missing struct field for column : id") }() setupLinkTableForUpdateTest(t) @@ -239,7 +239,7 @@ func TestUpdateQueryContext(t *testing.T) { dest := []model.Link{} err := updateStmt.QueryContext(ctx, db, &dest) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestUpdateExecContext(t *testing.T) { @@ -257,7 +257,7 @@ func TestUpdateExecContext(t *testing.T) { _, err := updateStmt.ExecContext(ctx, db) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestUpdateWithJoin(t *testing.T) { @@ -270,7 +270,7 @@ func TestUpdateWithJoin(t *testing.T) { //fmt.Println(query.DebugSql()) _, err := query.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func setupLinkTableForUpdateTest(t *testing.T) { @@ -285,5 +285,5 @@ func setupLinkTableForUpdateTest(t *testing.T) { VALUES(204, "http://www.bing.com", "Bing", DEFAULT). Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index f03f38b..d4e6ab4 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,25 +1,24 @@ package postgres import ( + "github.com/stretchr/testify/require" "testing" "time" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" + "github.com/google/uuid" ) func TestAllTypesSelect(t *testing.T) { dest := []model.AllTypes{} err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[1], allTypesRow1) @@ -31,7 +30,7 @@ func TestAllTypesViewSelect(t *testing.T) { dest := []AllTypesView{} err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], AllTypesView(allTypesRow0)) testutils.AssertDeepEqual(t, dest[1], AllTypesView(allTypesRow1)) @@ -45,9 +44,9 @@ func TestAllTypesInsertModel(t *testing.T) { dest := []model.AllTypes{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } @@ -64,8 +63,8 @@ func TestAllTypesInsertQuery(t *testing.T) { dest := []model.AllTypes{} err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } @@ -80,7 +79,7 @@ func TestAllTypesFromSubQuery(t *testing.T) { FROM(subQuery). LIMIT(2) - assert.Equal(t, mainQuery.DebugSql(), ` + require.Equal(t, mainQuery.DebugSql(), ` SELECT "allTypesSubQuery"."all_types.small_int_ptr" AS "all_types.small_int_ptr", "allTypesSubQuery"."all_types.small_int" AS "all_types.small_int", "allTypesSubQuery"."all_types.integer_ptr" AS "all_types.integer_ptr", @@ -212,8 +211,8 @@ LIMIT 2; dest := []model.AllTypes{} err := mainQuery.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) } func TestExpressionOperators(t *testing.T) { @@ -251,7 +250,7 @@ LIMIT $5; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) @@ -320,7 +319,7 @@ func TestExpressionCast(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestStringOperators(t *testing.T) { @@ -400,7 +399,7 @@ func TestStringOperators(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestBoolOperators(t *testing.T) { @@ -469,7 +468,7 @@ LIMIT $5; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") } @@ -519,7 +518,7 @@ func TestFloatOperators(t *testing.T) { queryStr, _ := query.Sql() - assert.Equal(t, queryStr, ` + require.Equal(t, queryStr, ` SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.decimal = $1) AS "eq2", (all_types.real = $2) AS "eq3", @@ -565,7 +564,7 @@ LIMIT $35; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) @@ -704,7 +703,7 @@ LIMIT $23; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.SaveJsonFile("./testdata/common/int_operators.json", dest) //testutils.PrintJson(dest) @@ -783,7 +782,7 @@ func TestTimeExpression(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestInterval(t *testing.T) { @@ -834,7 +833,7 @@ func TestInterval(t *testing.T) { //fmt.Println(stmt.DebugSql()) err := stmt.Query(db, &struct{}{}) - assert.NoError(t, err) + require.NoError(t, err) } func TestSubQueryColumnReference(t *testing.T) { @@ -986,12 +985,12 @@ FROM` dest1 := []model.AllTypes{} err := stmt1.Query(db, &dest1) - assert.NoError(t, err) - assert.Equal(t, len(dest1), 2) - assert.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean) - assert.Equal(t, dest1[0].Integer, allTypesRow0.Integer) - assert.Equal(t, dest1[0].Real, allTypesRow0.Real) - assert.Equal(t, dest1[0].Text, allTypesRow0.Text) + require.NoError(t, err) + require.Equal(t, len(dest1), 2) + require.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean) + require.Equal(t, dest1[0].Integer, allTypesRow0.Integer) + require.Equal(t, dest1[0].Real, allTypesRow0.Real) + require.Equal(t, dest1[0].Text, allTypesRow0.Text) testutils.AssertDeepEqual(t, dest1[0].Time, allTypesRow0.Time) testutils.AssertDeepEqual(t, dest1[0].Timez, allTypesRow0.Timez) testutils.AssertDeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp) @@ -1008,7 +1007,7 @@ FROM` dest2 := []model.AllTypes{} err = stmt2.Query(db, &dest2) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest1, dest2) } } @@ -1016,7 +1015,7 @@ FROM` func TestTimeLiterals(t *testing.T) { loc, err := time.LoadLocation("Europe/Berlin") - assert.NoError(t, err) + require.NoError(t, err) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, loc) @@ -1051,7 +1050,7 @@ LIMIT $6; err = query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index b0b5efe..2695981 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -7,7 +7,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -30,8 +30,8 @@ ORDER BY "Album"."AlbumId" ASC; err := stmt.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 347) + require.NoError(t, err) + require.Equal(t, len(dest), 347) testutils.AssertDeepEqual(t, dest[0], album1) testutils.AssertDeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[len(dest)-1], album347) @@ -103,8 +103,8 @@ func TestJoinEverything(t *testing.T) { err := stmt.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 275) + require.NoError(t, err) + require.Equal(t, len(dest), 275) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") } @@ -143,8 +143,8 @@ ORDER BY "Employee"."EmployeeId"; err := stmt.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 8) + require.NoError(t, err) + require.Equal(t, len(dest), 8) testutils.AssertJSON(t, dest[0:2], ` [ { @@ -236,9 +236,9 @@ ORDER BY "Album.AlbumId"; err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0], album1) testutils.AssertDeepEqual(t, dest[1], album2) } @@ -256,7 +256,7 @@ func TestQueryWithContext(t *testing.T) { SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns). QueryContext(ctx, db, &dest) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestExecWithContext(t *testing.T) { @@ -270,7 +270,7 @@ func TestExecWithContext(t *testing.T) { SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns). ExecContext(ctx, db) - assert.Error(t, err, "pq: canceling statement due to user request") + require.Error(t, err, "pq: canceling statement due to user request") } func TestSubQueriesForQuotedNames(t *testing.T) { @@ -327,7 +327,7 @@ ORDER BY "first10Artist"."Artist.ArtistId"; err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //spew.Dump(dest) } diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index 855f6dc..18080fc 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -48,9 +48,9 @@ RETURNING link.id AS "link.id", err := deleteStmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") } @@ -79,7 +79,7 @@ func TestDeleteQueryContext(t *testing.T) { dest := []model.Link{} err := deleteStmt.QueryContext(ctx, db, &dest) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestDeleteExecContext(t *testing.T) { @@ -98,5 +98,5 @@ func TestDeleteExecContext(t *testing.T) { _, err := deleteStmt.ExecContext(ctx, db) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 603fb0f..a27f595 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/generator/postgres" "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "io/ioutil" "os" "os/exec" @@ -17,30 +17,30 @@ import ( func TestGeneratedModel(t *testing.T) { actor := model.Actor{} - assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") + require.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") - assert.True(t, ok) - assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key") - assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") - assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") - assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time") + require.True(t, ok) + require.Equal(t, actorIDField.Tag.Get("sql"), "primary_key") + require.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") + require.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") + require.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time") filmActor := model.FilmActor{} - assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16") + require.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16") filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") - assert.True(t, ok) - assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") + require.True(t, ok) + require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") - assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16") + require.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16") actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") - assert.True(t, ok) - assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") + require.True(t, ok) + require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") staff := model.Staff{} - assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string") - assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8") + require.Equal(t, reflect.TypeOf(staff.Email).String(), "*string") + require.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8") } const genTestDir2 = "./.gentestdata2" @@ -49,10 +49,10 @@ func TestCmdGenerator(t *testing.T) { goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet") goInstallJet.Stderr = os.Stderr err := goInstallJet.Run() - assert.NoError(t, err) + require.NoError(t, err) err = os.RemoveAll(genTestDir2) - assert.NoError(t, err) + require.NoError(t, err) cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", "-port=5432", "-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2) @@ -60,12 +60,12 @@ func TestCmdGenerator(t *testing.T) { cmd.Stdout = os.Stdout err = cmd.Run() - assert.NoError(t, err) + require.NoError(t, err) assertGeneratedFiles(t) err = os.RemoveAll(genTestDir2) - assert.NoError(t, err) + require.NoError(t, err) } func TestGenerator(t *testing.T) { @@ -83,19 +83,19 @@ func TestGenerator(t *testing.T) { SchemaName: "dvds", }) - assert.NoError(t, err) + require.NoError(t, err) assertGeneratedFiles(t) } err := os.RemoveAll(genTestDir2) - assert.NoError(t, err) + require.NoError(t, err) } func assertGeneratedFiles(t *testing.T) { // Table SQL Builder files tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", @@ -105,7 +105,7 @@ func assertGeneratedFiles(t *testing.T) { // View SQL Builder files viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") @@ -114,14 +114,14 @@ func assertGeneratedFiles(t *testing.T) { // Enums SQL Builder files enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go") testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", @@ -335,14 +335,14 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { tableDir := testRoot + ".gentestdata/jetdb/test_sample/table/" enumFiles, err := ioutil.ReadDir(enumDir) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mood.go", "level.go") testutils.AssertFileContent(t, enumDir+"mood.go", moodEnumContent) testutils.AssertFileContent(t, enumDir+"level.go", levelEnumContent) modelFiles, err := ioutil.ReadDir(modelDir) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go") @@ -350,7 +350,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileContent(t, modelDir+"all_types.go", allTypesModelContent) tableFiles, err := ioutil.ReadDir(tableDir) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go") diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 7367e1d..a7facea 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "math/rand" "testing" "time" @@ -40,9 +40,9 @@ RETURNING link.id AS "link.id", err := insertQuery.Query(db, &insertedLinks) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(insertedLinks), 3) + require.Equal(t, len(insertedLinks), 3) testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{ ID: 100, @@ -69,7 +69,7 @@ RETURNING link.id AS "link.id", ORDER_BY(Link.ID). Query(db, &allLinks) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, insertedLinks, allLinks) } @@ -332,7 +332,7 @@ func TestInsertQuery(t *testing.T) { _, err := Link.DELETE(). WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). Exec(db) - assert.NoError(t, err) + require.NoError(t, err) var expectedSQL = ` INSERT INTO test_sample.link (url, name) ( @@ -362,7 +362,7 @@ RETURNING link.id AS "link.id", err = query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) youtubeLinks := []model.Link{} err = Link. @@ -370,8 +370,8 @@ RETURNING link.id AS "link.id", WHERE(Link.Name.EQ(String("Youtube"))). Query(db, &youtubeLinks) - assert.NoError(t, err) - assert.Equal(t, len(youtubeLinks), 2) + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) } func TestInsertWithQueryContext(t *testing.T) { @@ -389,7 +389,7 @@ func TestInsertWithQueryContext(t *testing.T) { dest := []model.Link{} err := stmt.QueryContext(ctx, db, &dest) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestInsertWithExecContext(t *testing.T) { @@ -405,5 +405,5 @@ func TestInsertWithExecContext(t *testing.T) { _, err := stmt.ExecContext(ctx, db) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } diff --git a/tests/postgres/lock_test.go b/tests/postgres/lock_test.go index acfb852..ce55874 100644 --- a/tests/postgres/lock_test.go +++ b/tests/postgres/lock_test.go @@ -3,7 +3,7 @@ package postgres import ( "context" "github.com/go-jet/jet/internal/testutils" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" @@ -35,11 +35,11 @@ LOCK TABLE dvds.address IN` _, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } for _, lockMode := range testData { @@ -51,11 +51,11 @@ LOCK TABLE dvds.address IN` _, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } } @@ -70,5 +70,5 @@ func TestLockExecContext(t *testing.T) { _, err := Address.LOCK().IN(LOCK_ACCESS_SHARE).ExecContext(ctx, tx) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index a50122b..80ab589 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/table" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) @@ -59,7 +59,7 @@ func TestNorthwindJoinEverything(t *testing.T) { } err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //jsonSave("./testdata/northwind-all.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index ea35c16..698e648 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -6,7 +6,6 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/google/uuid" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" ) @@ -26,8 +25,8 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; result := model.AllTypes{} err := query.Query(db, &result) - assert.NoError(t, err) - assert.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + require.NoError(t, err) + require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) } @@ -47,8 +46,8 @@ func TestUUIDComplex(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) testutils.AssertJSON(t, dest, ` [ { @@ -97,7 +96,7 @@ func TestUUIDComplex(t *testing.T) { } } err := singleQuery.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSON(t, dest, ` { @@ -133,7 +132,7 @@ func TestUUIDComplex(t *testing.T) { } err := leftQuery.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSON(t, dest, ` [ { @@ -195,7 +194,7 @@ FROM test_sample.person; err := query.Query(db, &result) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSON(t, result, ` [ { @@ -263,8 +262,8 @@ ORDER BY employee.employee_id; err = query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 8) + require.NoError(t, err) + require.Equal(t, len(dest), 8) testutils.AssertDeepEqual(t, dest[0].Employee, model.Employee{ EmployeeID: 1, FirstName: "Windy", @@ -273,7 +272,7 @@ ORDER BY employee.employee_id; ManagerID: nil, }) - assert.True(t, dest[0].Manager == nil) + require.True(t, dest[0].Manager == nil) testutils.AssertDeepEqual(t, dest[7].Employee, model.Employee{ EmployeeID: 8, @@ -311,9 +310,9 @@ FROM test_sample."WEIRD NAMES TABLE"; err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 1) + require.Equal(t, len(dest), 1) testutils.AssertDeepEqual(t, dest[0], model.WeirdNamesTable{ WeirdColumnName1: "Doe", WeirdColumnName2: "Doe", @@ -360,7 +359,7 @@ FROM test_sample."User"; var dest []model.User err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.PrintJson(dest) diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 92251c3..f6eb9f4 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -8,7 +8,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" "github.com/google/uuid" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) @@ -53,38 +53,38 @@ func TestScanToValidDestination(t *testing.T) { dest := []struct{}{} err := oneInventoryQuery.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) }) t.Run("global query function scan", func(t *testing.T) { queryStr, args := oneInventoryQuery.Sql() dest := []struct{}{} err := qrm.Query(nil, db, queryStr, args, &dest) - assert.NoError(t, err) + require.NoError(t, err) }) t.Run("pointer to slice", func(t *testing.T) { err := oneInventoryQuery.Query(db, &[]struct{}{}) - assert.NoError(t, err) + require.NoError(t, err) }) t.Run("pointer to slice of pointer to structs", func(t *testing.T) { err := oneInventoryQuery.Query(db, &[]*struct{}{}) - assert.NoError(t, err) + require.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { err := oneInventoryQuery.Query(db, &[]int32{}) - assert.NoError(t, err) + require.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { err := oneInventoryQuery.Query(db, &[]*int32{}) - assert.NoError(t, err) + require.NoError(t, err) }) } @@ -99,7 +99,7 @@ func TestScanToStruct(t *testing.T) { dest := model.Inventory{} err := query.LIMIT(1).Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, inventory1, dest) }) @@ -107,7 +107,7 @@ func TestScanToStruct(t *testing.T) { dest := model.Inventory{} err := query.LIMIT(10).Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, inventory1, dest) }) @@ -117,7 +117,7 @@ func TestScanToStruct(t *testing.T) { }{} err := query.LIMIT(1).Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, inventory1, dest.Inventory) }) @@ -127,7 +127,7 @@ func TestScanToStruct(t *testing.T) { }{} err := query.LIMIT(1).Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, inventory1, *dest.Inventory) }) @@ -158,11 +158,11 @@ func TestScanToStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, *dest.InventoryID, int32(1)) - assert.Equal(t, dest.FilmID, int16(1)) - assert.Equal(t, *dest.StoreID, int16(1)) + require.Equal(t, *dest.InventoryID, int32(1)) + require.Equal(t, dest.FilmID, int16(1)) + require.Equal(t, *dest.StoreID, int16(1)) }) t.Run("type convert int32 to int", func(t *testing.T) { @@ -175,7 +175,7 @@ func TestScanToStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) }) t.Run("type mismatch scanner type", func(t *testing.T) { @@ -217,7 +217,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Film, film1) testutils.AssertDeepEqual(t, dest.Store, store1) @@ -232,7 +232,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, *dest.Inventory, inventory1) testutils.AssertDeepEqual(t, *dest.Film, film1) testutils.AssertDeepEqual(t, *dest.Store, store1) @@ -246,7 +246,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Actor, model.Actor{}) }) @@ -259,7 +259,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Actor, (*model.Actor)(nil)) }) @@ -272,7 +272,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Actor, (*model.Actor)(nil)) }) @@ -291,9 +291,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) - assert.True(t, dest.Actor != nil) + require.True(t, dest.Actor != nil) }) t.Run("struct embedded unused pointer", func(t *testing.T) { @@ -306,7 +306,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Actor, (*struct{ model.Actor })(nil)) }) @@ -322,7 +322,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Actor, (*struct { model.Actor @@ -341,9 +341,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) - assert.True(t, dest.Actor != nil) + require.True(t, dest.Actor != nil) testutils.AssertDeepEqual(t, dest.Actor.Actor, model.Actor{}) testutils.AssertDeepEqual(t, dest.Actor.Film, film1) }) @@ -361,10 +361,10 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) - assert.True(t, dest.Actor != nil) - assert.True(t, dest.Actor.Film != nil) + require.True(t, dest.Actor != nil) + require.True(t, dest.Actor.Film != nil) testutils.AssertDeepEqual(t, dest.Actor.Film.Film, &film1) }) @@ -398,7 +398,7 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Film.Film, film1) testutils.AssertDeepEqual(t, dest.Store, store1) @@ -423,8 +423,8 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.NoError(t, err) + require.Equal(t, len(dest), 10) testutils.AssertDeepEqual(t, dest[0], inventory1) testutils.AssertDeepEqual(t, dest[1], inventory2) }) @@ -433,7 +433,7 @@ func TestScanToSlice(t *testing.T) { var dest []int32 err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest, []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) }) @@ -442,14 +442,14 @@ func TestScanToSlice(t *testing.T) { var dest []int err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) }) t.Run("slice type mismatch", func(t *testing.T) { var dest []bool testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: can't append int32 to []bool slice`) - //assert.Error(t, err, `jet: can't append int32 to []bool slice `) + //require.Error(t, err, `jet: can't append int32 to []bool slice `) }) }) @@ -473,7 +473,7 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest.Film, film1) testutils.AssertDeepEqual(t, dest.IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) }) @@ -486,8 +486,8 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) testutils.AssertDeepEqual(t, dest[1].Film, film2) @@ -502,8 +502,8 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.Int32Ptr(1), testutils.Int32Ptr(2), testutils.Int32Ptr(3), testutils.Int32Ptr(4), testutils.Int32Ptr(5), testutils.Int32Ptr(6), testutils.Int32Ptr(7), testutils.Int32Ptr(8)}) @@ -520,8 +520,8 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.NoError(t, err) + require.Equal(t, len(dest), 10) testutils.AssertDeepEqual(t, dest[0].Inventory, inventory1) testutils.AssertDeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Store, store1) @@ -538,8 +538,8 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.NoError(t, err) + require.Equal(t, len(dest), 10) testutils.AssertDeepEqual(t, dest[0].Inventory, &inventory1) testutils.AssertDeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Store, &store1) @@ -558,8 +558,8 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.NoError(t, err) + require.Equal(t, len(dest), 10) testutils.AssertDeepEqual(t, dest[0].Inventory, inventory1) testutils.AssertDeepEqual(t, dest[0].Film, &film1) testutils.AssertDeepEqual(t, dest[0].Store.Store, &store1) @@ -579,8 +579,8 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, len(dest[0].Inventories), 8) testutils.AssertDeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) @@ -601,14 +601,14 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) - assert.Equal(t, len(dest[0].Inventories), 8) + require.Equal(t, len(dest[0].Inventories), 8) testutils.AssertDeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) - assert.True(t, dest[0].Inventories[0].Rentals == nil) - assert.True(t, dest[0].Inventories[0].Rentals2 == nil) + require.True(t, dest[0].Inventories[0].Rentals == nil) + require.True(t, dest[0].Inventories[0].Rentals2 == nil) }) }) @@ -638,12 +638,12 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 108) + require.NoError(t, err) + require.Equal(t, len(dest), 108) testutils.AssertDeepEqual(t, dest[100].Country, countryUk) - assert.Equal(t, len(dest[100].Cities), 8) + require.Equal(t, len(dest[100].Cities), 8) testutils.AssertDeepEqual(t, dest[100].Cities[2].City, cityLondon) - assert.Equal(t, len(dest[100].Cities[2].Adresses), 2) + require.Equal(t, len(dest[100].Cities[2].Adresses), 2) testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[0].Address, address256) testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[0].Customer, customer256) testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[1].Address, addres517) @@ -667,12 +667,12 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 108) + require.NoError(t, err) + require.Equal(t, len(dest), 108) testutils.AssertDeepEqual(t, dest[100].Country, &countryUk) - assert.Equal(t, len(dest[100].Cities), 8) + require.Equal(t, len(dest[100].Cities), 8) testutils.AssertDeepEqual(t, dest[100].Cities[2].City, &cityLondon) - assert.Equal(t, len(*dest[100].Cities[2].Adresses), 2) + require.Equal(t, len(*dest[100].Cities[2].Adresses), 2) testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Address, &address256) testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Customer, &customer256) testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Address, &addres517) @@ -703,7 +703,7 @@ func TestStructScanErrNoRows(t *testing.T) { err := query.Query(db, &customer) - assert.Error(t, err, qrm.ErrNoRows.Error()) + require.Error(t, err, qrm.ErrNoRows.Error()) } func TestStructScanAllNull(t *testing.T) { @@ -716,7 +716,7 @@ func TestStructScanAllNull(t *testing.T) { err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest, struct { Null1 *int Null2 *int diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 28afd8c..5c52e71 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -8,7 +8,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -33,7 +33,7 @@ WHERE actor.actor_id = 2; actor := model.Actor{} err := query.Query(db, &actor) - assert.NoError(t, err) + require.NoError(t, err) expectedActor := model.Actor{ ActorID: 2, @@ -84,8 +84,8 @@ LIMIT 30; err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 30) + require.NoError(t, err) + require.Equal(t, len(dest), 30) } func TestSelect_ScanToSlice(t *testing.T) { @@ -110,9 +110,9 @@ ORDER BY customer.customer_id ASC; testutils.AssertDebugStatementSql(t, query, expectedSQL) err := query.Query(db, &customers) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(customers), 599) + require.Equal(t, len(customers), 599) testutils.AssertDeepEqual(t, customer0, customers[0]) testutils.AssertDeepEqual(t, customer1, customers[1]) @@ -164,7 +164,7 @@ LIMIT 12; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestJoinQueryStruct(t *testing.T) { @@ -253,10 +253,10 @@ LIMIT 1000; err := query.Query(db, &languageActorFilm) - assert.NoError(t, err) - assert.Equal(t, len(languageActorFilm), 1) - assert.Equal(t, len(languageActorFilm[0].Films), 10) - assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10) + require.NoError(t, err) + require.Equal(t, len(languageActorFilm), 1) + require.Equal(t, len(languageActorFilm[0].Films), 10) + require.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10) } } @@ -302,20 +302,20 @@ LIMIT 15; err := query.Query(db, &filmsPerLanguage) - assert.NoError(t, err) - assert.Equal(t, len(filmsPerLanguage), 1) - assert.Equal(t, len(filmsPerLanguage[0].Film), limit) + require.NoError(t, err) + require.Equal(t, len(filmsPerLanguage), 1) + require.Equal(t, len(filmsPerLanguage[0].Film), limit) englishFilms := filmsPerLanguage[0] - assert.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_Nc17) + require.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_Nc17) filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err = query.Query(db, &filmsPerLanguageWithPtrs) - assert.NoError(t, err) - assert.Equal(t, len(filmsPerLanguage), 1) - assert.Equal(t, len(filmsPerLanguage[0].Film), limit) + require.NoError(t, err) + require.Equal(t, len(filmsPerLanguage), 1) + require.Equal(t, len(filmsPerLanguage[0].Film), limit) } func TestExecution1(t *testing.T) { @@ -359,14 +359,14 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) - assert.Equal(t, dest[0].City.City, "London") - assert.Equal(t, dest[1].City.City, "York") - assert.Equal(t, len(dest[0].Customers), 2) - assert.Equal(t, dest[0].Customers[0].LastName, "Hoffman") - assert.Equal(t, dest[0].Customers[1].LastName, "Vines") + require.Equal(t, len(dest), 2) + require.Equal(t, dest[0].City.City, "London") + require.Equal(t, dest[1].City.City, "York") + require.Equal(t, len(dest[0].Customers), 2) + require.Equal(t, dest[0].Customers[0].LastName, "Hoffman") + require.Equal(t, dest[0].Customers[1].LastName, "Vines") } @@ -423,14 +423,14 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) - assert.Equal(t, dest[0].Name, "London") - assert.Equal(t, dest[1].Name, "York") - assert.Equal(t, len(dest[0].Customers), 2) - assert.Equal(t, *dest[0].Customers[0].LastName, "Hoffman") - assert.Equal(t, *dest[0].Customers[1].LastName, "Vines") + require.Equal(t, len(dest), 2) + require.Equal(t, dest[0].Name, "London") + require.Equal(t, dest[1].Name, "York") + require.Equal(t, len(dest[0].Customers), 2) + require.Equal(t, *dest[0].Customers[0].LastName, "Hoffman") + require.Equal(t, *dest[0].Customers[1].LastName, "Vines") } @@ -481,14 +481,14 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) - assert.Equal(t, dest[0].CityName, "London") - assert.Equal(t, dest[1].CityName, "York") - assert.Equal(t, len(dest[0].Customers), 2) - assert.Equal(t, *dest[0].Customers[0].LastName, "Hoffman") - assert.Equal(t, *dest[0].Customers[1].LastName, "Vines") + require.Equal(t, len(dest), 2) + require.Equal(t, dest[0].CityName, "London") + require.Equal(t, dest[1].CityName, "York") + require.Equal(t, len(dest[0].Customers), 2) + require.Equal(t, *dest[0].Customers[0].LastName, "Hoffman") + require.Equal(t, *dest[0].Customers[1].LastName, "Vines") } func TestExecution4(t *testing.T) { @@ -538,8 +538,8 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 2) + require.NoError(t, err) + require.Equal(t, len(dest), 2) testutils.AssertJSON(t, dest, ` [ { @@ -597,9 +597,9 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err := query.Query(db, &filmsPerLanguageWithPtrs) - assert.NoError(t, err) - assert.Equal(t, len(filmsPerLanguageWithPtrs), 1) - assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit)) + require.NoError(t, err) + require.Equal(t, len(filmsPerLanguageWithPtrs), 1) + require.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit)) } func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { @@ -609,11 +609,11 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { err := query.Query(db, &customers) - assert.NoError(t, err) + require.NoError(t, err) //spew.Dump(customers) - assert.Equal(t, len(customers), 599) + require.Equal(t, len(customers), 599) } func TestSelectOrderByAscDesc(t *testing.T) { @@ -623,7 +623,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { ORDER_BY(Customer.FirstName.ASC()). Query(db, &customersAsc) - assert.NoError(t, err) + require.NoError(t, err) firstCustomerAsc := customersAsc[0] lastCustomerAsc := customersAsc[len(customersAsc)-1] @@ -633,7 +633,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { ORDER_BY(Customer.FirstName.DESC()). Query(db, &customersDesc) - assert.NoError(t, err) + require.NoError(t, err) firstCustomerDesc := customersDesc[0] lastCustomerDesc := customersDesc[len(customersAsc)-1] @@ -646,7 +646,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { ORDER_BY(Customer.FirstName.ASC(), Customer.LastName.DESC()). Query(db, &customersAscDesc) - assert.NoError(t, err) + require.NoError(t, err) customerAscDesc326 := model.Customer{ CustomerID: 67, @@ -702,16 +702,16 @@ ORDER BY customer.customer_id ASC; err := query.Query(db, &allCustomersAndAddress) - assert.NoError(t, err) - assert.Equal(t, len(allCustomersAndAddress), 603) + require.NoError(t, err) + require.Equal(t, len(allCustomersAndAddress), 603) testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) - assert.True(t, allCustomersAndAddress[0].Address != nil) + require.True(t, allCustomersAndAddress[0].Address != nil) lastCustomerAddress := allCustomersAndAddress[len(allCustomersAndAddress)-1] - assert.True(t, lastCustomerAddress.Customer == nil) - assert.True(t, lastCustomerAddress.Address != nil) + require.True(t, lastCustomerAddress.Customer == nil) + require.True(t, lastCustomerAddress.Address != nil) } @@ -755,9 +755,9 @@ LIMIT 1000; err := query.Query(db, &customerAddresCrosJoined) - assert.Equal(t, len(customerAddresCrosJoined), 1000) + require.Equal(t, len(customerAddresCrosJoined), 1000) - assert.NoError(t, err) + require.NoError(t, err) } func TestSelectSelfJoin(t *testing.T) { @@ -813,11 +813,11 @@ ORDER BY f1.film_id ASC; err := query.Query(db, &theSameLengthFilms) - assert.NoError(t, err) + require.NoError(t, err) //spew.Dump(theSameLengthFilms) - //assert.Equal(t, len(theSameLengthFilms), 100) + //require.Equal(t, len(theSameLengthFilms), 100) } func TestSelectAliasColumn(t *testing.T) { @@ -854,11 +854,11 @@ LIMIT 1000; err := query.Query(db, &films) - assert.NoError(t, err) + require.NoError(t, err) //spew.Dump(films) - assert.Equal(t, len(films), 1000) + require.Equal(t, len(films), 1000) testutils.AssertDeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) } @@ -911,7 +911,7 @@ FROM dvds.actor err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestSelectFunctions(t *testing.T) { @@ -931,8 +931,8 @@ FROM dvds.film; err := query.Query(db, &ret) - assert.NoError(t, err) - assert.Equal(t, ret.MaxFilmRate, 4.99) + require.NoError(t, err) + require.Equal(t, ret.MaxFilmRate, 4.99) } func TestSelectQueryScalar(t *testing.T) { @@ -973,9 +973,9 @@ ORDER BY film.film_id ASC; maxRentalRateFilms := []model.Film{} err := query.Query(db, &maxRentalRateFilms) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(maxRentalRateFilms), 336) + require.Equal(t, len(maxRentalRateFilms), 336) gRating := model.MpaaRating_G @@ -1060,11 +1060,11 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC; err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //testutils.PrintJson(dest) - assert.Equal(t, len(dest), 104) + require.Equal(t, len(dest), 104) //testutils.SaveJsonFile(dest, "postgres/testdata/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json") @@ -1121,8 +1121,8 @@ ORDER BY customer_payment_sum."amount_sum" ASC; customersWithAmounts := []CustomerWithAmounts{} err := query.Query(db, &customersWithAmounts) - assert.NoError(t, err) - assert.Equal(t, len(customersWithAmounts), 599) + require.NoError(t, err) + require.Equal(t, len(customersWithAmounts), 599) testutils.AssertDeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{ CustomerID: 318, @@ -1137,7 +1137,7 @@ ORDER BY customer_payment_sum."amount_sum" ASC; Active: testutils.Int32Ptr(1), }) - assert.Equal(t, customersWithAmounts[0].AmountSum, 27.93) + require.Equal(t, customersWithAmounts[0].AmountSum, 27.93) } func TestSelectStaff(t *testing.T) { @@ -1145,7 +1145,7 @@ func TestSelectStaff(t *testing.T) { err := Staff.SELECT(Staff.AllColumns).Query(db, &staffs) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSON(t, staffs, ` [ @@ -1203,11 +1203,11 @@ ORDER BY payment.payment_date ASC; err := query.Query(db, &payments) - assert.NoError(t, err) + require.NoError(t, err) //spew.Dump(payments) - assert.Equal(t, len(payments), 9) + require.Equal(t, len(payments), 9) testutils.AssertDeepEqual(t, payments[0], model.Payment{ PaymentID: 17793, CustomerID: 416, @@ -1257,8 +1257,8 @@ OFFSET 20; err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.NoError(t, err) + require.Equal(t, len(dest), 10) testutils.AssertDeepEqual(t, dest[0], model.Payment{ PaymentID: 17523, Amount: 4.99, @@ -1283,8 +1283,8 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 20) + require.NoError(t, err) + require.Equal(t, len(dest), 20) }) t.Run("UNION_ALL", func(t *testing.T) { @@ -1293,8 +1293,8 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 20) + require.NoError(t, err) + require.Equal(t, len(dest), 20) }) t.Run("INTERSECT", func(t *testing.T) { @@ -1303,8 +1303,8 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 0) + require.NoError(t, err) + require.Equal(t, len(dest), 0) }) t.Run("INTERSECT_ALL", func(t *testing.T) { @@ -1313,8 +1313,8 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 0) + require.NoError(t, err) + require.Equal(t, len(dest), 0) }) t.Run("EXCEPT", func(t *testing.T) { @@ -1323,8 +1323,8 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.NoError(t, err) + require.Equal(t, len(dest), 10) }) t.Run("EXCEPT_ALL", func(t *testing.T) { @@ -1333,8 +1333,8 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 10) + require.NoError(t, err) + require.Equal(t, len(dest), 10) }) } @@ -1363,10 +1363,10 @@ LIMIT 20; err := query.Query(db, &dest) - assert.NoError(t, err) - assert.Equal(t, len(dest), 20) - assert.Equal(t, dest[0].StaffIDNum, "TWO") - assert.Equal(t, dest[1].StaffIDNum, "ONE") + require.NoError(t, err) + require.Equal(t, len(dest), 20) + require.Equal(t, dest[0].StaffIDNum, "TWO") + require.Equal(t, dest[1].StaffIDNum, "ONE") } func getRowLockTestData() map[RowLock]string { @@ -1396,12 +1396,12 @@ FOR` tx, _ := db.Begin() res, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, _ := res.RowsAffected() - assert.Equal(t, rowsAffected, int64(3)) + require.Equal(t, rowsAffected, int64(3)) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } for lockType, lockTypeStr := range getRowLockTestData() { @@ -1412,12 +1412,12 @@ FOR` tx, _ := db.Begin() res, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, _ := res.RowsAffected() - assert.Equal(t, rowsAffected, int64(3)) + require.Equal(t, rowsAffected, int64(3)) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } for lockType, lockTypeStr := range getRowLockTestData() { @@ -1428,12 +1428,12 @@ FOR` tx, _ := db.Begin() res, err := query.Exec(tx) - assert.NoError(t, err) + require.NoError(t, err) rowsAffected, _ := res.RowsAffected() - assert.Equal(t, rowsAffected, int64(3)) + require.Equal(t, rowsAffected, int64(3)) err = tx.Rollback() - assert.NoError(t, err) + require.NoError(t, err) } } @@ -1509,7 +1509,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; } err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //jsonSave("./testdata/quick-start-dest.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") @@ -1522,7 +1522,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; } err = stmt.Query(db, &dest2) - assert.NoError(t, err) + require.NoError(t, err) //jsonSave("./testdata/quick-start-dest2.json", dest2) testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") @@ -1574,7 +1574,7 @@ func TestQuickStartWithSubQueries(t *testing.T) { } err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) //jsonSave("./testdata/quick-start-dest.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") @@ -1587,7 +1587,7 @@ func TestQuickStartWithSubQueries(t *testing.T) { } err = stmt.Query(db, &dest2) - assert.NoError(t, err) + require.NoError(t, err) //jsonSave("./testdata/quick-start-dest2.json", dest2) testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") @@ -1620,7 +1620,7 @@ SELECT true, dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestWindowFunction(t *testing.T) { @@ -1692,7 +1692,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestWindowClause(t *testing.T) { @@ -1729,7 +1729,7 @@ ORDER BY payment.customer_id; dest := []struct{}{} err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) } func TestSimpleView(t *testing.T) { @@ -1751,7 +1751,7 @@ func TestSimpleView(t *testing.T) { var dest []ActorInfo err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) testutils.AssertJSON(t, dest[1:2], ` [ @@ -1785,11 +1785,11 @@ func TestJoinViewWithTable(t *testing.T) { fmt.Println(query.DebugSql()) err := query.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 2) - assert.Equal(t, len(dest[0].Rentals), 32) - assert.Equal(t, len(dest[1].Rentals), 27) + require.Equal(t, len(dest), 2) + require.Equal(t, len(dest[0].Rentals), 32) + require.Equal(t, len(dest[1].Rentals), 27) } func TestDynamicProjectionList(t *testing.T) { @@ -1834,9 +1834,9 @@ LIMIT 3; `) var dest []model.Customer err := stmt.Query(db, &dest) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, len(dest), 3) + require.Equal(t, len(dest), 3) } func TestDynamicCondition(t *testing.T) { @@ -1884,7 +1884,7 @@ WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3); dest := []model.Customer{} err := stmt.Query(db, &dest) - assert.NoError(t, err) - assert.Len(t, dest, 1) + require.NoError(t, err) + require.Len(t, dest, 1) testutils.AssertDeepEqual(t, dest[0], customer0) } diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 43e64fb..0862fb3 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -6,7 +6,6 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" "time" @@ -143,10 +142,10 @@ RETURNING link.id AS "link.id", err := stmt.Query(db, &links) - assert.NoError(t, err) - assert.Equal(t, len(links), 2) - assert.Equal(t, links[0].Name, "DuckDuckGo") - assert.Equal(t, links[1].Name, "DuckDuckGo") + require.NoError(t, err) + require.Equal(t, len(links), 2) + require.Equal(t, links[0].Name, "DuckDuckGo") + require.Equal(t, links[1].Name, "DuckDuckGo") } func TestUpdateWithSelect(t *testing.T) { @@ -294,7 +293,7 @@ func TestUpdateWithInvalidModelData(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "missing struct field for column : id") + require.Equal(t, r, "missing struct field for column : id") }() setupLinkTableForUpdateTest(t) @@ -342,7 +341,7 @@ func TestUpdateQueryContext(t *testing.T) { dest := []model.Link{} err := updateStmt.QueryContext(ctx, db, &dest) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func TestUpdateExecContext(t *testing.T) { @@ -360,7 +359,7 @@ func TestUpdateExecContext(t *testing.T) { _, err := updateStmt.ExecContext(ctx, db) - assert.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") } func setupLinkTableForUpdateTest(t *testing.T) { @@ -375,10 +374,10 @@ func setupLinkTableForUpdateTest(t *testing.T) { VALUES(204, "http://www.bing.com", "Bing", DEFAULT). Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } func cleanUpLinkTable(t *testing.T) { _, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go index b2d5452..4d6f478 100644 --- a/tests/postgres/util_test.go +++ b/tests/postgres/util_test.go @@ -4,7 +4,6 @@ import ( "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" ) @@ -14,8 +13,8 @@ func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { require.NoError(t, err) rows, err := res.RowsAffected() - assert.NoError(t, err) - assert.Equal(t, rows, rowsAffected) + require.NoError(t, err) + require.Equal(t, rows, rowsAffected) } var customer0 = model.Customer{ From 0d3ec872d6b7351aa890028c0fdba710c9e2c2f8 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 10 May 2020 11:41:07 +0200 Subject: [PATCH 10/23] Add support for automatic query logging. --- internal/jet/logger.go | 19 +++++++++++++++++++ internal/jet/statement.go | 31 +++++++++++++++++++++++-------- mysql/types.go | 6 ++++++ postgres/types.go | 6 ++++++ tests/mysql/alltypes_test.go | 1 + tests/mysql/cast_test.go | 2 ++ tests/mysql/delete_test.go | 3 +++ tests/mysql/insert_test.go | 2 ++ tests/mysql/lock_test.go | 3 +++ tests/mysql/main_test.go | 22 ++++++++++++++++++++++ tests/mysql/select_test.go | 3 +++ tests/mysql/update_test.go | 8 ++++++-- tests/postgres/alltypes_test.go | 3 +++ tests/postgres/chinook_db_test.go | 2 ++ tests/postgres/delete_test.go | 3 +++ tests/postgres/insert_test.go | 3 +++ tests/postgres/lock_test.go | 2 ++ tests/postgres/main_test.go | 21 +++++++++++++++++++++ tests/postgres/northwind_test.go | 1 + tests/postgres/sample_test.go | 3 +++ tests/postgres/select_test.go | 6 ++++++ tests/postgres/update_test.go | 15 ++++++++++++--- 22 files changed, 152 insertions(+), 13 deletions(-) create mode 100644 internal/jet/logger.go diff --git a/internal/jet/logger.go b/internal/jet/logger.go new file mode 100644 index 0000000..90818b0 --- /dev/null +++ b/internal/jet/logger.go @@ -0,0 +1,19 @@ +package jet + +import "context" + +// LoggableStatement is a statement which sql query can be logged +type LoggableStatement interface { + Sql() (query string, args []interface{}) + DebugSql() (query string) +} + +// LoggerFunc is a definition of a function user can implement to support automatic statement logging. +type LoggerFunc func(ctx context.Context, statement LoggableStatement) + +var logger LoggerFunc + +// SetLoggerFunc sets automatic statement logging +func SetLoggerFunc(loggerFunc LoggerFunc) { + logger = loggerFunc +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index beb52f1..37b2077 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -13,7 +13,6 @@ type Statement interface { // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // Do not use it in production. Use it only for debug purposes. DebugSql() (query string) - // Query executes statement over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. @@ -21,12 +20,12 @@ type Statement interface { // QueryContext executes statement with a context over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. - QueryContext(context context.Context, db qrm.DB, destination interface{}) error + QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error //Exec executes statement over db connection without returning any rows. Exec(db qrm.DB) (sql.Result, error) //Exec executes statement with context over db connection without returning any rows. - ExecContext(context context.Context, db qrm.DB) (sql.Result, error) + ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) } // SerializerStatement interface @@ -75,25 +74,41 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { query, args := s.Sql() + ctx := context.Background() - return qrm.Query(context.Background(), db, query, args, destination) + callLogger(ctx, s) + + return qrm.Query(ctx, db, query, args, destination) } -func (s *serializerStatementInterfaceImpl) QueryContext(context context.Context, db qrm.DB, destination interface{}) error { +func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { query, args := s.Sql() - return qrm.Query(context, db, query, args, destination) + callLogger(ctx, s) + + return qrm.Query(ctx, db, query, args, destination) } func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { query, args := s.Sql() + + callLogger(context.Background(), s) + return db.Exec(query, args...) } -func (s *serializerStatementInterfaceImpl) ExecContext(context context.Context, db qrm.DB) (res sql.Result, err error) { +func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { query, args := s.Sql() - return db.ExecContext(context, query, args...) + callLogger(ctx, s) + + return db.ExecContext(ctx, query, args...) +} + +func callLogger(ctx context.Context, statement Statement) { + if logger != nil { + logger(ctx, statement) + } } // ExpressionStatement interfacess diff --git a/mysql/types.go b/mysql/types.go index 908fce5..7e1424f 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -13,3 +13,9 @@ type ProjectionList = jet.ProjectionList // ColumnAssigment is interface wrapper around column assigment type ColumnAssigment = jet.ColumnAssigment + +// LoggableStatement is a statement which sql query can be logged +type LoggableStatement = jet.LoggableStatement + +// SetLogger sets automatic statement logging +var SetLogger = jet.SetLoggerFunc diff --git a/postgres/types.go b/postgres/types.go index 48de455..cfb52ec 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -13,3 +13,9 @@ type ProjectionList = jet.ProjectionList // ColumnAssigment is interface wrapper around column assigment type ColumnAssigment = jet.ColumnAssigment + +// LoggableStatement is a statement which sql query can be logged +type LoggableStatement = jet.LoggableStatement + +// SetLogger sets automatic statement logging +var SetLogger = jet.SetLoggerFunc diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 7d9e17e..d791d42 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -80,6 +80,7 @@ func TestUUID(t *testing.T) { require.True(t, dest.UUID.String() != uuid.UUID{}.String()) require.True(t, dest.StrUUID.String() != uuid.UUID{}.String()) require.Equal(t, dest.StrUUID.String(), dest.BinUUID.String()) + requireLogged(t, query) } func TestExpressionOperators(t *testing.T) { diff --git a/tests/mysql/cast_test.go b/tests/mysql/cast_test.go index 218665e..fda79e7 100644 --- a/tests/mysql/cast_test.go +++ b/tests/mysql/cast_test.go @@ -68,4 +68,6 @@ FROM test_sample.all_types; Unsigned: 15, Binary: "Some text", }) + + requireLogged(t, query) } diff --git a/tests/mysql/delete_test.go b/tests/mysql/delete_test.go index da91e97..90d15cc 100644 --- a/tests/mysql/delete_test.go +++ b/tests/mysql/delete_test.go @@ -24,6 +24,7 @@ WHERE link.name IN ('Gmail', 'Outlook'); testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") testutils.AssertExec(t, deleteStmt, db, 2) + requireLogged(t, deleteStmt) } func TestDeleteWithWhereOrderByLimit(t *testing.T) { @@ -43,6 +44,7 @@ LIMIT 1; testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1)) testutils.AssertExec(t, deleteStmt, db, 1) + requireLogged(t, deleteStmt) } func TestDeleteQueryContext(t *testing.T) { @@ -61,6 +63,7 @@ func TestDeleteQueryContext(t *testing.T) { err := deleteStmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) } func TestDeleteExecContext(t *testing.T) { diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 613a655..43091b2 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -34,6 +34,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT _, err := insertQuery.Exec(db) require.NoError(t, err) + requireLogged(t, insertQuery) insertedLinks := []model.Link{} @@ -82,6 +83,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT _, err := stmt.Exec(db) require.NoError(t, err) + requireLogged(t, stmt) insertedLinks := []model.Link{} diff --git a/tests/mysql/lock_test.go b/tests/mysql/lock_test.go index 8aed571..c44c436 100644 --- a/tests/mysql/lock_test.go +++ b/tests/mysql/lock_test.go @@ -17,6 +17,7 @@ LOCK TABLES dvds.customer READ; _, err := query.Exec(db) require.NoError(t, err) + requireLogged(t, query) } func TestLockWrite(t *testing.T) { @@ -28,6 +29,7 @@ LOCK TABLES dvds.customer WRITE; _, err := query.Exec(db) require.NoError(t, err) + requireLogged(t, query) } func TestUnlockTables(t *testing.T) { @@ -39,4 +41,5 @@ UNLOCK TABLES; _, err := query.Exec(db) require.NoError(t, err) + requireLogged(t, query) } diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index c7db884..0f51875 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -1,9 +1,13 @@ package mysql import ( + "context" "database/sql" "flag" + jetmysql "github.com/go-jet/jet/mysql" + "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/dbconfig" + "github.com/stretchr/testify/require" "math/rand" "time" @@ -44,3 +48,21 @@ func TestMain(m *testing.M) { os.Exit(ret) } + +var loggedSQL string +var loggedSQLArgs []interface{} +var loggedDebugSQL string + +func init() { + jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.LoggableStatement) { + loggedSQL, loggedSQLArgs = statement.Sql() + loggedDebugSQL = statement.DebugSql() + }) +} + +func requireLogged(t *testing.T, statement postgres.Statement) { + query, args := statement.Sql() + require.Equal(t, loggedSQL, query) + require.Equal(t, loggedSQLArgs, args) + require.Equal(t, loggedDebugSQL, statement.DebugSql()) +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index e5b748e..5fd8fdc 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -33,6 +33,7 @@ WHERE actor.actor_id = ?; require.NoError(t, err) testutils.AssertDeepEqual(t, actor, actor2) + requireLogged(t, query) } var actor2 = model.Actor{ @@ -67,6 +68,7 @@ ORDER BY actor.actor_id; //testutils.PrintJson(dest) //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json") + requireLogged(t, query) } func TestSelectGroupByHaving(t *testing.T) { @@ -144,6 +146,7 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; //testutils.SaveJsonFile(dest, "mysql/testdata/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json") + requireLogged(t, query) } func TestSubQuery(t *testing.T) { diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index a689584..94a6716 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -30,6 +30,7 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertExec(t, query, db) + requireLogged(t, query) }) t.Run("new version", func(t *testing.T) { @@ -42,6 +43,7 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertExec(t, stmt, db) + requireLogged(t, stmt) }) links := []model.Link{} @@ -88,6 +90,7 @@ WHERE link.name = ?; testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") testutils.AssertExec(t, query, db) + requireLogged(t, query) }) t.Run("new version", func(t *testing.T) { @@ -105,6 +108,7 @@ WHERE link.name = ?; testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") testutils.AssertExec(t, query, db) + requireLogged(t, query) }) } @@ -130,10 +134,10 @@ SET id = ?, description = ? WHERE link.id = ?; ` - fmt.Println(stmt.Sql()) testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) testutils.AssertExec(t, stmt, db) + requireLogged(t, stmt) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { @@ -165,10 +169,10 @@ WHERE link.id = 201; testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) testutils.AssertExec(t, stmt, db) + requireLogged(t, stmt) } func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { - setupLinkTableForUpdateTest(t) link := model.Link{ diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index d4e6ab4..3fa5543 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -834,6 +834,7 @@ func TestInterval(t *testing.T) { err := stmt.Query(db, &struct{}{}) require.NoError(t, err) + requireLogged(t, stmt) } func TestSubQueryColumnReference(t *testing.T) { @@ -1009,6 +1010,7 @@ FROM` require.NoError(t, err) testutils.AssertDeepEqual(t, dest1, dest2) + requireLogged(t, stmt2) } } @@ -1062,6 +1064,7 @@ LIMIT $6; "Timestamp": "2009-11-17T20:34:58.651387Z" } `) + requireLogged(t, query) } var allTypesRow0 = model.AllTypes{ diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 2695981..5c12010 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -35,6 +35,7 @@ ORDER BY "Album"."AlbumId" ASC; testutils.AssertDeepEqual(t, dest[0], album1) testutils.AssertDeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[len(dest)-1], album347) + requireLogged(t, stmt) } func TestJoinEverything(t *testing.T) { @@ -106,6 +107,7 @@ func TestJoinEverything(t *testing.T) { require.NoError(t, err) require.Equal(t, len(dest), 275) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") + requireLogged(t, stmt) } func TestSelfJoin(t *testing.T) { diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index 18080fc..01104d4 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -53,6 +53,7 @@ RETURNING link.id AS "link.id", require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") + requireLogged(t, deleteStmt) } func initForDeleteTest(t *testing.T) { @@ -80,6 +81,7 @@ func TestDeleteQueryContext(t *testing.T) { err := deleteStmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) } func TestDeleteExecContext(t *testing.T) { @@ -99,4 +101,5 @@ func TestDeleteExecContext(t *testing.T) { _, err := deleteStmt.ExecContext(ctx, db) require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) } diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index a7facea..e7dac15 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -89,6 +89,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") AssertExec(t, stmt, 1) + requireLogged(t, stmt) } func TestInsertOnConflict(t *testing.T) { @@ -108,6 +109,7 @@ VALUES ($1, $2, $3, $4, $5), ON CONFLICT (employee_id) DO NOTHING; `) AssertExec(t, stmt, 1) + requireLogged(t, stmt) }) t.Run("on constraint do nothing", func(t *testing.T) { @@ -125,6 +127,7 @@ VALUES ($1, $2, $3, $4, $5), ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) AssertExec(t, stmt, 1) + requireLogged(t, stmt) }) t.Run("do update", func(t *testing.T) { diff --git a/tests/postgres/lock_test.go b/tests/postgres/lock_test.go index ce55874..c27adf3 100644 --- a/tests/postgres/lock_test.go +++ b/tests/postgres/lock_test.go @@ -40,6 +40,7 @@ LOCK TABLE dvds.address IN` err = tx.Rollback() require.NoError(t, err) + requireLogged(t, query) } for _, lockMode := range testData { @@ -56,6 +57,7 @@ LOCK TABLE dvds.address IN` err = tx.Rollback() require.NoError(t, err) + requireLogged(t, query) } } diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index fd538d6..5fa23d5 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -1,10 +1,13 @@ package postgres import ( + "context" "database/sql" + "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/dbconfig" _ "github.com/lib/pq" "github.com/pkg/profile" + "github.com/stretchr/testify/require" "math/rand" "os" "os/exec" @@ -43,3 +46,21 @@ func setTestRoot() { testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" } + +var loggedSQL string +var loggedSQLArgs []interface{} +var loggedDebugSQL string + +func init() { + postgres.SetLogger(func(ctx context.Context, statement postgres.LoggableStatement) { + loggedSQL, loggedSQLArgs = statement.Sql() + loggedDebugSQL = statement.DebugSql() + }) +} + +func requireLogged(t *testing.T, statement postgres.Statement) { + query, args := statement.Sql() + require.Equal(t, loggedSQL, query) + require.Equal(t, loggedSQLArgs, args) + require.Equal(t, loggedDebugSQL, statement.DebugSql()) +} diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index 80ab589..e45661a 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -63,4 +63,5 @@ func TestNorthwindJoinEverything(t *testing.T) { //jsonSave("./testdata/northwind-all.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") + requireLogged(t, stmt) } diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 698e648..b3429fc 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -28,6 +28,7 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; require.NoError(t, err) require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + requireLogged(t, query) } func TestUUIDComplex(t *testing.T) { @@ -118,6 +119,7 @@ func TestUUIDComplex(t *testing.T) { ] } `) + requireLogged(t, query) }) t.Run("slice of structs left join", func(t *testing.T) { @@ -175,6 +177,7 @@ func TestUUIDComplex(t *testing.T) { } ] `) + requireLogged(t, leftQuery) }) } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 5c52e71..35f9803 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -43,6 +43,8 @@ WHERE actor.actor_id = 2; } testutils.AssertDeepEqual(t, actor, expectedActor) + + requireLogged(t, query) } func TestClassicSelect(t *testing.T) { @@ -86,6 +88,8 @@ LIMIT 30; require.NoError(t, err) require.Equal(t, len(dest), 30) + + requireLogged(t, query) } func TestSelect_ScanToSlice(t *testing.T) { @@ -117,6 +121,8 @@ ORDER BY customer.customer_id ASC; testutils.AssertDeepEqual(t, customer0, customers[0]) testutils.AssertDeepEqual(t, customer1, customers[1]) testutils.AssertDeepEqual(t, lastCustomer, customers[598]) + + requireLogged(t, query) } func TestSelectAndUnionInProjection(t *testing.T) { diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 0862fb3..5a7cdcd 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -27,13 +27,15 @@ WHERE link.name = 'Bing'; `, "Bong", "http://bong.com", "Bing") testutils.AssertExec(t, query, db, 1) + requireLogged(t, query) links := []model.Link{} - err := Link. + selQuery := Link. SELECT(Link.AllColumns). - WHERE(Link.Name.IN(String("Bong"))). - Query(db, &links) + WHERE(Link.Name.IN(String("Bong"))) + + err := selQuery.Query(db, &links) require.NoError(t, err) require.Equal(t, len(links), 1) @@ -42,6 +44,7 @@ WHERE link.name = 'Bing'; URL: "http://bong.com", Name: "Bong", }) + requireLogged(t, selQuery) }) t.Run("new version", func(t *testing.T) { @@ -59,6 +62,7 @@ SET name = 'DuckDuckGo', WHERE link.name = 'Yahoo'; `) testutils.AssertExec(t, stmt, db, 1) + requireLogged(t, stmt) }) } @@ -90,6 +94,7 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") AssertExec(t, query, 1) + requireLogged(t, query) }) t.Run("new version", func(t *testing.T) { @@ -114,6 +119,9 @@ SET name = $1, ) WHERE link.name = $3; `, "Bong", "Bing", "Bing") + _, err := query.Exec(db) + require.NoError(t, err) + requireLogged(t, query) }) } @@ -146,6 +154,7 @@ RETURNING link.id AS "link.id", require.Equal(t, len(links), 2) require.Equal(t, links[0].Name, "DuckDuckGo") require.Equal(t, links[1].Name, "DuckDuckGo") + requireLogged(t, stmt) } func TestUpdateWithSelect(t *testing.T) { From fb8607da29c7fb4cfe6005921ffe0e309323fb9b Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 24 May 2020 17:55:28 +0200 Subject: [PATCH 11/23] Add support for WITH statements and Common Table Expressions. --- internal/jet/clause.go | 20 +-- internal/jet/column.go | 2 +- internal/jet/func_expression.go | 5 + internal/jet/select_table.go | 32 ++--- internal/jet/serializer.go | 1 + internal/jet/sql_builder.go | 7 + internal/jet/sql_builder_test.go | 10 ++ internal/jet/statement.go | 9 +- internal/jet/with_statement.go | 78 +++++++++++ tests/mysql/with_test.go | 61 +++++++++ tests/postgres/select_test.go | 4 +- tests/postgres/with_test.go | 214 +++++++++++++++++++++++++++++++ tests/testdata | 2 +- 13 files changed, 406 insertions(+), 39 deletions(-) create mode 100644 internal/jet/with_statement.go create mode 100644 tests/mysql/with_test.go create mode 100644 tests/postgres/with_test.go diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 6091986..7b7e27b 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -13,17 +13,18 @@ type Clause interface { type ClauseWithProjections interface { Clause - projections() ProjectionList + Projections() ProjectionList } // ClauseSelect struct type ClauseSelect struct { - Distinct bool - Projections []Projection + Distinct bool + ProjectionList []Projection } -func (s *ClauseSelect) projections() ProjectionList { - return s.Projections +// Projections returns list of projections for select clause +func (s *ClauseSelect) Projections() ProjectionList { + return s.ProjectionList } // Serialize serializes clause into SQLBuilder @@ -35,11 +36,11 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o out.WriteString("DISTINCT") } - if len(s.Projections) == 0 { + if len(s.ProjectionList) == 0 { panic("jet: SELECT clause has to have at least one projection") } - out.WriteProjections(statementType, s.Projections) + out.WriteProjections(statementType, s.ProjectionList) } // ClauseFrom struct @@ -212,13 +213,14 @@ func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, opti type ClauseSetStmtOperator struct { Operator string All bool - Selects []StatementWithProjections + Selects []SerializerStatement OrderBy ClauseOrderBy Limit ClauseLimit Offset ClauseOffset } -func (s *ClauseSetStmtOperator) projections() ProjectionList { +// Projections returns set of projections for ClauseSetStmtOperator +func (s *ClauseSetStmtOperator) Projections() ProjectionList { if len(s.Selects) > 0 { return s.Selects[0].projections() } diff --git a/internal/jet/column.go b/internal/jet/column.go index 3e4c300..2b1b930 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -105,7 +105,7 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder if c.subQuery != nil { out.WriteIdentifier(c.subQuery.Alias()) out.WriteByte('.') - out.WriteIdentifier(c.defaultAlias(), true) + out.WriteIdentifier(c.defaultAlias()) } else { if c.tableName != "" && !contains(options, ShortName) { out.WriteIdentifier(c.tableName) diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index e95bece..d4eaa57 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -145,6 +145,11 @@ func MINi(integerExpression IntegerExpression) integerWindowExpression { return newIntegerWindowFunc("MIN", integerExpression) } +// SUM is aggregate function. Returns sum of all expressions +func SUM(expression Expression) Expression { + return newWindowFunc("SUM", expression) +} + // SUMf is aggregate function. Returns sum of expression across all float expressions func SUMf(floatExpression FloatExpression) floatWindowExpression { return NewFloatWindowFunc("SUM", floatExpression) diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index 9acd8a3..52689d4 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -8,35 +8,31 @@ type SelectTable interface { } type selectTableImpl struct { - selectStmt StatementWithProjections + selectStmt SerializerStatement alias string - - projections ProjectionList } // NewSelectTable func -func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTable { - selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias} - - projectionList := selectStmt.projections().fromImpl(&selectTable) - selectTable.projections = projectionList.(ProjectionList) - - return &selectTable +func NewSelectTable(selectStmt SerializerStatement, alias string) SelectTable { + selectTable := &selectTableImpl{selectStmt: selectStmt, alias: alias} + return selectTable } -func (s *selectTableImpl) Alias() string { +func (s selectTableImpl) Alias() string { return s.alias } -func (s *selectTableImpl) AllColumns() ProjectionList { - return s.projections -} - -func (s *selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if s == nil { - panic("jet: expression table is nil. ") +func (s selectTableImpl) AllColumns() ProjectionList { + statementWithProjections, ok := s.selectStmt.(HasProjections) + if !ok { + return ProjectionList{} } + projectionList := statementWithProjections.projections().fromImpl(s) + return projectionList.(ProjectionList) +} + +func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { s.selectStmt.serialize(statement, out) out.WriteString("AS") diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 2f014cc..b8cf04a 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -29,6 +29,7 @@ const ( SetStatementType StatementType = "SET" LockStatementType StatementType = "LOCK" UnLockStatementType StatementType = "UNLOCK" + WithStatementType StatementType = "WITH" ) // Serializer interface diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 59b776f..546759b 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -201,6 +201,13 @@ func integerTypesToString(value interface{}) string { } func shouldQuoteIdentifier(identifier string) bool { + _, err := strconv.ParseInt(identifier, 10, 64) + + if err == nil { // if it is a number we should quote it + return true + } + + // check if contains non ascii characters for _, c := range identifier { if unicode.IsNumber(c) || c == '_' { continue diff --git a/internal/jet/sql_builder_test.go b/internal/jet/sql_builder_test.go index 2aad3aa..3356e6e 100644 --- a/internal/jet/sql_builder_test.go +++ b/internal/jet/sql_builder_test.go @@ -47,3 +47,13 @@ func TestFallTrough(t *testing.T) { require.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil)) require.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName}) } + +func TestShouldQuote(t *testing.T) { + require.Equal(t, shouldQuoteIdentifier("123"), true) + require.Equal(t, shouldQuoteIdentifier("123.235"), true) + require.Equal(t, shouldQuoteIdentifier("abc123"), false) + require.Equal(t, shouldQuoteIdentifier("abc.123"), true) + require.Equal(t, shouldQuoteIdentifier("abc_123"), false) + require.Equal(t, shouldQuoteIdentifier("Abc_123"), true) + require.Equal(t, shouldQuoteIdentifier("DŽƜĐǶ"), true) +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 37b2077..23ae76c 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -32,13 +32,7 @@ type Statement interface { type SerializerStatement interface { Serializer Statement -} - -// StatementWithProjections interface -type StatementWithProjections interface { - Statement HasProjections - Serializer } // HasProjections interface @@ -163,7 +157,7 @@ type statementImpl struct { func (s *statementImpl) projections() ProjectionList { for _, clause := range s.Clauses { if selectClause, ok := clause.(ClauseWithProjections); ok { - return selectClause.projections() + return selectClause.Projections() } } @@ -171,7 +165,6 @@ func (s *statementImpl) projections() ProjectionList { } func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, NoWrap) { out.WriteString("(") out.IncreaseIdent() diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go new file mode 100644 index 0000000..6131b35 --- /dev/null +++ b/internal/jet/with_statement.go @@ -0,0 +1,78 @@ +package jet + +// WITH function creates new with statement from list of common table expressions for specified dialect +func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement SerializerStatement) Statement { + newWithImpl := &withImpl{ + ctes: cte, + serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ + dialect: dialect, + statementType: WithStatementType, + }, + } + newWithImpl.parent = newWithImpl + + return func(primaryStatement SerializerStatement) Statement { + newWithImpl.primaryStatement = primaryStatement + return newWithImpl + } +} + +type withImpl struct { + serializerStatementInterfaceImpl + ctes []CommonTableExpressionDefinition + primaryStatement SerializerStatement +} + +func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.NewLine() + out.WriteString("WITH") + + for i, cte := range w.ctes { + if i > 0 { + out.WriteString(",") + } + + cte.serialize(statement, out, FallTrough(options)...) + } + w.primaryStatement.serialize(statement, out, NoWrap.WithFallTrough(options)...) +} + +func (w withImpl) projections() ProjectionList { + return ProjectionList{} +} + +// CommonTableExpression contains information about a CTE. +type CommonTableExpression struct { + selectTableImpl +} + +// CTE creates new named CommonTableExpression +func CTE(name string) CommonTableExpression { + return CommonTableExpression{ + selectTableImpl: selectTableImpl{ + selectStmt: nil, + alias: name, + }, + } +} + +func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteIdentifier(c.alias) +} + +// AS returns sets definition for a CTE +func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition { + c.selectStmt = statement + return CommonTableExpressionDefinition{cte: c} +} + +// CommonTableExpressionDefinition contains implementation details of CTE +type CommonTableExpressionDefinition struct { + cte *CommonTableExpression +} + +func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteIdentifier(c.cte.alias) + out.WriteString("AS") + c.cte.selectStmt.serialize(statement, out, FallTrough(options)...) +} diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go new file mode 100644 index 0000000..7e3a8dd --- /dev/null +++ b/tests/mysql/with_test.go @@ -0,0 +1,61 @@ +package mysql + +import ( + "github.com/go-jet/jet/internal/testutils" + . "github.com/go-jet/jet/mysql" + . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" + "github.com/stretchr/testify/require" + "strings" + "testing" +) + +func TestWITH_SELECT(t *testing.T) { + salesRep := CTE("sales_rep") + salesRepStaffID := Staff.StaffID.From(salesRep) + salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep) + customerSalesRep := CTE("customer_sales_rep") + + stmt := WITH( + salesRep.AS( + SELECT( + Staff.StaffID, + Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()), + ).FROM(Staff), + ), + customerSalesRep.AS( + SELECT( + Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"), + salesRepFullName, + ).FROM( + salesRep. + INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)). + INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)), + ), + ), + )( + SELECT(customerSalesRep.AllColumns()). + FROM(customerSalesRep), + ) + + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, strings.Replace(` +WITH sales_rep AS ( + SELECT staff.staff_id AS "staff.staff_id", + (CONCAT(staff.first_name, staff.last_name)) AS "sales_rep_full_name" + FROM dvds.staff +),customer_sales_rep AS ( + SELECT (CONCAT(customer.first_name, customer.last_name)) AS "customer_name", + sales_rep.sales_rep_full_name AS "sales_rep_full_name" + FROM sales_rep + INNER JOIN dvds.store ON (store.manager_staff_id = sales_rep.''staff.staff_id'') + INNER JOIN dvds.customer ON (customer.store_id = store.store_id) +) +SELECT customer_sales_rep.customer_name AS "customer_name", + customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name" +FROM customer_sales_rep; +`, "''", "`", -1)) + + _, err := stmt.Exec(db) + require.NoError(t, err) +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 35f9803..b8e759d 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1088,7 +1088,7 @@ SELECT customer.customer_id AS "customer.customer_id", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active", - customer_payment_sum."amount_sum" AS "CustomerWithAmounts.AmountSum" + customer_payment_sum.amount_sum AS "CustomerWithAmounts.AmountSum" FROM dvds.customer INNER JOIN ( SELECT payment.customer_id AS "payment.customer_id", @@ -1096,7 +1096,7 @@ FROM dvds.customer FROM dvds.payment GROUP BY payment.customer_id ) AS customer_payment_sum ON (customer.customer_id = customer_payment_sum."payment.customer_id") -ORDER BY customer_payment_sum."amount_sum" ASC; +ORDER BY customer_payment_sum.amount_sum ASC; ` customersPayments := Payment. diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go new file mode 100644 index 0000000..3a21c63 --- /dev/null +++ b/tests/postgres/with_test.go @@ -0,0 +1,214 @@ +package postgres + +import ( + "github.com/go-jet/jet/internal/testutils" + . "github.com/go-jet/jet/postgres" + "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/model" + . "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/table" + "github.com/stretchr/testify/require" + "testing" +) + +func TestWithRegionalSales(t *testing.T) { + regionalSales := CTE("regional_sales") + topRegion := CTE("top_region") + + regionalSalesTotalSales := IntegerColumn("total_sales").From(regionalSales) + regionalSalesShipRegion := Orders.ShipRegion.From(regionalSales) + topRegionShipRegion := regionalSalesShipRegion.From(topRegion) + + stmt := WITH( + regionalSales.AS( + SELECT( + Orders.ShipRegion, + SUM(OrderDetails.Quantity).AS(regionalSalesTotalSales.Name()), + ). + FROM(Orders.INNER_JOIN(OrderDetails, OrderDetails.OrderID.EQ(Orders.OrderID))). + GROUP_BY(Orders.ShipRegion), + ), + topRegion.AS( + SELECT(regionalSalesShipRegion). + FROM(regionalSales). + WHERE(regionalSalesTotalSales.GT( + IntExp( + SELECT(SUM(regionalSalesTotalSales)). + FROM(regionalSales), + ).DIV(Int(50)), + ), + ), + ), + )( + SELECT( + Orders.ShipRegion, + OrderDetails.ProductID, + COUNT(STAR).AS("product_units"), + SUM(OrderDetails.Quantity).AS("product_sales"), + ). + FROM(Orders.INNER_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID))). + WHERE(Orders.ShipRegion.IN( + topRegion.SELECT(topRegionShipRegion)), + ). + GROUP_BY(Orders.ShipRegion, OrderDetails.ProductID). + ORDER_BY(SUM(OrderDetails.Quantity).DESC()), + ) + + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +WITH regional_sales AS ( + SELECT orders.ship_region AS "orders.ship_region", + SUM(order_details.quantity) AS "total_sales" + FROM northwind.orders + INNER JOIN northwind.order_details ON (order_details.order_id = orders.order_id) + GROUP BY orders.ship_region +),top_region AS ( + SELECT regional_sales."orders.ship_region" AS "orders.ship_region" + FROM regional_sales + WHERE regional_sales.total_sales > (( + SELECT SUM(regional_sales.total_sales) + FROM regional_sales + ) / 50) +) +SELECT orders.ship_region AS "orders.ship_region", + order_details.product_id AS "order_details.product_id", + COUNT(*) AS "product_units", + SUM(order_details.quantity) AS "product_sales" +FROM northwind.orders + INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) +WHERE orders.ship_region IN (( + SELECT top_region."orders.ship_region" AS "orders.ship_region" + FROM top_region + )) +GROUP BY orders.ship_region, order_details.product_id +ORDER BY SUM(order_details.quantity) DESC; +`) + + _, err := stmt.Exec(db) + require.NoError(t, err) +} + +func TestWithStatementDeleteAndInsert(t *testing.T) { + removeDiscontinuedOrders := CTE("remove_discontinued_orders") + updateDiscontinuedPrice := CTE("update_discontinued_price") + logDiscontinuedProducts := CTE("log_discontinued") + + discontinuedProductID := OrderDetails.ProductID.From(removeDiscontinuedOrders) + + stmt := WITH( + removeDiscontinuedOrders.AS( + OrderDetails.DELETE(). + WHERE(OrderDetails.ProductID.IN( + SELECT(Products.ProductID). + FROM(Products). + WHERE(Products.Discontinued.EQ(Int(1)))), + ).RETURNING(OrderDetails.ProductID), + ), + updateDiscontinuedPrice.AS( + Products.UPDATE(). + SET( + Products.UnitPrice.SET(Float(0.0)), + ). + WHERE(Products.ProductID.IN(removeDiscontinuedOrders.SELECT(discontinuedProductID))). + RETURNING(Products.AllColumns), + ), + logDiscontinuedProducts.AS( + ProductLogs.INSERT(ProductLogs.AllColumns). + QUERY(SELECT(updateDiscontinuedPrice.AllColumns()).FROM(updateDiscontinuedPrice)). + RETURNING( + ProductLogs.ProductID, + ProductLogs.ProductName, + ProductLogs.SupplierID, + ProductLogs.CategoryID, + ProductLogs.QuantityPerUnit, + ProductLogs.UnitPrice, + ProductLogs.UnitsInStock, + ProductLogs.UnitsOnOrder, + ProductLogs.ReorderLevel, + ProductLogs.Discontinued, + ), + ), + )( + SELECT(logDiscontinuedProducts.AllColumns()). + FROM(logDiscontinuedProducts), + ) + + require.Equal(t, len(removeDiscontinuedOrders.AllColumns()), 1) + require.Equal(t, len(updateDiscontinuedPrice.AllColumns()[0].(ProjectionList)), 10) + require.Equal(t, len(logDiscontinuedProducts.AllColumns()), 10) + + //fmt.Println(stmt.Sql()) + + testutils.AssertStatementSql(t, stmt, ` +WITH remove_discontinued_orders AS ( + DELETE FROM northwind.order_details + WHERE order_details.product_id IN (( + SELECT products.product_id AS "products.product_id" + FROM northwind.products + WHERE products.discontinued = $1 + )) + RETURNING order_details.product_id AS "order_details.product_id" +),update_discontinued_price AS ( + UPDATE northwind.products + SET unit_price = $2 + WHERE products.product_id IN (( + SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" + FROM remove_discontinued_orders + )) + RETURNING products.product_id AS "products.product_id", + products.product_name AS "products.product_name", + products.supplier_id AS "products.supplier_id", + products.category_id AS "products.category_id", + products.quantity_per_unit AS "products.quantity_per_unit", + products.unit_price AS "products.unit_price", + products.units_in_stock AS "products.units_in_stock", + products.units_on_order AS "products.units_on_order", + products.reorder_level AS "products.reorder_level", + products.discontinued AS "products.discontinued" +),log_discontinued AS ( + INSERT INTO northwind.product_logs (product_id, product_name, supplier_id, category_id, quantity_per_unit, unit_price, units_in_stock, units_on_order, reorder_level, discontinued) ( + SELECT update_discontinued_price."products.product_id" AS "products.product_id", + update_discontinued_price."products.product_name" AS "products.product_name", + update_discontinued_price."products.supplier_id" AS "products.supplier_id", + update_discontinued_price."products.category_id" AS "products.category_id", + update_discontinued_price."products.quantity_per_unit" AS "products.quantity_per_unit", + update_discontinued_price."products.unit_price" AS "products.unit_price", + update_discontinued_price."products.units_in_stock" AS "products.units_in_stock", + update_discontinued_price."products.units_on_order" AS "products.units_on_order", + update_discontinued_price."products.reorder_level" AS "products.reorder_level", + update_discontinued_price."products.discontinued" AS "products.discontinued" + FROM update_discontinued_price + ) + RETURNING product_logs.product_id AS "product_logs.product_id", + product_logs.product_name AS "product_logs.product_name", + product_logs.supplier_id AS "product_logs.supplier_id", + product_logs.category_id AS "product_logs.category_id", + product_logs.quantity_per_unit AS "product_logs.quantity_per_unit", + product_logs.unit_price AS "product_logs.unit_price", + product_logs.units_in_stock AS "product_logs.units_in_stock", + product_logs.units_on_order AS "product_logs.units_on_order", + product_logs.reorder_level AS "product_logs.reorder_level", + product_logs.discontinued AS "product_logs.discontinued" +) +SELECT log_discontinued."product_logs.product_id" AS "product_logs.product_id", + log_discontinued."product_logs.product_name" AS "product_logs.product_name", + log_discontinued."product_logs.supplier_id" AS "product_logs.supplier_id", + log_discontinued."product_logs.category_id" AS "product_logs.category_id", + log_discontinued."product_logs.quantity_per_unit" AS "product_logs.quantity_per_unit", + log_discontinued."product_logs.unit_price" AS "product_logs.unit_price", + log_discontinued."product_logs.units_in_stock" AS "product_logs.units_in_stock", + log_discontinued."product_logs.units_on_order" AS "product_logs.units_on_order", + log_discontinued."product_logs.reorder_level" AS "product_logs.reorder_level", + log_discontinued."product_logs.discontinued" AS "product_logs.discontinued" +FROM log_discontinued; +`, int64(1), 0.0) + + var resp []model.ProductLogs + + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + err = stmt.Query(tx, &resp) + require.NoError(t, err) + +} diff --git a/tests/testdata b/tests/testdata index 1745be3..ed53a50 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1745be34a649c0f37d0d31d7c0352a1248ace2dc +Subproject commit ed53a505eb738d1be457877eee251f9ba0418df1 From 8aa894730c170673f095e1000a05ea6a8d62b1b4 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 24 May 2020 17:56:17 +0200 Subject: [PATCH 12/23] [PostgreSQL] Add support for WITH statements and Common Table Expressions. --- postgres/clause.go | 11 ++++++++--- postgres/delete_statement.go | 4 ++-- postgres/functions.go | 3 +++ postgres/insert_statement.go | 4 ++-- postgres/select_statement.go | 2 +- postgres/select_table.go | 2 +- postgres/set_statement.go | 18 +++++++++--------- postgres/table.go | 12 ++++++------ postgres/update_statement.go | 4 ++-- postgres/with_statement.go | 26 ++++++++++++++++++++++++++ 10 files changed, 60 insertions(+), 26 deletions(-) create mode 100644 postgres/with_statement.go diff --git a/postgres/clause.go b/postgres/clause.go index ad61b9a..3812667 100644 --- a/postgres/clause.go +++ b/postgres/clause.go @@ -5,18 +5,23 @@ import ( ) type clauseReturning struct { - Projections []jet.Projection + ProjectionList []jet.Projection } func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(r.Projections) == 0 { + if len(r.ProjectionList) == 0 { return } out.NewLine() out.WriteString("RETURNING") out.IncreaseIdent() - out.WriteProjections(statementType, r.Projections) + out.WriteProjections(statementType, r.ProjectionList) + out.DecreaseIdent() +} + +func (r clauseReturning) Projections() ProjectionList { + return r.ProjectionList } // ========================================== // diff --git a/postgres/delete_statement.go b/postgres/delete_statement.go index b39b7c1..ff62710 100644 --- a/postgres/delete_statement.go +++ b/postgres/delete_statement.go @@ -4,7 +4,7 @@ import "github.com/go-jet/jet/internal/jet" // DeleteStatement is interface for PostgreSQL DELETE statement type DeleteStatement interface { - Statement + jet.SerializerStatement WHERE(expression BoolExpression) DeleteStatement @@ -37,6 +37,6 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { } func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement { - d.Returning.Projections = projections + d.Returning.ProjectionList = projections return d } diff --git a/postgres/functions.go b/postgres/functions.go index ddd01db..b97d25a 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -87,6 +87,9 @@ var MINf = jet.MINf // MINi is aggregate function. Returns minimum value of int expression across all input values var MINi = jet.MINi +// SUM is aggregate function. Returns sum of all expressions +var SUM = jet.SUM + // SUMf is aggregate function. Returns sum of expression across all float expressions var SUMf = jet.SUMf diff --git a/postgres/insert_statement.go b/postgres/insert_statement.go index e4c72f5..da13370 100644 --- a/postgres/insert_statement.go +++ b/postgres/insert_statement.go @@ -4,7 +4,7 @@ import "github.com/go-jet/jet/internal/jet" // InsertStatement is interface for SQL INSERT statements type InsertStatement interface { - Statement + jet.SerializerStatement // Insert row of values VALUES(value interface{}, values ...interface{}) InsertStatement @@ -55,7 +55,7 @@ func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { } func (i *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement { - i.Returning.Projections = projections + i.Returning.ProjectionList = projections return i } diff --git a/postgres/select_statement.go b/postgres/select_statement.go index c001a57..3e49534 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -75,7 +75,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.Limit, &newSelect.Offset, &newSelect.For) - newSelect.Select.Projections = projections + newSelect.Select.ProjectionList = projections newSelect.From.Table = table newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 diff --git a/postgres/select_table.go b/postgres/select_table.go index 8dea4cc..fe96bbe 100644 --- a/postgres/select_table.go +++ b/postgres/select_table.go @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/postgres/set_statement.go b/postgres/set_statement.go index cc6d1fe..4c83ecf 100644 --- a/postgres/set_statement.go +++ b/postgres/set_statement.go @@ -4,37 +4,37 @@ import "github.com/go-jet/jet/internal/jet" // UNION effectively appends the result of sub-queries(select statements) into single query. // It eliminates duplicate rows from its result. -func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { +func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) } // UNION_ALL effectively appends the result of sub-queries(select statements) into single query. // It does not eliminates duplicate rows from its result. -func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { +func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) } // INTERSECT returns all rows that are in query results. // It eliminates duplicate rows from its result. -func INTERSECT(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { +func INTERSECT(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...)) } // INTERSECT_ALL returns all rows that are in query results. // It does not eliminates duplicate rows from its result. -func INTERSECT_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { +func INTERSECT_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...)) } // EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs. // It eliminates duplicate rows from its result. -func EXCEPT(lhs, rhs jet.StatementWithProjections) setStatement { +func EXCEPT(lhs, rhs jet.SerializerStatement) setStatement { return newSetStatementImpl(except, false, toSelectList(lhs, rhs)) } // EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs. // It does not eliminates duplicate rows from its result. -func EXCEPT_ALL(lhs, rhs jet.StatementWithProjections) setStatement { +func EXCEPT_ALL(lhs, rhs jet.SerializerStatement) setStatement { return newSetStatementImpl(except, true, toSelectList(lhs, rhs)) } @@ -98,7 +98,7 @@ type setStatementImpl struct { setOperator jet.ClauseSetStmtOperator } -func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) setStatement { +func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement { newSetStatement := &setStatementImpl{} newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement, &newSetStatement.setOperator) @@ -139,6 +139,6 @@ const ( except = "EXCEPT" ) -func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections { - return append([]jet.StatementWithProjections{lhs, rhs}, selects...) +func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement { + return append([]jet.SerializerStatement{lhs, rhs}, selects...) } diff --git a/postgres/table.go b/postgres/table.go index c82a8f7..e928004 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -54,30 +54,30 @@ type readableTableInterfaceImpl struct { } // Generates a select query on the current tableName. -func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { +func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) } // Creates a inner join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) } // Creates a left join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) } // Creates a right join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return newJoinTable(r.parent, table, jet.RightJoin, onCondition) } -func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return newJoinTable(r.parent, table, jet.FullJoin, onCondition) } -func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable { +func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable { return newJoinTable(r.parent, table, jet.CrossJoin, nil) } diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 9c56012..29fd9c8 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -6,7 +6,7 @@ import ( // UpdateStatement is interface of SQL UPDATE statement type UpdateStatement interface { - Statement + jet.SerializerStatement SET(value interface{}, values ...interface{}) UpdateStatement MODEL(data interface{}) UpdateStatement @@ -67,7 +67,7 @@ func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { } func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement { - u.Returning.Projections = projections + u.Returning.ProjectionList = projections return u } diff --git a/postgres/with_statement.go b/postgres/with_statement.go new file mode 100644 index 0000000..caa7100 --- /dev/null +++ b/postgres/with_statement.go @@ -0,0 +1,26 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +// CommonTableExpression contains information about a CTE. +type CommonTableExpression struct { + readableTableInterfaceImpl + jet.CommonTableExpression +} + +// WITH function creates new WITH statement from list of common table expressions +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement { + return jet.WITH(Dialect, cte...) +} + +// CTE creates new named CommonTableExpression +func CTE(name string) CommonTableExpression { + cte := CommonTableExpression{ + readableTableInterfaceImpl: readableTableInterfaceImpl{}, + CommonTableExpression: jet.CTE(name), + } + + cte.parent = &cte + + return cte +} From f5fae577d7fd9a0fb222a8644b727f36dd45dac7 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 24 May 2020 17:56:35 +0200 Subject: [PATCH 13/23] [MySQL] Add support for WITH statements and Common Table Expressions. --- mysql/functions.go | 3 +++ mysql/select_statement.go | 2 +- mysql/select_table.go | 2 +- mysql/set_statement.go | 10 +++++----- mysql/table.go | 12 ++++++------ mysql/with_statement.go | 26 ++++++++++++++++++++++++++ 6 files changed, 42 insertions(+), 13 deletions(-) create mode 100644 mysql/with_statement.go diff --git a/mysql/functions.go b/mysql/functions.go index 17702b7..1ee5a5d 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -85,6 +85,9 @@ var MINi = jet.MINi // MINf is aggregate function. Returns minimum value of float expression across all input values var MINf = jet.MINf +// SUM is aggregate function. Returns sum of all expressions +var SUM = jet.SUM + // SUMi is aggregate function. Returns sum of integer expression. var SUMi = jet.SUMi diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 720ae57..4a7f275 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -69,7 +69,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) - newSelect.Select.Projections = projections + newSelect.Select.ProjectionList = projections newSelect.From.Table = table newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 diff --git a/mysql/select_table.go b/mysql/select_table.go index 1ac02a8..9d45ba3 100644 --- a/mysql/select_table.go +++ b/mysql/select_table.go @@ -13,7 +13,7 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { subQuery := &selectTableImpl{ SelectTable: jet.NewSelectTable(selectStmt, alias), } diff --git a/mysql/set_statement.go b/mysql/set_statement.go index ba14ece..596741d 100644 --- a/mysql/set_statement.go +++ b/mysql/set_statement.go @@ -4,13 +4,13 @@ import "github.com/go-jet/jet/internal/jet" // UNION effectively appends the result of sub-queries(select statements) into single query. // It eliminates duplicate rows from its result. -func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { +func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) } // UNION_ALL effectively appends the result of sub-queries(select statements) into single query. // It does not eliminates duplicate rows from its result. -func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { +func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) } @@ -54,7 +54,7 @@ type setStatementImpl struct { setOperator jet.ClauseSetStmtOperator } -func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) setStatement { +func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement { newSetStatement := &setStatementImpl{} newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement, &newSetStatement.setOperator) @@ -93,6 +93,6 @@ const ( union = "UNION" ) -func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections { - return append([]jet.StatementWithProjections{lhs, rhs}, selects...) +func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement { + return append([]jet.SerializerStatement{lhs, rhs}, selects...) } diff --git a/mysql/table.go b/mysql/table.go index 8287159..ee798e2 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -49,30 +49,30 @@ type readableTableInterfaceImpl struct { } // Generates a select query on the current tableName. -func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { +func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) } // Creates a inner join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { +func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) } // Creates a left join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { +func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) } // Creates a right join tableName Expression using onCondition. -func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { +func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { return newJoinTable(r.parent, table, jet.RightJoin, onCondition) } -func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { +func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { return newJoinTable(r.parent, table, jet.FullJoin, onCondition) } -func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable { +func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable { return newJoinTable(r.parent, table, jet.CrossJoin, nil) } diff --git a/mysql/with_statement.go b/mysql/with_statement.go new file mode 100644 index 0000000..5991287 --- /dev/null +++ b/mysql/with_statement.go @@ -0,0 +1,26 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +// CommonTableExpression contains information about a CTE. +type CommonTableExpression struct { + readableTableInterfaceImpl + jet.CommonTableExpression +} + +// WITH function creates new WITH statement from list of common table expressions +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement { + return jet.WITH(Dialect, cte...) +} + +// CTE creates new named CommonTableExpression +func CTE(name string) CommonTableExpression { + cte := CommonTableExpression{ + readableTableInterfaceImpl: readableTableInterfaceImpl{}, + CommonTableExpression: jet.CTE(name), + } + + cte.parent = &cte + + return cte +} From ac0fd9a6f62058ec6e346d957747c483cda29110 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 24 May 2020 18:09:16 +0200 Subject: [PATCH 14/23] Fix unit tests. --- internal/jet/column_types_test.go | 42 +++++++++++++++---------------- postgres/columns_test.go | 6 ++--- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/internal/jet/column_types_test.go b/internal/jet/column_types_test.go index 9cff309..059d722 100644 --- a/internal/jet/column_types_test.go +++ b/internal/jet/column_types_test.go @@ -22,9 +22,9 @@ func TestNewBoolColumn(t *testing.T) { func TestNewIntColumn(t *testing.T) { intColumn := IntegerColumn("col_int").From(subQuery) - assertClauseSerialize(t, intColumn, `sub_query."col_int"`) - assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query."col_int" = $1)`, int64(12)) - assertProjectionSerialize(t, intColumn, `sub_query."col_int" AS "col_int"`) + assertClauseSerialize(t, intColumn, `sub_query.col_int`) + assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query.col_int = $1)`, int64(12)) + assertProjectionSerialize(t, intColumn, `sub_query.col_int AS "col_int"`) intColumn2 := table1ColInt.From(subQuery) assertClauseSerialize(t, intColumn2, `sub_query."table1.col_int"`) @@ -35,9 +35,9 @@ func TestNewIntColumn(t *testing.T) { func TestNewFloatColumnColumn(t *testing.T) { floatColumn := FloatColumn("col_float").From(subQuery) - assertClauseSerialize(t, floatColumn, `sub_query."col_float"`) - assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query."col_float" = $1)`, float64(1.11)) - assertProjectionSerialize(t, floatColumn, `sub_query."col_float" AS "col_float"`) + assertClauseSerialize(t, floatColumn, `sub_query.col_float`) + assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query.col_float = $1)`, float64(1.11)) + assertProjectionSerialize(t, floatColumn, `sub_query.col_float AS "col_float"`) floatColumn2 := table1ColFloat.From(subQuery) assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`) @@ -47,10 +47,10 @@ func TestNewFloatColumnColumn(t *testing.T) { func TestNewDateColumnColumn(t *testing.T) { dateColumn := DateColumn("col_date").From(subQuery) - assertClauseSerialize(t, dateColumn, `sub_query."col_date"`) + assertClauseSerialize(t, dateColumn, `sub_query.col_date`) assertClauseSerialize(t, dateColumn.EQ(Date(2002, 2, 3)), - `(sub_query."col_date" = $1)`, "2002-02-03") - assertProjectionSerialize(t, dateColumn, `sub_query."col_date" AS "col_date"`) + `(sub_query.col_date = $1)`, "2002-02-03") + assertProjectionSerialize(t, dateColumn, `sub_query.col_date AS "col_date"`) dateColumn2 := table1ColDate.From(subQuery) assertClauseSerialize(t, dateColumn2, `sub_query."table1.col_date"`) @@ -61,10 +61,10 @@ func TestNewDateColumnColumn(t *testing.T) { func TestNewTimeColumnColumn(t *testing.T) { timeColumn := TimeColumn("col_time").From(subQuery) - assertClauseSerialize(t, timeColumn, `sub_query."col_time"`) + assertClauseSerialize(t, timeColumn, `sub_query.col_time`) assertClauseSerialize(t, timeColumn.EQ(Time(1, 1, 1, 1)), - `(sub_query."col_time" = $1)`, "01:01:01.000000001") - assertProjectionSerialize(t, timeColumn, `sub_query."col_time" AS "col_time"`) + `(sub_query.col_time = $1)`, "01:01:01.000000001") + assertProjectionSerialize(t, timeColumn, `sub_query.col_time AS "col_time"`) timeColumn2 := table1ColTime.From(subQuery) assertClauseSerialize(t, timeColumn2, `sub_query."table1.col_time"`) @@ -75,10 +75,10 @@ func TestNewTimeColumnColumn(t *testing.T) { func TestNewTimezColumnColumn(t *testing.T) { timezColumn := TimezColumn("col_timez").From(subQuery) - assertClauseSerialize(t, timezColumn, `sub_query."col_timez"`) + assertClauseSerialize(t, timezColumn, `sub_query.col_timez`) assertClauseSerialize(t, timezColumn.EQ(Timez(1, 1, 1, 1, "UTC")), - `(sub_query."col_timez" = $1)`, "01:01:01.000000001 UTC") - assertProjectionSerialize(t, timezColumn, `sub_query."col_timez" AS "col_timez"`) + `(sub_query.col_timez = $1)`, "01:01:01.000000001 UTC") + assertProjectionSerialize(t, timezColumn, `sub_query.col_timez AS "col_timez"`) timezColumn2 := table1ColTimez.From(subQuery) assertClauseSerialize(t, timezColumn2, `sub_query."table1.col_timez"`) @@ -89,10 +89,10 @@ func TestNewTimezColumnColumn(t *testing.T) { func TestNewTimestampColumnColumn(t *testing.T) { timestampColumn := TimestampColumn("col_timestamp").From(subQuery) - assertClauseSerialize(t, timestampColumn, `sub_query."col_timestamp"`) + assertClauseSerialize(t, timestampColumn, `sub_query.col_timestamp`) assertClauseSerialize(t, timestampColumn.EQ(Timestamp(1, 1, 1, 1, 1, 1)), - `(sub_query."col_timestamp" = $1)`, "0001-01-01 01:01:01") - assertProjectionSerialize(t, timestampColumn, `sub_query."col_timestamp" AS "col_timestamp"`) + `(sub_query.col_timestamp = $1)`, "0001-01-01 01:01:01") + assertProjectionSerialize(t, timestampColumn, `sub_query.col_timestamp AS "col_timestamp"`) timestampColumn2 := table1ColTimestamp.From(subQuery) assertClauseSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp"`) @@ -103,10 +103,10 @@ func TestNewTimestampColumnColumn(t *testing.T) { func TestNewTimestampzColumnColumn(t *testing.T) { timestampzColumn := TimestampzColumn("col_timestampz").From(subQuery) - assertClauseSerialize(t, timestampzColumn, `sub_query."col_timestampz"`) + assertClauseSerialize(t, timestampzColumn, `sub_query.col_timestampz`) assertClauseSerialize(t, timestampzColumn.EQ(Timestampz(1, 1, 1, 1, 1, 1, 0, "UTC")), - `(sub_query."col_timestampz" = $1)`, "0001-01-01 01:01:01 UTC") - assertProjectionSerialize(t, timestampzColumn, `sub_query."col_timestampz" AS "col_timestampz"`) + `(sub_query.col_timestampz = $1)`, "0001-01-01 01:01:01 UTC") + assertProjectionSerialize(t, timestampzColumn, `sub_query.col_timestampz AS "col_timestampz"`) timestampzColumn2 := table1ColTimestampz.From(subQuery) assertClauseSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz"`) diff --git a/postgres/columns_test.go b/postgres/columns_test.go index b00c4c6..f53303b 100644 --- a/postgres/columns_test.go +++ b/postgres/columns_test.go @@ -8,10 +8,10 @@ func TestNewIntervalColumn(t *testing.T) { subQuery := SELECT(Int(1)).AsTable("sub_query") subQueryIntervalColumn := IntervalColumn("col_interval").From(subQuery) - assertSerialize(t, subQueryIntervalColumn, `sub_query."col_interval"`) + assertSerialize(t, subQueryIntervalColumn, `sub_query.col_interval`) assertSerialize(t, subQueryIntervalColumn.EQ(INTERVAL(2, HOUR, 10, MINUTE)), - `(sub_query."col_interval" = INTERVAL '2 HOUR 10 MINUTE')`) - assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query."col_interval" AS "col_interval"`) + `(sub_query.col_interval = INTERVAL '2 HOUR 10 MINUTE')`) + assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query.col_interval AS "col_interval"`) subQueryIntervalColumn2 := table1ColInterval.From(subQuery) assertSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval"`) From 196989ab68ad14cdb99a3baceb0a77c6280f213b Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 24 May 2020 18:20:04 +0200 Subject: [PATCH 15/23] Update README.md. --- README.md | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 932fff5..f6840ab 100644 --- a/README.md +++ b/README.md @@ -35,17 +35,19 @@ https://medium.com/@go.jet/jet-5f3667efa0cc ## Features 1) Auto-generated type-safe SQL Builder - PostgreSQL: - * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` - * INSERT `(VALUES, query, RETURNING)`, - * UPDATE `(SET, WHERE, RETURNING)`, - * DELETE `(WHERE, RETURNING)`, - * LOCK `(IN, NOWAIT)` + * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` + * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, query, RETURNING)`, + * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, WHERE, RETURNING)`, + * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, RETURNING)`, + * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)` + * [WITH](https://github.com/go-jet/jet/wiki/WITH) - MySQL and MariaDB: - * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)` - * INSERT `(VALUES, query)`, - * UPDATE `(SET, WHERE)`, - * DELETE `(WHERE, ORDER_BY, LIMIT)`, - * LOCK `(READ, WRITE)` + * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)` + * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, query)`, + * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, WHERE)`, + * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT)`, + * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(READ, WRITE)` + * [WITH](https://github.com/go-jet/jet/wiki/WITH) 2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store result of database queries. Can be combined to create desired query result destination. 3) Query execution with result mapping to arbitrary destination structure. From 0183117b723b07314dbc937e0f01a9eebd2c3117 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 31 May 2020 10:42:55 +0200 Subject: [PATCH 16/23] Update quick start example. --- .../.gen/jetdb/dvds/enum/mpaa_rating.go | 1 - .../.gen/jetdb/dvds/model/actor.go | 1 - .../.gen/jetdb/dvds/model/category.go | 1 - .../quick-start/.gen/jetdb/dvds/model/film.go | 1 - .../.gen/jetdb/dvds/model/film_actor.go | 1 - .../.gen/jetdb/dvds/model/film_category.go | 1 - .../.gen/jetdb/dvds/model/language.go | 1 - .../.gen/jetdb/dvds/model/mpaa_rating.go | 1 - .../.gen/jetdb/dvds/table/actor.go | 30 ++++++++---- .../.gen/jetdb/dvds/table/category.go | 30 ++++++++---- .../quick-start/.gen/jetdb/dvds/table/film.go | 30 ++++++++---- .../.gen/jetdb/dvds/table/film_actor.go | 30 ++++++++---- .../.gen/jetdb/dvds/table/film_category.go | 30 ++++++++---- .../.gen/jetdb/dvds/table/language.go | 30 ++++++++---- .../.gen/jetdb/dvds/view/actor_info.go | 30 ++++++++---- .../.gen/jetdb/dvds/view/customer_list.go | 48 ++++++++++++------- examples/quick-start/README.md | 8 ++-- examples/quick-start/quick-start.go | 6 +-- 18 files changed, 184 insertions(+), 96 deletions(-) diff --git a/examples/quick-start/.gen/jetdb/dvds/enum/mpaa_rating.go b/examples/quick-start/.gen/jetdb/dvds/enum/mpaa_rating.go index e1dd269..60cb23b 100644 --- a/examples/quick-start/.gen/jetdb/dvds/enum/mpaa_rating.go +++ b/examples/quick-start/.gen/jetdb/dvds/enum/mpaa_rating.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/model/actor.go b/examples/quick-start/.gen/jetdb/dvds/model/actor.go index 56222ed..922dff5 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/actor.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/actor.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/model/category.go b/examples/quick-start/.gen/jetdb/dvds/model/category.go index 354d71b..a447fba 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/category.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/category.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/model/film.go b/examples/quick-start/.gen/jetdb/dvds/model/film.go index 0ccaa83..04a1ae5 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/film.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/film.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/model/film_actor.go b/examples/quick-start/.gen/jetdb/dvds/model/film_actor.go index 7e0aa87..f0fe9f4 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/film_actor.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/film_actor.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/model/film_category.go b/examples/quick-start/.gen/jetdb/dvds/model/film_category.go index 846e554..97fe013 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/film_category.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/film_category.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/model/language.go b/examples/quick-start/.gen/jetdb/dvds/model/language.go index 3e4bc17..ef0e592 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/language.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/language.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go b/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go index 821a11f..c802aa9 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated diff --git a/examples/quick-start/.gen/jetdb/dvds/table/actor.go b/examples/quick-start/.gen/jetdb/dvds/table/actor.go index d86132c..4332015 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/actor.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/actor.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var Actor = newActorTable() -type ActorTable struct { +type actorTable struct { postgres.Table //Columns @@ -27,25 +26,38 @@ type ActorTable struct { MutableColumns postgres.ColumnList } -// creates new ActorTable with assigned alias +type ActorTable struct { + actorTable + + EXCLUDED actorTable +} + +// AS creates new ActorTable with assigned alias func (a *ActorTable) AS(alias string) *ActorTable { aliasTable := newActorTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorTable() *ActorTable { + return &ActorTable{ + actorTable: newActorTableImpl("dvds", "actor"), + EXCLUDED: newActorTableImpl("", "excluded"), + } +} + +func newActorTableImpl(schemaName, tableName string) actorTable { var ( ActorIDColumn = postgres.IntegerColumn("actor_id") FirstNameColumn = postgres.StringColumn("first_name") LastNameColumn = postgres.StringColumn("last_name") LastUpdateColumn = postgres.TimestampColumn("last_update") + allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn} + mutableColumns = postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} ) - return &ActorTable{ - Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), + return actorTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -53,7 +65,7 @@ func newActorTable() *ActorTable { LastName: LastNameColumn, LastUpdate: LastUpdateColumn, - AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, - MutableColumns: postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/.gen/jetdb/dvds/table/category.go b/examples/quick-start/.gen/jetdb/dvds/table/category.go index 8e42de9..6b34fdf 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/category.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/category.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var Category = newCategoryTable() -type CategoryTable struct { +type categoryTable struct { postgres.Table //Columns @@ -26,31 +25,44 @@ type CategoryTable struct { MutableColumns postgres.ColumnList } -// creates new CategoryTable with assigned alias +type CategoryTable struct { + categoryTable + + EXCLUDED categoryTable +} + +// AS creates new CategoryTable with assigned alias func (a *CategoryTable) AS(alias string) *CategoryTable { aliasTable := newCategoryTable() - aliasTable.Table.AS(alias) - return aliasTable } func newCategoryTable() *CategoryTable { + return &CategoryTable{ + categoryTable: newCategoryTableImpl("dvds", "category"), + EXCLUDED: newCategoryTableImpl("", "excluded"), + } +} + +func newCategoryTableImpl(schemaName, tableName string) categoryTable { var ( CategoryIDColumn = postgres.IntegerColumn("category_id") NameColumn = postgres.StringColumn("name") LastUpdateColumn = postgres.TimestampColumn("last_update") + allColumns = postgres.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn} + mutableColumns = postgres.ColumnList{NameColumn, LastUpdateColumn} ) - return &CategoryTable{ - Table: postgres.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn), + return categoryTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns CategoryID: CategoryIDColumn, Name: NameColumn, LastUpdate: LastUpdateColumn, - AllColumns: postgres.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn}, - MutableColumns: postgres.ColumnList{NameColumn, LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/.gen/jetdb/dvds/table/film.go b/examples/quick-start/.gen/jetdb/dvds/table/film.go index 6c8a8c2..3550465 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/film.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/film.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var Film = newFilmTable() -type FilmTable struct { +type filmTable struct { postgres.Table //Columns @@ -36,16 +35,27 @@ type FilmTable struct { MutableColumns postgres.ColumnList } -// creates new FilmTable with assigned alias +type FilmTable struct { + filmTable + + EXCLUDED filmTable +} + +// AS creates new FilmTable with assigned alias func (a *FilmTable) AS(alias string) *FilmTable { aliasTable := newFilmTable() - aliasTable.Table.AS(alias) - return aliasTable } func newFilmTable() *FilmTable { + return &FilmTable{ + filmTable: newFilmTableImpl("dvds", "film"), + EXCLUDED: newFilmTableImpl("", "excluded"), + } +} + +func newFilmTableImpl(schemaName, tableName string) filmTable { var ( FilmIDColumn = postgres.IntegerColumn("film_id") TitleColumn = postgres.StringColumn("title") @@ -60,10 +70,12 @@ func newFilmTable() *FilmTable { LastUpdateColumn = postgres.TimestampColumn("last_update") SpecialFeaturesColumn = postgres.StringColumn("special_features") FulltextColumn = postgres.StringColumn("fulltext") + allColumns = postgres.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn} + mutableColumns = postgres.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn} ) - return &FilmTable{ - Table: postgres.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn), + return filmTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns FilmID: FilmIDColumn, @@ -80,7 +92,7 @@ func newFilmTable() *FilmTable { SpecialFeatures: SpecialFeaturesColumn, Fulltext: FulltextColumn, - AllColumns: postgres.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, - MutableColumns: postgres.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go b/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go index b30e524..89622ec 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/film_actor.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var FilmActor = newFilmActorTable() -type FilmActorTable struct { +type filmActorTable struct { postgres.Table //Columns @@ -26,31 +25,44 @@ type FilmActorTable struct { MutableColumns postgres.ColumnList } -// creates new FilmActorTable with assigned alias +type FilmActorTable struct { + filmActorTable + + EXCLUDED filmActorTable +} + +// AS creates new FilmActorTable with assigned alias func (a *FilmActorTable) AS(alias string) *FilmActorTable { aliasTable := newFilmActorTable() - aliasTable.Table.AS(alias) - return aliasTable } func newFilmActorTable() *FilmActorTable { + return &FilmActorTable{ + filmActorTable: newFilmActorTableImpl("dvds", "film_actor"), + EXCLUDED: newFilmActorTableImpl("", "excluded"), + } +} + +func newFilmActorTableImpl(schemaName, tableName string) filmActorTable { var ( ActorIDColumn = postgres.IntegerColumn("actor_id") FilmIDColumn = postgres.IntegerColumn("film_id") LastUpdateColumn = postgres.TimestampColumn("last_update") + allColumns = postgres.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn} + mutableColumns = postgres.ColumnList{LastUpdateColumn} ) - return &FilmActorTable{ - Table: postgres.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn), + return filmActorTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, FilmID: FilmIDColumn, LastUpdate: LastUpdateColumn, - AllColumns: postgres.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn}, - MutableColumns: postgres.ColumnList{LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/.gen/jetdb/dvds/table/film_category.go b/examples/quick-start/.gen/jetdb/dvds/table/film_category.go index 3605fe1..eb932c4 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/film_category.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/film_category.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var FilmCategory = newFilmCategoryTable() -type FilmCategoryTable struct { +type filmCategoryTable struct { postgres.Table //Columns @@ -26,31 +25,44 @@ type FilmCategoryTable struct { MutableColumns postgres.ColumnList } -// creates new FilmCategoryTable with assigned alias +type FilmCategoryTable struct { + filmCategoryTable + + EXCLUDED filmCategoryTable +} + +// AS creates new FilmCategoryTable with assigned alias func (a *FilmCategoryTable) AS(alias string) *FilmCategoryTable { aliasTable := newFilmCategoryTable() - aliasTable.Table.AS(alias) - return aliasTable } func newFilmCategoryTable() *FilmCategoryTable { + return &FilmCategoryTable{ + filmCategoryTable: newFilmCategoryTableImpl("dvds", "film_category"), + EXCLUDED: newFilmCategoryTableImpl("", "excluded"), + } +} + +func newFilmCategoryTableImpl(schemaName, tableName string) filmCategoryTable { var ( FilmIDColumn = postgres.IntegerColumn("film_id") CategoryIDColumn = postgres.IntegerColumn("category_id") LastUpdateColumn = postgres.TimestampColumn("last_update") + allColumns = postgres.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn} + mutableColumns = postgres.ColumnList{LastUpdateColumn} ) - return &FilmCategoryTable{ - Table: postgres.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn), + return filmCategoryTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns FilmID: FilmIDColumn, CategoryID: CategoryIDColumn, LastUpdate: LastUpdateColumn, - AllColumns: postgres.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn}, - MutableColumns: postgres.ColumnList{LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/.gen/jetdb/dvds/table/language.go b/examples/quick-start/.gen/jetdb/dvds/table/language.go index db8f513..c68a6ce 100644 --- a/examples/quick-start/.gen/jetdb/dvds/table/language.go +++ b/examples/quick-start/.gen/jetdb/dvds/table/language.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var Language = newLanguageTable() -type LanguageTable struct { +type languageTable struct { postgres.Table //Columns @@ -26,31 +25,44 @@ type LanguageTable struct { MutableColumns postgres.ColumnList } -// creates new LanguageTable with assigned alias +type LanguageTable struct { + languageTable + + EXCLUDED languageTable +} + +// AS creates new LanguageTable with assigned alias func (a *LanguageTable) AS(alias string) *LanguageTable { aliasTable := newLanguageTable() - aliasTable.Table.AS(alias) - return aliasTable } func newLanguageTable() *LanguageTable { + return &LanguageTable{ + languageTable: newLanguageTableImpl("dvds", "language"), + EXCLUDED: newLanguageTableImpl("", "excluded"), + } +} + +func newLanguageTableImpl(schemaName, tableName string) languageTable { var ( LanguageIDColumn = postgres.IntegerColumn("language_id") NameColumn = postgres.StringColumn("name") LastUpdateColumn = postgres.TimestampColumn("last_update") + allColumns = postgres.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn} + mutableColumns = postgres.ColumnList{NameColumn, LastUpdateColumn} ) - return &LanguageTable{ - Table: postgres.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn), + return languageTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns LanguageID: LanguageIDColumn, Name: NameColumn, LastUpdate: LastUpdateColumn, - AllColumns: postgres.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn}, - MutableColumns: postgres.ColumnList{NameColumn, LastUpdateColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go b/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go index 697966d..21e901a 100644 --- a/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go +++ b/examples/quick-start/.gen/jetdb/dvds/view/actor_info.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var ActorInfo = newActorInfoTable() -type ActorInfoTable struct { +type actorInfoTable struct { postgres.Table //Columns @@ -27,25 +26,38 @@ type ActorInfoTable struct { MutableColumns postgres.ColumnList } -// creates new ActorInfoTable with assigned alias +type ActorInfoTable struct { + actorInfoTable + + EXCLUDED actorInfoTable +} + +// AS creates new ActorInfoTable with assigned alias func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { aliasTable := newActorInfoTable() - aliasTable.Table.AS(alias) - return aliasTable } func newActorInfoTable() *ActorInfoTable { + return &ActorInfoTable{ + actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"), + EXCLUDED: newActorInfoTableImpl("", "excluded"), + } +} + +func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable { var ( ActorIDColumn = postgres.IntegerColumn("actor_id") FirstNameColumn = postgres.StringColumn("first_name") LastNameColumn = postgres.StringColumn("last_name") FilmInfoColumn = postgres.StringColumn("film_info") + allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} + mutableColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn} ) - return &ActorInfoTable{ - Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), + return actorInfoTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ActorID: ActorIDColumn, @@ -53,7 +65,7 @@ func newActorInfoTable() *ActorInfoTable { LastName: LastNameColumn, FilmInfo: FilmInfoColumn, - AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, - MutableColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go b/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go index 4766158..08b45d1 100644 --- a/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go +++ b/examples/quick-start/.gen/jetdb/dvds/view/customer_list.go @@ -1,6 +1,5 @@ // // Code generated by go-jet DO NOT EDIT. -// Generated at Thursday, 26-Sep-19 12:02:13 CEST // // WARNING: Changes to this file may cause incorrect behavior // and will be lost if the code is regenerated @@ -14,7 +13,7 @@ import ( var CustomerList = newCustomerListTable() -type CustomerListTable struct { +type customerListTable struct { postgres.Table //Columns @@ -32,30 +31,43 @@ type CustomerListTable struct { MutableColumns postgres.ColumnList } -// creates new CustomerListTable with assigned alias +type CustomerListTable struct { + customerListTable + + EXCLUDED customerListTable +} + +// AS creates new CustomerListTable with assigned alias func (a *CustomerListTable) AS(alias string) *CustomerListTable { aliasTable := newCustomerListTable() - aliasTable.Table.AS(alias) - return aliasTable } func newCustomerListTable() *CustomerListTable { + return &CustomerListTable{ + customerListTable: newCustomerListTableImpl("dvds", "customer_list"), + EXCLUDED: newCustomerListTableImpl("", "excluded"), + } +} + +func newCustomerListTableImpl(schemaName, tableName string) customerListTable { var ( - IDColumn = postgres.IntegerColumn("id") - NameColumn = postgres.StringColumn("name") - AddressColumn = postgres.StringColumn("address") - ZipCodeColumn = postgres.StringColumn("zip code") - PhoneColumn = postgres.StringColumn("phone") - CityColumn = postgres.StringColumn("city") - CountryColumn = postgres.StringColumn("country") - NotesColumn = postgres.StringColumn("notes") - SidColumn = postgres.IntegerColumn("sid") + IDColumn = postgres.IntegerColumn("id") + NameColumn = postgres.StringColumn("name") + AddressColumn = postgres.StringColumn("address") + ZipCodeColumn = postgres.StringColumn("zip code") + PhoneColumn = postgres.StringColumn("phone") + CityColumn = postgres.StringColumn("city") + CountryColumn = postgres.StringColumn("country") + NotesColumn = postgres.StringColumn("notes") + SidColumn = postgres.IntegerColumn("sid") + allColumns = postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn} + mutableColumns = postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn} ) - return &CustomerListTable{ - Table: postgres.NewTable("dvds", "customer_list", IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn), + return customerListTable{ + Table: postgres.NewTable(schemaName, tableName, allColumns...), //Columns ID: IDColumn, @@ -68,7 +80,7 @@ func newCustomerListTable() *CustomerListTable { Notes: NotesColumn, Sid: SidColumn, - AllColumns: postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}, - MutableColumns: postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}, + AllColumns: allColumns, + MutableColumns: mutableColumns, } } diff --git a/examples/quick-start/README.md b/examples/quick-start/README.md index 0417515..2e5adc5 100644 --- a/examples/quick-start/README.md +++ b/examples/quick-start/README.md @@ -3,10 +3,10 @@ This package contains sample usage for Jet framework. -Jet generated files of interest are in ./gen folder. +Jet generated files of interest are in `./gen` folder. -quick-start.go contains code explained at [README.md](../../README.md#quick-start), -with difference of redirecting json output to files(dest.json and dest2.json) rather then to a +`quick-start.go` - contains code explained at main [README.md](../../README.md#quick-start), +with a difference of redirecting json output to files(`dest.json` and `dest2.json`) rather then to a standard output. -./gen, dest.json and dest2.json - added to git for presentation purposes. +`./gen`, `dest.json` and `dest2.json` - added to git for presentation purposes. diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go index 0fe031a..693edf7 100644 --- a/examples/quick-start/quick-start.go +++ b/examples/quick-start/quick-start.go @@ -7,7 +7,7 @@ import ( _ "github.com/lib/pq" "io/ioutil" - // dot import so go code would resemble as much as native SQL + // dot import so that jet go code would resemble as much as native SQL // dot import is not mandatory . "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table" . "github.com/go-jet/jet/postgres" @@ -98,15 +98,15 @@ func printStatementInfo(stmt SelectStatement) { query, args := stmt.Sql() fmt.Println("Parameterized query: ") + fmt.Println("==============================") fmt.Println(query) fmt.Println("Arguments: ") fmt.Println(args) debugSQL := stmt.DebugSql() - fmt.Println("\n\n==============================") - fmt.Println("\n\nDebug sql: ") + fmt.Println("==============================") fmt.Println(debugSQL) } From 07251841aa5c077ea3e8eb1635310890f037d5b0 Mon Sep 17 00:00:00 2001 From: go-jet <47941548+go-jet@users.noreply.github.com> Date: Sun, 31 May 2020 11:16:26 +0200 Subject: [PATCH 17/23] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index f6840ab..addaf15 100644 --- a/README.md +++ b/README.md @@ -36,15 +36,15 @@ https://medium.com/@go.jet/jet-5f3667efa0cc 1) Auto-generated type-safe SQL Builder - PostgreSQL: * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` - * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, query, RETURNING)`, - * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, WHERE, RETURNING)`, + * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, QUERY, ON_CONFLICT, RETURNING)`, + * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE, RETURNING)`, * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, RETURNING)`, * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)` * [WITH](https://github.com/go-jet/jet/wiki/WITH) - MySQL and MariaDB: * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)` - * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, query)`, - * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, WHERE)`, + * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, ON_DUPLICATE_KEY_UPDATE, query)`, + * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE)`, * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT)`, * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(READ, WRITE)` * [WITH](https://github.com/go-jet/jet/wiki/WITH) From aefa4a2ff620a46fe6b9f5aea86ef90c9ae02a2c Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 31 May 2020 12:27:11 +0200 Subject: [PATCH 18/23] Add go mod support. --- README.md | 1 + go.mod | 10 ++++++++++ go.sum | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 go.mod create mode 100644 go.sum diff --git a/README.md b/README.md index addaf15..cdebfe0 100644 --- a/README.md +++ b/README.md @@ -563,6 +563,7 @@ At the moment Jet dependence only of: To run the tests, additional dependencies are required: - `github.com/pkg/profile` - `github.com/stretchr/testify` +- `github.com/google/go-cmp` ## Versioning diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2f9c17d --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module github.com/go-jet/jet + +require ( + github.com/go-sql-driver/mysql v1.5.0 + github.com/google/go-cmp v0.4.1 + github.com/google/uuid v1.1.1 + github.com/lib/pq v1.6.0 + github.com/pkg/profile v1.5.0 + github.com/stretchr/testify v1.6.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7afed30 --- /dev/null +++ b/go.sum @@ -0,0 +1,60 @@ +github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 h1:P5U+E4x5OkVEKQDklVPmzs71WM56RTTRqV4OrDC//Y4= +github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5/go.mod h1:976q2ETgjT2snVCf2ZaBnyBbVoPERGjUz+0sofzEfro= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0= +github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ= +github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= +github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= +github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= +github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.2.0 h1:lzPl/30ZLkTveYsYZPKMcgXc8MbnE6RsTd4F9KgiLtk= +github.com/jcmturner/gokrb5/v8 v8.2.0/go.mod h1:T1hnNppQsBtxW0tCHMHTkAt8n/sABdzZgZdoFrZaZNM= +github.com/jcmturner/rpc/v2 v2.0.2 h1:gMB4IwRXYsWw4Bc6o/az2HJgFUA1ffSh90i26ZJ6Xl0= +github.com/jcmturner/rpc/v2 v2.0.2/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/lib/pq v1.6.0 h1:I5DPxhYJChW9KYc66se+oKFFQX6VuQrKiprsX6ivRZc= +github.com/lib/pq v1.6.0/go.mod h1:4vXEAYvW1fRQ2/FhZ78H73A60MHw1geSm145z2mdY1g= +github.com/pkg/profile v1.5.0 h1:042Buzk+NhDI+DeSAA62RwJL8VAuZUMQZUjCsRz1Mug= +github.com/pkg/profile v1.5.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho= +github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4 h1:QmwruyY+bKbDDL0BaglrbZABEali68eoMFhTZpCjYVA= +golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= +golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= +gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= +gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4= +gopkg.in/jcmturner/gokrb5.v7 v7.5.0/go.mod h1:l8VISx+WGYp+Fp7KRbsiUuXTTOnxIc3Tuvyavf11/WM= +gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From b8d1f97cf538ec078e95fdebb2cc6ce8de38d4e1 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 31 May 2020 20:15:17 +0200 Subject: [PATCH 19/23] Update generator version. --- cmd/jet/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/jet/main.go b/cmd/jet/main.go index af66916..8016545 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -47,7 +47,7 @@ func main() { flag.Usage = func() { _, _ = fmt.Fprint(os.Stdout, ` -Jet generator 2.0.0 +Jet generator 2.3.0 Usage: -source string From cd3325054beb24d9d93c1f4c494ce9c381a6684c Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 1 Jun 2020 17:21:52 +0200 Subject: [PATCH 20/23] Remove unused template function. --- generator/internal/template/generate.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go index 46e72f9..eb52b1f 100644 --- a/generator/internal/template/generate.go +++ b/generator/internal/template/generate.go @@ -8,7 +8,6 @@ import ( "github.com/go-jet/jet/internal/utils" "path/filepath" "text/template" - "time" ) // GenerateFiles generates Go files from tables and enums metadata @@ -84,9 +83,6 @@ func GenerateTemplate(templateText string, templateData interface{}, dialect jet t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ "ToGoIdentifier": utils.ToGoIdentifier, "ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier, - "now": func() string { - return time.Now().Format(time.RFC850) - }, "dialect": func() jet.Dialect { return dialect }, From e54e8fcabf09c17b565177b76b33bad44c316ac5 Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 1 Jun 2020 18:22:24 +0200 Subject: [PATCH 21/23] Rename LoggableStatement to PrintableStatement. --- internal/jet/logger.go | 6 +++--- mysql/types.go | 4 ++-- postgres/types.go | 4 ++-- tests/mysql/main_test.go | 2 +- tests/postgres/main_test.go | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/jet/logger.go b/internal/jet/logger.go index 90818b0..c900fc0 100644 --- a/internal/jet/logger.go +++ b/internal/jet/logger.go @@ -2,14 +2,14 @@ package jet import "context" -// LoggableStatement is a statement which sql query can be logged -type LoggableStatement interface { +// PrintableStatement is a statement which sql query can be logged +type PrintableStatement interface { Sql() (query string, args []interface{}) DebugSql() (query string) } // LoggerFunc is a definition of a function user can implement to support automatic statement logging. -type LoggerFunc func(ctx context.Context, statement LoggableStatement) +type LoggerFunc func(ctx context.Context, statement PrintableStatement) var logger LoggerFunc diff --git a/mysql/types.go b/mysql/types.go index 7e1424f..08ae20a 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -14,8 +14,8 @@ type ProjectionList = jet.ProjectionList // ColumnAssigment is interface wrapper around column assigment type ColumnAssigment = jet.ColumnAssigment -// LoggableStatement is a statement which sql query can be logged -type LoggableStatement = jet.LoggableStatement +// PrintableStatement is a statement which sql query can be logged +type PrintableStatement = jet.PrintableStatement // SetLogger sets automatic statement logging var SetLogger = jet.SetLoggerFunc diff --git a/postgres/types.go b/postgres/types.go index cfb52ec..fb7b8a0 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -14,8 +14,8 @@ type ProjectionList = jet.ProjectionList // ColumnAssigment is interface wrapper around column assigment type ColumnAssigment = jet.ColumnAssigment -// LoggableStatement is a statement which sql query can be logged -type LoggableStatement = jet.LoggableStatement +// PrintableStatement is a statement which sql query can be logged +type PrintableStatement = jet.PrintableStatement // SetLogger sets automatic statement logging var SetLogger = jet.SetLoggerFunc diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index 0f51875..fd513ac 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -54,7 +54,7 @@ var loggedSQLArgs []interface{} var loggedDebugSQL string func init() { - jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.LoggableStatement) { + jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 5fa23d5..2d5edf4 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -52,7 +52,7 @@ var loggedSQLArgs []interface{} var loggedDebugSQL string func init() { - postgres.SetLogger(func(ctx context.Context, statement postgres.LoggableStatement) { + postgres.SetLogger(func(ctx context.Context, statement postgres.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) From d19fdea86d24bd0d5b7a4c7a7caa56f4cb3471c7 Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 1 Jun 2020 20:30:09 +0200 Subject: [PATCH 22/23] Additional MySQL WITH statement tests. --- internal/jet/with_statement.go | 10 +++- internal/testutils/test_utils.go | 2 +- mysql/with_statement.go | 2 +- postgres/with_statement.go | 2 +- tests/mysql/with_test.go | 95 +++++++++++++++++++++++++++++++- 5 files changed, 103 insertions(+), 8 deletions(-) diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index 6131b35..ab57067 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -1,7 +1,7 @@ package jet // WITH function creates new with statement from list of common table expressions for specified dialect -func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement SerializerStatement) Statement { +func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement { newWithImpl := &withImpl{ ctes: cte, serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ @@ -11,8 +11,12 @@ func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statemen } newWithImpl.parent = newWithImpl - return func(primaryStatement SerializerStatement) Statement { - newWithImpl.primaryStatement = primaryStatement + return func(primaryStatement Statement) Statement { + serializerStatement, ok := primaryStatement.(SerializerStatement) + if !ok { + panic("jet: unsupported main WITH statement.") + } + newWithImpl.primaryStatement = serializerStatement return newWithImpl } } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 3cff7ab..e5d7a09 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -28,7 +28,7 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int require.NoError(t, err) if len(rowsAffected) > 0 { - require.Equal(t, rows, rowsAffected[0]) + require.Equal(t, rowsAffected[0], rows) } } diff --git a/mysql/with_statement.go b/mysql/with_statement.go index 5991287..35066f7 100644 --- a/mysql/with_statement.go +++ b/mysql/with_statement.go @@ -9,7 +9,7 @@ type CommonTableExpression struct { } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement { +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { return jet.WITH(Dialect, cte...) } diff --git a/postgres/with_statement.go b/postgres/with_statement.go index caa7100..c1f7a7b 100644 --- a/postgres/with_statement.go +++ b/postgres/with_statement.go @@ -9,7 +9,7 @@ type CommonTableExpression struct { } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement { +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { return jet.WITH(Dialect, cte...) } diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index 7e3a8dd..fa53fad 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func TestWITH_SELECT(t *testing.T) { +func TestWITH_And_SELECT(t *testing.T) { salesRep := CTE("sales_rep") salesRepStaffID := Staff.StaffID.From(salesRep) salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep) @@ -56,6 +56,97 @@ SELECT customer_sales_rep.customer_name AS "customer_name", FROM customer_sales_rep; `, "''", "`", -1)) - _, err := stmt.Exec(db) + var dest []struct { + CustomerName string + SalesRepFullName string + } + err := stmt.Query(db, &dest) + + require.Equal(t, len(dest), 599) require.NoError(t, err) } + +//func TestWITH_And_INSERT(t *testing.T) { +// paymentsToInsert := CTE("payments_to_insert") +// +// stmt := WITH( +// paymentsToInsert.AS( +// SELECT(Payment.AllColumns). +// FROM(Payment). +// WHERE(Payment.Amount.LT(Float(0.5))), +// ), +// )( +// Payment.INSERT(Payment.AllColumns). +// QUERY( +// SELECT(paymentsToInsert.AllColumns()). +// FROM(paymentsToInsert), +// ).ON_DUPLICATE_KEY_UPDATE( +// Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))), +// ), +// ) +// +// //fmt.Println(stmt.DebugSql()) +// +// tx, err := db.Begin() +// require.NoError(t, err) +// defer tx.Rollback() +// +// testutils.AssertExec(t, stmt, tx, 24) +//} + +func TestWITH_And_UPDATE(t *testing.T) { + paymentsToUpdate := CTE("payments_to_update") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate) + + stmt := WITH( + paymentsToUpdate.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.UPDATE(). + SET(Payment.Amount.SET(Float(0.0))). + WHERE(Payment.PaymentID.IN( + SELECT(paymentsToDeleteID). + FROM(paymentsToUpdate), + ), + ), + ) + + //fmt.Println(stmt.DebugSql()) + + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx) +} + +func TestWITH_And_DELETE(t *testing.T) { + paymentsToDelete := CTE("payments_to_delete") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete) + + stmt := WITH( + paymentsToDelete.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.DELETE(). + WHERE(Payment.PaymentID.IN( + SELECT(paymentsToDeleteID). + FROM(paymentsToDelete), + ), + ), + ) + + //fmt.Println(stmt.DebugSql()) + + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx, 24) +} From 63aa31925aaada131d545581d9b19277f4b1e4a3 Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 1 Jun 2020 20:35:07 +0200 Subject: [PATCH 23/23] Disable unsupported MariaDB tests. --- tests/mysql/with_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index fa53fad..0a82f2e 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -95,6 +95,9 @@ FROM customer_sales_rep; //} func TestWITH_And_UPDATE(t *testing.T) { + if sourceIsMariaDB() { + return + } paymentsToUpdate := CTE("payments_to_update") paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate) @@ -124,6 +127,10 @@ func TestWITH_And_UPDATE(t *testing.T) { } func TestWITH_And_DELETE(t *testing.T) { + if sourceIsMariaDB() { + return + } + paymentsToDelete := CTE("payments_to_delete") paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete)