Merge pull request #40 from go-jet/develop

Merge develop to master for 2.3.0 release
This commit is contained in:
go-jet 2020-06-03 07:28:40 +02:00 committed by GitHub
commit c3903948c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
115 changed files with 3769 additions and 1460 deletions

View file

@ -35,17 +35,19 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
## Features ## Features
1) Auto-generated type-safe SQL Builder 1) Auto-generated type-safe SQL Builder
- PostgreSQL: - PostgreSQL:
* SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)`
* INSERT `(VALUES, query, RETURNING)`, * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, QUERY, ON_CONFLICT, RETURNING)`,
* UPDATE `(SET, WHERE, RETURNING)`, * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE, RETURNING)`,
* DELETE `(WHERE, RETURNING)`, * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, RETURNING)`,
* LOCK `(IN, NOWAIT)` * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)`
* [WITH](https://github.com/go-jet/jet/wiki/WITH)
- MySQL and MariaDB: - MySQL and MariaDB:
* SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)` * [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)`
* INSERT `(VALUES, query)`, * [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, ON_DUPLICATE_KEY_UPDATE, query)`,
* UPDATE `(SET, WHERE)`, * [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE)`,
* DELETE `(WHERE, ORDER_BY, LIMIT)`, * [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT)`,
* LOCK `(READ, WRITE)` * [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(READ, WRITE)`
* [WITH](https://github.com/go-jet/jet/wiki/WITH)
2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store 2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store
result of database queries. Can be combined to create desired query result destination. result of database queries. Can be combined to create desired query result destination.
3) Query execution with result mapping to arbitrary destination structure. 3) Query execution with result mapping to arbitrary destination structure.
@ -561,6 +563,7 @@ At the moment Jet dependence only of:
To run the tests, additional dependencies are required: To run the tests, additional dependencies are required:
- `github.com/pkg/profile` - `github.com/pkg/profile`
- `github.com/stretchr/testify` - `github.com/stretchr/testify`
- `github.com/google/go-cmp`
## Versioning ## Versioning
@ -568,5 +571,5 @@ To run the tests, additional dependencies are required:
## License ## License
Copyright 2019 Goran Bjelanovic Copyright 2019-2020 Goran Bjelanovic
Licensed under the Apache License, Version 2.0. Licensed under the Apache License, Version 2.0.

View file

@ -47,7 +47,7 @@ func main() {
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, ` _, _ = fmt.Fprint(os.Stdout, `
Jet generator 2.0.0 Jet generator 2.3.0
Usage: Usage:
-source string -source string

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Actor = newActorTable() var Actor = newActorTable()
type ActorTable struct { type actorTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -27,25 +26,38 @@ type ActorTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new ActorTable with assigned alias type ActorTable struct {
actorTable
EXCLUDED actorTable
}
// AS creates new ActorTable with assigned alias
func (a *ActorTable) AS(alias string) *ActorTable { func (a *ActorTable) AS(alias string) *ActorTable {
aliasTable := newActorTable() aliasTable := newActorTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newActorTable() *ActorTable { func newActorTable() *ActorTable {
return &ActorTable{
actorTable: newActorTableImpl("dvds", "actor"),
EXCLUDED: newActorTableImpl("", "excluded"),
}
}
func newActorTableImpl(schemaName, tableName string) actorTable {
var ( var (
ActorIDColumn = postgres.IntegerColumn("actor_id") ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = postgres.StringColumn("first_name") FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = postgres.StringColumn("last_name") LastNameColumn = postgres.StringColumn("last_name")
LastUpdateColumn = postgres.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}
) )
return &ActorTable{ return actorTable{
Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
@ -53,7 +65,7 @@ func newActorTable() *ActorTable {
LastName: LastNameColumn, LastName: LastNameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Category = newCategoryTable() var Category = newCategoryTable()
type CategoryTable struct { type categoryTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -26,31 +25,44 @@ type CategoryTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new CategoryTable with assigned alias type CategoryTable struct {
categoryTable
EXCLUDED categoryTable
}
// AS creates new CategoryTable with assigned alias
func (a *CategoryTable) AS(alias string) *CategoryTable { func (a *CategoryTable) AS(alias string) *CategoryTable {
aliasTable := newCategoryTable() aliasTable := newCategoryTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newCategoryTable() *CategoryTable { func newCategoryTable() *CategoryTable {
return &CategoryTable{
categoryTable: newCategoryTableImpl("dvds", "category"),
EXCLUDED: newCategoryTableImpl("", "excluded"),
}
}
func newCategoryTableImpl(schemaName, tableName string) categoryTable {
var ( var (
CategoryIDColumn = postgres.IntegerColumn("category_id") CategoryIDColumn = postgres.IntegerColumn("category_id")
NameColumn = postgres.StringColumn("name") NameColumn = postgres.StringColumn("name")
LastUpdateColumn = postgres.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{NameColumn, LastUpdateColumn}
) )
return &CategoryTable{ return categoryTable{
Table: postgres.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
CategoryID: CategoryIDColumn, CategoryID: CategoryIDColumn,
Name: NameColumn, Name: NameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{NameColumn, LastUpdateColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Film = newFilmTable() var Film = newFilmTable()
type FilmTable struct { type filmTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -36,16 +35,27 @@ type FilmTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new FilmTable with assigned alias type FilmTable struct {
filmTable
EXCLUDED filmTable
}
// AS creates new FilmTable with assigned alias
func (a *FilmTable) AS(alias string) *FilmTable { func (a *FilmTable) AS(alias string) *FilmTable {
aliasTable := newFilmTable() aliasTable := newFilmTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newFilmTable() *FilmTable { func newFilmTable() *FilmTable {
return &FilmTable{
filmTable: newFilmTableImpl("dvds", "film"),
EXCLUDED: newFilmTableImpl("", "excluded"),
}
}
func newFilmTableImpl(schemaName, tableName string) filmTable {
var ( var (
FilmIDColumn = postgres.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
TitleColumn = postgres.StringColumn("title") TitleColumn = postgres.StringColumn("title")
@ -60,10 +70,12 @@ func newFilmTable() *FilmTable {
LastUpdateColumn = postgres.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
SpecialFeaturesColumn = postgres.StringColumn("special_features") SpecialFeaturesColumn = postgres.StringColumn("special_features")
FulltextColumn = postgres.StringColumn("fulltext") FulltextColumn = postgres.StringColumn("fulltext")
allColumns = postgres.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}
mutableColumns = postgres.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}
) )
return &FilmTable{ return filmTable{
Table: postgres.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
@ -80,7 +92,7 @@ func newFilmTable() *FilmTable {
SpecialFeatures: SpecialFeaturesColumn, SpecialFeatures: SpecialFeaturesColumn,
Fulltext: FulltextColumn, Fulltext: FulltextColumn,
AllColumns: postgres.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var FilmActor = newFilmActorTable() var FilmActor = newFilmActorTable()
type FilmActorTable struct { type filmActorTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -26,31 +25,44 @@ type FilmActorTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new FilmActorTable with assigned alias type FilmActorTable struct {
filmActorTable
EXCLUDED filmActorTable
}
// AS creates new FilmActorTable with assigned alias
func (a *FilmActorTable) AS(alias string) *FilmActorTable { func (a *FilmActorTable) AS(alias string) *FilmActorTable {
aliasTable := newFilmActorTable() aliasTable := newFilmActorTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newFilmActorTable() *FilmActorTable { func newFilmActorTable() *FilmActorTable {
return &FilmActorTable{
filmActorTable: newFilmActorTableImpl("dvds", "film_actor"),
EXCLUDED: newFilmActorTableImpl("", "excluded"),
}
}
func newFilmActorTableImpl(schemaName, tableName string) filmActorTable {
var ( var (
ActorIDColumn = postgres.IntegerColumn("actor_id") ActorIDColumn = postgres.IntegerColumn("actor_id")
FilmIDColumn = postgres.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
LastUpdateColumn = postgres.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{LastUpdateColumn}
) )
return &FilmActorTable{ return filmActorTable{
Table: postgres.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{LastUpdateColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var FilmCategory = newFilmCategoryTable() var FilmCategory = newFilmCategoryTable()
type FilmCategoryTable struct { type filmCategoryTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -26,31 +25,44 @@ type FilmCategoryTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new FilmCategoryTable with assigned alias type FilmCategoryTable struct {
filmCategoryTable
EXCLUDED filmCategoryTable
}
// AS creates new FilmCategoryTable with assigned alias
func (a *FilmCategoryTable) AS(alias string) *FilmCategoryTable { func (a *FilmCategoryTable) AS(alias string) *FilmCategoryTable {
aliasTable := newFilmCategoryTable() aliasTable := newFilmCategoryTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newFilmCategoryTable() *FilmCategoryTable { func newFilmCategoryTable() *FilmCategoryTable {
return &FilmCategoryTable{
filmCategoryTable: newFilmCategoryTableImpl("dvds", "film_category"),
EXCLUDED: newFilmCategoryTableImpl("", "excluded"),
}
}
func newFilmCategoryTableImpl(schemaName, tableName string) filmCategoryTable {
var ( var (
FilmIDColumn = postgres.IntegerColumn("film_id") FilmIDColumn = postgres.IntegerColumn("film_id")
CategoryIDColumn = postgres.IntegerColumn("category_id") CategoryIDColumn = postgres.IntegerColumn("category_id")
LastUpdateColumn = postgres.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{LastUpdateColumn}
) )
return &FilmCategoryTable{ return filmCategoryTable{
Table: postgres.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
FilmID: FilmIDColumn, FilmID: FilmIDColumn,
CategoryID: CategoryIDColumn, CategoryID: CategoryIDColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{LastUpdateColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Language = newLanguageTable() var Language = newLanguageTable()
type LanguageTable struct { type languageTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -26,31 +25,44 @@ type LanguageTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new LanguageTable with assigned alias type LanguageTable struct {
languageTable
EXCLUDED languageTable
}
// AS creates new LanguageTable with assigned alias
func (a *LanguageTable) AS(alias string) *LanguageTable { func (a *LanguageTable) AS(alias string) *LanguageTable {
aliasTable := newLanguageTable() aliasTable := newLanguageTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newLanguageTable() *LanguageTable { func newLanguageTable() *LanguageTable {
return &LanguageTable{
languageTable: newLanguageTableImpl("dvds", "language"),
EXCLUDED: newLanguageTableImpl("", "excluded"),
}
}
func newLanguageTableImpl(schemaName, tableName string) languageTable {
var ( var (
LanguageIDColumn = postgres.IntegerColumn("language_id") LanguageIDColumn = postgres.IntegerColumn("language_id")
NameColumn = postgres.StringColumn("name") NameColumn = postgres.StringColumn("name")
LastUpdateColumn = postgres.TimestampColumn("last_update") LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{NameColumn, LastUpdateColumn}
) )
return &LanguageTable{ return languageTable{
Table: postgres.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
LanguageID: LanguageIDColumn, LanguageID: LanguageIDColumn,
Name: NameColumn, Name: NameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{NameColumn, LastUpdateColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var ActorInfo = newActorInfoTable() var ActorInfo = newActorInfoTable()
type ActorInfoTable struct { type actorInfoTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -27,25 +26,38 @@ type ActorInfoTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new ActorInfoTable with assigned alias type ActorInfoTable struct {
actorInfoTable
EXCLUDED actorInfoTable
}
// AS creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { func (a *ActorInfoTable) AS(alias string) *ActorInfoTable {
aliasTable := newActorInfoTable() aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newActorInfoTable() *ActorInfoTable { func newActorInfoTable() *ActorInfoTable {
return &ActorInfoTable{
actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"),
EXCLUDED: newActorInfoTableImpl("", "excluded"),
}
}
func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable {
var ( var (
ActorIDColumn = postgres.IntegerColumn("actor_id") ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = postgres.StringColumn("first_name") FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = postgres.StringColumn("last_name") LastNameColumn = postgres.StringColumn("last_name")
FilmInfoColumn = postgres.StringColumn("film_info") FilmInfoColumn = postgres.StringColumn("film_info")
allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
mutableColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
) )
return &ActorInfoTable{ return actorInfoTable{
Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
@ -53,7 +65,7 @@ func newActorInfoTable() *ActorInfoTable {
LastName: LastNameColumn, LastName: LastNameColumn,
FilmInfo: FilmInfoColumn, FilmInfo: FilmInfoColumn,
AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -1,6 +1,5 @@
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var CustomerList = newCustomerListTable() var CustomerList = newCustomerListTable()
type CustomerListTable struct { type customerListTable struct {
postgres.Table postgres.Table
//Columns //Columns
@ -32,16 +31,27 @@ type CustomerListTable struct {
MutableColumns postgres.ColumnList MutableColumns postgres.ColumnList
} }
// creates new CustomerListTable with assigned alias type CustomerListTable struct {
customerListTable
EXCLUDED customerListTable
}
// AS creates new CustomerListTable with assigned alias
func (a *CustomerListTable) AS(alias string) *CustomerListTable { func (a *CustomerListTable) AS(alias string) *CustomerListTable {
aliasTable := newCustomerListTable() aliasTable := newCustomerListTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newCustomerListTable() *CustomerListTable { func newCustomerListTable() *CustomerListTable {
return &CustomerListTable{
customerListTable: newCustomerListTableImpl("dvds", "customer_list"),
EXCLUDED: newCustomerListTableImpl("", "excluded"),
}
}
func newCustomerListTableImpl(schemaName, tableName string) customerListTable {
var ( var (
IDColumn = postgres.IntegerColumn("id") IDColumn = postgres.IntegerColumn("id")
NameColumn = postgres.StringColumn("name") NameColumn = postgres.StringColumn("name")
@ -52,10 +62,12 @@ func newCustomerListTable() *CustomerListTable {
CountryColumn = postgres.StringColumn("country") CountryColumn = postgres.StringColumn("country")
NotesColumn = postgres.StringColumn("notes") NotesColumn = postgres.StringColumn("notes")
SidColumn = postgres.IntegerColumn("sid") SidColumn = postgres.IntegerColumn("sid")
allColumns = postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}
mutableColumns = postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}
) )
return &CustomerListTable{ return customerListTable{
Table: postgres.NewTable("dvds", "customer_list", IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn), Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns //Columns
ID: IDColumn, ID: IDColumn,
@ -68,7 +80,7 @@ func newCustomerListTable() *CustomerListTable {
Notes: NotesColumn, Notes: NotesColumn,
Sid: SidColumn, Sid: SidColumn,
AllColumns: postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}, AllColumns: allColumns,
MutableColumns: postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}, MutableColumns: mutableColumns,
} }
} }

View file

@ -3,10 +3,10 @@
This package contains sample usage for Jet framework. This package contains sample usage for Jet framework.
Jet generated files of interest are in ./gen folder. Jet generated files of interest are in `./gen` folder.
quick-start.go contains code explained at [README.md](../../README.md#quick-start), `quick-start.go` - contains code explained at main [README.md](../../README.md#quick-start),
with difference of redirecting json output to files(dest.json and dest2.json) rather then to a with a difference of redirecting json output to files(`dest.json` and `dest2.json`) rather then to a
standard output. standard output.
./gen, dest.json and dest2.json - added to git for presentation purposes. `./gen`, `dest.json` and `dest2.json` - added to git for presentation purposes.

View file

@ -7,7 +7,7 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
"io/ioutil" "io/ioutil"
// dot import so go code would resemble as much as native SQL // dot import so that jet go code would resemble as much as native SQL
// dot import is not mandatory // dot import is not mandatory
. "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table" . "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
@ -98,15 +98,15 @@ func printStatementInfo(stmt SelectStatement) {
query, args := stmt.Sql() query, args := stmt.Sql()
fmt.Println("Parameterized query: ") fmt.Println("Parameterized query: ")
fmt.Println("==============================")
fmt.Println(query) fmt.Println(query)
fmt.Println("Arguments: ") fmt.Println("Arguments: ")
fmt.Println(args) fmt.Println(args)
debugSQL := stmt.DebugSql() debugSQL := stmt.DebugSql()
fmt.Println("\n\n==============================")
fmt.Println("\n\nDebug sql: ") fmt.Println("\n\nDebug sql: ")
fmt.Println("==============================")
fmt.Println(debugSQL) fmt.Println(debugSQL)
} }

View file

@ -3,6 +3,7 @@ package metadata
import ( import (
"database/sql" "database/sql"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"strings"
) )
// TableMetaData metadata struct // TableMetaData metadata struct
@ -67,15 +68,19 @@ func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table" return utils.ToGoIdentifier(t.name) + "Table"
} }
// GoStructImplName returns go struct impl name for sql builder
func (t TableMetaData) GoStructImplName() string {
name := utils.ToGoIdentifier(t.name) + "Table"
return string(strings.ToLower(name)[0]) + name[1:]
}
// GetTableMetaData returns table info metadata // GetTableMetaData returns table info metadata
func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) { func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) {
tableInfo.SchemaName = schemaName tableInfo.SchemaName = schemaName
tableInfo.name = tableName tableInfo.name = tableName
tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName) tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName)
tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName) tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName)
return return
} }

View file

@ -8,7 +8,6 @@ import (
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"path/filepath" "path/filepath"
"text/template" "text/template"
"time"
) )
// GenerateFiles generates Go files from tables and enums metadata // GenerateFiles generates Go files from tables and enums metadata
@ -22,6 +21,7 @@ func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect j
err := utils.CleanUpGeneratedFiles(destDir) err := utils.CleanUpGeneratedFiles(destDir)
utils.PanicOnError(err) utils.PanicOnError(err)
tableSQLBuilderTemplate := getTableSQLBuilderTemplate(dialect)
generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect)
generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect)
generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect)
@ -33,6 +33,14 @@ func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect j
fmt.Println("Done") fmt.Println("Done")
} }
func getTableSQLBuilderTemplate(dialect jet.Dialect) string {
if dialect.Name() == "PostgreSQL" {
return tablePostgreSQLBuilderTemplate
}
return tableSQLBuilderTemplate
}
func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 { if len(metaData) == 0 {
return return
@ -75,9 +83,6 @@ func GenerateTemplate(templateText string, templateData interface{}, dialect jet
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier, "ToGoIdentifier": utils.ToGoIdentifier,
"ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier, "ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier,
"now": func() string {
return time.Now().Format(time.RFC850)
},
"dialect": func() jet.Dialect { "dialect": func() jet.Dialect {
return dialect return dialect
}, },

View file

@ -3,7 +3,6 @@ package template
var autoGenWarningTemplate = ` var autoGenWarningTemplate = `
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.
// Generated at {{now}}
// //
// WARNING: Changes to this file may cause incorrect behavior // WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated // and will be lost if the code is regenerated
@ -38,35 +37,104 @@ type {{.GoStructName}} struct {
MutableColumns {{dialect.PackageName}}.ColumnList MutableColumns {{dialect.PackageName}}.ColumnList
} }
// creates new {{.GoStructName}} with assigned alias // AS creates new {{.GoStructName}} with assigned alias
func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} { func (a *{{.GoStructName}}) AS(alias string) {{.GoStructName}} {
aliasTable := new{{.GoStructName}}() aliasTable := new{{.GoStructName}}()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func new{{.GoStructName}}() *{{.GoStructName}} { func new{{.GoStructName}}() {{.GoStructName}} {
var ( var (
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}} {{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
) )
return &{{.GoStructName}}{ return {{.GoStructName}}{
Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", allColumns...),
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}} {{- end}}
AllColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }, AllColumns: allColumns,
MutableColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }, MutableColumns: mutableColumns,
}
}
`
var tablePostgreSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
{{- end}}
{{- end}}
package {{param "package"}}
import (
"github.com/go-jet/jet/{{dialect.PackageName}}"
)
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
type {{.GoStructImplName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
type {{.GoStructName}} struct {
{{.GoStructImplName}}
EXCLUDED {{.GoStructImplName}}
}
// AS creates new {{.GoStructName}} with assigned alias
func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
aliasTable := new{{.GoStructName}}()
aliasTable.Table.AS(alias)
return aliasTable
}
func new{{.GoStructName}}() *{{.GoStructName}} {
return &{{.GoStructName}}{
{{.GoStructImplName}}: new{{.GoStructName}}Impl("{{.SchemaName}}", "{{.Name}}"),
EXCLUDED: new{{.GoStructName}}Impl("", "excluded"),
} }
} }
func new{{.GoStructName}}Impl(schemaName, tableName string) {{.GoStructImplName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{.GoStructImplName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, allColumns...),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
` `
var tableModelTemplate = `package model var tableModelTemplate = `package model

10
go.mod Normal file
View file

@ -0,0 +1,10 @@
module github.com/go-jet/jet
require (
github.com/go-sql-driver/mysql v1.5.0
github.com/google/go-cmp v0.4.1
github.com/google/uuid v1.1.1
github.com/lib/pq v1.6.0
github.com/pkg/profile v1.5.0
github.com/stretchr/testify v1.6.0
)

60
go.sum Normal file
View file

@ -0,0 +1,60 @@
github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 h1:P5U+E4x5OkVEKQDklVPmzs71WM56RTTRqV4OrDC//Y4=
github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5/go.mod h1:976q2ETgjT2snVCf2ZaBnyBbVoPERGjUz+0sofzEfro=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0=
github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ=
github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8=
github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o=
github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
github.com/jcmturner/gokrb5/v8 v8.2.0 h1:lzPl/30ZLkTveYsYZPKMcgXc8MbnE6RsTd4F9KgiLtk=
github.com/jcmturner/gokrb5/v8 v8.2.0/go.mod h1:T1hnNppQsBtxW0tCHMHTkAt8n/sABdzZgZdoFrZaZNM=
github.com/jcmturner/rpc/v2 v2.0.2 h1:gMB4IwRXYsWw4Bc6o/az2HJgFUA1ffSh90i26ZJ6Xl0=
github.com/jcmturner/rpc/v2 v2.0.2/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/lib/pq v1.6.0 h1:I5DPxhYJChW9KYc66se+oKFFQX6VuQrKiprsX6ivRZc=
github.com/lib/pq v1.6.0/go.mod h1:4vXEAYvW1fRQ2/FhZ78H73A60MHw1geSm145z2mdY1g=
github.com/pkg/profile v1.5.0 h1:042Buzk+NhDI+DeSAA62RwJL8VAuZUMQZUjCsRz1Mug=
github.com/pkg/profile v1.5.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho=
github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4 h1:QmwruyY+bKbDDL0BaglrbZABEali68eoMFhTZpCjYVA=
golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo=
gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q=
gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4=
gopkg.in/jcmturner/gokrb5.v7 v7.5.0/go.mod h1:l8VISx+WGYp+Fp7KRbsiUuXTTOnxIc3Tuvyavf11/WM=
gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -1,16 +1,16 @@
package snaker package snaker
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
) )
func TestSnakeToCamel(t *testing.T) { func TestSnakeToCamel(t *testing.T) {
assert.Equal(t, SnakeToCamel(""), "") require.Equal(t, SnakeToCamel(""), "")
assert.Equal(t, SnakeToCamel("potato_"), "Potato") require.Equal(t, SnakeToCamel("potato_"), "Potato")
assert.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased") require.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
assert.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID") require.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
assert.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier") require.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")
assert.Equal(t, SnakeToCamel("id"), "ID") require.Equal(t, SnakeToCamel("id"), "ID")
assert.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient") require.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient")
} }

