Merge pull request #40 from go-jet/develop

Merge develop to master for 2.3.0 release
This commit is contained in:
go-jet 2020-06-03 07:28:40 +02:00 committed by GitHub
commit c3903948c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
115 changed files with 3769 additions and 1460 deletions

View file

@ -35,17 +35,19 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
## Features
1) Auto-generated type-safe SQL Builder
- PostgreSQL:
* SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)`
* INSERT `(VALUES, query, RETURNING)`,
* UPDATE `(SET, WHERE, RETURNING)`,
* DELETE `(WHERE, RETURNING)`,
* LOCK `(IN, NOWAIT)`
* [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)`
* [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, QUERY, ON_CONFLICT, RETURNING)`,
* [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE, RETURNING)`,
* [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, RETURNING)`,
* [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(IN, NOWAIT)`
* [WITH](https://github.com/go-jet/jet/wiki/WITH)
- MySQL and MariaDB:
* SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)`
* INSERT `(VALUES, query)`,
* UPDATE `(SET, WHERE)`,
* DELETE `(WHERE, ORDER_BY, LIMIT)`,
* LOCK `(READ, WRITE)`
* [SELECT](https://github.com/go-jet/jet/wiki/SELECT) `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)`
* [INSERT](https://github.com/go-jet/jet/wiki/INSERT) `(VALUES, MODEL, MODELS, ON_DUPLICATE_KEY_UPDATE, query)`,
* [UPDATE](https://github.com/go-jet/jet/wiki/UPDATE) `(SET, MODEL, WHERE)`,
* [DELETE](https://github.com/go-jet/jet/wiki/DELETE) `(WHERE, ORDER_BY, LIMIT)`,
* [LOCK](https://github.com/go-jet/jet/wiki/LOCK) `(READ, WRITE)`
* [WITH](https://github.com/go-jet/jet/wiki/WITH)
2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store
result of database queries. Can be combined to create desired query result destination.
3) Query execution with result mapping to arbitrary destination structure.
@ -561,6 +563,7 @@ At the moment Jet dependence only of:
To run the tests, additional dependencies are required:
- `github.com/pkg/profile`
- `github.com/stretchr/testify`
- `github.com/google/go-cmp`
## Versioning
@ -568,5 +571,5 @@ To run the tests, additional dependencies are required:
## License
Copyright 2019 Goran Bjelanovic
Copyright 2019-2020 Goran Bjelanovic
Licensed under the Apache License, Version 2.0.

View file

@ -47,7 +47,7 @@ func main() {
flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, `
Jet generator 2.0.0
Jet generator 2.3.0
Usage:
-source string

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Actor = newActorTable()
type ActorTable struct {
type actorTable struct {
postgres.Table
//Columns
@ -27,25 +26,38 @@ type ActorTable struct {
MutableColumns postgres.ColumnList
}
// creates new ActorTable with assigned alias
type ActorTable struct {
actorTable
EXCLUDED actorTable
}
// AS creates new ActorTable with assigned alias
func (a *ActorTable) AS(alias string) *ActorTable {
aliasTable := newActorTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorTable() *ActorTable {
return &ActorTable{
actorTable: newActorTableImpl("dvds", "actor"),
EXCLUDED: newActorTableImpl("", "excluded"),
}
}
func newActorTableImpl(schemaName, tableName string) actorTable {
var (
ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = postgres.StringColumn("last_name")
LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}
)
return &ActorTable{
Table: postgres.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
return actorTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
ActorID: ActorIDColumn,
@ -53,7 +65,7 @@ func newActorTable() *ActorTable {
LastName: LastNameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn},
MutableColumns: postgres.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Category = newCategoryTable()
type CategoryTable struct {
type categoryTable struct {
postgres.Table
//Columns
@ -26,31 +25,44 @@ type CategoryTable struct {
MutableColumns postgres.ColumnList
}
// creates new CategoryTable with assigned alias
type CategoryTable struct {
categoryTable
EXCLUDED categoryTable
}
// AS creates new CategoryTable with assigned alias
func (a *CategoryTable) AS(alias string) *CategoryTable {
aliasTable := newCategoryTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newCategoryTable() *CategoryTable {
return &CategoryTable{
categoryTable: newCategoryTableImpl("dvds", "category"),
EXCLUDED: newCategoryTableImpl("", "excluded"),
}
}
func newCategoryTableImpl(schemaName, tableName string) categoryTable {
var (
CategoryIDColumn = postgres.IntegerColumn("category_id")
NameColumn = postgres.StringColumn("name")
LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{NameColumn, LastUpdateColumn}
)
return &CategoryTable{
Table: postgres.NewTable("dvds", "category", CategoryIDColumn, NameColumn, LastUpdateColumn),
return categoryTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
CategoryID: CategoryIDColumn,
Name: NameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{CategoryIDColumn, NameColumn, LastUpdateColumn},
MutableColumns: postgres.ColumnList{NameColumn, LastUpdateColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Film = newFilmTable()
type FilmTable struct {
type filmTable struct {
postgres.Table
//Columns
@ -36,16 +35,27 @@ type FilmTable struct {
MutableColumns postgres.ColumnList
}
// creates new FilmTable with assigned alias
type FilmTable struct {
filmTable
EXCLUDED filmTable
}
// AS creates new FilmTable with assigned alias
func (a *FilmTable) AS(alias string) *FilmTable {
aliasTable := newFilmTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newFilmTable() *FilmTable {
return &FilmTable{
filmTable: newFilmTableImpl("dvds", "film"),
EXCLUDED: newFilmTableImpl("", "excluded"),
}
}
func newFilmTableImpl(schemaName, tableName string) filmTable {
var (
FilmIDColumn = postgres.IntegerColumn("film_id")
TitleColumn = postgres.StringColumn("title")
@ -60,10 +70,12 @@ func newFilmTable() *FilmTable {
LastUpdateColumn = postgres.TimestampColumn("last_update")
SpecialFeaturesColumn = postgres.StringColumn("special_features")
FulltextColumn = postgres.StringColumn("fulltext")
allColumns = postgres.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}
mutableColumns = postgres.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn}
)
return &FilmTable{
Table: postgres.NewTable("dvds", "film", FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn),
return filmTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
FilmID: FilmIDColumn,
@ -80,7 +92,7 @@ func newFilmTable() *FilmTable {
SpecialFeatures: SpecialFeaturesColumn,
Fulltext: FulltextColumn,
AllColumns: postgres.ColumnList{FilmIDColumn, TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn},
MutableColumns: postgres.ColumnList{TitleColumn, DescriptionColumn, ReleaseYearColumn, LanguageIDColumn, RentalDurationColumn, RentalRateColumn, LengthColumn, ReplacementCostColumn, RatingColumn, LastUpdateColumn, SpecialFeaturesColumn, FulltextColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var FilmActor = newFilmActorTable()
type FilmActorTable struct {
type filmActorTable struct {
postgres.Table
//Columns
@ -26,31 +25,44 @@ type FilmActorTable struct {
MutableColumns postgres.ColumnList
}
// creates new FilmActorTable with assigned alias
type FilmActorTable struct {
filmActorTable
EXCLUDED filmActorTable
}
// AS creates new FilmActorTable with assigned alias
func (a *FilmActorTable) AS(alias string) *FilmActorTable {
aliasTable := newFilmActorTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newFilmActorTable() *FilmActorTable {
return &FilmActorTable{
filmActorTable: newFilmActorTableImpl("dvds", "film_actor"),
EXCLUDED: newFilmActorTableImpl("", "excluded"),
}
}
func newFilmActorTableImpl(schemaName, tableName string) filmActorTable {
var (
ActorIDColumn = postgres.IntegerColumn("actor_id")
FilmIDColumn = postgres.IntegerColumn("film_id")
LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{LastUpdateColumn}
)
return &FilmActorTable{
Table: postgres.NewTable("dvds", "film_actor", ActorIDColumn, FilmIDColumn, LastUpdateColumn),
return filmActorTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
ActorID: ActorIDColumn,
FilmID: FilmIDColumn,
LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{ActorIDColumn, FilmIDColumn, LastUpdateColumn},
MutableColumns: postgres.ColumnList{LastUpdateColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var FilmCategory = newFilmCategoryTable()
type FilmCategoryTable struct {
type filmCategoryTable struct {
postgres.Table
//Columns
@ -26,31 +25,44 @@ type FilmCategoryTable struct {
MutableColumns postgres.ColumnList
}
// creates new FilmCategoryTable with assigned alias
type FilmCategoryTable struct {
filmCategoryTable
EXCLUDED filmCategoryTable
}
// AS creates new FilmCategoryTable with assigned alias
func (a *FilmCategoryTable) AS(alias string) *FilmCategoryTable {
aliasTable := newFilmCategoryTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newFilmCategoryTable() *FilmCategoryTable {
return &FilmCategoryTable{
filmCategoryTable: newFilmCategoryTableImpl("dvds", "film_category"),
EXCLUDED: newFilmCategoryTableImpl("", "excluded"),
}
}
func newFilmCategoryTableImpl(schemaName, tableName string) filmCategoryTable {
var (
FilmIDColumn = postgres.IntegerColumn("film_id")
CategoryIDColumn = postgres.IntegerColumn("category_id")
LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{LastUpdateColumn}
)
return &FilmCategoryTable{
Table: postgres.NewTable("dvds", "film_category", FilmIDColumn, CategoryIDColumn, LastUpdateColumn),
return filmCategoryTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
FilmID: FilmIDColumn,
CategoryID: CategoryIDColumn,
LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{FilmIDColumn, CategoryIDColumn, LastUpdateColumn},
MutableColumns: postgres.ColumnList{LastUpdateColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var Language = newLanguageTable()
type LanguageTable struct {
type languageTable struct {
postgres.Table
//Columns
@ -26,31 +25,44 @@ type LanguageTable struct {
MutableColumns postgres.ColumnList
}
// creates new LanguageTable with assigned alias
type LanguageTable struct {
languageTable
EXCLUDED languageTable
}
// AS creates new LanguageTable with assigned alias
func (a *LanguageTable) AS(alias string) *LanguageTable {
aliasTable := newLanguageTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newLanguageTable() *LanguageTable {
return &LanguageTable{
languageTable: newLanguageTableImpl("dvds", "language"),
EXCLUDED: newLanguageTableImpl("", "excluded"),
}
}
func newLanguageTableImpl(schemaName, tableName string) languageTable {
var (
LanguageIDColumn = postgres.IntegerColumn("language_id")
NameColumn = postgres.StringColumn("name")
LastUpdateColumn = postgres.TimestampColumn("last_update")
allColumns = postgres.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn}
mutableColumns = postgres.ColumnList{NameColumn, LastUpdateColumn}
)
return &LanguageTable{
Table: postgres.NewTable("dvds", "language", LanguageIDColumn, NameColumn, LastUpdateColumn),
return languageTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
LanguageID: LanguageIDColumn,
Name: NameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: postgres.ColumnList{LanguageIDColumn, NameColumn, LastUpdateColumn},
MutableColumns: postgres.ColumnList{NameColumn, LastUpdateColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var ActorInfo = newActorInfoTable()
type ActorInfoTable struct {
type actorInfoTable struct {
postgres.Table
//Columns
@ -27,25 +26,38 @@ type ActorInfoTable struct {
MutableColumns postgres.ColumnList
}
// creates new ActorInfoTable with assigned alias
type ActorInfoTable struct {
actorInfoTable
EXCLUDED actorInfoTable
}
// AS creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable {
aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorInfoTable() *ActorInfoTable {
return &ActorInfoTable{
actorInfoTable: newActorInfoTableImpl("dvds", "actor_info"),
EXCLUDED: newActorInfoTableImpl("", "excluded"),
}
}
func newActorInfoTableImpl(schemaName, tableName string) actorInfoTable {
var (
ActorIDColumn = postgres.IntegerColumn("actor_id")
FirstNameColumn = postgres.StringColumn("first_name")
LastNameColumn = postgres.StringColumn("last_name")
FilmInfoColumn = postgres.StringColumn("film_info")
allColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
mutableColumns = postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
)
return &ActorInfoTable{
Table: postgres.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
return actorInfoTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
ActorID: ActorIDColumn,
@ -53,7 +65,7 @@ func newActorInfoTable() *ActorInfoTable {
LastName: LastNameColumn,
FilmInfo: FilmInfoColumn,
AllColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn},
MutableColumns: postgres.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -1,6 +1,5 @@
//
// Code generated by go-jet DO NOT EDIT.
// Generated at Thursday, 26-Sep-19 12:02:13 CEST
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -14,7 +13,7 @@ import (
var CustomerList = newCustomerListTable()
type CustomerListTable struct {
type customerListTable struct {
postgres.Table
//Columns
@ -32,30 +31,43 @@ type CustomerListTable struct {
MutableColumns postgres.ColumnList
}
// creates new CustomerListTable with assigned alias
type CustomerListTable struct {
customerListTable
EXCLUDED customerListTable
}
// AS creates new CustomerListTable with assigned alias
func (a *CustomerListTable) AS(alias string) *CustomerListTable {
aliasTable := newCustomerListTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newCustomerListTable() *CustomerListTable {
return &CustomerListTable{
customerListTable: newCustomerListTableImpl("dvds", "customer_list"),
EXCLUDED: newCustomerListTableImpl("", "excluded"),
}
}
func newCustomerListTableImpl(schemaName, tableName string) customerListTable {
var (
IDColumn = postgres.IntegerColumn("id")
NameColumn = postgres.StringColumn("name")
AddressColumn = postgres.StringColumn("address")
ZipCodeColumn = postgres.StringColumn("zip code")
PhoneColumn = postgres.StringColumn("phone")
CityColumn = postgres.StringColumn("city")
CountryColumn = postgres.StringColumn("country")
NotesColumn = postgres.StringColumn("notes")
SidColumn = postgres.IntegerColumn("sid")
IDColumn = postgres.IntegerColumn("id")
NameColumn = postgres.StringColumn("name")
AddressColumn = postgres.StringColumn("address")
ZipCodeColumn = postgres.StringColumn("zip code")
PhoneColumn = postgres.StringColumn("phone")
CityColumn = postgres.StringColumn("city")
CountryColumn = postgres.StringColumn("country")
NotesColumn = postgres.StringColumn("notes")
SidColumn = postgres.IntegerColumn("sid")
allColumns = postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}
mutableColumns = postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn}
)
return &CustomerListTable{
Table: postgres.NewTable("dvds", "customer_list", IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn),
return customerListTable{
Table: postgres.NewTable(schemaName, tableName, allColumns...),
//Columns
ID: IDColumn,
@ -68,7 +80,7 @@ func newCustomerListTable() *CustomerListTable {
Notes: NotesColumn,
Sid: SidColumn,
AllColumns: postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn},
MutableColumns: postgres.ColumnList{IDColumn, NameColumn, AddressColumn, ZipCodeColumn, PhoneColumn, CityColumn, CountryColumn, NotesColumn, SidColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}

View file

@ -3,10 +3,10 @@
This package contains sample usage for Jet framework.
Jet generated files of interest are in ./gen folder.
Jet generated files of interest are in `./gen` folder.
quick-start.go contains code explained at [README.md](../../README.md#quick-start),
with difference of redirecting json output to files(dest.json and dest2.json) rather then to a
`quick-start.go` - contains code explained at main [README.md](../../README.md#quick-start),
with a difference of redirecting json output to files(`dest.json` and `dest2.json`) rather then to a
standard output.
./gen, dest.json and dest2.json - added to git for presentation purposes.
`./gen`, `dest.json` and `dest2.json` - added to git for presentation purposes.

View file

@ -7,7 +7,7 @@ import (
_ "github.com/lib/pq"
"io/ioutil"
// dot import so go code would resemble as much as native SQL
// dot import so that jet go code would resemble as much as native SQL
// dot import is not mandatory
. "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table"
. "github.com/go-jet/jet/postgres"
@ -98,15 +98,15 @@ func printStatementInfo(stmt SelectStatement) {
query, args := stmt.Sql()
fmt.Println("Parameterized query: ")
fmt.Println("==============================")
fmt.Println(query)
fmt.Println("Arguments: ")
fmt.Println(args)
debugSQL := stmt.DebugSql()
fmt.Println("\n\n==============================")
fmt.Println("\n\nDebug sql: ")
fmt.Println("==============================")
fmt.Println(debugSQL)
}

View file

@ -3,6 +3,7 @@ package metadata
import (
"database/sql"
"github.com/go-jet/jet/internal/utils"
"strings"
)
// TableMetaData metadata struct
@ -67,15 +68,19 @@ func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table"
}
// GoStructImplName returns go struct impl name for sql builder
func (t TableMetaData) GoStructImplName() string {
name := utils.ToGoIdentifier(t.name) + "Table"
return string(strings.ToLower(name)[0]) + name[1:]
}
// GetTableMetaData returns table info metadata
func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) {
tableInfo.SchemaName = schemaName
tableInfo.name = tableName
tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName)
tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName)
return
}

View file

@ -8,7 +8,6 @@ import (
"github.com/go-jet/jet/internal/utils"
"path/filepath"
"text/template"
"time"
)
// GenerateFiles generates Go files from tables and enums metadata
@ -22,6 +21,7 @@ func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect j
err := utils.CleanUpGeneratedFiles(destDir)
utils.PanicOnError(err)
tableSQLBuilderTemplate := getTableSQLBuilderTemplate(dialect)
generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect)
generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect)
generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect)
@ -33,6 +33,14 @@ func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect j
fmt.Println("Done")
}
func getTableSQLBuilderTemplate(dialect jet.Dialect) string {
if dialect.Name() == "PostgreSQL" {
return tablePostgreSQLBuilderTemplate
}
return tableSQLBuilderTemplate
}
func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
@ -75,9 +83,6 @@ func GenerateTemplate(templateText string, templateData interface{}, dialect jet
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier,
"ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier,
"now": func() string {
return time.Now().Format(time.RFC850)
},
"dialect": func() jet.Dialect {
return dialect
},

View file

@ -3,7 +3,6 @@ package template
var autoGenWarningTemplate = `
//
// Code generated by go-jet DO NOT EDIT.
// Generated at {{now}}
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
@ -38,35 +37,104 @@ type {{.GoStructName}} struct {
MutableColumns {{dialect.PackageName}}.ColumnList
}
// creates new {{.GoStructName}} with assigned alias
func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
// AS creates new {{.GoStructName}} with assigned alias
func (a *{{.GoStructName}}) AS(alias string) {{.GoStructName}} {
aliasTable := new{{.GoStructName}}()
aliasTable.Table.AS(alias)
return aliasTable
}
func new{{.GoStructName}}() *{{.GoStructName}} {
func new{{.GoStructName}}() {{.GoStructName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return &{{.GoStructName}}{
Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
return {{.GoStructName}}{
Table: {{dialect.PackageName}}.NewTable("{{.SchemaName}}", "{{.Name}}", allColumns...),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} },
MutableColumns: {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} },
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tablePostgreSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
{{- end}}
{{- end}}
package {{param "package"}}
import (
"github.com/go-jet/jet/{{dialect.PackageName}}"
)
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
type {{.GoStructImplName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
type {{.GoStructName}} struct {
{{.GoStructImplName}}
EXCLUDED {{.GoStructImplName}}
}
// AS creates new {{.GoStructName}} with assigned alias
func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
aliasTable := new{{.GoStructName}}()
aliasTable.Table.AS(alias)
return aliasTable
}
func new{{.GoStructName}}() *{{.GoStructName}} {
return &{{.GoStructName}}{
{{.GoStructImplName}}: new{{.GoStructName}}Impl("{{.SchemaName}}", "{{.Name}}"),
EXCLUDED: new{{.GoStructName}}Impl("", "excluded"),
}
}
func new{{.GoStructName}}Impl(schemaName, tableName string) {{.GoStructImplName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{.GoStructImplName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, allColumns...),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tableModelTemplate = `package model

10
go.mod Normal file
View file

@ -0,0 +1,10 @@
module github.com/go-jet/jet
require (
github.com/go-sql-driver/mysql v1.5.0
github.com/google/go-cmp v0.4.1
github.com/google/uuid v1.1.1
github.com/lib/pq v1.6.0
github.com/pkg/profile v1.5.0
github.com/stretchr/testify v1.6.0
)

60
go.sum Normal file
View file

@ -0,0 +1,60 @@
github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 h1:P5U+E4x5OkVEKQDklVPmzs71WM56RTTRqV4OrDC//Y4=
github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5/go.mod h1:976q2ETgjT2snVCf2ZaBnyBbVoPERGjUz+0sofzEfro=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/google/go-cmp v0.4.1 h1:/exdXoGamhu5ONeUJH0deniYLWYvQwW66yvlfiiKTu0=
github.com/google/go-cmp v0.4.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ=
github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8=
github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o=
github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
github.com/jcmturner/gokrb5/v8 v8.2.0 h1:lzPl/30ZLkTveYsYZPKMcgXc8MbnE6RsTd4F9KgiLtk=
github.com/jcmturner/gokrb5/v8 v8.2.0/go.mod h1:T1hnNppQsBtxW0tCHMHTkAt8n/sABdzZgZdoFrZaZNM=
github.com/jcmturner/rpc/v2 v2.0.2 h1:gMB4IwRXYsWw4Bc6o/az2HJgFUA1ffSh90i26ZJ6Xl0=
github.com/jcmturner/rpc/v2 v2.0.2/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/lib/pq v1.6.0 h1:I5DPxhYJChW9KYc66se+oKFFQX6VuQrKiprsX6ivRZc=
github.com/lib/pq v1.6.0/go.mod h1:4vXEAYvW1fRQ2/FhZ78H73A60MHw1geSm145z2mdY1g=
github.com/pkg/profile v1.5.0 h1:042Buzk+NhDI+DeSAA62RwJL8VAuZUMQZUjCsRz1Mug=
github.com/pkg/profile v1.5.0/go.mod h1:qBsxPvzyUincmltOk6iyRVxHYg4adc0OFOv72ZdLa18=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.0 h1:jlIyCplCJFULU/01vCkhKuTyc3OorI3bJFuw6obfgho=
github.com/stretchr/testify v1.6.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4 h1:QmwruyY+bKbDDL0BaglrbZABEali68eoMFhTZpCjYVA=
golang.org/x/crypto v0.0.0-20200311171314-f7b00557c8c4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo=
gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q=
gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4=
gopkg.in/jcmturner/gokrb5.v7 v7.5.0/go.mod h1:l8VISx+WGYp+Fp7KRbsiUuXTTOnxIc3Tuvyavf11/WM=
gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View file

@ -1,16 +1,16 @@
package snaker
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestSnakeToCamel(t *testing.T) {
assert.Equal(t, SnakeToCamel(""), "")
assert.Equal(t, SnakeToCamel("potato_"), "Potato")
assert.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
assert.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
assert.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")
assert.Equal(t, SnakeToCamel("id"), "ID")
assert.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient")
require.Equal(t, SnakeToCamel(""), "")
require.Equal(t, SnakeToCamel("potato_"), "Potato")
require.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
require.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
require.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")
require.Equal(t, SnakeToCamel("id"), "ID")
require.Equal(t, SnakeToCamel("oauth_client"), "OAuthClient")
}

View file

@ -42,12 +42,12 @@ func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, opt
castType := b.cast
if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil {
castOverride(expression, String(castType))(statement, out, options...)
castOverride(expression, String(castType))(statement, out, FallTrough(options)...)
return
}
out.WriteString("CAST(")
expression.serialize(statement, out, options...)
expression.serialize(statement, out, FallTrough(options)...)
out.WriteString("AS")
out.WriteString(castType + ")")
}

View file

@ -6,28 +6,29 @@ import (
// Clause interface
type Clause interface {
Serialize(statementType StatementType, out *SQLBuilder)
Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption)
}
// ClauseWithProjections interface
type ClauseWithProjections interface {
Clause
projections() ProjectionList
Projections() ProjectionList
}
// ClauseSelect struct
type ClauseSelect struct {
Distinct bool
Projections []Projection
Distinct bool
ProjectionList []Projection
}
func (s *ClauseSelect) projections() ProjectionList {
return s.Projections
// Projections returns list of projections for select clause
func (s *ClauseSelect) Projections() ProjectionList {
return s.ProjectionList
}
// Serialize serializes clause into SQLBuilder
func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder) {
func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("SELECT")
@ -35,11 +36,11 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("DISTINCT")
}
if len(s.Projections) == 0 {
if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.WriteProjections(statementType, s.Projections)
out.WriteProjections(statementType, s.ProjectionList)
}
// ClauseFrom struct
@ -48,7 +49,7 @@ type ClauseFrom struct {
}
// Serialize serializes clause into SQLBuilder
func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) {
func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if f.Table == nil {
return
}
@ -56,7 +57,7 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("FROM")
out.IncreaseIdent()
f.Table.serialize(statementType, out)
f.Table.serialize(statementType, out, FallTrough(options)...)
out.DecreaseIdent()
}
@ -67,18 +68,20 @@ type ClauseWhere struct {
}
// Serialize serializes clause into SQLBuilder
func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder) {
func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.Condition == nil {
if c.Mandatory {
panic("jet: WHERE clause not set")
}
return
}
out.NewLine()
if !contains(options, SkipNewLine) {
out.NewLine()
}
out.WriteString("WHERE")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...)
out.DecreaseIdent()
}
@ -88,7 +91,7 @@ type ClauseGroupBy struct {
}
// Serialize serializes clause into SQLBuilder
func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder) {
func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(c.List) == 0 {
return
}
@ -119,7 +122,7 @@ type ClauseHaving struct {
}
// Serialize serializes clause into SQLBuilder
func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) {
func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.Condition == nil {
return
}
@ -128,7 +131,7 @@ func (c *ClauseHaving) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("HAVING")
out.IncreaseIdent()
c.Condition.serialize(statementType, out, noWrap)
c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...)
out.DecreaseIdent()
}
@ -139,7 +142,7 @@ type ClauseOrderBy struct {
}
// Serialize serializes clause into SQLBuilder
func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder) {
func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if o.List == nil {
return
}
@ -168,7 +171,7 @@ type ClauseLimit struct {
}
// Serialize serializes clause into SQLBuilder
func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder) {
func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if l.Count >= 0 {
out.NewLine()
out.WriteString("LIMIT")
@ -182,7 +185,7 @@ type ClauseOffset struct {
}
// Serialize serializes clause into SQLBuilder
func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder) {
func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if o.Count >= 0 {
out.NewLine()
out.WriteString("OFFSET")
@ -196,27 +199,28 @@ type ClauseFor struct {
}
// Serialize serializes clause into SQLBuilder
func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder) {
func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if f.Lock == nil {
return
}
out.NewLine()
out.WriteString("FOR")
f.Lock.serialize(statementType, out)
f.Lock.serialize(statementType, out, FallTrough(options)...)
}
// ClauseSetStmtOperator struct
type ClauseSetStmtOperator struct {
Operator string
All bool
Selects []StatementWithProjections
Selects []SerializerStatement
OrderBy ClauseOrderBy
Limit ClauseLimit
Offset ClauseOffset
}
func (s *ClauseSetStmtOperator) projections() ProjectionList {
// Projections returns set of projections for ClauseSetStmtOperator
func (s *ClauseSetStmtOperator) Projections() ProjectionList {
if len(s.Selects) > 0 {
return s.Selects[0].projections()
}
@ -224,7 +228,7 @@ func (s *ClauseSetStmtOperator) projections() ProjectionList {
}
// Serialize serializes clause into SQLBuilder
func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder) {
func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s.Selects) < 2 {
panic("jet: UNION Statement must contain at least two SELECT statements")
}
@ -244,7 +248,7 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB
panic("jet: select statement of '" + s.Operator + "' is nil")
}
selectStmt.serialize(statementType, out)
selectStmt.serialize(statementType, out, FallTrough(options)...)
}
s.OrderBy.Serialize(statementType, out)
@ -258,7 +262,7 @@ type ClauseUpdate struct {
}
// Serialize serializes clause into SQLBuilder
func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) {
func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("UPDATE")
@ -266,17 +270,20 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder) {
panic("jet: table to update is nil")
}
u.Table.serialize(statementType, out)
u.Table.serialize(statementType, out, FallTrough(options)...)
}
// ClauseSet struct
type ClauseSet struct {
// SetClause struct
type SetClause struct {
Columns []Column
Values []Serializer
}
// Serialize serializes clause into SQLBuilder
func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) {
func (s *SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s.Values) == 0 {
return
}
out.NewLine()
out.WriteString("SET")
@ -287,7 +294,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) {
out.IncreaseIdent(4)
for i, column := range s.Columns {
if i > 0 {
out.WriteString(", ")
out.WriteString(",")
out.NewLine()
}
@ -299,7 +306,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString(" = ")
s.Values[i].serialize(UpdateStatementType, out)
s.Values[i].serialize(UpdateStatementType, out, FallTrough(options)...)
}
out.DecreaseIdent(4)
}
@ -320,7 +327,7 @@ func (i *ClauseInsert) GetColumns() []Column {
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder) {
func (i *ClauseInsert) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("INSERT INTO")
@ -346,7 +353,7 @@ type ClauseValuesQuery struct {
}
// Serialize serializes clause into SQLBuilder
func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder) {
func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(v.Rows) == 0 && v.Query == nil {
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
@ -355,8 +362,8 @@ func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuild
panic("jet: VALUES or QUERY has to be specified for INSERT statement")
}
v.ClauseValues.Serialize(statementType, out)
v.ClauseQuery.Serialize(statementType, out)
v.ClauseValues.Serialize(statementType, out, FallTrough(options)...)
v.ClauseQuery.Serialize(statementType, out, FallTrough(options)...)
}
// ClauseValues struct
@ -365,27 +372,29 @@ type ClauseValues struct {
}
// Serialize serializes clause into SQLBuilder
func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder) {
func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(v.Rows) == 0 {
return
}
out.NewLine()
out.WriteString("VALUES")
for rowIndex, row := range v.Rows {
if rowIndex > 0 {
out.WriteString(",")
out.NewLine()
} else {
out.IncreaseIdent(7)
}
out.IncreaseIdent()
out.NewLine()
out.WriteString("(")
SerializeClauseList(statementType, row, out)
out.WriteByte(')')
out.DecreaseIdent()
}
out.DecreaseIdent(7)
}
// ClauseQuery struct
@ -394,12 +403,12 @@ type ClauseQuery struct {
}
// Serialize serializes clause into SQLBuilder
func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder) {
func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if v.Query == nil {
return
}
v.Query.serialize(statementType, out)
v.Query.serialize(statementType, out, FallTrough(options)...)
}
// ClauseDelete struct
@ -408,7 +417,7 @@ type ClauseDelete struct {
}
// Serialize serializes clause into SQLBuilder
func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) {
func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("DELETE FROM")
@ -416,7 +425,7 @@ func (d *ClauseDelete) Serialize(statementType StatementType, out *SQLBuilder) {
panic("jet: nil table in DELETE clause")
}
d.Table.serialize(statementType, out)
d.Table.serialize(statementType, out, FallTrough(options)...)
}
// ClauseStatementBegin struct
@ -426,7 +435,7 @@ type ClauseStatementBegin struct {
}
// Serialize serializes clause into SQLBuilder
func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder) {
func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString(d.Name)
@ -435,7 +444,7 @@ func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SQLBu
out.WriteString(", ")
}
table.serialize(statementType, out)
table.serialize(statementType, out, FallTrough(options)...)
}
}
@ -447,7 +456,7 @@ type ClauseOptional struct {
}
// Serialize serializes clause into SQLBuilder
func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder) {
func (d *ClauseOptional) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if !d.Show {
return
}
@ -463,7 +472,7 @@ type ClauseIn struct {
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder) {
func (i *ClauseIn) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if i.LockMode == "" {
return
}
@ -485,7 +494,7 @@ type ClauseWindow struct {
}
// Serialize serializes clause into SQLBuilder
func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) {
func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(i.Definitions) == 0 {
return
}
@ -503,6 +512,46 @@ func (i *ClauseWindow) Serialize(statementType StatementType, out *SQLBuilder) {
out.WriteString("()")
continue
}
def.Window.serialize(statementType, out)
def.Window.serialize(statementType, out, FallTrough(options)...)
}
}
// SetPair clause
type SetPair struct {
Column ColumnSerializer
Value Serializer
}
// SetClauseNew clause
type SetClauseNew []ColumnAssigment
// Serialize for SetClauseNew
func (s SetClauseNew) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("SET")
out.IncreaseIdent(4)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
assigment.serialize(statementType, out, FallTrough(options)...)
}
out.DecreaseIdent(4)
}
// KeywordClause type
type KeywordClause struct {
Keyword
}
// Serialize for KeywordClause
func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
k.serialize(statementType, out, FallTrough(options)...)
}

View file

@ -1,14 +1,14 @@
package jet
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestClauseSelect_Serialize(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "jet: SELECT clause has to have at least one projection")
require.Equal(t, r, "jet: SELECT clause has to have at least one projection")
}()
selectClause := &ClauseSelect{}

View file

@ -12,6 +12,12 @@ type Column interface {
defaultAlias() string
}
// ColumnSerializer is interface for all serializable columns
type ColumnSerializer interface {
Serializer
Column
}
// ColumnExpression interface
type ColumnExpression interface {
Column
@ -99,9 +105,9 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias())
out.WriteByte('.')
out.WriteIdentifier(c.defaultAlias(), true)
out.WriteIdentifier(c.defaultAlias())
} else {
if c.tableName != "" {
if c.tableName != "" && !contains(options, ShortName) {
out.WriteIdentifier(c.tableName)
out.WriteByte('.')
}
@ -109,45 +115,3 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
out.WriteIdentifier(c.name)
}
}
//------------------------------------------------------//
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -0,0 +1,20 @@
package jet
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment interface {
Serializer
isColumnAssigment()
}
type columnAssigmentImpl struct {
column ColumnSerializer
expression Expression
}
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
out.WriteString("=")
a.expression.serialize(statement, out, FallTrough(options)...)
}

View file

@ -0,0 +1,60 @@
package jet
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
// SET creates column assigment for each column in column list. expression should be created by ROW function
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
return columnAssigmentImpl{
column: cl,
expression: expression,
}
}
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
for i, column := range cl {
if i > 0 {
out.WriteString(", ")
}
column.serialize(statement, out, FallTrough(options)...)
}
out.WriteString(")")
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -6,6 +6,7 @@ type ColumnBool interface {
Column
From(subQuery SelectTable) ColumnBool
SET(boolExp BoolExpression) ColumnAssigment
}
type boolColumnImpl struct {
@ -21,6 +22,13 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
return newBoolColumn
}
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: boolExp,
}
}
// BoolColumn creates named bool column.
func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{}
@ -38,6 +46,7 @@ type ColumnFloat interface {
Column
From(subQuery SelectTable) ColumnFloat
SET(floatExp FloatExpression) ColumnAssigment
}
type floatColumnImpl struct {
@ -53,6 +62,13 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
return newFloatColumn
}
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: floatExp,
}
}
// FloatColumn creates named float column.
func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{}
@ -70,6 +86,7 @@ type ColumnInteger interface {
Column
From(subQuery SelectTable) ColumnInteger
SET(intExp IntegerExpression) ColumnAssigment
}
type integerColumnImpl struct {
@ -86,6 +103,13 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return newIntColumn
}
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intExp,
}
}
// IntegerColumn creates named integer column.
func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{}
@ -104,6 +128,7 @@ type ColumnString interface {
Column
From(subQuery SelectTable) ColumnString
SET(stringExp StringExpression) ColumnAssigment
}
type stringColumnImpl struct {
@ -120,6 +145,13 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return newStrColumn
}
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: stringExp,
}
}
// StringColumn creates named string column.
func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{}
@ -137,6 +169,7 @@ type ColumnTime interface {
Column
From(subQuery SelectTable) ColumnTime
SET(timeExp TimeExpression) ColumnAssigment
}
type timeColumnImpl struct {
@ -152,6 +185,13 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return newTimeColumn
}
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timeExp,
}
}
// TimeColumn creates named time column
func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{}
@ -183,6 +223,13 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return newTimezColumn
}
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timezExp,
}
}
// TimezColumn creates named time with time zone column.
func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{}
@ -200,6 +247,7 @@ type ColumnTimestamp interface {
Column
From(subQuery SelectTable) ColumnTimestamp
SET(timestampExp TimestampExpression) ColumnAssigment
}
type timestampColumnImpl struct {
@ -215,6 +263,13 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return newTimestampColumn
}
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampExp,
}
}
// TimestampColumn creates named timestamp column
func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{}
@ -232,6 +287,7 @@ type ColumnTimestampz interface {
Column
From(subQuery SelectTable) ColumnTimestampz
SET(timestampzExp TimestampzExpression) ColumnAssigment
}
type timestampzColumnImpl struct {
@ -247,6 +303,13 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return newTimestampzColumn
}
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampzExp,
}
}
// TimestampzColumn creates named timestamp with time zone column.
func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{}
@ -264,6 +327,7 @@ type ColumnDate interface {
Column
From(subQuery SelectTable) ColumnDate
SET(dateExp DateExpression) ColumnAssigment
}
type dateColumnImpl struct {
@ -279,6 +343,13 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return newDateColumn
}
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: dateExp,
}
}
// DateColumn creates named date column.
func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{}

