From 16258ee18b3ecf046d5f4726f0d3a5109d3139fc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 25 Apr 2023 21:59:33 +0000 Subject: [PATCH 01/30] Bump github.com/go-sql-driver/mysql from 1.7.0 to 1.7.1 Bumps [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql) from 1.7.0 to 1.7.1. - [Release notes](https://github.com/go-sql-driver/mysql/releases) - [Changelog](https://github.com/go-sql-driver/mysql/blob/master/CHANGELOG.md) - [Commits](https://github.com/go-sql-driver/mysql/compare/v1.7.0...v1.7.1) --- updated-dependencies: - dependency-name: github.com/go-sql-driver/mysql dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 685bebb..c957367 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/go-jet/jet/v2 go 1.11 require ( - github.com/go-sql-driver/mysql v1.7.0 + github.com/go-sql-driver/mysql v1.7.1 github.com/google/uuid v1.3.0 github.com/jackc/pgconn v1.14.0 github.com/lib/pq v1.10.8 diff --git a/go.sum b/go.sum index 2c11a02..3a06634 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ github.com/friendsofgo/errors v0.9.2 h1:X6NYxef4efCBdwI7BgS820zFaN7Cphrmb+Pljdzj github.com/friendsofgo/errors v0.9.2/go.mod h1:yCvFW5AkDIL9qn7suHVLiI/gH228n7PC4Pn44IGoTOI= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= -github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= From 8bbfa2c29027540e8a20ef0f4f84f8ce2ef38118 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 21:10:23 +0000 Subject: [PATCH 02/30] Bump github.com/jackc/pgconn from 1.14.0 to 1.14.1 Bumps [github.com/jackc/pgconn](https://github.com/jackc/pgconn) from 1.14.0 to 1.14.1. - [Changelog](https://github.com/jackc/pgconn/blob/master/CHANGELOG.md) - [Commits](https://github.com/jackc/pgconn/compare/v1.14.0...v1.14.1) --- updated-dependencies: - dependency-name: github.com/jackc/pgconn dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 685bebb..9b6c7cc 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.11 require ( github.com/go-sql-driver/mysql v1.7.0 github.com/google/uuid v1.3.0 - github.com/jackc/pgconn v1.14.0 + github.com/jackc/pgconn v1.14.1 github.com/lib/pq v1.10.8 github.com/mattn/go-sqlite3 v1.14.16 ) diff --git a/go.sum b/go.sum index 2c11a02..ebf7f90 100644 --- a/go.sum +++ b/go.sum @@ -43,8 +43,9 @@ github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsU github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q= github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgconn v1.14.1 h1:smbxIaZA08n6YuxEX1sDyjV/qkbtUtkH20qLkR9MUR4= +github.com/jackc/pgconn v1.14.1/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= From 39f9996b345137fc34499d6c9e406b3e9691f31b Mon Sep 17 00:00:00 2001 From: quirell Date: Wed, 2 Aug 2023 10:03:59 +0900 Subject: [PATCH 03/30] Add DefaultAlias option to TableSQLBuilder --- generator/template/file_templates.go | 2 +- generator/template/sql_builder_template.go | 8 ++++++++ tests/mysql/generator_template_test.go | 23 ++++++++++++++++++++++ tests/postgres/generator_template_test.go | 23 ++++++++++++++++++++++ 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index c2dc33b..731b1af 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -24,7 +24,7 @@ import ( "github.com/go-jet/jet/v2/{{dialect.PackageName}}" ) -var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "") +var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "{{tableTemplate.DefaultAlias}}") type {{structImplName}} struct { {{dialect.PackageName}}.Table diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index 4b68f21..c9ba7c4 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -59,6 +59,7 @@ type TableSQLBuilder struct { FileName string InstanceName string TypeName string + DefaultAlias string Column func(columnMetaData metadata.Column) TableSQLBuilderColumn } @@ -72,6 +73,7 @@ func DefaultTableSQLBuilder(tableMetaData metadata.Table) TableSQLBuilder { FileName: dbidentifier.ToGoFileName(tableMetaData.Name), InstanceName: dbidentifier.ToGoIdentifier(tableMetaData.Name), TypeName: dbidentifier.ToGoIdentifier(tableMetaData.Name) + "Table", + DefaultAlias: "", Column: DefaultTableSQLBuilderColumn, } } @@ -112,6 +114,12 @@ func (tb TableSQLBuilder) UseTypeName(name string) TableSQLBuilder { return tb } +// UseDefaultAlias returns new TableSQLBuilder with new default alias set +func (tb TableSQLBuilder) UseDefaultAlias(defaultAlias string) TableSQLBuilder { + tb.DefaultAlias = defaultAlias + return tb +} + // UseColumn returns new TableSQLBuilder with new column template function set func (tb TableSQLBuilder) UseColumn(columnsFunc func(column metadata.Column) TableSQLBuilderColumn) TableSQLBuilder { tb.Column = columnsFunc diff --git a/tests/mysql/generator_template_test.go b/tests/mysql/generator_template_test.go index a47a21d..a257f31 100644 --- a/tests/mysql/generator_template_test.go +++ b/tests/mysql/generator_template_test.go @@ -285,6 +285,29 @@ func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { require.Contains(t, mpaaRating, "var FilmRatingEnumSQLBuilder = &struct {") } +func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection("dvds"), + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + if table.Name == "actor" { + return template.DefaultTableSQLBuilder(table).UseDefaultAlias("actors") + } + return template.DefaultTableSQLBuilder(table) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "actor.go") + require.Contains(t, actor, "var Actor = newActorTable(\"dvds\", \"actor\", \"actors\")") +} + func TestGeneratorTemplate_Model_AddTags(t *testing.T) { err := mysql2.Generate( diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 4f2f0da..2bc2d32 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -341,6 +341,29 @@ func UseSchema(schema string) { `) } +func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + if table.Name == "actor" { + return template.DefaultTableSQLBuilder(table).UseDefaultAlias("actors") + } + return template.DefaultTableSQLBuilder(table) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "actor.go") + require.Contains(t, actor, "var Actor = newActorTable(\"dvds\", \"actor\", \"actors\")") +} + func TestGeneratorTemplate_Model_AddTags(t *testing.T) { err := postgres.Generate( From a3267eb6c1cbf33fb59094d07e3ad445f0e4f53e Mon Sep 17 00:00:00 2001 From: ryym Date: Sun, 20 Aug 2023 10:22:18 +0900 Subject: [PATCH 04/30] Fix typo in README gen/ -> .gen/ --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 23b1386..e6b1891 100644 --- a/README.md +++ b/README.md @@ -105,23 +105,23 @@ Done ``` Procedure is similar for MySQL, CockroachDB, MariaDB and SQLite. For example: ```sh -jet -source=mysql -dsn="user:pass@tcp(localhost:3306)/dbname" -path=./gen +jet -source=mysql -dsn="user:pass@tcp(localhost:3306)/dbname" -path=./.gen jet -dsn=postgres://user:pass@localhost:26257/jetdb?sslmode=disable -schema=dvds -path=./.gen #cockroachdb -jet -dsn="mariadb://user:pass@tcp(localhost:3306)/dvds" -path=./gen # source flag can be omitted if data source appears in dsn -jet -source=sqlite -dsn="/path/to/sqlite/database/file" -schema=dvds -path=./gen -jet -dsn="file:///path/to/sqlite/database/file" -schema=dvds -path=./gen # sqlite database assumed for 'file' data sources +jet -dsn="mariadb://user:pass@tcp(localhost:3306)/dvds" -path=./.gen # source flag can be omitted if data source appears in dsn +jet -source=sqlite -dsn="/path/to/sqlite/database/file" -schema=dvds -path=./.gen +jet -dsn="file:///path/to/sqlite/database/file" -schema=dvds -path=./.gen # sqlite database assumed for 'file' data sources ``` _*User has to have a permission to read information schema tables._ As command output suggest, Jet will: - connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema -- delete everything in schema destination folder - `./gen/jetdb/dvds`, +- delete everything in schema destination folder - `./.gen/jetdb/dvds`, - and finally generate SQL Builder and Model types for each schema table, view and enum. Generated files folder structure will look like this: ```sh -|-- gen # -path +|-- .gen # -path | `-- jetdb # database name | `-- dvds # schema name | |-- enum # sql builder package for enums @@ -156,7 +156,7 @@ import ( . "github.com/go-jet/jet/v2/examples/quick-start/.gen/jetdb/dvds/table" . "github.com/go-jet/jet/v2/postgres" - "github.com/go-jet/jet/v2/examples/quick-start/gen/jetdb/dvds/model" + "github.com/go-jet/jet/v2/examples/quick-start/.gen/jetdb/dvds/model" ) ``` Let's say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English' From db808f136b6452f5b7b2caa1748054006c3adc8c Mon Sep 17 00:00:00 2001 From: Yosyp Buchma Date: Mon, 18 Sep 2023 17:35:11 +0300 Subject: [PATCH 05/30] Faster MySQL GetTableColumnsMetaData query --- generator/mysql/query_set.go | 49 +++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 9eb9257..442b74b 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -39,27 +39,36 @@ ORDER BY table_name; func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { query := ` -SELECT COLUMN_NAME AS "column.Name", - IS_NULLABLE = "YES" AS "column.IsNullable", - columns.COLUMN_COMMENT as "column.Comment", - (EXISTS( - SELECT 1 +SELECT + col.COLUMN_NAME AS "column.Name", + col.IS_NULLABLE = "YES" AS "column.IsNullable", + col.COLUMN_COMMENT AS "column.Comment", + pk.IsPrimaryKey AS "column.IsPrimaryKey", + IF (col.COLUMN_TYPE = 'tinyint(1)', + 'boolean', + IF (col.DATA_TYPE = 'enum', + CONCAT(col.TABLE_NAME, '_', col.COLUMN_NAME), + col.DATA_TYPE) + ) AS "dataType.Name", + IF (col.DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind", + col.COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned" +FROM + information_schema.columns AS col +LEFT JOIN ( + SELECT k.column_name, 1 AS IsPrimaryKey FROM information_schema.table_constraints t - JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name) - WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY' AND k.column_name = columns.column_name - )) AS "column.IsPrimaryKey", - IF (COLUMN_TYPE = 'tinyint(1)', - 'boolean', - IF (DATA_TYPE='enum', - CONCAT(TABLE_NAME, '_', COLUMN_NAME), - DATA_TYPE) - ) AS "dataType.Name", - IF (DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind", - COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned" -FROM information_schema.columns -WHERE table_schema = ? AND table_name = ? -ORDER BY ordinal_position; -` + JOIN information_schema.key_column_usage k USING(constraint_name, table_schema, table_name) + WHERE t.table_schema = ? + AND t.table_name = ? + AND t.constraint_type = 'PRIMARY KEY' +) AS pk ON col.COLUMN_NAME = pk.column_name +WHERE + col.table_schema = ? + AND col.table_name = ? +ORDER BY + col.ordinal_position; + ` + var columns []metadata.Column _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) if err != nil { From 98dfce2ae587ebd9b2a2064f457af464e19b9ded Mon Sep 17 00:00:00 2001 From: Yosyp Buchma Date: Mon, 18 Sep 2023 17:35:58 +0300 Subject: [PATCH 06/30] Concurrent GetTableColumnsMetaData for MySQL --- generator/mysql/query_set.go | 39 ++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 442b74b..95d6a93 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -4,7 +4,9 @@ import ( "context" "database/sql" "fmt" + "runtime" "strings" + "sync" "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/qrm" @@ -27,11 +29,40 @@ ORDER BY table_name; return nil, fmt.Errorf("failed to query %s metadata result: %w", tableType, err) } + tblChan := make(chan int, len(tables)) + errChan := make(chan error, 1) + + wg := sync.WaitGroup{} + for i := 0; i < runtime.NumCPU(); i++ { + wg.Add(1) + go func() { + defer wg.Done() + var err1 error + for tblIdx := range tblChan { + tables[tblIdx].Columns, err1 = m.GetTableColumnsMetaData(db, schemaName, tables[tblIdx].Name) + if err1 != nil { + select { + case errChan <- fmt.Errorf("failed to get '%s' table columns metadata: %w", tables[tblIdx].Name, err1): + return + default: + } + return + } + } + }() + } + for i := range tables { - tables[i].Columns, err = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name) - if err != nil { - return nil, fmt.Errorf("failed to get '%s' table columns metadata: %w", tables[i].Name, err) - } + tblChan <- i + } + + close(tblChan) + wg.Wait() + + select { + case err = <-errChan: + return nil, err + default: } return tables, nil From ffabf8b26e90b94fb019988d06dbca26d6aceece Mon Sep 17 00:00:00 2001 From: Yosyp Buchma Date: Tue, 19 Sep 2023 20:56:21 +0300 Subject: [PATCH 07/30] coalesce pk.IsPrimaryKey --- generator/mysql/query_set.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 95d6a93..f886357 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -74,7 +74,7 @@ SELECT col.COLUMN_NAME AS "column.Name", col.IS_NULLABLE = "YES" AS "column.IsNullable", col.COLUMN_COMMENT AS "column.Comment", - pk.IsPrimaryKey AS "column.IsPrimaryKey", + COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey", IF (col.COLUMN_TYPE = 'tinyint(1)', 'boolean', IF (col.DATA_TYPE = 'enum', From f472becd892d3cfca593ecca62aa92389e6f3ebb Mon Sep 17 00:00:00 2001 From: Yosyp Buchma Date: Tue, 19 Sep 2023 20:56:54 +0300 Subject: [PATCH 08/30] simplified concurrent querying --- generator/mysql/query_set.go | 49 ++++++++++-------------------------- go.mod | 1 + go.sum | 2 ++ 3 files changed, 16 insertions(+), 36 deletions(-) diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index f886357..b58ff82 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -4,12 +4,11 @@ import ( "context" "database/sql" "fmt" - "runtime" "strings" - "sync" "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/qrm" + "golang.org/x/sync/errgroup" ) // mySqlQuerySet is dialect query set for MySQL @@ -29,43 +28,21 @@ ORDER BY table_name; return nil, fmt.Errorf("failed to query %s metadata result: %w", tableType, err) } - tblChan := make(chan int, len(tables)) - errChan := make(chan error, 1) + const maxConns = 32 + db.SetMaxOpenConns(maxConns) + db.SetMaxIdleConns(maxConns) - wg := sync.WaitGroup{} - for i := 0; i < runtime.NumCPU(); i++ { - wg.Add(1) - go func() { - defer wg.Done() - var err1 error - for tblIdx := range tblChan { - tables[tblIdx].Columns, err1 = m.GetTableColumnsMetaData(db, schemaName, tables[tblIdx].Name) - if err1 != nil { - select { - case errChan <- fmt.Errorf("failed to get '%s' table columns metadata: %w", tables[tblIdx].Name, err1): - return - default: - } - return - } - } - }() + wg := errgroup.Group{} + for i := 0; i < len(tables); i++ { + i := i + wg.Go(func() (err1 error) { + tables[i].Columns, err1 = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + return err1 + }) } - for i := range tables { - tblChan <- i - } - - close(tblChan) - wg.Wait() - - select { - case err = <-errChan: - return nil, err - default: - } - - return tables, nil + err = wg.Wait() + return tables, err } func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { diff --git a/go.mod b/go.mod index 685bebb..d0a810a 100644 --- a/go.mod +++ b/go.mod @@ -18,5 +18,6 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.2 github.com/volatiletech/null/v8 v8.1.2 + golang.org/x/sync v0.3.0 gopkg.in/guregu/null.v4 v4.0.0 ) diff --git a/go.sum b/go.sum index 2c11a02..4b06a0a 100644 --- a/go.sum +++ b/go.sum @@ -182,6 +182,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= From dd8d043cb88440207670da9604777e994a78e830 Mon Sep 17 00:00:00 2001 From: Yosyp Buchma Date: Wed, 20 Sep 2023 13:13:50 +0300 Subject: [PATCH 09/30] moved mysql connection pool config to openConection func --- generator/mysql/mysql_generator.go | 5 +++++ generator/mysql/query_set.go | 4 ---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index 635bcbc..7495bec 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -12,6 +12,8 @@ import ( mysqldr "github.com/go-sql-driver/mysql" ) +const mysqlMaxConns = 10 + // DBConnection contains MySQL connection details type DBConnection struct { Host string @@ -83,6 +85,9 @@ func openConnection(connectionString string) (*sql.DB, error) { return nil, fmt.Errorf("failed to open mysql connection: %w", err) } + db.SetMaxOpenConns(mysqlMaxConns) + db.SetMaxIdleConns(mysqlMaxConns) + err = db.Ping() if err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index b58ff82..80406bd 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -28,10 +28,6 @@ ORDER BY table_name; return nil, fmt.Errorf("failed to query %s metadata result: %w", tableType, err) } - const maxConns = 32 - db.SetMaxOpenConns(maxConns) - db.SetMaxIdleConns(maxConns) - wg := errgroup.Group{} for i := 0; i < len(tables); i++ { i := i From f16f0b5e5dfc13445255a79688ac2167e8f33dcd Mon Sep 17 00:00:00 2001 From: Matthew Dowdell Date: Thu, 30 Nov 2023 07:52:54 +0000 Subject: [PATCH 10/30] Add support for OF in row lock clauses This adds support for statements such as `SELECT ... FOR UPDATE OF table NOWAIT` where `OF table` could not be specified previously. Fixes #285. --- internal/jet/select_lock.go | 26 +++++++++++++++++++++++++- mysql/select_statement_test.go | 5 +++++ postgres/select_statement_test.go | 6 ++++++ tests/postgres/select_test.go | 12 ++++++++++++ 4 files changed, 48 insertions(+), 1 deletion(-) diff --git a/internal/jet/select_lock.go b/internal/jet/select_lock.go index 4694ee3..c9d431a 100644 --- a/internal/jet/select_lock.go +++ b/internal/jet/select_lock.go @@ -4,12 +4,14 @@ package jet type RowLock interface { Serializer + OF(...Table) RowLock NOWAIT() RowLock SKIP_LOCKED() RowLock } type selectLockImpl struct { lockStrength string + of []Table noWait, skipLocked bool } @@ -20,10 +22,15 @@ func NewRowLock(name string) func() RowLock { } } -func newSelectLock(lockStrength string) RowLock { +func newSelectLock(lockStrength string) *selectLockImpl { return &selectLockImpl{lockStrength: lockStrength} } +func (s *selectLockImpl) OF(tables ...Table) RowLock { + s.of = tables + return s +} + func (s *selectLockImpl) NOWAIT() RowLock { s.noWait = true return s @@ -37,6 +44,23 @@ func (s *selectLockImpl) SKIP_LOCKED() RowLock { func (s *selectLockImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(s.lockStrength) + if len(s.of) > 0 { + out.WriteString("OF") + + for i, of := range s.of { + if i > 0 { + out.WriteString(", ") + } + + table := of.Alias() + if table == "" { + table = of.TableName() + } + + out.WriteIdentifier(table) + } + } + if s.noWait { out.WriteString("NOWAIT") } diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go index bd3a0e9..1136dd7 100644 --- a/mysql/select_statement_test.go +++ b/mysql/select_statement_test.go @@ -122,6 +122,11 @@ FOR UPDATE; SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR SHARE NOWAIT; +`) + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE().OF(table1).NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE OF table1 NOWAIT; `) } diff --git a/postgres/select_statement_test.go b/postgres/select_statement_test.go index b487f90..5dcacf0 100644 --- a/postgres/select_statement_test.go +++ b/postgres/select_statement_test.go @@ -132,5 +132,11 @@ FOR KEY SHARE NOWAIT; SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR NO KEY UPDATE SKIP LOCKED; +`) + + assertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE().OF(table1, table2).NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE OF table1, table2 NOWAIT; `) } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index a6d9588..1f686b2 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -2251,6 +2251,18 @@ FOR` } } +func TestRowLockWithJoins(t *testing.T) { + query := SELECT(STAR). + FROM( + Film. + INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)). + LEFT_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID))). + LIMIT(1). + FOR(UPDATE().OF(Film, FilmCategory).NOWAIT()) + + testutils.AssertExecAndRollback(t, query, db, 1) +} + func TestQuickStart(t *testing.T) { var expectedSQL = ` From b6d57075e821ff33ad74a83344db90e491ec10df Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 1 Feb 2024 14:43:12 +0100 Subject: [PATCH 11/30] Additional tests for row lock UPDATE OF use case. --- tests/mysql/select_test.go | 81 ++++++++++++++++++++++++++++++++++ tests/postgres/select_test.go | 83 +++++++++++++++++++++++++++++++---- 2 files changed, 155 insertions(+), 9 deletions(-) diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 4009bb8..deb4529 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -2,6 +2,7 @@ package mysql import ( "context" + "database/sql" "strings" "testing" "time" @@ -632,6 +633,86 @@ FOR` } } +func TestRowLockWithUpdateOf(t *testing.T) { + skipForMariaDB(t) // MariaDB does not support UPDATE OF + + stmt := SELECT( + Film.FilmID, + Film.Title, + Actor.ActorID, + Actor.FirstName, + ).FROM( + Film. + INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)). + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)), + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(Film, Actor).NOWAIT(), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name" +FROM dvds.film + INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id) + INNER JOIN dvds.film_actor ON (film_actor.film_id = film.film_id) + INNER JOIN dvds.actor ON (actor.actor_id = film_actor.actor_id) +LIMIT 1 +FOR UPDATE OF film, actor NOWAIT; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film + CategoryID int + Actor []model.Actor + } + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) +} + +func TestRowLockWithUpdateOfAliasedTable(t *testing.T) { + skipForMariaDB(t) // MariaDB does not support UPDATE OF + + myFilm := Film.AS("myFilm") + + stmt := SELECT( + myFilm.FilmID, + myFilm.Title, + ).FROM( + myFilm, + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(myFilm), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +SELECT ''myFilm''.film_id AS "myFilm.film_id", + ''myFilm''.title AS "myFilm.title" +FROM dvds.film AS ''myFilm'' +LIMIT 1 +FOR UPDATE OF ''myFilm''; +`, "''", "`")) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film `alias:"myFilm.*"` + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) +} + func TestExpressionWrappers(t *testing.T) { query := SELECT( BoolExp(Raw("true")), diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 1f686b2..dcf8993 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -2251,16 +2251,81 @@ FOR` } } -func TestRowLockWithJoins(t *testing.T) { - query := SELECT(STAR). - FROM( - Film. - INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)). - LEFT_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID))). - LIMIT(1). - FOR(UPDATE().OF(Film, FilmCategory).NOWAIT()) +func TestRowLockWithUpdateOf(t *testing.T) { + stmt := SELECT( + Film.FilmID, + Film.Title, + Actor.ActorID, + Actor.FirstName, + ).FROM( + Film. + INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)). + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)), + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(Film, Actor).NOWAIT(), + ) - testutils.AssertExecAndRollback(t, query, db, 1) + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name" +FROM dvds.film + INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id) + INNER JOIN dvds.film_actor ON (film_actor.film_id = film.film_id) + INNER JOIN dvds.actor ON (actor.actor_id = film_actor.actor_id) +LIMIT 1 +FOR UPDATE OF film, actor NOWAIT; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film + CategoryID int + Actor []model.Actor + } + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) +} + +func TestRowLockWithUpdateOfAliasedTable(t *testing.T) { + + myFilm := Film.AS("myFilm") + + stmt := SELECT( + myFilm.FilmID, + myFilm.Title, + ).FROM( + myFilm, + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(myFilm), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT "myFilm".film_id AS "myFilm.film_id", + "myFilm".title AS "myFilm.title" +FROM dvds.film AS "myFilm" +LIMIT 1 +FOR UPDATE OF "myFilm"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film `alias:"myFilm.*"` + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) } func TestQuickStart(t *testing.T) { From 7f48e9fb67070e4c029ae6b1f2a87d97e2d96cdb Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 1 Feb 2024 15:20:49 +0100 Subject: [PATCH 12/30] Add support for materialized views. --- generator/metadata/column_meta_data.go | 1 + generator/postgres/query_set.go | 83 +++++++++++++--------- generator/template/sql_builder_template.go | 15 ++-- tests/docker-compose.yaml | 4 +- tests/postgres/alltypes_test.go | 51 ++++++++++--- tests/postgres/generator_test.go | 22 ++++-- tests/postgres/sample_test.go | 12 +++- tests/testdata | 2 +- 8 files changed, 127 insertions(+), 63 deletions(-) diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index 55533f4..679f6c1 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -33,6 +33,7 @@ const ( EnumType DataTypeKind = "enum" UserDefinedType DataTypeKind = "user-defined" ArrayType DataTypeKind = "array" + RangeType DataTypeKind = "range" ) // DataType contains information about column data type diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 856173d..2d8835b 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -26,8 +26,25 @@ ORDER BY table_name; return nil, fmt.Errorf("failed to query %s metadata: %w", tableType, err) } + // add materialized views separately, because materialized views are not part of standard information schema + if tableType == metadata.ViewTable { + matViewQuery := ` + select matviewname as "table.name" + from pg_matviews + where schemaname = $1; + ` + var matViews []metadata.Table + + _, err := qrm.Query(context.Background(), db, matViewQuery, []interface{}{schemaName}, &matViews) + if err != nil { + return nil, fmt.Errorf("failed to query materialized view metadata: %w", err) + } + + tables = append(tables, matViews...) + } + for i := range tables { - tables[i].Columns, err = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + tables[i].Columns, err = getColumnsMetaData(db, schemaName, tables[i].Name) if err != nil { return nil, fmt.Errorf("failed to query %s columns metadata: %w", tableType, err) } @@ -36,39 +53,39 @@ ORDER BY table_name; return tables, nil } -func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { +func getColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { query := ` -WITH primaryKeys AS ( - SELECT column_name - FROM information_schema.key_column_usage AS c - LEFT JOIN information_schema.table_constraints AS t - ON t.constraint_name = c.constraint_name AND - c.table_schema = t.table_schema AND - c.table_name = t.table_name - WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY' -) -SELECT column_name as "column.Name", - is_nullable = 'YES' as "column.isNullable", - is_generated = 'ALWAYS' or is_generated = 'YES' as "column.isGenerated", - (EXISTS(SELECT 1 from primaryKeys as pk where pk.column_name = columns.column_name)) as "column.IsPrimaryKey", - dataType.kind as "dataType.Kind", - (case dataType.Kind when 'base' then data_type else LTRIM(udt_name, '_') end) as "dataType.Name", - FALSE as "dataType.isUnsigned" -FROM information_schema.columns, - LATERAL (select (case data_type - when 'ARRAY' then 'array' - when 'USER-DEFINED' then - case (select t.typtype - from pg_type as t - join pg_namespace as p on p.oid = t.typnamespace - where t.typname = columns.udt_name and p.nspname = $1) - when 'e' then 'enum' - else 'user-defined' - end - else 'base' - end) as Kind) as dataType -where table_schema = $1 and table_name = $2 -order by ordinal_position; +select + attr.attname as "column.Name", + exists( + select 1 + from pg_catalog.pg_index indx + where attr.attrelid = indx.indrelid and attr.attnum = any(indx.indkey) and indx.indisprimary + ) as "column.IsPrimaryKey", + not attr.attnotnull as "column.isNullable", + attr.attgenerated = 's' as "column.isGenerated", + (case tp.typtype + when 'b' then 'base' + when 'd' then 'base' + when 'e' then 'enum' + when 'r' then 'range' + end) as "dataType.Kind", + (case when tp.typtype = 'd' then (select pg_type.typname from pg_catalog.pg_type where pg_type.oid = tp.typbasetype) + when tp.typcategory = 'A' then pg_catalog.format_type(attr.atttypid, attr.atttypmod) + else tp.typname + end) as "dataType.Name", + false as "dataType.isUnsigned" +from pg_catalog.pg_attribute as attr + join pg_catalog.pg_class as cls on cls.oid = attr.attrelid + join pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace + join pg_catalog.pg_type as tp on tp.oid = attr.atttypid +where + ns.nspname = $1 and + cls.relname = $2 and + not attr.attisdropped and + attr.attnum > 0 +order by + attr.attnum; ` var columns []metadata.Column _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index c9ba7c4..869b8e5 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -142,14 +142,15 @@ func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilde // getSqlBuilderColumnType returns type of jet sql builder column func getSqlBuilderColumnType(columnMetaData metadata.Column) string { - if columnMetaData.DataType.Kind != metadata.BaseType { + if columnMetaData.DataType.Kind != metadata.BaseType && + columnMetaData.DataType.Kind != metadata.RangeType { return "String" } switch strings.ToLower(columnMetaData.DataType.Name) { - case "boolean": + case "boolean", "bool": return "Bool" - case "smallint", "integer", "bigint", + case "smallint", "integer", "bigint", "int2", "int4", "int8", "tinyint", "mediumint", "int", "year": //MySQL return "Integer" case "date": @@ -157,21 +158,21 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string { case "timestamp without time zone", "timestamp", "datetime": //MySQL: return "Timestamp" - case "timestamp with time zone": + case "timestamp with time zone", "timestamptz": return "Timestampz" case "time without time zone", "time": //MySQL return "Time" - case "time with time zone": + case "time with time zone", "timetz": return "Timez" case "interval": return "Interval" case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", - "char", "varchar", "nvarchar", "binary", "varbinary", + "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL return "String" - case "real", "numeric", "decimal", "double precision", "float", + case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", "double": // MySQL return "Float" default: diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 9c562fb..9b3af50 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -39,14 +39,14 @@ services: - ./testdata/init/mysql:/docker-entrypoint-initdb.d cockroach: - image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 + image: cockroachdb/cockroach-unstable:v23.1.0-rc.2 environment: - COCKROACH_USER=jet - COCKROACH_PASSWORD=jet - COCKROACH_DATABASE=jetdb ports: - "26257:26257" - command: start-single-node --insecure + command: start-single-node --accept-sql-without-tls # volumes: # - ./testdata/init/cockroach:/docker-entrypoint-initdb.d diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 8a11191..d41feee 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -17,31 +17,52 @@ import ( "github.com/go-jet/jet/v2/tests/testdata/results/common" ) +var AllTypesAllColumns = AllTypes.AllColumns. + Except(IntegerColumn("rowid")) // cockroachDB: exclude rowid column + func TestAllTypesSelect(t *testing.T) { var dest []model.AllTypes - err := AllTypes.SELECT( - AllTypesAllColumns, - ).LIMIT(2). + err := AllTypes.SELECT(AllTypesAllColumns). + LIMIT(2). Query(db, &dest) - require.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes - var dest []AllTypesView - err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) - require.NoError(t, err) + err := SELECT(view.AllTypesView.AllColumns). + FROM(view.AllTypesView). + Query(db, &dest) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], AllTypesView(allTypesRow0)) testutils.AssertDeepEqual(t, dest[1], AllTypesView(allTypesRow1)) } +func TestMaterializedViewAllTypes(t *testing.T) { + stmt := SELECT( + view.AllTypesMaterializedView.AllColumns. + Except(IntegerColumn("rowid")), // cockroachDB: exclude rowid column + ).FROM( + view.AllTypesMaterializedView, + ) + + type AllTypesMaterializedView model.AllTypes + var dest []AllTypesMaterializedView + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest[0], AllTypesMaterializedView(allTypesRow0)) + testutils.AssertDeepEqual(t, dest[1], AllTypesMaterializedView(allTypesRow1)) +} + func TestAllTypesInsertModel(t *testing.T) { skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008) @@ -64,8 +85,6 @@ func TestAllTypesInsertModel(t *testing.T) { }) } -var AllTypesAllColumns = AllTypes.AllColumns.Except(IntegerColumn("rowid")) - func TestAllTypesInsertQuery(t *testing.T) { query := AllTypes.INSERT(AllTypesAllColumns). QUERY( @@ -230,7 +249,9 @@ SELECT "allTypesSubQuery"."all_types.small_int_ptr" AS "all_types.small_int_ptr" "allTypesSubQuery"."all_types.text_array" AS "all_types.text_array", "allTypesSubQuery"."all_types.jsonb_array" AS "all_types.jsonb_array", "allTypesSubQuery"."all_types.text_multi_dim_array_ptr" AS "all_types.text_multi_dim_array_ptr", - "allTypesSubQuery"."all_types.text_multi_dim_array" AS "all_types.text_multi_dim_array" + "allTypesSubQuery"."all_types.text_multi_dim_array" AS "all_types.text_multi_dim_array", + "allTypesSubQuery"."all_types.mood_ptr" AS "all_types.mood_ptr", + "allTypesSubQuery"."all_types.mood" AS "all_types.mood" FROM ( SELECT all_types.small_int_ptr AS "all_types.small_int_ptr", all_types.small_int AS "all_types.small_int", @@ -292,7 +313,9 @@ FROM ( all_types.text_array AS "all_types.text_array", all_types.jsonb_array AS "all_types.jsonb_array", all_types.text_multi_dim_array_ptr AS "all_types.text_multi_dim_array_ptr", - all_types.text_multi_dim_array AS "all_types.text_multi_dim_array" + all_types.text_multi_dim_array AS "all_types.text_multi_dim_array", + all_types.mood_ptr AS "all_types.mood_ptr", + all_types.mood AS "all_types.mood" FROM test_sample.all_types ) AS "allTypesSubQuery" LIMIT 2; @@ -1279,6 +1302,8 @@ RETURNING all_types.json AS "all_types.json"; }) } +var moodSad = model.Mood_Sad + var allTypesRow0 = model.AllTypes{ SmallIntPtr: testutils.Int16Ptr(14), SmallInt: 14, @@ -1343,6 +1368,8 @@ var allTypesRow0 = model.AllTypes{ JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", + MoodPtr: &moodSad, + Mood: model.Mood_Happy, } var allTypesRow1 = model.AllTypes{ @@ -1409,6 +1436,8 @@ var allTypesRow1 = model.AllTypes{ JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, TextMultiDimArrayPtr: nil, TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", + MoodPtr: nil, + Mood: model.Mood_Ok, } func TestAliasedDuplicateSliceSubType(t *testing.T) { diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index a0257e7..5aaaa31 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -548,11 +548,12 @@ func UseSchema(schema string) { ` func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { - skipForCockroachDB(t) + skipForCockroachDB(t) // because of rowid column enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/") modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/") tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/") + viewDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/view/") testutils.AssertFileNamesEqual(t, enumDir, "mood.go", "level.go") testutils.AssertFileContent(t, enumDir+"/mood.go", moodEnumContent) @@ -560,15 +561,16 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go", - "components.go", "vulnerabilities.go") - + "components.go", "vulnerabilities.go", "all_types_materialized_view.go") testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go", "components.go", "vulnerabilities.go") - testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) + + testutils.AssertFileNamesEqual(t, viewDir, "all_types_materialized_view.go", "all_types_view.go", + "view_use_schema.go") } var moodEnumContent = ` @@ -698,6 +700,8 @@ type AllTypes struct { JsonbArray string TextMultiDimArrayPtr *string TextMultiDimArray string + MoodPtr *Mood + Mood Mood } ` @@ -782,6 +786,8 @@ type allTypesTable struct { JsonbArray postgres.ColumnString TextMultiDimArrayPtr postgres.ColumnString TextMultiDimArray postgres.ColumnString + MoodPtr postgres.ColumnString + Mood postgres.ColumnString AllColumns postgres.ColumnList MutableColumns postgres.ColumnList @@ -883,8 +889,10 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { JsonbArrayColumn = postgres.StringColumn("jsonb_array") TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr") TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array") - allColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} - mutableColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} + MoodPtrColumn = postgres.StringColumn("mood_ptr") + MoodColumn = postgres.StringColumn("mood") + allColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn, MoodPtrColumn, MoodColumn} + mutableColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn, MoodPtrColumn, MoodColumn} ) return allTypesTable{ @@ -952,6 +960,8 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { JsonbArray: JsonbArrayColumn, TextMultiDimArrayPtr: TextMultiDimArrayPtrColumn, TextMultiDimArray: TextMultiDimArrayColumn, + MoodPtr: MoodPtrColumn, + Mood: MoodColumn, AllColumns: allColumns, MutableColumns: mutableColumns, diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 8bc8105..1df72dd 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -120,9 +120,15 @@ RETURNING floats.decimal_ptr AS "floats.decimal_ptr", } func TestUUIDComplex(t *testing.T) { - query := Person.INNER_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)). - SELECT(Person.AllColumns, PersonPhone.AllColumns). - ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC()) + query := SELECT( + Person.AllColumns, + PersonPhone.AllColumns, + ).FROM( + Person.INNER_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)), + ).ORDER_BY( + Person.PersonID.ASC(), + PersonPhone.PhoneID.ASC(), + ) t.Run("slice of structs", func(t *testing.T) { diff --git a/tests/testdata b/tests/testdata index 3398b97..b2f98e8 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 3398b9735b9d097d2ee0c282976726affc6b96f0 +Subproject commit b2f98e8297c34e86e02ada4226c125de53f64a6d From 1cbbf495db05533ac77a0231882e93d754720f39 Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 1 Feb 2024 17:36:12 +0100 Subject: [PATCH 13/30] Update circle.ci config. --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f83283c..132c49f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -33,8 +33,8 @@ jobs: MYSQL_USER: jet MYSQL_PASSWORD: jet - - image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 - command: ['start-single-node', '--insecure'] + - image: cockroachdb/cockroach-unstable:v23.1.0-rc.2 + command: ['start-single-node', '--accept-sql-without-tls'] environment: COCKROACH_USER: jet COCKROACH_PASSWORD: jet From 23cb5dcfbc0aac3bf36c78847133cd768a5c8751 Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 1 Feb 2024 17:46:11 +0100 Subject: [PATCH 14/30] Set status code if tests init command fails. --- tests/init/init.go | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/init/init.go b/tests/init/init.go index 2543eb1..5a21ee7 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -78,6 +78,7 @@ func main() { if err != nil { fmt.Println(errfmt.Trace(err)) + os.Exit(1) } } From 71fb1c7cd1dead9c3fa0f9bf6814470569699c8e Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 16 Dec 2023 11:43:40 +0100 Subject: [PATCH 15/30] Add support for sqlite generated columns. --- generator/sqlite/query_set.go | 36 +++++++++++++- internal/utils/semantic/version.go | 66 +++++++++++++++++++++++++ tests/sqlite/sample_test.go | 78 ++++++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 2 deletions(-) create mode 100644 internal/utils/semantic/version.go create mode 100644 tests/sqlite/sample_test.go diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index d1d0bf7..9cf6541 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/semantic" "github.com/go-jet/jet/v2/qrm" "strings" ) @@ -42,16 +43,45 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy return tables, nil } +func getTableInfoQuery(db *sql.DB) (string, error) { + var version string + err := db.QueryRow("select sqlite_version();").Scan(&version) + + if err != nil { + return "", fmt.Errorf("failed to get sqlite version: %w", err) + } + + sqliteVersion, err := semantic.VersionFromString(version) + + if err != nil { + return "", fmt.Errorf("can't parse sqlite version: %w", err) + } + + // generated columns were added in version 3.26.0 + if sqliteVersion.Lt(semantic.Version{Major: 3, Minor: 26, Patch: 0}) { + return `select * from pragma_table_info(?);`, nil + } + + return `select * from pragma_table_xinfo(?);`, nil +} + func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { - query := fmt.Sprintf(`select * from pragma_table_info(?);`) + + tableInfoQuery, err := getTableInfoQuery(db) + + if err != nil { + return nil, err + } + var columnInfos []struct { Name string Type string NotNull int32 Pk int32 + Hidden int32 } - _, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) + _, err = qrm.Query(context.Background(), db, tableInfoQuery, []interface{}{tableName}, &columnInfos) if err != nil { return nil, fmt.Errorf("failed to query '%s' column metadata: %w", tableName, err) } @@ -60,11 +90,13 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t for _, columnInfo := range columnInfos { columnType := getColumnType(columnInfo.Type) + isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column columns = append(columns, metadata.Column{ Name: columnInfo.Name, IsPrimaryKey: columnInfo.Pk != 0, IsNullable: columnInfo.NotNull != 1, + IsGenerated: isGenerated, DataType: metadata.DataType{ Name: columnType, Kind: metadata.BaseType, diff --git a/internal/utils/semantic/version.go b/internal/utils/semantic/version.go new file mode 100644 index 0000000..b13413f --- /dev/null +++ b/internal/utils/semantic/version.go @@ -0,0 +1,66 @@ +package semantic + +import ( + "fmt" + "strconv" + "strings" +) + +// Version struct holds semantic versioning information +type Version struct { + Major int + Minor int + Patch int +} + +// VersionFromString creates new semantic Version by parsing version string +func VersionFromString(version string) (Version, error) { + parts := strings.Split(version, ".") + + var ret Version + + if len(parts) > 0 { + major, err := strconv.Atoi(parts[0]) + + if err != nil { + return ret, fmt.Errorf("major is not a number: %w", err) + } + + ret.Major = major + } + + if len(parts) > 1 { + minor, err := strconv.Atoi(parts[1]) + + if err != nil { + return ret, fmt.Errorf("minor is not a number: %w", err) + } + + ret.Minor = minor + } + + if len(parts) > 2 { + patch, err := strconv.Atoi(parts[2]) + + if err != nil { + return ret, fmt.Errorf("patch is not a number: %w", err) + } + + ret.Patch = patch + } + + return ret, nil +} + +// Lt returns true if this version is less than version parameter +func (v Version) Lt(version Version) bool { + if v.Major < version.Major { + return true + } + + if v.Minor < version.Minor { + return true + } + + return v.Patch < version.Patch +} diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go new file mode 100644 index 0000000..1671f5d --- /dev/null +++ b/tests/sqlite/sample_test.go @@ -0,0 +1,78 @@ +package sqlite + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "testing" + + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" +) + +func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) { + + t.Run("should not have the generated column in mutableColumns", func(t *testing.T) { + require.Equal(t, 2, len(People.MutableColumns)) + require.Equal(t, People.PeopleName, People.MutableColumns[0]) + require.Equal(t, People.PeopleHeightCm, People.MutableColumns[1]) + }) + + t.Run("should query with all columns", func(t *testing.T) { + query := SELECT( + People.AllColumns, + ).FROM( + People, + ).WHERE( + People.PeopleID.EQ(Int(3)), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT people.people_id AS "people.people_id", + people.people_name AS "people.people_name", + people.people_height_cm AS "people.people_height_cm", + people.people_height_in AS "people.people_height_in" +FROM people +WHERE people.people_id = ?; +`) + var result model.People + + err := query.Query(sampleDB, &result) + require.NoError(t, err) + + require.Equal(t, "Carla", result.PeopleName) + require.Equal(t, 155., *result.PeopleHeightCm) + require.InEpsilon(t, 61.02, *result.PeopleHeightIn, 1e-3) + }) + + t.Run("should insert without generated columns", func(t *testing.T) { + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + insertQuery := People.INSERT( + People.MutableColumns, + ).MODEL( + model.People{ + PeopleName: "Dario", + PeopleHeightCm: testutils.Float64Ptr(120), + }, + ).RETURNING( + People.AllColumns, + ) + + testutils.AssertDebugStatementSql(t, insertQuery, ` +INSERT INTO people (people_name, people_height_cm) +VALUES ('Dario', 120) +RETURNING people.people_id AS "people.people_id", + people.people_name AS "people.people_name", + people.people_height_cm AS "people.people_height_cm", + people.people_height_in AS "people.people_height_in"; +`) + var result model.People + err := insertQuery.Query(tx, &result) + require.NoError(t, err) + + require.Equal(t, "Dario", result.PeopleName) + require.Equal(t, 120., *result.PeopleHeightCm) + }) + }) +} From 2eaa75345c86e75793c92f904204fdff5497c4a6 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 4 Feb 2024 18:45:48 +0100 Subject: [PATCH 16/30] [sqlite] Generated columns additional tests. --- generator/sqlite/query_set.go | 2 +- go.mod | 2 +- go.sum | 4 ++-- tests/sqlite/sample_test.go | 16 ++++++++++------ tests/sqlite/update_test.go | 4 ++-- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index 9cf6541..745aae4 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -89,7 +89,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t var columns []metadata.Column for _, columnInfo := range columnInfos { - columnType := getColumnType(columnInfo.Type) + columnType := strings.TrimSuffix(getColumnType(columnInfo.Type), " GENERATED ALWAYS") isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column columns = append(columns, metadata.Column{ diff --git a/go.mod b/go.mod index d0a810a..cac32e5 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/google/uuid v1.3.0 github.com/jackc/pgconn v1.14.0 github.com/lib/pq v1.10.8 - github.com/mattn/go-sqlite3 v1.14.16 + github.com/mattn/go-sqlite3 v1.14.17 ) // test dependencies diff --git a/go.sum b/go.sum index 4b06a0a..c9d541e 100644 --- a/go.sum +++ b/go.sum @@ -102,8 +102,8 @@ github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= -github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go index 1671f5d..4775eeb 100644 --- a/tests/sqlite/sample_test.go +++ b/tests/sqlite/sample_test.go @@ -32,7 +32,8 @@ func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) { SELECT people.people_id AS "people.people_id", people.people_name AS "people.people_name", people.people_height_cm AS "people.people_height_cm", - people.people_height_in AS "people.people_height_in" + people.people_height_inch AS "people.people_height_inch", + people.people_height_feet AS "people.people_height_feet" FROM people WHERE people.people_id = ?; `) @@ -43,7 +44,7 @@ WHERE people.people_id = ?; require.Equal(t, "Carla", result.PeopleName) require.Equal(t, 155., *result.PeopleHeightCm) - require.InEpsilon(t, 61.02, *result.PeopleHeightIn, 1e-3) + require.InEpsilon(t, 61.02, *result.PeopleHeightInch, 1e-3) }) t.Run("should insert without generated columns", func(t *testing.T) { @@ -53,7 +54,7 @@ WHERE people.people_id = ?; ).MODEL( model.People{ PeopleName: "Dario", - PeopleHeightCm: testutils.Float64Ptr(120), + PeopleHeightCm: testutils.Float64Ptr(190), }, ).RETURNING( People.AllColumns, @@ -61,18 +62,21 @@ WHERE people.people_id = ?; testutils.AssertDebugStatementSql(t, insertQuery, ` INSERT INTO people (people_name, people_height_cm) -VALUES ('Dario', 120) +VALUES ('Dario', 190) RETURNING people.people_id AS "people.people_id", people.people_name AS "people.people_name", people.people_height_cm AS "people.people_height_cm", - people.people_height_in AS "people.people_height_in"; + people.people_height_inch AS "people.people_height_inch", + people.people_height_feet AS "people.people_height_feet"; `) var result model.People err := insertQuery.Query(tx, &result) require.NoError(t, err) require.Equal(t, "Dario", result.PeopleName) - require.Equal(t, 120., *result.PeopleHeightCm) + require.Equal(t, 190., *result.PeopleHeightCm) + require.InEpsilon(t, float32(74.80314), *result.PeopleHeightInch, 1e-3) + require.InEpsilon(t, float32(6.233595), *result.PeopleHeightFeet, 1e-3) }) }) } diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go index 9ec1acf..110c659 100644 --- a/tests/sqlite/update_test.go +++ b/tests/sqlite/update_test.go @@ -183,8 +183,8 @@ RETURNING link.id AS "link.id", BinaryOperator: 31, CastOperator: "20", LikeOperator: false, - // IsNull: true, //TODO: uncomment when sqlite driver updates to sqlite version > 3.40.1 - CaseOperator: "unknown", + IsNull: true, + CaseOperator: "unknown", }) requireLogged(t, stmt) } From 64ad9de99e9f49f3aa08f958e6cfbd74168c8d25 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 4 Feb 2024 18:56:37 +0100 Subject: [PATCH 17/30] Fix circle.ci --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 132c49f..574a4f3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ jobs: build_and_tests: docker: # specify the version - - image: circleci/golang:1.16 + - image: circleci/golang:1.19 - image: circleci/postgres:12 environment: POSTGRES_USER: jet From 44e1b7f4d91110d93fc87f0a3ad66a5e3ae52141 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 4 Feb 2024 18:59:43 +0100 Subject: [PATCH 18/30] Fix circle.ci --- .circleci/config.yml | 2 +- tests/testdata | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 574a4f3..b0e39a8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ jobs: build_and_tests: docker: # specify the version - - image: circleci/golang:1.19 + - image: cimg/go:1.21.6 - image: circleci/postgres:12 environment: POSTGRES_USER: jet diff --git a/tests/testdata b/tests/testdata index b2f98e8..08bcfcb 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit b2f98e8297c34e86e02ada4226c125de53f64a6d +Subproject commit 08bcfcbb2e1eadfca54c4522802fc65f1fee865c From e03773a79ec5e78fdc2ddd9ceb8057dd0d58c0c6 Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 5 Feb 2024 11:33:26 +0100 Subject: [PATCH 19/30] Update README.md --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index e6b1891..847208c 100644 --- a/README.md +++ b/README.md @@ -121,9 +121,9 @@ As command output suggest, Jet will: Generated files folder structure will look like this: ```sh -|-- .gen # -path -| `-- jetdb # database name -| `-- dvds # schema name +|-- .gen # path +| -- jetdb # database name +| -- dvds # schema name | |-- enum # sql builder package for enums | | |-- mpaa_rating.go | |-- table # sql builder package for tables @@ -131,7 +131,7 @@ Generated files folder structure will look like this: | |-- address.go | |-- category.go | ... -| |-- view # sql builder package for views +| |-- view # sql builder package for views | |-- actor_info.go | |-- film_list.go | ... @@ -530,8 +530,8 @@ Automatic scan to arbitrary structure removes a lot of headache and boilerplate ##### Speed of execution -While ORM libraries can introduce significant performance penalties due to number of round-trips to the database, -Jet will always perform better as developers can write complex query and retrieve result with a single database call. +While ORM libraries can introduce significant performance penalties due to number of round-trips to the database(N+1 query problem), +`jet` will always perform better as developers can write complex query and retrieve result with a single database call. Thus handler time lost on latency between server and database can be constant. Handler execution will be proportional only to the query complexity and the number of rows returned from database. From e51ddd5506d71e8b40403d618b091347144a77a5 Mon Sep 17 00:00:00 2001 From: go-jet Date: Wed, 7 Feb 2024 11:07:50 +0100 Subject: [PATCH 20/30] Add support for FETCH FIRST clause. --- internal/jet/clause.go | 23 ++++++++++ postgres/select_statement.go | 41 +++++++++++++++-- tests/postgres/select_test.go | 84 +++++++++++++++++++++++++++++++++++ 3 files changed, 145 insertions(+), 3 deletions(-) diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 1619124..ed899b3 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -234,6 +234,29 @@ func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder, o } } +// ClauseFetch struct +type ClauseFetch struct { + Count IntegerExpression + WithTies bool +} + +// Serialize serializes ClauseFetch into sql builder output +func (o *ClauseFetch) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if is.Nil(o.Count) { + return + } + + out.NewLine() + out.WriteString("FETCH FIRST") + o.Count.serialize(statementType, out, options...) + + if o.WithTies { + out.WriteString("ROWS WITH TIES") + } else { + out.WriteString("ROWS ONLY") + } +} + // ClauseFor struct type ClauseFor struct { Lock RowLock diff --git a/postgres/select_statement.go b/postgres/select_statement.go index d44d6aa..0ee80bf 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -53,6 +53,7 @@ type SelectStatement interface { ORDER_BY(orderByClauses ...OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement + FETCH_FIRST(count IntegerExpression) fetchExpand FOR(lock RowLock) SelectStatement UNION(rhs SelectStatement) setStatement @@ -72,9 +73,18 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} - newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, - &newSelect.Limit, &newSelect.Offset, &newSelect.For) + newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, + &newSelect.Select, + &newSelect.From, + &newSelect.Where, + &newSelect.GroupBy, + &newSelect.Having, + &newSelect.Window, + &newSelect.OrderBy, + &newSelect.Limit, + &newSelect.Offset, + &newSelect.Fetch, + &newSelect.For) newSelect.Select.ProjectionList = projections if table != nil { @@ -101,6 +111,7 @@ type selectStatementImpl struct { OrderBy jet.ClauseOrderBy Limit jet.ClauseLimit Offset jet.ClauseOffset + Fetch jet.ClauseFetch For jet.ClauseFor } @@ -150,6 +161,14 @@ func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { return s } +func (s *selectStatementImpl) FETCH_FIRST(count IntegerExpression) fetchExpand { + s.Fetch.Count = count + + return fetchExpand{ + selectStatement: s, + } +} + func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { s.For.Lock = lock return s @@ -188,3 +207,19 @@ func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { } return ret } + +type fetchExpand struct { + selectStatement *selectStatementImpl +} + +func (f fetchExpand) ROWS_ONLY() SelectStatement { + f.selectStatement.Fetch.WithTies = false + + return f.selectStatement +} + +func (f fetchExpand) ROWS_WITH_TIES() SelectStatement { + f.selectStatement.Fetch.WithTies = true + + return f.selectStatement +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index dcf8993..d9b30ad 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -236,6 +236,90 @@ LIMIT 12; require.NoError(t, err) } +func TestFetchFirst(t *testing.T) { + + t.Run("rows only", func(t *testing.T) { + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID). + OFFSET(2). + FETCH_FIRST(Int(3)).ROWS_ONLY() + + testutils.AssertStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.actor_id +OFFSET $1 +FETCH FIRST $2 ROWS ONLY; +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 3) + require.Equal(t, dest[0].ActorID, int32(3)) + require.Equal(t, dest[2].ActorID, int32(5)) + }) + + t.Run("rows with ties", func(t *testing.T) { + skipForCockroachDB(t) // ROWS_WITH_TIES is not supported on cockroachdb + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.LastUpdate). + FETCH_FIRST(Int(3)).ROWS_WITH_TIES() + + testutils.AssertStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.last_update +FETCH FIRST $1 ROWS WITH TIES; +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 200) + }) + + t.Run("complex expression", func(t *testing.T) { + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.LastUpdate). + FETCH_FIRST(IntExp( + SELECT(MAX(Store.StoreID)). + FROM(Store), + )).ROWS_ONLY() + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.last_update +FETCH FIRST ( + SELECT MAX(store.store_id) + FROM dvds.store +) ROWS ONLY; +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 2) + }) +} + func TestJoinQueryStruct(t *testing.T) { expectedSQL := ` From c19b3e7ae1e8f1b74d17df55ea25ec39ec08ee6b Mon Sep 17 00:00:00 2001 From: go-jet Date: Wed, 7 Feb 2024 11:15:36 +0100 Subject: [PATCH 21/30] Update circle.ci postgres image. --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b0e39a8..269116b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,7 +9,7 @@ jobs: docker: # specify the version - image: cimg/go:1.21.6 - - image: circleci/postgres:12 + - image: cimg/postgres:16.1.0 environment: POSTGRES_USER: jet POSTGRES_PASSWORD: jet From a46f5c1bd667123594d2c2f37e2303fe152a9c29 Mon Sep 17 00:00:00 2001 From: go-jet Date: Wed, 7 Feb 2024 11:18:29 +0100 Subject: [PATCH 22/30] Update circle.ci postgres image. --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 269116b..e21f00a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,7 +9,7 @@ jobs: docker: # specify the version - image: cimg/go:1.21.6 - - image: cimg/postgres:16.1.0 + - image: cimg/postgres:14.10 environment: POSTGRES_USER: jet POSTGRES_PASSWORD: jet From 12faa55c529bb0c1eb04a1f2ee41665d65a87975 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 7 Feb 2024 21:22:43 +0000 Subject: [PATCH 23/30] Bump github.com/google/uuid from 1.3.0 to 1.6.0 Bumps [github.com/google/uuid](https://github.com/google/uuid) from 1.3.0 to 1.6.0. - [Release notes](https://github.com/google/uuid/releases) - [Changelog](https://github.com/google/uuid/blob/master/CHANGELOG.md) - [Commits](https://github.com/google/uuid/compare/v1.3.0...v1.6.0) --- updated-dependencies: - dependency-name: github.com/google/uuid dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 9b73508..16c19fa 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.11 require ( github.com/go-sql-driver/mysql v1.7.1 - github.com/google/uuid v1.3.0 + github.com/google/uuid v1.6.0 github.com/jackc/pgconn v1.14.1 github.com/lib/pq v1.10.8 github.com/mattn/go-sqlite3 v1.14.17 diff --git a/go.sum b/go.sum index 5bde7a7..deb45b8 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= From dab153a73918acfab60c094e0e59c19ff3773524 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 10 Feb 2024 14:03:31 +0100 Subject: [PATCH 24/30] Add support for NULLS_FIRST and NULLS_LAST sorting order. --- internal/jet/dialect.go | 10 +- internal/jet/expression.go | 18 +- internal/jet/order_by_clause.go | 70 ++++++-- internal/jet/serializer.go | 4 + mysql/dialect.go | 50 +++++- tests/mysql/select_test.go | 291 ++++++++++++++++++++++++++++++++ tests/postgres/select_test.go | 231 ++++++++++++++++++++++++- tests/sqlite/select_test.go | 227 +++++++++++++++++++++++++ 8 files changed, 882 insertions(+), 19 deletions(-) diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 434e3b8..f3ad2b4 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -12,6 +12,7 @@ type Dialect interface { IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc IsReservedWord(name string) bool + SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc } // SerializerFunc func @@ -33,6 +34,7 @@ type DialectParams struct { IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc ReservedWords []string + SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc } // NewDialect creates new dialect with params @@ -46,6 +48,7 @@ func NewDialect(params DialectParams) Dialect { identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), + serializeOrderBy: params.SerializeOrderBy, } } @@ -58,8 +61,7 @@ type dialectImpl struct { identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc reservedWords map[string]bool - - supportsReturning bool + serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc } func (d *dialectImpl) Name() string { @@ -101,6 +103,10 @@ func (d *dialectImpl) IsReservedWord(name string) bool { return isReservedWord } +func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc { + return d.serializeOrderBy +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 436b9d6..d62920c 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -64,14 +64,24 @@ func (e *ExpressionInterfaceImpl) AS(alias string) Projection { return newAlias(e.Parent, alias) } -// ASC expression will be used to sort query result in ascending order +// ASC expression will be used to sort a query result in ascending order func (e *ExpressionInterfaceImpl) ASC() OrderByClause { - return newOrderByClause(e.Parent, true) + return newOrderByAscending(e.Parent, true) } -// DESC expression will be used to sort query result in descending order +// DESC expression will be used to sort a query result in descending order func (e *ExpressionInterfaceImpl) DESC() OrderByClause { - return newOrderByClause(e.Parent, false) + return newOrderByAscending(e.Parent, false) +} + +// NULLS_FIRST specifies sort where null values appear before all non-null values +func (e *ExpressionInterfaceImpl) NULLS_FIRST() OrderByClause { + return newOrderByNullsFirst(e.Parent, true) +} + +// NULLS_LAST specifies sort where null values appear after all non-null values +func (e *ExpressionInterfaceImpl) NULLS_LAST() OrderByClause { + return newOrderByNullsFirst(e.Parent, false) } func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { diff --git a/internal/jet/order_by_clause.go b/internal/jet/order_by_clause.go index 55fb7d2..09d3ea6 100644 --- a/internal/jet/order_by_clause.go +++ b/internal/jet/order_by_clause.go @@ -2,28 +2,78 @@ package jet // OrderByClause interface type OrderByClause interface { + // NULLS_FIRST specifies sort where null values appear before all non-null values. + // For some dialects(mysql,mariadb), which do not support NULL_FIRST, NULL_FIRST is simulated + // with additional IS_NOT_NULL expression. + // For instance, + // Rental.ReturnDate.DESC().NULLS_FIRST() + // would translate to, + // rental.return_date IS NOT NULL, rental.return_date DESC + NULLS_FIRST() OrderByClause + + // NULLS_LAST specifies sort where null values appear after all non-null values. + // For some dialects(mysql,mariadb), which do not support NULLS_LAST, NULLS_LAST is simulated + // with additional IS_NULL expression. + // For instance, + // Rental.ReturnDate.ASC().NULLS_LAST() + // would translate to, + // rental.return_date IS NULL, rental.return_date ASC + NULLS_LAST() OrderByClause + serializeForOrderBy(statement StatementType, out *SQLBuilder) } type orderByClauseImpl struct { expression Expression - ascent bool + ascending *bool + nullsFirst *bool } -func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { - if o.expression == nil { +func (ord *orderByClauseImpl) NULLS_FIRST() OrderByClause { + nullsFirst := true + ord.nullsFirst = &nullsFirst + return ord +} +func (ord *orderByClauseImpl) NULLS_LAST() OrderByClause { + nullsFirst := false + ord.nullsFirst = &nullsFirst + return ord +} + +func (ord *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { + customSerializer := out.Dialect.SerializeOrderBy() + if customSerializer != nil { + customSerializer(ord.expression, ord.ascending, ord.nullsFirst)(statement, out) + return + } + + if ord.expression == nil { panic("jet: nil expression in ORDER BY clause") } - o.expression.serializeForOrderBy(statement, out) + ord.expression.serializeForOrderBy(statement, out) - if o.ascent { - out.WriteString("ASC") - } else { - out.WriteString("DESC") + if ord.ascending != nil { + if *ord.ascending { + out.WriteString("ASC") + } else { + out.WriteString("DESC") + } + } + + if ord.nullsFirst != nil { + if *ord.nullsFirst { + out.WriteString("NULLS FIRST") + } else { + out.WriteString("NULLS LAST") + } } } -func newOrderByClause(expression Expression, ascent bool) OrderByClause { - return &orderByClauseImpl{expression: expression, ascent: ascent} +func newOrderByAscending(expression Expression, ascending bool) OrderByClause { + return &orderByClauseImpl{expression: expression, ascending: &ascending} +} + +func newOrderByNullsFirst(expression Expression, nullsFirst bool) OrderByClause { + return &orderByClauseImpl{expression: expression, nullsFirst: &nullsFirst} } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 93a1d3b..9d36de4 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -44,6 +44,10 @@ func Serialize(exp Serializer, statementType StatementType, out *SQLBuilder, opt exp.serialize(statementType, out, options...) } +func SerializeForOrderBy(exp Expression, statementType StatementType, out *SQLBuilder) { + exp.serializeForOrderBy(statementType, out) +} + func contains(options []SerializeOption, option SerializeOption) bool { for _, opt := range options { if opt == option { diff --git a/mysql/dialect.go b/mysql/dialect.go index 24b8755..18d2eec 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -26,7 +26,8 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, - ReservedWords: reservedWords, + ReservedWords: reservedWords, + SerializeOrderBy: serializeOrderBy, } return jet.NewDialect(mySQLDialectParams) @@ -162,6 +163,53 @@ func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFun } } +func serializeOrderBy(expression Expression, ascending, nullsFirst *bool) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + + if nullsFirst == nil { + jet.SerializeForOrderBy(expression, statement, out) + + if ascending != nil { + serializeAscending(*ascending, out) + } + return + } + + asc := true + + if ascending != nil { + asc = *ascending + } + + if asc { + if !*nullsFirst { + jet.SerializeForOrderBy(expression.IS_NULL(), statement, out) + out.WriteString(", ") + } + jet.SerializeForOrderBy(expression, statement, out) + if ascending != nil { + serializeAscending(asc, out) + } + } else { + if *nullsFirst { + jet.SerializeForOrderBy(expression.IS_NOT_NULL(), statement, out) + out.WriteString(", ") + } + + jet.SerializeForOrderBy(expression, statement, out) + serializeAscending(asc, out) + } + } +} + +func serializeAscending(ascending bool, out *jet.SQLBuilder) { + if ascending { + out.WriteString("ASC") + } else { + out.WriteString("DESC") + } +} + var reservedWords = []string{ "ACCESSIBLE", "ADD", diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index deb4529..8e03f82 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -3,6 +3,7 @@ package mysql import ( "context" "database/sql" + "github.com/go-jet/jet/v2/postgres" "strings" "testing" "time" @@ -296,6 +297,296 @@ ORDER BY inventory.film_id, inventory.store_id; `) } +func TestOrderBy(t *testing.T) { + // NULLS_FIRST and NULLS_LAST are simulated using IS_NULL, ... + + ensureNullsFirstRentalResult := func(t *testing.T, stmt postgres.Statement, asc bool) { + var dest []model.Rental + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 200) + require.Nil(t, dest[0].ReturnDate) + require.NotNil(t, dest[199].ReturnDate) + + if asc { + require.True(t, dest[198].ReturnDate.Before(*dest[199].ReturnDate)) + } else { + require.True(t, dest[199].ReturnDate.Before(*dest[198].ReturnDate)) + } + } + + ensureNullsLastRentalResult := func(t *testing.T, stmt postgres.Statement, asc bool) { + var dest []model.Rental + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 200) + require.NotNil(t, dest[0].ReturnDate) + require.Nil(t, dest[199].ReturnDate) + + if asc { + require.True(t, dest[0].ReturnDate.Before(*dest[1].ReturnDate)) + } else { + require.True(t, dest[1].ReturnDate.Before(*dest[0].ReturnDate)) + } + } + + t.Run("default", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate, + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date +LIMIT 200; +`) + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date +LIMIT 200; +`) + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date IS NULL, rental.return_date +LIMIT 200 +OFFSET 15800; +`) + ensureNullsLastRentalResult(t, stmt, true) + }) + + t.Run("ASC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("ASC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("ASC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date IS NULL, rental.return_date ASC +LIMIT 200 +OFFSET 15800; +`) + + ensureNullsLastRentalResult(t, stmt, true) + }) + + t.Run("DESC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + ensureNullsLastRentalResult(t, stmt, false) + }) + + t.Run("DESC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + ensureNullsLastRentalResult(t, stmt, false) + }) + + t.Run("DESC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date IS NOT NULL, rental.return_date DESC +LIMIT 200; +`) + + ensureNullsFirstRentalResult(t, stmt, false) + }) + + t.Run("complex", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.RentalID.DESC(), + Rental.ReturnDate.DESC().NULLS_FIRST(), + Rental.LastUpdate.ASC(), + Rental.InventoryID.ADD(Rental.RentalID).ASC().NULLS_LAST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.rental_id DESC, rental.return_date IS NOT NULL, rental.return_date DESC, rental.last_update ASC, (rental.inventory_id + rental.rental_id) IS NULL, rental.inventory_id + rental.rental_id ASC +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + + }) + +} + func TestSubQuery(t *testing.T) { rRatingFilms := SELECT( diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index d9b30ad..8c7e4bc 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1151,7 +1151,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { firstCustomerAsc := customersAsc[0] lastCustomerAsc := customersAsc[len(customersAsc)-1] - customersDesc := []model.Customer{} + var customersDesc []model.Customer err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). ORDER_BY(Customer.FirstName.DESC()). Query(db, &customersDesc) @@ -1164,7 +1164,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { testutils.AssertDeepEqual(t, firstCustomerAsc, lastCustomerDesc) testutils.AssertDeepEqual(t, lastCustomerAsc, firstCustomerDesc) - customersAscDesc := []model.Customer{} + var customersAscDesc []model.Customer err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). ORDER_BY(Customer.FirstName.ASC(), Customer.LastName.DESC()). Query(db, &customersAscDesc) @@ -1187,6 +1187,233 @@ func TestSelectOrderByAscDesc(t *testing.T) { testutils.AssertDeepEqual(t, customerAscDesc327, customersAscDesc[327]) } +func TestOrderBy(t *testing.T) { + + t.Run("default", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate, + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date NULLS FIRST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_LAST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date NULLS LAST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) +} + func TestSelectFullJoin(t *testing.T) { expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index 838cb69..63d43a9 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -145,6 +145,233 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; requireLogged(t, query) } +func TestOrderBy(t *testing.T) { + + t.Run("default", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate, + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date NULLS FIRST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_LAST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date NULLS LAST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date ASC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date ASC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date DESC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date DESC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) +} + func TestAggregateFunctionDistinct(t *testing.T) { stmt := SELECT( Payment.CustomerID, From 255f4a8eaf0088d493bce5a912aad55f3c2f5cdf Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 13 Feb 2024 14:01:13 +0100 Subject: [PATCH 25/30] Add support for expression in OFFSET clause. --- README.md | 2 +- internal/jet/clause.go | 12 +++-- mysql/select_statement.go | 3 +- mysql/set_statement.go | 3 +- postgres/select_statement.go | 8 ++- postgres/set_statement.go | 8 ++- sqlite/select_statement.go | 3 +- sqlite/set_statement.go | 3 +- tests/postgres/select_test.go | 94 ++++++++++++++++++++++++++++++++--- 9 files changed, 113 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 847208c..15fd887 100644 --- a/README.md +++ b/README.md @@ -579,5 +579,5 @@ To run the tests, additional dependencies are required: ## License -Copyright 2019-2023 Goran Bjelanovic +Copyright 2019-2024 Goran Bjelanovic Licensed under the Apache License, Version 2.0. diff --git a/internal/jet/clause.go b/internal/jet/clause.go index ed899b3..533d223 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -222,16 +222,18 @@ func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder, op // ClauseOffset struct type ClauseOffset struct { - Count int64 + Count IntegerExpression } // Serialize serializes clause into SQLBuilder func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - if o.Count >= 0 { - out.NewLine() - out.WriteString("OFFSET") - out.insertParametrizedArgument(o.Count) + if is.Nil(o.Count) { + return } + + out.NewLine() + out.WriteString("OFFSET") + o.Count.serialize(statementType, out, options...) } // ClauseFetch struct diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 45c7782..6c5f345 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -86,7 +86,6 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta newSelect.From.Tables = []jet.Serializer{table} } newSelect.Limit.Count = -1 - newSelect.Offset.Count = -1 newSelect.ShareLock.Name = "LOCK IN SHARE MODE" newSelect.ShareLock.InNewLine = true @@ -158,7 +157,7 @@ func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { } func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { - s.Offset.Count = offset + s.Offset.Count = Int(offset) return s } diff --git a/mysql/set_statement.go b/mysql/set_statement.go index ec9f8fa..2df75a0 100644 --- a/mysql/set_statement.go +++ b/mysql/set_statement.go @@ -63,7 +63,6 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat newSetStatement.setOperator.All = all newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Limit.Count = -1 - newSetStatement.setOperator.Offset.Count = -1 newSetStatement.setOperatorsImpl.parent = newSetStatement @@ -81,7 +80,7 @@ func (s *setStatementImpl) LIMIT(limit int64) setStatement { } func (s *setStatementImpl) OFFSET(offset int64) setStatement { - s.setOperator.Offset.Count = offset + s.setOperator.Offset.Count = Int(offset) return s } diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 0ee80bf..70a9a50 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -53,6 +53,8 @@ type SelectStatement interface { ORDER_BY(orderByClauses ...OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement + // OFFSET_e can be used when an integer expression is needed as offset, otherwise OFFSET can be used + OFFSET_e(offset IntegerExpression) SelectStatement FETCH_FIRST(count IntegerExpression) fetchExpand FOR(lock RowLock) SelectStatement @@ -91,7 +93,6 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta newSelect.From.Tables = []jet.Serializer{table} } newSelect.Limit.Count = -1 - newSelect.Offset.Count = -1 newSelect.setOperatorsImpl.parent = newSelect @@ -157,6 +158,11 @@ func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { } func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { + s.Offset.Count = Int(offset) + return s +} + +func (s *selectStatementImpl) OFFSET_e(offset IntegerExpression) SelectStatement { s.Offset.Count = offset return s } diff --git a/postgres/set_statement.go b/postgres/set_statement.go index 236f8de..834560d 100644 --- a/postgres/set_statement.go +++ b/postgres/set_statement.go @@ -45,6 +45,8 @@ type setStatement interface { LIMIT(limit int64) setStatement OFFSET(offset int64) setStatement + // OFFSET_e can be used when an integer expression is needed as offset, otherwise OFFSET can be used + OFFSET_e(offset IntegerExpression) setStatement AsTable(alias string) SelectTable } @@ -107,7 +109,6 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat newSetStatement.setOperator.All = all newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Limit.Count = -1 - newSetStatement.setOperator.Offset.Count = -1 newSetStatement.setOperatorsImpl.parent = newSetStatement @@ -125,6 +126,11 @@ func (s *setStatementImpl) LIMIT(limit int64) setStatement { } func (s *setStatementImpl) OFFSET(offset int64) setStatement { + s.setOperator.Offset.Count = Int(offset) + return s +} + +func (s *setStatementImpl) OFFSET_e(offset IntegerExpression) setStatement { s.setOperator.Offset.Count = offset return s } diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go index 5e92d52..e74cda6 100644 --- a/sqlite/select_statement.go +++ b/sqlite/select_statement.go @@ -74,7 +74,6 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta newSelect.From.Tables = []jet.Serializer{table} } newSelect.Limit.Count = -1 - newSelect.Offset.Count = -1 newSelect.ShareLock.Name = "LOCK IN SHARE MODE" newSelect.ShareLock.InNewLine = true @@ -141,7 +140,7 @@ func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { } func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { - s.Offset.Count = offset + s.Offset.Count = Int(offset) return s } diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go index 18bcca5..0a004bf 100644 --- a/sqlite/set_statement.go +++ b/sqlite/set_statement.go @@ -63,7 +63,6 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat newSetStatement.setOperator.All = all newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Limit.Count = -1 - newSetStatement.setOperator.Offset.Count = -1 newSetStatement.setOperator.SkipSelectWrap = true newSetStatement.setOperatorsImpl.parent = newSetStatement @@ -82,7 +81,7 @@ func (s *setStatementImpl) LIMIT(limit int64) setStatement { } func (s *setStatementImpl) OFFSET(offset int64) setStatement { - s.setOperator.Offset.Count = offset + s.setOperator.Offset.Count = Int(offset) return s } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 8c7e4bc..048db14 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -320,6 +320,38 @@ FETCH FIRST ( }) } +func TestOffsetExpression(t *testing.T) { + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID). + OFFSET_e(IntExp( + SELECT(MAX(Store.StoreID)). + FROM(Store), + )).LIMIT(10) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.actor_id +LIMIT 10 +OFFSET ( + SELECT MAX(store.store_id) + FROM dvds.store +); +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + require.Equal(t, dest[0].ActorID, int32(3)) +} + func TestJoinQueryStruct(t *testing.T) { expectedSQL := ` @@ -2365,16 +2397,14 @@ OFFSET 20; Payment. SELECT(Payment.PaymentID, Payment.Amount). WHERE(Payment.Amount.GT_EQ(Float(200))), - ). - ORDER_BY(IntegerColumn("payment.payment_id").ASC(), Payment.Amount.DESC()). - LIMIT(10). - OFFSET(20) - - //fmt.Println(query.DebugSql()) + ).ORDER_BY( + IntegerColumn("payment.payment_id").ASC(), + Payment.Amount.DESC(), + ).LIMIT(10).OFFSET(20) testutils.AssertDebugStatementSql(t, query, expectedQuery, float64(100), float64(200), int64(10), int64(20)) - dest := []model.Payment{} + var dest []model.Payment err := query.Query(db, &dest) @@ -2394,6 +2424,56 @@ OFFSET 20; }) } +func TestUnionOffsetWithExpression(t *testing.T) { + stmt := UNION( + SELECT(Rental.AllColumns). + FROM(Rental). + WHERE(Rental.ReturnDate.IS_NULL()), + + SELECT(Rental.AllColumns). + FROM(Rental). + WHERE(Rental.LastUpdate.GT(LOCALTIMESTAMP())), + ).OFFSET_e(IntExp( + SELECT(Int32(3)), + )).LIMIT(10) + + testutils.AssertStatementSql(t, stmt, ` +( + SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" + FROM dvds.rental + WHERE rental.return_date IS NULL +) +UNION +( + SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" + FROM dvds.rental + WHERE rental.last_update > LOCALTIMESTAMP +) +LIMIT $1 +OFFSET ( + SELECT $2::integer +); +`) + + var dest []model.Rental + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) +} + func TestAllSetOperators(t *testing.T) { var select1 = Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17600)).AND(Payment.PaymentID.LT(Int(17610)))) var select2 = Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17620)).AND(Payment.PaymentID.LT(Int(17630)))) From 6b098b8e41fc8adc22418202019b1e2cb2441b58 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 17 Feb 2024 12:46:00 +0100 Subject: [PATCH 26/30] Replace invalid character from the go identifiers with description string. --- generator/template/sql_builder_template.go | 6 +- internal/utils/dbidentifier/dbidentifier.go | 93 ++++++++++++++++++- .../utils/dbidentifier/dbidentifier_test.go | 43 +++++++++ 3 files changed, 135 insertions(+), 7 deletions(-) diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index 869b8e5..cf7e121 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -68,11 +68,13 @@ type ViewSQLBuilder = TableSQLBuilder // DefaultTableSQLBuilder returns default implementation for TableSQLBuilder func DefaultTableSQLBuilder(tableMetaData metadata.Table) TableSQLBuilder { + tableNameGoIdentifier := dbidentifier.ToGoIdentifier(tableMetaData.Name) + return TableSQLBuilder{ Path: "/table", FileName: dbidentifier.ToGoFileName(tableMetaData.Name), - InstanceName: dbidentifier.ToGoIdentifier(tableMetaData.Name), - TypeName: dbidentifier.ToGoIdentifier(tableMetaData.Name) + "Table", + InstanceName: tableNameGoIdentifier, + TypeName: tableNameGoIdentifier + "Table", DefaultAlias: "", Column: DefaultTableSQLBuilderColumn, } diff --git a/internal/utils/dbidentifier/dbidentifier.go b/internal/utils/dbidentifier/dbidentifier.go index c4278f6..7aee4a9 100644 --- a/internal/utils/dbidentifier/dbidentifier.go +++ b/internal/utils/dbidentifier/dbidentifier.go @@ -3,6 +3,7 @@ package dbidentifier import ( "github.com/go-jet/jet/v2/internal/3rdparty/snaker" "strings" + "unicode" ) // ToGoIdentifier converts database identifier to Go identifier. @@ -15,10 +16,92 @@ func ToGoFileName(databaseIdentifier string) string { return strings.ToLower(replaceInvalidChars(databaseIdentifier)) } -func replaceInvalidChars(str string) string { - str = strings.Replace(str, " ", "_", -1) - str = strings.Replace(str, "-", "_", -1) - str = strings.Replace(str, ".", "_", -1) +func replaceInvalidChars(identifier string) string { + increase, needs := needsCharReplacement(identifier) - return str + if !needs { + return identifier + } + + var b strings.Builder + + b.Grow(len(identifier) + increase) + + for _, c := range identifier { + switch { + case unicode.IsSpace(c): + b.WriteByte('_') + case unicode.IsControl(c): + continue + default: + replacement, ok := asciiCharacterReplacement[c] + + if ok { + b.WriteByte('_') + b.WriteString(replacement) + b.WriteByte('_') + } else { + b.WriteRune(c) + } + } + + } + + return b.String() +} + +func needsCharReplacement(identifier string) (increase int, needs bool) { + for _, c := range identifier { + switch { + case unicode.IsSpace(c): + needs = true + case unicode.IsControl(c): + increase += -1 + needs = true + continue + default: + replacement, ok := asciiCharacterReplacement[c] + + if ok { + increase += len(replacement) + 1 + needs = true + } + } + } + + return increase, needs +} + +var asciiCharacterReplacement = map[rune]string{ + '!': "exclamation", + '"': "quotation", + '#': "number", + '$': "dollar", + '%': "percent", + '&': "ampersand", + '\'': "apostrophe", + '(': "opening_parentheses", + ')': "closing_parentheses", + '*': "asterisk", + '+': "plus", + ',': "comma", + '-': "_", + '.': "_", + '/': "slash", + ':': "colon", + ';': "semicolon", + '<': "less", + '=': "equal", + '>': "greater", + '?': "question", + '@': "at", + '[': "opening_bracket", + '\\': "backslash", + ']': "closing_bracket", + '^': "caret", + '`': "accent", + '{': "opening_braces", + '|': "vertical_bar", + '}': "closing_braces", + '~': "tilde", } diff --git a/internal/utils/dbidentifier/dbidentifier_test.go b/internal/utils/dbidentifier/dbidentifier_test.go index 339fb5a..668bf56 100644 --- a/internal/utils/dbidentifier/dbidentifier_test.go +++ b/internal/utils/dbidentifier/dbidentifier_test.go @@ -8,6 +8,7 @@ import ( func TestToGoIdentifier(t *testing.T) { require.Equal(t, ToGoIdentifier(""), "") require.Equal(t, ToGoIdentifier("uuid"), "UUID") + require.Equal(t, ToGoIdentifier("uuid_ptr"), "UUIDPtr") require.Equal(t, ToGoIdentifier("col1"), "Col1") require.Equal(t, ToGoIdentifier("PG-13"), "Pg13") require.Equal(t, ToGoIdentifier("13_pg"), "13Pg") @@ -18,8 +19,50 @@ func TestToGoIdentifier(t *testing.T) { require.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE") require.Equal(t, ToGoIdentifier("my_table"), "MyTable") + require.Equal(t, ToGoIdentifier("my_____table"), "MyTable") require.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable") require.Equal(t, ToGoIdentifier("My_Table"), "MyTable") require.Equal(t, ToGoIdentifier("My Table"), "MyTable") require.Equal(t, ToGoIdentifier("My-Table"), "MyTable") + + require.Equal(t, ToGoIdentifier("EN\bUM"), "Enum") // control character + require.Equal(t, ToGoIdentifier("EN\tUM"), "EnUm") // space character + require.Equal(t, ToGoIdentifier("S3:INIT"), "S3ColonInit") // replacement chars + require.Equal(t, ToGoIdentifier("Entity-"), "Entity") + require.Equal(t, ToGoIdentifier("Entity+"), "EntityPlus") + require.Equal(t, ToGoIdentifier("="), "Equal") + require.Equal(t, ToGoIdentifier("<="), "LessEqual") + require.Equal(t, ToGoIdentifier(">="), "GreaterEqual") + require.Equal(t, ToGoIdentifier("some#$%name"), "SomeNumberDollarPercentName") + require.Equal(t, ToGoIdentifier(`An!"them`), "AnExclamationQuotationThem") + require.Equal(t, ToGoIdentifier(`An(Um)`), + "AnOpeningParenthesesUmClosingParentheses") +} + +func TestNeedsCharReplacement(t *testing.T) { + increase, needs := needsCharReplacement("some_name") + require.False(t, needs) + require.Zero(t, increase) + + increase, needs = needsCharReplacement("some name") + require.True(t, needs) + require.Zero(t, increase) + + increase, needs = needsCharReplacement("some\bname") + require.True(t, needs) + require.Equal(t, increase, -1) + + increase, needs = needsCharReplacement("some#$%name") + require.True(t, needs) + require.Equal(t, increase, 22) +} + +func TestToGoFileName(t *testing.T) { + require.Equal(t, ToGoFileName("FileName"), "filename") + require.Equal(t, ToGoFileName("File_Name"), "file_name") + require.Equal(t, ToGoFileName("File___Name__"), "file___name__") + require.Equal(t, ToGoFileName("File___Name__"), "file___name__") + require.Equal(t, ToGoFileName("File\bName"), "filename") + require.Equal(t, ToGoFileName("File\tName"), "file_name") + require.Equal(t, ToGoFileName("File^^Name"), "file_caret__caret_name") } From bffec36917df447f56d11eaf8b8eef132161718b Mon Sep 17 00:00:00 2001 From: Jupp Mueller Date: Sat, 3 Feb 2024 14:35:28 -0800 Subject: [PATCH 27/30] Improve performance of mysql generator This change improves performance for generating mysql models for databases with large number of tables. In my local testing for a database with about 1000 tables and 140k columns, generation time was reduced from about 1h to less than one second. --- generator/metadata/column_meta_data.go | 2 +- generator/metadata/table_meta_data.go | 2 +- generator/mysql/query_set.go | 53 +++++++------------------- 3 files changed, 16 insertions(+), 41 deletions(-) diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index 679f6c1..5c888a3 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -6,7 +6,7 @@ import ( // Column struct type Column struct { - Name string + Name string `sql:"primary_key"` IsPrimaryKey bool IsNullable bool IsGenerated bool diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go index 2e01c61..df9514e 100644 --- a/generator/metadata/table_meta_data.go +++ b/generator/metadata/table_meta_data.go @@ -2,7 +2,7 @@ package metadata // Table metadata struct type Table struct { - Name string + Name string `sql:"primary_key"` Columns []Column } diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 80406bd..bd4d593 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -8,7 +8,6 @@ import ( "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/qrm" - "golang.org/x/sync/errgroup" ) // mySqlQuerySet is dialect query set for MySQL @@ -16,34 +15,8 @@ type mySqlQuerySet struct{} func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) { query := ` -SELECT table_name as "table.name" -FROM INFORMATION_SCHEMA.tables -WHERE table_schema = ? and table_type = ? -ORDER BY table_name; -` - var tables []metadata.Table - - _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) - if err != nil { - return nil, fmt.Errorf("failed to query %s metadata result: %w", tableType, err) - } - - wg := errgroup.Group{} - for i := 0; i < len(tables); i++ { - i := i - wg.Go(func() (err1 error) { - tables[i].Columns, err1 = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name) - return err1 - }) - } - - err = wg.Wait() - return tables, err -} - -func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { - query := ` SELECT + t.table_name as "table.name", col.COLUMN_NAME AS "column.Name", col.IS_NULLABLE = "YES" AS "column.IsNullable", col.COLUMN_COMMENT AS "column.Comment", @@ -56,30 +29,32 @@ SELECT ) AS "dataType.Name", IF (col.DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind", col.COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned" -FROM +FROM INFORMATION_SCHEMA.tables AS t +INNER JOIN information_schema.columns AS col + ON t.table_schema = col.table_schema AND t.table_name = col.table_name LEFT JOIN ( - SELECT k.column_name, 1 AS IsPrimaryKey + SELECT k.column_name, 1 AS IsPrimaryKey, k.table_name FROM information_schema.table_constraints t JOIN information_schema.key_column_usage k USING(constraint_name, table_schema, table_name) WHERE t.table_schema = ? - AND t.table_name = ? AND t.constraint_type = 'PRIMARY KEY' -) AS pk ON col.COLUMN_NAME = pk.column_name -WHERE - col.table_schema = ? - AND col.table_name = ? +) AS pk ON col.COLUMN_NAME = pk.column_name AND col.table_name = pk.table_name +WHERE t.table_schema = ? + AND t.table_type = ? ORDER BY + t.table_name, col.ordinal_position; ` - var columns []metadata.Column - _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) + var tables []metadata.Table + + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, schemaName, tableType}, &tables) if err != nil { - return nil, fmt.Errorf("failed to query %s column meta data: %w", tableName, err) + return nil, fmt.Errorf("failed to query column meta data: %w", err) } - return columns, nil + return tables, nil } func (m mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) { From 09fe45b09cc9dff7a0d62feeb055e3e66eb6d07b Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 20 Feb 2024 23:56:11 +0530 Subject: [PATCH 28/30] mysql: added a helper to compare UUID strings with uuid_to_bin --- mysql/literal.go | 15 ++++++++++++++- mysql/literal_test.go | 10 ++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/mysql/literal.go b/mysql/literal.go index 1c69c31..3c7c07a 100644 --- a/mysql/literal.go +++ b/mysql/literal.go @@ -1,8 +1,10 @@ package mysql import ( - "github.com/go-jet/jet/v2/internal/jet" + "fmt" "time" + + "github.com/go-jet/jet/v2/internal/jet" ) // Keywords @@ -55,6 +57,17 @@ var String = jet.String // value can be any uuid type with a String method var UUID = jet.UUID +// UUIDToBin takes ay object with a String method and calls StringUUIDToBin. +func UUIDToBin(str fmt.Stringer) StringExpression { + return StringUUIDToBin(str.String()) +} + +// StringUUIDToBin is a helper function that calls "uuid_to_bin" function on the passed value. +func StringUUIDToBin(str string) StringExpression { + fn := Func("uuid_to_bin", String(str)) + return StringExp(fn) +} + // Date creates new date literal func Date(year int, month time.Month, day int) DateExpression { return CAST(jet.Date(year, month, day)).AS_DATE() diff --git a/mysql/literal_test.go b/mysql/literal_test.go index fb96641..d5831b6 100644 --- a/mysql/literal_test.go +++ b/mysql/literal_test.go @@ -4,6 +4,8 @@ import ( "math" "testing" "time" + + "github.com/google/uuid" ) func TestBool(t *testing.T) { @@ -81,3 +83,11 @@ func TestTimestamp(t *testing.T) { assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30") assertSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`) } + +func TestUUIDToBin(t *testing.T) { + assertSerialize(t, UUIDToBin(uuid.Nil), `uuid_to_bin(?)`, uuid.Nil.String()) +} + +func TestStringUUIDToBin(t *testing.T) { + assertSerialize(t, StringUUIDToBin(uuid.Nil.String()), `uuid_to_bin(?)`, uuid.Nil.String()) +} From 33ec120437b23edbc05a798491ca523c92beaad1 Mon Sep 17 00:00:00 2001 From: Jay Date: Thu, 22 Feb 2024 17:23:14 +0530 Subject: [PATCH 29/30] replaced the UUIDToBin functions with a singular UUID_TO_BIN --- mysql/functions.go | 6 ++++++ mysql/functions_test.go | 11 +++++++++++ mysql/literal.go | 12 ------------ mysql/literal_test.go | 10 ---------- 4 files changed, 17 insertions(+), 22 deletions(-) create mode 100644 mysql/functions_test.go diff --git a/mysql/functions.go b/mysql/functions.go index 4eba942..ca31d18 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -222,6 +222,12 @@ var SUBSTR = jet.SUBSTR // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. var REGEXP_LIKE = jet.REGEXP_LIKE +// UUID_TO_BIN is a helper function that calls "uuid_to_bin" function on the passed value. +func UUID_TO_BIN(str StringExpression) StringExpression { + fn := Func("uuid_to_bin", str) + return StringExp(fn) +} + //----------------- Date/Time Functions and Operators ------------// // EXTRACT function retrieves subfields such as year or hour from date/time values diff --git a/mysql/functions_test.go b/mysql/functions_test.go new file mode 100644 index 0000000..197b515 --- /dev/null +++ b/mysql/functions_test.go @@ -0,0 +1,11 @@ +package mysql + +import ( + "testing" + + "github.com/google/uuid" +) + +func TestUUIDToBin(t *testing.T) { + assertSerialize(t, UUID_TO_BIN(String(uuid.Nil.String())), `uuid_to_bin(?)`, uuid.Nil.String()) +} diff --git a/mysql/literal.go b/mysql/literal.go index 3c7c07a..ca720c8 100644 --- a/mysql/literal.go +++ b/mysql/literal.go @@ -1,7 +1,6 @@ package mysql import ( - "fmt" "time" "github.com/go-jet/jet/v2/internal/jet" @@ -57,17 +56,6 @@ var String = jet.String // value can be any uuid type with a String method var UUID = jet.UUID -// UUIDToBin takes ay object with a String method and calls StringUUIDToBin. -func UUIDToBin(str fmt.Stringer) StringExpression { - return StringUUIDToBin(str.String()) -} - -// StringUUIDToBin is a helper function that calls "uuid_to_bin" function on the passed value. -func StringUUIDToBin(str string) StringExpression { - fn := Func("uuid_to_bin", String(str)) - return StringExp(fn) -} - // Date creates new date literal func Date(year int, month time.Month, day int) DateExpression { return CAST(jet.Date(year, month, day)).AS_DATE() diff --git a/mysql/literal_test.go b/mysql/literal_test.go index d5831b6..fb96641 100644 --- a/mysql/literal_test.go +++ b/mysql/literal_test.go @@ -4,8 +4,6 @@ import ( "math" "testing" "time" - - "github.com/google/uuid" ) func TestBool(t *testing.T) { @@ -83,11 +81,3 @@ func TestTimestamp(t *testing.T) { assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30") assertSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`) } - -func TestUUIDToBin(t *testing.T) { - assertSerialize(t, UUIDToBin(uuid.Nil), `uuid_to_bin(?)`, uuid.Nil.String()) -} - -func TestStringUUIDToBin(t *testing.T) { - assertSerialize(t, StringUUIDToBin(uuid.Nil.String()), `uuid_to_bin(?)`, uuid.Nil.String()) -} From 893567daca6adaf7b6f512bc62f547375875306a Mon Sep 17 00:00:00 2001 From: Sarkan Date: Wed, 31 Jan 2024 15:30:09 +0100 Subject: [PATCH 30/30] Range types implemented plus and minus infinity keyword tests implemented range table tests added skip cockroach db added select test case added for range fields generator modified to generate correct types generator tests modified to include sample range table model and template generators modified to support range fields returning the T in UPPER and LOWER functions raw ranges implemented bounds set as optional dep modified dependencies modified and issue fixed range expression with templates implemented rangeExpression change to make it more type safe third parameter of constructor function fixed literals removed, functions added tests modified constructor functions used for creating range expressions NumRange converted to a constructor function from literal range_lower and range_upper renamed to lower_bound and upper_bound range literal removed PlusInfinity and MinusInfinity implemented int4 and int8 castings added issues fixed and tests checked number, ts, tstz literal and cast implemented date range literal expression modified and raw function used parent type converted from RangeExpression to Expression range type implemented for postgres range column type, function and literal expression implemented CONTAINS and OVERLAP operations added for range expressions range expressions implemented --- generator/template/model_template.go | 13 + generator/template/sql_builder_template.go | 12 + go.mod | 25 +- go.sum | 2 - internal/jet/column_types.go | 41 ++ internal/jet/func_expression.go | 59 +++ internal/jet/func_expression_test.go | 11 + internal/jet/literal_expression.go | 9 + internal/jet/operators.go | 10 + internal/jet/range_expression.go | 95 ++++ internal/jet/range_expression_test.go | 63 +++ internal/jet/testutils.go | 6 +- postgres/columns.go | 36 ++ postgres/expressions.go | 49 +- postgres/functions.go | 27 ++ postgres/keywords.go | 4 + postgres/utils_test.go | 5 +- tests/postgres/generator_test.go | 98 +++- tests/postgres/range_test.go | 514 +++++++++++++++++++++ tests/testdata | 2 +- 20 files changed, 1062 insertions(+), 19 deletions(-) create mode 100644 internal/jet/range_expression.go create mode 100644 internal/jet/range_expression_test.go create mode 100644 tests/postgres/range_test.go diff --git a/generator/template/model_template.go b/generator/template/model_template.go index 47f0339..f89ebd1 100644 --- a/generator/template/model_template.go +++ b/generator/template/model_template.go @@ -5,6 +5,7 @@ import ( "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/internal/utils/dbidentifier" "github.com/google/uuid" + "github.com/jackc/pgtype" "path" "reflect" "strings" @@ -320,6 +321,18 @@ func toGoType(column metadata.Column) interface{} { return float64(0.0) case "uuid": return uuid.UUID{} + case "daterange": + return pgtype.Daterange{} + case "tsrange": + return pgtype.Tsrange{} + case "tstzrange": + return pgtype.Tstzrange{} + case "int4range": + return pgtype.Int4range{} + case "int8range": + return pgtype.Int8range{} + case "numrange": + return pgtype.Numrange{} default: fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.") return "" diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index cf7e121..fe7fba5 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -177,6 +177,18 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string { case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", "double": // MySQL return "Float" + case "daterange": + return "DateRange" + case "tsrange": + return "TimestampRange" + case "tstzrange": + return "TimestampzRange" + case "int4range": + return "Int4Range" + case "int8range": + return "Int8Range" + case "numrange": + return "NumericRange" default: fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.") return "String" diff --git a/go.mod b/go.mod index 16c19fa..2f60aea 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-jet/jet/v2 -go 1.11 +go 1.21.6 require ( github.com/go-sql-driver/mysql v1.7.1 @@ -10,7 +10,6 @@ require ( github.com/mattn/go-sqlite3 v1.14.17 ) -// test dependencies require ( github.com/google/go-cmp v0.5.9 github.com/jackc/pgx/v4 v4.18.1 @@ -21,3 +20,25 @@ require ( golang.org/x/sync v0.3.0 gopkg.in/guregu/null.v4 v4.0.0 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/felixge/fgprof v0.9.3 // indirect + github.com/friendsofgo/errors v0.9.2 // indirect + github.com/gofrs/uuid v4.0.0+incompatible // indirect + github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.3.2 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgtype v1.14.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/volatiletech/inflect v0.0.1 // indirect + github.com/volatiletech/randomize v0.0.1 // indirect + github.com/volatiletech/strmangle v0.0.1 // indirect + golang.org/x/crypto v0.6.0 // indirect + golang.org/x/text v0.7.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index deb45b8..2044f31 100644 --- a/go.sum +++ b/go.sum @@ -32,7 +32,6 @@ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm4 github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -54,7 +53,6 @@ github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5W github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index 4748b2c..a732061 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -358,3 +358,44 @@ func DateColumn(name string) ColumnDate { dateColumn.ColumnExpressionImpl = NewColumnImpl(name, "", dateColumn) return dateColumn } + +//------------------------------------------------------// + +// ColumnRange is interface for range columns which can be int range, string range +// timestamp range or date range. +type ColumnRange[T Expression] interface { + Range[T] + Column + + From(subQuery SelectTable) ColumnRange[T] + SET(rangeExp Range[T]) ColumnAssigment +} + +type rangeColumnImpl[T Expression] struct { + rangeInterfaceImpl[T] + ColumnExpressionImpl +} + +func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] { + newRangeColumn := RangeColumn[T](i.name) + newRangeColumn.setTableName(i.tableName) + newRangeColumn.setSubQuery(subQuery) + + return newRangeColumn +} + +func (i *rangeColumnImpl[T]) SET(rangeExp Range[T]) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: rangeExp, + } +} + +// RangeColumn creates named range column. +func RangeColumn[T Expression](name string) ColumnRange[T] { + rangeColumn := &rangeColumnImpl[T]{} + rangeColumn.rangeInterfaceImpl.parent = rangeColumn + rangeColumn.ColumnExpressionImpl = NewColumnImpl(name, "", rangeColumn) + + return rangeColumn +} diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 4ff3791..c56053f 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -468,6 +468,33 @@ func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType return newBoolFunc("REGEXP_LIKE", stringExp, pattern) } +//----------Range Type Functions ----------------------// + +// LOWER_BOUND returns range expressions lower bound +func LOWER_BOUND[T Expression](rangeExpression Range[T]) T { + return rangeTypeCaster[T](rangeExpression, NewFunc("LOWER", []Expression{rangeExpression}, nil)) +} + +// UPPER_BOUND returns range expressions upper bound +func UPPER_BOUND[T Expression](rangeExpression Range[T]) T { + return rangeTypeCaster[T](rangeExpression, NewFunc("UPPER", []Expression{rangeExpression}, nil)) +} + +func rangeTypeCaster[T Expression](rangeExpression Range[T], exp Expression) T { + var i Expression + switch rangeExpression.(type) { + case Range[DateExpression]: + i = DateExp(exp) + case Range[IntegerExpression]: + i = IntExp(exp) + case Range[TimestampzExpression]: + i = TimestampzExp(exp) + case Range[TimestampExpression]: + i = TimestampExp(exp) + } + return i.(T) +} + //----------Data Type Formatting Functions ----------------------// // TO_CHAR converts expression to string with format @@ -843,3 +870,35 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { func Func(name string, expressions ...Expression) Expression { return NewFunc(name, expressions, nil) } + +func NumRange(lowNum, highNum NumericExpression, bounds ...StringExpression) Range[NumericExpression] { + return RangeExp[NumericExpression](NewFunc("numrange", rangeFuncParamCombiner[NumericExpression](lowNum, highNum, bounds...), nil)) +} + +func Int4Range(lowNum, highNum IntegerExpression, bounds ...StringExpression) Range[IntegerExpression] { + return RangeExp[IntegerExpression](NewFunc("int4range", rangeFuncParamCombiner[IntegerExpression](lowNum, highNum, bounds...), nil)) +} + +func Int8Range(lowNum, highNum IntegerExpression, bounds ...StringExpression) Range[IntegerExpression] { + return RangeExp[IntegerExpression](NewFunc("int8range", rangeFuncParamCombiner[IntegerExpression](lowNum, highNum, bounds...), nil)) +} + +func TimestampRange(lowTs, highTs TimestampExpression, bounds ...StringExpression) Range[TimestampExpression] { + return RangeExp[TimestampExpression](NewFunc("tsrange", rangeFuncParamCombiner[TimestampExpression](lowTs, highTs, bounds...), nil)) +} + +func TimestampzRange(lowTs, highTs TimestampzExpression, bounds ...StringExpression) Range[TimestampzExpression] { + return RangeExp[TimestampzExpression](NewFunc("tstzrange", rangeFuncParamCombiner[TimestampzExpression](lowTs, highTs, bounds...), nil)) +} + +func DateRange(lowTs, highTs DateExpression, bounds ...StringExpression) Range[DateExpression] { + return RangeExp[DateExpression](NewFunc("daterange", rangeFuncParamCombiner[DateExpression](lowTs, highTs, bounds...), nil)) +} + +func rangeFuncParamCombiner[T Expression](low, high T, bounds ...StringExpression) []Expression { + exp := []Expression{low, high} + if len(bounds) != 0 { + exp = append(exp, bounds[0]) + } + return exp +} diff --git a/internal/jet/func_expression_test.go b/internal/jet/func_expression_test.go index 048ade2..ca32589 100644 --- a/internal/jet/func_expression_test.go +++ b/internal/jet/func_expression_test.go @@ -202,3 +202,14 @@ func TestTO_ASCII(t *testing.T) { func TestFunc(t *testing.T) { assertClauseSerialize(t, Func("FOO", String("test"), NULL, MAX(Int(1))), "FOO($1, NULL, MAX($2))", "test", int64(1)) } + +func Test_rangePointCaster(t *testing.T) { + mainRange := Int8Range(Int8(10), Int8(12)) + exp := NewFunc("UPPER", []Expression{mainRange}, nil) + + got := rangeTypeCaster(mainRange, exp) + _, ok := got.(IntegerExpression) + if !ok { + t.Errorf("expecting to get IntegerExpression but got %v", got) + } +} diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 8c28b44..d6f0b41 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -333,6 +333,10 @@ var ( NULL = newNullLiteral() // STAR is jet equivalent of SQL * STAR = newStarLiteral() + // PLUS_INFINITY is jet equivalent for sql infinity + PLUS_INFINITY = String("infinity") + // MINUS_INFINITY is jet equivalent for sql -infinity + MINUS_INFINITY = String("-infinity") ) type nullLiteral struct { @@ -490,6 +494,11 @@ func RawDate(raw string, namedArgs ...map[string]interface{}) DateExpression { return DateExp(Raw(raw, namedArgs...)) } +// RawRange helper that for range expressions +func RawRange[T Expression](raw string, namedArgs ...map[string]interface{}) Range[T] { + return RangeExp[T](Raw(raw, namedArgs...)) +} + // UUID is a helper function to create string literal expression from uuid object // value can be any uuid type with a String method func UUID(value fmt.Stringer) StringExpression { diff --git a/internal/jet/operators.go b/internal/jet/operators.go index b73a451..bf1dedf 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -69,6 +69,16 @@ func GtEq(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, ">=") } +// Contains returns a representation of "a @> b" +func Contains(lhs Expression, rhs Expression) BoolExpression { + return newBinaryBoolOperatorExpression(lhs, rhs, "@>") +} + +// Overlap returns a representation of "a && b" +func Overlap(lhs, rhs Expression) BoolExpression { + return newBinaryBoolOperatorExpression(lhs, rhs, "&&") +} + // Add notEq returns a representation of "a + b" func Add(lhs, rhs Serializer) Expression { return NewBinaryOperatorExpression(lhs, rhs, "+") diff --git a/internal/jet/range_expression.go b/internal/jet/range_expression.go new file mode 100644 index 0000000..f05fdea --- /dev/null +++ b/internal/jet/range_expression.go @@ -0,0 +1,95 @@ +package jet + +// Range Expression is interface for date range types +type Range[T Expression] interface { + Expression + + EQ(rhs Range[T]) BoolExpression + NOT_EQ(rhs Range[T]) BoolExpression + + LT(rhs Range[T]) BoolExpression + LT_EQ(rhs Range[T]) BoolExpression + GT(rhs Range[T]) BoolExpression + GT_EQ(rhs Range[T]) BoolExpression + + CONTAINS(rhs T) BoolExpression + CONTAINS_RANGE(rhs Range[T]) BoolExpression + OVERLAP(rhs Range[T]) BoolExpression + UNION(rhs Range[T]) Range[T] + INTERSECTION(rhs Range[T]) Range[T] + DIFFERENCE(rhs Range[T]) Range[T] +} + +type rangeInterfaceImpl[T Expression] struct { + parent Expression +} + +func (r *rangeInterfaceImpl[T]) EQ(rhs Range[T]) BoolExpression { + return Eq(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) NOT_EQ(rhs Range[T]) BoolExpression { + return NotEq(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) LT(rhs Range[T]) BoolExpression { + return Lt(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) LT_EQ(rhs Range[T]) BoolExpression { + return LtEq(r.parent, rhs) + +} + +func (r *rangeInterfaceImpl[T]) GT(rhs Range[T]) BoolExpression { + return Gt(r.parent, rhs) + +} + +func (r *rangeInterfaceImpl[T]) GT_EQ(rhs Range[T]) BoolExpression { + return GtEq(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) CONTAINS(rhs T) BoolExpression { + return Contains(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) CONTAINS_RANGE(rhs Range[T]) BoolExpression { + return Contains(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) OVERLAP(rhs Range[T]) BoolExpression { + return Overlap(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) UNION(rhs Range[T]) Range[T] { + return RangeExp[T](Add(r.parent, rhs)) +} + +func (r *rangeInterfaceImpl[T]) INTERSECTION(rhs Range[T]) Range[T] { + return RangeExp[T](Mul(r.parent, rhs)) +} + +func (r *rangeInterfaceImpl[T]) DIFFERENCE(rhs Range[T]) Range[T] { + return RangeExp[T](Sub(r.parent, rhs)) +} + +//---------------------------------------------------// + +type rangeExpressionWrapper[T Expression] struct { + rangeInterfaceImpl[T] + Expression +} + +func newRangeExpressionWrap[T Expression](expression Expression) Range[T] { + rangeExpressionWrap := rangeExpressionWrapper[T]{Expression: expression} + rangeExpressionWrap.rangeInterfaceImpl.parent = &rangeExpressionWrap + return &rangeExpressionWrap +} + +// RangeExp is range expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as range expression. +// Does not add sql cast to generated sql builder output. +func RangeExp[T Expression](expression Expression) Range[T] { + return newRangeExpressionWrap[T](expression) +} diff --git a/internal/jet/range_expression_test.go b/internal/jet/range_expression_test.go new file mode 100644 index 0000000..e5e8b75 --- /dev/null +++ b/internal/jet/range_expression_test.go @@ -0,0 +1,63 @@ +package jet + +import "testing" + +func TestRangeExpressionEQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.EQ(table2ColRange), "(table1.col_range = table2.col_range)") + assertClauseSerialize(t, table1ColRange.EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range = int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionNOT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.NOT_EQ(table2ColRange), "(table1.col_range != table2.col_range)") + assertClauseSerialize(t, table1ColRange.NOT_EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range != int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionLT(t *testing.T) { + assertClauseSerialize(t, table1ColRange.LT(table2ColRange), "(table1.col_range < table2.col_range)") + assertClauseSerialize(t, table1ColRange.LT(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range < int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionLT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.LT_EQ(table2ColRange), "(table1.col_range <= table2.col_range)") + assertClauseSerialize(t, table1ColRange.LT_EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range <= int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionGT(t *testing.T) { + assertClauseSerialize(t, table1ColRange.GT(table2ColRange), "(table1.col_range > table2.col_range)") + assertClauseSerialize(t, table1ColRange.GT(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range > int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionGT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.GT_EQ(table2ColRange), "(table1.col_range >= table2.col_range)") + assertClauseSerialize(t, table1ColRange.GT_EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range >= int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionCONTAINS_RANGE(t *testing.T) { + assertClauseSerialize(t, table1ColRange.CONTAINS_RANGE(table2ColRange), "(table1.col_range @> table2.col_range)") + assertClauseSerialize(t, table1ColRange.CONTAINS_RANGE(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range @> int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionCONTAINS(t *testing.T) { + assertClauseSerialize(t, table1ColRange.CONTAINS(table2Col3), "(table1.col_range @> table2.col3)") + assertClauseSerialize(t, table1ColRange.CONTAINS(Int8(1)), "(table1.col_range @> $1)", int8(1)) +} + +func TestRangeExpressionOVERLAP(t *testing.T) { + assertClauseSerialize(t, table1ColRange.OVERLAP(table2ColRange), "(table1.col_range && table2.col_range)") + assertClauseSerialize(t, table1ColRange.OVERLAP(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range && int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionUNION(t *testing.T) { + assertClauseSerialize(t, table1ColRange.UNION(table2ColRange), "(table1.col_range + table2.col_range)") + assertClauseSerialize(t, table1ColRange.UNION(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range + int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionINTERSECTION(t *testing.T) { + assertClauseSerialize(t, table1ColRange.INTERSECTION(table2ColRange), "(table1.col_range * table2.col_range)") + assertClauseSerialize(t, table1ColRange.INTERSECTION(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range * int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionDIFFERENCE(t *testing.T) { + assertClauseSerialize(t, table1ColRange.DIFFERENCE(table2ColRange), "(table1.col_range - table2.col_range)") + assertClauseSerialize(t, table1ColRange.DIFFERENCE(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range - int8range($1, $2, $3))", int8(1), int8(4), "[)") +} diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 43f1f5d..d91e30f 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -25,8 +25,9 @@ var ( table1ColTimestampz = TimestampzColumn("col_timestampz") table1ColBool = BoolColumn("col_bool") table1ColDate = DateColumn("col_date") + table1ColRange = RangeColumn[IntegerExpression]("col_range") ) -var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz) +var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz) var ( table2Col3 = IntegerColumn("col3") @@ -40,8 +41,9 @@ var ( table2ColTimestamp = TimestampColumn("col_timestamp") table2ColTimestampz = TimestampzColumn("col_timestampz") table2ColDate = DateColumn("col_date") + table2ColRange = RangeColumn[IntegerExpression]("col_range") ) -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz) var ( table3Col1 = IntegerColumn("col1") diff --git a/postgres/columns.go b/postgres/columns.go index a25a88c..aee2896 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -65,6 +65,42 @@ type ColumnTimestampz = jet.ColumnTimestampz // TimestampzColumn creates named timestamp with time zone column. var TimestampzColumn = jet.TimestampzColumn +// ColumnDateRange is interface of SQL date range column +type ColumnDateRange = jet.ColumnRange[DateExpression] + +// DateRangeColumn creates named range with range column +var DateRangeColumn = jet.RangeColumn[DateExpression] + +// ColumnNumericRange is interface of SQL numeric range column +type ColumnNumericRange = jet.ColumnRange[NumericExpression] + +// NumericRangeColumn creates named range with range column +var NumericRangeColumn = jet.RangeColumn[NumericExpression] + +// ColumnTimestampRange is interface of SQL timestamp range column +type ColumnTimestampRange = jet.ColumnRange[TimestampExpression] + +// TimestampRangeColumn creates named range with range column +var TimestampRangeColumn = jet.RangeColumn[TimestampExpression] + +// ColumnTimestampzRange is interface of SQL timestamp range column +type ColumnTimestampzRange = jet.ColumnRange[TimestampzExpression] + +// TimestampzRangeColumn creates named range with range column +var TimestampzRangeColumn = jet.RangeColumn[TimestampzExpression] + +// ColumnInt4Range is interface of SQL int range column +type ColumnInt4Range = jet.ColumnRange[IntegerExpression] + +// Int4RangeColumn creates named range with range column +var Int4RangeColumn = jet.RangeColumn[IntegerExpression] + +// ColumnInt8Range is interface of SQL int range column +type ColumnInt8Range = jet.ColumnRange[IntegerExpression] + +// Int8RangeColumn creates named range with range column +var Int8RangeColumn = jet.RangeColumn[IntegerExpression] + //------------------------------------------------------// // ColumnInterval is interface of PostgreSQL interval columns. diff --git a/postgres/expressions.go b/postgres/expressions.go index faf153d..a903860 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -36,6 +36,24 @@ type TimestampExpression = jet.TimestampExpression // TimestampzExpression interface type TimestampzExpression = jet.TimestampzExpression +// DateRange Expression interface +type DateRange = jet.Range[DateExpression] + +// TimestampRange Expression interface +type TimestampRange = jet.Range[TimestampExpression] + +// TimestampzRange Expression interface +type TimestampzRange = jet.Range[TimestampzExpression] + +// NumericRange Expression interface +type NumericRange = jet.Range[NumericExpression] + +// Int4Range Expression interface +type Int4Range = jet.Range[IntegerExpression] + +// Int8Range Expression interface +type Int8Range = jet.Range[IntegerExpression] + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -81,6 +99,13 @@ var TimestampExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampzExp = jet.TimestampzExp +// RangeExp is range expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as range expression. +// Does not add sql cast to generated sql builder output. +func RangeExp[T Expression](expression T) jet.Range[T] { + return jet.RangeExp[T](expression) +} + // RawArgs is type used to pass optional arguments to Raw method type RawArgs = map[string]interface{} @@ -90,15 +115,21 @@ type RawArgs = map[string]interface{} var ( Raw = jet.Raw - RawBool = jet.RawBool - RawInt = jet.RawInt - RawFloat = jet.RawFloat - RawString = jet.RawString - RawTime = jet.RawTime - RawTimez = jet.RawTimez - RawTimestamp = jet.RawTimestamp - RawTimestampz = jet.RawTimestampz - RawDate = jet.RawDate + RawBool = jet.RawBool + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimez = jet.RawTimez + RawTimestamp = jet.RawTimestamp + RawTimestampz = jet.RawTimestampz + RawDate = jet.RawDate + RawNumRange = jet.RawRange[jet.NumericExpression] + RawInt4Range = jet.RawRange[jet.IntegerExpression] + RawInt8Range = jet.RawRange[jet.IntegerExpression] + RawTimestampRange = jet.RawRange[jet.TimestampExpression] + RawTimestampzRange = jet.RawRange[jet.TimestampzExpression] + RawDateRange = jet.RawRange[jet.DateExpression] ) // Func can be used to call custom or unsupported database functions. diff --git a/postgres/functions.go b/postgres/functions.go index 43a5f39..eddc930 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -267,6 +267,18 @@ var TO_HEX = jet.TO_HEX //----------Data Type Formatting Functions ----------------------// +// LOWER_BOUND returns range expressions lower bound +func LOWER_BOUND[T Expression](expression jet.Range[T]) T { + return jet.LOWER_BOUND[T](expression) +} + +// UPPER_BOUND returns range expressions upper bound +func UPPER_BOUND[T Expression](expression jet.Range[T]) T { + return jet.UPPER_BOUND[T](expression) +} + +//----------Data Type Formatting Functions ----------------------// + // TO_CHAR converts expression to string with format var TO_CHAR = jet.TO_CHAR @@ -421,3 +433,18 @@ var CUBE = jet.CUBE // It can be also used with multiple parameters to check if a set of columns is included in the current grouping set. The result // of the GROUPING function would then be an integer bit mask having 1’s for the arguments which have GROUPING(argument) as 1. var GROUPING = jet.GROUPING + +var ( + // DATE_RANGE constructor function to create a date range + DATE_RANGE = jet.DateRange + // NUM_Range constructor function to create a numeric range + NUM_Range = jet.NumRange + // TIMESTAMP_RANGE constructor function to create a timestamp range + TIMESTAMP_RANGE = jet.TimestampRange + // TIMESTAMPTZ_RANGE constructor function to create a timestampz range + TIMESTAMPTZ_RANGE = jet.TimestampzRange + // INT4_RANGE constructor function to create a int4 range + INT4_RANGE = jet.Int4Range + // INT8_RANGE constructor function to create a int8 range + INT8_RANGE = jet.Int8Range +) diff --git a/postgres/keywords.go b/postgres/keywords.go index cfc90a2..a468cc5 100644 --- a/postgres/keywords.go +++ b/postgres/keywords.go @@ -12,4 +12,8 @@ var ( NULL = jet.NULL // STAR is jet equivalent of SQL * STAR = jet.STAR + // PLUS_INFINITY is jet equivalent for sql infinity + PLUS_INFINITY = jet.PLUS_INFINITY + // MINUS_INFINITY is jet equivalent for sql -infinity + MINUS_INFINITY = jet.MINUS_INFINITY ) diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 292d7e4..96bb13b 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -17,6 +17,7 @@ var table1ColTimestampz = TimestampzColumn("col_timestampz") var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") var table1ColInterval = IntervalColumn("col_interval") +var table1ColRange = Int8RangeColumn("col_range") var table1 = NewTable( "db", @@ -32,6 +33,7 @@ var table1 = NewTable( table1ColTimestamp, table1ColTimestampz, table1ColInterval, + table1ColRange, ) var table2Col3 = IntegerColumn("col3") @@ -46,8 +48,9 @@ var table2ColTimestamp = TimestampColumn("col_timestamp") var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColDate = DateColumn("col_date") var table2ColInterval = IntervalColumn("col_interval") +var table2ColRange = Int8RangeColumn("col_range") -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval, table2ColRange) var table3Col1 = IntegerColumn("col1") var table3ColInt = IntegerColumn("col_int") diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 5aaaa31..479d82f 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -561,13 +561,14 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go", - "components.go", "vulnerabilities.go", "all_types_materialized_view.go") + "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go") testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go", - "components.go", "vulnerabilities.go") + "components.go", "vulnerabilities.go", "sample_ranges.go") testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) + testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent) testutils.AssertFileNamesEqual(t, viewDir, "all_types_materialized_view.go", "all_types_view.go", "view_use_schema.go") @@ -968,3 +969,96 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { } } ` + +var sampleRangeTableContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/postgres" +) + +var SampleRanges = newSampleRangesTable("test_sample", "sample_ranges", "") + +type sampleRangesTable struct { + postgres.Table + + // Columns + DateRange postgres.ColumnDateRange + TimestampRange postgres.ColumnTimestampRange + TimestampzRange postgres.ColumnTimestampzRange + Int4Range postgres.ColumnInt4Range + Int8Range postgres.ColumnInt8Range + NumRange postgres.ColumnNumericRange + + AllColumns postgres.ColumnList + MutableColumns postgres.ColumnList +} + +type SampleRangesTable struct { + sampleRangesTable + + EXCLUDED sampleRangesTable +} + +// AS creates new SampleRangesTable with assigned alias +func (a SampleRangesTable) AS(alias string) *SampleRangesTable { + return newSampleRangesTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new SampleRangesTable with assigned schema name +func (a SampleRangesTable) FromSchema(schemaName string) *SampleRangesTable { + return newSampleRangesTable(schemaName, a.TableName(), a.Alias()) +} + +// WithPrefix creates new SampleRangesTable with assigned table prefix +func (a SampleRangesTable) WithPrefix(prefix string) *SampleRangesTable { + return newSampleRangesTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new SampleRangesTable with assigned table suffix +func (a SampleRangesTable) WithSuffix(suffix string) *SampleRangesTable { + return newSampleRangesTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + +func newSampleRangesTable(schemaName, tableName, alias string) *SampleRangesTable { + return &SampleRangesTable{ + sampleRangesTable: newSampleRangesTableImpl(schemaName, tableName, alias), + EXCLUDED: newSampleRangesTableImpl("", "excluded", ""), + } +} + +func newSampleRangesTableImpl(schemaName, tableName, alias string) sampleRangesTable { + var ( + DateRangeColumn = postgres.DateRangeColumn("date_range") + TimestampRangeColumn = postgres.TimestampRangeColumn("timestamp_range") + TimestampzRangeColumn = postgres.TimestampzRangeColumn("timestampz_range") + Int4RangeColumn = postgres.Int4RangeColumn("int4_range") + Int8RangeColumn = postgres.Int8RangeColumn("int8_range") + NumRangeColumn = postgres.NumericRangeColumn("num_range") + allColumns = postgres.ColumnList{DateRangeColumn, TimestampRangeColumn, TimestampzRangeColumn, Int4RangeColumn, Int8RangeColumn, NumRangeColumn} + mutableColumns = postgres.ColumnList{DateRangeColumn, TimestampRangeColumn, TimestampzRangeColumn, Int4RangeColumn, Int8RangeColumn, NumRangeColumn} + ) + + return sampleRangesTable{ + Table: postgres.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + DateRange: DateRangeColumn, + TimestampRange: TimestampRangeColumn, + TimestampzRange: TimestampzRangeColumn, + Int4Range: Int4RangeColumn, + Int8Range: Int8RangeColumn, + NumRange: NumRangeColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` diff --git a/tests/postgres/range_test.go b/tests/postgres/range_test.go new file mode 100644 index 0000000..63498eb --- /dev/null +++ b/tests/postgres/range_test.go @@ -0,0 +1,514 @@ +package postgres + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/qrm" + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" + "math/big" + "testing" + "time" + + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" +) + +func TestRangeTable_DateContainsSingle(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.date_range @> '2023-12-12'::date; +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.DateRange.CONTAINS(Date(2023, 12, 12))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2023-12-12") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.NoError(t, err) + + expectedRow := model.SampleRanges{ + DateRange: pgtype.Daterange{ + Lower: pgtype.Date{ + Time: time.Date(2023, 9, 25, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Date{ + Time: time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + TimestampRange: pgtype.Tsrange{ + Lower: pgtype.Timestamp{ + Time: time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamp{ + Time: time.Date(2021, 01, 01, 15, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Status: pgtype.Present, + }, + TimestampzRange: pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{ + Time: time.Date(2024, 05, 07, 15, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamptz{ + Time: time.Date(2024, 10, 11, 14, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int4Range: pgtype.Int4range{ + Lower: pgtype.Int4{ + Int: 11, + Status: pgtype.Present, + }, + Upper: pgtype.Int4{ + Int: 20, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int8Range: pgtype.Int8range{ + Lower: pgtype.Int8{ + Int: 200, + Status: pgtype.Present, + }, + Upper: pgtype.Int8{ + Int: 2450, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + NumRange: pgtype.Numrange{ + Lower: pgtype.Numeric{ + Int: big.NewInt(2), + Exp: 3, + Status: pgtype.Present, + }, + Upper: pgtype.Numeric{ + Int: big.NewInt(5), + Status: pgtype.Present, + Exp: 3, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + } + + testutils.AssertDeepEqual(t, sample, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestRangeTable_IntContainsRange(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.int4_range @> int4range(12, 18, '[)'::text); +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.Int4Range.CONTAINS_RANGE(INT4_RANGE(Int(12), Int(18), String("[)")))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(12), int64(18), "[)") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.NoError(t, err) + + expectedRow := model.SampleRanges{ + DateRange: pgtype.Daterange{ + Lower: pgtype.Date{ + Time: time.Date(2023, 9, 25, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Date{ + Time: time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + TimestampRange: pgtype.Tsrange{ + Lower: pgtype.Timestamp{ + Time: time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamp{ + Time: time.Date(2021, 01, 01, 15, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Status: pgtype.Present, + }, + TimestampzRange: pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{ + Time: time.Date(2024, 05, 07, 15, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamptz{ + Time: time.Date(2024, 10, 11, 14, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int4Range: pgtype.Int4range{ + Lower: pgtype.Int4{ + Int: 11, + Status: pgtype.Present, + }, + Upper: pgtype.Int4{ + Int: 20, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int8Range: pgtype.Int8range{ + Lower: pgtype.Int8{ + Int: 200, + Status: pgtype.Present, + }, + Upper: pgtype.Int8{ + Int: 2450, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + NumRange: pgtype.Numrange{ + Lower: pgtype.Numeric{ + Int: big.NewInt(2), + Exp: 3, + Status: pgtype.Present, + }, + Upper: pgtype.Numeric{ + Int: big.NewInt(5), + Status: pgtype.Present, + Exp: 3, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + } + + testutils.AssertDeepEqual(t, sample, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestRangeTable_TimestampContainsRange(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.timestamp_range @> tsrange('2020-02-01 00:00:00'::timestamp without time zone, '2020-10-01 00:00:00'::timestamp without time zone, '[)'::text); +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.TimestampRange.CONTAINS_RANGE(TIMESTAMP_RANGE(Timestamp(2020, 02, 01, 0, 0, 0), Timestamp(2020, 10, 01, 0, 0, 0), String("[)")))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2020-02-01 00:00:00", "2020-10-01 00:00:00", "[)") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.NoError(t, err) + + expectedRow := model.SampleRanges{ + DateRange: pgtype.Daterange{ + Lower: pgtype.Date{ + Time: time.Date(2023, 9, 25, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Date{ + Time: time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + TimestampRange: pgtype.Tsrange{ + Lower: pgtype.Timestamp{ + Time: time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamp{ + Time: time.Date(2021, 01, 01, 15, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Status: pgtype.Present, + }, + TimestampzRange: pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{ + Time: time.Date(2024, 05, 07, 15, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamptz{ + Time: time.Date(2024, 10, 11, 14, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int4Range: pgtype.Int4range{ + Lower: pgtype.Int4{ + Int: 11, + Status: pgtype.Present, + }, + Upper: pgtype.Int4{ + Int: 20, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int8Range: pgtype.Int8range{ + Lower: pgtype.Int8{ + Int: 200, + Status: pgtype.Present, + }, + Upper: pgtype.Int8{ + Int: 2450, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + NumRange: pgtype.Numrange{ + Lower: pgtype.Numeric{ + Int: big.NewInt(2), + Exp: 3, + Status: pgtype.Present, + }, + Upper: pgtype.Numeric{ + Int: big.NewInt(5), + Status: pgtype.Present, + Exp: 3, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + } + + testutils.AssertDeepEqual(t, sample, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestRangeTable_ContainsOutOfRange(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.int4_range @> int4range(12, 30, '[)'::text); +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.Int4Range.CONTAINS_RANGE(INT4_RANGE(Int(12), Int(30), String("[)")))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(12), int64(30), "[)") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.ErrorIs(t, err, qrm.ErrNoRows) + requireLogged(t, query) +} + +func TestRangeTable_InsertColumn(t *testing.T) { + skipForCockroachDB(t) + + insertQuery := SampleRanges.INSERT(SampleRanges.AllColumns). + VALUES( + DATE_RANGE( + Date(2010, 01, 01), + Date(2014, 01, 01), + String("[)"), + ), + DEFAULT, + TIMESTAMPTZ_RANGE( + TimestampzT(time.Date(2010, 01, 01, 23, 0, 0, 0, time.UTC)), + TimestampzT(time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC)), + String("[)"), + ), + INT4_RANGE(Int(64), Int(128), String("[]")), + INT8_RANGE(Int(1024), Int(2048), String("[]")), + DEFAULT, + ). + RETURNING(SampleRanges.AllColumns) + + expectedQuery := ` +INSERT INTO test_sample.sample_ranges (date_range, timestamp_range, timestampz_range, int4_range, int8_range, num_range) +VALUES (daterange('2010-01-01'::date, '2014-01-01'::date, '[)'::text), DEFAULT, tstzrange('2010-01-01 23:00:00Z'::timestamp with time zone, '2014-01-01 15:00:00Z'::timestamp with time zone, '[)'::text), int4range(64, 128, '[]'::text), int8range(1024, 2048, '[]'::text), DEFAULT) +RETURNING sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range"; +` + testutils.AssertDebugStatementSql(t, insertQuery, expectedQuery, + "2010-01-01", "2014-01-01", "[)", + time.Date(2010, 01, 01, 23, 0, 0, 0, time.UTC), time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC), "[)", + int64(64), int64(128), "[]", + int64(1024), int64(2048), "[]", + ) +} + +func TestRangeTable_UpperBound(t *testing.T) { + skipForCockroachDB(t) + + expectedSQL := ` +SELECT UPPER(sample_ranges.date_range) +FROM test_sample.sample_ranges +WHERE sample_ranges.date_range @> '2023-12-12'::date; +` + + query := SELECT(UPPER_BOUND[DateExpression](SampleRanges.DateRange)). + FROM(SampleRanges). + WHERE(SampleRanges.DateRange.CONTAINS(Date(2023, 12, 12))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2023-12-12") + + var date time.Time + err := query.Query(db, &date) + require.NoError(t, err) + + expectedYear := 2024 + expectedMonth := time.February + expectedDay := 10 + if expectedYear != date.Year() || expectedMonth != date.Month() || expectedDay != date.Day() { + t.Errorf("expected: 2024-02-10 got: %s", date.Format("2006-01-02")) + } +} + +func TestRangeTable_LowerBound(t *testing.T) { + skipForCockroachDB(t) + + expectedSQL := ` +SELECT LOWER(sample_ranges.date_range) +FROM test_sample.sample_ranges +WHERE sample_ranges.date_range @> '2023-12-12'::date; +` + + query := SELECT(LOWER_BOUND[DateExpression](SampleRanges.DateRange)). + FROM(SampleRanges). + WHERE(SampleRanges.DateRange.CONTAINS(Date(2023, 12, 12))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2023-12-12") + + var date time.Time + err := query.Query(db, &date) + require.NoError(t, err) + + expectedYear := 2023 + expectedMonth := time.September + expectedDay := 25 + if expectedYear != date.Year() || expectedMonth != date.Month() || expectedDay != date.Day() { + t.Errorf("expected: 2023-09-25 got: %s", date.Format("2006-01-02")) + } +} + +func TestRangeTable_InsertInfinite(t *testing.T) { + skipForCockroachDB(t) + + insertQuery := SampleRanges.INSERT(SampleRanges.AllColumns). + VALUES( + DATE_RANGE( + Date(2010, 01, 01), + DateExp(PLUS_INFINITY), + String("[)"), + ), + DEFAULT, + TIMESTAMPTZ_RANGE( + TimestampzExp(MINUS_INFINITY), + TimestampzT(time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC)), + String("[)"), + ), + INT4_RANGE(Int(64), Int(128), String("[]")), + INT8_RANGE(Int(1024), Int(2048), String("[]")), + DEFAULT, + ). + RETURNING(SampleRanges.AllColumns) + + expectedQuery := ` +INSERT INTO test_sample.sample_ranges (date_range, timestamp_range, timestampz_range, int4_range, int8_range, num_range) +VALUES (daterange('2010-01-01'::date, 'infinity', '[)'::text), DEFAULT, tstzrange('-infinity', '2014-01-01 15:00:00Z'::timestamp with time zone, '[)'::text), int4range(64, 128, '[]'::text), int8range(1024, 2048, '[]'::text), DEFAULT) +RETURNING sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range"; +` + + testutils.AssertDebugStatementSql(t, insertQuery, expectedQuery, + "2010-01-01", "infinity", "[)", + "-infinity", time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC), "[)", + int64(64), int64(128), "[]", + int64(1024), int64(2048), "[]", + ) +} diff --git a/tests/testdata b/tests/testdata index 08bcfcb..915bdc1 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 08bcfcbb2e1eadfca54c4522802fc65f1fee865c +Subproject commit 915bdc16b723d89becc577c780949baef861a6ae