View file

@ -42,12 +42,12 @@ func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, opt
castType := b.cast castType := b.cast
if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil { if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil {
castOverride(expression, String(castType))(statement, out, options...) castOverride(expression, String(castType))(statement, out, FallTrough(options)...)
return return
} }
out.WriteString("CAST(") out.WriteString("CAST(")
expression.serialize(statement, out, options...) expression.serialize(statement, out, FallTrough(options)...)
out.WriteString("AS") out.WriteString("AS")
out.WriteString(castType + ")") out.WriteString(castType + ")")
} }

View file

@ -6,28 +6,29 @@ import (
// Clause interface // Clause interface
type Clause interface { type Clause interface {
Serialize(statementType StatementType, out *SQLBuilder) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption)
} }
// ClauseWithProjections interface // ClauseWithProjections interface
type ClauseWithProjections interface { type ClauseWithProjections interface {
Clause Clause
projections() ProjectionList Projections() ProjectionList
} }
// ClauseSelect struct // ClauseSelect struct
type ClauseSelect struct { type ClauseSelect struct {
Distinct bool Distinct bool
Projections []Projection ProjectionList []Projection
} }
func (s *ClauseSelect) projections() ProjectionList { // Projections returns list of projections for select clause
return s.Projections func (s *ClauseSelect) Projections() ProjectionList {
return s.ProjectionList
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder) { func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine() out.NewLine()
out.WriteString("SELECT") out.WriteString("SELECT")
@ -35,11 +36,11 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("DISTINCT") out.WriteString("DISTINCT")
} }
if len(s.Projections) == 0 { if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection") panic("jet: SELECT clause has to have at least one projection")
} }
out.WriteProjections(statementType, s.Projections) out.WriteProjections(statementType, s.ProjectionList)
} }
// ClauseFrom struct // ClauseFrom struct
@ -48,7 +49,7 @@ type ClauseFrom struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) { func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if f.Table == nil { if f.Table == nil {
return return
} }
@ -56,7 +57,7 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("FROM") out.WriteString("FROM")
out.IncreaseIdent() out.IncreaseIdent()
f.Table.serialize(statementType, out) f.Table.serialize(statementType, out, FallTrough(options)...)
out.DecreaseIdent() out.DecreaseIdent()
} }
@ -67,18 +68,20 @@ type ClauseWhere struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder) { func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.Condition == nil { if c.Condition == nil {
if c.Mandatory { if c.Mandatory {
panic("jet: WHERE clause not set") panic("jet: WHERE clause not set")
} }
return return
} }
if !contains(options, SkipNewLine) {
out.NewLine() out.NewLine()
}
out.WriteString("WHERE") out.WriteString("WHERE")
out.IncreaseIdent() out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap) c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...)
out.DecreaseIdent() out.DecreaseIdent()
} }
@ -88,7 +91,7 @@ type ClauseGroupBy struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder) { func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(c.List) == 0 { if len(c.List) == 0 {
return return
} }
@ -119,7 +122,7 @@ type ClauseHaving struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) { func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.Condition == nil { if c.Condition == nil {
return return
} }
@ -128,7 +131,7 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("HAVING") out.WriteString("HAVING")
out.IncreaseIdent() out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap) c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...)
out.DecreaseIdent() out.DecreaseIdent()
} }
@ -139,7 +142,7 @@ type ClauseOrderBy struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder) { func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if o.List == nil { if o.List == nil {
return return
} }
@ -168,7 +171,7 @@ type ClauseLimit struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder) { func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if l.Count >= 0 { if l.Count >= 0 {
out.NewLine() out.NewLine()
out.WriteString("LIMIT") out.WriteString("LIMIT")
@ -182,7 +185,7 @@ type ClauseOffset struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder) { func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if o.Count >= 0 { if o.Count >= 0 {
out.NewLine() out.NewLine()
out.WriteString("OFFSET") out.WriteString("OFFSET")
@ -196,27 +199,28 @@ type ClauseFor struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder) { func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if f.Lock == nil { if f.Lock == nil {
return return
} }
out.NewLine() out.NewLine()
out.WriteString("FOR") out.WriteString("FOR")
f.Lock.serialize(statementType, out) f.Lock.serialize(statementType, out, FallTrough(options)...)
} }
// ClauseSetStmtOperator struct // ClauseSetStmtOperator struct
type ClauseSetStmtOperator struct { type ClauseSetStmtOperator struct {
Operator string Operator string
All bool All bool
Selects []StatementWithProjections Selects []SerializerStatement
OrderBy ClauseOrderBy OrderBy ClauseOrderBy
Limit ClauseLimit Limit ClauseLimit
Offset ClauseOffset Offset ClauseOffset
} }
func (s *ClauseSetStmtOperator) projections() ProjectionList { // Projections returns set of projections for ClauseSetStmtOperator
func (s *ClauseSetStmtOperator) Projections() ProjectionList {
if len(s.Selects) > 0 { if len(s.Selects) > 0 {
return s.Selects[0].projections() return s.Selects[0].projections()
} }
@ -224,7 +228,7 @@ func (s *ClauseSetStmtOperator) projections() ProjectionList {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder) { func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s.Selects) < 2 { if len(s.Selects) < 2 {
panic("jet: UNION Statement must contain at least two SELECT statements") panic("jet: UNION Statement must contain at least two SELECT statements")
} }
@ -244,7 +248,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB
panic("jet: select statement of '" + s.Operator + "' is nil") panic("jet: select statement of '" + s.Operator + "' is nil")
} }
selectStmt.serialize(statementType, out) selectStmt.serialize(statementType, out, FallTrough(options)...)
} }
s.OrderBy.Serialize(statementType, out) s.OrderBy.Serialize(statementType, out)
@ -258,7 +262,7 @@ type ClauseUpdate struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) { func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine() out.NewLine()
out.WriteString("UPDATE") out.WriteString("UPDATE")
@ -266,17 +270,20 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) {
panic("jet: table to update is nil") panic("jet: table to update is nil")
} }
u.Table.serialize(statementType, out) u.Table.serialize(statementType, out, FallTrough(options)...)
} }
// ClauseSet struct // SetClause struct
type ClauseSet struct { type SetClause struct {
Columns []Column Columns []Column
Values []Serializer Values []Serializer
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) { func (s *SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s.Values) == 0 {
return
}
out.NewLine() out.NewLine()
out.WriteString("SET") out.WriteString("SET")
@ -299,7 +306,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString(" = ") out.WriteString(" = ")
s.Values[i].serialize(UpdateStatementType, out) s.Values[i].serialize(UpdateStatementType, out, FallTrough(options)...)
} }
out.DecreaseIdent(4) out.DecreaseIdent(4)
} }
@ -320,7 +327,7 @@ func (i *ClauseInsert) GetColumns() []Column {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder) { func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine() out.NewLine()
out.WriteString("INSERT INTO") out.WriteString("INSERT INTO")
@ -346,7 +353,7 @@ type ClauseValuesQuery struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder) { func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(v.Rows) == 0 && v.Query == nil { if len(v.Rows) == 0 && v.Query == nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement") panic("jet: VALUES or QUERY has to be specified for INSERT statement")
} }
@ -355,8 +362,8 @@ func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuild
panic("jet: VALUES or QUERY has to be specified for INSERT statement") panic("jet: VALUES or QUERY has to be specified for INSERT statement")
} }
v.ClauseValues.Serialize(statementType, out) v.ClauseValues.Serialize(statementType, out, FallTrough(options)...)
v.ClauseQuery.Serialize(statementType, out) v.ClauseQuery.Serialize(statementType, out, FallTrough(options)...)
} }
// ClauseValues struct // ClauseValues struct
@ -365,27 +372,29 @@ type ClauseValues struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder) { func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(v.Rows) == 0 { if len(v.Rows) == 0 {
return return
} }
out.NewLine()
out.WriteString("VALUES") out.WriteString("VALUES")
for rowIndex, row := range v.Rows { for rowIndex, row := range v.Rows {
if rowIndex > 0 { if rowIndex > 0 {
out.WriteString(",") out.WriteString(",")
out.NewLine()
} else {
out.IncreaseIdent(7)
} }
out.IncreaseIdent()
out.NewLine()
out.WriteString("(") out.WriteString("(")
SerializeClauseList(statementType, row, out) SerializeClauseList(statementType, row, out)
out.WriteByte(')') out.WriteByte(')')
out.DecreaseIdent()
} }
out.DecreaseIdent(7)
} }
// ClauseQuery struct // ClauseQuery struct
@ -394,12 +403,12 @@ type ClauseQuery struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder) { func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if v.Query == nil { if v.Query == nil {
return return
} }
v.Query.serialize(statementType, out) v.Query.serialize(statementType, out, FallTrough(options)...)
} }
// ClauseDelete struct // ClauseDelete struct
@ -408,7 +417,7 @@ type ClauseDelete struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) { func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine() out.NewLine()
out.WriteString("DELETE FROM") out.WriteString("DELETE FROM")
@ -416,7 +425,7 @@ func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) {
panic("jet: nil table in DELETE clause") panic("jet: nil table in DELETE clause")
} }
d.Table.serialize(statementType, out) d.Table.serialize(statementType, out, FallTrough(options)...)
} }
// ClauseStatementBegin struct // ClauseStatementBegin struct
@ -426,7 +435,7 @@ type ClauseStatementBegin struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder) { func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine() out.NewLine()
out.WriteString(d.Name) out.WriteString(d.Name)
@ -435,7 +444,7 @@ func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBu
out.WriteString(", ") out.WriteString(", ")
} }
table.serialize(statementType, out) table.serialize(statementType, out, FallTrough(options)...)
} }
} }
@ -447,7 +456,7 @@ type ClauseOptional struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder) { func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if !d.Show { if !d.Show {
return return
} }
@ -463,7 +472,7 @@ type ClauseIn struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) { func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if i.LockMode == "" { if i.LockMode == "" {
return return
} }
@ -485,7 +494,7 @@ type ClauseWindow struct {
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) { func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(i.Definitions) == 0 { if len(i.Definitions) == 0 {
return return
} }
@ -503,6 +512,46 @@ func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("()") out.WriteString("()")
continue continue
} }
def.Window.serialize(statementType, out) def.Window.serialize(statementType, out, FallTrough(options)...)
} }
} }
// SetPair clause
type SetPair struct {
Column ColumnSerializer
Value Serializer
}
// SetClauseNew clause
type SetClauseNew []ColumnAssigment
// Serialize for SetClauseNew
func (s SetClauseNew) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("SET")
out.IncreaseIdent(4)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
assigment.serialize(statementType, out, FallTrough(options)...)
}
out.DecreaseIdent(4)
}
// KeywordClause type
type KeywordClause struct {
Keyword
}
// Serialize for KeywordClause
func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
k.serialize(statementType, out, FallTrough(options)...)
}

View file

@ -1,14 +1,14 @@
package jet package jet
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
) )
func TestClauseSelect_Serialize(t *testing.T) { func TestClauseSelect_Serialize(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, "jet: SELECT clause has to have at least one projection") require.Equal(t, r, "jet: SELECT clause has to have at least one projection")
}() }()
selectClause := &ClauseSelect{} selectClause := &ClauseSelect{}

View file

@ -12,6 +12,12 @@ type Column interface {
defaultAlias() string defaultAlias() string
} }
// ColumnSerializer is interface for all serializable columns
type ColumnSerializer interface {
Serializer
Column
}
// ColumnExpression interface // ColumnExpression interface
type ColumnExpression interface { type ColumnExpression interface {
Column Column
@ -99,9 +105,9 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
if c.subQuery != nil { if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias()) out.WriteIdentifier(c.subQuery.Alias())
out.WriteByte('.') out.WriteByte('.')
out.WriteIdentifier(c.defaultAlias(), true) out.WriteIdentifier(c.defaultAlias())
} else { } else {
if c.tableName != "" { if c.tableName != "" && !contains(options, ShortName) {
out.WriteIdentifier(c.tableName) out.WriteIdentifier(c.tableName)
out.WriteByte('.') out.WriteByte('.')
} }
@ -109,45 +115,3 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
out.WriteIdentifier(c.name) out.WriteIdentifier(c.name)
} }
} }
//------------------------------------------------------//
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -0,0 +1,20 @@
package jet
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment interface {
Serializer
isColumnAssigment()
}
type columnAssigmentImpl struct {
column ColumnSerializer
expression Expression
}
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
out.WriteString("=")
a.expression.serialize(statement, out, FallTrough(options)...)
}

View file

@ -0,0 +1,60 @@
package jet
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
// SET creates column assigment for each column in column list. expression should be created by ROW function
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
return columnAssigmentImpl{
column: cl,
expression: expression,
}
}
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
for i, column := range cl {
if i > 0 {
out.WriteString(", ")
}
column.serialize(statement, out, FallTrough(options)...)
}
out.WriteString(")")
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -6,6 +6,7 @@ type ColumnBool interface {
Column Column
From(subQuery SelectTable) ColumnBool From(subQuery SelectTable) ColumnBool
SET(boolExp BoolExpression) ColumnAssigment
} }
type boolColumnImpl struct { type boolColumnImpl struct {
@ -21,6 +22,13 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
return newBoolColumn return newBoolColumn
} }
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: boolExp,
}
}
// BoolColumn creates named bool column. // BoolColumn creates named bool column.
func BoolColumn(name string) ColumnBool { func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{} boolColumn := &boolColumnImpl{}
@ -38,6 +46,7 @@ type ColumnFloat interface {
Column Column
From(subQuery SelectTable) ColumnFloat From(subQuery SelectTable) ColumnFloat
SET(floatExp FloatExpression) ColumnAssigment
} }
type floatColumnImpl struct { type floatColumnImpl struct {
@ -53,6 +62,13 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
return newFloatColumn return newFloatColumn
} }
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: floatExp,
}
}
// FloatColumn creates named float column. // FloatColumn creates named float column.
func FloatColumn(name string) ColumnFloat { func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{} floatColumn := &floatColumnImpl{}
@ -70,6 +86,7 @@ type ColumnInteger interface {
Column Column
From(subQuery SelectTable) ColumnInteger From(subQuery SelectTable) ColumnInteger
SET(intExp IntegerExpression) ColumnAssigment
} }
type integerColumnImpl struct { type integerColumnImpl struct {
@ -86,6 +103,13 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return newIntColumn return newIntColumn
} }
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intExp,
}
}
// IntegerColumn creates named integer column. // IntegerColumn creates named integer column.
func IntegerColumn(name string) ColumnInteger { func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{} integerColumn := &integerColumnImpl{}
@ -104,6 +128,7 @@ type ColumnString interface {
Column Column
From(subQuery SelectTable) ColumnString From(subQuery SelectTable) ColumnString
SET(stringExp StringExpression) ColumnAssigment
} }
type stringColumnImpl struct { type stringColumnImpl struct {
@ -120,6 +145,13 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return newStrColumn return newStrColumn
} }
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: stringExp,
}
}
// StringColumn creates named string column. // StringColumn creates named string column.
func StringColumn(name string) ColumnString { func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{} stringColumn := &stringColumnImpl{}
@ -137,6 +169,7 @@ type ColumnTime interface {
Column Column
From(subQuery SelectTable) ColumnTime From(subQuery SelectTable) ColumnTime
SET(timeExp TimeExpression) ColumnAssigment
} }
type timeColumnImpl struct { type timeColumnImpl struct {
@ -152,6 +185,13 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return newTimeColumn return newTimeColumn
} }
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timeExp,
}
}
// TimeColumn creates named time column // TimeColumn creates named time column
func TimeColumn(name string) ColumnTime { func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{} timeColumn := &timeColumnImpl{}
@ -183,6 +223,13 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return newTimezColumn return newTimezColumn
} }
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timezExp,
}
}
// TimezColumn creates named time with time zone column. // TimezColumn creates named time with time zone column.
func TimezColumn(name string) ColumnTimez { func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{} timezColumn := &timezColumnImpl{}
@ -200,6 +247,7 @@ type ColumnTimestamp interface {
Column Column
From(subQuery SelectTable) ColumnTimestamp From(subQuery SelectTable) ColumnTimestamp
SET(timestampExp TimestampExpression) ColumnAssigment
} }
type timestampColumnImpl struct { type timestampColumnImpl struct {
@ -215,6 +263,13 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return newTimestampColumn return newTimestampColumn
} }
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampExp,
}
}
// TimestampColumn creates named timestamp column // TimestampColumn creates named timestamp column
func TimestampColumn(name string) ColumnTimestamp { func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{} timestampColumn := &timestampColumnImpl{}
@ -232,6 +287,7 @@ type ColumnTimestampz interface {
Column Column
From(subQuery SelectTable) ColumnTimestampz From(subQuery SelectTable) ColumnTimestampz
SET(timestampzExp TimestampzExpression) ColumnAssigment
} }
type timestampzColumnImpl struct { type timestampzColumnImpl struct {
@ -247,6 +303,13 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return newTimestampzColumn return newTimestampzColumn
} }
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampzExp,
}
}
// TimestampzColumn creates named timestamp with time zone column. // TimestampzColumn creates named timestamp with time zone column.
func TimestampzColumn(name string) ColumnTimestampz { func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{} timestampzColumn := &timestampzColumnImpl{}
@ -264,6 +327,7 @@ type ColumnDate interface {
Column Column
From(subQuery SelectTable) ColumnDate From(subQuery SelectTable) ColumnDate
SET(dateExp DateExpression) ColumnAssigment
} }
type dateColumnImpl struct { type dateColumnImpl struct {
@ -279,6 +343,13 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return newDateColumn return newDateColumn
} }
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: dateExp,
}
}
// DateColumn creates named date column. // DateColumn creates named date column.
func DateColumn(name string) ColumnDate { func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{} dateColumn := &dateColumnImpl{}

View file

@ -22,9 +22,9 @@ func TestNewBoolColumn(t *testing.T) {
func TestNewIntColumn(t *testing.T) { func TestNewIntColumn(t *testing.T) {
intColumn := IntegerColumn("col_int").From(subQuery) intColumn := IntegerColumn("col_int").From(subQuery)
assertClauseSerialize(t, intColumn, `sub_query."col_int"`) assertClauseSerialize(t, intColumn, `sub_query.col_int`)
assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query."col_int" = $1)`, int64(12)) assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query.col_int = $1)`, int64(12))
assertProjectionSerialize(t, intColumn, `sub_query."col_int" AS "col_int"`) assertProjectionSerialize(t, intColumn, `sub_query.col_int AS "col_int"`)
intColumn2 := table1ColInt.From(subQuery) intColumn2 := table1ColInt.From(subQuery)
assertClauseSerialize(t, intColumn2, `sub_query."table1.col_int"`) assertClauseSerialize(t, intColumn2, `sub_query."table1.col_int"`)
@ -35,9 +35,9 @@ func TestNewIntColumn(t *testing.T) {
func TestNewFloatColumnColumn(t *testing.T) { func TestNewFloatColumnColumn(t *testing.T) {
floatColumn := FloatColumn("col_float").From(subQuery) floatColumn := FloatColumn("col_float").From(subQuery)
assertClauseSerialize(t, floatColumn, `sub_query."col_float"`) assertClauseSerialize(t, floatColumn, `sub_query.col_float`)
assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query."col_float" = $1)`, float64(1.11)) assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query.col_float = $1)`, float64(1.11))
assertProjectionSerialize(t, floatColumn, `sub_query."col_float" AS "col_float"`) assertProjectionSerialize(t, floatColumn, `sub_query.col_float AS "col_float"`)
floatColumn2 := table1ColFloat.From(subQuery) floatColumn2 := table1ColFloat.From(subQuery)
assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`) assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`)
@ -47,10 +47,10 @@ func TestNewFloatColumnColumn(t *testing.T) {
func TestNewDateColumnColumn(t *testing.T) { func TestNewDateColumnColumn(t *testing.T) {
dateColumn := DateColumn("col_date").From(subQuery) dateColumn := DateColumn("col_date").From(subQuery)
assertClauseSerialize(t, dateColumn, `sub_query."col_date"`) assertClauseSerialize(t, dateColumn, `sub_query.col_date`)
assertClauseSerialize(t, dateColumn.EQ(Date(2002, 2, 3)), assertClauseSerialize(t, dateColumn.EQ(Date(2002, 2, 3)),
`(sub_query."col_date" = $1)`, "2002-02-03") `(sub_query.col_date = $1)`, "2002-02-03")
assertProjectionSerialize(t, dateColumn, `sub_query."col_date" AS "col_date"`) assertProjectionSerialize(t, dateColumn, `sub_query.col_date AS "col_date"`)
dateColumn2 := table1ColDate.From(subQuery) dateColumn2 := table1ColDate.From(subQuery)
assertClauseSerialize(t, dateColumn2, `sub_query."table1.col_date"`) assertClauseSerialize(t, dateColumn2, `sub_query."table1.col_date"`)
@ -61,10 +61,10 @@ func TestNewDateColumnColumn(t *testing.T) {
func TestNewTimeColumnColumn(t *testing.T) { func TestNewTimeColumnColumn(t *testing.T) {
timeColumn := TimeColumn("col_time").From(subQuery) timeColumn := TimeColumn("col_time").From(subQuery)
assertClauseSerialize(t, timeColumn, `sub_query."col_time"`) assertClauseSerialize(t, timeColumn, `sub_query.col_time`)
assertClauseSerialize(t, timeColumn.EQ(Time(1, 1, 1, 1)), assertClauseSerialize(t, timeColumn.EQ(Time(1, 1, 1, 1)),
`(sub_query."col_time" = $1)`, "01:01:01.000000001") `(sub_query.col_time = $1)`, "01:01:01.000000001")
assertProjectionSerialize(t, timeColumn, `sub_query."col_time" AS "col_time"`) assertProjectionSerialize(t, timeColumn, `sub_query.col_time AS "col_time"`)
timeColumn2 := table1ColTime.From(subQuery) timeColumn2 := table1ColTime.From(subQuery)
assertClauseSerialize(t, timeColumn2, `sub_query."table1.col_time"`) assertClauseSerialize(t, timeColumn2, `sub_query."table1.col_time"`)
@ -75,10 +75,10 @@ func TestNewTimeColumnColumn(t *testing.T) {
func TestNewTimezColumnColumn(t *testing.T) { func TestNewTimezColumnColumn(t *testing.T) {
timezColumn := TimezColumn("col_timez").From(subQuery) timezColumn := TimezColumn("col_timez").From(subQuery)
assertClauseSerialize(t, timezColumn, `sub_query."col_timez"`) assertClauseSerialize(t, timezColumn, `sub_query.col_timez`)
assertClauseSerialize(t, timezColumn.EQ(Timez(1, 1, 1, 1, "UTC")), assertClauseSerialize(t, timezColumn.EQ(Timez(1, 1, 1, 1, "UTC")),
`(sub_query."col_timez" = $1)`, "01:01:01.000000001 UTC") `(sub_query.col_timez = $1)`, "01:01:01.000000001 UTC")
assertProjectionSerialize(t, timezColumn, `sub_query."col_timez" AS "col_timez"`) assertProjectionSerialize(t, timezColumn, `sub_query.col_timez AS "col_timez"`)
timezColumn2 := table1ColTimez.From(subQuery) timezColumn2 := table1ColTimez.From(subQuery)
assertClauseSerialize(t, timezColumn2, `sub_query."table1.col_timez"`) assertClauseSerialize(t, timezColumn2, `sub_query."table1.col_timez"`)
@ -89,10 +89,10 @@ func TestNewTimezColumnColumn(t *testing.T) {
func TestNewTimestampColumnColumn(t *testing.T) { func TestNewTimestampColumnColumn(t *testing.T) {
timestampColumn := TimestampColumn("col_timestamp").From(subQuery) timestampColumn := TimestampColumn("col_timestamp").From(subQuery)
assertClauseSerialize(t, timestampColumn, `sub_query."col_timestamp"`) assertClauseSerialize(t, timestampColumn, `sub_query.col_timestamp`)
assertClauseSerialize(t, timestampColumn.EQ(Timestamp(1, 1, 1, 1, 1, 1)), assertClauseSerialize(t, timestampColumn.EQ(Timestamp(1, 1, 1, 1, 1, 1)),
`(sub_query."col_timestamp" = $1)`, "0001-01-01 01:01:01") `(sub_query.col_timestamp = $1)`, "0001-01-01 01:01:01")
assertProjectionSerialize(t, timestampColumn, `sub_query."col_timestamp" AS "col_timestamp"`) assertProjectionSerialize(t, timestampColumn, `sub_query.col_timestamp AS "col_timestamp"`)
timestampColumn2 := table1ColTimestamp.From(subQuery) timestampColumn2 := table1ColTimestamp.From(subQuery)
assertClauseSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp"`) assertClauseSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp"`)
@ -103,10 +103,10 @@ func TestNewTimestampColumnColumn(t *testing.T) {
func TestNewTimestampzColumnColumn(t *testing.T) { func TestNewTimestampzColumnColumn(t *testing.T) {
timestampzColumn := TimestampzColumn("col_timestampz").From(subQuery) timestampzColumn := TimestampzColumn("col_timestampz").From(subQuery)
assertClauseSerialize(t, timestampzColumn, `sub_query."col_timestampz"`) assertClauseSerialize(t, timestampzColumn, `sub_query.col_timestampz`)
assertClauseSerialize(t, timestampzColumn.EQ(Timestampz(1, 1, 1, 1, 1, 1, 0, "UTC")), assertClauseSerialize(t, timestampzColumn.EQ(Timestampz(1, 1, 1, 1, 1, 1, 0, "UTC")),
`(sub_query."col_timestampz" = $1)`, "0001-01-01 01:01:01 UTC") `(sub_query.col_timestampz = $1)`, "0001-01-01 01:01:01 UTC")
assertProjectionSerialize(t, timestampzColumn, `sub_query."col_timestampz" AS "col_timestampz"`) assertProjectionSerialize(t, timestampzColumn, `sub_query.col_timestampz AS "col_timestampz"`)
timestampzColumn2 := table1ColTimestampz.From(subQuery) timestampzColumn2 := table1ColTimestampz.From(subQuery)
assertClauseSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz"`) assertClauseSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz"`)

View file

@ -72,15 +72,15 @@ func (e *ExpressionInterfaceImpl) DESC() OrderByClause {
} }
func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap) e.Parent.serialize(statement, out, NoWrap)
} }
func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap) e.Parent.serialize(statement, out, NoWrap)
} }
func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap) e.Parent.serialize(statement, out, NoWrap)
} }
// Representation of binary operations (e.g. comparisons, arithmetic) // Representation of binary operations (e.g. comparisons, arithmetic)
@ -117,7 +117,7 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
panic("jet: rhs is nil for '" + c.operator + "' operator") panic("jet: rhs is nil for '" + c.operator + "' operator")
} }
wrap := !contains(options, noWrap) wrap := !contains(options, NoWrap)
if wrap { if wrap {
out.WriteString("(") out.WriteString("(")
@ -125,11 +125,11 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, options...) serializeOverrideFunc(statement, out, FallTrough(options)...)
} else { } else {
c.lhs.serialize(statement, out) c.lhs.serialize(statement, out, FallTrough(options)...)
out.WriteString(c.operator) out.WriteString(c.operator)
c.rhs.serialize(statement, out) c.rhs.serialize(statement, out, FallTrough(options)...)
} }
if wrap { if wrap {
@ -163,7 +163,7 @@ func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, o
panic("jet: nil prefix expression in prefix operator " + p.operator) panic("jet: nil prefix expression in prefix operator " + p.operator)
} }
p.expression.serialize(statement, out) p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(")") out.WriteString(")")
} }
@ -192,7 +192,7 @@ func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder
panic("jet: nil prefix expression in postfix operator " + p.operator) panic("jet: nil prefix expression in postfix operator " + p.operator)
} }
p.expression.serialize(statement, out) p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(p.operator) out.WriteString(p.operator)
} }