View file

@ -22,9 +22,9 @@ func TestNewBoolColumn(t *testing.T) {
func TestNewIntColumn(t *testing.T) {
intColumn := IntegerColumn("col_int").From(subQuery)
assertClauseSerialize(t, intColumn, `sub_query."col_int"`)
assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query."col_int" = $1)`, int64(12))
assertProjectionSerialize(t, intColumn, `sub_query."col_int" AS "col_int"`)
assertClauseSerialize(t, intColumn, `sub_query.col_int`)
assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query.col_int = $1)`, int64(12))
assertProjectionSerialize(t, intColumn, `sub_query.col_int AS "col_int"`)
intColumn2 := table1ColInt.From(subQuery)
assertClauseSerialize(t, intColumn2, `sub_query."table1.col_int"`)
@ -35,9 +35,9 @@ func TestNewIntColumn(t *testing.T) {
func TestNewFloatColumnColumn(t *testing.T) {
floatColumn := FloatColumn("col_float").From(subQuery)
assertClauseSerialize(t, floatColumn, `sub_query."col_float"`)
assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query."col_float" = $1)`, float64(1.11))
assertProjectionSerialize(t, floatColumn, `sub_query."col_float" AS "col_float"`)
assertClauseSerialize(t, floatColumn, `sub_query.col_float`)
assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query.col_float = $1)`, float64(1.11))
assertProjectionSerialize(t, floatColumn, `sub_query.col_float AS "col_float"`)
floatColumn2 := table1ColFloat.From(subQuery)
assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`)
@ -47,10 +47,10 @@ func TestNewFloatColumnColumn(t *testing.T) {
func TestNewDateColumnColumn(t *testing.T) {
dateColumn := DateColumn("col_date").From(subQuery)
assertClauseSerialize(t, dateColumn, `sub_query."col_date"`)
assertClauseSerialize(t, dateColumn, `sub_query.col_date`)
assertClauseSerialize(t, dateColumn.EQ(Date(2002, 2, 3)),
`(sub_query."col_date" = $1)`, "2002-02-03")
assertProjectionSerialize(t, dateColumn, `sub_query."col_date" AS "col_date"`)
`(sub_query.col_date = $1)`, "2002-02-03")
assertProjectionSerialize(t, dateColumn, `sub_query.col_date AS "col_date"`)
dateColumn2 := table1ColDate.From(subQuery)
assertClauseSerialize(t, dateColumn2, `sub_query."table1.col_date"`)
@ -61,10 +61,10 @@ func TestNewDateColumnColumn(t *testing.T) {
func TestNewTimeColumnColumn(t *testing.T) {
timeColumn := TimeColumn("col_time").From(subQuery)
assertClauseSerialize(t, timeColumn, `sub_query."col_time"`)
assertClauseSerialize(t, timeColumn, `sub_query.col_time`)
assertClauseSerialize(t, timeColumn.EQ(Time(1, 1, 1, 1)),
`(sub_query."col_time" = $1)`, "01:01:01.000000001")
assertProjectionSerialize(t, timeColumn, `sub_query."col_time" AS "col_time"`)
`(sub_query.col_time = $1)`, "01:01:01.000000001")
assertProjectionSerialize(t, timeColumn, `sub_query.col_time AS "col_time"`)
timeColumn2 := table1ColTime.From(subQuery)
assertClauseSerialize(t, timeColumn2, `sub_query."table1.col_time"`)
@ -75,10 +75,10 @@ func TestNewTimeColumnColumn(t *testing.T) {
func TestNewTimezColumnColumn(t *testing.T) {
timezColumn := TimezColumn("col_timez").From(subQuery)
assertClauseSerialize(t, timezColumn, `sub_query."col_timez"`)
assertClauseSerialize(t, timezColumn, `sub_query.col_timez`)
assertClauseSerialize(t, timezColumn.EQ(Timez(1, 1, 1, 1, "UTC")),
`(sub_query."col_timez" = $1)`, "01:01:01.000000001 UTC")
assertProjectionSerialize(t, timezColumn, `sub_query."col_timez" AS "col_timez"`)
`(sub_query.col_timez = $1)`, "01:01:01.000000001 UTC")
assertProjectionSerialize(t, timezColumn, `sub_query.col_timez AS "col_timez"`)
timezColumn2 := table1ColTimez.From(subQuery)
assertClauseSerialize(t, timezColumn2, `sub_query."table1.col_timez"`)
@ -89,10 +89,10 @@ func TestNewTimezColumnColumn(t *testing.T) {
func TestNewTimestampColumnColumn(t *testing.T) {
timestampColumn := TimestampColumn("col_timestamp").From(subQuery)
assertClauseSerialize(t, timestampColumn, `sub_query."col_timestamp"`)
assertClauseSerialize(t, timestampColumn, `sub_query.col_timestamp`)
assertClauseSerialize(t, timestampColumn.EQ(Timestamp(1, 1, 1, 1, 1, 1)),
`(sub_query."col_timestamp" = $1)`, "0001-01-01 01:01:01")
assertProjectionSerialize(t, timestampColumn, `sub_query."col_timestamp" AS "col_timestamp"`)
`(sub_query.col_timestamp = $1)`, "0001-01-01 01:01:01")
assertProjectionSerialize(t, timestampColumn, `sub_query.col_timestamp AS "col_timestamp"`)
timestampColumn2 := table1ColTimestamp.From(subQuery)
assertClauseSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp"`)
@ -103,10 +103,10 @@ func TestNewTimestampColumnColumn(t *testing.T) {
func TestNewTimestampzColumnColumn(t *testing.T) {
timestampzColumn := TimestampzColumn("col_timestampz").From(subQuery)
assertClauseSerialize(t, timestampzColumn, `sub_query."col_timestampz"`)
assertClauseSerialize(t, timestampzColumn, `sub_query.col_timestampz`)
assertClauseSerialize(t, timestampzColumn.EQ(Timestampz(1, 1, 1, 1, 1, 1, 0, "UTC")),
`(sub_query."col_timestampz" = $1)`, "0001-01-01 01:01:01 UTC")
assertProjectionSerialize(t, timestampzColumn, `sub_query."col_timestampz" AS "col_timestampz"`)
`(sub_query.col_timestampz = $1)`, "0001-01-01 01:01:01 UTC")
assertProjectionSerialize(t, timestampzColumn, `sub_query.col_timestampz AS "col_timestampz"`)
timestampzColumn2 := table1ColTimestampz.From(subQuery)
assertClauseSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz"`)

View file

@ -72,15 +72,15 @@ func (e *ExpressionInterfaceImpl) DESC() OrderByClause {
}
func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap)
e.Parent.serialize(statement, out, NoWrap)
}
func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap)
e.Parent.serialize(statement, out, NoWrap)
}
func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap)
e.Parent.serialize(statement, out, NoWrap)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
@ -117,7 +117,7 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
panic("jet: rhs is nil for '" + c.operator + "' operator")
}
wrap := !contains(options, noWrap)
wrap := !contains(options, NoWrap)
if wrap {
out.WriteString("(")
@ -125,11 +125,11 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, options...)
serializeOverrideFunc(statement, out, FallTrough(options)...)
} else {
c.lhs.serialize(statement, out)
c.lhs.serialize(statement, out, FallTrough(options)...)
out.WriteString(c.operator)
c.rhs.serialize(statement, out)
c.rhs.serialize(statement, out, FallTrough(options)...)
}
if wrap {
@ -163,7 +163,7 @@ func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, o
panic("jet: nil prefix expression in prefix operator " + p.operator)
}
p.expression.serialize(statement, out)
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(")")
}
@ -192,7 +192,7 @@ func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder
panic("jet: nil prefix expression in postfix operator " + p.operator)
}
p.expression.serialize(statement, out)
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(p.operator)
}