View file

@ -145,6 +145,11 @@ func MINi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MIN", integerExpression) return newIntegerWindowFunc("MIN", integerExpression)
} }
// SUM is aggregate function. Returns sum of all expressions
func SUM(expression Expression) Expression {
return newWindowFunc("SUM", expression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions // SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) floatWindowExpression { func SUMf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("SUM", floatExpression) return NewFloatWindowFunc("SUM", floatExpression)
@ -613,7 +618,7 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression {
func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...) serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...)
serializeOverrideFunc(statement, out, options...) serializeOverrideFunc(statement, out, FallTrough(options)...)
return return
} }

View file

@ -33,5 +33,5 @@ type IntervalImpl struct {
func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("INTERVAL") out.WriteString("INTERVAL")
i.interval.serialize(statement, out, options...) i.interval.serialize(statement, out, FallTrough(options)...)
} }

View file

@ -2,18 +2,12 @@ package jet
const ( const (
// DEFAULT is jet equivalent of SQL DEFAULT // DEFAULT is jet equivalent of SQL DEFAULT
DEFAULT keywordClause = "DEFAULT" DEFAULT Keyword = "DEFAULT"
) )
var ( // Keyword type
// NULL is jet equivalent of SQL NULL type Keyword string
NULL = newNullLiteral()
// STAR is jet equivalent of SQL *
STAR = newStarLiteral()
)
type keywordClause string func (k Keyword) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
func (k keywordClause) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(string(k)) out.WriteString(string(k))
} }

View file

@ -278,6 +278,14 @@ func formatNanoseconds(nanoseconds ...time.Duration) string {
} }
//--------------------------------------------------// //--------------------------------------------------//
var (
// NULL is jet equivalent of SQL NULL
NULL = newNullLiteral()
// STAR is jet equivalent of SQL *
STAR = newStarLiteral()
)
type nullLiteral struct { type nullLiteral struct {
ExpressionInterfaceImpl ExpressionInterfaceImpl
} }

19
internal/jet/logger.go Normal file
View file

@ -0,0 +1,19 @@
package jet
import "context"
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement interface {
Sql() (query string, args []interface{})
DebugSql() (query string)
}
// LoggerFunc is a definition of a function user can implement to support automatic statement logging.
type LoggerFunc func(ctx context.Context, statement PrintableStatement)
var logger LoggerFunc
// SetLoggerFunc sets automatic statement logging
func SetLoggerFunc(loggerFunc LoggerFunc) {
logger = loggerFunc
}

View file

@ -147,7 +147,7 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o
out.WriteString("(CASE") out.WriteString("(CASE")
if c.expression != nil { if c.expression != nil {
c.expression.serialize(statement, out) c.expression.serialize(statement, out, FallTrough(options)...)
} }
if len(c.when) == 0 || len(c.then) == 0 { if len(c.when) == 0 || len(c.then) == 0 {
@ -160,15 +160,15 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o
for i, when := range c.when { for i, when := range c.when {
out.WriteString("WHEN") out.WriteString("WHEN")
when.serialize(statement, out, noWrap) when.serialize(statement, out, NoWrap)
out.WriteString("THEN") out.WriteString("THEN")
c.then[i].serialize(statement, out, noWrap) c.then[i].serialize(statement, out, NoWrap)
} }
if c.els != nil { if c.els != nil {
out.WriteString("ELSE") out.WriteString("ELSE")
c.els.serialize(statement, out, noWrap) c.els.serialize(statement, out, NoWrap)
} }
out.WriteString("END)") out.WriteString("END)")

View file

@ -8,35 +8,31 @@ type SelectTable interface {
} }
type selectTableImpl struct { type selectTableImpl struct {
selectStmt StatementWithProjections selectStmt SerializerStatement
alias string alias string
projections ProjectionList
} }
// NewSelectTable func // NewSelectTable func
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTable { func NewSelectTable(selectStmt SerializerStatement, alias string) SelectTable {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias} selectTable := &selectTableImpl{selectStmt: selectStmt, alias: alias}
return selectTable
projectionList := selectStmt.projections().fromImpl(&selectTable)
selectTable.projections = projectionList.(ProjectionList)
return &selectTable
} }
func (s *selectTableImpl) Alias() string { func (s selectTableImpl) Alias() string {
return s.alias return s.alias
} }
func (s *selectTableImpl) AllColumns() ProjectionList { func (s selectTableImpl) AllColumns() ProjectionList {
return s.projections statementWithProjections, ok := s.selectStmt.(HasProjections)
if !ok {
return ProjectionList{}
} }
func (s *selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { projectionList := statementWithProjections.projections().fromImpl(s)
if s == nil { return projectionList.(ProjectionList)
panic("jet: expression table is nil. ")
} }
func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.selectStmt.serialize(statement, out) s.selectStmt.serialize(statement, out)
out.WriteString("AS") out.WriteString("AS")

View file

@ -5,9 +5,18 @@ type SerializeOption int
// Serialize options // Serialize options
const ( const (
noWrap SerializeOption = iota NoWrap SerializeOption = iota
SkipNewLine
fallTroughOptions // fall trough options
ShortName
) )
// WithFallTrough extends existing serialize options with additional
func (s SerializeOption) WithFallTrough(options []SerializeOption) []SerializeOption {
return append(FallTrough(options), s)
}
// StatementType is type of the SQL statement // StatementType is type of the SQL statement
type StatementType string type StatementType string
@ -20,6 +29,7 @@ const (
SetStatementType StatementType = "SET" SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK" LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK" UnLockStatementType StatementType = "UNLOCK"
WithStatementType StatementType = "WITH"
) )
// Serializer interface // Serializer interface
@ -42,6 +52,19 @@ func contains(options []SerializeOption, option SerializeOption) bool {
return false return false
} }
// FallTrough filters fall-trough options from the list
func FallTrough(options []SerializeOption) []SerializeOption {
var ret []SerializeOption
for _, option := range options {
if option > fallTroughOptions {
ret = append(ret, option)
}
}
return ret
}
// ListSerializer serializes list of serializers with separator // ListSerializer serializes list of serializers with separator
type ListSerializer struct { type ListSerializer struct {
Serializers []Serializer Serializers []Serializer
@ -53,6 +76,21 @@ func (s ListSerializer) serialize(statement StatementType, out *SQLBuilder, opti
if i > 0 { if i > 0 {
out.WriteString(s.Separator) out.WriteString(s.Separator)
} }
ser.serialize(statement, out) ser.serialize(statement, out, FallTrough(options)...)
}
}
// NewSerializerClauseImpl is constructor for Seralizer with list of clauses
func NewSerializerClauseImpl(clauses ...Clause) Serializer {
return &serializerImpl{Clauses: clauses}
}
type serializerImpl struct {
Clauses []Clause
}
func (s serializerImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
for _, clause := range s.Clauses {
clause.Serialize(statement, out, FallTrough(options)...)
} }
} }

View file

@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) {
// WriteIdentifier adds identifier to output SQL // WriteIdentifier adds identifier to output SQL
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { if s.shouldQuote(name, alwaysQuote...) {
identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
s.WriteString(identQuoteChar + name + identQuoteChar) s.WriteString(identQuoteChar + name + identQuoteChar)
} else { } else {
@ -106,6 +106,10 @@ func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
} }
} }
func (s *SQLBuilder) shouldQuote(name string, alwaysQuote ...bool) bool {
return s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0
}
// WriteByte writes byte to output SQL // WriteByte writes byte to output SQL
func (s *SQLBuilder) WriteByte(b byte) { func (s *SQLBuilder) WriteByte(b byte) {
s.write([]byte{b}) s.write([]byte{b})
@ -159,10 +163,17 @@ func argToString(value interface{}) string {
case time.Time: case time.Time:
return stringQuote(string(pq.FormatTimestamp(bindVal))) return stringQuote(string(pq.FormatTimestamp(bindVal)))
default: default:
if strBindValue, ok := bindVal.(toStringInterface); ok {
return stringQuote(strBindValue.String())
}
panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String()))
} }
} }
type toStringInterface interface {
String() string
}
func integerTypesToString(value interface{}) string { func integerTypesToString(value interface{}) string {
switch bindVal := value.(type) { switch bindVal := value.(type) {
case int: case int:
@ -190,6 +201,13 @@ func integerTypesToString(value interface{}) string {
} }
func shouldQuoteIdentifier(identifier string) bool { func shouldQuoteIdentifier(identifier string) bool {
_, err := strconv.ParseInt(identifier, 10, 64)
if err == nil { // if it is a number we should quote it
return true
}
// check if contains non ascii characters
for _, c := range identifier { for _, c := range identifier {
if unicode.IsNumber(c) || c == '_' { if unicode.IsNumber(c) || c == '_' {
continue continue

View file

@ -2,42 +2,58 @@ package jet
import ( import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
func TestArgToString(t *testing.T) { func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(true), "TRUE") require.Equal(t, argToString(true), "TRUE")
assert.Equal(t, argToString(false), "FALSE") require.Equal(t, argToString(false), "FALSE")
assert.Equal(t, argToString(int(-32)), "-32") require.Equal(t, argToString(int(-32)), "-32")
assert.Equal(t, argToString(uint(32)), "32") require.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(int8(-43)), "-43") require.Equal(t, argToString(int8(-43)), "-43")
assert.Equal(t, argToString(uint8(43)), "43") require.Equal(t, argToString(uint8(43)), "43")
assert.Equal(t, argToString(int16(-54)), "-54") require.Equal(t, argToString(int16(-54)), "-54")
assert.Equal(t, argToString(uint16(54)), "54") require.Equal(t, argToString(uint16(54)), "54")
assert.Equal(t, argToString(int32(-65)), "-65") require.Equal(t, argToString(int32(-65)), "-65")
assert.Equal(t, argToString(uint32(65)), "65") require.Equal(t, argToString(uint32(65)), "65")
assert.Equal(t, argToString(int64(-64)), "-64") require.Equal(t, argToString(int64(-64)), "-64")
assert.Equal(t, argToString(uint64(64)), "64") require.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float32(2.0)), "2") require.Equal(t, argToString(float32(2.0)), "2")
assert.Equal(t, argToString(float64(1.11)), "1.11") require.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'") require.Equal(t, argToString("john"), "'john'")
assert.Equal(t, argToString("It's text"), "'It''s text'") require.Equal(t, argToString("It's text"), "'It''s text'")
assert.Equal(t, argToString([]byte("john")), "'john'") require.Equal(t, argToString([]byte("john")), "'john'")
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") require.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") require.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
func() { func() {
defer func() { defer func() {
assert.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter") require.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter")
}() }()
argToString(map[string]bool{}) argToString(map[string]bool{})
}() }()
} }
func TestFallTrough(t *testing.T) {
require.Equal(t, FallTrough([]SerializeOption{ShortName}), []SerializeOption{ShortName})
require.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil))
require.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName})
}
func TestShouldQuote(t *testing.T) {
require.Equal(t, shouldQuoteIdentifier("123"), true)
require.Equal(t, shouldQuoteIdentifier("123.235"), true)
require.Equal(t, shouldQuoteIdentifier("abc123"), false)
require.Equal(t, shouldQuoteIdentifier("abc.123"), true)
require.Equal(t, shouldQuoteIdentifier("abc_123"), false)
require.Equal(t, shouldQuoteIdentifier("Abc_123"), true)
require.Equal(t, shouldQuoteIdentifier("DŽƜĐǶ"), true)
}

View file

@ -13,7 +13,6 @@ type Statement interface {
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes. // Do not use it in production. Use it only for debug purposes.
DebugSql() (query string) DebugSql() (query string)
// Query executes statement over database connection db and stores row result in destination. // Query executes statement over database connection db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice. // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
@ -21,25 +20,19 @@ type Statement interface {
// QueryContext executes statement with a context over database connection db and stores row result in destination. // QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice. // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
QueryContext(context context.Context, db qrm.DB, destination interface{}) error QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows. //Exec executes statement over db connection without returning any rows.
Exec(db qrm.DB) (sql.Result, error) Exec(db qrm.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows. //Exec executes statement with context over db connection without returning any rows.
ExecContext(context context.Context, db qrm.DB) (sql.Result, error) ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error)
} }
// SerializerStatement interface // SerializerStatement interface
type SerializerStatement interface { type SerializerStatement interface {
Serializer Serializer
Statement Statement
}
// StatementWithProjections interface
type StatementWithProjections interface {
Statement
HasProjections HasProjections
Serializer
} }
// HasProjections interface // HasProjections interface
@ -58,7 +51,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface
queryData := &SQLBuilder{Dialect: s.dialect} queryData := &SQLBuilder{Dialect: s.dialect}
s.parent.serialize(s.statementType, queryData, noWrap) s.parent.serialize(s.statementType, queryData, NoWrap)
query, args = queryData.finalize() query, args = queryData.finalize()
return return
@ -67,7 +60,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface
func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true} sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true}
s.parent.serialize(s.statementType, sqlBuilder, noWrap) s.parent.serialize(s.statementType, sqlBuilder, NoWrap)
query, _ = sqlBuilder.finalize() query, _ = sqlBuilder.finalize()
return return
@ -75,25 +68,41 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error {
query, args := s.Sql() query, args := s.Sql()
ctx := context.Background()
return qrm.Query(context.Background(), db, query, args, destination) callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination)
} }
func (s *serializerStatementInterfaceImpl) QueryContext(context context.Context, db qrm.DB, destination interface{}) error { func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error {
query, args := s.Sql() query, args := s.Sql()
return qrm.Query(context, db, query, args, destination) callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination)
} }
func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) {
query, args := s.Sql() query, args := s.Sql()
callLogger(context.Background(), s)
return db.Exec(query, args...) return db.Exec(query, args...)
} }
func (s *serializerStatementInterfaceImpl) ExecContext(context context.Context, db qrm.DB) (res sql.Result, err error) { func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) {
query, args := s.Sql() query, args := s.Sql()
return db.ExecContext(context, query, args...) callLogger(ctx, s)
return db.ExecContext(ctx, query, args...)
}
func callLogger(ctx context.Context, statement Statement) {
if logger != nil {
logger(ctx, statement)
}
} }
// ExpressionStatement interfacess // ExpressionStatement interfacess
@ -148,7 +157,7 @@ type statementImpl struct {
func (s *statementImpl) projections() ProjectionList { func (s *statementImpl) projections() ProjectionList {
for _, clause := range s.Clauses { for _, clause := range s.Clauses {
if selectClause, ok := clause.(ClauseWithProjections); ok { if selectClause, ok := clause.(ClauseWithProjections); ok {
return selectClause.projections() return selectClause.Projections()
} }
} }
@ -156,17 +165,16 @@ func (s *statementImpl) projections() ProjectionList {
} }
func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
if !contains(options, noWrap) {
out.WriteString("(") out.WriteString("(")
out.IncreaseIdent() out.IncreaseIdent()
} }
for _, clause := range s.Clauses { for _, clause := range s.Clauses {
clause.Serialize(statement, out) clause.Serialize(statement, out, FallTrough(options)...)
} }
if !contains(options, noWrap) { if !contains(options, NoWrap) {
out.DecreaseIdent() out.DecreaseIdent()
out.NewLine() out.NewLine()
out.WriteString(")") out.WriteString(")")

View file

@ -19,17 +19,15 @@ type Table interface {
} }
// NewTable creates new table with schema Name, table Name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, column ColumnExpression, columns ...ColumnExpression) SerializerTable { func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable {
columnList := append([]ColumnExpression{column}, columns...)
t := tableImpl{ t := tableImpl{
schemaName: schemaName, schemaName: schemaName,
name: name, name: name,
columnList: columnList, columnList: columns,
} }
for _, c := range columnList { for _, c := range columns {
c.setTableName(name) c.setTableName(name)
} }
@ -156,7 +154,7 @@ func (t *joinTableImpl) serialize(statement StatementType, out *SQLBuilder, opti
panic("jet: left hand side of join operation is nil table") panic("jet: left hand side of join operation is nil table")
} }
t.lhs.serialize(statement, out) t.lhs.serialize(statement, out, FallTrough(options)...)
out.NewLine() out.NewLine()

View file

@ -1,18 +1,18 @@
package jet package jet
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
) )
func TestNewTable(t *testing.T) { func TestNewTable(t *testing.T) {
newTable := NewTable("schema", "table", IntegerColumn("intCol")) newTable := NewTable("schema", "table", IntegerColumn("intCol"))
assert.Equal(t, newTable.SchemaName(), "schema") require.Equal(t, newTable.SchemaName(), "schema")
assert.Equal(t, newTable.TableName(), "table") require.Equal(t, newTable.TableName(), "table")
assert.Equal(t, len(newTable.columns()), 1) require.Equal(t, len(newTable.columns()), 1)
assert.Equal(t, newTable.columns()[0].Name(), "intCol") require.Equal(t, newTable.columns()[0].Name(), "intCol")
} }
func TestNewJoinTable(t *testing.T) { func TestNewJoinTable(t *testing.T) {
@ -24,10 +24,10 @@ func TestNewJoinTable(t *testing.T) {
assertClauseSerialize(t, joinTable, `schema.table assertClauseSerialize(t, joinTable, `schema.table
INNER JOIN schema.table2 ON ("intCol1" = "intCol2")`) INNER JOIN schema.table2 ON ("intCol1" = "intCol2")`)
assert.Equal(t, joinTable.SchemaName(), "schema") require.Equal(t, joinTable.SchemaName(), "schema")
assert.Equal(t, joinTable.TableName(), "") require.Equal(t, joinTable.TableName(), "")
assert.Equal(t, len(joinTable.columns()), 2) require.Equal(t, len(joinTable.columns()), 2)
assert.Equal(t, joinTable.columns()[0].Name(), "intCol1") require.Equal(t, joinTable.columns()[0].Name(), "intCol1")
assert.Equal(t, joinTable.columns()[1].Name(), "intCol2") require.Equal(t, joinTable.columns()[1].Name(), "intCol2")
} }

View file

@ -1,7 +1,7 @@
package jet package jet
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"strconv" "strconv"
"testing" "testing"
) )
@ -56,14 +56,14 @@ func assertClauseSerialize(t *testing.T, clause Serializer, query string, args .
//fmt.Println(out.Buff.String()) //fmt.Println(out.Buff.String())
assert.Equal(t, out.Buff.String(), query) require.Equal(t, out.Buff.String(), query)
assert.Equal(t, out.Args, args) require.Equal(t, out.Args, args)
} }
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, errString) require.Equal(t, r, errString)
}() }()
out := SQLBuilder{Dialect: defaultDialect} out := SQLBuilder{Dialect: defaultDialect}
@ -76,14 +76,14 @@ func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, a
//fmt.Println(out.Buff.String()) //fmt.Println(out.Buff.String())
assert.Equal(t, out.Buff.String(), query) require.Equal(t, out.Buff.String(), query)
assert.Equal(t, out.Args, args) require.Equal(t, out.Args, args)
} }
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) { func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SQLBuilder{Dialect: defaultDialect} out := SQLBuilder{Dialect: defaultDialect}
projection.serializeForProjection(SelectStatementType, &out) projection.serializeForProjection(SelectStatementType, &out)
assert.Equal(t, out.Buff.String(), query) require.Equal(t, out.Buff.String(), query)
assert.Equal(t, out.Args, args) require.Equal(t, out.Args, args)
} }