View file

@ -145,6 +145,11 @@ func MINi(integerExpression IntegerExpression) integerWindowExpression {
return newIntegerWindowFunc("MIN", integerExpression)
}
// SUM is aggregate function. Returns sum of all expressions
func SUM(expression Expression) Expression {
return newWindowFunc("SUM", expression)
}
// SUMf is aggregate function. Returns sum of expression across all float expressions
func SUMf(floatExpression FloatExpression) floatWindowExpression {
return NewFloatWindowFunc("SUM", floatExpression)
@ -613,7 +618,7 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression {
func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...)
serializeOverrideFunc(statement, out, options...)
serializeOverrideFunc(statement, out, FallTrough(options)...)
return
}

View file

@ -33,5 +33,5 @@ type IntervalImpl struct {
func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("INTERVAL")
i.interval.serialize(statement, out, options...)
i.interval.serialize(statement, out, FallTrough(options)...)
}

View file

@ -2,18 +2,12 @@ package jet
const (
// DEFAULT is jet equivalent of SQL DEFAULT
DEFAULT keywordClause = "DEFAULT"
DEFAULT Keyword = "DEFAULT"
)
var (
// NULL is jet equivalent of SQL NULL
NULL = newNullLiteral()
// STAR is jet equivalent of SQL *
STAR = newStarLiteral()
)
// Keyword type
type Keyword string
type keywordClause string
func (k keywordClause) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
func (k Keyword) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(string(k))
}

View file

@ -278,6 +278,14 @@ func formatNanoseconds(nanoseconds ...time.Duration) string {
}
//--------------------------------------------------//
var (
// NULL is jet equivalent of SQL NULL
NULL = newNullLiteral()
// STAR is jet equivalent of SQL *
STAR = newStarLiteral()
)
type nullLiteral struct {
ExpressionInterfaceImpl
}

19
internal/jet/logger.go Normal file
View file

@ -0,0 +1,19 @@
package jet
import "context"
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement interface {
Sql() (query string, args []interface{})
DebugSql() (query string)
}
// LoggerFunc is a definition of a function user can implement to support automatic statement logging.
type LoggerFunc func(ctx context.Context, statement PrintableStatement)
var logger LoggerFunc
// SetLoggerFunc sets automatic statement logging
func SetLoggerFunc(loggerFunc LoggerFunc) {
logger = loggerFunc
}

View file

@ -147,7 +147,7 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o
out.WriteString("(CASE")
if c.expression != nil {
c.expression.serialize(statement, out)
c.expression.serialize(statement, out, FallTrough(options)...)
}
if len(c.when) == 0 || len(c.then) == 0 {
@ -160,15 +160,15 @@ func (c *caseOperatorImpl) serialize(statement StatementType, out *SQLBuilder, o
for i, when := range c.when {
out.WriteString("WHEN")
when.serialize(statement, out, noWrap)
when.serialize(statement, out, NoWrap)
out.WriteString("THEN")
c.then[i].serialize(statement, out, noWrap)
c.then[i].serialize(statement, out, NoWrap)
}
if c.els != nil {
out.WriteString("ELSE")
c.els.serialize(statement, out, noWrap)
c.els.serialize(statement, out, NoWrap)
}
out.WriteString("END)")

View file

@ -8,35 +8,31 @@ type SelectTable interface {
}
type selectTableImpl struct {
selectStmt StatementWithProjections
selectStmt SerializerStatement
alias string
projections ProjectionList
}
// NewSelectTable func
func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTable {
selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias}
projectionList := selectStmt.projections().fromImpl(&selectTable)
selectTable.projections = projectionList.(ProjectionList)
return &selectTable
func NewSelectTable(selectStmt SerializerStatement, alias string) SelectTable {
selectTable := &selectTableImpl{selectStmt: selectStmt, alias: alias}
return selectTable
}
func (s *selectTableImpl) Alias() string {
func (s selectTableImpl) Alias() string {
return s.alias
}
func (s *selectTableImpl) AllColumns() ProjectionList {
return s.projections
}
func (s *selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if s == nil {
panic("jet: expression table is nil. ")
func (s selectTableImpl) AllColumns() ProjectionList {
statementWithProjections, ok := s.selectStmt.(HasProjections)
if !ok {
return ProjectionList{}
}
projectionList := statementWithProjections.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.selectStmt.serialize(statement, out)
out.WriteString("AS")

View file

@ -5,9 +5,18 @@ type SerializeOption int
// Serialize options
const (
noWrap SerializeOption = iota
NoWrap SerializeOption = iota
SkipNewLine
fallTroughOptions // fall trough options
ShortName
)
// WithFallTrough extends existing serialize options with additional
func (s SerializeOption) WithFallTrough(options []SerializeOption) []SerializeOption {
return append(FallTrough(options), s)
}
// StatementType is type of the SQL statement
type StatementType string
@ -20,6 +29,7 @@ const (
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
WithStatementType StatementType = "WITH"
)
// Serializer interface
@ -42,6 +52,19 @@ func contains(options []SerializeOption, option SerializeOption) bool {
return false
}
// FallTrough filters fall-trough options from the list
func FallTrough(options []SerializeOption) []SerializeOption {
var ret []SerializeOption
for _, option := range options {
if option > fallTroughOptions {
ret = append(ret, option)
}
}
return ret
}
// ListSerializer serializes list of serializers with separator
type ListSerializer struct {
Serializers []Serializer
@ -53,6 +76,21 @@ func (s ListSerializer) serialize(statement StatementType, out *SQLBuilder, opti
if i > 0 {
out.WriteString(s.Separator)
}
ser.serialize(statement, out)
ser.serialize(statement, out, FallTrough(options)...)
}
}
// NewSerializerClauseImpl is constructor for Seralizer with list of clauses
func NewSerializerClauseImpl(clauses ...Clause) Serializer {
return &serializerImpl{Clauses: clauses}
}
type serializerImpl struct {
Clauses []Clause
}
func (s serializerImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
for _, clause := range s.Clauses {
clause.Serialize(statement, out, FallTrough(options)...)
}
}

View file

@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) {
// WriteIdentifier adds identifier to output SQL
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
if s.shouldQuote(name, alwaysQuote...) {
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
s.WriteString(identQuoteChar + name + identQuoteChar)
} else {
@ -106,6 +106,10 @@ func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
}
}
func (s *SQLBuilder) shouldQuote(name string, alwaysQuote ...bool) bool {
return s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0
}
// WriteByte writes byte to output SQL
func (s *SQLBuilder) WriteByte(b byte) {
s.write([]byte{b})
@ -159,10 +163,17 @@ func argToString(value interface{}) string {
case time.Time:
return stringQuote(string(pq.FormatTimestamp(bindVal)))
default:
if strBindValue, ok := bindVal.(toStringInterface); ok {
return stringQuote(strBindValue.String())
}
panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String()))
}
}
type toStringInterface interface {
String() string
}
func integerTypesToString(value interface{}) string {
switch bindVal := value.(type) {
case int:
@ -190,6 +201,13 @@ func integerTypesToString(value interface{}) string {
}
func shouldQuoteIdentifier(identifier string) bool {
_, err := strconv.ParseInt(identifier, 10, 64)
if err == nil { // if it is a number we should quote it
return true
}
// check if contains non ascii characters
for _, c := range identifier {
if unicode.IsNumber(c) || c == '_' {
continue

View file

@ -2,42 +2,58 @@ package jet
import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(true), "TRUE")
assert.Equal(t, argToString(false), "FALSE")
require.Equal(t, argToString(true), "TRUE")
require.Equal(t, argToString(false), "FALSE")
assert.Equal(t, argToString(int(-32)), "-32")
assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(int8(-43)), "-43")
assert.Equal(t, argToString(uint8(43)), "43")
assert.Equal(t, argToString(int16(-54)), "-54")
assert.Equal(t, argToString(uint16(54)), "54")
assert.Equal(t, argToString(int32(-65)), "-65")
assert.Equal(t, argToString(uint32(65)), "65")
assert.Equal(t, argToString(int64(-64)), "-64")
assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float32(2.0)), "2")
assert.Equal(t, argToString(float64(1.11)), "1.11")
require.Equal(t, argToString(int(-32)), "-32")
require.Equal(t, argToString(uint(32)), "32")
require.Equal(t, argToString(int8(-43)), "-43")
require.Equal(t, argToString(uint8(43)), "43")
require.Equal(t, argToString(int16(-54)), "-54")
require.Equal(t, argToString(uint16(54)), "54")
require.Equal(t, argToString(int32(-65)), "-65")
require.Equal(t, argToString(uint32(65)), "65")
require.Equal(t, argToString(int64(-64)), "-64")
require.Equal(t, argToString(uint64(64)), "64")
require.Equal(t, argToString(float32(2.0)), "2")
require.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'")
assert.Equal(t, argToString("It's text"), "'It''s text'")
assert.Equal(t, argToString([]byte("john")), "'john'")
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
require.Equal(t, argToString("john"), "'john'")
require.Equal(t, argToString("It's text"), "'It''s text'")
require.Equal(t, argToString([]byte("john")), "'john'")
require.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
assert.NoError(t, err)
assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
require.NoError(t, err)
require.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
func() {
defer func() {
assert.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter")
require.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter")
}()
argToString(map[string]bool{})
}()
}
func TestFallTrough(t *testing.T) {
require.Equal(t, FallTrough([]SerializeOption{ShortName}), []SerializeOption{ShortName})
require.Equal(t, FallTrough([]SerializeOption{SkipNewLine}), []SerializeOption(nil))
require.Equal(t, FallTrough([]SerializeOption{ShortName, SkipNewLine}), []SerializeOption{ShortName})
}
func TestShouldQuote(t *testing.T) {
require.Equal(t, shouldQuoteIdentifier("123"), true)
require.Equal(t, shouldQuoteIdentifier("123.235"), true)
require.Equal(t, shouldQuoteIdentifier("abc123"), false)
require.Equal(t, shouldQuoteIdentifier("abc.123"), true)
require.Equal(t, shouldQuoteIdentifier("abc_123"), false)
require.Equal(t, shouldQuoteIdentifier("Abc_123"), true)
require.Equal(t, shouldQuoteIdentifier("DŽƜĐǶ"), true)
}

View file

@ -13,7 +13,6 @@ type Statement interface {
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes.
DebugSql() (query string)
// Query executes statement over database connection db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
@ -21,25 +20,19 @@ type Statement interface {
// QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
QueryContext(context context.Context, db qrm.DB, destination interface{}) error
QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows.
Exec(db qrm.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows.
ExecContext(context context.Context, db qrm.DB) (sql.Result, error)
ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error)
}
// SerializerStatement interface
type SerializerStatement interface {
Serializer
Statement
}
// StatementWithProjections interface
type StatementWithProjections interface {
Statement
HasProjections
Serializer
}
// HasProjections interface
@ -58,7 +51,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface
queryData := &SQLBuilder{Dialect: s.dialect}
s.parent.serialize(s.statementType, queryData, noWrap)
s.parent.serialize(s.statementType, queryData, NoWrap)
query, args = queryData.finalize()
return
@ -67,7 +60,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface
func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true}
s.parent.serialize(s.statementType, sqlBuilder, noWrap)
s.parent.serialize(s.statementType, sqlBuilder, NoWrap)
query, _ = sqlBuilder.finalize()
return
@ -75,25 +68,41 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error {
query, args := s.Sql()
ctx := context.Background()
return qrm.Query(context.Background(), db, query, args, destination)
callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination)
}
func (s *serializerStatementInterfaceImpl) QueryContext(context context.Context, db qrm.DB, destination interface{}) error {
func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error {
query, args := s.Sql()
return qrm.Query(context, db, query, args, destination)
callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination)
}
func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) {
query, args := s.Sql()
callLogger(context.Background(), s)
return db.Exec(query, args...)
}
func (s *serializerStatementInterfaceImpl) ExecContext(context context.Context, db qrm.DB) (res sql.Result, err error) {
func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) {
query, args := s.Sql()
return db.ExecContext(context, query, args...)
callLogger(ctx, s)
return db.ExecContext(ctx, query, args...)
}
func callLogger(ctx context.Context, statement Statement) {
if logger != nil {
logger(ctx, statement)
}
}
// ExpressionStatement interfacess
@ -148,7 +157,7 @@ type statementImpl struct {
func (s *statementImpl) projections() ProjectionList {
for _, clause := range s.Clauses {
if selectClause, ok := clause.(ClauseWithProjections); ok {
return selectClause.projections()
return selectClause.Projections()
}
}
@ -156,17 +165,16 @@ func (s *statementImpl) projections() ProjectionList {
}
func (s *statementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
if !contains(options, NoWrap) {
out.WriteString("(")
out.IncreaseIdent()
}
for _, clause := range s.Clauses {
clause.Serialize(statement, out)
clause.Serialize(statement, out, FallTrough(options)...)
}
if !contains(options, noWrap) {
if !contains(options, NoWrap) {
out.DecreaseIdent()
out.NewLine()
out.WriteString(")")

View file

@ -19,17 +19,15 @@ type Table interface {
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, column ColumnExpression, columns ...ColumnExpression) SerializerTable {
columnList := append([]ColumnExpression{column}, columns...)
func NewTable(schemaName, name string, columns ...ColumnExpression) SerializerTable {
t := tableImpl{
schemaName: schemaName,
name: name,
columnList: columnList,
columnList: columns,
}
for _, c := range columnList {
for _, c := range columns {
c.setTableName(name)
}
@ -156,7 +154,7 @@ func (t *joinTableImpl) serialize(statement StatementType, out *SQLBuilder, opti
panic("jet: left hand side of join operation is nil table")
}
t.lhs.serialize(statement, out)
t.lhs.serialize(statement, out, FallTrough(options)...)
out.NewLine()

View file

@ -1,18 +1,18 @@
package jet
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestNewTable(t *testing.T) {
newTable := NewTable("schema", "table", IntegerColumn("intCol"))
assert.Equal(t, newTable.SchemaName(), "schema")
assert.Equal(t, newTable.TableName(), "table")
require.Equal(t, newTable.SchemaName(), "schema")
require.Equal(t, newTable.TableName(), "table")
assert.Equal(t, len(newTable.columns()), 1)
assert.Equal(t, newTable.columns()[0].Name(), "intCol")
require.Equal(t, len(newTable.columns()), 1)
require.Equal(t, newTable.columns()[0].Name(), "intCol")
}
func TestNewJoinTable(t *testing.T) {
@ -24,10 +24,10 @@ func TestNewJoinTable(t *testing.T) {
assertClauseSerialize(t, joinTable, `schema.table
INNER JOIN schema.table2 ON ("intCol1" = "intCol2")`)
assert.Equal(t, joinTable.SchemaName(), "schema")
assert.Equal(t, joinTable.TableName(), "")
require.Equal(t, joinTable.SchemaName(), "schema")
require.Equal(t, joinTable.TableName(), "")
assert.Equal(t, len(joinTable.columns()), 2)
assert.Equal(t, joinTable.columns()[0].Name(), "intCol1")
assert.Equal(t, joinTable.columns()[1].Name(), "intCol2")
require.Equal(t, len(joinTable.columns()), 2)
require.Equal(t, joinTable.columns()[0].Name(), "intCol1")
require.Equal(t, joinTable.columns()[1].Name(), "intCol2")
}

View file

@ -1,7 +1,7 @@
package jet
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"strconv"
"testing"
)
@ -56,14 +56,14 @@ func assertClauseSerialize(t *testing.T, clause Serializer, query string, args .
//fmt.Println(out.Buff.String())
assert.Equal(t, out.Buff.String(), query)
assert.Equal(t, out.Args, args)
require.Equal(t, out.Buff.String(), query)
require.Equal(t, out.Args, args)
}
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
require.Equal(t, r, errString)
}()
out := SQLBuilder{Dialect: defaultDialect}
@ -76,14 +76,14 @@ func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, a
//fmt.Println(out.Buff.String())
assert.Equal(t, out.Buff.String(), query)
assert.Equal(t, out.Args, args)
require.Equal(t, out.Buff.String(), query)
require.Equal(t, out.Args, args)
}
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SQLBuilder{Dialect: defaultDialect}
projection.serializeForProjection(SelectStatementType, &out)
assert.Equal(t, out.Buff.String(), query)
assert.Equal(t, out.Args, args)
require.Equal(t, out.Buff.String(), query)
require.Equal(t, out.Args, args)
}

View file

@ -59,7 +59,23 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) {
panic("jet: nil column in columns list")
}
out.WriteString(col.Name())
out.WriteIdentifier(col.Name())
}
}
// SerializeColumnExpressionNames func
func SerializeColumnExpressionNames(columns []ColumnExpression, statementType StatementType,
out *SQLBuilder, options ...SerializeOption) {
for i, col := range columns {
if i > 0 {
out.WriteString(", ")
}
if col == nil {
panic("jet: nil column in columns list")
}
col.serialize(statementType, out, options...)
}
}
@ -85,7 +101,8 @@ func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
return ret
}
func valueToClause(value interface{}) Serializer {
// ToSerializerValue creates Serializer type from the value
func ToSerializerValue(value interface{}) Serializer {
if clause, ok := value.(Serializer); ok {
return clause
}
@ -148,7 +165,7 @@ func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer {
allValues := append([]interface{}{value}, values...)
for _, val := range allValues {
row = append(row, valueToClause(val))
row = append(row, ToSerializerValue(val))
}
return row

View file

@ -1,19 +1,19 @@
package jet
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestOptionalOrDefaultString(t *testing.T) {
assert.Equal(t, OptionalOrDefaultString("default"), "default")
assert.Equal(t, OptionalOrDefaultString("default", "optional"), "optional")
require.Equal(t, OptionalOrDefaultString("default"), "default")
require.Equal(t, OptionalOrDefaultString("default", "optional"), "optional")
}
func TestOptionalOrDefaultExpression(t *testing.T) {
defaultExpression := table2ColFloat
optionalExpression := table1Col1
assert.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression)
assert.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression)
require.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression)
require.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression)
}

View file

@ -17,7 +17,7 @@ func (w *commonWindowImpl) serialize(statement StatementType, out *SQLBuilder, o
w.expression.serialize(statement, out)
if w.window != nil {
out.WriteString("OVER")
w.window.serialize(statement, out)
w.window.serialize(statement, out, FallTrough(options)...)
}
}
@ -49,7 +49,7 @@ func (f *windowExpressionImpl) OVER(window ...Window) Expression {
}
func (f *windowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
}
// -----------------------------------------------------
@ -80,7 +80,7 @@ func (f *floatWindowExpressionImpl) OVER(window ...Window) FloatExpression {
}
func (f *floatWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
}
// ------------------------------------------------
@ -111,7 +111,7 @@ func (f *integerWindowExpressionImpl) OVER(window ...Window) IntegerExpression {
}
func (f *integerWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
}
// ------------------------------------------------
@ -142,5 +142,5 @@ func (f *boolWindowExpressionImpl) OVER(window ...Window) BoolExpression {
}
func (f *boolWindowExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
f.commonWindowImpl.serialize(statement, out)
f.commonWindowImpl.serialize(statement, out, FallTrough(options)...)
}

View file

@ -30,7 +30,7 @@ func newWindowImpl(parent Window) *windowImpl {
}
func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, noWrap) {
if !contains(options, NoWrap) {
out.WriteByte('(')
}
@ -40,7 +40,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options
serializeExpressionList(statement, w.partitionBy, ", ", out)
}
w.orderBy.SkipNewLine = true
w.orderBy.Serialize(statement, out)
w.orderBy.Serialize(statement, out, FallTrough(options)...)
if w.frameUnits != "" {
out.WriteString(w.frameUnits)
@ -55,7 +55,7 @@ func (w *windowImpl) serialize(statement StatementType, out *SQLBuilder, options
}
}
if !contains(options, noWrap) {
if !contains(options, NoWrap) {
out.WriteByte(')')
}
}
@ -139,7 +139,7 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op
if f == nil {
return
}
f.offset.serialize(statement, out)
f.offset.serialize(statement, out, FallTrough(options)...)
if f.preceding {
out.WriteString("PRECEDING")
@ -152,12 +152,12 @@ func (f *frameExtentImpl) serialize(statement StatementType, out *SQLBuilder, op
// Window function keywords
var (
UNBOUNDED = keywordClause("UNBOUNDED")
UNBOUNDED = Keyword("UNBOUNDED")
CURRENT_ROW = frameExtentKeyword{"CURRENT ROW"}
)
type frameExtentKeyword struct {
keywordClause
Keyword
}
func (f frameExtentKeyword) isFrameExtent() {}
@ -180,7 +180,7 @@ func (w windowName) serialize(statement StatementType, out *SQLBuilder, options
out.WriteByte('(')
out.WriteString(w.name)
w.windowImpl.serialize(statement, out, noWrap)
w.windowImpl.serialize(statement, out, NoWrap.WithFallTrough(options)...)
out.WriteByte(')')
}

View file

@ -0,0 +1,82 @@
package jet
// WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement {
newWithImpl := &withImpl{
ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
statementType: WithStatementType,
},
}
newWithImpl.parent = newWithImpl
return func(primaryStatement Statement) Statement {
serializerStatement, ok := primaryStatement.(SerializerStatement)
if !ok {
panic("jet: unsupported main WITH statement.")
}
newWithImpl.primaryStatement = serializerStatement
return newWithImpl
}
}
type withImpl struct {
serializerStatementInterfaceImpl
ctes []CommonTableExpressionDefinition
primaryStatement SerializerStatement
}
func (w withImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.NewLine()
out.WriteString("WITH")
for i, cte := range w.ctes {
if i > 0 {
out.WriteString(",")
}
cte.serialize(statement, out, FallTrough(options)...)
}
w.primaryStatement.serialize(statement, out, NoWrap.WithFallTrough(options)...)
}
func (w withImpl) projections() ProjectionList {
return ProjectionList{}
}
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
selectTableImpl
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
return CommonTableExpression{
selectTableImpl: selectTableImpl{
selectStmt: nil,
alias: name,
},
}
}
func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.alias)
}
// AS returns sets definition for a CTE
func (c *CommonTableExpression) AS(statement SerializerStatement) CommonTableExpressionDefinition {
c.selectStmt = statement
return CommonTableExpressionDefinition{cte: c}
}
// CommonTableExpressionDefinition contains implementation details of CTE
type CommonTableExpressionDefinition struct {
cte *CommonTableExpression
}
func (c CommonTableExpressionDefinition) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteIdentifier(c.cte.alias)
out.WriteString("AS")
c.cte.selectStmt.serialize(statement, out, FallTrough(options)...)
}

View file

@ -7,12 +7,14 @@ import (
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/qrm"
"github.com/stretchr/testify/assert"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
@ -21,12 +23,12 @@ import (
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
rows, err := res.RowsAffected()
assert.NoError(t, err)
require.NoError(t, err)
if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0])
require.Equal(t, rowsAffected[0], rows)
}
}
@ -34,7 +36,7 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int
func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) {
_, err := stmt.Exec(db)
assert.Error(t, err, errorStr)
require.Error(t, err, errorStr)
}
func getFullPath(relativePath string) string {
@ -51,9 +53,9 @@ func PrintJson(v interface{}) {
// AssertJSON check if data json output is the same as expectedJSON
func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
require.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
}
// SaveJSONFile saves v as json at testRelativePath
@ -71,23 +73,23 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath)
assert.NoError(t, err)
require.NoError(t, err)
if runtime.GOOS == "windows" {
fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1)
}
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NoError(t, err)
require.NoError(t, err)
assert.True(t, string(fileJSONData) == string(jsonData))
require.True(t, string(fileJSONData) == string(jsonData))
//AssertDeepEqual(t, string(fileJSONData), string(jsonData))
}
// AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs
func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args := query.Sql()
assert.Equal(t, queryStr, expectedQuery)
require.Equal(t, queryStr, expectedQuery)
if len(expectedArgs) == 0 {
return
@ -99,7 +101,7 @@ func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string,
func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) {
defer func() {
r := recover()
assert.Equal(t, r, errorStr)
require.Equal(t, r, errorStr)
}()
stmt.Sql()
@ -110,17 +112,17 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
_, args := query.Sql()
if len(expectedArgs) > 0 {
AssertDeepEqual(t, args, expectedArgs)
AssertDeepEqual(t, args, expectedArgs, "arguments are not equal")
}
debuqSql := query.DebugSql()
assert.Equal(t, debuqSql, expectedQuery)
require.Equal(t, debuqSql, expectedQuery)
}
// AssertClauseSerialize checks if clause serialize produces expected query and args
func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
// AssertSerialize checks if clause serialize produces expected query and args
func AssertSerialize(t *testing.T, dialect jet.Dialect, serializer jet.Serializer, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect}
jet.Serialize(clause, jet.SelectStatementType, &out)
jet.Serialize(serializer, jet.SelectStatementType, &out)
//fmt.Println(out.Buff.String())
@ -131,8 +133,20 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali
}
}
// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args
func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
// AssertClauseSerialize checks if clause serialize produces expected query and args
func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect}
clause.Serialize(jet.SelectStatementType, &out)
require.Equal(t, out.Buff.String(), query)
if len(args) > 0 {
AssertDeepEqual(t, out.Args, args)
}
}
// AssertDebugSerialize checks if clause serialize produces expected debug query and args
func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect, Debug: true}
jet.Serialize(clause, jet.SelectStatementType, &out)
@ -147,17 +161,17 @@ func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Se
func AssertPanicErr(t *testing.T, fun func(), errorStr string) {
defer func() {
r := recover()
assert.Equal(t, r, errorStr)
require.Equal(t, r, errorStr)
}()
fun()
}
// AssertClauseSerializeErr check if clause serialize panics with errString
func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) {
// AssertSerializeErr check if clause serialize panics with errString
func AssertSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
require.Equal(t, r, errString)
}()
out := jet.SQLBuilder{Dialect: dialect}
@ -177,28 +191,24 @@ func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet
func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest interface{}, errString string) {
defer func() {
r := recover()
assert.Equal(t, r, errString)
require.Equal(t, r, errString)
}()
stmt.Query(db, dest)
}
// AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
assert.NoError(t, err)
require.NoError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
require.Equal(t, "\n"+string(enumFileData), expectedContent)
}
// AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...string) {
assert.Equal(t, len(fileInfos), len(fileNames))
require.Equal(t, len(fileInfos), len(fileNames))
fileNamesMap := map[string]bool{}
@ -207,11 +217,88 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
}
for _, fileName := range fileNames {
assert.True(t, fileNamesMap[fileName], fileName+" does not exist.")
require.True(t, fileNamesMap[fileName], fileName+" does not exist.")
}
}
// AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}) {
assert.True(t, cmp.Equal(actual, expected))
func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) {
require.True(t, cmp.Equal(actual, expected), msg)
}
// BoolPtr returns address of bool parameter
func BoolPtr(b bool) *bool {
return &b
}
// Int8Ptr returns address of int8 parameter
func Int8Ptr(i int8) *int8 {
return &i
}
// UInt8Ptr returns address of uint8 parameter
func UInt8Ptr(i uint8) *uint8 {
return &i
}
// Int16Ptr returns address of int16 parameter
func Int16Ptr(i int16) *int16 {
return &i
}
// UInt16Ptr returns address of uint16 parameter
func UInt16Ptr(i uint16) *uint16 {
return &i
}
// Int32Ptr returns address of int32 parameter
func Int32Ptr(i int32) *int32 {
return &i
}
// UInt32Ptr returns address of uint32 parameter
func UInt32Ptr(i uint32) *uint32 {
return &i
}
// Int64Ptr returns address of int64 parameter
func Int64Ptr(i int64) *int64 {
return &i
}
// UInt64Ptr returns address of uint64 parameter
func UInt64Ptr(i uint64) *uint64 {
return &i
}
// StringPtr returns address of string parameter
func StringPtr(s string) *string {
return &s
}
// TimePtr returns address of time.Time parameter
func TimePtr(t time.Time) *time.Time {
return &t
}
// ByteArrayPtr returns address of []byte parameter
func ByteArrayPtr(arr []byte) *[]byte {
return &arr
}
// Float32Ptr returns address of float32 parameter
func Float32Ptr(f float32) *float32 {
return &f
}
// Float64Ptr returns address of float64 parameter
func Float64Ptr(f float64) *float64 {
return &f
}
// UUIDPtr returns address of uuid.UUID
func UUIDPtr(u string) *uuid.UUID {
newUUID := uuid.MustParse(u)
return &newUUID
}

View file