View file

@ -59,7 +59,23 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) {
panic("jet: nil column in columns list") panic("jet: nil column in columns list")
} }
out.WriteString(col.Name()) out.WriteIdentifier(col.Name())
}
}
// SerializeColumnExpressionNames func
func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType,
out *SQLBuilder, options ...SerializeOption) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
col.serialize(statementType, out, options...)
} }
} }
@ -85,7 +101,8 @@ func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
return ret return ret
} }
func valueToClause(value interface{}) Serializer { // ToSerializerValue creates Serializer type from the value
func ToSerializerValue(value interface{}) Serializer {
if clause, ok := value.(Serializer); ok { if clause, ok := value.(Serializer); ok {
return clause return clause
} }
@ -148,7 +165,7 @@ func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer {
allValues := append([]interface{}{value}, values...) allValues := append([]interface{}{value}, values...)
for _, val := range allValues { for _, val := range allValues {
row = append(row, valueToClause(val)) row = append(row, ToSerializerValue(val))
} }
return row return row

View file

@ -1,19 +1,19 @@
package jet package jet
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
) )
func TestOptionalOrDefaultString(t *testing.T) { func TestOptionalOrDefaultString(t *testing.T) {
assert.Equal(t, OptionalOrDefaultString("default"), "default") require.Equal(t, OptionalOrDefaultString("default"), "default")
assert.Equal(t, OptionalOrDefaultString("default", "optional"), "optional") require.Equal(t, OptionalOrDefaultString("default", "optional"), "optional")
} }
func TestOptionalOrDefaultExpression(t *testing.T) { func TestOptionalOrDefaultExpression(t *testing.T) {
defaultExpression := table2ColFloat defaultExpression := table2ColFloat
optionalExpression := table1Col1 optionalExpression := table1Col1
assert.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression) require.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression)
assert.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression) require.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression)
} }

View file

@ -17,7 +17,7 @@ func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, o
w.expression.serialize(statement, out) w.expression.serialize(statement, out)
if w.window != nil { if w.window != nil {
out.WriteString("OVER") out.WriteString("OVER")
w.window.serialize(statement, out) w.window.serialize(statement, out, FallTrough(options)...)
} }
} }
@ -49,7 +49,7 @@ func (f *windowExpressionImpl) OVER(window ...Window) Expression {
} }
func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out) f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
} }
// ----------------------------------------------------- // -----------------------------------------------------
@ -80,7 +80,7 @@ func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression {
} }
func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out) f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
} }
// ------------------------------------------------ // ------------------------------------------------
@ -111,7 +111,7 @@ func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression {
} }
func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out) f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
} }
// ------------------------------------------------ // ------------------------------------------------
@ -142,5 +142,5 @@ func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression {
} }
func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out) f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
} }

View file

@ -30,7 +30,7 @@ func newWindowImpl(parent Window) *windowImpl {
} }
func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, noWrap) { if !contains(options, NoWrap) {
out.WriteByte('(') out.WriteByte('(')
} }
@ -40,7 +40,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options
serializeExpressionList(statement, w.partitionBy, ", ", out) serializeExpressionList(statement, w.partitionBy, ", ", out)
} }
w.orderBy.SkipNewLine = true w.orderBy.SkipNewLine = true
w.orderBy.Serialize(statement, out) w.orderBy.Serialize(statement, out, FallTrough(options)...)
if w.frameUnits != "" { if w.frameUnits != "" {
out.WriteString(w.frameUnits) out.WriteString(w.frameUnits)
@ -55,7 +55,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options
} }
} }
if !contains(options, noWrap) { if !contains(options, NoWrap) {
out.WriteByte(')') out.WriteByte(')')
} }
} }
@ -139,7 +139,7 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op
if f == nil { if f == nil {
return return
} }
f.offset.serialize(statement, out) f.offset.serialize(statement, out, FallTrough(options)...)
if f.preceding { if f.preceding {
out.WriteString("PRECEDING") out.WriteString("PRECEDING")
@ -152,12 +152,12 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op
// Window function keywords // Window function keywords
var ( var (
UNBOUNDED = keywordClause("UNBOUNDED") UNBOUNDED = Keyword("UNBOUNDED")
CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"} CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"}
) )
type frameExtentKeyword struct { type frameExtentKeyword struct {
keywordClause Keyword
} }
func (f frameExtentKeyword) isFrameExtent() {} func (f frameExtentKeyword) isFrameExtent() {}
@ -180,7 +180,7 @@ func (w windowName) serialize(statement StatementType, out *SQLBuilder, options
out.WriteByte('(') out.WriteByte('(')
out.WriteString(w.name) out.WriteString(w.name)
w.windowImpl.serialize(statement, out, noWrap) w.windowImpl.serialize(statement, out, NoWrap.WithFallTrough(options)...)
out.WriteByte(')') out.WriteByte(')')
} }

View file

@ -0,0 +1,82 @@
package jet
// WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement {
newWithImpl := &withImpl{
ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
statementType: WithStatementType,
},
}
newWithImpl.parent = newWithImpl
return func(primaryStatement Statement) Statement {
serializerStatement, ok := primaryStatement.(SerializerStatement)
if !ok {
panic("jet: unsupported main WITH statement.")
}
newWithImpl.primaryStatement = serializerStatement
return newWithImpl
}
}
type withImpl struct {
serializerStatementInterfaceImpl
ctes []CommonTableExpressionDefinition
primaryStatement SerializerStatement
}
func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("WITH")
for i, cte := range w.ctes {
if i > 0 {
out.WriteString(",")
}
cte.serialize(statement, out, FallTrough(options)...)
}
w.primaryStatement.serialize(statement, out, NoWrap.WithFallTrough(options)...)
}
func (w withImpl) projections() ProjectionList {
return ProjectionList{}
}
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
selectTableImpl
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
return CommonTableExpression{
selectTableImpl: selectTableImpl{
selectStmt: nil,
alias: name,
},
}
}
func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.alias)
}
// AS returns sets definition for a CTE
func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition {
c.selectStmt = statement
return CommonTableExpressionDefinition{cte: c}
}
// CommonTableExpressionDefinition contains implementation details of CTE
type CommonTableExpressionDefinition struct {
cte *CommonTableExpression
}
func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.cte.alias)
out.WriteString("AS")
c.cte.selectStmt.serialize(statement, out, FallTrough(options)...)
}

View file

@ -7,12 +7,14 @@ import (
"github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/qrm" "github.com/go-jet/jet/qrm"
"github.com/stretchr/testify/assert" "github.com/google/uuid"
"github.com/stretchr/testify/require"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"testing" "testing"
"time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
@ -21,12 +23,12 @@ import (
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db) res, err := stmt.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
rows, err := res.RowsAffected() rows, err := res.RowsAffected()
assert.NoError(t, err) require.NoError(t, err)
if len(rowsAffected) > 0 { if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0]) require.Equal(t, rowsAffected[0], rows)
} }
} }
@ -34,7 +36,7 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int
func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) { func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) {
_, err := stmt.Exec(db) _, err := stmt.Exec(db)
assert.Error(t, err, errorStr) require.Error(t, err, errorStr)
} }
func getFullPath(relativePath string) string { func getFullPath(relativePath string) string {
@ -51,9 +53,9 @@ func PrintJson(v interface{}) {
// AssertJSON check if data json output is the same as expectedJSON // AssertJSON check if data json output is the same as expectedJSON
func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t") jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON) require.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
} }
// SaveJSONFile saves v as json at testRelativePath // SaveJSONFile saves v as json at testRelativePath
@ -71,23 +73,23 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath) filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath) fileJSONData, err := ioutil.ReadFile(filePath)
assert.NoError(t, err) require.NoError(t, err)
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1) fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1)
} }
jsonData, err := json.MarshalIndent(data, "", "\t") jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NoError(t, err) require.NoError(t, err)
assert.True(t, string(fileJSONData) == string(jsonData)) require.True(t, string(fileJSONData) == string(jsonData))
//AssertDeepEqual(t, string(fileJSONData), string(jsonData)) //AssertDeepEqual(t, string(fileJSONData), string(jsonData))
} }
// AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs
func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args := query.Sql() queryStr, args := query.Sql()
assert.Equal(t, queryStr, expectedQuery) require.Equal(t, queryStr, expectedQuery)
if len(expectedArgs) == 0 { if len(expectedArgs) == 0 {
return return
@ -99,7 +101,7 @@ func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string,
func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) { func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, errorStr) require.Equal(t, r, errorStr)
}() }()
stmt.Sql() stmt.Sql()
@ -110,17 +112,17 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
_, args := query.Sql() _, args := query.Sql()
if len(expectedArgs) > 0 { if len(expectedArgs) > 0 {
AssertDeepEqual(t, args, expectedArgs) AssertDeepEqual(t, args, expectedArgs, "arguments are not equal")
} }
debuqSql := query.DebugSql() debuqSql := query.DebugSql()
assert.Equal(t, debuqSql, expectedQuery) require.Equal(t, debuqSql, expectedQuery)
} }
// AssertClauseSerialize checks if clause serialize produces expected query and args // AssertSerialize checks if clause serialize produces expected query and args
func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { func AssertSerialize(t *testing.T, dialect jet.Dialect, serializer jet.Serializer, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect} out := jet.SQLBuilder{Dialect: dialect}
jet.Serialize(clause, jet.SelectStatementType, &out) jet.Serialize(serializer, jet.SelectStatementType, &out)
//fmt.Println(out.Buff.String()) //fmt.Println(out.Buff.String())
@ -131,8 +133,20 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali
} }
} }
// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args // AssertClauseSerialize checks if clause serialize produces expected query and args
func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect}
clause.Serialize(jet.SelectStatementType, &out)
require.Equal(t, out.Buff.String(), query)
if len(args) > 0 {
AssertDeepEqual(t, out.Args, args)
}
}
// AssertDebugSerialize checks if clause serialize produces expected debug query and args
func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect, Debug: true} out := jet.SQLBuilder{Dialect: dialect, Debug: true}
jet.Serialize(clause, jet.SelectStatementType, &out) jet.Serialize(clause, jet.SelectStatementType, &out)
@ -147,17 +161,17 @@ func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Se
func AssertPanicErr(t *testing.T, fun func(), errorStr string) { func AssertPanicErr(t *testing.T, fun func(), errorStr string) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, errorStr) require.Equal(t, r, errorStr)
}() }()
fun() fun()
} }
// AssertClauseSerializeErr check if clause serialize panics with errString // AssertSerializeErr check if clause serialize panics with errString
func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { func AssertSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, errString) require.Equal(t, r, errString)
}() }()
out := jet.SQLBuilder{Dialect: dialect} out := jet.SQLBuilder{Dialect: dialect}
@ -177,28 +191,24 @@ func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet
func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest interface{}, errString string) { func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest interface{}, errString string) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, errString) require.Equal(t, r, errString)
}() }()
stmt.Query(db, dest) stmt.Query(db, dest)
} }
// AssertFileContent check if file content at filePath contains expectedContent text. // AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath) enumFileData, err := ioutil.ReadFile(filePath)
assert.NoError(t, err) require.NoError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) require.Equal(t, "\n"+string(enumFileData), expectedContent)
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
} }
// AssertFileNamesEqual check if all filesInfos are contained in fileNames // AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) { func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
assert.Equal(t, len(fileInfos), len(fileNames)) require.Equal(t, len(fileInfos), len(fileNames))
fileNamesMap := map[string]bool{} fileNamesMap := map[string]bool{}
@ -207,11 +217,88 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
} }
for _, fileName := range fileNames { for _, fileName := range fileNames {
assert.True(t, fileNamesMap[fileName], fileName+" does not exist.") require.True(t, fileNamesMap[fileName], fileName+" does not exist.")
} }
} }
// AssertDeepEqual checks if actual and expected objects are deeply equal. // AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}) { func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) {
assert.True(t, cmp.Equal(actual, expected)) require.True(t, cmp.Equal(actual, expected), msg)
}
// BoolPtr returns address of bool parameter
func BoolPtr(b bool) *bool {
return &b
}
// Int8Ptr returns address of int8 parameter
func Int8Ptr(i int8) *int8 {
return &i
}
// UInt8Ptr returns address of uint8 parameter
func UInt8Ptr(i uint8) *uint8 {
return &i
}
// Int16Ptr returns address of int16 parameter
func Int16Ptr(i int16) *int16 {
return &i
}
// UInt16Ptr returns address of uint16 parameter
func UInt16Ptr(i uint16) *uint16 {
return &i
}
// Int32Ptr returns address of int32 parameter
func Int32Ptr(i int32) *int32 {
return &i
}
// UInt32Ptr returns address of uint32 parameter
func UInt32Ptr(i uint32) *uint32 {
return &i
}
// Int64Ptr returns address of int64 parameter
func Int64Ptr(i int64) *int64 {
return &i
}
// UInt64Ptr returns address of uint64 parameter
func UInt64Ptr(i uint64) *uint64 {
return &i
}
// StringPtr returns address of string parameter
func StringPtr(s string) *string {
return &s
}
// TimePtr returns address of time.Time parameter
func TimePtr(t time.Time) *time.Time {
return &t
}
// ByteArrayPtr returns address of []byte parameter
func ByteArrayPtr(arr []byte) *[]byte {
return &arr
}
// Float32Ptr returns address of float32 parameter
func Float32Ptr(f float32) *float32 {
return &f
}
// Float64Ptr returns address of float64 parameter
func Float64Ptr(f float64) *float64 {
return &f
}
// UUIDPtr returns address of uuid.UUID
func UUIDPtr(u string) *uuid.UUID {
newUUID := uuid.MustParse(u)
return &newUUID
} }

View file

@ -2,32 +2,32 @@ package utils
import ( import (
"fmt" "fmt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
) )
func TestToGoIdentifier(t *testing.T) { func TestToGoIdentifier(t *testing.T) {
assert.Equal(t, ToGoIdentifier(""), "") require.Equal(t, ToGoIdentifier(""), "")
assert.Equal(t, ToGoIdentifier("uuid"), "UUID") require.Equal(t, ToGoIdentifier("uuid"), "UUID")
assert.Equal(t, ToGoIdentifier("col1"), "Col1") require.Equal(t, ToGoIdentifier("col1"), "Col1")
assert.Equal(t, ToGoIdentifier("PG-13"), "Pg13") require.Equal(t, ToGoIdentifier("PG-13"), "Pg13")
assert.Equal(t, ToGoIdentifier("13_pg"), "13Pg") require.Equal(t, ToGoIdentifier("13_pg"), "13Pg")
assert.Equal(t, ToGoIdentifier("mytable"), "Mytable") require.Equal(t, ToGoIdentifier("mytable"), "Mytable")
assert.Equal(t, ToGoIdentifier("MYTABLE"), "Mytable") require.Equal(t, ToGoIdentifier("MYTABLE"), "Mytable")
assert.Equal(t, ToGoIdentifier("MyTaBlE"), "MyTaBlE") require.Equal(t, ToGoIdentifier("MyTaBlE"), "MyTaBlE")
assert.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE") require.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE")
assert.Equal(t, ToGoIdentifier("my_table"), "MyTable") require.Equal(t, ToGoIdentifier("my_table"), "MyTable")
assert.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable") require.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable")
assert.Equal(t, ToGoIdentifier("My_Table"), "MyTable") require.Equal(t, ToGoIdentifier("My_Table"), "MyTable")
assert.Equal(t, ToGoIdentifier("My Table"), "MyTable") require.Equal(t, ToGoIdentifier("My Table"), "MyTable")
assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable") require.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
} }
func TestToGoEnumValueIdentifier(t *testing.T) { func TestToGoEnumValueIdentifier(t *testing.T) {
assert.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue") require.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue")
assert.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100") require.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100")
} }
func TestErrorCatchErr(t *testing.T) { func TestErrorCatchErr(t *testing.T) {
@ -39,7 +39,7 @@ func TestErrorCatchErr(t *testing.T) {
panic(fmt.Errorf("newError")) panic(fmt.Errorf("newError"))
}() }()
assert.Error(t, err, "newError") require.Error(t, err, "newError")
} }
func TestErrorCatchNonErr(t *testing.T) { func TestErrorCatchNonErr(t *testing.T) {
@ -51,5 +51,5 @@ func TestErrorCatchNonErr(t *testing.T) {
panic(11) panic(11)
}() }()
assert.Error(t, err, "11") require.Error(t, err, "11")
} }

View file

@ -26,6 +26,7 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(int) string { ArgumentPlaceholder: func(int) string {
return "?" return "?"
}, },
ReservedWords: reservedWords,
} }
return jet.NewDialect(mySQLDialectParams) return jet.NewDialect(mySQLDialectParams)
@ -160,3 +161,267 @@ func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFun
jet.Serialize(expressions[1], statement, out, options...) jet.Serialize(expressions[1], statement, out, options...)
} }
} }
var reservedWords = []string{
"ACCESSIBLE",
"ADD",
"ALL",
"ALTER",
"ANALYZE",
"AND",
"AS",
"ASC",
"ASENSITIVE",
"BEFORE",
"BETWEEN",
"BIGINT",
"BINARY",
"BLOB",
"BOTH",
"BY",
"CALL",
"CASCADE",
"CASE",
"CHANGE",
"CHAR",
"CHARACTER",
"CHECK",
"COLLATE",
"COLUMN",
"CONDITION",
"CONSTRAINT",
"CONTINUE",
"CONVERT",
"CREATE",
"CROSS",
"CUBE",
"CUME_DIST",
"CURRENT_DATE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_USER",
"CURSOR",
"DATABASE",
"DATABASES",
"DAY_HOUR",
"DAY_MICROSECOND",
"DAY_MINUTE",
"DAY_SECOND",
"DEC",
"DECIMAL",
"DECLARE",
"DEFAULT",
"DELAYED",
"DELETE",
"DENSE_RANK",
"DESC",
"DESCRIBE",
"DETERMINISTIC",
"DISTINCT",
"DISTINCTROW",
"DIV",
"DOUBLE",
"DROP",
"DUAL",
"EACH",
"ELSE",
"ELSEIF",
"EMPTY",
"ENCLOSED",
"ESCAPED",
"EXCEPT",
"EXISTS",
"EXIT",
"EXPLAIN",
"FALSE",
"FETCH",
"FIRST_VALUE",
"FLOAT",
"FLOAT4",
"FLOAT8",
"FOR",
"FORCE",
"FOREIGN",
"FROM",
"FULLTEXT",
"FUNCTION",
"GENERATED",
"GET",
"GRANT",
"GROUP",
"GROUPING",
"GROUPS",
"HAVING",
"HIGH_PRIORITY",
"HOUR_MICROSECOND",
"HOUR_MINUTE",
"HOUR_SECOND",
"IF",
"IGNORE",
"IN",
"INDEX",
"INFILE",
"INNER",
"INOUT",
"INSENSITIVE",
"INSERT",
"INT",
"INT1",
"INT2",
"INT3",
"INT4",
"INT8",
"INTEGER",
"INTERVAL",
"INTO",
"IO_AFTER_GTIDS",
"IO_BEFORE_GTIDS",
"IS",
"ITERATE",
"JOIN",
"JSON_TABLE",
"KEY",
"KEYS",
"KILL",
"LAG",
"LAST_VALUE",
"LATERAL",
"LEAD",
"LEADING",
"LEAVE",
"LEFT",
"LIKE",
"LIMIT",
"LINEAR",
"LINES",
"LOAD",
"LOCALTIME",
"LOCALTIMESTAMP",
"LOCK",
"LONG",
"LONGBLOB",
"LONGTEXT",
"LOOP",
"LOW_PRIORITY",
"MASTER_BIND",
"MASTER_SSL_VERIFY_SERVER_CERT",
"MATCH",
"MAXVALUE",
"MEDIUMBLOB",
"MEDIUMINT",
"MEDIUMTEXT",
"MIDDLEINT",
"MINUTE_MICROSECOND",
"MINUTE_SECOND",
"MOD",
"MODIFIES",
"NATURAL",
"NOT",
"NO_WRITE_TO_BINLOG",
"NTH_VALUE",
"NTILE",
"NULL",
"NUMERIC",
"OF",
"ON",
"OPTIMIZE",
"OPTIMIZER_COSTS",
"OPTION",
"OPTIONALLY",
"OR",
"ORDER",
"OUT",
"OUTER",
"OUTFILE",
"OVER",
"PARTITION",
"PERCENT_RANK",
"PRECISION",
"PRIMARY",
"PROCEDURE",
"PURGE",
"RANGE",
"RANK",
"READ",
"READS",
"READ_WRITE",
"REAL",
"RECURSIVE",
"REFERENCES",
"REGEXP",
"RELEASE",
"RENAME",
"REPEAT",
"REPLACE",
"REQUIRE",
"RESIGNAL",
"RESTRICT",
"RETURN",
"REVOKE",
"RIGHT",
"RLIKE",
"ROW",
"ROWS",
"ROW_NUMBER",
"SCHEMA",
"SCHEMAS",
"SECOND_MICROSECOND",
"SELECT",
"SENSITIVE",
"SEPARATOR",
"SET",
"SHOW",
"SIGNAL",
"SMALLINT",
"SPATIAL",
"SPECIFIC",
"SQL",
"SQLEXCEPTION",
"SQLSTATE",
"SQLWARNING",
"SQL_BIG_RESULT",
"SQL_CALC_FOUND_ROWS",
"SQL_SMALL_RESULT",
"SSL",
"STARTING",
"STORED",
"STRAIGHT_JOIN",
"SYSTEM",
"TABLE",
"TERMINATED",
"THEN",
"TINYBLOB",
"TINYINT",
"TINYTEXT",
"TO",
"TRAILING",
"TRIGGER",
"TRUE",
"UNDO",
"UNION",
"UNIQUE",
"UNLOCK",
"UNSIGNED",
"UPDATE",
"USAGE",
"USE",
"USING",
"UTC_DATE",
"UTC_TIME",
"UTC_TIMESTAMP",
"VALUES",
"VARBINARY",
"VARCHAR",
"VARCHARACTER",
"VARYING",
"VIRTUAL",
"WHEN",
"WHERE",
"WHILE",
"WINDOW",
"WITH",
"WRITE",
"XOR",
"YEAR_MONTH",
"ZEROFILL",
}