@ -2,32 +2,32 @@ package utils
import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestToGoIdentifier(t *testing.T) {
assert.Equal(t, ToGoIdentifier(""), "")
assert.Equal(t, ToGoIdentifier("uuid"), "UUID")
assert.Equal(t, ToGoIdentifier("col1"), "Col1")
assert.Equal(t, ToGoIdentifier("PG-13"), "Pg13")
assert.Equal(t, ToGoIdentifier("13_pg"), "13Pg")
require.Equal(t, ToGoIdentifier(""), "")
require.Equal(t, ToGoIdentifier("uuid"), "UUID")
require.Equal(t, ToGoIdentifier("col1"), "Col1")
require.Equal(t, ToGoIdentifier("PG-13"), "Pg13")
require.Equal(t, ToGoIdentifier("13_pg"), "13Pg")
assert.Equal(t, ToGoIdentifier("mytable"), "Mytable")
assert.Equal(t, ToGoIdentifier("MYTABLE"), "Mytable")
assert.Equal(t, ToGoIdentifier("MyTaBlE"), "MyTaBlE")
assert.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE")
require.Equal(t, ToGoIdentifier("mytable"), "Mytable")
require.Equal(t, ToGoIdentifier("MYTABLE"), "Mytable")
require.Equal(t, ToGoIdentifier("MyTaBlE"), "MyTaBlE")
require.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE")
assert.Equal(t, ToGoIdentifier("my_table"), "MyTable")
assert.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable")
assert.Equal(t, ToGoIdentifier("My_Table"), "MyTable")
assert.Equal(t, ToGoIdentifier("My Table"), "MyTable")
assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
require.Equal(t, ToGoIdentifier("my_table"), "MyTable")
require.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable")
require.Equal(t, ToGoIdentifier("My_Table"), "MyTable")
require.Equal(t, ToGoIdentifier("My Table"), "MyTable")
require.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
}
func TestToGoEnumValueIdentifier(t *testing.T) {
assert.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue")
assert.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100")
require.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue")
require.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100")
}
func TestErrorCatchErr(t *testing.T) {
@ -39,7 +39,7 @@ func TestErrorCatchErr(t *testing.T) {
panic(fmt.Errorf("newError"))
}()
assert.Error(t, err, "newError")
require.Error(t, err, "newError")
}
func TestErrorCatchNonErr(t *testing.T) {
@ -51,5 +51,5 @@ func TestErrorCatchNonErr(t *testing.T) {
panic(11)
}()
assert.Error(t, err, "11")
require.Error(t, err, "11")
}

View file

@ -26,6 +26,7 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(int) string {
return "?"
},
ReservedWords: reservedWords,
}
return jet.NewDialect(mySQLDialectParams)
@ -160,3 +161,267 @@ func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFun
jet.Serialize(expressions[1], statement, out, options...)
}
}
var reservedWords = []string{
"ACCESSIBLE",
"ADD",
"ALL",
"ALTER",
"ANALYZE",
"AND",
"AS",
"ASC",
"ASENSITIVE",
"BEFORE",
"BETWEEN",
"BIGINT",
"BINARY",
"BLOB",
"BOTH",
"BY",
"CALL",
"CASCADE",
"CASE",
"CHANGE",
"CHAR",
"CHARACTER",
"CHECK",
"COLLATE",
"COLUMN",
"CONDITION",
"CONSTRAINT",
"CONTINUE",
"CONVERT",
"CREATE",
"CROSS",
"CUBE",
"CUME_DIST",
"CURRENT_DATE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_USER",
"CURSOR",
"DATABASE",
"DATABASES",
"DAY_HOUR",
"DAY_MICROSECOND",
"DAY_MINUTE",
"DAY_SECOND",
"DEC",
"DECIMAL",
"DECLARE",
"DEFAULT",
"DELAYED",
"DELETE",
"DENSE_RANK",
"DESC",
"DESCRIBE",
"DETERMINISTIC",
"DISTINCT",
"DISTINCTROW",
"DIV",
"DOUBLE",
"DROP",
"DUAL",
"EACH",
"ELSE",
"ELSEIF",
"EMPTY",
"ENCLOSED",
"ESCAPED",
"EXCEPT",
"EXISTS",
"EXIT",
"EXPLAIN",
"FALSE",
"FETCH",
"FIRST_VALUE",
"FLOAT",
"FLOAT4",
"FLOAT8",
"FOR",
"FORCE",
"FOREIGN",
"FROM",
"FULLTEXT",
"FUNCTION",
"GENERATED",
"GET",
"GRANT",
"GROUP",
"GROUPING",
"GROUPS",
"HAVING",
"HIGH_PRIORITY",
"HOUR_MICROSECOND",
"HOUR_MINUTE",
"HOUR_SECOND",
"IF",
"IGNORE",
"IN",
"INDEX",
"INFILE",
"INNER",
"INOUT",
"INSENSITIVE",
"INSERT",
"INT",
"INT1",
"INT2",
"INT3",
"INT4",
"INT8",
"INTEGER",
"INTERVAL",
"INTO",
"IO_AFTER_GTIDS",
"IO_BEFORE_GTIDS",
"IS",
"ITERATE",
"JOIN",
"JSON_TABLE",
"KEY",
"KEYS",
"KILL",
"LAG",
"LAST_VALUE",
"LATERAL",
"LEAD",
"LEADING",
"LEAVE",
"LEFT",
"LIKE",
"LIMIT",
"LINEAR",
"LINES",
"LOAD",
"LOCALTIME",
"LOCALTIMESTAMP",
"LOCK",
"LONG",
"LONGBLOB",
"LONGTEXT",
"LOOP",
"LOW_PRIORITY",
"MASTER_BIND",
"MASTER_SSL_VERIFY_SERVER_CERT",
"MATCH",
"MAXVALUE",
"MEDIUMBLOB",
"MEDIUMINT",
"MEDIUMTEXT",
"MIDDLEINT",
"MINUTE_MICROSECOND",
"MINUTE_SECOND",
"MOD",
"MODIFIES",
"NATURAL",
"NOT",
"NO_WRITE_TO_BINLOG",
"NTH_VALUE",
"NTILE",
"NULL",
"NUMERIC",
"OF",
"ON",
"OPTIMIZE",
"OPTIMIZER_COSTS",
"OPTION",
"OPTIONALLY",
"OR",
"ORDER",
"OUT",
"OUTER",
"OUTFILE",
"OVER",
"PARTITION",
"PERCENT_RANK",
"PRECISION",
"PRIMARY",
"PROCEDURE",
"PURGE",
"RANGE",
"RANK",
"READ",
"READS",
"READ_WRITE",
"REAL",
"RECURSIVE",
"REFERENCES",
"REGEXP",
"RELEASE",
"RENAME",
"REPEAT",
"REPLACE",
"REQUIRE",
"RESIGNAL",
"RESTRICT",
"RETURN",
"REVOKE",
"RIGHT",
"RLIKE",
"ROW",
"ROWS",
"ROW_NUMBER",
"SCHEMA",
"SCHEMAS",
"SECOND_MICROSECOND",
"SELECT",
"SENSITIVE",
"SEPARATOR",
"SET",
"SHOW",
"SIGNAL",
"SMALLINT",
"SPATIAL",
"SPECIFIC",
"SQL",
"SQLEXCEPTION",
"SQLSTATE",
"SQLWARNING",
"SQL_BIG_RESULT",
"SQL_CALC_FOUND_ROWS",
"SQL_SMALL_RESULT",
"SSL",
"STARTING",
"STORED",
"STRAIGHT_JOIN",
"SYSTEM",
"TABLE",
"TERMINATED",
"THEN",
"TINYBLOB",
"TINYINT",
"TINYTEXT",
"TO",
"TRAILING",
"TRIGGER",
"TRUE",
"UNDO",
"UNION",
"UNIQUE",
"UNLOCK",
"UNSIGNED",
"UPDATE",
"USAGE",
"USE",
"USING",
"UTC_DATE",
"UTC_TIME",
"UTC_TIMESTAMP",
"VALUES",
"VARBINARY",
"VARCHAR",
"VARCHARACTER",
"VARYING",
"VIRTUAL",
"WHEN",
"WHERE",
"WHILE",
"WINDOW",
"WITH",
"WRITE",
"XOR",
"YEAR_MONTH",
"ZEROFILL",
}

View file

@ -85,6 +85,9 @@ var MINi = jet.MINi
// MINf is aggregate function. Returns minimum value of float expression across all input values
var MINf = jet.MINf
// SUM is aggregate function. Returns sum of all expressions
var SUM = jet.SUM
// SUMi is aggregate function. Returns sum of integer expression.
var SUMi = jet.SUMi

View file

@ -13,13 +13,15 @@ type InsertStatement interface {
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
}
func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery)
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnDuplicateKey)
newInsert.Insert.Table = table
newInsert.Insert.Columns = columns
@ -30,26 +32,55 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
type insertStatementImpl struct {
jet.SerializerStatement
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
OnDuplicateKey onDuplicateKeyUpdateClause
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return i
func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return is
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data))
return i
func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data))
return is
}
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...)
return i
func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...)
return is
}
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.ValuesQuery.Query = selectStatement
return i
func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
is.OnDuplicateKey = assigments
return is
}
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
is.ValuesQuery.Query = selectStatement
return is
}
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
// Serialize for SetClause
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("ON DUPLICATE KEY UPDATE")
out.IncreaseIdent(24)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
}
out.DecreaseIdent(24)
}

View file

@ -1,7 +1,7 @@
package mysql
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
@ -13,15 +13,15 @@ func TestInvalidInsert(t *testing.T) {
func TestInsertNilValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1) VALUES
(?);
INSERT INTO db.table1 (col1)
VALUES (?);
`, nil)
}
func TestInsertSingleValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1) VALUES
(?);
INSERT INTO db.table1 (col1)
VALUES (?);
`, int(1))
}
@ -31,8 +31,8 @@ func TestInsertWithColumnList(t *testing.T) {
columnList = append(columnList, table3StrCol)
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2) VALUES
(?, ?);
INSERT INTO db.table3 (col_int, col2)
VALUES (?, ?);
`, 1, 3)
}
@ -40,15 +40,15 @@ func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), `
INSERT INTO db.table1 (col_timestamp) VALUES
(?);
INSERT INTO db.table1 (col_timestamp)
VALUES (?);
`, date)
}
func TestInsertMultipleValues(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3) VALUES
(?, ?, ?);
INSERT INTO db.table1 (col1, col_float, col3)
VALUES (?, ?, ?);
`, 1, 2, 3)
}
@ -59,10 +59,10 @@ func TestInsertMultipleRows(t *testing.T) {
VALUES(111, 222)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float) VALUES
(?, ?),
(?, ?),
(?, ?);
INSERT INTO db.table1 (col1, col_float)
VALUES (?, ?),
(?, ?),
(?, ?);
`, 1, 2, 11, 22, 111, 222)
}
@ -84,9 +84,9 @@ func TestInsertValuesFromModel(t *testing.T) {
MODEL(&toInsert)
expectedSQL := `
INSERT INTO db.table1 (col1, col_float) VALUES
(?, ?),
(?, ?);
INSERT INTO db.table1 (col1, col_float)
VALUES (?, ?),
(?, ?);
`
assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
@ -95,7 +95,7 @@ INSERT INTO db.table1 (col1, col_float) VALUES
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : col1")
require.Equal(t, r, "missing struct field for column : col1")
}()
type Table1Model struct {
Col1Prim int
@ -116,7 +116,7 @@ func TestInsertFromNonStructModel(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "jet: data has to be a struct")
require.Equal(t, r, "jet: data has to be a struct")
}()
table2.INSERT(table2ColInt).MODEL([]int{})
@ -127,9 +127,56 @@ func TestInsertDefaultValue(t *testing.T) {
VALUES(DEFAULT, "two")
var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES
(DEFAULT, ?);
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?);
`
assertStatementSql(t, stmt, expectedSQL, "two")
}
func TestInsertOnDuplicateKeyUpdate(t *testing.T) {
stmt := func() InsertStatement {
return table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
}
t.Run("empty list", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE()
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?);
`, "two")
})
t.Run("one set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(table1ColFloat.SET(Float(11.1)))
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_float = ?;
`, "two", 11.1)
})
t.Run("all types set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(
table1ColBool.SET(Bool(true)),
table1ColInt.SET(Int(11)),
table1ColFloat.SET(Float(11.1)),
table1ColString.SET(String("str")),
table1ColTime.SET(Time(11, 23, 11)),
table1ColTimestamp.SET(Timestamp(2020, 1, 22, 3, 4, 5)),
table1ColDate.SET(Date(2020, 12, 1)),
)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_bool = ?,
col_int = ?,
col_float = ?,
col_string = ?,
col_time = CAST(? AS TIME),
col_timestamp = TIMESTAMP(?),
col_date = CAST(? AS DATE);
`, "two", true, int64(11), 11.1, "str", "11:23:11", "2020-01-22 03:04:05", "2020-12-01")
})
}

View file

@ -69,7 +69,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock)
newSelect.Select.Projections = projections
newSelect.Select.ProjectionList = projections
newSelect.From.Table = table
newSelect.Limit.Count = -1
newSelect.Offset.Count = -1

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable {
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
}

View file

@ -4,13 +4,13 @@ import "github.com/go-jet/jet/internal/jet"
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement {
func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
}
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement {
func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
}
@ -54,7 +54,7 @@ type setStatementImpl struct {
setOperator jet.ClauseSetStmtOperator
}
func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) setStatement {
func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement {
newSetStatement := &setStatementImpl{}
newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement,
&newSetStatement.setOperator)
@ -93,6 +93,6 @@ const (
union = "UNION"
)
func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections {
return append([]jet.StatementWithProjections{lhs, rhs}, selects...)
func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement {
return append([]jet.SerializerStatement{lhs, rhs}, selects...)
}

View file

@ -8,7 +8,7 @@ type Table interface {
readableTable
INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement
UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement
LOCK() LockStatement
}
@ -35,7 +35,7 @@ type readableTable interface {
type joinSelectUpdateTable interface {
ReadableTable
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement
UPDATE(columns ...jet.Column) UpdateStatement
}
// ReadableTable interface
@ -49,37 +49,37 @@ type readableTableInterfaceImpl struct {
}
// Generates a select query on the current tableName.
func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.InnerJoin, onCondition)
}
// Creates a left join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.LeftJoin, onCondition)
}
// Creates a right join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.RightJoin, onCondition)
}
func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.FullJoin, onCondition)
}
func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable {
func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.CrossJoin, nil)
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, column, columns...),
SerializerTable: jet.NewTable(schemaName, name, columns...),
}
t.readableTableInterfaceImpl.parent = t
@ -98,8 +98,8 @@ func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwindColumns(column, columns...))
func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) DELETE() DeleteStatement {

View file

@ -5,9 +5,9 @@ import (
)
func TestJoinNilInputs(t *testing.T) {
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table")
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil),
assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil")
}

View file

@ -10,3 +10,12 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement = jet.PrintableStatement
// SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc

View file

@ -16,14 +16,18 @@ type updateStatementImpl struct {
jet.SerializerStatement
Update jet.ClauseUpdate
Set jet.ClauseSet
Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere
}
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update,
&update.Set, &update.Where)
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Update,
&update.Set,
&update.SetNew,
&update.Where)
update.Update.Table = table
update.Set.Columns = columns
@ -33,7 +37,17 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
u.Set.Values = jet.UnwindRowFromValues(value, values)
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u
}

View file

@ -23,7 +23,7 @@ WHERE table1.col_int >= ?;
func TestUpdateWithValues(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?,
SET col_int = ?,
col_float = ?
WHERE table1.col_int >= ?;
`

View file

@ -7,12 +7,14 @@ import (
)
var table1Col1 = IntegerColumn("col1")
var table1ColBool = BoolColumn("col_bool")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1ColString = StringColumn("col_string")
var table1Col3 = IntegerColumn("col3")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1ColTime = TimeColumn("col_time")
var table1 = NewTable(
"db",
@ -20,10 +22,12 @@ var table1 = NewTable(
table1Col1,
table1ColInt,
table1ColFloat,
table1ColString,
table1Col3,
table1ColBool,
table1ColDate,
table1ColTimestamp,
table1ColTime,
)
var table2Col3 = IntegerColumn("col3")
@ -59,15 +63,15 @@ var table3 = NewTable(
table3StrCol)
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...)
testutils.AssertSerialize(t, Dialect, clause, query, args...)
}
func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertDebugClauseSerialize(t, Dialect, clause, query, args...)
testutils.AssertDebugSerialize(t, Dialect, clause, query, args...)
}
func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertClauseSerializeErr(t, Dialect, clause, errString)
func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertSerializeErr(t, Dialect, clause, errString)
}
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {

26
mysql/with_statement.go Normal file
View file

@ -0,0 +1,26 @@
package mysql
import "github.com/go-jet/jet/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
}
cte.parent = &cte
return cte
}

92
postgres/clause.go Normal file
View file

@ -0,0 +1,92 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
)
type clauseReturning struct {
ProjectionList []jet.Projection
}
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(r.ProjectionList) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.ProjectionList)
out.DecreaseIdent()
}
func (r clauseReturning) Projections() ProjectionList {
return r.ProjectionList
}
// ========================================== //
type onConflict interface {
ON_CONSTRAINT(name string) conflictTarget
WHERE(indexPredicate BoolExpression) conflictTarget
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type conflictTarget interface {
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type onConflictClause struct {
insertStatement InsertStatement
constraint string
indexExpressions []jet.ColumnExpression
whereClause jet.ClauseWhere
do jet.Serializer
}
func (o *onConflictClause) ON_CONSTRAINT(name string) conflictTarget {
o.constraint = name
return o
}
func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget {
o.whereClause.Condition = indexPredicate
return o
}
func (o *onConflictClause) DO_NOTHING() InsertStatement {
o.do = jet.Keyword("DO NOTHING")
return o.insertStatement
}
func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement {
o.do = action
return o.insertStatement
}
func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(o.indexExpressions) == 0 && o.constraint == "" {
return
}
out.NewLine()
out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 {
out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")")
}
if o.constraint != "" {
out.WriteString("ON CONSTRAINT")
out.WriteString(o.constraint)
}
o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName)
out.IncreaseIdent(7)
jet.Serialize(o.do, statementType, out)
out.DecreaseIdent(7)
}

35
postgres/clause_test.go Normal file
View file