View file

@ -85,6 +85,9 @@ var MINi = jet.MINi
// MINf is aggregate function. Returns minimum value of float expression across all input values // MINf is aggregate function. Returns minimum value of float expression across all input values
var MINf = jet.MINf var MINf = jet.MINf
// SUM is aggregate function. Returns sum of all expressions
var SUM = jet.SUM
// SUMi is aggregate function. Returns sum of integer expression. // SUMi is aggregate function. Returns sum of integer expression.
var SUMi = jet.SUMi var SUMi = jet.SUMi

View file

@ -13,13 +13,15 @@ type InsertStatement interface {
MODEL(data interface{}) InsertStatement MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement MODELS(data interface{}) InsertStatement
ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement
} }
func newInsertStatement(table Table, columns []jet.Column) InsertStatement { func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{} newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery) &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnDuplicateKey)
newInsert.Insert.Table = table newInsert.Insert.Table = table
newInsert.Insert.Columns = columns newInsert.Insert.Columns = columns
@ -32,24 +34,53 @@ type insertStatementImpl struct {
Insert jet.ClauseInsert Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery ValuesQuery jet.ClauseValuesQuery
OnDuplicateKey onDuplicateKeyUpdateClause
} }
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return i return is
} }
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data)) is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data))
return i return is
} }
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...) is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...)
return i return is
} }
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
i.ValuesQuery.Query = selectStatement is.OnDuplicateKey = assigments
return i return is
}
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
is.ValuesQuery.Query = selectStatement
return is
}
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
// Serialize for SetClause
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("ON DUPLICATE KEY UPDATE")
out.IncreaseIdent(24)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
}
out.DecreaseIdent(24)
} }

View file

@ -1,7 +1,7 @@
package mysql package mysql
import ( import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
@ -13,15 +13,15 @@ func TestInvalidInsert(t *testing.T) {
func TestInsertNilValue(t *testing.T) { func TestInsertNilValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1) VALUES INSERT INTO db.table1 (col1)
(?); VALUES (?);
`, nil) `, nil)
} }
func TestInsertSingleValue(t *testing.T) { func TestInsertSingleValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1) VALUES INSERT INTO db.table1 (col1)
(?); VALUES (?);
`, int(1)) `, int(1))
} }
@ -31,8 +31,8 @@ func TestInsertWithColumnList(t *testing.T) {
columnList = append(columnList, table3StrCol) columnList = append(columnList, table3StrCol)
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2) VALUES INSERT INTO db.table3 (col_int, col2)
(?, ?); VALUES (?, ?);
`, 1, 3) `, 1, 3)
} }
@ -40,15 +40,15 @@ func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), `
INSERT INTO db.table1 (col_timestamp) VALUES INSERT INTO db.table1 (col_timestamp)
(?); VALUES (?);
`, date) `, date)
} }
func TestInsertMultipleValues(t *testing.T) { func TestInsertMultipleValues(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3) VALUES INSERT INTO db.table1 (col1, col_float, col3)
(?, ?, ?); VALUES (?, ?, ?);
`, 1, 2, 3) `, 1, 2, 3)
} }
@ -59,8 +59,8 @@ func TestInsertMultipleRows(t *testing.T) {
VALUES(111, 222) VALUES(111, 222)
assertStatementSql(t, stmt, ` assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float)
(?, ?), VALUES (?, ?),
(?, ?), (?, ?),
(?, ?); (?, ?);
`, 1, 2, 11, 22, 111, 222) `, 1, 2, 11, 22, 111, 222)
@ -84,8 +84,8 @@ func TestInsertValuesFromModel(t *testing.T) {
MODEL(&toInsert) MODEL(&toInsert)
expectedSQL := ` expectedSQL := `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float)
(?, ?), VALUES (?, ?),
(?, ?); (?, ?);
` `
@ -95,7 +95,7 @@ INSERT INTO db.table1 (col1, col_float) VALUES
func TestInsertValuesFromModelColumnMismatch(t *testing.T) { func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, "missing struct field for column : col1") require.Equal(t, r, "missing struct field for column : col1")
}() }()
type Table1Model struct { type Table1Model struct {
Col1Prim int Col1Prim int
@ -116,7 +116,7 @@ func TestInsertFromNonStructModel(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, "jet: data has to be a struct") require.Equal(t, r, "jet: data has to be a struct")
}() }()
table2.INSERT(table2ColInt).MODEL([]int{}) table2.INSERT(table2ColInt).MODEL([]int{})
@ -127,9 +127,56 @@ func TestInsertDefaultValue(t *testing.T) {
VALUES(DEFAULT, "two") VALUES(DEFAULT, "two")
var expectedSQL = ` var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float)
(DEFAULT, ?); VALUES (DEFAULT, ?);
` `
assertStatementSql(t, stmt, expectedSQL, "two") assertStatementSql(t, stmt, expectedSQL, "two")
} }
func TestInsertOnDuplicateKeyUpdate(t *testing.T) {
stmt := func() InsertStatement {
return table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
}
t.Run("empty list", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE()
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?);
`, "two")
})
t.Run("one set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(table1ColFloat.SET(Float(11.1)))
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_float = ?;
`, "two", 11.1)
})
t.Run("all types set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(
table1ColBool.SET(Bool(true)),
table1ColInt.SET(Int(11)),
table1ColFloat.SET(Float(11.1)),
table1ColString.SET(String("str")),
table1ColTime.SET(Time(11, 23, 11)),
table1ColTimestamp.SET(Timestamp(2020, 1, 22, 3, 4, 5)),
table1ColDate.SET(Date(2020, 12, 1)),
)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_bool = ?,
col_int = ?,
col_float = ?,
col_string = ?,
col_time = CAST(? AS TIME),
col_timestamp = TIMESTAMP(?),
col_date = CAST(? AS DATE);
`, "two", true, int64(11), 11.1, "str", "11:23:11", "2020-01-22 03:04:05", "2020-12-01")
})
}

View file

@ -69,7 +69,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock)
newSelect.Select.Projections = projections newSelect.Select.ProjectionList = projections
newSelect.From.Table = table newSelect.From.Table = table
newSelect.Limit.Count = -1 newSelect.Limit.Count = -1
newSelect.Offset.Count = -1 newSelect.Offset.Count = -1

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
} }
func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
subQuery := &selectTableImpl{ subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias), SelectTable: jet.NewSelectTable(selectStmt, alias),
} }

View file

@ -4,13 +4,13 @@ import "github.com/go-jet/jet/internal/jet"
// UNION effectively appends the result of sub-queries(select statements) into single query. // UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result. // It eliminates duplicate rows from its result.
func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
} }
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query. // UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result. // It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
} }
@ -54,7 +54,7 @@ type setStatementImpl struct {
setOperator jet.ClauseSetStmtOperator setOperator jet.ClauseSetStmtOperator
} }
func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) setStatement { func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement {
newSetStatement := &setStatementImpl{} newSetStatement := &setStatementImpl{}
newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement, newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement,
&newSetStatement.setOperator) &newSetStatement.setOperator)
@ -93,6 +93,6 @@ const (
union = "UNION" union = "UNION"
) )
func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections { func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement {
return append([]jet.StatementWithProjections{lhs, rhs}, selects...) return append([]jet.SerializerStatement{lhs, rhs}, selects...)
} }

View file

@ -8,7 +8,7 @@ type Table interface {
readableTable readableTable
INSERT(columns ...jet.Column) InsertStatement INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement DELETE() DeleteStatement
LOCK() LockStatement LOCK() LockStatement
} }
@ -35,7 +35,7 @@ type readableTable interface {
type joinSelectUpdateTable interface { type joinSelectUpdateTable interface {
ReadableTable ReadableTable
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement UPDATE(columns ...jet.Column) UpdateStatement
} }
// ReadableTable interface // ReadableTable interface
@ -49,37 +49,37 @@ type readableTableInterfaceImpl struct {
} }
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
} }
// Creates a inner join tableName Expression using onCondition. // Creates a inner join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) return newJoinTable(r.parent, table, jet.InnerJoin, onCondition)
} }
// Creates a left join tableName Expression using onCondition. // Creates a left join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) return newJoinTable(r.parent, table, jet.LeftJoin, onCondition)
} }
// Creates a right join tableName Expression using onCondition. // Creates a right join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.RightJoin, onCondition) return newJoinTable(r.parent, table, jet.RightJoin, onCondition)
} }
func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.FullJoin, onCondition) return newJoinTable(r.parent, table, jet.FullJoin, onCondition)
} }
func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable { func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.CrossJoin, nil) return newJoinTable(r.parent, table, jet.CrossJoin, nil)
} }
// NewTable creates new table with schema Name, table Name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{ t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, column, columns...), SerializerTable: jet.NewTable(schemaName, name, columns...),
} }
t.readableTableInterfaceImpl.parent = t t.readableTableInterfaceImpl.parent = t
@ -98,8 +98,8 @@ func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(t.parent, jet.UnwidColumnList(columns)) return newInsertStatement(t.parent, jet.UnwidColumnList(columns))
} }
func (t *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwindColumns(column, columns...)) return newUpdateStatement(t.parent, jet.UnwidColumnList(columns))
} }
func (t *tableImpl) DELETE() DeleteStatement { func (t *tableImpl) DELETE() DeleteStatement {

View file

@ -5,9 +5,9 @@ import (
) )
func TestJoinNilInputs(t *testing.T) { func TestJoinNilInputs(t *testing.T) {
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table") "jet: right hand side of join operation is nil table")
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil") "jet: join condition is nil")
} }

View file

@ -10,3 +10,12 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list. // ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement = jet.PrintableStatement
// SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc

View file

@ -16,14 +16,18 @@ type updateStatementImpl struct {
jet.SerializerStatement jet.SerializerStatement
Update jet.ClauseUpdate Update jet.ClauseUpdate
Set jet.ClauseSet Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere Where jet.ClauseWhere
} }
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{} update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Set, &update.Where) &update.Update,
&update.Set,
&update.SetNew,
&update.Where)
update.Update.Table = table update.Update.Table = table
update.Set.Columns = columns update.Set.Columns = columns
@ -33,7 +37,17 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
} }
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values) u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u return u
} }

View file

@ -7,12 +7,14 @@ import (
) )
var table1Col1 = IntegerColumn("col1") var table1Col1 = IntegerColumn("col1")
var table1ColBool = BoolColumn("col_bool")
var table1ColInt = IntegerColumn("col_int") var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float") var table1ColFloat = FloatColumn("col_float")
var table1ColString = StringColumn("col_string")
var table1Col3 = IntegerColumn("col3") var table1Col3 = IntegerColumn("col3")
var table1ColTimestamp = TimestampColumn("col_timestamp") var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date") var table1ColDate = DateColumn("col_date")
var table1ColTime = TimeColumn("col_time")
var table1 = NewTable( var table1 = NewTable(
"db", "db",
@ -20,10 +22,12 @@ var table1 = NewTable(
table1Col1, table1Col1,
table1ColInt, table1ColInt,
table1ColFloat, table1ColFloat,
table1ColString,
table1Col3, table1Col3,
table1ColBool, table1ColBool,
table1ColDate, table1ColDate,
table1ColTimestamp, table1ColTimestamp,
table1ColTime,
) )
var table2Col3 = IntegerColumn("col3") var table2Col3 = IntegerColumn("col3")
@ -59,15 +63,15 @@ var table3 = NewTable(
table3StrCol) table3StrCol)
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) testutils.AssertSerialize(t, Dialect, clause, query, args...)
} }
func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertDebugClauseSerialize(t, Dialect, clause, query, args...) testutils.AssertDebugSerialize(t, Dialect, clause, query, args...)
} }
func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) testutils.AssertSerializeErr(t, Dialect, clause, errString)
} }
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {

26
mysql/with_statement.go Normal file
View file

@ -0,0 +1,26 @@
package mysql
import "github.com/go-jet/jet/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
}
cte.parent = &cte
return cte
}

92
postgres/clause.go Normal file
View file

@ -0,0 +1,92 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
)
type clauseReturning struct {
ProjectionList []jet.Projection
}
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(r.ProjectionList) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.ProjectionList)
out.DecreaseIdent()
}
func (r clauseReturning) Projections() ProjectionList {
return r.ProjectionList
}
// ========================================== //
type onConflict interface {
ON_CONSTRAINT(name string) conflictTarget
WHERE(indexPredicate BoolExpression) conflictTarget
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type conflictTarget interface {
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type onConflictClause struct {
insertStatement InsertStatement
constraint string
indexExpressions []jet.ColumnExpression
whereClause jet.ClauseWhere
do jet.Serializer
}
func (o *onConflictClause) ON_CONSTRAINT(name string) conflictTarget {
o.constraint = name
return o
}
func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget {
o.whereClause.Condition = indexPredicate
return o
}
func (o *onConflictClause) DO_NOTHING() InsertStatement {
o.do = jet.Keyword("DO NOTHING")
return o.insertStatement
}
func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement {
o.do = action
return o.insertStatement
}
func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(o.indexExpressions) == 0 && o.constraint == "" {
return
}
out.NewLine()
out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 {
out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")")
}
if o.constraint != "" {
out.WriteString("ON CONSTRAINT")
out.WriteString(o.constraint)
}
o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName)
out.IncreaseIdent(7)
jet.Serialize(o.do, statementType, out)
out.DecreaseIdent(7)
}

35
postgres/clause_test.go Normal file
View file

@ -0,0 +1,35 @@
package postgres
import "testing"
func TestOnConflict(t *testing.T) {
assertClauseSerialize(t, &onConflictClause{}, "")
onConflict := &onConflictClause{}
onConflict.DO_NOTHING()
assertClauseSerialize(t, onConflict, "")
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
onConflict.DO_NOTHING()
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool) DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
onConflict.ON_CONSTRAINT("table_pkey").DO_NOTHING()
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}}
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).
DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table1ColInt.SET(Int(11))).
WHERE(table2ColFloat.GT(Float(11.1))),
)
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE
SET col_bool = $1,
col_int = $2
WHERE table2.col_float > $3`)
}

View file

@ -1,20 +0,0 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
)
type clauseReturning struct {
Projections []jet.Projection
}
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) {
if len(r.Projections) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.Projections)
}

View file

@ -8,10 +8,10 @@ func TestNewIntervalColumn(t *testing.T) {
subQuery := SELECT(Int(1)).AsTable("sub_query") subQuery := SELECT(Int(1)).AsTable("sub_query")
subQueryIntervalColumn := IntervalColumn("col_interval").From(subQuery) subQueryIntervalColumn := IntervalColumn("col_interval").From(subQuery)
assertSerialize(t, subQueryIntervalColumn, `sub_query."col_interval"`) assertSerialize(t, subQueryIntervalColumn, `sub_query.col_interval`)
assertSerialize(t, subQueryIntervalColumn.EQ(INTERVAL(2, HOUR, 10, MINUTE)), assertSerialize(t, subQueryIntervalColumn.EQ(INTERVAL(2, HOUR, 10, MINUTE)),
`(sub_query."col_interval" = INTERVAL '2 HOUR 10 MINUTE')`) `(sub_query.col_interval = INTERVAL '2 HOUR 10 MINUTE')`)
assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query."col_interval" AS "col_interval"`) assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query.col_interval AS "col_interval"`)
subQueryIntervalColumn2 := table1ColInterval.From(subQuery) subQueryIntervalColumn2 := table1ColInterval.From(subQuery)
assertSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval"`) assertSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval"`)

View file

@ -0,0 +1,30 @@
package postgres
import "github.com/go-jet/jet/internal/jet"
type conflictAction interface {
jet.Serializer
WHERE(condition BoolExpression) conflictAction
}
// SET creates conflict action for ON_CONFLICT clause
func SET(assigments ...ColumnAssigment) conflictAction {
conflictAction := updateConflictActionImpl{}
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
conflictAction.set = assigments
return &conflictAction
}
type updateConflictActionImpl struct {
jet.Serializer
doUpdate jet.KeywordClause
set jet.SetClauseNew
where jet.ClauseWhere
}
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
u.where.Condition = condition
return u
}

View file

@ -4,7 +4,7 @@ import "github.com/go-jet/jet/internal/jet"
// DeleteStatement is interface for PostgreSQL DELETE statement // DeleteStatement is interface for PostgreSQL DELETE statement
type DeleteStatement interface { type DeleteStatement interface {
Statement jet.SerializerStatement
WHERE(expression BoolExpression) DeleteStatement WHERE(expression BoolExpression) DeleteStatement
@ -37,6 +37,6 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
} }
func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement { func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement {
d.Returning.Projections = projections d.Returning.ProjectionList = projections
return d return d
} }

View file

@ -87,6 +87,9 @@ var MINf = jet.MINf
// MINi is aggregate function. Returns minimum value of int expression across all input values // MINi is aggregate function. Returns minimum value of int expression across all input values
var MINi = jet.MINi var MINi = jet.MINi
// SUM is aggregate function. Returns sum of all expressions
var SUM = jet.SUM
// SUMf is aggregate function. Returns sum of expression across all float expressions // SUMf is aggregate function. Returns sum of expression across all float expressions
var SUMf = jet.SUMf var SUMf = jet.SUMf

View file

@ -4,25 +4,25 @@ import "github.com/go-jet/jet/internal/jet"
// InsertStatement is interface for SQL INSERT statements // InsertStatement is interface for SQL INSERT statements
type InsertStatement interface { type InsertStatement interface {
Statement jet.SerializerStatement
// Insert row of values // Insert row of values
VALUES(value interface{}, values ...interface{}) InsertStatement VALUES(value interface{}, values ...interface{}) InsertStatement
// Insert row of values, where value for each column is extracted from filed of structure data. // Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic. // If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...jet.Projection) InsertStatement ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict
RETURNING(projections ...Projection) InsertStatement
} }
func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement { func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{} newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.Returning) &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning)
newInsert.Insert.Table = table newInsert.Insert.Table = table
newInsert.Insert.Columns = columns newInsert.Insert.Columns = columns
@ -36,6 +36,7 @@ type insertStatementImpl struct {
Insert jet.ClauseInsert Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery ValuesQuery jet.ClauseValuesQuery
Returning clauseReturning Returning clauseReturning
OnConflict onConflictClause
} }
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
@ -54,7 +55,7 @@ func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
} }
func (i *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement { func (i *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement {
i.Returning.Projections = projections i.Returning.ProjectionList = projections
return i return i
} }
@ -62,3 +63,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState
i.ValuesQuery.Query = selectStatement i.ValuesQuery.Query = selectStatement
return i return i
} }
func (i *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict {
i.OnConflict = onConflictClause{
insertStatement: i,
indexExpressions: indexExpressions,
}
return &i.OnConflict
}

View file

@ -1,7 +1,8 @@
package postgres package postgres
import ( import (
"github.com/stretchr/testify/assert" "github.com/go-jet/jet/internal/jet"
"github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
@ -13,15 +14,15 @@ func TestInvalidInsert(t *testing.T) {
func TestInsertNilValue(t *testing.T) { func TestInsertNilValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1) VALUES INSERT INTO db.table1 (col1)
($1); VALUES ($1);
`, nil) `, nil)
} }
func TestInsertSingleValue(t *testing.T) { func TestInsertSingleValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1) VALUES INSERT INTO db.table1 (col1)
($1); VALUES ($1);
`, int(1)) `, int(1))
} }
@ -29,8 +30,8 @@ func TestInsertWithColumnList(t *testing.T) {
columnList := ColumnList{table3ColInt, table3StrCol} columnList := ColumnList{table3ColInt, table3StrCol}
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2) VALUES INSERT INTO db.table3 (col_int, col2)
($1, $2); VALUES ($1, $2);
`, 1, 3) `, 1, 3)
} }
@ -38,15 +39,15 @@ func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), ` assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), `
INSERT INTO db.table1 (col_time) VALUES INSERT INTO db.table1 (col_time)
($1); VALUES ($1);
`, date) `, date)
} }
func TestInsertMultipleValues(t *testing.T) { func TestInsertMultipleValues(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1ColBool).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3) VALUES INSERT INTO db.table1 (col1, col_float, col_bool)
($1, $2, $3); VALUES ($1, $2, $3);
`, 1, 2, 3) `, 1, 2, 3)
} }
@ -57,8 +58,8 @@ func TestInsertMultipleRows(t *testing.T) {
VALUES(111, 222) VALUES(111, 222)
assertStatementSql(t, stmt, ` assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float)
($1, $2), VALUES ($1, $2),
($3, $4), ($3, $4),
($5, $6); ($5, $6);
`, 1, 2, 11, 22, 111, 222) `, 1, 2, 11, 22, 111, 222)
@ -82,18 +83,18 @@ func TestInsertValuesFromModel(t *testing.T) {
MODEL(&toInsert) MODEL(&toInsert)
expectedSQL := ` expectedSQL := `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float)
($1, $2), VALUES ($1, $2),
($3, $4); ($3, $4);
` `
assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) assertStatementSql(t, stmt, expectedSQL, 1, float64(1.11), 1, float64(1.11))
} }
func TestInsertValuesFromModelColumnMismatch(t *testing.T) { func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, "missing struct field for column : col1") require.Equal(t, r, "missing struct field for column : col1")
}() }()
type Table1Model struct { type Table1Model struct {
Col1Prim int Col1Prim int
@ -114,7 +115,7 @@ func TestInsertFromNonStructModel(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, "jet: data has to be a struct") require.Equal(t, r, "jet: data has to be a struct")
}() }()
table2.INSERT(table2ColInt).MODEL([]int{}) table2.INSERT(table2ColInt).MODEL([]int{})
@ -139,9 +140,63 @@ func TestInsertDefaultValue(t *testing.T) {
VALUES(DEFAULT, "two") VALUES(DEFAULT, "two")
var expectedSQL = ` var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES INSERT INTO db.table1 (col1, col_float)
(DEFAULT, $1); VALUES (DEFAULT, $1);
` `
assertStatementSql(t, stmt, expectedSQL, "two") assertStatementSql(t, stmt, expectedSQL, "two")
} }
func TestInsert_ON_CONFLICT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2'),
('theta', 'beta')
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = TRUE,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}
func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
SET(table1ColBool.SET(Bool(false)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2')
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = FALSE,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}

View file

@ -75,7 +75,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For) &newSelect.Limit, &newSelect.Offset, &newSelect.For)
newSelect.Select.Projections = projections newSelect.Select.ProjectionList = projections
newSelect.From.Table = table newSelect.From.Table = table
newSelect.Limit.Count = -1 newSelect.Limit.Count = -1
newSelect.Offset.Count = -1 newSelect.Offset.Count = -1

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
} }
func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
subQuery := &selectTableImpl{ subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias), SelectTable: jet.NewSelectTable(selectStmt, alias),
} }

View file

@ -4,37 +4,37 @@ import "github.com/go-jet/jet/internal/jet"
// UNION effectively appends the result of sub-queries(select statements) into single query. // UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result. // It eliminates duplicate rows from its result.
func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
} }
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query. // UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result. // It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
} }
// INTERSECT returns all rows that are in query results. // INTERSECT returns all rows that are in query results.
// It eliminates duplicate rows from its result. // It eliminates duplicate rows from its result.
func INTERSECT(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { func INTERSECT(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...))
} }
// INTERSECT_ALL returns all rows that are in query results. // INTERSECT_ALL returns all rows that are in query results.
// It does not eliminates duplicate rows from its result. // It does not eliminates duplicate rows from its result.
func INTERSECT_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement { func INTERSECT_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...)) return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...))
} }
// EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs. // EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs.
// It eliminates duplicate rows from its result. // It eliminates duplicate rows from its result.
func EXCEPT(lhs, rhs jet.StatementWithProjections) setStatement { func EXCEPT(lhs, rhs jet.SerializerStatement) setStatement {
return newSetStatementImpl(except, false, toSelectList(lhs, rhs)) return newSetStatementImpl(except, false, toSelectList(lhs, rhs))
} }
// EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs. // EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs.
// It does not eliminates duplicate rows from its result. // It does not eliminates duplicate rows from its result.
func EXCEPT_ALL(lhs, rhs jet.StatementWithProjections) setStatement { func EXCEPT_ALL(lhs, rhs jet.SerializerStatement) setStatement {
return newSetStatementImpl(except, true, toSelectList(lhs, rhs)) return newSetStatementImpl(except, true, toSelectList(lhs, rhs))
} }
@ -98,7 +98,7 @@ type setStatementImpl struct {
setOperator jet.ClauseSetStmtOperator setOperator jet.ClauseSetStmtOperator
} }
func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) setStatement { func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement {
newSetStatement := &setStatementImpl{} newSetStatement := &setStatementImpl{}
newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement, newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement,
&newSetStatement.setOperator) &newSetStatement.setOperator)
@ -139,6 +139,6 @@ const (
except = "EXCEPT" except = "EXCEPT"
) )
func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections { func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement {
return append([]jet.StatementWithProjections{lhs, rhs}, selects...) return append([]jet.SerializerStatement{lhs, rhs}, selects...)
} }

View file

@ -31,7 +31,7 @@ type readableTable interface {
type writableTable interface { type writableTable interface {
INSERT(columns ...jet.Column) InsertStatement INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement DELETE() DeleteStatement
LOCK() LockStatement LOCK() LockStatement
} }
@ -54,30 +54,30 @@ type readableTableInterfaceImpl struct {
} }
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
} }
// Creates a inner join tableName Expression using onCondition. // Creates a inner join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) return newJoinTable(r.parent, table, jet.InnerJoin, onCondition)
} }
// Creates a left join tableName Expression using onCondition. // Creates a left join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) return newJoinTable(r.parent, table, jet.LeftJoin, onCondition)
} }
// Creates a right join tableName Expression using onCondition. // Creates a right join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.RightJoin, onCondition) return newJoinTable(r.parent, table, jet.RightJoin, onCondition)
} }
func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.FullJoin, onCondition) return newJoinTable(r.parent, table, jet.FullJoin, onCondition)
} }
func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable { func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable {
return newJoinTable(r.parent, table, jet.CrossJoin, nil) return newJoinTable(r.parent, table, jet.CrossJoin, nil)
} }
@ -89,8 +89,8 @@ func (w *writableTableInterfaceImpl) INSERT(columns ...jet.Column) InsertStateme
return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) return newInsertStatement(w.parent, jet.UnwidColumnList(columns))
} }
func (w *writableTableInterfaceImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { func (w *writableTableInterfaceImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) return newUpdateStatement(w.parent, jet.UnwidColumnList(columns))
} }
func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { func (w *writableTableInterfaceImpl) DELETE() DeleteStatement {
@ -109,10 +109,10 @@ type tableImpl struct {
} }
// NewTable creates new table with schema Name, table Name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table { func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{ t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, column, columns...), SerializerTable: jet.NewTable(schemaName, name, columns...),
} }
t.readableTableInterfaceImpl.parent = t t.readableTableInterfaceImpl.parent = t

View file

@ -5,9 +5,9 @@ import (
) )
func TestJoinNilInputs(t *testing.T) { func TestJoinNilInputs(t *testing.T) {
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table") "jet: right hand side of join operation is nil table")
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil), assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil") "jet: join condition is nil")
} }

View file

@ -10,3 +10,12 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list. // ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement = jet.PrintableStatement
// SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc

View file

@ -6,7 +6,7 @@ import (
// UpdateStatement is interface of SQL UPDATE statement // UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface { type UpdateStatement interface {
Statement jet.SerializerStatement
SET(value interface{}, values ...interface{}) UpdateStatement SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement MODEL(data interface{}) UpdateStatement
@ -20,14 +20,19 @@ type updateStatementImpl struct {
Update jet.ClauseUpdate Update jet.ClauseUpdate
Set clauseSet Set clauseSet
SetNew jet.SetClauseNew
Where jet.ClauseWhere Where jet.ClauseWhere
Returning clauseReturning Returning clauseReturning
} }
func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement { func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{} update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Set, &update.Where, &update.Returning) &update.Update,
&update.Set,
&update.SetNew,
&update.Where,
&update.Returning)
update.Update.Table = table update.Update.Table = table
update.Set.Columns = columns update.Set.Columns = columns
@ -37,7 +42,17 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme
} }
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values) u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u return u
} }
@ -52,7 +67,7 @@ func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
} }
func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement { func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement {
u.Returning.Projections = projections u.Returning.ProjectionList = projections
return u return u
} }
@ -61,7 +76,10 @@ type clauseSet struct {
Values []jet.Serializer Values []jet.Serializer
} }
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) { func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s.Values) == 0 {
return
}
out.NewLine() out.NewLine()
out.WriteString("SET") out.WriteString("SET")

View file

@ -10,7 +10,6 @@ import (
var table1Col1 = IntegerColumn("col1") var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int") var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float") var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time") var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez") var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp") var table1ColTimestamp = TimestampColumn("col_timestamp")
@ -25,7 +24,6 @@ var table1 = NewTable(
table1Col1, table1Col1,
table1ColInt, table1ColInt,
table1ColFloat, table1ColFloat,
table1Col3,
table1ColTime, table1ColTime,
table1ColTimez, table1ColTimez,
table1ColBool, table1ColBool,
@ -75,12 +73,16 @@ var table3 = NewTable(
table3ColInt, table3ColInt,
table3StrCol) table3StrCol)
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { func assertSerialize(t *testing.T, serializer jet.Serializer, query string, args ...interface{}) {
testutils.AssertSerialize(t, Dialect, serializer, query, args...)
}
func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) {
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) testutils.AssertClauseSerialize(t, Dialect, clause, query, args...)
} }
func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { func assertSerializeErr(t *testing.T, serializer jet.Serializer, errString string) {
testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) testutils.AssertSerializeErr(t, Dialect, serializer, errString)
} }
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {
@ -88,5 +90,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st
} }
var assertStatementSql = testutils.AssertStatementSql var assertStatementSql = testutils.AssertStatementSql
var assertDebugStatementSql = testutils.AssertDebugStatementSql
var assertStatementSqlErr = testutils.AssertStatementSqlErr var assertStatementSqlErr = testutils.AssertStatementSqlErr
var assertPanicErr = testutils.AssertPanicErr var assertPanicErr = testutils.AssertPanicErr

View file

@ -0,0 +1,26 @@
package postgres
import "github.com/go-jet/jet/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
}
cte.parent = &cte
return cte
}

View file

@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
@ -10,138 +10,138 @@ import (
func TestNullByteArray(t *testing.T) { func TestNullByteArray(t *testing.T) {
var array NullByteArray var array NullByteArray
assert.NoError(t, array.Scan(nil)) require.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan([]byte("bytea"))) require.NoError(t, array.Scan([]byte("bytea")))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
assert.Equal(t, string(array.ByteArray), string([]byte("bytea"))) require.Equal(t, string(array.ByteArray), string([]byte("bytea")))
assert.Error(t, array.Scan(12), "can't scan []byte from 12") require.Error(t, array.Scan(12), "can't scan []byte from 12")
} }
func TestNullTime(t *testing.T) { func TestNullTime(t *testing.T) {
var array NullTime var array NullTime
assert.NoError(t, array.Scan(nil)) require.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) require.Equal(t, array.Valid, false)
time := time.Now() time := time.Now()
assert.NoError(t, array.Scan(time)) require.NoError(t, array.Scan(time))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, time) require.Equal(t, value, time)
assert.NoError(t, array.Scan([]byte("13:10:11"))) require.NoError(t, array.Scan([]byte("13:10:11")))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.NoError(t, array.Scan("13:10:11")) require.NoError(t, array.Scan("13:10:11"))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.Error(t, array.Scan(12), "can't scan time.Time from 12") require.Error(t, array.Scan(12), "can't scan time.Time from 12")
} }
func TestNullInt8(t *testing.T) { func TestNullInt8(t *testing.T) {
var array NullInt8 var array NullInt8
assert.NoError(t, array.Scan(nil)) require.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(int64(11))) require.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, int8(11)) require.Equal(t, value, int8(11))
assert.Error(t, array.Scan("text"), "can't scan int8 from text") require.Error(t, array.Scan("text"), "can't scan int8 from text")
} }
func TestNullInt16(t *testing.T) { func TestNullInt16(t *testing.T) {
var array NullInt16 var array NullInt16
assert.NoError(t, array.Scan(nil)) require.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(int64(11))) require.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, int16(11)) require.Equal(t, value, int16(11))
assert.NoError(t, array.Scan(int16(20))) require.NoError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int16(20)) require.Equal(t, value, int16(20))
assert.NoError(t, array.Scan(int8(30))) require.NoError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int16(30)) require.Equal(t, value, int16(30))
assert.NoError(t, array.Scan(uint8(30))) require.NoError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int16(30)) require.Equal(t, value, int16(30))
assert.Error(t, array.Scan("text"), "can't scan int16 from text") require.Error(t, array.Scan("text"), "can't scan int16 from text")
} }
func TestNullInt32(t *testing.T) { func TestNullInt32(t *testing.T) {
var array NullInt32 var array NullInt32
assert.NoError(t, array.Scan(nil)) require.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(int64(11))) require.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, int32(11)) require.Equal(t, value, int32(11))
assert.NoError(t, array.Scan(int32(32))) require.NoError(t, array.Scan(int32(32)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(32)) require.Equal(t, value, int32(32))
assert.NoError(t, array.Scan(int16(20))) require.NoError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(20)) require.Equal(t, value, int32(20))
assert.NoError(t, array.Scan(uint16(16))) require.NoError(t, array.Scan(uint16(16)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(16)) require.Equal(t, value, int32(16))
assert.NoError(t, array.Scan(int8(30))) require.NoError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(30)) require.Equal(t, value, int32(30))
assert.NoError(t, array.Scan(uint8(30))) require.NoError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(30)) require.Equal(t, value, int32(30))
assert.Error(t, array.Scan("text"), "can't scan int32 from text") require.Error(t, array.Scan("text"), "can't scan int32 from text")
} }
func TestNullFloat32(t *testing.T) { func TestNullFloat32(t *testing.T) {
var array NullFloat32 var array NullFloat32
assert.NoError(t, array.Scan(nil)) require.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(float64(64))) require.NoError(t, array.Scan(float64(64)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, float32(64)) require.Equal(t, value, float32(64))
assert.NoError(t, array.Scan(float32(32))) require.NoError(t, array.Scan(float32(32)))
assert.Equal(t, array.Valid, true) require.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, float32(32)) require.Equal(t, value, float32(32))
assert.Error(t, array.Scan(12), "can't scan float32 from 12") require.Error(t, array.Scan(12), "can't scan float32 from 12")
} }

View file