@ -0,0 +1,35 @@
package postgres
import "testing"
func TestOnConflict(t *testing.T) {
assertClauseSerialize(t, &onConflictClause{}, "")
onConflict := &onConflictClause{}
onConflict.DO_NOTHING()
assertClauseSerialize(t, onConflict, "")
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
onConflict.DO_NOTHING()
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool) DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
onConflict.ON_CONSTRAINT("table_pkey").DO_NOTHING()
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}}
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).
DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table1ColInt.SET(Int(11))).
WHERE(table2ColFloat.GT(Float(11.1))),
)
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE
SET col_bool = $1,
col_int = $2
WHERE table2.col_float > $3`)
}

View file

@ -1,20 +0,0 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
)
type clauseReturning struct {
Projections []jet.Projection
}
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) {
if len(r.Projections) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.Projections)
}

View file

@ -8,10 +8,10 @@ func TestNewIntervalColumn(t *testing.T) {
subQuery := SELECT(Int(1)).AsTable("sub_query")
subQueryIntervalColumn := IntervalColumn("col_interval").From(subQuery)
assertSerialize(t, subQueryIntervalColumn, `sub_query."col_interval"`)
assertSerialize(t, subQueryIntervalColumn, `sub_query.col_interval`)
assertSerialize(t, subQueryIntervalColumn.EQ(INTERVAL(2, HOUR, 10, MINUTE)),
`(sub_query."col_interval" = INTERVAL '2 HOUR 10 MINUTE')`)
assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query."col_interval" AS "col_interval"`)
`(sub_query.col_interval = INTERVAL '2 HOUR 10 MINUTE')`)
assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query.col_interval AS "col_interval"`)
subQueryIntervalColumn2 := table1ColInterval.From(subQuery)
assertSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval"`)

View file

@ -0,0 +1,30 @@
package postgres
import "github.com/go-jet/jet/internal/jet"
type conflictAction interface {
jet.Serializer
WHERE(condition BoolExpression) conflictAction
}
// SET creates conflict action for ON_CONFLICT clause
func SET(assigments ...ColumnAssigment) conflictAction {
conflictAction := updateConflictActionImpl{}
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
conflictAction.set = assigments
return &conflictAction
}
type updateConflictActionImpl struct {
jet.Serializer
doUpdate jet.KeywordClause
set jet.SetClauseNew
where jet.ClauseWhere
}
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
u.where.Condition = condition
return u
}

View file

@ -4,7 +4,7 @@ import "github.com/go-jet/jet/internal/jet"
// DeleteStatement is interface for PostgreSQL DELETE statement
type DeleteStatement interface {
Statement
jet.SerializerStatement
WHERE(expression BoolExpression) DeleteStatement
@ -37,6 +37,6 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
}
func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement {
d.Returning.Projections = projections
d.Returning.ProjectionList = projections
return d
}

View file

@ -87,6 +87,9 @@ var MINf = jet.MINf
// MINi is aggregate function. Returns minimum value of int expression across all input values
var MINi = jet.MINi
// SUM is aggregate function. Returns sum of all expressions
var SUM = jet.SUM
// SUMf is aggregate function. Returns sum of expression across all float expressions
var SUMf = jet.SUMf

View file

@ -4,25 +4,25 @@ import "github.com/go-jet/jet/internal/jet"
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
jet.SerializerStatement
// Insert row of values
VALUES(value interface{}, values ...interface{}) InsertStatement
// Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...jet.Projection) InsertStatement
ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict
RETURNING(projections ...Projection) InsertStatement
}
func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.Returning)
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning)
newInsert.Insert.Table = table
newInsert.Insert.Columns = columns
@ -36,6 +36,7 @@ type insertStatementImpl struct {
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
Returning clauseReturning
OnConflict onConflictClause
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
@ -54,7 +55,7 @@ func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
}
func (i *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement {
i.Returning.Projections = projections
i.Returning.ProjectionList = projections
return i
}
@ -62,3 +63,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState
i.ValuesQuery.Query = selectStatement
return i
}
func (i *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict {
i.OnConflict = onConflictClause{
insertStatement: i,
indexExpressions: indexExpressions,
}
return &i.OnConflict
}

View file

@ -1,7 +1,8 @@
package postgres
import (
"github.com/stretchr/testify/assert"
"github.com/go-jet/jet/internal/jet"
"github.com/stretchr/testify/require"
"testing"
"time"
)
@ -13,15 +14,15 @@ func TestInvalidInsert(t *testing.T) {
func TestInsertNilValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1) VALUES
($1);
INSERT INTO db.table1 (col1)
VALUES ($1);
`, nil)
}
func TestInsertSingleValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1) VALUES
($1);
INSERT INTO db.table1 (col1)
VALUES ($1);
`, int(1))
}
@ -29,8 +30,8 @@ func TestInsertWithColumnList(t *testing.T) {
columnList := ColumnList{table3ColInt, table3StrCol}
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2) VALUES
($1, $2);
INSERT INTO db.table3 (col_int, col2)
VALUES ($1, $2);
`, 1, 3)
}
@ -38,15 +39,15 @@ func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), `
INSERT INTO db.table1 (col_time) VALUES
($1);
INSERT INTO db.table1 (col_time)
VALUES ($1);
`, date)
}
func TestInsertMultipleValues(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3) VALUES
($1, $2, $3);
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1ColBool).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col_bool)
VALUES ($1, $2, $3);
`, 1, 2, 3)
}
@ -57,10 +58,10 @@ func TestInsertMultipleRows(t *testing.T) {
VALUES(111, 222)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2),
($3, $4),
($5, $6);
INSERT INTO db.table1 (col1, col_float)
VALUES ($1, $2),
($3, $4),
($5, $6);
`, 1, 2, 11, 22, 111, 222)
}
@ -82,18 +83,18 @@ func TestInsertValuesFromModel(t *testing.T) {
MODEL(&toInsert)
expectedSQL := `
INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2),
($3, $4);
INSERT INTO db.table1 (col1, col_float)
VALUES ($1, $2),
($3, $4);
`
assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
assertStatementSql(t, stmt, expectedSQL, 1, float64(1.11), 1, float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : col1")
require.Equal(t, r, "missing struct field for column : col1")
}()
type Table1Model struct {
Col1Prim int
@ -114,7 +115,7 @@ func TestInsertFromNonStructModel(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "jet: data has to be a struct")
require.Equal(t, r, "jet: data has to be a struct")
}()
table2.INSERT(table2ColInt).MODEL([]int{})
@ -139,9 +140,63 @@ func TestInsertDefaultValue(t *testing.T) {
VALUES(DEFAULT, "two")
var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES
(DEFAULT, $1);
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, $1);
`
assertStatementSql(t, stmt, expectedSQL, "two")
}
func TestInsert_ON_CONFLICT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2'),
('theta', 'beta')
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = TRUE,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}
func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
SET(table1ColBool.SET(Bool(false)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2')
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = FALSE,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}

View file

@ -75,7 +75,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For)
newSelect.Select.Projections = projections
newSelect.Select.ProjectionList = projections
newSelect.From.Table = table
newSelect.Limit.Count = -1
newSelect.Offset.Count = -1

View file

@ -13,7 +13,7 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable {
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
}

View file

@ -4,37 +4,37 @@ import "github.com/go-jet/jet/internal/jet"
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement {
func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
}
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement {
func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
}
// INTERSECT returns all rows that are in query results.
// It eliminates duplicate rows from its result.
func INTERSECT(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement {
func INTERSECT(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...))
}
// INTERSECT_ALL returns all rows that are in query results.
// It does not eliminates duplicate rows from its result.
func INTERSECT_ALL(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) setStatement {
func INTERSECT_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...))
}
// EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs.
// It eliminates duplicate rows from its result.
func EXCEPT(lhs, rhs jet.StatementWithProjections) setStatement {
func EXCEPT(lhs, rhs jet.SerializerStatement) setStatement {
return newSetStatementImpl(except, false, toSelectList(lhs, rhs))
}
// EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs.
// It does not eliminates duplicate rows from its result.
func EXCEPT_ALL(lhs, rhs jet.StatementWithProjections) setStatement {
func EXCEPT_ALL(lhs, rhs jet.SerializerStatement) setStatement {
return newSetStatementImpl(except, true, toSelectList(lhs, rhs))
}
@ -98,7 +98,7 @@ type setStatementImpl struct {
setOperator jet.ClauseSetStmtOperator
}
func newSetStatementImpl(operator string, all bool, selects []jet.StatementWithProjections) setStatement {
func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement {
newSetStatement := &setStatementImpl{}
newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement,
&newSetStatement.setOperator)
@ -139,6 +139,6 @@ const (
except = "EXCEPT"
)
func toSelectList(lhs, rhs jet.StatementWithProjections, selects ...jet.StatementWithProjections) []jet.StatementWithProjections {
return append([]jet.StatementWithProjections{lhs, rhs}, selects...)
func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement {
return append([]jet.SerializerStatement{lhs, rhs}, selects...)
}

View file

@ -31,7 +31,7 @@ type readableTable interface {
type writableTable interface {
INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement
UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement
LOCK() LockStatement
}
@ -54,30 +54,30 @@ type readableTableInterfaceImpl struct {
}
// Generates a select query on the current tableName.
func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.InnerJoin, onCondition)
}
// Creates a left join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.LeftJoin, onCondition)
}
// Creates a right join tableName Expression using onCondition.
func (r *readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.RightJoin, onCondition)
}
func (r *readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return newJoinTable(r.parent, table, jet.FullJoin, onCondition)
}
func (r *readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable {
func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) ReadableTable {
return newJoinTable(r.parent, table, jet.CrossJoin, nil)
}
@ -89,8 +89,8 @@ func (w *writableTableInterfaceImpl) INSERT(columns ...jet.Column) InsertStateme
return newInsertStatement(w.parent, jet.UnwidColumnList(columns))
}
func (w *writableTableInterfaceImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement {
return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...))
func (w *writableTableInterfaceImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(w.parent, jet.UnwidColumnList(columns))
}
func (w *writableTableInterfaceImpl) DELETE() DeleteStatement {
@ -109,10 +109,10 @@ type tableImpl struct {
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, column, columns...),
SerializerTable: jet.NewTable(schemaName, name, columns...),
}
t.readableTableInterfaceImpl.parent = t

View file

@ -5,9 +5,9 @@ import (
)
func TestJoinNilInputs(t *testing.T) {
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table")
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil),
assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil")
}

View file

@ -10,3 +10,12 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement = jet.PrintableStatement
// SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc

View file

@ -6,7 +6,7 @@ import (
// UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface {
Statement
jet.SerializerStatement
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
@ -20,14 +20,19 @@ type updateStatementImpl struct {
Update jet.ClauseUpdate
Set clauseSet
SetNew jet.SetClauseNew
Where jet.ClauseWhere
Returning clauseReturning
}
func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update,
&update.Set, &update.Where, &update.Returning)
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Update,
&update.Set,
&update.SetNew,
&update.Where,
&update.Returning)
update.Update.Table = table
update.Set.Columns = columns
@ -37,7 +42,17 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
u.Set.Values = jet.UnwindRowFromValues(value, values)
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u
}
@ -52,7 +67,7 @@ func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
}
func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement {
u.Returning.Projections = projections
u.Returning.ProjectionList = projections
return u
}
@ -61,7 +76,10 @@ type clauseSet struct {
Values []jet.Serializer
}
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) {
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s.Values) == 0 {
return
}
out.NewLine()
out.WriteString("SET")

View file

@ -10,7 +10,6 @@ import (
var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp")
@ -25,7 +24,6 @@ var table1 = NewTable(
table1Col1,
table1ColInt,
table1ColFloat,
table1Col3,
table1ColTime,
table1ColTimez,
table1ColBool,
@ -75,12 +73,16 @@ var table3 = NewTable(
table3ColInt,
table3StrCol)
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
func assertSerialize(t *testing.T, serializer jet.Serializer, query string, args ...interface{}) {
testutils.AssertSerialize(t, Dialect, serializer, query, args...)
}
func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) {
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...)
}
func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertClauseSerializeErr(t, Dialect, clause, errString)
func assertSerializeErr(t *testing.T, serializer jet.Serializer, errString string) {
testutils.AssertSerializeErr(t, Dialect, serializer, errString)
}
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {
@ -88,5 +90,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st
}
var assertStatementSql = testutils.AssertStatementSql
var assertDebugStatementSql = testutils.AssertDebugStatementSql
var assertStatementSqlErr = testutils.AssertStatementSqlErr
var assertPanicErr = testutils.AssertPanicErr

View file

@ -0,0 +1,26 @@
package postgres
import "github.com/go-jet/jet/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
}
cte.parent = &cte
return cte
}

View file