@ -2,36 +2,36 @@ package qrm
import ( import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"reflect" "reflect"
"testing" "testing"
"time" "time"
) )
func TestIsSimpleModelType(t *testing.T) { func TestIsSimpleModelType(t *testing.T) {
assert.True(t, isSimpleModelType(reflect.TypeOf(int8(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(int8(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(int16(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(int16(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(int32(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(int32(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(int64(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(int64(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(uint8(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(uint16(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(uint32(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) require.True(t, isSimpleModelType(reflect.TypeOf(uint64(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) require.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46))))
assert.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) require.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46))))
assert.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) require.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text"))))
assert.True(t, isSimpleModelType(reflect.TypeOf(time.Now()))) require.True(t, isSimpleModelType(reflect.TypeOf(time.Now())))
assert.True(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) require.True(t, isSimpleModelType(reflect.TypeOf(uuid.New())))
complexModelType := struct { complexModelType := struct {
Field1 string Field1 string
Field2 string Field2 string
}{} }{}
assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false) require.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false) require.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false) require.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false) require.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false)
} }

View file

@ -1,6 +1,9 @@
package mysql package mysql
import ( import (
"fmt"
"github.com/stretchr/testify/require"
"strings"
"testing" "testing"
"time" "time"
@ -13,8 +16,6 @@ import (
"github.com/go-jet/jet/tests/testdata/results/common" "github.com/go-jet/jet/tests/testdata/results/common"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"github.com/stretchr/testify/assert"
) )
func TestAllTypes(t *testing.T) { func TestAllTypes(t *testing.T) {
@ -26,9 +27,9 @@ func TestAllTypes(t *testing.T) {
LIMIT(2). LIMIT(2).
Query(db, &dest) Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return return
@ -45,8 +46,8 @@ func TestAllTypesViewSelect(t *testing.T) {
dest := []AllTypesView{} dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return return
@ -74,11 +75,12 @@ func TestUUID(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.True(t, dest.StrUUID != nil) require.True(t, dest.StrUUID != nil)
assert.True(t, dest.UUID.String() != uuid.UUID{}.String()) require.True(t, dest.UUID.String() != uuid.UUID{}.String())
assert.True(t, dest.StrUUID.String() != uuid.UUID{}.String()) require.True(t, dest.StrUUID.String() != uuid.UUID{}.String())
assert.Equal(t, dest.StrUUID.String(), dest.BinUUID.String()) require.Equal(t, dest.StrUUID.String(), dest.BinUUID.String())
requireLogged(t, query)
} }
func TestExpressionOperators(t *testing.T) { func TestExpressionOperators(t *testing.T) {
@ -95,23 +97,23 @@ func TestExpressionOperators(t *testing.T) {
//fmt.Println(query.Sql()) //fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, strings.Replace(`
SELECT all_types.integer IS NULL AS "result.is_null", SELECT all_types.'integer' IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in", (all_types.small_int_ptr IN (?, ?)) AS "result.in",
(all_types.small_int_ptr IN (( (all_types.small_int_ptr IN ((
SELECT all_types.integer AS "all_types.integer" SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types FROM test_sample.all_types
))) AS "result.in_select", ))) AS "result.in_select",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN (( (all_types.small_int_ptr NOT IN ((
SELECT all_types.integer AS "all_types.integer" SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types FROM test_sample.all_types
))) AS "result.not_in_select", ))) AS "result.not_in_select",
DATABASE() DATABASE()
FROM test_sample.all_types FROM test_sample.all_types
LIMIT ?; LIMIT ?;
`, int64(11), int64(22), int64(11), int64(22), int64(2)) `, "'", "`", -1), int64(11), int64(22), int64(11), int64(22), int64(2))
var dest []struct { var dest []struct {
common.ExpressionTestResult `alias:"result.*"` common.ExpressionTestResult `alias:"result.*"`
@ -119,7 +121,7 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -210,7 +212,7 @@ FROM test_sample.all_types;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json")
} }
@ -261,45 +263,47 @@ func TestFloatOperators(t *testing.T) {
queryStr, _ := query.Sql() queryStr, _ := query.Sql()
assert.Equal(t, queryStr, ` //fmt.Println(queryStr)
SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.decimal = ?) AS "eq2", require.Equal(t, queryStr, strings.Replace(`
(all_types.real = ?) AS "eq3", SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1",
(NOT(all_types.numeric <=> all_types.numeric)) AS "distinct1", (all_types.'decimal' = ?) AS "eq2",
(NOT(all_types.decimal <=> ?)) AS "distinct2", (all_types.'real' = ?) AS "eq3",
(NOT(all_types.real <=> ?)) AS "distinct3", (NOT(all_types.'numeric' <=> all_types.'numeric')) AS "distinct1",
(all_types.numeric <=> all_types.numeric) AS "not_distinct1", (NOT(all_types.'decimal' <=> ?)) AS "distinct2",
(all_types.decimal <=> ?) AS "not_distinct2", (NOT(all_types.'real' <=> ?)) AS "distinct3",
(all_types.real <=> ?) AS "not_distinct3", (all_types.'numeric' <=> all_types.'numeric') AS "not_distinct1",
(all_types.numeric < ?) AS "lt1", (all_types.'decimal' <=> ?) AS "not_distinct2",
(all_types.numeric < ?) AS "lt2", (all_types.'real' <=> ?) AS "not_distinct3",
(all_types.numeric > ?) AS "gt1", (all_types.'numeric' < ?) AS "lt1",
(all_types.numeric > ?) AS "gt2", (all_types.'numeric' < ?) AS "lt2",
TRUNCATE((all_types.decimal + all_types.decimal), ?) AS "add1", (all_types.'numeric' > ?) AS "gt1",
TRUNCATE((all_types.decimal + ?), ?) AS "add2", (all_types.'numeric' > ?) AS "gt2",
TRUNCATE((all_types.decimal - all_types.decimal_ptr), ?) AS "sub1", TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1",
TRUNCATE((all_types.decimal - ?), ?) AS "sub2", TRUNCATE((all_types.'decimal' + ?), ?) AS "add2",
TRUNCATE((all_types.decimal * all_types.decimal_ptr), ?) AS "mul1", TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1",
TRUNCATE((all_types.decimal * ?), ?) AS "mul2", TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2",
TRUNCATE((all_types.decimal / all_types.decimal_ptr), ?) AS "div1", TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1",
TRUNCATE((all_types.decimal / ?), ?) AS "div2", TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2",
TRUNCATE((all_types.decimal % all_types.decimal_ptr), ?) AS "mod1", TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1",
TRUNCATE((all_types.decimal % ?), ?) AS "mod2", TRUNCATE((all_types.'decimal' / ?), ?) AS "div2",
TRUNCATE(POW(all_types.decimal, all_types.decimal_ptr), ?) AS "pow1", TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1",
TRUNCATE(POW(all_types.decimal, ?), ?) AS "pow2", TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2",
TRUNCATE(ABS(all_types.decimal), ?) AS "abs", TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1",
TRUNCATE(POWER(all_types.decimal, ?), ?) AS "power", TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2",
TRUNCATE(SQRT(all_types.decimal), ?) AS "sqrt", TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs",
TRUNCATE(POWER(all_types.decimal, (? / ?)), ?) AS "cbrt", TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power",
CEIL(all_types.real) AS "ceil", TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt",
FLOOR(all_types.real) AS "floor", TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt",
ROUND(all_types.decimal) AS "round1", CEIL(all_types.'real') AS "ceil",
ROUND(all_types.decimal, ?) AS "round2", FLOOR(all_types.'real') AS "floor",
SIGN(all_types.real) AS "sign", ROUND(all_types.'decimal') AS "round1",
TRUNCATE(all_types.decimal, ?) AS "trunc" ROUND(all_types.'decimal', ?) AS "round2",
SIGN(all_types.'real') AS "sign",
TRUNCATE(all_types.'decimal', ?) AS "trunc"
FROM test_sample.all_types FROM test_sample.all_types
LIMIT ?; LIMIT ?;
`) `, "'", "`", -1))
var dest []struct { var dest []struct {
common.FloatExpressionTestResult `alias:"."` common.FloatExpressionTestResult `alias:"."`
@ -307,7 +311,7 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json") testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json")
} }
@ -444,7 +448,7 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -516,7 +520,7 @@ func TestStringOperators(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
@ -568,7 +572,7 @@ func TestTimeExpressions(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, ` testutils.AssertDebugStatementSql(t, query, strings.Replace(`
SELECT CAST('20:34:58' AS TIME), SELECT CAST('20:34:58' AS TIME),
all_types.time = all_types.time, all_types.time = all_types.time,
all_types.time = CAST('23:06:06' AS TIME), all_types.time = CAST('23:06:06' AS TIME),
@ -589,7 +593,7 @@ SELECT CAST('20:34:58' AS TIME),
all_types.time >= all_types.time, all_types.time >= all_types.time,
all_types.time >= CAST('14:26:36' AS TIME), all_types.time >= CAST('14:26:36' AS TIME),
all_types.time + INTERVAL 10 MINUTE, all_types.time + INTERVAL 10 MINUTE,
all_types.time + INTERVAL all_types.integer MINUTE, all_types.time + INTERVAL all_types.''integer'' MINUTE,
all_types.time + INTERVAL 3 HOUR, all_types.time + INTERVAL 3 HOUR,
all_types.time - INTERVAL 20 MINUTE, all_types.time - INTERVAL 20 MINUTE,
all_types.time - INTERVAL all_types.small_int MINUTE, all_types.time - INTERVAL all_types.small_int MINUTE,
@ -598,13 +602,13 @@ SELECT CAST('20:34:58' AS TIME),
CURRENT_TIME, CURRENT_TIME,
CURRENT_TIME(3) CURRENT_TIME(3)
FROM test_sample.all_types; FROM test_sample.all_types;
`, "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", `, "''", "`", -1), "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06",
"19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36") "19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36")
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestDateExpressions(t *testing.T) { func TestDateExpressions(t *testing.T) {
@ -648,25 +652,25 @@ func TestDateExpressions(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
testutils.AssertStatementSql(t, query, ` testutils.AssertDebugStatementSql(t, query, `
SELECT CAST(? AS DATE), SELECT CAST('2009-11-17' AS DATE),
all_types.date = all_types.date, all_types.date = all_types.date,
all_types.date = CAST(? AS DATE), all_types.date = CAST('2019-06-06' AS DATE),
all_types.date_ptr != all_types.date, all_types.date_ptr != all_types.date,
all_types.date_ptr != CAST(? AS DATE), all_types.date_ptr != CAST('2019-01-06' AS DATE),
NOT(all_types.date <=> all_types.date), NOT(all_types.date <=> all_types.date),
NOT(all_types.date <=> CAST(? AS DATE)), NOT(all_types.date <=> CAST('2019-02-06' AS DATE)),
all_types.date <=> all_types.date, all_types.date <=> all_types.date,
all_types.date <=> CAST(? AS DATE), all_types.date <=> CAST('2019-03-06' AS DATE),
all_types.date < all_types.date, all_types.date < all_types.date,
all_types.date < CAST(? AS DATE), all_types.date < CAST('2019-04-06' AS DATE),
all_types.date <= all_types.date, all_types.date <= all_types.date,
all_types.date <= CAST(? AS DATE), all_types.date <= CAST('2019-05-05' AS DATE),
all_types.date > all_types.date, all_types.date > all_types.date,
all_types.date > CAST(? AS DATE), all_types.date > CAST('2019-01-04' AS DATE),
all_types.date >= all_types.date, all_types.date >= all_types.date,
all_types.date >= CAST(? AS DATE), all_types.date >= CAST('2019-02-03' AS DATE),
all_types.date + INTERVAL ? MINUTE_MICROSECOND, all_types.date + INTERVAL '10:20.000100' MINUTE_MICROSECOND,
all_types.date + INTERVAL all_types.big_int MINUTE, all_types.date + INTERVAL all_types.big_int MINUTE,
all_types.date + INTERVAL 15 HOUR, all_types.date + INTERVAL 15 HOUR,
all_types.date - INTERVAL 20 MINUTE, all_types.date - INTERVAL 20 MINUTE,
@ -679,7 +683,7 @@ FROM test_sample.all_types;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestDateTimeExpressions(t *testing.T) { func TestDateTimeExpressions(t *testing.T) {
@ -756,7 +760,7 @@ FROM test_sample.all_types;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestTimestampExpressions(t *testing.T) { func TestTimestampExpressions(t *testing.T) {
@ -832,13 +836,13 @@ FROM test_sample.all_types;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestTimeLiterals(t *testing.T) { func TestTimeLiterals(t *testing.T) {
loc, err := time.LoadLocation("Europe/Berlin") loc, err := time.LoadLocation("Europe/Berlin")
assert.NoError(t, err) require.NoError(t, err)
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 351387237, loc) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 351387237, loc)
@ -877,7 +881,7 @@ LIMIT ?;
} }
err = query.Query(db, &dest) err = query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -960,7 +964,139 @@ func TestINTERVAL(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
err := query.Query(db, &struct{}{}) err := query.Query(db, &struct{}{})
assert.NoError(t, err) require.NoError(t, err)
}
func TestAllTypesInsert(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
stmt := AllTypes.INSERT(AllTypes.AllColumns).
MODEL(toInsert)
fmt.Println(stmt.DebugSql())
testutils.AssertExec(t, stmt, tx, 1)
var dest model.AllTypes
err = AllTypes.SELECT(AllTypes.AllColumns).
WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))).
Query(tx, &dest)
require.NoError(t, err)
require.Equal(t, toInsert.TinyInt, dest.TinyInt)
err = tx.Rollback()
require.NoError(t, err)
}
func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
toInsert := model.AllTypes{
Boolean: true,
Integer: 124,
Float: 45.67,
Blob: []byte("blob"),
Text: "text",
JSON: "{}",
Time: time.Now(),
Timestamp: time.Now(),
Date: time.Now(),
}
stmt := AllTypes.INSERT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Float,
AllTypes.Blob,
AllTypes.Text,
AllTypes.JSON,
AllTypes.Time,
AllTypes.Timestamp,
AllTypes.Date,
).
MODEL(toInsert).
ON_DUPLICATE_KEY_UPDATE(
AllTypes.Boolean.SET(Bool(false)),
AllTypes.Integer.SET(Int(4)),
AllTypes.Float.SET(Float(0.67)),
AllTypes.Text.SET(String("new text")),
AllTypes.Time.SET(TimeT(time.Now())),
AllTypes.Timestamp.SET(TimestampT(time.Now())),
AllTypes.Date.SET(DateT(time.Now())),
)
fmt.Println(stmt.DebugSql())
_, err = stmt.Exec(tx)
require.NoError(t, err)
err = tx.Rollback()
require.NoError(t, err)
}
var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: testutils.BoolPtr(true),
TinyInt: 1,
UTinyInt: 2,
SmallInt: 3,
USmallInt: 4,
MediumInt: 5,
UMediumInt: 6,
Integer: 7,
UInteger: 8,
BigInt: 9,
UBigInt: 1122334455,
TinyIntPtr: testutils.Int8Ptr(11),
UTinyIntPtr: testutils.UInt8Ptr(22),
SmallIntPtr: testutils.Int16Ptr(33),
USmallIntPtr: testutils.UInt16Ptr(44),
MediumIntPtr: testutils.Int32Ptr(55),
UMediumIntPtr: testutils.UInt32Ptr(66),
IntegerPtr: testutils.Int32Ptr(77),
UIntegerPtr: testutils.UInt32Ptr(88),
BigIntPtr: testutils.Int64Ptr(99),
UBigIntPtr: testutils.UInt64Ptr(111),
Decimal: 11.22,
DecimalPtr: testutils.Float64Ptr(33.44),
Numeric: 55.66,
NumericPtr: testutils.Float64Ptr(77.88),
Float: 99.00,
FloatPtr: testutils.Float64Ptr(11.22),
Double: 33.44,
DoublePtr: testutils.Float64Ptr(55.66),
Real: 77.88,
RealPtr: testutils.Float64Ptr(99.00),
Bit: "1",
BitPtr: testutils.StringPtr("0"),
Time: time.Date(0, 0, 0, 10, 11, 12, 100, &time.Location{}),
TimePtr: testutils.TimePtr(time.Date(0, 0, 0, 10, 11, 12, 100, time.UTC)),
Date: time.Now(),
DatePtr: testutils.TimePtr(time.Now()),
DateTime: time.Now(),
DateTimePtr: testutils.TimePtr(time.Now()),
Timestamp: time.Now(),
//TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB
Year: 2000,
YearPtr: testutils.Int16Ptr(2001),
Char: "abcd",
CharPtr: testutils.StringPtr("absd"),
VarChar: "abcd",
VarCharPtr: testutils.StringPtr("absd"),
Binary: []byte("1010"),
BinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
VarBinary: []byte("1010"),
VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
Blob: []byte("large file"),
BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
Text: "some text",
TextPtr: testutils.StringPtr("text"),
Enum: model.AllTypesEnum_Value1,
JSON: "{}",
JSONPtr: testutils.StringPtr(`{"a": 1}`),
} }
var allTypesJson = ` var allTypesJson = `
@ -1100,28 +1236,26 @@ func TestReservedWord(t *testing.T) {
stmt := SELECT(User.AllColumns). stmt := SELECT(User.AllColumns).
FROM(User) FROM(User)
// NOTE: A word that follows a period in a qualified name must be an identifier, so it testutils.AssertDebugStatementSql(t, stmt, strings.Replace(`
// need not be quoted even if it is reserved SELECT user.''column'' AS "user.column",
testutils.AssertDebugStatementSql(t, stmt, ` user.''use'' AS "user.use",
SELECT user.column AS "user.column",
user.use AS "user.use",
user.ceil AS "user.ceil", user.ceil AS "user.ceil",
user.commit AS "user.commit", user.commit AS "user.commit",
user.create AS "user.create", user.''create'' AS "user.create",
user.default AS "user.default", user.''default'' AS "user.default",
user.desc AS "user.desc", user.''desc'' AS "user.desc",
user.empty AS "user.empty", user.''empty'' AS "user.empty",
user.float AS "user.float", user.''float'' AS "user.float",
user.join AS "user.join", user.''join'' AS "user.join",
user.like AS "user.like", user.''like'' AS "user.like",
user.max AS "user.max", user.max AS "user.max",
user.rank AS "user.rank" user.''rank'' AS "user.rank"
FROM test_sample.user; FROM test_sample.user;
`) `, "''", "`", -1))
var dest []model.User var dest []model.User
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
testutils.PrintJson(dest) testutils.PrintJson(dest)

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
@ -55,7 +55,7 @@ FROM test_sample.all_types;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertDeepEqual(t, dest, Result{ testutils.AssertDeepEqual(t, dest, Result{
As1: "test", As1: "test",
@ -68,4 +68,6 @@ FROM test_sample.all_types;
Unsigned: 15, Unsigned: 15,
Binary: "Some text", Binary: "Some text",
}) })
requireLogged(t, query)
} }

View file

@ -6,7 +6,7 @@ import (
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
@ -24,6 +24,7 @@ WHERE link.name IN ('Gmail', 'Outlook');
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
testutils.AssertExec(t, deleteStmt, db, 2) testutils.AssertExec(t, deleteStmt, db, 2)
requireLogged(t, deleteStmt)
} }
func TestDeleteWithWhereOrderByLimit(t *testing.T) { func TestDeleteWithWhereOrderByLimit(t *testing.T) {
@ -43,6 +44,7 @@ LIMIT 1;
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1)) testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1))
testutils.AssertExec(t, deleteStmt, db, 1) testutils.AssertExec(t, deleteStmt, db, 1)
requireLogged(t, deleteStmt)
} }
func TestDeleteQueryContext(t *testing.T) { func TestDeleteQueryContext(t *testing.T) {
@ -60,7 +62,8 @@ func TestDeleteQueryContext(t *testing.T) {
dest := []model.Link{} dest := []model.Link{}
err := deleteStmt.QueryContext(ctx, db, &dest) err := deleteStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
requireLogged(t, deleteStmt)
} }
func TestDeleteExecContext(t *testing.T) { func TestDeleteExecContext(t *testing.T) {
@ -77,7 +80,7 @@ func TestDeleteExecContext(t *testing.T) {
_, err := deleteStmt.ExecContext(ctx, db) _, err := deleteStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func initForDeleteTest(t *testing.T) { func initForDeleteTest(t *testing.T) {

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/generator/mysql" "github.com/go-jet/jet/generator/mysql"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
@ -25,23 +25,23 @@ func TestGenerator(t *testing.T) {
DBName: "dvds", DBName: "dvds",
}) })
assert.NoError(t, err) require.NoError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
} }
err := os.RemoveAll(genTestDirRoot) err := os.RemoveAll(genTestDirRoot)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCmdGenerator(t *testing.T) { func TestCmdGenerator(t *testing.T) {
goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet") goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet")
goInstallJet.Stderr = os.Stderr goInstallJet.Stderr = os.Stderr
err := goInstallJet.Run() err := goInstallJet.Run()
assert.NoError(t, err) require.NoError(t, err)
err = os.RemoveAll(genTestDir3) err = os.RemoveAll(genTestDir3)
assert.NoError(t, err) require.NoError(t, err)
cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306", cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306",
"-user=jet", "-password=jet", "-path="+genTestDir3) "-user=jet", "-password=jet", "-path="+genTestDir3)
@ -50,44 +50,44 @@ func TestCmdGenerator(t *testing.T) {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
err = cmd.Run() err = cmd.Run()
assert.NoError(t, err) require.NoError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
err = os.RemoveAll(genTestDirRoot) err = os.RemoveAll(genTestDirRoot)
assert.NoError(t, err) require.NoError(t, err)
} }
func assertGeneratedFiles(t *testing.T) { func assertGeneratedFiles(t *testing.T) {
// Table SQL Builder files // Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table")
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go") "payment.go", "rental.go", "staff.go", "store.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", actorSQLBuilderFile)
// View SQL Builder files // View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view")
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", actorInfoSQLBuilderFile)
// Enums SQL Builder files // Enums SQL Builder files
enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum")
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go") testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", mpaaRatingEnumFile)
// Model files // Model files
modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model")
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
@ -96,10 +96,17 @@ func assertGeneratedFiles(t *testing.T) {
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go") "customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", actorModelFile)
} }
var mpaaRatingEnumFile = ` var mpaaRatingEnumFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package enum package enum
import "github.com/go-jet/jet/mysql" import "github.com/go-jet/jet/mysql"
@ -120,6 +127,13 @@ var FilmRating = &struct {
` `
var actorSQLBuilderFile = ` var actorSQLBuilderFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table package table
import ( import (
@ -141,25 +155,25 @@ type ActorTable struct {
MutableColumns mysql.ColumnList MutableColumns mysql.ColumnList
} }
// creates new ActorTable with assigned alias // AS creates new ActorTable with assigned alias
func (a *ActorTable) AS(alias string) *ActorTable { func (a *ActorTable) AS(alias string) ActorTable {
aliasTable := newActorTable() aliasTable := newActorTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newActorTable() *ActorTable { func newActorTable() ActorTable {
var ( var (
ActorIDColumn = mysql.IntegerColumn("actor_id") ActorIDColumn = mysql.IntegerColumn("actor_id")
FirstNameColumn = mysql.StringColumn("first_name") FirstNameColumn = mysql.StringColumn("first_name")
LastNameColumn = mysql.StringColumn("last_name") LastNameColumn = mysql.StringColumn("last_name")
LastUpdateColumn = mysql.TimestampColumn("last_update") LastUpdateColumn = mysql.TimestampColumn("last_update")
allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}
mutableColumns = mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}
) )
return &ActorTable{ return ActorTable{
Table: mysql.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn), Table: mysql.NewTable("dvds", "actor", allColumns...),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
@ -167,13 +181,20 @@ func newActorTable() *ActorTable {
LastName: LastNameColumn, LastName: LastNameColumn,
LastUpdate: LastUpdateColumn, LastUpdate: LastUpdateColumn,
AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}, AllColumns: allColumns,
MutableColumns: mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}, MutableColumns: mutableColumns,
} }
} }
` `
var actorModelFile = ` var actorModelFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model package model
import ( import (
@ -188,7 +209,14 @@ type Actor struct {
} }
` `
var actorInfoSQLBuilerFile = ` var actorInfoSQLBuilderFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package view package view
import ( import (
@ -210,25 +238,25 @@ type ActorInfoTable struct {
MutableColumns mysql.ColumnList MutableColumns mysql.ColumnList
} }
// creates new ActorInfoTable with assigned alias // AS creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable { func (a *ActorInfoTable) AS(alias string) ActorInfoTable {
aliasTable := newActorInfoTable() aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias) aliasTable.Table.AS(alias)
return aliasTable return aliasTable
} }
func newActorInfoTable() *ActorInfoTable { func newActorInfoTable() ActorInfoTable {
var ( var (
ActorIDColumn = mysql.IntegerColumn("actor_id") ActorIDColumn = mysql.IntegerColumn("actor_id")
FirstNameColumn = mysql.StringColumn("first_name") FirstNameColumn = mysql.StringColumn("first_name")
LastNameColumn = mysql.StringColumn("last_name") LastNameColumn = mysql.StringColumn("last_name")
FilmInfoColumn = mysql.StringColumn("film_info") FilmInfoColumn = mysql.StringColumn("film_info")
allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
mutableColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
) )
return &ActorInfoTable{ return ActorInfoTable{
Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn), Table: mysql.NewTable("dvds", "actor_info", allColumns...),
//Columns //Columns
ActorID: ActorIDColumn, ActorID: ActorIDColumn,
@ -236,8 +264,8 @@ func newActorInfoTable() *ActorInfoTable {
LastName: LastNameColumn, LastName: LastNameColumn,
FilmInfo: FilmInfoColumn, FilmInfo: FilmInfoColumn,
AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, AllColumns: allColumns,
MutableColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}, MutableColumns: mutableColumns,
} }
} }
` `

View file

@ -6,7 +6,8 @@ import (
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"math/rand"
"testing" "testing"
"time" "time"
) )
@ -15,8 +16,8 @@ func TestInsertValues(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)
var expectedSQL = ` var expectedSQL = `
INSERT INTO test_sample.link (id, url, name, description) VALUES INSERT INTO test_sample.link (id, url, name, description)
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT), (101, 'http://www.google.com', 'Google', DEFAULT),
(102, 'http://www.yahoo.com', 'Yahoo', NULL); (102, 'http://www.yahoo.com', 'Yahoo', NULL);
` `
@ -32,7 +33,8 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES
102, "http://www.yahoo.com", "Yahoo", nil) 102, "http://www.yahoo.com", "Yahoo", nil)
_, err := insertQuery.Exec(db) _, err := insertQuery.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
requireLogged(t, insertQuery)
insertedLinks := []model.Link{} insertedLinks := []model.Link{}
@ -41,8 +43,8 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &insertedLinks) Query(db, &insertedLinks)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(insertedLinks), 3) require.Equal(t, len(insertedLinks), 3)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
@ -69,8 +71,8 @@ func TestInsertEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)
expectedSQL := ` expectedSQL := `
INSERT INTO test_sample.link VALUES INSERT INTO test_sample.link
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
` `
stmt := Link.INSERT(). stmt := Link.INSERT().
@ -80,7 +82,8 @@ INSERT INTO test_sample.link VALUES
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
_, err := stmt.Exec(db) _, err := stmt.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
requireLogged(t, stmt)
insertedLinks := []model.Link{} insertedLinks := []model.Link{}
@ -89,16 +92,16 @@ INSERT INTO test_sample.link VALUES
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &insertedLinks) Query(db, &insertedLinks)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(insertedLinks), 1) require.Equal(t, len(insertedLinks), 1)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
} }
func TestInsertModelObject(t *testing.T) { func TestInsertModelObject(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)
var expectedSQL = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name) VALUES INSERT INTO test_sample.link (url, name)
('http://www.duckduckgo.com', 'Duck Duck go'); VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
` `
linkData := model.Link{ linkData := model.Link{
@ -113,14 +116,14 @@ INSERT INTO test_sample.link (url, name) VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go")
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestInsertModelObjectEmptyColumnList(t *testing.T) { func TestInsertModelObjectEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)
var expectedSQL = ` var expectedSQL = `
INSERT INTO test_sample.link VALUES INSERT INTO test_sample.link
(1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
` `
linkData := model.Link{ linkData := model.Link{
@ -136,13 +139,13 @@ INSERT INTO test_sample.link VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestInsertModelsObject(t *testing.T) { func TestInsertModelsObject(t *testing.T) {
expectedSQL := ` expectedSQL := `
INSERT INTO test_sample.link (url, name) VALUES INSERT INTO test_sample.link (url, name)
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'), ('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo'); ('http://www.yahoo.com', 'Yahoo');
` `
@ -172,13 +175,13 @@ INSERT INTO test_sample.link (url, name) VALUES
"http://www.yahoo.com", "Yahoo") "http://www.yahoo.com", "Yahoo")
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestInsertUsingMutableColumns(t *testing.T) { func TestInsertUsingMutableColumns(t *testing.T) {
var expectedSQL = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name, description) VALUES INSERT INTO test_sample.link (url, name, description)
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL), ('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL), ('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL); ('http://www.yahoo.com', 'Yahoo', NULL);
@ -207,14 +210,14 @@ INSERT INTO test_sample.link (url, name, description) VALUES
"http://www.yahoo.com", "Yahoo", nil) "http://www.yahoo.com", "Yahoo", nil)
_, err := stmt.Exec(db) _, err := stmt.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestInsertQuery(t *testing.T) { func TestInsertQuery(t *testing.T) {
_, err := Link.DELETE(). _, err := Link.DELETE().
WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))). WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))).
Exec(db) Exec(db)
assert.NoError(t, err) require.NoError(t, err)
var expectedSQL = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name) ( INSERT INTO test_sample.link (url, name) (
@ -236,7 +239,7 @@ INSERT INTO test_sample.link (url, name) (
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1))
_, err = query.Exec(db) _, err = query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
youtubeLinks := []model.Link{} youtubeLinks := []model.Link{}
err = Link. err = Link.
@ -244,8 +247,48 @@ INSERT INTO test_sample.link (url, name) (
WHERE(Link.Name.EQ(String("Youtube"))). WHERE(Link.Name.EQ(String("Youtube"))).
Query(db, &youtubeLinks) Query(db, &youtubeLinks)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(youtubeLinks), 2) require.Equal(t, len(youtubeLinks), 2)
}
func TestInsertOnDuplicateKey(t *testing.T) {
randId := rand.Int31()
stmt := Link.INSERT().
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_DUPLICATE_KEY_UPDATE(
Link.ID.SET(Link.ID.ADD(Int(11))),
Link.Name.SET(String("PostgreSQL Tutorial 2")),
)
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.link
VALUES (?, ?, ?, DEFAULT),
(?, ?, ?, DEFAULT)
ON DUPLICATE KEY UPDATE id = (id + ?),
name = ?;
`, randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
int64(11), "PostgreSQL Tutorial 2")
testutils.AssertExec(t, stmt, db, 3)
newLinks := []model.Link{}
err := SELECT(Link.AllColumns).
FROM(Link).
WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))).
Query(db, &newLinks)
require.NoError(t, err)
require.Len(t, newLinks, 1)
require.Equal(t, newLinks[0], model.Link{
ID: randId + 11,
URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial 2",
Description: nil,
})
} }
func TestInsertWithQueryContext(t *testing.T) { func TestInsertWithQueryContext(t *testing.T) {
@ -262,7 +305,7 @@ func TestInsertWithQueryContext(t *testing.T) {
dest := []model.Link{} dest := []model.Link{}
err := stmt.QueryContext(ctx, db, &dest) err := stmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func TestInsertWithExecContext(t *testing.T) { func TestInsertWithExecContext(t *testing.T) {
@ -278,10 +321,10 @@ func TestInsertWithExecContext(t *testing.T) {
_, err := stmt.ExecContext(ctx, db) _, err := stmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func cleanUpLinkTable(t *testing.T) { func cleanUpLinkTable(t *testing.T) {
_, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db) _, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db)
assert.NoError(t, err) require.NoError(t, err)
} }

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
) )
@ -16,7 +16,8 @@ LOCK TABLES dvds.customer READ;
`) `)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
requireLogged(t, query)
} }
func TestLockWrite(t *testing.T) { func TestLockWrite(t *testing.T) {
@ -27,7 +28,8 @@ LOCK TABLES dvds.customer WRITE;
`) `)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
requireLogged(t, query)
} }
func TestUnlockTables(t *testing.T) { func TestUnlockTables(t *testing.T) {
@ -38,5 +40,6 @@ UNLOCK TABLES;
`) `)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
requireLogged(t, query)
} }

View file

@ -1,9 +1,15 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"flag" "flag"
jetmysql "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
"github.com/stretchr/testify/require"
"math/rand"
"time"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@ -28,6 +34,7 @@ func sourceIsMariaDB() bool {
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop() defer profile.Start().Stop()
var err error var err error
@ -41,3 +48,21 @@ func TestMain(m *testing.M) {
os.Exit(ret) os.Exit(ret)
} }
var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string
func init() {
jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql()
})
}
func requireLogged(t *testing.T, statement postgres.Statement) {
query, args := statement.Sql()
require.Equal(t, loggedSQL, query)
require.Equal(t, loggedSQLArgs, args)
require.Equal(t, loggedDebugSQL, statement.DebugSql())
}

View file

@ -7,7 +7,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
) )
@ -30,9 +30,10 @@ WHERE actor.actor_id = ?;
actor := model.Actor{} actor := model.Actor{}
err := query.Query(db, &actor) err := query.Query(db, &actor)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertDeepEqual(t, actor, actor2) testutils.AssertDeepEqual(t, actor, actor2)
requireLogged(t, query)
} }
var actor2 = model.Actor{ var actor2 = model.Actor{
@ -59,14 +60,15 @@ ORDER BY actor.actor_id;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 200) require.Equal(t, len(dest), 200)
testutils.AssertDeepEqual(t, dest[1], actor2) testutils.AssertDeepEqual(t, dest[1], actor2)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
//testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json")
requireLogged(t, query)
} }
func TestSelectGroupByHaving(t *testing.T) { func TestSelectGroupByHaving(t *testing.T) {
@ -136,14 +138,15 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
assert.Equal(t, len(dest), 174) require.Equal(t, len(dest), 174)
//testutils.SaveJsonFile(dest, "mysql/testdata/customer_payment_sum.json") //testutils.SaveJsonFile(dest, "mysql/testdata/customer_payment_sum.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json")
requireLogged(t, query)
} }
func TestSubQuery(t *testing.T) { func TestSubQuery(t *testing.T) {
@ -176,7 +179,7 @@ func TestSubQuery(t *testing.T) {
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.SaveJsonFile(dest, "mysql/testdata/r_rating_films.json") //testutils.SaveJsonFile(dest, "mysql/testdata/r_rating_films.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/r_rating_films.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/r_rating_films.json")
@ -229,7 +232,7 @@ LIMIT ?;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestSelectUNION(t *testing.T) { func TestSelectUNION(t *testing.T) {
@ -265,7 +268,7 @@ LIMIT ?;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestSelectUNION_ALL(t *testing.T) { func TestSelectUNION_ALL(t *testing.T) {
@ -308,7 +311,7 @@ OFFSET ?;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestJoinQueryStruct(t *testing.T) { func TestJoinQueryStruct(t *testing.T) {
@ -406,10 +409,10 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//assert.Equal(t, len(dest), 1) //require.Equal(t, len(dest), 1)
//assert.Equal(t, len(dest[0].Films), 10) //require.Equal(t, len(dest[0].Films), 10)
//assert.Equal(t, len(dest[0].Films[0].Actors), 10) //require.Equal(t, len(dest[0].Films[0].Actors), 10)
//testutils.SaveJsonFile(dest, "./mysql/testdata/lang_film_actor_inventory_rental.json") //testutils.SaveJsonFile(dest, "./mysql/testdata/lang_film_actor_inventory_rental.json")
@ -450,10 +453,10 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NoError(t, err) require.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NoError(t, err) require.NoError(t, err)
} }
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
@ -464,10 +467,10 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NoError(t, err) require.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NoError(t, err) require.NoError(t, err)
} }
if sourceIsMariaDB() { if sourceIsMariaDB() {
@ -482,10 +485,10 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NoError(t, err) require.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NoError(t, err) require.NoError(t, err)
} }
} }
@ -514,7 +517,7 @@ SELECT true,
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestLockInShareMode(t *testing.T) { func TestLockInShareMode(t *testing.T) {
@ -535,7 +538,7 @@ LOCK IN SHARE MODE;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestWindowFunction(t *testing.T) { func TestWindowFunction(t *testing.T) {
@ -612,7 +615,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestWindowClause(t *testing.T) { func TestWindowClause(t *testing.T) {
@ -649,7 +652,7 @@ ORDER BY payment.customer_id;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestSimpleView(t *testing.T) { func TestSimpleView(t *testing.T) {
@ -670,9 +673,9 @@ func TestSimpleView(t *testing.T) {
var dest []ActorInfo var dest []ActorInfo
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 10) require.Equal(t, len(dest), 10)
testutils.AssertJSON(t, dest[1:2], ` testutils.AssertJSON(t, dest[1:2], `
[ [
{ {
@ -702,11 +705,11 @@ func TestJoinViewWithTable(t *testing.T) {
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32) require.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27) require.Equal(t, len(dest[1].Rentals), 27)
} }
func TestConditionalProjectionList(t *testing.T) { func TestConditionalProjectionList(t *testing.T) {
@ -737,7 +740,7 @@ LIMIT 3;
`) `)
var dest []model.Customer var dest []model.Customer
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 3) require.Equal(t, len(dest), 3)
} }

View file

@ -8,7 +8,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
@ -16,22 +16,35 @@ import (
func TestUpdateValues(t *testing.T) { func TestUpdateValues(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
var expectedSQL = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET name = 'Bong', SET name = 'Bong',
url = 'http://bong.com' url = 'http://bong.com'
WHERE link.name = 'Bing'; WHERE link.name = 'Bing';
` `
t.Run("old version", func(t *testing.T) {
query := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
fmt.Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, query, db) testutils.AssertExec(t, query, db)
requireLogged(t, query)
})
t.Run("new version", func(t *testing.T) {
stmt := Link.UPDATE().
SET(
Link.Name.SET(String("Bong")),
Link.URL.SET(String("http://bong.com")),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, stmt, db)
requireLogged(t, stmt)
})
links := []model.Link{} links := []model.Link{}
@ -40,8 +53,8 @@ WHERE link.name = 'Bing';
WHERE(Link.Name.EQ(String("Bong"))). WHERE(Link.Name.EQ(String("Bong"))).
Query(db, &links) Query(db, &links)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(links), 1) require.Equal(t, len(links), 1)
testutils.AssertDeepEqual(t, links[0], model.Link{ testutils.AssertDeepEqual(t, links[0], model.Link{
ID: 204, ID: 204,
URL: "http://bong.com", URL: "http://bong.com",
@ -52,16 +65,6 @@ WHERE link.name = 'Bing';
func TestUpdateWithSubQueries(t *testing.T) { func TestUpdateWithSubQueries(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET(
SELECT(String("Bong")),
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
).
WHERE(Link.Name.EQ(String("Bing")))
expectedSQL := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET name = ( SET name = (
@ -74,10 +77,39 @@ SET name = (
) )
WHERE link.name = ?; WHERE link.name = ?;
` `
fmt.Println(query.Sql()) t.Run("old version", func(t *testing.T) {
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") query := Link.
UPDATE(Link.Name, Link.URL).
SET(
SELECT(String("Bong")),
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db) testutils.AssertExec(t, query, db)
requireLogged(t, query)
})
t.Run("new version", func(t *testing.T) {
query := Link.
UPDATE().
SET(
Link.Name.SET(StringExp(SELECT(String("Bong")))),
Link.URL.SET(StringExp(
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
)),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db)
requireLogged(t, query)
})
} }
func TestUpdateWithModelData(t *testing.T) { func TestUpdateWithModelData(t *testing.T) {
@ -102,10 +134,10 @@ SET id = ?,
description = ? description = ?
WHERE link.id = ?; WHERE link.id = ?;
` `
fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
testutils.AssertExec(t, stmt, db) testutils.AssertExec(t, stmt, db)
requireLogged(t, stmt)
} }
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
@ -137,10 +169,10 @@ WHERE link.id = 201;
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
testutils.AssertExec(t, stmt, db) testutils.AssertExec(t, stmt, db)
requireLogged(t, stmt)
} }
func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { func TestUpdateWithModelDataAndMutableColumns(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
link := model.Link{ link := model.Link{
@ -164,15 +196,13 @@ WHERE link.id = 201;
fmt.Println(stmt.DebugSql()) fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
testutils.AssertExec(t, stmt, db) testutils.AssertExec(t, stmt, db)
} }
func TestUpdateWithInvalidModelData(t *testing.T) { func TestUpdateWithInvalidModelData(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
require.Equal(t, r, "missing struct field for column : id")
assert.Equal(t, r, "missing struct field for column : id")
}() }()
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
@ -213,7 +243,7 @@ func TestUpdateQueryContext(t *testing.T) {
dest := []model.Link{} dest := []model.Link{}
err := updateStmt.QueryContext(ctx, db, &dest) err := updateStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func TestUpdateExecContext(t *testing.T) { func TestUpdateExecContext(t *testing.T) {
@ -231,7 +261,7 @@ func TestUpdateExecContext(t *testing.T) {
_, err := updateStmt.ExecContext(ctx, db) _, err := updateStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func TestUpdateWithJoin(t *testing.T) { func TestUpdateWithJoin(t *testing.T) {
@ -244,7 +274,7 @@ func TestUpdateWithJoin(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
} }
func setupLinkTableForUpdateTest(t *testing.T) { func setupLinkTableForUpdateTest(t *testing.T) {
@ -259,5 +289,5 @@ func setupLinkTableForUpdateTest(t *testing.T) {
VALUES(204, "http://www.bing.com", "Bing", DEFAULT). VALUES(204, "http://www.bing.com", "Bing", DEFAULT).
Exec(db) Exec(db)
assert.NoError(t, err) require.NoError(t, err)
} }

159
tests/mysql/with_test.go Normal file
View file

@ -0,0 +1,159 @@
package mysql
import (
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/require"
"strings"
"testing"
)
func TestWITH_And_SELECT(t *testing.T) {
salesRep := CTE("sales_rep")
salesRepStaffID := Staff.StaffID.From(salesRep)
salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep)
customerSalesRep := CTE("customer_sales_rep")
stmt := WITH(
salesRep.AS(
SELECT(
Staff.StaffID,
Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()),
).FROM(Staff),
),
customerSalesRep.AS(
SELECT(
Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"),
salesRepFullName,
).FROM(
salesRep.
INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)).
INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)),
),
),
)(
SELECT(customerSalesRep.AllColumns()).
FROM(customerSalesRep),
)
//fmt.Println(stmt.DebugSql())
testutils.AssertStatementSql(t, stmt, strings.Replace(`
WITH sales_rep AS (
SELECT staff.staff_id AS "staff.staff_id",
(CONCAT(staff.first_name, staff.last_name)) AS "sales_rep_full_name"
FROM dvds.staff
),customer_sales_rep AS (
SELECT (CONCAT(customer.first_name, customer.last_name)) AS "customer_name",
sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM sales_rep
INNER JOIN dvds.store ON (store.manager_staff_id = sales_rep.''staff.staff_id'')
INNER JOIN dvds.customer ON (customer.store_id = store.store_id)
)
SELECT customer_sales_rep.customer_name AS "customer_name",
customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM customer_sales_rep;
`, "''", "`", -1))
var dest []struct {
CustomerName string
SalesRepFullName string
}
err := stmt.Query(db, &dest)
require.Equal(t, len(dest), 599)
require.NoError(t, err)
}
//func TestWITH_And_INSERT(t *testing.T) {
// paymentsToInsert := CTE("payments_to_insert")
//
// stmt := WITH(
// paymentsToInsert.AS(
// SELECT(Payment.AllColumns).
// FROM(Payment).
// WHERE(Payment.Amount.LT(Float(0.5))),
// ),
// )(
// Payment.INSERT(Payment.AllColumns).
// QUERY(
// SELECT(paymentsToInsert.AllColumns()).
// FROM(paymentsToInsert),
// ).ON_DUPLICATE_KEY_UPDATE(
// Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))),
// ),
// )
//
// //fmt.Println(stmt.DebugSql())
//
// tx, err := db.Begin()
// require.NoError(t, err)
// defer tx.Rollback()
//
// testutils.AssertExec(t, stmt, tx, 24)
//}
func TestWITH_And_UPDATE(t *testing.T) {
if sourceIsMariaDB() {
return
}
paymentsToUpdate := CTE("payments_to_update")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate)
stmt := WITH(
paymentsToUpdate.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.UPDATE().
SET(Payment.Amount.SET(Float(0.0))).
WHERE(Payment.PaymentID.IN(
SELECT(paymentsToDeleteID).
FROM(paymentsToUpdate),
),
),
)
//fmt.Println(stmt.DebugSql())
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx)
}
func TestWITH_And_DELETE(t *testing.T) {
if sourceIsMariaDB() {
return
}
paymentsToDelete := CTE("payments_to_delete")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete)
stmt := WITH(
paymentsToDelete.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.DELETE().
WHERE(Payment.PaymentID.IN(
SELECT(paymentsToDeleteID).
FROM(paymentsToDelete),
),
),
)
//fmt.Println(stmt.DebugSql())
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx, 24)
}

View file

@ -1,25 +1,24 @@
package postgres package postgres
import ( import (
"github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view"
"github.com/go-jet/jet/tests/testdata/results/common" "github.com/go-jet/jet/tests/testdata/results/common"
"github.com/google/uuid"
) )
func TestAllTypesSelect(t *testing.T) { func TestAllTypesSelect(t *testing.T) {
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest) err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
testutils.AssertDeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
@ -31,7 +30,7 @@ func TestAllTypesViewSelect(t *testing.T) {
dest := []AllTypesView{} dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertDeepEqual(t, dest[0], AllTypesView(allTypesRow0)) testutils.AssertDeepEqual(t, dest[0], AllTypesView(allTypesRow0))
testutils.AssertDeepEqual(t, dest[1], AllTypesView(allTypesRow1)) testutils.AssertDeepEqual(t, dest[1], AllTypesView(allTypesRow1))
@ -45,9 +44,9 @@ func TestAllTypesInsertModel(t *testing.T) {
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
testutils.AssertDeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
} }
@ -64,8 +63,8 @@ func TestAllTypesInsertQuery(t *testing.T) {
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
testutils.AssertDeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
} }
@ -80,7 +79,7 @@ func TestAllTypesFromSubQuery(t *testing.T) {
FROM(subQuery). FROM(subQuery).
LIMIT(2) LIMIT(2)
assert.Equal(t, mainQuery.DebugSql(), ` require.Equal(t, mainQuery.DebugSql(), `
SELECT "allTypesSubQuery"."all_types.small_int_ptr" AS "all_types.small_int_ptr", SELECT "allTypesSubQuery"."all_types.small_int_ptr" AS "all_types.small_int_ptr",
"allTypesSubQuery"."all_types.small_int" AS "all_types.small_int", "allTypesSubQuery"."all_types.small_int" AS "all_types.small_int",
"allTypesSubQuery"."all_types.integer_ptr" AS "all_types.integer_ptr", "allTypesSubQuery"."all_types.integer_ptr" AS "all_types.integer_ptr",
@ -212,8 +211,8 @@ LIMIT 2;
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := mainQuery.Query(db, &dest) err := mainQuery.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
} }
func TestExpressionOperators(t *testing.T) { func TestExpressionOperators(t *testing.T) {
@ -251,7 +250,7 @@ LIMIT $5;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -320,7 +319,7 @@ func TestExpressionCast(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestStringOperators(t *testing.T) { func TestStringOperators(t *testing.T) {
@ -400,7 +399,7 @@ func TestStringOperators(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestBoolOperators(t *testing.T) { func TestBoolOperators(t *testing.T) {
@ -469,7 +468,7 @@ LIMIT $5;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json")
} }
@ -519,7 +518,7 @@ func TestFloatOperators(t *testing.T) {
queryStr, _ := query.Sql() queryStr, _ := query.Sql()
assert.Equal(t, queryStr, ` require.Equal(t, queryStr, `
SELECT (all_types.numeric = all_types.numeric) AS "eq1", SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.decimal = $1) AS "eq2", (all_types.decimal = $1) AS "eq2",
(all_types.real = $2) AS "eq3", (all_types.real = $2) AS "eq3",
@ -565,7 +564,7 @@ LIMIT $35;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -704,7 +703,7 @@ LIMIT $23;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.SaveJsonFile("./testdata/common/int_operators.json", dest) //testutils.SaveJsonFile("./testdata/common/int_operators.json", dest)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -783,7 +782,7 @@ func TestTimeExpression(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestInterval(t *testing.T) { func TestInterval(t *testing.T) {
@ -834,7 +833,8 @@ func TestInterval(t *testing.T) {
//fmt.Println(stmt.DebugSql()) //fmt.Println(stmt.DebugSql())
err := stmt.Query(db, &struct{}{}) err := stmt.Query(db, &struct{}{})
assert.NoError(t, err) require.NoError(t, err)
requireLogged(t, stmt)
} }
func TestSubQueryColumnReference(t *testing.T) { func TestSubQueryColumnReference(t *testing.T) {
@ -986,12 +986,12 @@ FROM`
dest1 := []model.AllTypes{} dest1 := []model.AllTypes{}
err := stmt1.Query(db, &dest1) err := stmt1.Query(db, &dest1)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(dest1), 2) require.Equal(t, len(dest1), 2)
assert.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean) require.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean)
assert.Equal(t, dest1[0].Integer, allTypesRow0.Integer) require.Equal(t, dest1[0].Integer, allTypesRow0.Integer)
assert.Equal(t, dest1[0].Real, allTypesRow0.Real) require.Equal(t, dest1[0].Real, allTypesRow0.Real)
assert.Equal(t, dest1[0].Text, allTypesRow0.Text) require.Equal(t, dest1[0].Text, allTypesRow0.Text)
testutils.AssertDeepEqual(t, dest1[0].Time, allTypesRow0.Time) testutils.AssertDeepEqual(t, dest1[0].Time, allTypesRow0.Time)
testutils.AssertDeepEqual(t, dest1[0].Timez, allTypesRow0.Timez) testutils.AssertDeepEqual(t, dest1[0].Timez, allTypesRow0.Timez)
testutils.AssertDeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp) testutils.AssertDeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp)
@ -1008,15 +1008,16 @@ FROM`
dest2 := []model.AllTypes{} dest2 := []model.AllTypes{}
err = stmt2.Query(db, &dest2) err = stmt2.Query(db, &dest2)
assert.NoError(t, err) require.NoError(t, err)
testutils.AssertDeepEqual(t, dest1, dest2) testutils.AssertDeepEqual(t, dest1, dest2)
requireLogged(t, stmt2)
} }
} }
func TestTimeLiterals(t *testing.T) { func TestTimeLiterals(t *testing.T) {
loc, err := time.LoadLocation("Europe/Berlin") loc, err := time.LoadLocation("Europe/Berlin")
assert.NoError(t, err) require.NoError(t, err)
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, loc) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, loc)
@ -1051,7 +1052,7 @@ LIMIT $6;
err = query.Query(db, &dest) err = query.Query(db, &dest)
assert.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -1063,35 +1064,36 @@ LIMIT $6;
"Timestamp": "2009-11-17T20:34:58.651387Z" "Timestamp": "2009-11-17T20:34:58.651387Z"
} }
`) `)
requireLogged(t, query)
} }
var allTypesRow0 = model.AllTypes{ var allTypesRow0 = model.AllTypes{
SmallIntPtr: Int16Ptr(14), SmallIntPtr: testutils.Int16Ptr(14),
SmallInt: 14, SmallInt: 14,
IntegerPtr: Int32Ptr(300), IntegerPtr: testutils.Int32Ptr(300),
Integer: 300, Integer: 300,
BigIntPtr: Int64Ptr(50000), BigIntPtr: testutils.Int64Ptr(50000),
BigInt: 5000, BigInt: 5000,
DecimalPtr: Float64Ptr(1.11), DecimalPtr: testutils.Float64Ptr(1.11),
Decimal: 1.11, Decimal: 1.11,
NumericPtr: Float64Ptr(2.22), NumericPtr: testutils.Float64Ptr(2.22),
Numeric: 2.22, Numeric: 2.22,
RealPtr: Float32Ptr(5.55), RealPtr: testutils.Float32Ptr(5.55),
Real: 5.55, Real: 5.55,
DoublePrecisionPtr: Float64Ptr(11111111.22), DoublePrecisionPtr: testutils.Float64Ptr(11111111.22),
DoublePrecision: 11111111.22, DoublePrecision: 11111111.22,
Smallserial: 1, Smallserial: 1,
Serial: 1, Serial: 1,
Bigserial: 1, Bigserial: 1,
//MoneyPtr: nil, //MoneyPtr: nil,
//Money: //Money:
VarCharPtr: StringPtr("ABBA"), VarCharPtr: testutils.StringPtr("ABBA"),
VarChar: "ABBA", VarChar: "ABBA",
CharPtr: StringPtr("JOHN "), CharPtr: testutils.StringPtr("JOHN "),
Char: "JOHN ", Char: "JOHN ",
TextPtr: StringPtr("Some text"), TextPtr: testutils.StringPtr("Some text"),
Text: "Some text", Text: "Some text",
ByteaPtr: ByteArrayPtr([]byte("bytea")), ByteaPtr: testutils.ByteArrayPtr([]byte("bytea")),
Bytea: []byte("bytea"), Bytea: []byte("bytea"),
TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
@ -1103,31 +1105,31 @@ var allTypesRow0 = model.AllTypes{
Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"), Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"),
TimePtr: testutils.TimeWithoutTimeZone("04:05:06"), TimePtr: testutils.TimeWithoutTimeZone("04:05:06"),
Time: *testutils.TimeWithoutTimeZone("04:05:06"), Time: *testutils.TimeWithoutTimeZone("04:05:06"),
IntervalPtr: StringPtr("3 days 04:05:06"), IntervalPtr: testutils.StringPtr("3 days 04:05:06"),
Interval: "3 days 04:05:06", Interval: "3 days 04:05:06",
BooleanPtr: BoolPtr(true), BooleanPtr: testutils.BoolPtr(true),
Boolean: false, Boolean: false,
PointPtr: StringPtr("(2,3)"), PointPtr: testutils.StringPtr("(2,3)"),
BitPtr: StringPtr("101"), BitPtr: testutils.StringPtr("101"),
Bit: "101", Bit: "101",
BitVaryingPtr: StringPtr("101111"), BitVaryingPtr: testutils.StringPtr("101111"),
BitVarying: "101111", BitVarying: "101111",
TsvectorPtr: StringPtr("'supernova':1"), TsvectorPtr: testutils.StringPtr("'supernova':1"),
Tsvector: "'supernova':1", Tsvector: "'supernova':1",
UUIDPtr: UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
XMLPtr: StringPtr("<Sub>abc</Sub>"), XMLPtr: testutils.StringPtr("<Sub>abc</Sub>"),
XML: "<Sub>abc</Sub>", XML: "<Sub>abc</Sub>",
JSONPtr: StringPtr(`{"a": 1, "b": 3}`), JSONPtr: testutils.StringPtr(`{"a": 1, "b": 3}`),
JSON: `{"a": 1, "b": 3}`, JSON: `{"a": 1, "b": 3}`,
JsonbPtr: StringPtr(`{"a": 1, "b": 3}`), JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`),
Jsonb: `{"a": 1, "b": 3}`, Jsonb: `{"a": 1, "b": 3}`,
IntegerArrayPtr: StringPtr("{1,2,3}"), IntegerArrayPtr: testutils.StringPtr("{1,2,3}"),
IntegerArray: "{1,2,3}", IntegerArray: "{1,2,3}",
TextArrayPtr: StringPtr("{breakfast,consulting}"), TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"),
TextArray: "{breakfast,consulting}", TextArray: "{breakfast,consulting}",
JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`,
TextMultiDimArrayPtr: StringPtr("{{meeting,lunch},{training,presentation}}"), TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"),
TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", TextMultiDimArray: "{{meeting,lunch},{training,presentation}}",
} }

Some files were not shown because too many files have changed in this diff Show more