@ -2,7 +2,7 @@ package internal
import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
@ -10,138 +10,138 @@ import (
func TestNullByteArray(t *testing.T) {
var array NullByteArray
assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan([]byte("bytea")))
assert.Equal(t, array.Valid, true)
assert.Equal(t, string(array.ByteArray), string([]byte("bytea")))
require.NoError(t, array.Scan([]byte("bytea")))
require.Equal(t, array.Valid, true)
require.Equal(t, string(array.ByteArray), string([]byte("bytea")))
assert.Error(t, array.Scan(12), "can't scan []byte from 12")
require.Error(t, array.Scan(12), "can't scan []byte from 12")
}
func TestNullTime(t *testing.T) {
var array NullTime
assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
time := time.Now()
assert.NoError(t, array.Scan(time))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(time))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, time)
require.Equal(t, value, time)
assert.NoError(t, array.Scan([]byte("13:10:11")))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan([]byte("13:10:11")))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.NoError(t, array.Scan("13:10:11"))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan("13:10:11"))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.Error(t, array.Scan(12), "can't scan time.Time from 12")
require.Error(t, array.Scan(12), "can't scan time.Time from 12")
}
func TestNullInt8(t *testing.T) {
var array NullInt8
assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int8(11))
require.Equal(t, value, int8(11))
assert.Error(t, array.Scan("text"), "can't scan int8 from text")
require.Error(t, array.Scan("text"), "can't scan int8 from text")
}
func TestNullInt16(t *testing.T) {
var array NullInt16
assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int16(11))
require.Equal(t, value, int16(11))
assert.NoError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(20))
require.Equal(t, value, int16(20))
assert.NoError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(30))
require.Equal(t, value, int16(30))
assert.NoError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int16(30))
require.Equal(t, value, int16(30))
assert.Error(t, array.Scan("text"), "can't scan int16 from text")
require.Error(t, array.Scan("text"), "can't scan int16 from text")
}
func TestNullInt32(t *testing.T) {
var array NullInt32
assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, int32(11))
require.Equal(t, value, int32(11))
assert.NoError(t, array.Scan(int32(32)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(32))
require.Equal(t, value, int32(32))
assert.NoError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(20))
require.Equal(t, value, int32(20))
assert.NoError(t, array.Scan(uint16(16)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(uint16(16)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(16))
require.Equal(t, value, int32(16))
assert.NoError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(30))
require.Equal(t, value, int32(30))
assert.NoError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, int32(30))
require.Equal(t, value, int32(30))
assert.Error(t, array.Scan("text"), "can't scan int32 from text")
require.Error(t, array.Scan("text"), "can't scan int32 from text")
}
func TestNullFloat32(t *testing.T) {
var array NullFloat32
assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
assert.NoError(t, array.Scan(float64(64)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(float64(64)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
assert.Equal(t, value, float32(64))
require.Equal(t, value, float32(64))
assert.NoError(t, array.Scan(float32(32)))
assert.Equal(t, array.Valid, true)
require.NoError(t, array.Scan(float32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
assert.Equal(t, value, float32(32))
require.Equal(t, value, float32(32))
assert.Error(t, array.Scan(12), "can't scan float32 from 12")
require.Error(t, array.Scan(12), "can't scan float32 from 12")
}

View file

@ -2,36 +2,36 @@ package qrm
import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"reflect"
"testing"
"time"
)
func TestIsSimpleModelType(t *testing.T) {
assert.True(t, isSimpleModelType(reflect.TypeOf(int8(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(int16(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(int32(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(int64(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint8(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint16(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint32(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(uint64(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(int8(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(int16(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(int32(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(int64(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(uint8(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(uint16(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(uint32(11))))
require.True(t, isSimpleModelType(reflect.TypeOf(uint64(11))))
assert.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46))))
assert.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46))))
require.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46))))
require.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46))))
assert.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text"))))
assert.True(t, isSimpleModelType(reflect.TypeOf(time.Now())))
assert.True(t, isSimpleModelType(reflect.TypeOf(uuid.New())))
require.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text"))))
require.True(t, isSimpleModelType(reflect.TypeOf(time.Now())))
require.True(t, isSimpleModelType(reflect.TypeOf(uuid.New())))
complexModelType := struct {
Field1 string
Field2 string
}{}
assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false)
require.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false)
require.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false)
require.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false)
require.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false)
}

View file

@ -1,6 +1,9 @@
package mysql
import (
"fmt"
"github.com/stretchr/testify/require"
"strings"
"testing"
"time"
@ -13,8 +16,6 @@ import (
"github.com/go-jet/jet/tests/testdata/results/common"
. "github.com/go-jet/jet/mysql"
"github.com/stretchr/testify/assert"
)
func TestAllTypes(t *testing.T) {
@ -26,9 +27,9 @@ func TestAllTypes(t *testing.T) {
LIMIT(2).
Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, len(dest), 2)
require.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return
@ -45,8 +46,8 @@ func TestAllTypesViewSelect(t *testing.T) {
dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NoError(t, err)
assert.Equal(t, len(dest), 2)
require.NoError(t, err)
require.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
return
@ -74,11 +75,12 @@ func TestUUID(t *testing.T) {
err := query.Query(db, &dest)
assert.NoError(t, err)
assert.True(t, dest.StrUUID != nil)
assert.True(t, dest.UUID.String() != uuid.UUID{}.String())
assert.True(t, dest.StrUUID.String() != uuid.UUID{}.String())
assert.Equal(t, dest.StrUUID.String(), dest.BinUUID.String())
require.NoError(t, err)
require.True(t, dest.StrUUID != nil)
require.True(t, dest.UUID.String() != uuid.UUID{}.String())
require.True(t, dest.StrUUID.String() != uuid.UUID{}.String())
require.Equal(t, dest.StrUUID.String(), dest.BinUUID.String())
requireLogged(t, query)
}
func TestExpressionOperators(t *testing.T) {
@ -95,23 +97,23 @@ func TestExpressionOperators(t *testing.T) {
//fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, `
SELECT all_types.integer IS NULL AS "result.is_null",
testutils.AssertStatementSql(t, query, strings.Replace(`
SELECT all_types.'integer' IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in",
(all_types.small_int_ptr IN ((
SELECT all_types.integer AS "all_types.integer"
SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types
))) AS "result.in_select",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN ((
SELECT all_types.integer AS "all_types.integer"
SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types
))) AS "result.not_in_select",
DATABASE()
FROM test_sample.all_types
LIMIT ?;
`, int64(11), int64(22), int64(11), int64(22), int64(2))
`, "'", "`", -1), int64(11), int64(22), int64(11), int64(22), int64(2))
var dest []struct {
common.ExpressionTestResult `alias:"result.*"`
@ -119,7 +121,7 @@ LIMIT ?;
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
//testutils.PrintJson(dest)
@ -210,7 +212,7 @@ FROM test_sample.all_types;
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json")
}
@ -261,45 +263,47 @@ func TestFloatOperators(t *testing.T) {
queryStr, _ := query.Sql()
assert.Equal(t, queryStr, `
SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.decimal = ?) AS "eq2",
(all_types.real = ?) AS "eq3",
(NOT(all_types.numeric <=> all_types.numeric)) AS "distinct1",
(NOT(all_types.decimal <=> ?)) AS "distinct2",
(NOT(all_types.real <=> ?)) AS "distinct3",
(all_types.numeric <=> all_types.numeric) AS "not_distinct1",
(all_types.decimal <=> ?) AS "not_distinct2",
(all_types.real <=> ?) AS "not_distinct3",
(all_types.numeric < ?) AS "lt1",
(all_types.numeric < ?) AS "lt2",
(all_types.numeric > ?) AS "gt1",
(all_types.numeric > ?) AS "gt2",
TRUNCATE((all_types.decimal + all_types.decimal), ?) AS "add1",
TRUNCATE((all_types.decimal + ?), ?) AS "add2",
TRUNCATE((all_types.decimal - all_types.decimal_ptr), ?) AS "sub1",
TRUNCATE((all_types.decimal - ?), ?) AS "sub2",
TRUNCATE((all_types.decimal * all_types.decimal_ptr), ?) AS "mul1",
TRUNCATE((all_types.decimal * ?), ?) AS "mul2",
TRUNCATE((all_types.decimal / all_types.decimal_ptr), ?) AS "div1",
TRUNCATE((all_types.decimal / ?), ?) AS "div2",
TRUNCATE((all_types.decimal % all_types.decimal_ptr), ?) AS "mod1",
TRUNCATE((all_types.decimal % ?), ?) AS "mod2",
TRUNCATE(POW(all_types.decimal, all_types.decimal_ptr), ?) AS "pow1",
TRUNCATE(POW(all_types.decimal, ?), ?) AS "pow2",
TRUNCATE(ABS(all_types.decimal), ?) AS "abs",
TRUNCATE(POWER(all_types.decimal, ?), ?) AS "power",
TRUNCATE(SQRT(all_types.decimal), ?) AS "sqrt",
TRUNCATE(POWER(all_types.decimal, (? / ?)), ?) AS "cbrt",
CEIL(all_types.real) AS "ceil",
FLOOR(all_types.real) AS "floor",
ROUND(all_types.decimal) AS "round1",
ROUND(all_types.decimal, ?) AS "round2",
SIGN(all_types.real) AS "sign",
TRUNCATE(all_types.decimal, ?) AS "trunc"
//fmt.Println(queryStr)
require.Equal(t, queryStr, strings.Replace(`
SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1",
(all_types.'decimal' = ?) AS "eq2",
(all_types.'real' = ?) AS "eq3",
(NOT(all_types.'numeric' <=> all_types.'numeric')) AS "distinct1",
(NOT(all_types.'decimal' <=> ?)) AS "distinct2",
(NOT(all_types.'real' <=> ?)) AS "distinct3",
(all_types.'numeric' <=> all_types.'numeric') AS "not_distinct1",
(all_types.'decimal' <=> ?) AS "not_distinct2",
(all_types.'real' <=> ?) AS "not_distinct3",
(all_types.'numeric' < ?) AS "lt1",
(all_types.'numeric' < ?) AS "lt2",
(all_types.'numeric' > ?) AS "gt1",
(all_types.'numeric' > ?) AS "gt2",
TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1",
TRUNCATE((all_types.'decimal' + ?), ?) AS "add2",
TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1",
TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2",
TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1",
TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2",
TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1",
TRUNCATE((all_types.'decimal' / ?), ?) AS "div2",
TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1",
TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2",
TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1",
TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2",
TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs",
TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power",
TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt",
TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt",
CEIL(all_types.'real') AS "ceil",
FLOOR(all_types.'real') AS "floor",
ROUND(all_types.'decimal') AS "round1",
ROUND(all_types.'decimal', ?) AS "round2",
SIGN(all_types.'real') AS "sign",
TRUNCATE(all_types.'decimal', ?) AS "trunc"
FROM test_sample.all_types
LIMIT ?;
`)
`, "'", "`", -1))
var dest []struct {
common.FloatExpressionTestResult `alias:"."`
@ -307,7 +311,7 @@ LIMIT ?;
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json")
}
@ -444,7 +448,7 @@ LIMIT ?;
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
//testutils.PrintJson(dest)
@ -516,7 +520,7 @@ func TestStringOperators(t *testing.T) {
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
@ -568,7 +572,7 @@ func TestTimeExpressions(t *testing.T) {
//fmt.Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, `
testutils.AssertDebugStatementSql(t, query, strings.Replace(`
SELECT CAST('20:34:58' AS TIME),
all_types.time = all_types.time,
all_types.time = CAST('23:06:06' AS TIME),
@ -589,7 +593,7 @@ SELECT CAST('20:34:58' AS TIME),
all_types.time >= all_types.time,
all_types.time >= CAST('14:26:36' AS TIME),
all_types.time + INTERVAL 10 MINUTE,
all_types.time + INTERVAL all_types.integer MINUTE,
all_types.time + INTERVAL all_types.''integer'' MINUTE,
all_types.time + INTERVAL 3 HOUR,
all_types.time - INTERVAL 20 MINUTE,
all_types.time - INTERVAL all_types.small_int MINUTE,
@ -598,13 +602,13 @@ SELECT CAST('20:34:58' AS TIME),
CURRENT_TIME,
CURRENT_TIME(3)
FROM test_sample.all_types;
`, "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06",
`, "''", "`", -1), "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06",
"19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36")
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestDateExpressions(t *testing.T) {
@ -648,25 +652,25 @@ func TestDateExpressions(t *testing.T) {
//fmt.Println(query.DebugSql())
testutils.AssertStatementSql(t, query, `
SELECT CAST(? AS DATE),
testutils.AssertDebugStatementSql(t, query, `
SELECT CAST('2009-11-17' AS DATE),
all_types.date = all_types.date,
all_types.date = CAST(? AS DATE),
all_types.date = CAST('2019-06-06' AS DATE),
all_types.date_ptr != all_types.date,
all_types.date_ptr != CAST(? AS DATE),
all_types.date_ptr != CAST('2019-01-06' AS DATE),
NOT(all_types.date <=> all_types.date),
NOT(all_types.date <=> CAST(? AS DATE)),
NOT(all_types.date <=> CAST('2019-02-06' AS DATE)),
all_types.date <=> all_types.date,
all_types.date <=> CAST(? AS DATE),
all_types.date <=> CAST('2019-03-06' AS DATE),
all_types.date < all_types.date,
all_types.date < CAST(? AS DATE),
all_types.date < CAST('2019-04-06' AS DATE),
all_types.date <= all_types.date,
all_types.date <= CAST(? AS DATE),
all_types.date <= CAST('2019-05-05' AS DATE),
all_types.date > all_types.date,
all_types.date > CAST(? AS DATE),
all_types.date > CAST('2019-01-04' AS DATE),
all_types.date >= all_types.date,
all_types.date >= CAST(? AS DATE),
all_types.date + INTERVAL ? MINUTE_MICROSECOND,
all_types.date >= CAST('2019-02-03' AS DATE),
all_types.date + INTERVAL '10:20.000100' MINUTE_MICROSECOND,
all_types.date + INTERVAL all_types.big_int MINUTE,
all_types.date + INTERVAL 15 HOUR,
all_types.date - INTERVAL 20 MINUTE,
@ -679,7 +683,7 @@ FROM test_sample.all_types;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestDateTimeExpressions(t *testing.T) {
@ -756,7 +760,7 @@ FROM test_sample.all_types;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestTimestampExpressions(t *testing.T) {
@ -832,13 +836,13 @@ FROM test_sample.all_types;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestTimeLiterals(t *testing.T) {
loc, err := time.LoadLocation("Europe/Berlin")
assert.NoError(t, err)
require.NoError(t, err)
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 351387237, loc)
@ -877,7 +881,7 @@ LIMIT ?;
}
err = query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
//testutils.PrintJson(dest)
@ -960,7 +964,139 @@ func TestINTERVAL(t *testing.T) {
//fmt.Println(query.DebugSql())
err := query.Query(db, &struct{}{})
assert.NoError(t, err)
require.NoError(t, err)
}
func TestAllTypesInsert(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
stmt := AllTypes.INSERT(AllTypes.AllColumns).
MODEL(toInsert)
fmt.Println(stmt.DebugSql())
testutils.AssertExec(t, stmt, tx, 1)
var dest model.AllTypes
err = AllTypes.SELECT(AllTypes.AllColumns).
WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))).
Query(tx, &dest)
require.NoError(t, err)
require.Equal(t, toInsert.TinyInt, dest.TinyInt)
err = tx.Rollback()
require.NoError(t, err)
}
func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
toInsert := model.AllTypes{
Boolean: true,
Integer: 124,
Float: 45.67,
Blob: []byte("blob"),
Text: "text",
JSON: "{}",
Time: time.Now(),
Timestamp: time.Now(),
Date: time.Now(),
}
stmt := AllTypes.INSERT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Float,
AllTypes.Blob,
AllTypes.Text,
AllTypes.JSON,
AllTypes.Time,
AllTypes.Timestamp,
AllTypes.Date,
).
MODEL(toInsert).
ON_DUPLICATE_KEY_UPDATE(
AllTypes.Boolean.SET(Bool(false)),
AllTypes.Integer.SET(Int(4)),
AllTypes.Float.SET(Float(0.67)),
AllTypes.Text.SET(String("new text")),
AllTypes.Time.SET(TimeT(time.Now())),
AllTypes.Timestamp.SET(TimestampT(time.Now())),
AllTypes.Date.SET(DateT(time.Now())),
)
fmt.Println(stmt.DebugSql())
_, err = stmt.Exec(tx)
require.NoError(t, err)
err = tx.Rollback()
require.NoError(t, err)
}
var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: testutils.BoolPtr(true),
TinyInt: 1,
UTinyInt: 2,
SmallInt: 3,
USmallInt: 4,
MediumInt: 5,
UMediumInt: 6,
Integer: 7,
UInteger: 8,
BigInt: 9,
UBigInt: 1122334455,
TinyIntPtr: testutils.Int8Ptr(11),
UTinyIntPtr: testutils.UInt8Ptr(22),
SmallIntPtr: testutils.Int16Ptr(33),
USmallIntPtr: testutils.UInt16Ptr(44),
MediumIntPtr: testutils.Int32Ptr(55),
UMediumIntPtr: testutils.UInt32Ptr(66),
IntegerPtr: testutils.Int32Ptr(77),
UIntegerPtr: testutils.UInt32Ptr(88),
BigIntPtr: testutils.Int64Ptr(99),
UBigIntPtr: testutils.UInt64Ptr(111),
Decimal: 11.22,
DecimalPtr: testutils.Float64Ptr(33.44),
Numeric: 55.66,
NumericPtr: testutils.Float64Ptr(77.88),
Float: 99.00,
FloatPtr: testutils.Float64Ptr(11.22),
Double: 33.44,
DoublePtr: testutils.Float64Ptr(55.66),
Real: 77.88,
RealPtr: testutils.Float64Ptr(99.00),
Bit: "1",
BitPtr: testutils.StringPtr("0"),
Time: time.Date(0, 0, 0, 10, 11, 12, 100, &time.Location{}),
TimePtr: testutils.TimePtr(time.Date(0, 0, 0, 10, 11, 12, 100, time.UTC)),
Date: time.Now(),
DatePtr: testutils.TimePtr(time.Now()),
DateTime: time.Now(),
DateTimePtr: testutils.TimePtr(time.Now()),
Timestamp: time.Now(),
//TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB
Year: 2000,
YearPtr: testutils.Int16Ptr(2001),
Char: "abcd",
CharPtr: testutils.StringPtr("absd"),
VarChar: "abcd",
VarCharPtr: testutils.StringPtr("absd"),
Binary: []byte("1010"),
BinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
VarBinary: []byte("1010"),
VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
Blob: []byte("large file"),
BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
Text: "some text",
TextPtr: testutils.StringPtr("text"),
Enum: model.AllTypesEnum_Value1,
JSON: "{}",
JSONPtr: testutils.StringPtr(`{"a": 1}`),
}
var allTypesJson = `
@ -1100,28 +1236,26 @@ func TestReservedWord(t *testing.T) {
stmt := SELECT(User.AllColumns).
FROM(User)
// NOTE: A word that follows a period in a qualified name must be an identifier, so it
// need not be quoted even if it is reserved
testutils.AssertDebugStatementSql(t, stmt, `
SELECT user.column AS "user.column",
user.use AS "user.use",
testutils.AssertDebugStatementSql(t, stmt, strings.Replace(`
SELECT user.''column'' AS "user.column",
user.''use'' AS "user.use",
user.ceil AS "user.ceil",
user.commit AS "user.commit",
user.create AS "user.create",
user.default AS "user.default",
user.desc AS "user.desc",
user.empty AS "user.empty",
user.float AS "user.float",
user.join AS "user.join",
user.like AS "user.like",
user.''create'' AS "user.create",
user.''default'' AS "user.default",
user.''desc'' AS "user.desc",
user.''empty'' AS "user.empty",
user.''float'' AS "user.float",
user.''join'' AS "user.join",
user.''like'' AS "user.like",
user.max AS "user.max",
user.rank AS "user.rank"
user.''rank'' AS "user.rank"
FROM test_sample.user;
`)
`, "''", "`", -1))
var dest []model.User
err := stmt.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
testutils.PrintJson(dest)

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
@ -55,7 +55,7 @@ FROM test_sample.all_types;
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertDeepEqual(t, dest, Result{
As1: "test",
@ -68,4 +68,6 @@ FROM test_sample.all_types;
Unsigned: 15,
Binary: "Some text",
})
requireLogged(t, query)
}

View file

@ -6,7 +6,7 @@ import (
. "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
@ -24,6 +24,7 @@ WHERE link.name IN ('Gmail', 'Outlook');
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
testutils.AssertExec(t, deleteStmt, db, 2)
requireLogged(t, deleteStmt)
}
func TestDeleteWithWhereOrderByLimit(t *testing.T) {
@ -43,6 +44,7 @@ LIMIT 1;
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1))
testutils.AssertExec(t, deleteStmt, db, 1)
requireLogged(t, deleteStmt)
}
func TestDeleteQueryContext(t *testing.T) {
@ -60,7 +62,8 @@ func TestDeleteQueryContext(t *testing.T) {
dest := []model.Link{}
err := deleteStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded")
requireLogged(t, deleteStmt)
}
func TestDeleteExecContext(t *testing.T) {
@ -77,7 +80,7 @@ func TestDeleteExecContext(t *testing.T) {
_, err := deleteStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded")
}
func initForDeleteTest(t *testing.T) {

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/generator/mysql"
"github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"os/exec"
@ -25,23 +25,23 @@ func TestGenerator(t *testing.T) {
DBName: "dvds",
})
assert.NoError(t, err)
require.NoError(t, err)
assertGeneratedFiles(t)
}
err := os.RemoveAll(genTestDirRoot)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestCmdGenerator(t *testing.T) {
goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet")
goInstallJet.Stderr = os.Stderr
err := goInstallJet.Run()
assert.NoError(t, err)
require.NoError(t, err)
err = os.RemoveAll(genTestDir3)
assert.NoError(t, err)
require.NoError(t, err)
cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306",
"-user=jet", "-password=jet", "-path="+genTestDir3)
@ -50,44 +50,44 @@ func TestCmdGenerator(t *testing.T) {
cmd.Stdout = os.Stdout
err = cmd.Run()
assert.NoError(t, err)
require.NoError(t, err)
assertGeneratedFiles(t)
err = os.RemoveAll(genTestDirRoot)
assert.NoError(t, err)
require.NoError(t, err)
}
func assertGeneratedFiles(t *testing.T) {
// Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table")
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", "\npackage table", actorSQLBuilderFile)
testutils.AssertFileContent(t, genTestDir3+"/dvds/table/actor.go", actorSQLBuilderFile)
// View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view")
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", "\npackage view", actorInfoSQLBuilerFile)
testutils.AssertFileContent(t, genTestDir3+"/dvds/view/actor_info.go", actorInfoSQLBuilderFile)
// Enums SQL Builder files
enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum")
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile)
testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", mpaaRatingEnumFile)
// Model files
modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model")
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
@ -96,10 +96,17 @@ func assertGeneratedFiles(t *testing.T) {
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", "\npackage model", actorModelFile)
testutils.AssertFileContent(t, genTestDir3+"/dvds/model/actor.go", actorModelFile)
}
var mpaaRatingEnumFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package enum
import "github.com/go-jet/jet/mysql"
@ -120,6 +127,13 @@ var FilmRating = &struct {
`
var actorSQLBuilderFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
@ -141,25 +155,25 @@ type ActorTable struct {
MutableColumns mysql.ColumnList
}
// creates new ActorTable with assigned alias
func (a *ActorTable) AS(alias string) *ActorTable {
// AS creates new ActorTable with assigned alias
func (a *ActorTable) AS(alias string) ActorTable {
aliasTable := newActorTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorTable() *ActorTable {
func newActorTable() ActorTable {
var (
ActorIDColumn = mysql.IntegerColumn("actor_id")
FirstNameColumn = mysql.StringColumn("first_name")
LastNameColumn = mysql.StringColumn("last_name")
LastUpdateColumn = mysql.TimestampColumn("last_update")
allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}
mutableColumns = mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}
)
return &ActorTable{
Table: mysql.NewTable("dvds", "actor", ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn),
return ActorTable{
Table: mysql.NewTable("dvds", "actor", allColumns...),
//Columns
ActorID: ActorIDColumn,
@ -167,13 +181,20 @@ func newActorTable() *ActorTable {
LastName: LastNameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn},
MutableColumns: mysql.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var actorModelFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
@ -188,7 +209,14 @@ type Actor struct {
}
`
var actorInfoSQLBuilerFile = `
var actorInfoSQLBuilderFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package view
import (
@ -210,25 +238,25 @@ type ActorInfoTable struct {
MutableColumns mysql.ColumnList
}
// creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) *ActorInfoTable {
// AS creates new ActorInfoTable with assigned alias
func (a *ActorInfoTable) AS(alias string) ActorInfoTable {
aliasTable := newActorInfoTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newActorInfoTable() *ActorInfoTable {
func newActorInfoTable() ActorInfoTable {
var (
ActorIDColumn = mysql.IntegerColumn("actor_id")
FirstNameColumn = mysql.StringColumn("first_name")
LastNameColumn = mysql.StringColumn("last_name")
FilmInfoColumn = mysql.StringColumn("film_info")
allColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
mutableColumns = mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn}
)
return &ActorInfoTable{
Table: mysql.NewTable("dvds", "actor_info", ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn),
return ActorInfoTable{
Table: mysql.NewTable("dvds", "actor_info", allColumns...),
//Columns
ActorID: ActorIDColumn,
@ -236,8 +264,8 @@ func newActorInfoTable() *ActorInfoTable {
LastName: LastNameColumn,
FilmInfo: FilmInfoColumn,
AllColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn},
MutableColumns: mysql.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, FilmInfoColumn},
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`

View file

@ -6,7 +6,8 @@ import (
. "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"math/rand"
"testing"
"time"
)
@ -15,10 +16,10 @@ func TestInsertValues(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link (id, url, name, description) VALUES
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT),
(102, 'http://www.yahoo.com', 'Yahoo', NULL);
INSERT INTO test_sample.link (id, url, name, description)
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT),
(102, 'http://www.yahoo.com', 'Yahoo', NULL);
`
insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
@ -32,7 +33,8 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES
102, "http://www.yahoo.com", "Yahoo", nil)
_, err := insertQuery.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
requireLogged(t, insertQuery)
insertedLinks := []model.Link{}
@ -41,8 +43,8 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES
ORDER_BY(Link.ID).
Query(db, &insertedLinks)
assert.NoError(t, err)
assert.Equal(t, len(insertedLinks), 3)
require.NoError(t, err)
require.Equal(t, len(insertedLinks), 3)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
@ -69,8 +71,8 @@ func TestInsertEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
expectedSQL := `
INSERT INTO test_sample.link VALUES
(100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
INSERT INTO test_sample.link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
`
stmt := Link.INSERT().
@ -80,7 +82,8 @@ INSERT INTO test_sample.link VALUES
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
_, err := stmt.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
requireLogged(t, stmt)
insertedLinks := []model.Link{}
@ -89,16 +92,16 @@ INSERT INTO test_sample.link VALUES
ORDER_BY(Link.ID).
Query(db, &insertedLinks)
assert.NoError(t, err)
assert.Equal(t, len(insertedLinks), 1)
require.NoError(t, err)
require.Equal(t, len(insertedLinks), 1)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
}
func TestInsertModelObject(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link (url, name) VALUES
('http://www.duckduckgo.com', 'Duck Duck go');
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
`
linkData := model.Link{
@ -113,14 +116,14 @@ INSERT INTO test_sample.link (url, name) VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go")
_, err := query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestInsertModelObjectEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link VALUES
(1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
INSERT INTO test_sample.link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`
linkData := model.Link{
@ -136,15 +139,15 @@ INSERT INTO test_sample.link VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
_, err := query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestInsertModelsObject(t *testing.T) {
expectedSQL := `
INSERT INTO test_sample.link (url, name) VALUES
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`
tutorial := model.Link{
@ -172,16 +175,16 @@ INSERT INTO test_sample.link (url, name) VALUES
"http://www.yahoo.com", "Yahoo")
_, err := query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestInsertUsingMutableColumns(t *testing.T) {
var expectedSQL = `
INSERT INTO test_sample.link (url, name, description) VALUES
('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
INSERT INTO test_sample.link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`
google := model.Link{
@ -207,14 +210,14 @@ INSERT INTO test_sample.link (url, name, description) VALUES
"http://www.yahoo.com", "Yahoo", nil)
_, err := stmt.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestInsertQuery(t *testing.T) {
_, err := Link.DELETE().
WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))).
Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
var expectedSQL = `
INSERT INTO test_sample.link (url, name) (
@ -236,7 +239,7 @@ INSERT INTO test_sample.link (url, name) (
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1))
_, err = query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
youtubeLinks := []model.Link{}
err = Link.
@ -244,8 +247,48 @@ INSERT INTO test_sample.link (url, name) (
WHERE(Link.Name.EQ(String("Youtube"))).
Query(db, &youtubeLinks)
assert.NoError(t, err)
assert.Equal(t, len(youtubeLinks), 2)
require.NoError(t, err)
require.Equal(t, len(youtubeLinks), 2)
}
func TestInsertOnDuplicateKey(t *testing.T) {
randId := rand.Int31()
stmt := Link.INSERT().
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_DUPLICATE_KEY_UPDATE(
Link.ID.SET(Link.ID.ADD(Int(11))),
Link.Name.SET(String("PostgreSQL Tutorial 2")),
)
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.link
VALUES (?, ?, ?, DEFAULT),
(?, ?, ?, DEFAULT)
ON DUPLICATE KEY UPDATE id = (id + ?),
name = ?;
`, randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
int64(11), "PostgreSQL Tutorial 2")
testutils.AssertExec(t, stmt, db, 3)
newLinks := []model.Link{}
err := SELECT(Link.AllColumns).
FROM(Link).
WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))).
Query(db, &newLinks)
require.NoError(t, err)
require.Len(t, newLinks, 1)
require.Equal(t, newLinks[0], model.Link{
ID: randId + 11,
URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial 2",
Description: nil,
})
}
func TestInsertWithQueryContext(t *testing.T) {
@ -262,7 +305,7 @@ func TestInsertWithQueryContext(t *testing.T) {
dest := []model.Link{}
err := stmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded")
}
func TestInsertWithExecContext(t *testing.T) {
@ -278,10 +321,10 @@ func TestInsertWithExecContext(t *testing.T) {
_, err := stmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded")
}
func cleanUpLinkTable(t *testing.T) {
_, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
}

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
@ -16,7 +16,8 @@ LOCK TABLES dvds.customer READ;
`)
_, err := query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
requireLogged(t, query)
}
func TestLockWrite(t *testing.T) {
@ -27,7 +28,8 @@ LOCK TABLES dvds.customer WRITE;
`)
_, err := query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
requireLogged(t, query)
}
func TestUnlockTables(t *testing.T) {
@ -38,5 +40,6 @@ UNLOCK TABLES;
`)
_, err := query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
requireLogged(t, query)
}

View file

@ -1,9 +1,15 @@
package mysql
import (
"context"
"database/sql"
"flag"
jetmysql "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/dbconfig"
"github.com/stretchr/testify/require"
"math/rand"
"time"
_ "github.com/go-sql-driver/mysql"
@ -28,6 +34,7 @@ func sourceIsMariaDB() bool {
}
func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
var err error
@ -41,3 +48,21 @@ func TestMain(m *testing.M) {
os.Exit(ret)
}
var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string
func init() {
jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql()
})
}
func requireLogged(t *testing.T, statement postgres.Statement) {
query, args := statement.Sql()
require.Equal(t, loggedSQL, query)
require.Equal(t, loggedSQLArgs, args)
require.Equal(t, loggedDebugSQL, statement.DebugSql())
}

View file

@ -7,7 +7,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
@ -30,9 +30,10 @@ WHERE actor.actor_id = ?;
actor := model.Actor{}
err := query.Query(db, &actor)
assert.NoError(t, err)
require.NoError(t, err)
testutils.AssertDeepEqual(t, actor, actor2)
requireLogged(t, query)
}
var actor2 = model.Actor{
@ -59,14 +60,15 @@ ORDER BY actor.actor_id;
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, len(dest), 200)
require.Equal(t, len(dest), 200)
testutils.AssertDeepEqual(t, dest[1], actor2)
//testutils.PrintJson(dest)
//testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json")
requireLogged(t, query)
}
func TestSelectGroupByHaving(t *testing.T) {
@ -136,14 +138,15 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC;
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
//testutils.PrintJson(dest)
assert.Equal(t, len(dest), 174)
require.Equal(t, len(dest), 174)
//testutils.SaveJsonFile(dest, "mysql/testdata/customer_payment_sum.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json")
requireLogged(t, query)
}
func TestSubQuery(t *testing.T) {
@ -176,7 +179,7 @@ func TestSubQuery(t *testing.T) {
}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
//testutils.SaveJsonFile(dest, "mysql/testdata/r_rating_films.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/r_rating_films.json")
@ -229,7 +232,7 @@ LIMIT ?;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestSelectUNION(t *testing.T) {
@ -265,7 +268,7 @@ LIMIT ?;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestSelectUNION_ALL(t *testing.T) {
@ -308,7 +311,7 @@ OFFSET ?;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestJoinQueryStruct(t *testing.T) {
@ -406,10 +409,10 @@ LIMIT ?;
err := query.Query(db, &dest)
assert.NoError(t, err)
//assert.Equal(t, len(dest), 1)
//assert.Equal(t, len(dest[0].Films), 10)
//assert.Equal(t, len(dest[0].Films[0].Actors), 10)
require.NoError(t, err)
//require.Equal(t, len(dest), 1)
//require.Equal(t, len(dest[0].Films), 10)
//require.Equal(t, len(dest[0].Films[0].Actors), 10)
//testutils.SaveJsonFile(dest, "./mysql/testdata/lang_film_actor_inventory_rental.json")
@ -450,10 +453,10 @@ FOR`
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NoError(t, err)
require.NoError(t, err)
err = tx.Rollback()
assert.NoError(t, err)
require.NoError(t, err)
}
for lockType, lockTypeStr := range getRowLockTestData() {
@ -464,10 +467,10 @@ FOR`
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NoError(t, err)
require.NoError(t, err)
err = tx.Rollback()
assert.NoError(t, err)
require.NoError(t, err)
}
if sourceIsMariaDB() {
@ -482,10 +485,10 @@ FOR`
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NoError(t, err)
require.NoError(t, err)
err = tx.Rollback()
assert.NoError(t, err)
require.NoError(t, err)
}
}
@ -514,7 +517,7 @@ SELECT true,
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestLockInShareMode(t *testing.T) {
@ -535,7 +538,7 @@ LOCK IN SHARE MODE;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestWindowFunction(t *testing.T) {
@ -612,7 +615,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestWindowClause(t *testing.T) {
@ -649,7 +652,7 @@ ORDER BY payment.customer_id;
dest := []struct{}{}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
}
func TestSimpleView(t *testing.T) {
@ -670,9 +673,9 @@ func TestSimpleView(t *testing.T) {
var dest []ActorInfo
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, len(dest), 10)
require.Equal(t, len(dest), 10)
testutils.AssertJSON(t, dest[1:2], `
[
{
@ -702,11 +705,11 @@ func TestJoinViewWithTable(t *testing.T) {
}
err := query.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27)
require.Equal(t, len(dest), 2)
require.Equal(t, len(dest[0].Rentals), 32)
require.Equal(t, len(dest[1].Rentals), 27)
}
func TestConditionalProjectionList(t *testing.T) {
@ -737,7 +740,7 @@ LIMIT 3;
`)
var dest []model.Customer
err := stmt.Query(db, &dest)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, len(dest), 3)
require.Equal(t, len(dest), 3)
}

View file

@ -8,7 +8,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
@ -16,22 +16,35 @@ import (
func TestUpdateValues(t *testing.T) {
setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
var expectedSQL = `
UPDATE test_sample.link
SET name = 'Bong',
SET name = 'Bong',
url = 'http://bong.com'
WHERE link.name = 'Bing';
`
t.Run("old version", func(t *testing.T) {
query := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
fmt.Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, query, db)
requireLogged(t, query)
})
testutils.AssertExec(t, query, db)
t.Run("new version", func(t *testing.T) {
stmt := Link.UPDATE().
SET(
Link.Name.SET(String("Bong")),
Link.URL.SET(String("http://bong.com")),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, stmt, db)
requireLogged(t, stmt)
})
links := []model.Link{}
@ -40,8 +53,8 @@ WHERE link.name = 'Bing';
WHERE(Link.Name.EQ(String("Bong"))).
Query(db, &links)
assert.NoError(t, err)
assert.Equal(t, len(links), 1)
require.NoError(t, err)
require.Equal(t, len(links), 1)
testutils.AssertDeepEqual(t, links[0], model.Link{
ID: 204,
URL: "http://bong.com",
@ -52,21 +65,11 @@ WHERE link.name = 'Bing';
func TestUpdateWithSubQueries(t *testing.T) {
setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET(
SELECT(String("Bong")),
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
).
WHERE(Link.Name.EQ(String("Bing")))
expectedSQL := `
UPDATE test_sample.link
SET name = (
SELECT ?
),
),
url = (
SELECT link2.url AS "link2.url"
FROM test_sample.link2
@ -74,10 +77,39 @@ SET name = (
)
WHERE link.name = ?;
`
fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
t.Run("old version", func(t *testing.T) {
query := Link.
UPDATE(Link.Name, Link.URL).
SET(
SELECT(String("Bong")),
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertExec(t, query, db)
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db)
requireLogged(t, query)
})
t.Run("new version", func(t *testing.T) {
query := Link.
UPDATE().
SET(
Link.Name.SET(StringExp(SELECT(String("Bong")))),
Link.URL.SET(StringExp(
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
)),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db)
requireLogged(t, query)
})
}
func TestUpdateWithModelData(t *testing.T) {
@ -96,16 +128,16 @@ func TestUpdateWithModelData(t *testing.T) {
expectedSQL := `
UPDATE test_sample.link
SET id = ?,
url = ?,
name = ?,
SET id = ?,
url = ?,
name = ?,
description = ?
WHERE link.id = ?;
`
fmt.Println(stmt.Sql())
testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
testutils.AssertExec(t, stmt, db)
requireLogged(t, stmt)
}
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
@ -127,8 +159,8 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
var expectedSQL = `
UPDATE test_sample.link
SET description = NULL,
name = 'DuckDuckGo',
SET description = NULL,
name = 'DuckDuckGo',
url = 'http://www.duckduckgo.com'
WHERE link.id = 201;
`
@ -137,10 +169,10 @@ WHERE link.id = 201;
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
testutils.AssertExec(t, stmt, db)
requireLogged(t, stmt)
}
func TestUpdateWithModelDataAndMutableColumns(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{
@ -156,23 +188,21 @@ func TestUpdateWithModelDataAndMutableColumns(t *testing.T) {
var expectedSQL = `
UPDATE test_sample.link
SET url = 'http://www.duckduckgo.com',
name = 'DuckDuckGo',
SET url = 'http://www.duckduckgo.com',
name = 'DuckDuckGo',
description = NULL
WHERE link.id = 201;
`
fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
testutils.AssertExec(t, stmt, db)
}
func TestUpdateWithInvalidModelData(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : id")
require.Equal(t, r, "missing struct field for column : id")
}()
setupLinkTableForUpdateTest(t)
@ -213,7 +243,7 @@ func TestUpdateQueryContext(t *testing.T) {
dest := []model.Link{}
err := updateStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded")
}
func TestUpdateExecContext(t *testing.T) {
@ -231,7 +261,7 @@ func TestUpdateExecContext(t *testing.T) {
_, err := updateStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded")
}
func TestUpdateWithJoin(t *testing.T) {
@ -244,7 +274,7 @@ func TestUpdateWithJoin(t *testing.T) {
//fmt.Println(query.DebugSql())
_, err := query.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
}
func setupLinkTableForUpdateTest(t *testing.T) {
@ -259,5 +289,5 @@ func setupLinkTableForUpdateTest(t *testing.T) {
VALUES(204, "http://www.bing.com", "Bing", DEFAULT).
Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
}

159
tests/mysql/with_test.go Normal file
View file

@ -0,0 +1,159 @@
package mysql
import (
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/require"
"strings"
"testing"
)
func TestWITH_And_SELECT(t *testing.T) {
salesRep := CTE("sales_rep")
salesRepStaffID := Staff.StaffID.From(salesRep)
salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep)
customerSalesRep := CTE("customer_sales_rep")
stmt := WITH(
salesRep.AS(
SELECT(
Staff.StaffID,
Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()),
).FROM(Staff),
),
customerSalesRep.AS(
SELECT(
Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"),
salesRepFullName,
).FROM(
salesRep.
INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)).
INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)),
),
),
)(
SELECT(customerSalesRep.AllColumns()).
FROM(customerSalesRep),
)
//fmt.Println(stmt.DebugSql())
testutils.AssertStatementSql(t, stmt, strings.Replace(`
WITH sales_rep AS (
SELECT staff.staff_id AS "staff.staff_id",
(CONCAT(staff.first_name, staff.last_name)) AS "sales_rep_full_name"
FROM dvds.staff
),customer_sales_rep AS (
SELECT (CONCAT(customer.first_name, customer.last_name)) AS "customer_name",
sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM sales_rep
INNER JOIN dvds.store ON (store.manager_staff_id = sales_rep.''staff.staff_id'')
INNER JOIN dvds.customer ON (customer.store_id = store.store_id)
)
SELECT customer_sales_rep.customer_name AS "customer_name",
customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM customer_sales_rep;
`, "''", "`", -1))
var dest []struct {
CustomerName string
SalesRepFullName string
}
err := stmt.Query(db, &dest)
require.Equal(t, len(dest), 599)
require.NoError(t, err)
}
//func TestWITH_And_INSERT(t *testing.T) {
// paymentsToInsert := CTE("payments_to_insert")
//
// stmt := WITH(
// paymentsToInsert.AS(
// SELECT(Payment.AllColumns).
// FROM(Payment).
// WHERE(Payment.Amount.LT(Float(0.5))),
// ),
// )(
// Payment.INSERT(Payment.AllColumns).
// QUERY(
// SELECT(paymentsToInsert.AllColumns()).
// FROM(paymentsToInsert),
// ).ON_DUPLICATE_KEY_UPDATE(
// Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))),
// ),
// )
//
// //fmt.Println(stmt.DebugSql())
//
// tx, err := db.Begin()
// require.NoError(t, err)
// defer tx.Rollback()
//
// testutils.AssertExec(t, stmt, tx, 24)
//}
func TestWITH_And_UPDATE(t *testing.T) {
if sourceIsMariaDB() {
return
}
paymentsToUpdate := CTE("payments_to_update")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate)
stmt := WITH(
paymentsToUpdate.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.UPDATE().
SET(Payment.Amount.SET(Float(0.0))).
WHERE(Payment.PaymentID.IN(
SELECT(paymentsToDeleteID).
FROM(paymentsToUpdate),
),
),
)
//fmt.Println(stmt.DebugSql())
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx)
}
func TestWITH_And_DELETE(t *testing.T) {
if sourceIsMariaDB() {
return
}
paymentsToDelete := CTE("payments_to_delete")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete)
stmt := WITH(
paymentsToDelete.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.DELETE().
WHERE(Payment.PaymentID.IN(
SELECT(paymentsToDeleteID).
FROM(paymentsToDelete),
),
),
)
//fmt.Println(stmt.DebugSql())
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx, 24)
}

Some files were not shown because too many files have changed in this diff Show more