From b7904cde4e3e40e79f51074445648aac743aa416 Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Tue, 13 Aug 2024 14:52:54 -0600 Subject: [PATCH 01/28] Add HasDefault column metadata --- generator/metadata/column_meta_data.go | 1 + generator/mysql/query_set.go | 1 + generator/postgres/query_set.go | 1 + generator/sqlite/query_set.go | 16 ++++++---- tests/mysql/generator_test.go | 36 +++++++++++++++++++++ tests/postgres/generator_template_test.go | 5 +-- tests/postgres/generator_test.go | 38 +++++++++++++++++++++-- tests/sqlite/generator_test.go | 33 ++++++++++++++++++++ 8 files changed, 120 insertions(+), 11 deletions(-) diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index 5c888a3..1502719 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -10,6 +10,7 @@ type Column struct { IsPrimaryKey bool IsNullable bool IsGenerated bool + HasDefault bool DataType DataType Comment string } diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index bd4d593..6dc2714 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -18,6 +18,7 @@ func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTyp SELECT t.table_name as "table.name", col.COLUMN_NAME AS "column.Name", + col.COLUMN_DEFAULT IS NOT NULL as "column.HasDefault", col.IS_NULLABLE = "YES" AS "column.IsNullable", col.COLUMN_COMMENT AS "column.Comment", COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey", diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 2d8835b..d61bd34 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -64,6 +64,7 @@ select ) as "column.IsPrimaryKey", not attr.attnotnull as "column.isNullable", attr.attgenerated = 's' as "column.isGenerated", + attr.atthasdef as "column.hasDefault", (case tp.typtype when 'b' then 'base' when 'd' then 'base' diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index 745aae4..dbf9d15 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -4,10 +4,11 @@ import ( "context" "database/sql" "fmt" + "strings" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/internal/utils/semantic" "github.com/go-jet/jet/v2/qrm" - "strings" ) // sqliteQuerySet is dialect query set for SQLite @@ -74,11 +75,12 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t } var columnInfos []struct { - Name string - Type string - NotNull int32 - Pk int32 - Hidden int32 + Name string + Type string + NotNull int32 + DfltValue string + Pk int32 + Hidden int32 } _, err = qrm.Query(context.Background(), db, tableInfoQuery, []interface{}{tableName}, &columnInfos) @@ -91,12 +93,14 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t for _, columnInfo := range columnInfos { columnType := strings.TrimSuffix(getColumnType(columnInfo.Type), " GENERATED ALWAYS") isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column + hasDefault := columnInfo.DfltValue != "" columns = append(columns, metadata.Column{ Name: columnInfo.Name, IsPrimaryKey: columnInfo.Pk != 0, IsNullable: columnInfo.NotNull != 1, IsGenerated: isGenerated, + HasDefault: hasDefault, DataType: metadata.DataType{ Name: columnType, Kind: metadata.BaseType, diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index a7b5554..2a915a5 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -8,8 +8,11 @@ import ( "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/testutils" + mysql2 "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/dbconfig" ) @@ -39,6 +42,39 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) } +func TestGenerator_TableMetadata(t *testing.T) { + var schema metadata.Schema + err := mysql.Generate(genTestDir3, dbConnection("dvds"), + template.Default(mysql2.Dialect).UseSchema(func(m metadata.Schema) template.Schema { + schema = m + return template.DefaultSchema(m) + })) + require.NoError(t, err) + + // Spot check the actor table and assert that the emitted + // properties are as expected. + var got metadata.Table + for _, table := range schema.TablesMetaData { + if table.Name == "actor" { + got = table + } + } + + want := metadata.Table{ + Name: "actor", + Columns: []metadata.Column{ + {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "smallint", Kind: "base", IsUnsigned: true}, Comment: ""}, + {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""}, + }, + } + require.Equal(t, want, got) + + err = os.RemoveAll(genTestDirRoot) + require.NoError(t, err) +} + func TestCmdGenerator(t *testing.T) { err := os.RemoveAll(genTestDir3) require.NoError(t, err) diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 2bc2d32..4d87295 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -3,6 +3,9 @@ package postgres import ( "database/sql" "fmt" + "path" + "testing" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/template" @@ -13,8 +16,6 @@ import ( "github.com/go-jet/jet/v2/tests/dbconfig" file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" "github.com/stretchr/testify/require" - "path" - "testing" ) const tempTestDir = "./.tempTestDir" diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 479d82f..a1ea307 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -2,7 +2,6 @@ package postgres import ( "fmt" - "github.com/go-jet/jet/v2/tests/internal/utils/file" "os" "os/exec" "path/filepath" @@ -12,11 +11,14 @@ import ( "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/postgres" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/dbconfig" - + postgres2 "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/go-jet/jet/v2/tests/internal/utils/file" ) func dsn(host string, port int, dbName, user, password string) string { @@ -208,6 +210,36 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) } +func TestGenerator_TableMetadata(t *testing.T) { + var schema metadata.Schema + err := postgres.GenerateDSN(defaultDSN(), "dvds", genTestDir2, + template.Default(postgres2.Dialect).UseSchema(func(m metadata.Schema) template.Schema { + schema = m + return template.DefaultSchema(m) + })) + require.NoError(t, err) + + // Spot check the actor table and assert that the emitted + // properties are as expected. + var got metadata.Table + for _, table := range schema.TablesMetaData { + if table.Name == "actor" { + got = table + } + } + + want := metadata.Table{ + Name: "actor", + Columns: []metadata.Column{ + {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "int4", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""}, + }, + } + require.Equal(t, want, got) +} + func TestGeneratorSpecialCharacters(t *testing.T) { t.SkipNow() err := postgres.Generate(genTestDir2, postgres.DBConnection{ diff --git a/tests/sqlite/generator_test.go b/tests/sqlite/generator_test.go index b232aba..8a15568 100644 --- a/tests/sqlite/generator_test.go +++ b/tests/sqlite/generator_test.go @@ -8,8 +8,11 @@ import ( "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/testutils" + sqlite2 "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" "github.com/go-jet/jet/v2/tests/internal/utils/repo" ) @@ -58,6 +61,36 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) } +func TestGenerator_TableMetadata(t *testing.T) { + var schema metadata.Schema + err := sqlite.GenerateDSN(testDatabaseFilePath, genDestDir, + template.Default(sqlite2.Dialect).UseSchema(func(m metadata.Schema) template.Schema { + schema = m + return template.DefaultSchema(m) + })) + require.NoError(t, err) + + // Spot check the actor table and assert that the emitted + // properties are as expected. + var got metadata.Table + for _, table := range schema.TablesMetaData { + if table.Name == "actor" { + got = table + } + } + + want := metadata.Table{ + Name: "actor", + Columns: []metadata.Column{ + {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "INTEGER", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "TIMESTAMP", Kind: "base", IsUnsigned: false}, Comment: ""}, + }, + } + require.Equal(t, want, got) +} + func TestCmdGenerator(t *testing.T) { cmd := exec.Command("jet", "-source=SQLite", "-dsn=file://"+testDatabaseFilePath, "-path="+genDestDir) From 882b4311b50a212c4d1e3906fbfbd05b9d1a0113 Mon Sep 17 00:00:00 2001 From: Mike Nelson Date: Wed, 28 Aug 2024 11:09:16 -0600 Subject: [PATCH 02/28] Fix Postgres column array detection --- generator/postgres/query_set.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index d61bd34..329741a 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -65,11 +65,12 @@ select not attr.attnotnull as "column.isNullable", attr.attgenerated = 's' as "column.isGenerated", attr.atthasdef as "column.hasDefault", - (case tp.typtype - when 'b' then 'base' - when 'd' then 'base' - when 'e' then 'enum' - when 'r' then 'range' + (case + when tp.typtype = 'b' AND tp.typcategory <> 'A' then 'base' + when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array' + when tp.typtype = 'd' then 'base' + when tp.typtype = 'e' then 'enum' + when tp.typtype = 'r' then 'range' end) as "dataType.Kind", (case when tp.typtype = 'd' then (select pg_type.typname from pg_catalog.pg_type where pg_type.oid = tp.typbasetype) when tp.typcategory = 'A' then pg_catalog.format_type(attr.atttypid, attr.atttypmod) From 52df6593185632b31e908c172c82b856437b485c Mon Sep 17 00:00:00 2001 From: Mike Nelson Date: Wed, 28 Aug 2024 11:45:44 -0600 Subject: [PATCH 03/28] test --- tests/postgres/generator_test.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index a1ea307..fe1407f 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -222,10 +222,18 @@ func TestGenerator_TableMetadata(t *testing.T) { // Spot check the actor table and assert that the emitted // properties are as expected. var got metadata.Table + var specialFeatures metadata.Column for _, table := range schema.TablesMetaData { if table.Name == "actor" { got = table } + if table.Name == "film" { + for _, column := range table.Columns { + if column.Name == "special_features" { + specialFeatures = column + } + } + } } want := metadata.Table{ @@ -238,6 +246,7 @@ func TestGenerator_TableMetadata(t *testing.T) { }, } require.Equal(t, want, got) + require.Equal(t, metadata.ArrayType, specialFeatures.DataType.Kind) } func TestGeneratorSpecialCharacters(t *testing.T) { From cf08bcd6f725c98d1749b2dc3654e1dc52bea7d9 Mon Sep 17 00:00:00 2001 From: Mike Nelson Date: Wed, 28 Aug 2024 11:46:35 -0600 Subject: [PATCH 04/28] spacing --- generator/postgres/query_set.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 329741a..abb21ba 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -67,7 +67,7 @@ select attr.atthasdef as "column.hasDefault", (case when tp.typtype = 'b' AND tp.typcategory <> 'A' then 'base' - when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array' + when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array' when tp.typtype = 'd' then 'base' when tp.typtype = 'e' then 'enum' when tp.typtype = 'r' then 'range' From 42a37c09d0741acac7ad4239e95cb24531fbdab3 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Wed, 18 Sep 2024 12:09:43 -0400 Subject: [PATCH 05/28] Adding support for enum model AllValues Fixes #368 ChangeLog: - Updating test mysql version as it no longer exists. - Add a simple validation test --- .circleci/config.yml | 4 ++-- generator/template/file_templates.go | 6 ++++++ tests/docker-compose.yaml | 2 +- tests/postgres/generator_template_test.go | 1 + 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e21f00a..e634858 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,7 +16,7 @@ jobs: POSTGRES_DB: jetdb PGPORT: 50901 - - image: circleci/mysql:8.0.27 + - image: circleci/mysql:8.0 command: [ --default-authentication-plugin=mysql_native_password ] environment: MYSQL_ROOT_PASSWORD: jet @@ -163,4 +163,4 @@ workflows: version: 2 build_and_test: jobs: - - build_and_tests \ No newline at end of file + - build_and_tests diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index 731b1af..fa16864 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -156,6 +156,12 @@ const ( {{- end}} ) +var {{$enumTemplate.TypeName}}_AllValues = []{{$enumTemplate.TypeName}} { +{{- range $_, $value := .Values}} + {{valueName $value}}, +{{- end}} +} + func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error { var enumValue string switch val := value.(type) { diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 9b3af50..09ce9d7 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -13,7 +13,7 @@ services: - ./testdata/init/postgres:/docker-entrypoint-initdb.d mysql: - image: mysql:8.0.27 + image: mysql:8.0 command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] restart: always environment: diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 4d87295..a9a244b 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -171,6 +171,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go") require.Contains(t, mpaaRating, "type MpaaRatingEnum string") + require.Contains(t, mpaaRating, "MpaaRatingEnum_AllValues") } func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { From eb57b8003f76b2e76f0edb2f9094d0cb85382872 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Wed, 18 Sep 2024 12:20:08 -0400 Subject: [PATCH 06/28] Migrating circleci images from legacy to update cimg namespace --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e634858..0fb458f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,7 +16,7 @@ jobs: POSTGRES_DB: jetdb PGPORT: 50901 - - image: circleci/mysql:8.0 + - image: cimg/mysql:8.0 command: [ --default-authentication-plugin=mysql_native_password ] environment: MYSQL_ROOT_PASSWORD: jet @@ -25,7 +25,7 @@ jobs: MYSQL_PASSWORD: jet MYSQL_TCP_PORT: 50902 - - image: circleci/mariadb:10.3 + - image: cimg/mariadb:10.3 command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ] environment: MYSQL_ROOT_PASSWORD: jet From b25b2aa2133276497c923cf003b72eae092a93d5 Mon Sep 17 00:00:00 2001 From: Branislav Lazic Date: Fri, 20 Sep 2024 09:47:40 +0200 Subject: [PATCH 07/28] Add Postgres DATE_TRUNC function --- postgres/functions.go | 9 +++++++++ postgres/functions_test.go | 13 ++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/postgres/functions.go b/postgres/functions.go index 7b6d1e1..75a6925 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -332,6 +332,15 @@ var LOCALTIMESTAMP = jet.LOCALTIMESTAMP // NOW returns current date and time var NOW = jet.NOW +// DATE_TRUNC returns the truncated date and time using optional time zone +func DATE_TRUNC(field unit, source Expression, timezone ...string) TimestampExpression { + if len(timezone) > 0 { + return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source, jet.FixedLiteral(timezone[0])) + } + + return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source) +} + // --------------- Conditional Expressions Functions -------------// // COALESCE function returns the first of its arguments that is not null. diff --git a/postgres/functions_test.go b/postgres/functions_test.go index 4190f70..d0081b0 100644 --- a/postgres/functions_test.go +++ b/postgres/functions_test.go @@ -1,6 +1,8 @@ package postgres -import "testing" +import ( + "testing" +) func TestROW(t *testing.T) { assertSerialize(t, ROW(SELECT(Int(1))), `ROW(( @@ -10,3 +12,12 @@ func TestROW(t *testing.T) { SELECT $2 ), $3)`) } + +func TestDATE_TRUNC(t *testing.T) { + assertSerialize(t, DATE_TRUNC(YEAR, NOW()), "DATE_TRUNC('YEAR', NOW())") + assertSerialize( + t, + DATE_TRUNC(DAY, NOW().ADD(INTERVAL(1, HOUR)), "Australia/Sydney"), + "DATE_TRUNC('DAY', NOW() + INTERVAL '1 HOUR', 'Australia/Sydney')", + ) +} From 4f80e0d36bb0d902a8b7927b1212e9b0b7bfaa70 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Sun, 22 Sep 2024 09:44:13 -0400 Subject: [PATCH 08/28] Addressing Code review comments --- generator/template/file_templates.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index fa16864..f3aa505 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -156,7 +156,7 @@ const ( {{- end}} ) -var {{$enumTemplate.TypeName}}_AllValues = []{{$enumTemplate.TypeName}} { +var {{$enumTemplate.TypeName}}AllValues = []{{$enumTemplate.TypeName}} { {{- range $_, $value := .Values}} {{valueName $value}}, {{- end}} From d3ce39f27582c8c20567451841f852847b0fd7a8 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Sun, 22 Sep 2024 16:04:11 -0400 Subject: [PATCH 09/28] Reverting circleci docker-compose --- .circleci/config.yml | 6 +++--- tests/docker-compose.yaml | 2 +- tests/postgres/generator_template_test.go | 4 +--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 0fb458f..e21f00a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -16,7 +16,7 @@ jobs: POSTGRES_DB: jetdb PGPORT: 50901 - - image: cimg/mysql:8.0 + - image: circleci/mysql:8.0.27 command: [ --default-authentication-plugin=mysql_native_password ] environment: MYSQL_ROOT_PASSWORD: jet @@ -25,7 +25,7 @@ jobs: MYSQL_PASSWORD: jet MYSQL_TCP_PORT: 50902 - - image: cimg/mariadb:10.3 + - image: circleci/mariadb:10.3 command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ] environment: MYSQL_ROOT_PASSWORD: jet @@ -163,4 +163,4 @@ workflows: version: 2 build_and_test: jobs: - - build_and_tests + - build_and_tests \ No newline at end of file diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 09ce9d7..9b3af50 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -13,7 +13,7 @@ services: - ./testdata/init/postgres:/docker-entrypoint-initdb.d mysql: - image: mysql:8.0 + image: mysql:8.0.27 command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] restart: always environment: diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index a9a244b..65603cd 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -171,7 +171,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go") require.Contains(t, mpaaRating, "type MpaaRatingEnum string") - require.Contains(t, mpaaRating, "MpaaRatingEnum_AllValues") + require.Contains(t, mpaaRating, "MpaaRatingEnumAllValues") } func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { @@ -269,7 +269,6 @@ func UseSchema(schema string) { FilmList = FilmList.FromSchema(schema) } `) - } func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { @@ -367,7 +366,6 @@ func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) { } func TestGeneratorTemplate_Model_AddTags(t *testing.T) { - err := postgres.Generate( tempTestDir, dbConnection, From dce5fd6552e9cd22c346b9aa83eabb7de3cbd518 Mon Sep 17 00:00:00 2001 From: Branislav Lazic Date: Mon, 23 Sep 2024 09:11:53 +0200 Subject: [PATCH 10/28] Add return type switching note --- postgres/functions.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/postgres/functions.go b/postgres/functions.go index 75a6925..7e71a96 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -332,7 +332,8 @@ var LOCALTIMESTAMP = jet.LOCALTIMESTAMP // NOW returns current date and time var NOW = jet.NOW -// DATE_TRUNC returns the truncated date and time using optional time zone +// DATE_TRUNC returns the truncated date and time using optional time zone. +// Use TimestampzExp if you need timestamp with time zone and IntervalExp if you need interval. func DATE_TRUNC(field unit, source Expression, timezone ...string) TimestampExpression { if len(timezone) > 0 { return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source, jet.FixedLiteral(timezone[0])) From 929109622e7a3a6fa6125048cee531da14d71e2d Mon Sep 17 00:00:00 2001 From: Volker Lieber <42102008+VolkerLieber@users.noreply.github.com> Date: Mon, 23 Sep 2024 20:34:34 +0200 Subject: [PATCH 11/28] Include postgres comments in output #391 --- generator/postgres/query_set.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index abb21ba..d549cf1 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -57,6 +57,7 @@ func getColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]meta query := ` select attr.attname as "column.Name", + dsc.description as "column.Comment", exists( select 1 from pg_catalog.pg_index indx @@ -81,6 +82,7 @@ from pg_catalog.pg_attribute as attr join pg_catalog.pg_class as cls on cls.oid = attr.attrelid join pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace join pg_catalog.pg_type as tp on tp.oid = attr.atttypid + left join pg_catalog.pg_description as dsc on dsc.objoid = attr.attrelid and dsc.objsubid = attr.attnum where ns.nspname = $1 and cls.relname = $2 and From ff82eb5df7f1e971a1d32956a9bbf23f2d976884 Mon Sep 17 00:00:00 2001 From: Volker Lieber <42102008+VolkerLieber@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:41:27 +0200 Subject: [PATCH 12/28] Implemented postgres table and enum comment generation --- generator/metadata/enum_meta_data.go | 17 +++++++++++++++-- generator/metadata/table_meta_data.go | 13 +++++++++++++ generator/postgres/query_set.go | 2 +- generator/template/file_templates.go | 4 ++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/generator/metadata/enum_meta_data.go b/generator/metadata/enum_meta_data.go index 7aea3d6..733150b 100644 --- a/generator/metadata/enum_meta_data.go +++ b/generator/metadata/enum_meta_data.go @@ -1,7 +1,20 @@ package metadata +import "regexp" + // Enum metadata struct type Enum struct { - Name string `sql:"primary_key"` - Values []string + Name string `sql:"primary_key"` + Comment string + Values []string +} + +// GoLangComment returns enum comment without ascii control characters +func (e Enum) GoLangComment() string { + if e.Comment == "" { + return "" + } + + // remove ascii control characters from string + return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(e.Comment, "") } diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go index df9514e..95fd3c4 100644 --- a/generator/metadata/table_meta_data.go +++ b/generator/metadata/table_meta_data.go @@ -1,8 +1,11 @@ package metadata +import "regexp" + // Table metadata struct type Table struct { Name string `sql:"primary_key"` + Comment string Columns []Column } @@ -20,3 +23,13 @@ func (t Table) MutableColumns() []Column { return ret } + +// GoLangComment returns table comment without ascii control characters +func (t Table) GoLangComment() string { + if t.Comment == "" { + return "" + } + + // remove ascii control characters from string + return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(t.Comment, "") +} diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index d549cf1..50129d1 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -14,7 +14,7 @@ type postgresQuerySet struct{} func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) { query := ` -SELECT table_name as "table.name" +SELECT table_name as "table.name", obj_description((table_schema||'.'||quote_ident(table_name))::regclass) as "table.comment" FROM information_schema.tables WHERE table_schema = $1 and table_type = $2 ORDER BY table_name; diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index f3aa505..45104db 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -26,6 +26,7 @@ import ( var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "{{tableTemplate.DefaultAlias}}") +{{if .Comment }} // {{.GoLangComment}} {{end}} type {{structImplName}} struct { {{dialect.PackageName}}.Table @@ -119,6 +120,7 @@ import ( {{end}} {{$modelTableTemplate := tableTemplate}} +{{if .Comment }} // {{.GoLangComment}} {{end}} type {{$modelTableTemplate.TypeName}} struct { {{- range .Columns}} {{- $field := structField .}} @@ -132,6 +134,7 @@ var enumSQLBuilderTemplate = `package {{package}} import "github.com/go-jet/jet/v2/{{dialect.PackageName}}" +{{if .Comment }} // {{.GoLangComment}} {{end}} var {{enumTemplate.InstanceName}} = &struct { {{- range $index, $value := .Values}} {{enumValueName $value}} {{dialect.PackageName}}.StringExpression @@ -148,6 +151,7 @@ var enumModelTemplate = `package {{package}} import "errors" +{{if .Comment }} // {{.GoLangComment}} {{end}} type {{$enumTemplate.TypeName}} string const ( From 0f21699a1f66480f2fa3093b18e736cf01bff7f8 Mon Sep 17 00:00:00 2001 From: Volker Lieber <42102008+VolkerLieber@users.noreply.github.com> Date: Wed, 25 Sep 2024 14:42:27 +0200 Subject: [PATCH 13/28] Implemented postgres enum comment generation --- generator/postgres/query_set.go | 1 + 1 file changed, 1 insertion(+) diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 50129d1..4474c56 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -103,6 +103,7 @@ order by func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) { query := ` SELECT t.typname as "enum.name", + obj_description(t.oid) as "enum.comment", e.enumlabel as "values" FROM pg_catalog.pg_type t JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid From 30e02dc9c07bbc0bf89f36d27f9c00b9b84b8123 Mon Sep 17 00:00:00 2001 From: Volker Lieber <42102008+VolkerLieber@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:45:06 +0200 Subject: [PATCH 14/28] Implemented postgres generator comment tests https://github.com/go-jet/jet-test-data/pull/6 --- tests/postgres/generator_test.go | 111 +++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index fe1407f..87a7eee 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -604,6 +604,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go", "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go") testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) + testutils.AssertFileContent(t, modelDir+"/link.go", linkModelContent) testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go", @@ -611,6 +612,8 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent) + testutils.AssertFileContent(t, tableDir+"/link.go", linkTableContent) + testutils.AssertFileNamesEqual(t, viewDir, "all_types_materialized_view.go", "all_types_view.go", "view_use_schema.go") } @@ -650,6 +653,7 @@ package enum import "github.com/go-jet/jet/v2/postgres" +// Level enum var Level = &struct { Level1 postgres.StringExpression Level2 postgres.StringExpression @@ -747,6 +751,25 @@ type AllTypes struct { } ` +var linkModelContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package model + +// Link table +type Link struct { + ID int64 ` + "`sql:\"primary_key\"`" + ` // this is link id + URL string // link url + Name string // Unicode characters comment ₲鬼佬℧⇄↻ + Description *string // '"Z\%_ +} +` + var allTypesTableContent = ` // // Code generated by go-jet DO NOT EDIT. @@ -1103,3 +1126,91 @@ func newSampleRangesTableImpl(schemaName, tableName, alias string) sampleRangesT } } ` + +var linkTableContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/postgres" +) + +var Link = newLinkTable("test_sample", "link", "") + +// Link table +type linkTable struct { + postgres.Table + + // Columns + ID postgres.ColumnInteger // this is link id + URL postgres.ColumnString // link url + Name postgres.ColumnString // Unicode characters comment ₲鬼佬℧⇄↻ + Description postgres.ColumnString // '"Z\%_ + + AllColumns postgres.ColumnList + MutableColumns postgres.ColumnList +} + +type LinkTable struct { + linkTable + + EXCLUDED linkTable +} + +// AS creates new LinkTable with assigned alias +func (a LinkTable) AS(alias string) *LinkTable { + return newLinkTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new LinkTable with assigned schema name +func (a LinkTable) FromSchema(schemaName string) *LinkTable { + return newLinkTable(schemaName, a.TableName(), a.Alias()) +} + +// WithPrefix creates new LinkTable with assigned table prefix +func (a LinkTable) WithPrefix(prefix string) *LinkTable { + return newLinkTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new LinkTable with assigned table suffix +func (a LinkTable) WithSuffix(suffix string) *LinkTable { + return newLinkTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + +func newLinkTable(schemaName, tableName, alias string) *LinkTable { + return &LinkTable{ + linkTable: newLinkTableImpl(schemaName, tableName, alias), + EXCLUDED: newLinkTableImpl("", "excluded", ""), + } +} + +func newLinkTableImpl(schemaName, tableName, alias string) linkTable { + var ( + IDColumn = postgres.IntegerColumn("id") + URLColumn = postgres.StringColumn("url") + NameColumn = postgres.StringColumn("name") + DescriptionColumn = postgres.StringColumn("description") + allColumns = postgres.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn} + mutableColumns = postgres.ColumnList{URLColumn, NameColumn, DescriptionColumn} + ) + + return linkTable{ + Table: postgres.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + ID: IDColumn, + URL: URLColumn, + Name: NameColumn, + Description: DescriptionColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` From 64884e496971236fe98fa10e9b0857ae15b127bf Mon Sep 17 00:00:00 2001 From: Volker Lieber <42102008+VolkerLieber@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:17:21 +0200 Subject: [PATCH 15/28] Extracted golang comment format function --- generator/metadata/column_meta_data.go | 14 ------------- generator/metadata/enum_meta_data.go | 12 ------------ generator/metadata/table_meta_data.go | 12 ------------ generator/template/file_templates.go | 12 ++++++------ generator/template/format.go | 13 +++++++++++++ generator/template/format_test.go | 27 ++++++++++++++++++++++++++ generator/template/process.go | 4 ++++ 7 files changed, 50 insertions(+), 44 deletions(-) create mode 100644 generator/template/format.go create mode 100644 generator/template/format_test.go diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index 1502719..ecd61e2 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -1,9 +1,5 @@ package metadata -import ( - "regexp" -) - // Column struct type Column struct { Name string `sql:"primary_key"` @@ -15,16 +11,6 @@ type Column struct { Comment string } -// GoLangComment returns column comment without ascii control characters -func (c Column) GoLangComment() string { - if c.Comment == "" { - return "" - } - - // remove ascii control characters from string - return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(c.Comment, "") -} - // DataTypeKind is database type kind(base, enum, user-defined, array) type DataTypeKind string diff --git a/generator/metadata/enum_meta_data.go b/generator/metadata/enum_meta_data.go index 733150b..8cce596 100644 --- a/generator/metadata/enum_meta_data.go +++ b/generator/metadata/enum_meta_data.go @@ -1,20 +1,8 @@ package metadata -import "regexp" - // Enum metadata struct type Enum struct { Name string `sql:"primary_key"` Comment string Values []string } - -// GoLangComment returns enum comment without ascii control characters -func (e Enum) GoLangComment() string { - if e.Comment == "" { - return "" - } - - // remove ascii control characters from string - return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(e.Comment, "") -} diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go index 95fd3c4..1d56bc0 100644 --- a/generator/metadata/table_meta_data.go +++ b/generator/metadata/table_meta_data.go @@ -1,7 +1,5 @@ package metadata -import "regexp" - // Table metadata struct type Table struct { Name string `sql:"primary_key"` @@ -23,13 +21,3 @@ func (t Table) MutableColumns() []Column { return ret } - -// GoLangComment returns table comment without ascii control characters -func (t Table) GoLangComment() string { - if t.Comment == "" { - return "" - } - - // remove ascii control characters from string - return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(t.Comment, "") -} diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index 45104db..1538031 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -26,14 +26,14 @@ import ( var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "{{tableTemplate.DefaultAlias}}") -{{if .Comment }} // {{.GoLangComment}} {{end}} +{{golangComment .Comment}} type {{structImplName}} struct { {{dialect.PackageName}}.Table // Columns {{- range $i, $c := .Columns}} {{- $field := columnField $c}} - {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{- if $c.Comment }} // {{$c.GoLangComment}} {{end}} + {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{golangComment .Comment}} {{- end}} AllColumns {{dialect.PackageName}}.ColumnList @@ -120,11 +120,11 @@ import ( {{end}} {{$modelTableTemplate := tableTemplate}} -{{if .Comment }} // {{.GoLangComment}} {{end}} +{{golangComment .Comment}} type {{$modelTableTemplate.TypeName}} struct { {{- range .Columns}} {{- $field := structField .}} - {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{- if .Comment }} // {{.GoLangComment}} {{end}} + {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{golangComment .Comment}} {{- end}} } @@ -134,7 +134,7 @@ var enumSQLBuilderTemplate = `package {{package}} import "github.com/go-jet/jet/v2/{{dialect.PackageName}}" -{{if .Comment }} // {{.GoLangComment}} {{end}} +{{golangComment .Comment}} var {{enumTemplate.InstanceName}} = &struct { {{- range $index, $value := .Values}} {{enumValueName $value}} {{dialect.PackageName}}.StringExpression @@ -151,7 +151,7 @@ var enumModelTemplate = `package {{package}} import "errors" -{{if .Comment }} // {{.GoLangComment}} {{end}} +{{golangComment .Comment}} type {{$enumTemplate.TypeName}} string const ( diff --git a/generator/template/format.go b/generator/template/format.go new file mode 100644 index 0000000..bd88fb1 --- /dev/null +++ b/generator/template/format.go @@ -0,0 +1,13 @@ +package template + +import "regexp" + +// Returns the provided string as golang comment without ascii control characters +func formatGolangComment(comment string) string { + if len(comment) == 0 { + return "" + } + + // Format as colang comment and remove ascii control characters from string + return "// " + regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(comment, "") +} diff --git a/generator/template/format_test.go b/generator/template/format_test.go new file mode 100644 index 0000000..b43b61d --- /dev/null +++ b/generator/template/format_test.go @@ -0,0 +1,27 @@ +package template + +import "testing" + +func Test_formatGolangComment(t *testing.T) { + type args struct { + comment string + } + tests := []struct { + name string + args args + want string + }{ + {name: "Empty string", args: args{comment: ""}, want: ""}, + {name: "Non-empty string", args: args{comment: "This is a comment"}, want: "// This is a comment"}, + {name: "String with control characters", args: args{comment: "This is a comment with control characters \x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f and text after"}, want: "// This is a comment with control characters and text after"}, + {name: "String with escape characters", args: args{comment: "This is a comment with escape characters \n\r\t and text after"}, want: "// This is a comment with escape characters and text after"}, + {name: "String with unicode characters", args: args{comment: "This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"}, want: "// This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatGolangComment(tt.args.comment); got != tt.want { + t.Errorf("formatGoLangComment() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/generator/template/process.go b/generator/template/process.go index 3f3798a..5abef87 100644 --- a/generator/template/process.go +++ b/generator/template/process.go @@ -140,6 +140,7 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData [] "enumValueName": func(enumValue string) string { return enumTemplate.ValueName(enumValue) }, + "golangComment": formatGolangComment, }) if err != nil { return fmt.Errorf("failed to generete enum type %s: %w", enumTemplate.FileName, err) @@ -215,6 +216,7 @@ func processTableSQLBuilder(fileTypes, dirPath string, "insertedRowAlias": func() string { return insertedRowAlias(dialect) }, + "golangComment": formatGolangComment, }) if err != nil { return fmt.Errorf("failed to generate table sql builder type %s: %w", tableSQLBuilder.TypeName, err) @@ -307,6 +309,7 @@ func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadat "structField": func(columnMetaData metadata.Column) TableModelField { return tableTemplate.Field(columnMetaData) }, + "golangComment": formatGolangComment, }) if err != nil { return fmt.Errorf("failed to generate model type '%s': %w", tableMetaData.Name, err) @@ -347,6 +350,7 @@ func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemp "valueName": func(value string) string { return enumTemplate.ValueName(value) }, + "golangComment": formatGolangComment, }) if err != nil { From b5f04ffea8c66718ee104fe1ee1ee6d8ef07aa0b Mon Sep 17 00:00:00 2001 From: Volker Lieber <42102008+VolkerLieber@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:33:23 +0200 Subject: [PATCH 16/28] Improved postgres table comment generation --- generator/postgres/query_set.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 4474c56..fc4135a 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -14,7 +14,7 @@ type postgresQuerySet struct{} func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) { query := ` -SELECT table_name as "table.name", obj_description((table_schema||'.'||quote_ident(table_name))::regclass) as "table.comment" +SELECT table_name as "table.name", obj_description((quote_ident(table_schema)||'.'||quote_ident(table_name))::regclass) as "table.comment" FROM information_schema.tables WHERE table_schema = $1 and table_type = $2 ORDER BY table_name; @@ -57,7 +57,7 @@ func getColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]meta query := ` select attr.attname as "column.Name", - dsc.description as "column.Comment", + col_description(attr.attrelid, attr.attnum) as "column.Comment", exists( select 1 from pg_catalog.pg_index indx @@ -82,7 +82,6 @@ from pg_catalog.pg_attribute as attr join pg_catalog.pg_class as cls on cls.oid = attr.attrelid join pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace join pg_catalog.pg_type as tp on tp.oid = attr.atttypid - left join pg_catalog.pg_description as dsc on dsc.objoid = attr.attrelid and dsc.objsubid = attr.attnum where ns.nspname = $1 and cls.relname = $2 and From 99be328e9dbeef5ab783271e8005ae627139a7c4 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Sat, 5 Oct 2024 12:32:03 -0400 Subject: [PATCH 17/28] Replacing several test util function with a generic version ChangeLog: - updated several test utils with a generic PtrOf - updated references using iotils (Deprecated) with os equivalent import. --- examples/quick-start/quick-start.go | 4 +- internal/testutils/test_utils.go | 80 +++-------------------------- tests/init/init.go | 3 +- tests/internal/utils/file/file.go | 5 +- tests/mysql/alltypes_test.go | 66 ++++++++++++------------ tests/mysql/insert_test.go | 5 +- tests/postgres/alltypes_test.go | 46 ++++++++--------- tests/postgres/chinook_db_test.go | 2 +- tests/postgres/sample_test.go | 14 ++--- tests/postgres/scan_test.go | 40 +++++++-------- tests/postgres/select_test.go | 28 +++++----- tests/sqlite/alltypes_test.go | 40 +++++++-------- tests/sqlite/insert_test.go | 2 +- tests/sqlite/sample_test.go | 2 +- tests/sqlite/select_test.go | 18 +++---- 15 files changed, 144 insertions(+), 211 deletions(-) diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go index e453f51..b4707f6 100644 --- a/examples/quick-start/quick-start.go +++ b/examples/quick-start/quick-start.go @@ -4,7 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" - "io/ioutil" + "os" _ "github.com/lib/pq" @@ -90,7 +90,7 @@ func main() { func jsonSave(path string, v interface{}) { jsonText, _ := json.MarshalIndent(v, "", "\t") - err := ioutil.WriteFile(path, jsonText, 0644) + err := os.WriteFile(path, jsonText, 0644) panicOnError(err) } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 22c0d9c..c996077 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -12,7 +12,6 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "io/ioutil" "os" "path/filepath" "runtime" @@ -109,7 +108,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) { jsonText, _ := json.MarshalIndent(v, "", "\t") filePath := getFullPath(testRelativePath) - err := ioutil.WriteFile(filePath, jsonText, 0644) + err := os.WriteFile(filePath, jsonText, 0644) throw.OnError(err) } @@ -118,7 +117,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) { func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) - fileJSONData, err := ioutil.ReadFile(filePath) + fileJSONData, err := os.ReadFile(filePath) require.NoError(t, err) if runtime.GOOS == "windows" { @@ -245,7 +244,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter // AssertFileContent check if file content at filePath contains expectedContent text. func AssertFileContent(t *testing.T, filePath string, expectedContent string) { - enumFileData, err := ioutil.ReadFile(filePath) + enumFileData, err := os.ReadFile(filePath) require.NoError(t, err) @@ -254,7 +253,7 @@ func AssertFileContent(t *testing.T, filePath string, expectedContent string) { // AssertFileNamesEqual check if all filesInfos are contained in fileNames func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) { - files, err := ioutil.ReadDir(dirPath) + files, err := os.ReadDir(dirPath) require.NoError(t, err) require.Equal(t, len(files), len(fileNames)) @@ -293,74 +292,9 @@ func printDiff(actual, expected interface{}, options ...cmp.Option) { fmt.Println(expected) } -// BoolPtr returns address of bool parameter -func BoolPtr(b bool) *bool { - return &b -} - -// Int8Ptr returns address of int8 parameter -func Int8Ptr(i int8) *int8 { - return &i -} - -// UInt8Ptr returns address of uint8 parameter -func UInt8Ptr(i uint8) *uint8 { - return &i -} - -// Int16Ptr returns address of int16 parameter -func Int16Ptr(i int16) *int16 { - return &i -} - -// UInt16Ptr returns address of uint16 parameter -func UInt16Ptr(i uint16) *uint16 { - return &i -} - -// Int32Ptr returns address of int32 parameter -func Int32Ptr(i int32) *int32 { - return &i -} - -// UInt32Ptr returns address of uint32 parameter -func UInt32Ptr(i uint32) *uint32 { - return &i -} - -// Int64Ptr returns address of int64 parameter -func Int64Ptr(i int64) *int64 { - return &i -} - -// UInt64Ptr returns address of uint64 parameter -func UInt64Ptr(i uint64) *uint64 { - return &i -} - -// StringPtr returns address of string parameter -func StringPtr(s string) *string { - return &s -} - -// TimePtr returns address of time.Time parameter -func TimePtr(t time.Time) *time.Time { - return &t -} - -// ByteArrayPtr returns address of []byte parameter -func ByteArrayPtr(arr []byte) *[]byte { - return &arr -} - -// Float32Ptr returns address of float32 parameter -func Float32Ptr(f float32) *float32 { - return &f -} - -// Float64Ptr returns address of float64 parameter -func Float64Ptr(f float64) *float64 { - return &f +// PtrOf returns the address of any given parameter +func PtrOf[T any](value T) *T { + return &value } // UUIDPtr returns address of uuid.UUID diff --git a/tests/init/init.go b/tests/init/init.go index 5a21ee7..10631f1 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -10,7 +10,6 @@ import ( "github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/internal/utils/errfmt" "github.com/go-jet/jet/v2/tests/internal/utils/repo" - "io/ioutil" "os" "os/exec" "strings" @@ -184,7 +183,7 @@ func initPostgresDB(dbType string, connectionString string) error { } func execFile(db *sql.DB, sqlFilePath string) error { - testSampleSql, err := ioutil.ReadFile(sqlFilePath) + testSampleSql, err := os.ReadFile(sqlFilePath) if err != nil { return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err) } diff --git a/tests/internal/utils/file/file.go b/tests/internal/utils/file/file.go index 6d08d22..fea0634 100644 --- a/tests/internal/utils/file/file.go +++ b/tests/internal/utils/file/file.go @@ -2,7 +2,6 @@ package file import ( "github.com/stretchr/testify/require" - "io/ioutil" "os" "path" "testing" @@ -11,7 +10,7 @@ import ( // Exists expects file to exist on path constructed from pathElems and returns content of the file func Exists(t *testing.T, pathElems ...string) (fileContent string) { modelFilePath := path.Join(pathElems...) - file, err := ioutil.ReadFile(modelFilePath) + file, err := os.ReadFile(modelFilePath) require.Nil(t, err) require.NotEmpty(t, file) return string(file) @@ -20,6 +19,6 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) { // NotExists expects file not to exist on path constructed from pathElems func NotExists(t *testing.T, pathElems ...string) { modelFilePath := path.Join(pathElems...) - _, err := ioutil.ReadFile(modelFilePath) + _, err := os.ReadFile(modelFilePath) require.True(t, os.IsNotExist(err)) } diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index c90f774..b3f6a55 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1067,7 +1067,7 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) { var toInsert = model.AllTypes{ Boolean: false, - BooleanPtr: testutils.BoolPtr(true), + BooleanPtr: testutils.PtrOf(true), TinyInt: 1, UTinyInt: 2, SmallInt: 3, @@ -1078,53 +1078,53 @@ var toInsert = model.AllTypes{ UInteger: 8, BigInt: 9, UBigInt: 1122334455, - TinyIntPtr: testutils.Int8Ptr(11), - UTinyIntPtr: testutils.UInt8Ptr(22), - SmallIntPtr: testutils.Int16Ptr(33), - USmallIntPtr: testutils.UInt16Ptr(44), - MediumIntPtr: testutils.Int32Ptr(55), - UMediumIntPtr: testutils.UInt32Ptr(66), - IntegerPtr: testutils.Int32Ptr(77), - UIntegerPtr: testutils.UInt32Ptr(88), - BigIntPtr: testutils.Int64Ptr(99), - UBigIntPtr: testutils.UInt64Ptr(111), + TinyIntPtr: testutils.PtrOf(int8(11)), + UTinyIntPtr: testutils.PtrOf(uint8(22)), + SmallIntPtr: testutils.PtrOf(int16(33)), + USmallIntPtr: testutils.PtrOf(uint16(44)), + MediumIntPtr: testutils.PtrOf(int32(55)), + UMediumIntPtr: testutils.PtrOf(uint32(66)), + IntegerPtr: testutils.PtrOf(int32(77)), + UIntegerPtr: testutils.PtrOf(uint32(88)), + BigIntPtr: testutils.PtrOf(int64(99)), + UBigIntPtr: testutils.PtrOf(uint64(111)), Decimal: 11.22, - DecimalPtr: testutils.Float64Ptr(33.44), + DecimalPtr: testutils.PtrOf(33.44), Numeric: 55.66, - NumericPtr: testutils.Float64Ptr(77.88), + NumericPtr: testutils.PtrOf(77.88), Float: 99.00, - FloatPtr: testutils.Float64Ptr(11.22), + FloatPtr: testutils.PtrOf(11.22), Double: 33.44, - DoublePtr: testutils.Float64Ptr(55.66), + DoublePtr: testutils.PtrOf(55.66), Real: 77.88, - RealPtr: testutils.Float64Ptr(99.00), + RealPtr: testutils.PtrOf(99.00), Bit: "1", - BitPtr: testutils.StringPtr("0"), + BitPtr: testutils.PtrOf("0"), Time: time.Date(1, 1, 1, 10, 11, 12, 100, &time.Location{}), - TimePtr: testutils.TimePtr(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)), + TimePtr: testutils.PtrOf(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)), Date: time.Now(), - DatePtr: testutils.TimePtr(time.Now()), + DatePtr: testutils.PtrOf(time.Now()), DateTime: time.Now(), - DateTimePtr: testutils.TimePtr(time.Now()), + DateTimePtr: testutils.PtrOf(time.Now()), Timestamp: time.Now(), //TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB Year: 2000, - YearPtr: testutils.Int16Ptr(2001), + YearPtr: testutils.PtrOf(int16(2001)), Char: "abcd", - CharPtr: testutils.StringPtr("absd"), + CharPtr: testutils.PtrOf("absd"), VarChar: "abcd", - VarCharPtr: testutils.StringPtr("absd"), + VarCharPtr: testutils.PtrOf("absd"), Binary: []byte("1010"), - BinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + BinaryPtr: testutils.PtrOf([]byte("100001")), VarBinary: []byte("1010"), - VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + VarBinaryPtr: testutils.PtrOf([]byte("100001")), Blob: []byte("large file"), - BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + BlobPtr: testutils.PtrOf([]byte("very large file")), Text: "some text", - TextPtr: testutils.StringPtr("text"), + TextPtr: testutils.PtrOf("text"), Enum: model.AllTypesEnum_Value1, JSON: "{}", - JSONPtr: testutils.StringPtr(`{"a": 1}`), + JSONPtr: testutils.PtrOf(`{"a": 1}`), } var allTypesJson = ` @@ -1358,17 +1358,17 @@ func TestExactDecimals(t *testing.T) { Floats: model.Floats{ // overwritten by wrapped(floats) scope Numeric: 0.1, - NumericPtr: testutils.Float64Ptr(0.1), + NumericPtr: testutils.PtrOf(0.1), Decimal: 0.1, - DecimalPtr: testutils.Float64Ptr(0.1), + DecimalPtr: testutils.PtrOf(0.1), // not overwritten Float: 0.2, - FloatPtr: testutils.Float64Ptr(0.22), + FloatPtr: testutils.PtrOf(0.22), Double: 0.3, - DoublePtr: testutils.Float64Ptr(0.33), + DoublePtr: testutils.PtrOf(0.33), Real: 0.4, - RealPtr: testutils.Float64Ptr(0.44), + RealPtr: testutils.PtrOf(0.44), }, Numeric: decimal.RequireFromString("12.35"), NumericPtr: decimal.RequireFromString("56.79"), diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index b05c91d..f0655f2 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -7,6 +7,7 @@ import ( . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table" + "github.com/stretchr/testify/require" "math/rand" "testing" @@ -300,7 +301,7 @@ func TestInsertOnDuplicateKeyUpdateNEW(t *testing.T) { ID: randId, URL: "https://www.yahoo.com", Name: "Yahoo", - Description: testutils.StringPtr("web portal and search engine"), + Description: testutils.PtrOf("web portal and search engine"), }, }).AS_NEW(). ON_DUPLICATE_KEY_UPDATE( @@ -337,7 +338,7 @@ ON DUPLICATE KEY UPDATE id = (link.id + ?), ID: randId + 11, URL: "https://www.yahoo.com", Name: "Yahoo", - Description: testutils.StringPtr("web portal and search engine"), + Description: testutils.PtrOf("web portal and search engine"), }) }) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index d41feee..7894293 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1305,32 +1305,32 @@ RETURNING all_types.json AS "all_types.json"; var moodSad = model.Mood_Sad var allTypesRow0 = model.AllTypes{ - SmallIntPtr: testutils.Int16Ptr(14), + SmallIntPtr: testutils.PtrOf(int16(14)), SmallInt: 14, - IntegerPtr: testutils.Int32Ptr(300), + IntegerPtr: testutils.PtrOf(int32(300)), Integer: 300, - BigIntPtr: testutils.Int64Ptr(50000), + BigIntPtr: testutils.PtrOf(int64(50000)), BigInt: 5000, - DecimalPtr: testutils.Float64Ptr(1.11), + DecimalPtr: testutils.PtrOf(1.11), Decimal: 1.11, - NumericPtr: testutils.Float64Ptr(2.22), + NumericPtr: testutils.PtrOf(2.22), Numeric: 2.22, - RealPtr: testutils.Float32Ptr(5.55), + RealPtr: testutils.PtrOf(float32(5.55)), Real: 5.55, - DoublePrecisionPtr: testutils.Float64Ptr(11111111.22), + DoublePrecisionPtr: testutils.PtrOf(11111111.22), DoublePrecision: 11111111.22, Smallserial: 1, Serial: 1, Bigserial: 1, //MoneyPtr: nil, //Money: - VarCharPtr: testutils.StringPtr("ABBA"), + VarCharPtr: testutils.PtrOf("ABBA"), VarChar: "ABBA", - CharPtr: testutils.StringPtr("JOHN "), + CharPtr: testutils.PtrOf("JOHN "), Char: "JOHN ", - TextPtr: testutils.StringPtr("Some text"), + TextPtr: testutils.PtrOf("Some text"), Text: "Some text", - ByteaPtr: testutils.ByteArrayPtr([]byte("bytea")), + ByteaPtr: testutils.PtrOf([]byte("bytea")), Bytea: []byte("bytea"), TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), @@ -1342,31 +1342,31 @@ var allTypesRow0 = model.AllTypes{ Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"), TimePtr: testutils.TimeWithoutTimeZone("04:05:06"), Time: *testutils.TimeWithoutTimeZone("04:05:06"), - IntervalPtr: testutils.StringPtr("3 days 04:05:06"), + IntervalPtr: testutils.PtrOf("3 days 04:05:06"), Interval: "3 days 04:05:06", - BooleanPtr: testutils.BoolPtr(true), + BooleanPtr: testutils.PtrOf(true), Boolean: false, - PointPtr: testutils.StringPtr("(2,3)"), - BitPtr: testutils.StringPtr("101"), + PointPtr: testutils.PtrOf("(2,3)"), + BitPtr: testutils.PtrOf("101"), Bit: "101", - BitVaryingPtr: testutils.StringPtr("101111"), + BitVaryingPtr: testutils.PtrOf("101111"), BitVarying: "101111", - TsvectorPtr: testutils.StringPtr("'supernova':1"), + TsvectorPtr: testutils.PtrOf("'supernova':1"), Tsvector: "'supernova':1", UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), - XMLPtr: testutils.StringPtr("abc"), + XMLPtr: testutils.PtrOf("abc"), XML: "abc", - JSONPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), + JSONPtr: testutils.PtrOf(`{"a": 1, "b": 3}`), JSON: `{"a": 1, "b": 3}`, - JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), + JsonbPtr: testutils.PtrOf(`{"a": 1, "b": 3}`), Jsonb: `{"a": 1, "b": 3}`, - IntegerArrayPtr: testutils.StringPtr("{1,2,3}"), + IntegerArrayPtr: testutils.PtrOf("{1,2,3}"), IntegerArray: "{1,2,3}", - TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"), + TextArrayPtr: testutils.PtrOf("{breakfast,consulting}"), TextArray: "{breakfast,consulting}", JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, - TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"), + TextMultiDimArrayPtr: testutils.PtrOf("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", MoodPtr: &moodSad, Mood: model.Mood_Happy, diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index afc4842..7c7f957 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -455,7 +455,7 @@ FROM ( require.Len(t, dest, 275) require.Equal(t, dest[0].Artist1.Artist, model.Artist{ ArtistId: 1, - Name: testutils.StringPtr("AC/DC"), + Name: testutils.PtrOf("AC/DC"), }) require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1") require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2") diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 1df72dd..8d12d0c 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -63,15 +63,15 @@ func TestExactDecimals(t *testing.T) { Floats: model.Floats{ // overwritten by wrapped(floats) scope Numeric: 0.1, - NumericPtr: testutils.Float64Ptr(0.1), + NumericPtr: testutils.PtrOf(0.1), Decimal: 0.1, - DecimalPtr: testutils.Float64Ptr(0.1), + DecimalPtr: testutils.PtrOf(0.1), // not overwritten Real: 0.4, - RealPtr: testutils.Float32Ptr(0.44), + RealPtr: testutils.PtrOf(float32(0.44)), Double: 0.3, - DoublePtr: testutils.Float64Ptr(0.33), + DoublePtr: testutils.PtrOf(0.33), }, Numeric: decimal.RequireFromString("0.1234567890123456789"), NumericPtr: decimal.RequireFromString("1.1111111111111111111"), @@ -378,7 +378,7 @@ ORDER BY employee.employee_id; FirstName: "Salley", LastName: "Lester", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1), - ManagerID: testutils.Int32Ptr(3), + ManagerID: testutils.PtrOf(int32(3)), }) } @@ -420,7 +420,7 @@ FROM test_sample."WEIRD NAMES TABLE"; WeirdColumnName5: "Doe", WeirdColumnName6: "Doe", WeirdColumnName7: "Doe", - Weirdcolumnname8: testutils.StringPtr("Doe"), + Weirdcolumnname8: testutils.PtrOf("Doe"), WeirdColName9: "Doe", WeirdColuName10: "Doe", WeirdColuName11: "Doe", @@ -518,7 +518,7 @@ func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) { ).MODEL( model.People{ PeopleName: "Dario", - PeopleHeightCm: testutils.Float64Ptr(120), + PeopleHeightCm: testutils.PtrOf(120.0), }, ).RETURNING( People.MutableColumns, diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 24b5949..b71f49a 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -93,7 +93,7 @@ func TestScanToValidDestination(t *testing.T) { err := oneInventoryQuery.Query(db, &dest) require.NoError(t, err) - require.Equal(t, dest[0], testutils.Int32Ptr(1)) + require.Equal(t, dest[0], testutils.PtrOf(int32(1))) }) t.Run("NULL to integer", func(t *testing.T) { @@ -530,10 +530,10 @@ func TestScanToSlice(t *testing.T) { require.NoError(t, err) require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) - testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.Int32Ptr(1), testutils.Int32Ptr(2), testutils.Int32Ptr(3), testutils.Int32Ptr(4), - testutils.Int32Ptr(5), testutils.Int32Ptr(6), testutils.Int32Ptr(7), testutils.Int32Ptr(8)}) + testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.PtrOf(int32(1)), testutils.PtrOf(int32(2)), testutils.PtrOf(int32(3)), testutils.PtrOf(int32(4)), + testutils.PtrOf(int32(5)), testutils.PtrOf(int32(6)), testutils.PtrOf(int32(7)), testutils.PtrOf(int32(8))}) testutils.AssertDeepEqual(t, dest[1].Film, film2) - testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.Int32Ptr(9), testutils.Int32Ptr(10)}) + testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.PtrOf(int32(9)), testutils.PtrOf(int32(10))}) }) t.Run("complex struct 1", func(t *testing.T) { @@ -1076,10 +1076,10 @@ VALUES (1234, 0, 'Joe', '', NULL, 1, TRUE, '2020-02-02 10:00:00Z', NULL, 1); var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", - Address2: testutils.StringPtr(""), + Address2: testutils.PtrOf(""), District: "England", CityID: 312, - PostalCode: testutils.StringPtr("3433"), + PostalCode: testutils.PtrOf("3433"), Phone: "246810237916", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -1087,10 +1087,10 @@ var address256 = model.Address{ var addres517 = model.Address{ AddressID: 517, Address: "548 Uruapan Street", - Address2: testutils.StringPtr(""), + Address2: testutils.PtrOf(""), District: "Ontario", CityID: 312, - PostalCode: testutils.StringPtr("35653"), + PostalCode: testutils.PtrOf("35653"), Phone: "879347453467", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -1100,12 +1100,12 @@ var customer256 = model.Customer{ StoreID: 2, FirstName: "Mattie", LastName: "Hoffman", - Email: testutils.StringPtr("mattie.hoffman@sakilacustomer.org"), + Email: testutils.PtrOf("mattie.hoffman@sakilacustomer.org"), AddressID: 256, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: testutils.Int32Ptr(1), + Active: testutils.PtrOf(int32(1)), } var customer512 = model.Customer{ @@ -1113,12 +1113,12 @@ var customer512 = model.Customer{ StoreID: 1, FirstName: "Cecil", LastName: "Vines", - Email: testutils.StringPtr("cecil.vines@sakilacustomer.org"), + Email: testutils.PtrOf("cecil.vines@sakilacustomer.org"), AddressID: 517, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: testutils.Int32Ptr(1), + Active: testutils.PtrOf(int32(1)), } var countryUk = model.Country{ @@ -1151,32 +1151,32 @@ var inventory2 = model.Inventory{ var film1 = model.Film{ FilmID: 1, Title: "Academy Dinosaur", - Description: testutils.StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), - ReleaseYear: testutils.Int32Ptr(2006), + Description: testutils.PtrOf("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), + ReleaseYear: testutils.PtrOf(int32(2006)), LanguageID: 1, RentalDuration: 6, RentalRate: 0.99, - Length: testutils.Int16Ptr(86), + Length: testutils.PtrOf(int16(86)), ReplacementCost: 20.99, Rating: &pgRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + SpecialFeatures: testutils.PtrOf("{\"Deleted Scenes\",\"Behind the Scenes\"}"), Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", } var film2 = model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: testutils.Int32Ptr(2006), + Description: testutils.PtrOf("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: testutils.PtrOf(int32(2006)), LanguageID: 1, RentalDuration: 3, RentalRate: 4.99, - Length: testutils.Int16Ptr(48), + Length: testutils.PtrOf(int16(48)), ReplacementCost: 12.99, Rating: &gRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr(`{Trailers,"Deleted Scenes"}`), + SpecialFeatures: testutils.PtrOf(`{Trailers,"Deleted Scenes"}`), Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`, } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 048db14..94bd68d 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1828,16 +1828,16 @@ ORDER BY film.film_id ASC; testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: testutils.Int32Ptr(2006), + Description: testutils.PtrOf("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: testutils.PtrOf(int32(2006)), LanguageID: 1, RentalRate: 4.99, - Length: testutils.Int16Ptr(48), + Length: testutils.PtrOf(int16(48)), ReplacementCost: 12.99, Rating: &gRating, RentalDuration: 3, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr("{Trailers,\"Deleted Scenes\"}"), + SpecialFeatures: testutils.PtrOf("{Trailers,\"Deleted Scenes\"}"), Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) } @@ -2286,11 +2286,11 @@ ORDER BY customer_payment_sum.amount_sum ASC; FirstName: "Brian", LastName: "Wyman", AddressID: 323, - Email: testutils.StringPtr("brian.wyman@sakilacustomer.org"), + Email: testutils.PtrOf("brian.wyman@sakilacustomer.org"), Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: testutils.PtrOf(int32(1)), }) require.Equal(t, customersWithAmounts[0].AmountSum, 27.93) @@ -3133,8 +3133,8 @@ func TestDynamicCondition(t *testing.T) { Active *bool } - request.CustomerID = testutils.Int64Ptr(1) - request.Active = testutils.BoolPtr(true) + request.CustomerID = testutils.PtrOf(int64(1)) + request.Active = testutils.PtrOf(true) // ... @@ -3894,12 +3894,12 @@ var customer0 = model.Customer{ StoreID: 1, FirstName: "Mary", LastName: "Smith", - Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), + Email: testutils.PtrOf("mary.smith@sakilacustomer.org"), AddressID: 5, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: testutils.PtrOf(int32(1)), } var customer1 = model.Customer{ @@ -3907,12 +3907,12 @@ var customer1 = model.Customer{ StoreID: 1, FirstName: "Patricia", LastName: "Johnson", - Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), + Email: testutils.PtrOf("patricia.johnson@sakilacustomer.org"), AddressID: 6, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: testutils.PtrOf(int32(1)), } var lastCustomer = model.Customer{ @@ -3920,10 +3920,10 @@ var lastCustomer = model.Customer{ StoreID: 2, FirstName: "Austin", LastName: "Cintron", - Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), + Email: testutils.PtrOf("austin.cintron@sakilacustomer.org"), AddressID: 605, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: testutils.PtrOf(int32(1)), } diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 58a9da6..63d1835 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -153,43 +153,43 @@ func TestAllTypesInsert(t *testing.T) { var toInsert = model.AllTypes{ Boolean: false, - BooleanPtr: testutils.BoolPtr(true), + BooleanPtr: testutils.PtrOf(true), TinyInt: 1, SmallInt: 3, MediumInt: 5, Integer: 7, BigInt: 9, - TinyIntPtr: testutils.Int8Ptr(11), - SmallIntPtr: testutils.Int16Ptr(33), - MediumIntPtr: testutils.Int32Ptr(55), - IntegerPtr: testutils.Int32Ptr(77), - BigIntPtr: testutils.Int64Ptr(99), + TinyIntPtr: testutils.PtrOf(int8(11)), + SmallIntPtr: testutils.PtrOf(int16(33)), + MediumIntPtr: testutils.PtrOf(int32(55)), + IntegerPtr: testutils.PtrOf(int32(77)), + BigIntPtr: testutils.PtrOf(int64(99)), Decimal: 11.22, - DecimalPtr: testutils.Float64Ptr(33.44), + DecimalPtr: testutils.PtrOf(33.44), Numeric: 55.66, - NumericPtr: testutils.Float64Ptr(77.88), + NumericPtr: testutils.PtrOf(77.88), Float: 99.00, - FloatPtr: testutils.Float64Ptr(11.22), + FloatPtr: testutils.PtrOf(11.22), Double: 33.44, - DoublePtr: testutils.Float64Ptr(55.66), + DoublePtr: testutils.PtrOf(55.66), Real: 77.88, - RealPtr: testutils.Float32Ptr(99.00), + RealPtr: testutils.PtrOf(float32(99.00)), Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC), - TimePtr: testutils.TimePtr(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), + TimePtr: testutils.PtrOf(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), Date: time.Now(), - DatePtr: testutils.TimePtr(time.Now()), + DatePtr: testutils.PtrOf(time.Now()), DateTime: time.Now(), - DateTimePtr: testutils.TimePtr(time.Now()), + DateTimePtr: testutils.PtrOf(time.Now()), Timestamp: time.Now(), - TimestampPtr: testutils.TimePtr(time.Now()), + TimestampPtr: testutils.PtrOf(time.Now()), Char: "abcd", - CharPtr: testutils.StringPtr("absd"), + CharPtr: testutils.PtrOf("absd"), VarChar: "abcd", - VarCharPtr: testutils.StringPtr("absd"), + VarCharPtr: testutils.PtrOf("absd"), Blob: []byte("large file"), - BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + BlobPtr: testutils.PtrOf([]byte("very large file")), Text: "some text", - TextPtr: testutils.StringPtr("text"), + TextPtr: testutils.PtrOf("text"), } func TestUUID(t *testing.T) { @@ -659,7 +659,7 @@ func TestExactDecimals(t *testing.T) { // not overwritten Numeric: "6.7", - NumericPtr: testutils.StringPtr("7.7"), + NumericPtr: testutils.PtrOf("7.7"), }, Decimal: decimal.RequireFromString("91.23"), DecimalPtr: decimal.RequireFromString("45.67"), diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go index 776d546..5e9c7b2 100644 --- a/tests/sqlite/insert_test.go +++ b/tests/sqlite/insert_test.go @@ -49,7 +49,7 @@ VALUES (?, ?, ?, ?), ID: 101, URL: "http://www.google.com", Name: "Google", - Description: testutils.StringPtr("Search engine"), + Description: testutils.PtrOf("Search engine"), }) testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ ID: 102, diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go index 4775eeb..7349d76 100644 --- a/tests/sqlite/sample_test.go +++ b/tests/sqlite/sample_test.go @@ -54,7 +54,7 @@ WHERE people.people_id = ?; ).MODEL( model.People{ PeopleName: "Dario", - PeopleHeightCm: testutils.Float64Ptr(190), + PeopleHeightCm: testutils.PtrOf(190.0), }, ).RETURNING( People.AllColumns, diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index 63d43a9..da4fa71 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -846,15 +846,15 @@ func TestSimpleView(t *testing.T) { require.Equal(t, len(dest), 10) require.Equal(t, dest[2], model.CustomerList{ - ID: testutils.Int32Ptr(3), - Name: testutils.StringPtr("LINDA WILLIAMS"), - Address: testutils.StringPtr("692 Joliet Street"), - ZipCode: testutils.StringPtr("83579"), - Phone: testutils.StringPtr(" "), - City: testutils.StringPtr("Athenai"), - Country: testutils.StringPtr("Greece"), - Notes: testutils.StringPtr("active"), - Sid: testutils.Int32Ptr(1), + ID: testutils.PtrOf(int32(3)), + Name: testutils.PtrOf("LINDA WILLIAMS"), + Address: testutils.PtrOf("692 Joliet Street"), + ZipCode: testutils.PtrOf("83579"), + Phone: testutils.PtrOf(" "), + City: testutils.PtrOf("Athenai"), + Country: testutils.PtrOf("Greece"), + Notes: testutils.PtrOf("active"), + Sid: testutils.PtrOf(int32(1)), }) } From c2703558d7b8aa4ecf5cc51fbdd5e7285cf46327 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Sun, 6 Oct 2024 09:04:10 -0400 Subject: [PATCH 18/28] Moving PtrOf to package internal/ptr --- internal/testutils/test_utils.go | 5 --- internal/utils/ptr/ptr.go | 6 +++ tests/mysql/alltypes_test.go | 67 ++++++++++++++++--------------- tests/mysql/insert_test.go | 5 ++- tests/postgres/alltypes_test.go | 47 +++++++++++----------- tests/postgres/chinook_db_test.go | 3 +- tests/postgres/sample_test.go | 15 +++---- tests/postgres/scan_test.go | 41 ++++++++++--------- tests/postgres/select_test.go | 29 ++++++------- tests/sqlite/alltypes_test.go | 41 ++++++++++--------- tests/sqlite/insert_test.go | 3 +- tests/sqlite/sample_test.go | 3 +- tests/sqlite/select_test.go | 19 ++++----- 13 files changed, 148 insertions(+), 136 deletions(-) create mode 100644 internal/utils/ptr/ptr.go diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index c996077..97f90d6 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -292,11 +292,6 @@ func printDiff(actual, expected interface{}, options ...cmp.Option) { fmt.Println(expected) } -// PtrOf returns the address of any given parameter -func PtrOf[T any](value T) *T { - return &value -} - // UUIDPtr returns address of uuid.UUID func UUIDPtr(u string) *uuid.UUID { newUUID := uuid.MustParse(u) diff --git a/internal/utils/ptr/ptr.go b/internal/utils/ptr/ptr.go new file mode 100644 index 0000000..26f2688 --- /dev/null +++ b/internal/utils/ptr/ptr.go @@ -0,0 +1,6 @@ +package ptr + +// Of returns the address of any given parameter +func Of[T any](value T) *T { + return &value +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index b3f6a55..d5702bd 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,6 +1,7 @@ package mysql import ( + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "strings" @@ -1067,7 +1068,7 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) { var toInsert = model.AllTypes{ Boolean: false, - BooleanPtr: testutils.PtrOf(true), + BooleanPtr: ptr.Of(true), TinyInt: 1, UTinyInt: 2, SmallInt: 3, @@ -1078,53 +1079,53 @@ var toInsert = model.AllTypes{ UInteger: 8, BigInt: 9, UBigInt: 1122334455, - TinyIntPtr: testutils.PtrOf(int8(11)), - UTinyIntPtr: testutils.PtrOf(uint8(22)), - SmallIntPtr: testutils.PtrOf(int16(33)), - USmallIntPtr: testutils.PtrOf(uint16(44)), - MediumIntPtr: testutils.PtrOf(int32(55)), - UMediumIntPtr: testutils.PtrOf(uint32(66)), - IntegerPtr: testutils.PtrOf(int32(77)), - UIntegerPtr: testutils.PtrOf(uint32(88)), - BigIntPtr: testutils.PtrOf(int64(99)), - UBigIntPtr: testutils.PtrOf(uint64(111)), + TinyIntPtr: ptr.Of(int8(11)), + UTinyIntPtr: ptr.Of(uint8(22)), + SmallIntPtr: ptr.Of(int16(33)), + USmallIntPtr: ptr.Of(uint16(44)), + MediumIntPtr: ptr.Of(int32(55)), + UMediumIntPtr: ptr.Of(uint32(66)), + IntegerPtr: ptr.Of(int32(77)), + UIntegerPtr: ptr.Of(uint32(88)), + BigIntPtr: ptr.Of(int64(99)), + UBigIntPtr: ptr.Of(uint64(111)), Decimal: 11.22, - DecimalPtr: testutils.PtrOf(33.44), + DecimalPtr: ptr.Of(33.44), Numeric: 55.66, - NumericPtr: testutils.PtrOf(77.88), + NumericPtr: ptr.Of(77.88), Float: 99.00, - FloatPtr: testutils.PtrOf(11.22), + FloatPtr: ptr.Of(11.22), Double: 33.44, - DoublePtr: testutils.PtrOf(55.66), + DoublePtr: ptr.Of(55.66), Real: 77.88, - RealPtr: testutils.PtrOf(99.00), + RealPtr: ptr.Of(99.00), Bit: "1", - BitPtr: testutils.PtrOf("0"), + BitPtr: ptr.Of("0"), Time: time.Date(1, 1, 1, 10, 11, 12, 100, &time.Location{}), - TimePtr: testutils.PtrOf(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)), + TimePtr: ptr.Of(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)), Date: time.Now(), - DatePtr: testutils.PtrOf(time.Now()), + DatePtr: ptr.Of(time.Now()), DateTime: time.Now(), - DateTimePtr: testutils.PtrOf(time.Now()), + DateTimePtr: ptr.Of(time.Now()), Timestamp: time.Now(), //TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB Year: 2000, - YearPtr: testutils.PtrOf(int16(2001)), + YearPtr: ptr.Of(int16(2001)), Char: "abcd", - CharPtr: testutils.PtrOf("absd"), + CharPtr: ptr.Of("absd"), VarChar: "abcd", - VarCharPtr: testutils.PtrOf("absd"), + VarCharPtr: ptr.Of("absd"), Binary: []byte("1010"), - BinaryPtr: testutils.PtrOf([]byte("100001")), + BinaryPtr: ptr.Of([]byte("100001")), VarBinary: []byte("1010"), - VarBinaryPtr: testutils.PtrOf([]byte("100001")), + VarBinaryPtr: ptr.Of([]byte("100001")), Blob: []byte("large file"), - BlobPtr: testutils.PtrOf([]byte("very large file")), + BlobPtr: ptr.Of([]byte("very large file")), Text: "some text", - TextPtr: testutils.PtrOf("text"), + TextPtr: ptr.Of("text"), Enum: model.AllTypesEnum_Value1, JSON: "{}", - JSONPtr: testutils.PtrOf(`{"a": 1}`), + JSONPtr: ptr.Of(`{"a": 1}`), } var allTypesJson = ` @@ -1358,17 +1359,17 @@ func TestExactDecimals(t *testing.T) { Floats: model.Floats{ // overwritten by wrapped(floats) scope Numeric: 0.1, - NumericPtr: testutils.PtrOf(0.1), + NumericPtr: ptr.Of(0.1), Decimal: 0.1, - DecimalPtr: testutils.PtrOf(0.1), + DecimalPtr: ptr.Of(0.1), // not overwritten Float: 0.2, - FloatPtr: testutils.PtrOf(0.22), + FloatPtr: ptr.Of(0.22), Double: 0.3, - DoublePtr: testutils.PtrOf(0.33), + DoublePtr: ptr.Of(0.33), Real: 0.4, - RealPtr: testutils.PtrOf(0.44), + RealPtr: ptr.Of(0.44), }, Numeric: decimal.RequireFromString("12.35"), NumericPtr: decimal.RequireFromString("56.79"), diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index f0655f2..8874456 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table" @@ -301,7 +302,7 @@ func TestInsertOnDuplicateKeyUpdateNEW(t *testing.T) { ID: randId, URL: "https://www.yahoo.com", Name: "Yahoo", - Description: testutils.PtrOf("web portal and search engine"), + Description: ptr.Of("web portal and search engine"), }, }).AS_NEW(). ON_DUPLICATE_KEY_UPDATE( @@ -338,7 +339,7 @@ ON DUPLICATE KEY UPDATE id = (link.id + ?), ID: randId + 11, URL: "https://www.yahoo.com", Name: "Yahoo", - Description: testutils.PtrOf("web portal and search engine"), + Description: ptr.Of("web portal and search engine"), }) }) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 7894293..2f1be14 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -2,6 +2,7 @@ package postgres import ( "database/sql" + "github.com/go-jet/jet/v2/internal/utils/ptr" "testing" "time" @@ -1305,32 +1306,32 @@ RETURNING all_types.json AS "all_types.json"; var moodSad = model.Mood_Sad var allTypesRow0 = model.AllTypes{ - SmallIntPtr: testutils.PtrOf(int16(14)), + SmallIntPtr: ptr.Of(int16(14)), SmallInt: 14, - IntegerPtr: testutils.PtrOf(int32(300)), + IntegerPtr: ptr.Of(int32(300)), Integer: 300, - BigIntPtr: testutils.PtrOf(int64(50000)), + BigIntPtr: ptr.Of(int64(50000)), BigInt: 5000, - DecimalPtr: testutils.PtrOf(1.11), + DecimalPtr: ptr.Of(1.11), Decimal: 1.11, - NumericPtr: testutils.PtrOf(2.22), + NumericPtr: ptr.Of(2.22), Numeric: 2.22, - RealPtr: testutils.PtrOf(float32(5.55)), + RealPtr: ptr.Of(float32(5.55)), Real: 5.55, - DoublePrecisionPtr: testutils.PtrOf(11111111.22), + DoublePrecisionPtr: ptr.Of(11111111.22), DoublePrecision: 11111111.22, Smallserial: 1, Serial: 1, Bigserial: 1, //MoneyPtr: nil, //Money: - VarCharPtr: testutils.PtrOf("ABBA"), + VarCharPtr: ptr.Of("ABBA"), VarChar: "ABBA", - CharPtr: testutils.PtrOf("JOHN "), + CharPtr: ptr.Of("JOHN "), Char: "JOHN ", - TextPtr: testutils.PtrOf("Some text"), + TextPtr: ptr.Of("Some text"), Text: "Some text", - ByteaPtr: testutils.PtrOf([]byte("bytea")), + ByteaPtr: ptr.Of([]byte("bytea")), Bytea: []byte("bytea"), TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), @@ -1342,31 +1343,31 @@ var allTypesRow0 = model.AllTypes{ Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"), TimePtr: testutils.TimeWithoutTimeZone("04:05:06"), Time: *testutils.TimeWithoutTimeZone("04:05:06"), - IntervalPtr: testutils.PtrOf("3 days 04:05:06"), + IntervalPtr: ptr.Of("3 days 04:05:06"), Interval: "3 days 04:05:06", - BooleanPtr: testutils.PtrOf(true), + BooleanPtr: ptr.Of(true), Boolean: false, - PointPtr: testutils.PtrOf("(2,3)"), - BitPtr: testutils.PtrOf("101"), + PointPtr: ptr.Of("(2,3)"), + BitPtr: ptr.Of("101"), Bit: "101", - BitVaryingPtr: testutils.PtrOf("101111"), + BitVaryingPtr: ptr.Of("101111"), BitVarying: "101111", - TsvectorPtr: testutils.PtrOf("'supernova':1"), + TsvectorPtr: ptr.Of("'supernova':1"), Tsvector: "'supernova':1", UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), - XMLPtr: testutils.PtrOf("abc"), + XMLPtr: ptr.Of("abc"), XML: "abc", - JSONPtr: testutils.PtrOf(`{"a": 1, "b": 3}`), + JSONPtr: ptr.Of(`{"a": 1, "b": 3}`), JSON: `{"a": 1, "b": 3}`, - JsonbPtr: testutils.PtrOf(`{"a": 1, "b": 3}`), + JsonbPtr: ptr.Of(`{"a": 1, "b": 3}`), Jsonb: `{"a": 1, "b": 3}`, - IntegerArrayPtr: testutils.PtrOf("{1,2,3}"), + IntegerArrayPtr: ptr.Of("{1,2,3}"), IntegerArray: "{1,2,3}", - TextArrayPtr: testutils.PtrOf("{breakfast,consulting}"), + TextArrayPtr: ptr.Of("{breakfast,consulting}"), TextArray: "{breakfast,consulting}", JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, - TextMultiDimArrayPtr: testutils.PtrOf("{{meeting,lunch},{training,presentation}}"), + TextMultiDimArrayPtr: ptr.Of("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", MoodPtr: &moodSad, Mood: model.Mood_Happy, diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 7c7f957..33507e0 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -3,6 +3,7 @@ package postgres import ( "context" "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table" @@ -455,7 +456,7 @@ FROM ( require.Len(t, dest, 275) require.Equal(t, dest[0].Artist1.Artist, model.Artist{ ArtistId: 1, - Name: testutils.PtrOf("AC/DC"), + Name: ptr.Of("AC/DC"), }) require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1") require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2") diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 8d12d0c..a1d4c2d 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -2,6 +2,7 @@ package postgres import ( "database/sql" + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/google/uuid" "testing" @@ -63,15 +64,15 @@ func TestExactDecimals(t *testing.T) { Floats: model.Floats{ // overwritten by wrapped(floats) scope Numeric: 0.1, - NumericPtr: testutils.PtrOf(0.1), + NumericPtr: ptr.Of(0.1), Decimal: 0.1, - DecimalPtr: testutils.PtrOf(0.1), + DecimalPtr: ptr.Of(0.1), // not overwritten Real: 0.4, - RealPtr: testutils.PtrOf(float32(0.44)), + RealPtr: ptr.Of(float32(0.44)), Double: 0.3, - DoublePtr: testutils.PtrOf(0.33), + DoublePtr: ptr.Of(0.33), }, Numeric: decimal.RequireFromString("0.1234567890123456789"), NumericPtr: decimal.RequireFromString("1.1111111111111111111"), @@ -378,7 +379,7 @@ ORDER BY employee.employee_id; FirstName: "Salley", LastName: "Lester", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1), - ManagerID: testutils.PtrOf(int32(3)), + ManagerID: ptr.Of(int32(3)), }) } @@ -420,7 +421,7 @@ FROM test_sample."WEIRD NAMES TABLE"; WeirdColumnName5: "Doe", WeirdColumnName6: "Doe", WeirdColumnName7: "Doe", - Weirdcolumnname8: testutils.PtrOf("Doe"), + Weirdcolumnname8: ptr.Of("Doe"), WeirdColName9: "Doe", WeirdColuName10: "Doe", WeirdColuName11: "Doe", @@ -518,7 +519,7 @@ func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) { ).MODEL( model.People{ PeopleName: "Dario", - PeopleHeightCm: testutils.PtrOf(120.0), + PeopleHeightCm: ptr.Of(120.0), }, ).RETURNING( People.MutableColumns, diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index b71f49a..321fc38 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/volatiletech/null/v8" "testing" "time" @@ -93,7 +94,7 @@ func TestScanToValidDestination(t *testing.T) { err := oneInventoryQuery.Query(db, &dest) require.NoError(t, err) - require.Equal(t, dest[0], testutils.PtrOf(int32(1))) + require.Equal(t, dest[0], ptr.Of(int32(1))) }) t.Run("NULL to integer", func(t *testing.T) { @@ -530,10 +531,10 @@ func TestScanToSlice(t *testing.T) { require.NoError(t, err) require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) - testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.PtrOf(int32(1)), testutils.PtrOf(int32(2)), testutils.PtrOf(int32(3)), testutils.PtrOf(int32(4)), - testutils.PtrOf(int32(5)), testutils.PtrOf(int32(6)), testutils.PtrOf(int32(7)), testutils.PtrOf(int32(8))}) + testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{ptr.Of(int32(1)), ptr.Of(int32(2)), ptr.Of(int32(3)), ptr.Of(int32(4)), + ptr.Of(int32(5)), ptr.Of(int32(6)), ptr.Of(int32(7)), ptr.Of(int32(8))}) testutils.AssertDeepEqual(t, dest[1].Film, film2) - testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.PtrOf(int32(9)), testutils.PtrOf(int32(10))}) + testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{ptr.Of(int32(9)), ptr.Of(int32(10))}) }) t.Run("complex struct 1", func(t *testing.T) { @@ -1076,10 +1077,10 @@ VALUES (1234, 0, 'Joe', '', NULL, 1, TRUE, '2020-02-02 10:00:00Z', NULL, 1); var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", - Address2: testutils.PtrOf(""), + Address2: ptr.Of(""), District: "England", CityID: 312, - PostalCode: testutils.PtrOf("3433"), + PostalCode: ptr.Of("3433"), Phone: "246810237916", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -1087,10 +1088,10 @@ var address256 = model.Address{ var addres517 = model.Address{ AddressID: 517, Address: "548 Uruapan Street", - Address2: testutils.PtrOf(""), + Address2: ptr.Of(""), District: "Ontario", CityID: 312, - PostalCode: testutils.PtrOf("35653"), + PostalCode: ptr.Of("35653"), Phone: "879347453467", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -1100,12 +1101,12 @@ var customer256 = model.Customer{ StoreID: 2, FirstName: "Mattie", LastName: "Hoffman", - Email: testutils.PtrOf("mattie.hoffman@sakilacustomer.org"), + Email: ptr.Of("mattie.hoffman@sakilacustomer.org"), AddressID: 256, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: testutils.PtrOf(int32(1)), + Active: ptr.Of(int32(1)), } var customer512 = model.Customer{ @@ -1113,12 +1114,12 @@ var customer512 = model.Customer{ StoreID: 1, FirstName: "Cecil", LastName: "Vines", - Email: testutils.PtrOf("cecil.vines@sakilacustomer.org"), + Email: ptr.Of("cecil.vines@sakilacustomer.org"), AddressID: 517, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: testutils.PtrOf(int32(1)), + Active: ptr.Of(int32(1)), } var countryUk = model.Country{ @@ -1151,32 +1152,32 @@ var inventory2 = model.Inventory{ var film1 = model.Film{ FilmID: 1, Title: "Academy Dinosaur", - Description: testutils.PtrOf("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), - ReleaseYear: testutils.PtrOf(int32(2006)), + Description: ptr.Of("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), + ReleaseYear: ptr.Of(int32(2006)), LanguageID: 1, RentalDuration: 6, RentalRate: 0.99, - Length: testutils.PtrOf(int16(86)), + Length: ptr.Of(int16(86)), ReplacementCost: 20.99, Rating: &pgRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.PtrOf("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + SpecialFeatures: ptr.Of("{\"Deleted Scenes\",\"Behind the Scenes\"}"), Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", } var film2 = model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: testutils.PtrOf("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: testutils.PtrOf(int32(2006)), + Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: ptr.Of(int32(2006)), LanguageID: 1, RentalDuration: 3, RentalRate: 4.99, - Length: testutils.PtrOf(int16(48)), + Length: ptr.Of(int16(48)), ReplacementCost: 12.99, Rating: &gRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.PtrOf(`{Trailers,"Deleted Scenes"}`), + SpecialFeatures: ptr.Of(`{Trailers,"Deleted Scenes"}`), Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`, } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 94bd68d..14dd627 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "github.com/go-jet/jet/v2/internal/utils/ptr" "testing" "time" @@ -1828,16 +1829,16 @@ ORDER BY film.film_id ASC; testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: testutils.PtrOf("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: testutils.PtrOf(int32(2006)), + Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: ptr.Of(int32(2006)), LanguageID: 1, RentalRate: 4.99, - Length: testutils.PtrOf(int16(48)), + Length: ptr.Of(int16(48)), ReplacementCost: 12.99, Rating: &gRating, RentalDuration: 3, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.PtrOf("{Trailers,\"Deleted Scenes\"}"), + SpecialFeatures: ptr.Of("{Trailers,\"Deleted Scenes\"}"), Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) } @@ -2286,11 +2287,11 @@ ORDER BY customer_payment_sum.amount_sum ASC; FirstName: "Brian", LastName: "Wyman", AddressID: 323, - Email: testutils.PtrOf("brian.wyman@sakilacustomer.org"), + Email: ptr.Of("brian.wyman@sakilacustomer.org"), Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.PtrOf(int32(1)), + Active: ptr.Of(int32(1)), }) require.Equal(t, customersWithAmounts[0].AmountSum, 27.93) @@ -3133,8 +3134,8 @@ func TestDynamicCondition(t *testing.T) { Active *bool } - request.CustomerID = testutils.PtrOf(int64(1)) - request.Active = testutils.PtrOf(true) + request.CustomerID = ptr.Of(int64(1)) + request.Active = ptr.Of(true) // ... @@ -3894,12 +3895,12 @@ var customer0 = model.Customer{ StoreID: 1, FirstName: "Mary", LastName: "Smith", - Email: testutils.PtrOf("mary.smith@sakilacustomer.org"), + Email: ptr.Of("mary.smith@sakilacustomer.org"), AddressID: 5, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.PtrOf(int32(1)), + Active: ptr.Of(int32(1)), } var customer1 = model.Customer{ @@ -3907,12 +3908,12 @@ var customer1 = model.Customer{ StoreID: 1, FirstName: "Patricia", LastName: "Johnson", - Email: testutils.PtrOf("patricia.johnson@sakilacustomer.org"), + Email: ptr.Of("patricia.johnson@sakilacustomer.org"), AddressID: 6, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.PtrOf(int32(1)), + Active: ptr.Of(int32(1)), } var lastCustomer = model.Customer{ @@ -3920,10 +3921,10 @@ var lastCustomer = model.Customer{ StoreID: 2, FirstName: "Austin", LastName: "Cintron", - Email: testutils.PtrOf("austin.cintron@sakilacustomer.org"), + Email: ptr.Of("austin.cintron@sakilacustomer.org"), AddressID: 605, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.PtrOf(int32(1)), + Active: ptr.Of(int32(1)), } diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 63d1835..41e5cf7 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -2,6 +2,7 @@ package sqlite import ( "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" @@ -153,43 +154,43 @@ func TestAllTypesInsert(t *testing.T) { var toInsert = model.AllTypes{ Boolean: false, - BooleanPtr: testutils.PtrOf(true), + BooleanPtr: ptr.Of(true), TinyInt: 1, SmallInt: 3, MediumInt: 5, Integer: 7, BigInt: 9, - TinyIntPtr: testutils.PtrOf(int8(11)), - SmallIntPtr: testutils.PtrOf(int16(33)), - MediumIntPtr: testutils.PtrOf(int32(55)), - IntegerPtr: testutils.PtrOf(int32(77)), - BigIntPtr: testutils.PtrOf(int64(99)), + TinyIntPtr: ptr.Of(int8(11)), + SmallIntPtr: ptr.Of(int16(33)), + MediumIntPtr: ptr.Of(int32(55)), + IntegerPtr: ptr.Of(int32(77)), + BigIntPtr: ptr.Of(int64(99)), Decimal: 11.22, - DecimalPtr: testutils.PtrOf(33.44), + DecimalPtr: ptr.Of(33.44), Numeric: 55.66, - NumericPtr: testutils.PtrOf(77.88), + NumericPtr: ptr.Of(77.88), Float: 99.00, - FloatPtr: testutils.PtrOf(11.22), + FloatPtr: ptr.Of(11.22), Double: 33.44, - DoublePtr: testutils.PtrOf(55.66), + DoublePtr: ptr.Of(55.66), Real: 77.88, - RealPtr: testutils.PtrOf(float32(99.00)), + RealPtr: ptr.Of(float32(99.00)), Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC), - TimePtr: testutils.PtrOf(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), + TimePtr: ptr.Of(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), Date: time.Now(), - DatePtr: testutils.PtrOf(time.Now()), + DatePtr: ptr.Of(time.Now()), DateTime: time.Now(), - DateTimePtr: testutils.PtrOf(time.Now()), + DateTimePtr: ptr.Of(time.Now()), Timestamp: time.Now(), - TimestampPtr: testutils.PtrOf(time.Now()), + TimestampPtr: ptr.Of(time.Now()), Char: "abcd", - CharPtr: testutils.PtrOf("absd"), + CharPtr: ptr.Of("absd"), VarChar: "abcd", - VarCharPtr: testutils.PtrOf("absd"), + VarCharPtr: ptr.Of("absd"), Blob: []byte("large file"), - BlobPtr: testutils.PtrOf([]byte("very large file")), + BlobPtr: ptr.Of([]byte("very large file")), Text: "some text", - TextPtr: testutils.PtrOf("text"), + TextPtr: ptr.Of("text"), } func TestUUID(t *testing.T) { @@ -659,7 +660,7 @@ func TestExactDecimals(t *testing.T) { // not overwritten Numeric: "6.7", - NumericPtr: testutils.PtrOf("7.7"), + NumericPtr: ptr.Of("7.7"), }, Decimal: decimal.RequireFromString("91.23"), DecimalPtr: decimal.RequireFromString("45.67"), diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go index 5e9c7b2..dd31784 100644 --- a/tests/sqlite/insert_test.go +++ b/tests/sqlite/insert_test.go @@ -3,6 +3,7 @@ package sqlite import ( "context" "database/sql" + "github.com/go-jet/jet/v2/internal/utils/ptr" "math/rand" "testing" @@ -49,7 +50,7 @@ VALUES (?, ?, ?, ?), ID: 101, URL: "http://www.google.com", Name: "Google", - Description: testutils.PtrOf("Search engine"), + Description: ptr.Of("Search engine"), }) testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ ID: 102, diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go index 7349d76..c306421 100644 --- a/tests/sqlite/sample_test.go +++ b/tests/sqlite/sample_test.go @@ -3,6 +3,7 @@ package sqlite import ( "database/sql" "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/require" "testing" @@ -54,7 +55,7 @@ WHERE people.people_id = ?; ).MODEL( model.People{ PeopleName: "Dario", - PeopleHeightCm: testutils.PtrOf(190.0), + PeopleHeightCm: ptr.Of(190.0), }, ).RETURNING( People.AllColumns, diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index da4fa71..baafa76 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "github.com/go-jet/jet/v2/internal/utils/ptr" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table" "strings" @@ -846,15 +847,15 @@ func TestSimpleView(t *testing.T) { require.Equal(t, len(dest), 10) require.Equal(t, dest[2], model.CustomerList{ - ID: testutils.PtrOf(int32(3)), - Name: testutils.PtrOf("LINDA WILLIAMS"), - Address: testutils.PtrOf("692 Joliet Street"), - ZipCode: testutils.PtrOf("83579"), - Phone: testutils.PtrOf(" "), - City: testutils.PtrOf("Athenai"), - Country: testutils.PtrOf("Greece"), - Notes: testutils.PtrOf("active"), - Sid: testutils.PtrOf(int32(1)), + ID: ptr.Of(int32(3)), + Name: ptr.Of("LINDA WILLIAMS"), + Address: ptr.Of("692 Joliet Street"), + ZipCode: ptr.Of("83579"), + Phone: ptr.Of(" "), + City: ptr.Of("Athenai"), + Country: ptr.Of("Greece"), + Notes: ptr.Of("active"), + Sid: ptr.Of(int32(1)), }) } From 743df3ae7dddd32bece0a1dd5b90362dae3c048d Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Mon, 7 Oct 2024 08:42:44 -0400 Subject: [PATCH 19/28] Updating mysql version to allow to run on OS X as well --- tests/Makefile | 2 +- tests/docker-compose.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Makefile b/tests/Makefile index 1a84778..8d40346 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -6,7 +6,7 @@ setup: checkout-testdata docker-compose-up checkout-testdata: git submodule init git submodule update - cd ./testdata && git fetch && git checkout master && git pull +# cd ./testdata && git fetch && git checkout master && git pull # docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize # database with testdata retrieved in previous step. diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 9b3af50..09ce9d7 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -13,7 +13,7 @@ services: - ./testdata/init/postgres:/docker-entrypoint-initdb.d mysql: - image: mysql:8.0.27 + image: mysql:8.0 command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] restart: always environment: From a77ecc3a30a50259619239e8f8af9d1ee22f00ce Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Mon, 7 Oct 2024 12:54:15 -0400 Subject: [PATCH 20/28] Addressing code review comments --- tests/Makefile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/Makefile b/tests/Makefile index 8d40346..43a8d66 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -6,7 +6,9 @@ setup: checkout-testdata docker-compose-up checkout-testdata: git submodule init git submodule update -# cd ./testdata && git fetch && git checkout master && git pull +# +checkout-latest-testdata: checkout-testdata + cd ./testdata && git fetch && git checkout master && git pull # docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize # database with testdata retrieved in previous step. From c30a3507e328e64f99d61736dee5c9d6a5354396 Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 7 Oct 2024 20:23:25 +0200 Subject: [PATCH 21/28] Fix build. --- cmd/jet/version.go | 2 +- tests/testdata | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/jet/version.go b/cmd/jet/version.go index 4241ec7..3300008 100644 --- a/cmd/jet/version.go +++ b/cmd/jet/version.go @@ -1,3 +1,3 @@ package main -const version = "v2.11.0" +const version = "v2.11.1" diff --git a/tests/testdata b/tests/testdata index 915bdc1..1e9247e 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 915bdc16b723d89becc577c780949baef861a6ae +Subproject commit 1e9247e333babd5172cf162e38518d993f5f3df4 From 288ebdc373953bbd02b59d9eeacafb52795bbee7 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Mon, 7 Oct 2024 15:02:12 -0400 Subject: [PATCH 22/28] Improved support for intervals in postgres Fixes #393 --- internal/jet/column_assigment.go | 7 ++++++ internal/testutils/test_utils.go | 3 +-- postgres/columns.go | 7 ++++++ tests/postgres/alltypes_test.go | 38 ++++++++++++++++++++++++++++++++ tests/postgres/insert_test.go | 22 +++++++++--------- tests/postgres/sample_test.go | 5 ++++- tests/testdata | 2 +- 7 files changed, 69 insertions(+), 15 deletions(-) diff --git a/internal/jet/column_assigment.go b/internal/jet/column_assigment.go index 440f3eb..c888433 100644 --- a/internal/jet/column_assigment.go +++ b/internal/jet/column_assigment.go @@ -11,6 +11,13 @@ type columnAssigmentImpl struct { expression Expression } +func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment { + return &columnAssigmentImpl{ + column: serializer, + expression: expression, + } +} + func (a columnAssigmentImpl) isColumnAssigment() {} func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 97f90d6..cfedde4 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -9,6 +9,7 @@ import ( "github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/qrm" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,8 +18,6 @@ import ( "runtime" "testing" "time" - - "github.com/google/go-cmp/cmp" ) // UnixTimeComparer will compare time equality while ignoring time zone diff --git a/postgres/columns.go b/postgres/columns.go index 819da38..a70c234 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -109,13 +109,20 @@ type ColumnInterval interface { jet.Column From(subQuery SelectTable) ColumnInterval + SET(intervalExp IntervalExpression) ColumnAssigment } +//------------------------------------------------------// + type intervalColumnImpl struct { jet.ColumnExpressionImpl intervalInterfaceImpl } +func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment { + return jet.NewColumnAssignment(i, intervalExp) +} + func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval { newIntervalColumn := IntervalColumn(i.Name()) jet.SetTableName(newIntervalColumn, i.TableName()) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 2f1be14..fdb6ac9 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -3,6 +3,7 @@ package postgres import ( "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/stretchr/testify/assert" "testing" "time" @@ -931,6 +932,43 @@ func TestTimeExpression(t *testing.T) { require.NoError(t, err) } +func TestIntervalUpsert(t *testing.T) { + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + stmt := SELECT(Employee.AllColumns).FROM(Employee). + WHERE(Employee.EmployeeID.EQ(Int(1))) + + //Validate initial dataset + var windy model.Employee + err := stmt.Query(tx, &windy) + assert.Equal(t, windy.EmployeeID, int32(1)) + assert.Equal(t, windy.FirstName, "Windy") + assert.Equal(t, windy.LastName, "Hays") + assert.Equal(t, *windy.PtoAccrual, "22:00:00") + assert.Nil(t, err) + windy.PtoAccrual = ptr.Of("3h") + //Update data + updateStmt := Employee.UPDATE(Employee.PtoAccrual).SET( + Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), + ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + + err = updateStmt.Query(tx, &windy) + err = stmt.Query(tx, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, "03:00:00") + //Upsert dataset with a different value + windy.PtoAccrual = ptr.Of("5h") + insertStmt := Employee.INSERT(Employee.AllColumns). + MODEL(&windy). + ON_CONFLICT(Employee.EmployeeID). + DO_UPDATE(SET( + Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), + )).RETURNING(Employee.AllColumns) + err = insertStmt.Query(tx, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, "05:00:00") + }) +} + func TestInterval(t *testing.T) { skipForCockroachDB(t) diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index a079091..7edefd4 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -93,9 +93,9 @@ func TestInsertOnConflict(t *testing.T) { ON_CONFLICT().DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -111,9 +111,9 @@ ON CONFLICT DO NOTHING; ON_CONFLICT(Employee.EmployeeID).DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT (employee_id) DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -130,9 +130,9 @@ ON CONFLICT (employee_id) DO NOTHING; ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -234,8 +234,8 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE ON_CONFLICT().DO_UPDATE(nil) testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5); +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6); `) testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index a1d4c2d..f252631 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -331,11 +331,13 @@ SELECT employee.employee_id AS "employee.employee_id", employee.last_name AS "employee.last_name", employee.employment_date AS "employee.employment_date", employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual", manager.employee_id AS "manager.employee_id", manager.first_name AS "manager.first_name", manager.last_name AS "manager.last_name", manager.employment_date AS "manager.employment_date", - manager.manager_id AS "manager.manager_id" + manager.manager_id AS "manager.manager_id", + manager.pto_accrual AS "manager.pto_accrual" FROM test_sample.employee LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id) ORDER BY employee.employee_id; @@ -370,6 +372,7 @@ ORDER BY employee.employee_id; LastName: "Hays", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06.1 +0100 CET", 1), ManagerID: nil, + PtoAccrual: ptr.Of("22:00:00"), }) require.True(t, dest[0].Manager == nil) diff --git a/tests/testdata b/tests/testdata index 1e9247e..6a39774 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1e9247e333babd5172cf162e38518d993f5f3df4 +Subproject commit 6a397747d310938b41d3950d68009578180d3dd5 From 600e0a7ce726df70e7b7f7f306adf3041c7cf04d Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Wed, 9 Oct 2024 08:30:33 -0400 Subject: [PATCH 23/28] Addressing code review comments and adding a query validation --- tests/postgres/alltypes_test.go | 106 ++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index fdb6ac9..62e0bc6 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -2,8 +2,10 @@ package postgres import ( "database/sql" + "github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" + "log/slog" "testing" "time" @@ -932,41 +934,79 @@ func TestTimeExpression(t *testing.T) { require.NoError(t, err) } -func TestIntervalUpsert(t *testing.T) { - testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { - stmt := SELECT(Employee.AllColumns).FROM(Employee). - WHERE(Employee.EmployeeID.EQ(Int(1))) +func TestIntervalSetFunctionality(t *testing.T) { + updateQuery := ` +UPDATE test_sample.employee +SET pto_accrual = INTERVAL '3 HOUR' +WHERE employee.employee_id = $1 +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` + insertQuery := ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (employee_id) DO UPDATE + SET pto_accrual = excluded.pto_accrual +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` - //Validate initial dataset - var windy model.Employee - err := stmt.Query(tx, &windy) - assert.Equal(t, windy.EmployeeID, int32(1)) - assert.Equal(t, windy.FirstName, "Windy") - assert.Equal(t, windy.LastName, "Hays") - assert.Equal(t, *windy.PtoAccrual, "22:00:00") - assert.Nil(t, err) - windy.PtoAccrual = ptr.Of("3h") - //Update data - updateStmt := Employee.UPDATE(Employee.PtoAccrual).SET( - Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), - ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + testCases := []struct { + expectedQuery string + name string + duration string + expectedInterval string + statement func(employee *model.Employee) jet.Statement + }{ + { + name: "updateQuery", + expectedQuery: updateQuery, + duration: "3h", + expectedInterval: "03:00:00", + statement: func(employee *model.Employee) jet.Statement { + return Employee.UPDATE(Employee.PtoAccrual).SET( + Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), + ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + }, + }, + { + expectedQuery: insertQuery, + name: "insertQuery", + duration: "5h", + expectedInterval: "05:00:00", + statement: func(employee *model.Employee) jet.Statement { + return Employee.INSERT(Employee.AllColumns). + MODEL(employee). + ON_CONFLICT(Employee.EmployeeID). + DO_UPDATE(SET( + Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), + )).RETURNING(Employee.AllColumns) + }, + }, + } - err = updateStmt.Query(tx, &windy) - err = stmt.Query(tx, &windy) - assert.Nil(t, err) - assert.Equal(t, *windy.PtoAccrual, "03:00:00") - //Upsert dataset with a different value - windy.PtoAccrual = ptr.Of("5h") - insertStmt := Employee.INSERT(Employee.AllColumns). - MODEL(&windy). - ON_CONFLICT(Employee.EmployeeID). - DO_UPDATE(SET( - Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), - )).RETURNING(Employee.AllColumns) - err = insertStmt.Query(tx, &windy) - assert.Nil(t, err) - assert.Equal(t, *windy.PtoAccrual, "05:00:00") - }) + for _, tc := range testCases { + slog.Info("Running test", slog.Any("test", tc.name)) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var windy model.Employee + windy.PtoAccrual = ptr.Of(tc.duration) + stmt := tc.statement(&windy) + + testutils.AssertStatementSql(t, stmt, tc.expectedQuery) + err := stmt.Query(tx, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, tc.expectedInterval) + + }) + } } func TestInterval(t *testing.T) { From c0a0b450aac24235de8d9a78b15465c979e064e4 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Wed, 9 Oct 2024 14:53:47 -0400 Subject: [PATCH 24/28] Changing from Table Driven tests to subtests --- tests/postgres/alltypes_test.go | 80 ++++++++++++++------------------- 1 file changed, 33 insertions(+), 47 deletions(-) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 62e0bc6..70c4332 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -2,7 +2,6 @@ package postgres import ( "database/sql" - "github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" "log/slog" @@ -935,7 +934,10 @@ func TestTimeExpression(t *testing.T) { } func TestIntervalSetFunctionality(t *testing.T) { - updateQuery := ` + + t.Run("updateQueryIntervalTest", func(t *testing.T) { + slog.Info("Running test", slog.Any("test", t.Name())) + expectedQuery := ` UPDATE test_sample.employee SET pto_accrual = INTERVAL '3 HOUR' WHERE employee.employee_id = $1 @@ -946,7 +948,23 @@ RETURNING employee.employee_id AS "employee.employee_id", employee.manager_id AS "employee.manager_id", employee.pto_accrual AS "employee.pto_accrual"; ` - insertQuery := ` + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var windy model.Employee + windy.PtoAccrual = ptr.Of("3h") + stmt := Employee.UPDATE(Employee.PtoAccrual).SET( + Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), + ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + + testutils.AssertStatementSql(t, stmt, expectedQuery) + err := stmt.Query(tx, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, "03:00:00") + + }) + }) + + t.Run("upsertQueryIntervalTest", func(t *testing.T) { + expectedQuery := ` INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (employee_id) DO UPDATE @@ -958,55 +976,23 @@ RETURNING employee.employee_id AS "employee.employee_id", employee.manager_id AS "employee.manager_id", employee.pto_accrual AS "employee.pto_accrual"; ` - - testCases := []struct { - expectedQuery string - name string - duration string - expectedInterval string - statement func(employee *model.Employee) jet.Statement - }{ - { - name: "updateQuery", - expectedQuery: updateQuery, - duration: "3h", - expectedInterval: "03:00:00", - statement: func(employee *model.Employee) jet.Statement { - return Employee.UPDATE(Employee.PtoAccrual).SET( - Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), - ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) - }, - }, - { - expectedQuery: insertQuery, - name: "insertQuery", - duration: "5h", - expectedInterval: "05:00:00", - statement: func(employee *model.Employee) jet.Statement { - return Employee.INSERT(Employee.AllColumns). - MODEL(employee). - ON_CONFLICT(Employee.EmployeeID). - DO_UPDATE(SET( - Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), - )).RETURNING(Employee.AllColumns) - }, - }, - } - - for _, tc := range testCases { - slog.Info("Running test", slog.Any("test", tc.name)) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { - var windy model.Employee - windy.PtoAccrual = ptr.Of(tc.duration) - stmt := tc.statement(&windy) + var employee model.Employee + employee.PtoAccrual = ptr.Of("5h") + stmt := Employee.INSERT(Employee.AllColumns). + MODEL(employee). + ON_CONFLICT(Employee.EmployeeID). + DO_UPDATE(SET( + Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), + )).RETURNING(Employee.AllColumns) - testutils.AssertStatementSql(t, stmt, tc.expectedQuery) - err := stmt.Query(tx, &windy) + testutils.AssertStatementSql(t, stmt, expectedQuery) + err := stmt.Query(tx, &employee) assert.Nil(t, err) - assert.Equal(t, *windy.PtoAccrual, tc.expectedInterval) + assert.Equal(t, *employee.PtoAccrual, "05:00:00") }) - } + }) } func TestInterval(t *testing.T) { From f7082eda688aff81632b001519deb5a10968ab53 Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Tue, 8 Oct 2024 10:17:25 -0400 Subject: [PATCH 25/28] Adding gosec and lint, fixing null_type overflow ChangeLog: - Adding gosec linting - Adding static type to enum - fixing nulltype overflow - Trying out gotestsum as an alternative to go-junit-report.xml --- .circleci/config.yml | 18 +++++------ .github/workflows/code_scanner.yml | 48 +++++++++++++++++++++++++++++ .golangci.yml | 19 ++++++++++++ examples/quick-start/quick-start.go | 2 +- go.mod | 8 ++--- internal/jet/statement.go | 2 +- internal/testutils/test_utils.go | 7 +++-- internal/utils/filesys/filesys.go | 10 ++++-- mysql/interval.go | 38 +++++++++++------------ qrm/internal/null_types.go | 43 ++++++++++++++++++-------- qrm/internal/null_types_test.go | 36 ++++++++++++++++++++++ tests/init/init.go | 10 ++++-- tests/internal/utils/file/file.go | 4 +-- tests/mysql/main_test.go | 8 ++--- tests/postgres/main_test.go | 6 +--- tests/sqlite/main_test.go | 27 +++------------- 16 files changed, 193 insertions(+), 93 deletions(-) create mode 100644 .github/workflows/code_scanner.yml create mode 100644 .golangci.yml diff --git a/.circleci/config.yml b/.circleci/config.yml index e21f00a..563616f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ jobs: build_and_tests: docker: # specify the version - - image: cimg/go:1.21.6 + - image: cimg/go:1.22.2 - image: cimg/postgres:14.10 environment: POSTGRES_USER: jet @@ -128,17 +128,13 @@ jobs: # to create test results report - run: - name: Install go-junit-report - command: go install github.com/jstemmer/go-junit-report@latest - + name: Install gotestsum + command: go install gotest.tools/gotestsum@latest - run: mkdir -p $TEST_RESULTS - # this will run all tests and exclude test files from code coverage report - - run: | - go test -v ./... \ - -covermode=atomic \ - -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \ - -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml + - run: + name: Running tests + command: gotestsum --junitfile report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... ./... # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ @@ -163,4 +159,4 @@ workflows: version: 2 build_and_test: jobs: - - build_and_tests \ No newline at end of file + - build_and_tests diff --git a/.github/workflows/code_scanner.yml b/.github/workflows/code_scanner.yml new file mode 100644 index 0000000..affc756 --- /dev/null +++ b/.github/workflows/code_scanner.yml @@ -0,0 +1,48 @@ +name: Code Scanners +on: + push: + branches: + - master + pull_request: + branches: + - master + +env: + GO_VERSION: 1.22.0 + + +permissions: + contents: read + # Optional: allow read access to pull request. Use with `only-new-issues` option. + # pull-requests: read + + +jobs: + security_scanning: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: 1.22.0 + cache: true + - name: Setup Tools + run: | + go install github.com/securego/gosec/v2/cmd/gosec@latest + - name: Running Scan + run: gosec --exclude=G402,G304 ./... + lint_scanner: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: 1.22.0 + cache: true + - name: Setup Tools + run: | + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + - name: Running Scan + run: golangci-lint run --timeout=30m ./... diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..4eabe05 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,19 @@ +run: + # The default concurrency value is the number of available CPU. + concurrency: 4 + # Timeout for analysis, e.g. 30s, 5m. + # Default: 1m + timeout: 30m + # Exit code when at least one issue was found. + # Default: 1 + issues-exit-code: 2 + # Include test files or not. + # Default: true + tests: false + +issues: + exclude-dirs: + - tests + exclude-files: + - "_test.go" + - "testutils.go" diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go index b4707f6..dc84989 100644 --- a/examples/quick-start/quick-start.go +++ b/examples/quick-start/quick-start.go @@ -90,7 +90,7 @@ func main() { func jsonSave(path string, v interface{}) { jsonText, _ := json.MarshalIndent(v, "", "\t") - err := os.WriteFile(path, jsonText, 0644) + err := os.WriteFile(path, jsonText, 0600) panicOnError(err) } diff --git a/go.mod b/go.mod index cb204af..ffa977e 100644 --- a/go.mod +++ b/go.mod @@ -1,19 +1,19 @@ module github.com/go-jet/jet/v2 -go 1.18 +go 1.22.0 require ( github.com/go-sql-driver/mysql v1.8.1 + github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/jackc/pgconn v1.14.3 + github.com/jackc/pgtype v1.14.3 + github.com/jackc/pgx/v4 v4.18.3 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.17 ) require ( - github.com/google/go-cmp v0.6.0 - github.com/jackc/pgtype v1.14.3 - github.com/jackc/pgx/v4 v4.18.3 github.com/pkg/profile v1.7.0 github.com/shopspring/decimal v1.4.0 github.com/stretchr/testify v1.9.0 diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 6703154..2ca229d 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -180,7 +180,7 @@ func duration(f func()) time.Duration { f() - return time.Now().Sub(start) + return time.Since(start) } // ExpressionStatement interfacess diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index cfedde4..c2fdc8d 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -103,11 +103,12 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { } // SaveJSONFile saves v as json at testRelativePath +// nolint:unused func SaveJSONFile(v interface{}, testRelativePath string) { jsonText, _ := json.MarshalIndent(v, "", "\t") filePath := getFullPath(testRelativePath) - err := os.WriteFile(filePath, jsonText, 0644) + err := os.WriteFile(filePath, jsonText, 0600) throw.OnError(err) } @@ -116,7 +117,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) { func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) - fileJSONData, err := os.ReadFile(filePath) + fileJSONData, err := os.ReadFile(filePath) // #nosec G304 require.NoError(t, err) if runtime.GOOS == "windows" { @@ -243,7 +244,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter // AssertFileContent check if file content at filePath contains expectedContent text. func AssertFileContent(t *testing.T, filePath string, expectedContent string) { - enumFileData, err := os.ReadFile(filePath) + enumFileData, err := os.ReadFile(filePath) // #nosec G304 require.NoError(t, err) diff --git a/internal/utils/filesys/filesys.go b/internal/utils/filesys/filesys.go index d904be8..0724470 100644 --- a/internal/utils/filesys/filesys.go +++ b/internal/utils/filesys/filesys.go @@ -1,6 +1,7 @@ package filesys import ( + "errors" "fmt" "go/format" "os" @@ -16,7 +17,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error { newGoFilePath += ".go" } - file, err := os.Create(newGoFilePath) + file, err := os.Create(newGoFilePath) // #nosec 304 if err != nil { return err @@ -28,7 +29,10 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error { // if there is a format error we will write unformulated text for debug purposes if err != nil { - file.Write(text) + _, writeErr := file.Write(text) + if writeErr != nil { + return errors.Join(writeErr, fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err)) + } return fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err) } @@ -43,7 +47,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error { // EnsureDirPathExist ensures dir path exists. If path does not exist, creates new path. func EnsureDirPathExist(dirPath string) error { if _, err := os.Stat(dirPath); os.IsNotExist(err) { - err := os.MkdirAll(dirPath, os.ModePerm) + err := os.MkdirAll(dirPath, 0o750) if err != nil { return fmt.Errorf("can't create directory - %s: %w", dirPath, err) diff --git a/mysql/interval.go b/mysql/interval.go index ce1c609..23be21f 100644 --- a/mysql/interval.go +++ b/mysql/interval.go @@ -14,25 +14,25 @@ type unitType string // List of interval unit types for MySQL const ( MICROSECOND unitType = "MICROSECOND" - SECOND = "SECOND" - MINUTE = "MINUTE" - HOUR = "HOUR" - DAY = "DAY" - WEEK = "WEEK" - MONTH = "MONTH" - QUARTER = "QUARTER" - YEAR = "YEAR" - SECOND_MICROSECOND = "SECOND_MICROSECOND" - MINUTE_MICROSECOND = "MINUTE_MICROSECOND" - MINUTE_SECOND = "MINUTE_SECOND" - HOUR_MICROSECOND = "HOUR_MICROSECOND" - HOUR_SECOND = "HOUR_SECOND" - HOUR_MINUTE = "HOUR_MINUTE" - DAY_MICROSECOND = "DAY_MICROSECOND" - DAY_SECOND = "DAY_SECOND" - DAY_MINUTE = "DAY_MINUTE" - DAY_HOUR = "DAY_HOUR" - YEAR_MONTH = "YEAR_MONTH" + SECOND unitType = "SECOND" + MINUTE unitType = "MINUTE" + HOUR unitType = "HOUR" + DAY unitType = "DAY" + WEEK unitType = "WEEK" + MONTH unitType = "MONTH" + QUARTER unitType = "QUARTER" + YEAR unitType = "YEAR" + SECOND_MICROSECOND unitType = "SECOND_MICROSECOND" + MINUTE_MICROSECOND unitType = "MINUTE_MICROSECOND" + MINUTE_SECOND unitType = "MINUTE_SECOND" + HOUR_MICROSECOND unitType = "HOUR_MICROSECOND" + HOUR_SECOND unitType = "HOUR_SECOND" + HOUR_MINUTE unitType = "HOUR_MINUTE" + DAY_MICROSECOND unitType = "DAY_MICROSECOND" + DAY_SECOND unitType = "DAY_SECOND" + DAY_MINUTE unitType = "DAY_MINUTE" + DAY_HOUR unitType = "DAY_HOUR" + YEAR_MONTH unitType = "YEAR_MONTH" ) // Interval is representation of MySQL interval diff --git a/qrm/internal/null_types.go b/qrm/internal/null_types.go index ab75cf6..85d9c68 100644 --- a/qrm/internal/null_types.go +++ b/qrm/internal/null_types.go @@ -10,6 +10,10 @@ import ( "time" ) +var ( + castOverFlowError = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error") +) + // NullBool struct type NullBool struct { sql.NullBool @@ -119,32 +123,47 @@ func (n *NullUInt64) Scan(value interface{}) error { n.Valid = false return nil case int64: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int32: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int16: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int8: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int: + if v < 0 { + return castOverFlowError + } n.UInt64, n.Valid = uint64(v), true return nil case uint64: n.UInt64, n.Valid = v, true return nil - case int32: - n.UInt64, n.Valid = uint64(v), true - return nil case uint32: n.UInt64, n.Valid = uint64(v), true return nil - case int16: - n.UInt64, n.Valid = uint64(v), true - return nil case uint16: n.UInt64, n.Valid = uint64(v), true return nil - case int8: - n.UInt64, n.Valid = uint64(v), true - return nil case uint8: n.UInt64, n.Valid = uint64(v), true return nil - case int: - n.UInt64, n.Valid = uint64(v), true - return nil case uint: n.UInt64, n.Valid = uint64(v), true return nil diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go index a15b104..eab2dd2 100644 --- a/qrm/internal/null_types_test.go +++ b/qrm/internal/null_types_test.go @@ -2,6 +2,7 @@ package internal import ( "fmt" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" "time" @@ -62,11 +63,21 @@ func TestNullUInt64(t *testing.T) { value, _ := nullUInt64.Value() require.Equal(t, value, uint64(11)) + require.NoError(t, nullUInt64.Scan(uint64(11))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(11)) + require.NoError(t, nullUInt64.Scan(int32(32))) require.Equal(t, nullUInt64.Valid, true) value, _ = nullUInt64.Value() require.Equal(t, value, uint64(32)) + require.NoError(t, nullUInt64.Scan(uint32(32))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(32)) + require.NoError(t, nullUInt64.Scan(int16(20))) require.Equal(t, nullUInt64.Valid, true) value, _ = nullUInt64.Value() @@ -88,4 +99,29 @@ func TestNullUInt64(t *testing.T) { require.Equal(t, value, uint64(30)) require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text") + + //Validate negative use cases + err := nullUInt64.Scan(int64(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(-5) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(int32(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(int16(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(int8(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) } diff --git a/tests/init/init.go b/tests/init/init.go index 10631f1..5a0ccdb 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -3,6 +3,7 @@ package main import ( "context" "database/sql" + "errors" "flag" "fmt" "github.com/go-jet/jet/v2/generator/mysql" @@ -124,7 +125,7 @@ func initMySQLDB(isMariaDB bool) error { fmt.Println(cmdLine) - cmd := exec.Command("sh", "-c", cmdLine) + cmd := exec.Command("sh", "-c", cmdLine) // #nosec G204 cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -183,7 +184,7 @@ func initPostgresDB(dbType string, connectionString string) error { } func execFile(db *sql.DB, sqlFilePath string) error { - testSampleSql, err := os.ReadFile(sqlFilePath) + testSampleSql, err := os.ReadFile(sqlFilePath) // #nosec G304 if err != nil { return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err) } @@ -210,7 +211,10 @@ func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error { err = f(tx) if err != nil { - tx.Rollback() + rollBackError := tx.Rollback() + if rollBackError != nil { + return errors.Join(rollBackError, err) + } return err } diff --git a/tests/internal/utils/file/file.go b/tests/internal/utils/file/file.go index fea0634..73095ea 100644 --- a/tests/internal/utils/file/file.go +++ b/tests/internal/utils/file/file.go @@ -10,7 +10,7 @@ import ( // Exists expects file to exist on path constructed from pathElems and returns content of the file func Exists(t *testing.T, pathElems ...string) (fileContent string) { modelFilePath := path.Join(pathElems...) - file, err := os.ReadFile(modelFilePath) + file, err := os.ReadFile(modelFilePath) // #nosec G304 require.Nil(t, err) require.NotEmpty(t, file) return string(file) @@ -19,6 +19,6 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) { // NotExists expects file not to exist on path constructed from pathElems func NotExists(t *testing.T, pathElems ...string) { modelFilePath := path.Join(pathElems...) - _, err := os.ReadFile(modelFilePath) + _, err := os.ReadFile(modelFilePath) // #nosec G304 require.True(t, os.IsNotExist(err)) } diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index f6ce57d..4e3b5d5 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -6,12 +6,9 @@ import ( jetmysql "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" - "math/rand" - "runtime" - "time" - _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + "runtime" "github.com/pkg/profile" "os" @@ -33,7 +30,6 @@ func sourceIsMariaDB() bool { } func TestMain(m *testing.M) { - rand.Seed(time.Now().Unix()) defer profile.Start().Stop() var err error diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 08af67a..5b7f48d 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -5,13 +5,10 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/tests/internal/utils/repo" - "math/rand" + "github.com/jackc/pgx/v4/stdlib" "os" "runtime" "testing" - "time" - - "github.com/jackc/pgx/v4/stdlib" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/dbconfig" @@ -44,7 +41,6 @@ func skipForCockroachDB(t *testing.T) { } func TestMain(m *testing.M) { - rand.Seed(time.Now().Unix()) defer profile.Start().Stop() setTestRoot() diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 4975845..c5d4fb6 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -8,30 +8,21 @@ import ( "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" - "math/rand" - "os" - "os/exec" - "runtime" - "strings" - "testing" - "time" - "github.com/pkg/profile" + "github.com/stretchr/testify/require" + "os" + "runtime" + "testing" _ "github.com/mattn/go-sqlite3" ) var db *sql.DB var sampleDB *sql.DB -var testRoot string func TestMain(m *testing.M) { - rand.Seed(time.Now().Unix()) defer profile.Start().Stop() - setTestRoot() - var err error db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) throw.OnError(err) @@ -50,16 +41,6 @@ func TestMain(m *testing.M) { } } -func setTestRoot() { - cmd := exec.Command("git", "rev-parse", "--show-toplevel") - byteArr, err := cmd.Output() - if err != nil { - panic(err) - } - - testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" -} - var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string From dbd23ed612d77102d2ace76351f692e798dfcdfe Mon Sep 17 00:00:00 2001 From: Samir Faci Date: Tue, 15 Oct 2024 08:54:28 -0400 Subject: [PATCH 26/28] Addressing code review comments --- .circleci/config.yml | 4 ++-- .github/workflows/code_scanner.yml | 13 +++++-------- go.mod | 2 +- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 563616f..74f2bd8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ jobs: build_and_tests: docker: # specify the version - - image: cimg/go:1.22.2 + - image: cimg/go:1.22.8 - image: cimg/postgres:14.10 environment: POSTGRES_USER: jet @@ -134,7 +134,7 @@ jobs: - run: name: Running tests - command: gotestsum --junitfile report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... ./... + command: gotestsum --junitfile $TEST_RESULTS/report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... ./... # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ diff --git a/.github/workflows/code_scanner.yml b/.github/workflows/code_scanner.yml index affc756..d4ecc5d 100644 --- a/.github/workflows/code_scanner.yml +++ b/.github/workflows/code_scanner.yml @@ -7,14 +7,11 @@ on: branches: - master -env: - GO_VERSION: 1.22.0 - - permissions: contents: read - # Optional: allow read access to pull request. Use with `only-new-issues` option. - # pull-requests: read + +env: + go_version: "1.22.8" jobs: @@ -25,7 +22,7 @@ jobs: uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: 1.22.0 + go-version: ${{ env.go_version }} cache: true - name: Setup Tools run: | @@ -39,7 +36,7 @@ jobs: uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: 1.22.0 + go-version: ${{ env.go_version }} cache: true - name: Setup Tools run: | diff --git a/go.mod b/go.mod index ffa977e..890dc0d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-jet/jet/v2 -go 1.22.0 +go 1.18 require ( github.com/go-sql-driver/mysql v1.8.1 From 3fcbbec427774fa8bd3e812a4234dcece791456f Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 6 Oct 2024 14:21:42 +0200 Subject: [PATCH 27/28] Add support for Row expression. --- internal/jet/func_expression.go | 7 +- internal/jet/literal_expression.go | 26 ------- internal/jet/order_set_aggregate_functions.go | 7 +- internal/jet/row_expression.go | 78 +++++++++++++++++++ mysql/expressions.go | 8 ++ postgres/expressions.go | 8 ++ postgres/literal.go | 10 +++ sqlite/expressions.go | 8 ++ sqlite/functions.go | 6 +- tests/mysql/alltypes_test.go | 39 +++++++++- tests/mysql/with_test.go | 4 +- tests/postgres/alltypes_test.go | 50 ++++++++++-- tests/postgres/update_test.go | 5 +- tests/postgres/with_test.go | 12 +-- tests/sqlite/alltypes_test.go | 41 +++++++++- tests/sqlite/with_test.go | 8 +- 16 files changed, 254 insertions(+), 63 deletions(-) create mode 100644 internal/jet/row_expression.go diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 7e49880..ddc579e 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression { return newBoolExpressionListOperator("OR", expressions...) } -// ROW function is used to create a tuple value that consists of a set of expressions or column values. -func ROW(expressions ...Expression) Expression { - return NewFunc("ROW", expressions, nil) -} - // ------------------ Mathematical functions ---------------// // ABSf calculates absolute value from float expression @@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder if _, isStatement := expression.(Statement); isStatement { expression.serialize(statement, out, options...) } else { - skipWrap(expression).serialize(statement, out, options...) + expression.serialize(statement, out, append(options, NoWrap, Ident)...) } } } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index d6f0b41..251d3ab 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option //---------------------------------------------------// -type wrap struct { - ExpressionInterfaceImpl - expressions []Expression -} - -func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("(") - - if len(n.expressions) == 1 { - options = append(options, NoWrap, Ident) - } - serializeExpressionList(statementType, n.expressions, ", ", out, options...) - - out.WriteString(")") -} - -// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) -func WRAP(expression ...Expression) Expression { - wrap := &wrap{expressions: expression} - wrap.ExpressionInterfaceImpl.Parent = wrap - - return wrap -} - -//---------------------------------------------------// - type rawExpression struct { ExpressionInterfaceImpl diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go index 8ce5d1e..eff954a 100644 --- a/internal/jet/order_set_aggregate_functions.go +++ b/internal/jet/order_set_aggregate_functions.go @@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct { func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(p.name) - WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + + if p.fraction != nil { + WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + } else { + WRAP().serialize(statement, out, FallTrough(options)...) + } out.WriteString("WITHIN GROUP") p.orderBy.serialize(statement, out) } diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go new file mode 100644 index 0000000..819cf15 --- /dev/null +++ b/internal/jet/row_expression.go @@ -0,0 +1,78 @@ +package jet + +// RowExpression interface +type RowExpression interface { + Expression + + EQ(rhs RowExpression) BoolExpression + NOT_EQ(rhs RowExpression) BoolExpression + IS_DISTINCT_FROM(rhs RowExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression + + LT(rhs RowExpression) BoolExpression + LT_EQ(rhs RowExpression) BoolExpression + GT(rhs RowExpression) BoolExpression + GT_EQ(rhs RowExpression) BoolExpression +} + +type rowInterfaceImpl struct { + parent RowExpression +} + +func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression { + return Eq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression { + return NotEq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression { + return IsDistinctFrom(n.parent, rhs) +} + +func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression { + return IsNotDistinctFrom(n.parent, rhs) +} + +func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression { + return Gt(n.parent, rhs) +} + +func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression { + return GtEq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression { + return Lt(n.parent, rhs) +} + +func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression { + return LtEq(n.parent, rhs) +} + +//---------------------------------------------------// + +type rowExpressionWrapper struct { + rowInterfaceImpl + Expression +} + +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +func RowExp(expression Expression) RowExpression { + rowExpressionWrap := rowExpressionWrapper{Expression: expression} + rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap + return &rowExpressionWrap +} + +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(expressions ...Expression) RowExpression { + return RowExp(NewFunc("ROW", expressions, nil)) +} + +// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`. +func WRAP(expressions ...Expression) RowExpression { + return RowExp(NewFunc("", expressions, nil)) +} diff --git a/mysql/expressions.go b/mysql/expressions.go index 53b1fa7..4073ef5 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -30,6 +30,9 @@ type DateTimeExpression = jet.TimestampExpression // TimestampExpression interface type TimestampExpression = jet.TimestampExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -70,6 +73,11 @@ var DateTimeExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampExp = jet.TimestampExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // CustomExpression is used to define custom expressions. var CustomExpression = jet.CustomExpression diff --git a/postgres/expressions.go b/postgres/expressions.go index 9872910..d8ad34b 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -36,6 +36,9 @@ type TimestampExpression = jet.TimestampExpression // TimestampzExpression interface type TimestampzExpression = jet.TimestampzExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // DateRange Expression interface type DateRange = jet.Range[DateExpression] @@ -99,6 +102,11 @@ var TimestampExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampzExp = jet.TimestampzExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // RangeExp is range expression wrapper around arbitrary expression. // Allows go compiler to see any expression as range expression. // Does not add sql cast to generated sql builder output. diff --git a/postgres/literal.go b/postgres/literal.go index e3a95b3..26b75d8 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -57,6 +57,16 @@ func Uint64(value uint64) IntegerExpression { // Float creates new float literal expression var Float = jet.Float +// Float32 is constructor for 32 bit float literals +func Float32(value float32) FloatExpression { + return CAST(jet.Literal(value)).AS_REAL() +} + +// Float64 is constructor for 64 bit float literals +func Float64(value float64) FloatExpression { + return CAST(jet.Literal(value)).AS_DOUBLE() +} + // Decimal creates new float literal expression var Decimal = jet.Decimal diff --git a/sqlite/expressions.go b/sqlite/expressions.go index 42ccc96..0b2d320 100644 --- a/sqlite/expressions.go +++ b/sqlite/expressions.go @@ -33,6 +33,9 @@ type DateTimeExpression = jet.TimestampExpression // TimestampExpression interface type TimestampExpression = jet.TimestampExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -73,6 +76,11 @@ var DateTimeExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampExp = jet.TimestampExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // CustomExpression is used to define custom expressions. var CustomExpression = jet.CustomExpression diff --git a/sqlite/functions.go b/sqlite/functions.go index 47a0a5b..a76f236 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -15,10 +15,8 @@ var ( OR = jet.OR ) -// ROW is construct one table row from list of expressions. -func ROW(expressions ...Expression) Expression { - return jet.NewFunc("", expressions, nil) -} +// ROW is construct one row from a list of expressions. +var ROW = jet.WRAP // ------------------ Mathematical functions ---------------// diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index d5702bd..61bc7f2 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -97,18 +97,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - )) AS "result.in_select", + ))) AS "result.in_select", (CURRENT_USER()) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM test_sample.all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) @@ -1404,3 +1404,34 @@ VALUES ('91.23', '45.67', '12.35', '56.79', 0.2, 0.22, 0.3, 0.33, 0.4, 0.44); require.Equal(t, 45.67, *result.Floats.DecimalPtr) }) } + +func TestRowExpression(t *testing.T) { + now := time.Now() + nowAddHour := time.Now().Add(time.Hour) + + stmt := SELECT( + ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))), + ROW(Bool(false), DateT(now)).NOT_EQ(ROW(Bool(true), DateT(now))), + ROW(TimestampT(nowAddHour), String("txt")).IS_DISTINCT_FROM(RowExp(Raw("row(NOW(), 'png')"))), + ROW(TimestampT(now), DateTimeT(nowAddHour)).GT(ROW(TimestampT(now), DateTimeT(now))), + ROW(DateTimeT(nowAddHour), Int(1)).GT_EQ(ROW(DateTimeT(now), Int(2))), + ROW(TimestampT(now), DateTimeT(nowAddHour)).LT(ROW(TimestampT(now), DateTimeT(now))), + ROW(DateTimeT(nowAddHour), Float(1.22)).LT_EQ(ROW(DateTimeT(now), Float(2.33))), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT ROW(?, CAST(? AS DATE)) = ROW(?, CAST(? AS DATE)), + ROW(?, CAST(? AS DATE)) != ROW(?, CAST(? AS DATE)), + NOT(ROW(TIMESTAMP(?), ?) <=> (row(NOW(), 'png'))), + ROW(TIMESTAMP(?), CAST(? AS DATETIME)) > ROW(TIMESTAMP(?), CAST(? AS DATETIME)), + ROW(CAST(? AS DATETIME), ?) >= ROW(CAST(? AS DATETIME), ?), + ROW(TIMESTAMP(?), CAST(? AS DATETIME)) < ROW(TIMESTAMP(?), CAST(? AS DATETIME)), + ROW(CAST(? AS DATETIME), ?) <= ROW(CAST(? AS DATETIME), ?); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index d7d8d3d..59da3a5 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -164,10 +164,10 @@ WITH payments_to_delete AS ( WHERE payment.amount < 0.5 ) DELETE FROM dvds.payment -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_delete - ); + )); `, "''", "`")) tx, err := db.Begin() diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 70c4332..07350b8 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -347,24 +347,24 @@ func TestExpressionOperators(t *testing.T) { AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) - //fmt.Println(query.Sql()) + // fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - )) AS "result.in_select", + ))) AS "result.in_select", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; `, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2)) @@ -1111,6 +1111,46 @@ FROM test_sample.all_types; require.NoError(t, err) } +func TestRowExpression(t *testing.T) { + now := time.Now() + nowAddHour := time.Now().Add(time.Hour) + + stmt := SELECT( + ROW(Int32(1), Float32(11.22), String("john")).AS("row"), + WRAP(Int64(1), Float64(11.22), String("john")).AS("wrap"), + + ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))), + WRAP(Bool(false), DateT(now)).NOT_EQ(WRAP(Bool(true), DateT(now))), + + ROW(TimeT(nowAddHour)).IS_DISTINCT_FROM(RowExp(Raw("row(NOW()::time)"))), + ROW().IS_NOT_DISTINCT_FROM(ROW()), + + ROW(TimestampT(now), TimestampzT(nowAddHour)).GT(WRAP(TimestampT(now), TimestampzT(now))), + ROW(TimestampzT(nowAddHour)).GT_EQ(ROW(TimestampzT(now))), + WRAP(TimestampT(now), TimestampzT(nowAddHour)).LT(ROW(TimestampT(now), TimestampzT(now))), + ROW(TimestampzT(nowAddHour)).LT_EQ(ROW(TimestampzT(now))), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT ROW($1::integer, $2::real, $3::text) AS "row", + ($4::bigint, $5::double precision, $6::text) AS "wrap", + ROW($7::boolean, $8::date) = ROW($9::boolean, $10::date), + ($11::boolean, $12::date) != ($13::boolean, $14::date), + ROW($15::time without time zone) IS DISTINCT FROM (row(NOW()::time)), + ROW() IS NOT DISTINCT FROM ROW(), + ROW($16::timestamp without time zone, $17::timestamp with time zone) > ($18::timestamp without time zone, $19::timestamp with time zone), + ROW($20::timestamp with time zone) >= ROW($21::timestamp with time zone), + ($22::timestamp without time zone, $23::timestamp with time zone) < ROW($24::timestamp without time zone, $25::timestamp with time zone), + ROW($26::timestamp with time zone) <= ROW($27::timestamp with time zone); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { sql string diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index e0c7da2..103b420 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -344,7 +344,10 @@ func TestUpdateExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - testutils.AssertExecContextErr(ctx, t, updateStmt, db, "context deadline exceeded") + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := updateStmt.ExecContext(ctx, tx) + require.Error(t, err, "context deadline exceeded") + }) } func TestUpdateFrom(t *testing.T) { diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 21fca32..92b5649 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -83,10 +83,10 @@ SELECT orders.ship_region AS "orders.ship_region", SUM(order_details.quantity) AS "product_sales" FROM northwind.orders INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) -WHERE orders.ship_region IN ( +WHERE orders.ship_region IN (( SELECT top_region."orders.ship_region" AS "orders.ship_region" FROM top_region - ) + )) GROUP BY orders.ship_region, order_details.product_id ORDER BY SUM(order_details.quantity) DESC; `) @@ -157,19 +157,19 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { testutils.AssertStatementSql(t, stmt, ` WITH remove_discontinued_orders AS ( DELETE FROM northwind.order_details - WHERE order_details.product_id IN ( + WHERE order_details.product_id IN (( SELECT products.product_id AS "products.product_id" FROM northwind.products WHERE products.discontinued = $1 - ) + )) RETURNING order_details.product_id AS "order_details.product_id" ),update_discontinued_price AS ( UPDATE northwind.products SET unit_price = $2 - WHERE products.product_id IN ( + WHERE products.product_id IN (( SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" FROM remove_discontinued_orders - ) + )) RETURNING products.product_id AS "products.product_id", products.product_name AS "products.product_name", products.supplier_id AS "products.supplier_id", diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 41e5cf7..080e870 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -234,18 +234,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" FROM all_types - )) AS "result.in_select", + ))) AS "result.in_select", (length(121232459)) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) @@ -900,3 +900,36 @@ func TestDateTimeExpressions(t *testing.T) { require.Equal(t, dest.JulianDay, 2.4551543576232754e+06) require.Equal(t, dest.StrfTime, "20:34") } + +func TestRowExpression(t *testing.T) { + date := Date(2000, 9, 9) + time := Time(11, 22, 11) + dateTime := DateTime(2008, 11, 22, 10, 12, 40) + dateTime2 := DateTime(2011, 1, 2, 5, 12, 40) + + stmt := SELECT( + ROW(Bool(false), date).EQ(ROW(Bool(true), date)), + ROW(Bool(false), time).NOT_EQ(ROW(Bool(true), time)), + ROW(time).IS_DISTINCT_FROM(RowExp(Raw("(time('now'))"))), + ROW(dateTime, dateTime2).GT(ROW(dateTime, dateTime2)), + ROW(dateTime2).GT_EQ(ROW(dateTime)), + ROW(dateTime, dateTime2).LT(ROW(dateTime, dateTime2)), + ROW(dateTime2).LT_EQ(ROW(dateTime2)), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT (FALSE, DATE('2000-09-09')) = (TRUE, DATE('2000-09-09')), + (FALSE, TIME('11:22:11')) != (TRUE, TIME('11:22:11')), + (TIME('11:22:11')) IS NOT ((time('now'))), + (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) > (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')), + (DATETIME('2011-01-02 05:12:40')) >= (DATETIME('2008-11-22 10:12:40')), + (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) < (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')), + (DATETIME('2011-01-02 05:12:40')) <= (DATETIME('2011-01-02 05:12:40')); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go index 402df2f..485655c 100644 --- a/tests/sqlite/with_test.go +++ b/tests/sqlite/with_test.go @@ -153,10 +153,10 @@ WITH payments_to_update AS ( ) UPDATE payment SET amount = 0 -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_update - ); + )); `, "''", "`", -1)) tx := beginDBTx(t) @@ -205,10 +205,10 @@ WITH payments_to_delete AS ( WHERE payment.amount < 0.5 ) DELETE FROM payment -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_delete - ); + )); `, "''", "`", -1)) tx := beginDBTx(t) From 8d112f7db8ec674496bf5e8a0919c892b8827fd9 Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 17 Oct 2024 14:12:21 +0200 Subject: [PATCH 28/28] Add support for VALUES statement. --- internal/jet/dialect.go | 8 + internal/jet/expression.go | 17 +- internal/jet/order_set_aggregate_functions.go | 4 +- internal/jet/projection_test.go | 2 +- internal/jet/raw_statement.go | 2 +- internal/jet/row_expression.go | 48 ++- internal/jet/select_table.go | 28 +- internal/jet/values.go | 35 ++ internal/jet/with_statement.go | 11 +- mysql/cast.go | 30 +- mysql/dialect.go | 4 + mysql/functions.go | 4 +- mysql/select_statement.go | 2 +- mysql/select_table.go | 4 +- mysql/set_statement.go | 2 +- mysql/statement.go | 2 +- mysql/values.go | 32 ++ mysql/with_statement.go | 6 +- postgres/dialect.go | 4 + postgres/dialect_test.go | 16 +- postgres/functions.go | 14 +- postgres/insert_statement_test.go | 26 +- postgres/select_statement.go | 2 +- postgres/select_table.go | 6 +- postgres/set_statement.go | 2 +- postgres/values.go | 32 ++ postgres/with_statement.go | 6 +- sqlite/dialect.go | 4 + sqlite/functions.go | 6 +- sqlite/select_statement.go | 2 +- sqlite/select_table.go | 4 +- sqlite/set_statement.go | 2 +- sqlite/values.go | 26 ++ sqlite/with_statement.go | 10 +- tests/mysql/main_test.go | 6 + tests/mysql/values_test.go | 347 ++++++++++++++++++ tests/postgres/alltypes_test.go | 9 +- tests/postgres/northwind_test.go | 32 +- tests/postgres/values_test.go | 284 ++++++++++++++ tests/sqlite/update_test.go | 2 +- tests/sqlite/values_test.go | 344 +++++++++++++++++ 41 files changed, 1296 insertions(+), 131 deletions(-) create mode 100644 internal/jet/values.go create mode 100644 mysql/values.go create mode 100644 postgres/values.go create mode 100644 sqlite/values.go create mode 100644 tests/mysql/values_test.go create mode 100644 tests/postgres/values_test.go create mode 100644 tests/sqlite/values_test.go diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index f3ad2b4..68c4c02 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -13,6 +13,7 @@ type Dialect interface { ArgumentPlaceholder() QueryPlaceholderFunc IsReservedWord(name string) bool SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + ValuesDefaultColumnName(index int) string } // SerializerFunc func @@ -35,6 +36,7 @@ type DialectParams struct { ArgumentPlaceholder QueryPlaceholderFunc ReservedWords []string SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + ValuesDefaultColumnName func(index int) string } // NewDialect creates new dialect with params @@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect { argumentPlaceholder: params.ArgumentPlaceholder, reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), serializeOrderBy: params.SerializeOrderBy, + valuesDefaultColumnName: params.ValuesDefaultColumnName, } } @@ -62,6 +65,7 @@ type dialectImpl struct { argumentPlaceholder QueryPlaceholderFunc reservedWords map[string]bool serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + valuesDefaultColumnName func(index int) string } func (d *dialectImpl) Name() string { @@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending, return d.serializeOrderBy } +func (d *dialectImpl) ValuesDefaultColumnName(index int) string { + return d.valuesDefaultColumnName(index) +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 05b1797..9999803 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { // IN checks if this expressions matches any in expressions list func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN") + return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN") } // NOT_IN checks if this expressions is different of all expressions in expressions list func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN") + return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN") } // AS the temporary alias name to assign to the expression @@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, } } -type skipParenthesisWrap struct { - Expression -} - -func skipWrap(expression Expression) Expression { - return &skipParenthesisWrap{expression} -} - -// since the expression is a function parameter, there is no need to wrap it in parentheses -func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - s.Expression.serialize(statement, out, append(options, NoWrap)...) +func wrap(expressions ...Expression) Expression { + return NewFunc("", expressions, nil) } diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go index eff954a..c853845 100644 --- a/internal/jet/order_set_aggregate_functions.go +++ b/internal/jet/order_set_aggregate_functions.go @@ -56,9 +56,9 @@ func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out out.WriteString(p.name) if p.fraction != nil { - WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + wrap(p.fraction).serialize(statement, out, FallTrough(options)...) } else { - WRAP().serialize(statement, out, FallTrough(options)...) + wrap().serialize(statement, out, FallTrough(options)...) } out.WriteString("WITHIN GROUP") p.orderBy.serialize(statement, out) diff --git a/internal/jet/projection_test.go b/internal/jet/projection_test.go index 0370b43..61dd920 100644 --- a/internal/jet/projection_test.go +++ b/internal/jet/projection_test.go @@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg", table2.col3 AS "col3", table2.col4 AS "col4"`) - subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery")) + subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil)) assertProjectionSerialize(t, subQueryProjections, `"subQuery"."table1.col3" AS "table1.col3", diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go index 191c7b4..99fb8eb 100644 --- a/internal/jet/raw_statement.go +++ b/internal/jet/raw_statement.go @@ -8,7 +8,7 @@ type rawStatementImpl struct { } // RawStatement creates new sql statements from raw query and optional map of named arguments -func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement { +func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement { newRawStatement := rawStatementImpl{ serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ dialect: dialect, diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go index 819cf15..e7d5ed5 100644 --- a/internal/jet/row_expression.go +++ b/internal/jet/row_expression.go @@ -3,6 +3,7 @@ package jet // RowExpression interface type RowExpression interface { Expression + HasProjections EQ(rhs RowExpression) BoolExpression NOT_EQ(rhs RowExpression) BoolExpression @@ -16,7 +17,9 @@ type RowExpression interface { } type rowInterfaceImpl struct { - parent RowExpression + parent Expression + dialect Dialect + elemCount int } func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression { @@ -51,13 +54,44 @@ func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression { return LtEq(n.parent, rhs) } -//---------------------------------------------------// +func (n *rowInterfaceImpl) projections() ProjectionList { + var ret ProjectionList + for i := 0; i < n.elemCount; i++ { + rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil) + ret = append(ret, &rowColumn) + } + + return ret +} + +// ---------------------------------------------------// type rowExpressionWrapper struct { rowInterfaceImpl Expression } +func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression { + ret := &rowExpressionWrapper{} + ret.rowInterfaceImpl.parent = ret + + ret.Expression = NewFunc(name, expressions, ret) + ret.dialect = dialect + ret.elemCount = len(expressions) + + return ret +} + +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(dialect Dialect, expressions ...Expression) RowExpression { + return newRowExpression("ROW", dialect, expressions...) +} + +// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`. +func WRAP(dialect Dialect, expressions ...Expression) RowExpression { + return newRowExpression("", dialect, expressions...) +} + // RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. // This enables the Go compiler to interpret any expression as a row expression // Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. @@ -66,13 +100,3 @@ func RowExp(expression Expression) RowExpression { rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap return &rowExpressionWrap } - -// ROW function is used to create a tuple value that consists of a set of expressions or column values. -func ROW(expressions ...Expression) RowExpression { - return RowExp(NewFunc("ROW", expressions, nil)) -} - -// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`. -func WRAP(expressions ...Expression) RowExpression { - return RowExp(NewFunc("", expressions, nil)) -} diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index c25fba3..f1f58ac 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -8,15 +8,21 @@ type SelectTable interface { } type selectTableImpl struct { - Statement SerializerHasProjections - alias string + Statement SerializerHasProjections + alias string + columnAliases []ColumnExpression } // NewSelectTable func -func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl { +func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl { selectTable := selectTableImpl{ - Statement: selectStmt, - alias: alias, + Statement: selectStmt, + alias: alias, + columnAliases: columnAliases, + } + + for _, column := range selectTable.columnAliases { + column.setSubQuery(selectTable) } return selectTable @@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string { } func (s selectTableImpl) AllColumns() ProjectionList { + if len(s.columnAliases) > 0 { + return ColumnListToProjectionList(s.columnAliases) + } + projectionList := s.projections().fromImpl(s) return projectionList.(ProjectionList) } @@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt out.WriteString("AS") out.WriteIdentifier(s.alias) + + if len(s.columnAliases) > 0 { + out.WriteByte('(') + SerializeColumnExpressionNames(s.columnAliases, out) + out.WriteByte(')') + } } // -------------------------------------- @@ -50,7 +66,7 @@ type lateralImpl struct { // NewLateral creates new lateral expression from select statement with alias func NewLateral(selectStmt SerializerStatement, alias string) SelectTable { - return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)} + return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)} } func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/jet/values.go b/internal/jet/values.go new file mode 100644 index 0000000..9d0f691 --- /dev/null +++ b/internal/jet/values.go @@ -0,0 +1,35 @@ +package jet + +// Values hold a set of one or more rows +type Values []RowExpression + +func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteByte('(') + out.IncreaseIdent(5) + + out.NewLine() + out.WriteString("VALUES") + + for rowIndex, row := range v { + if rowIndex > 0 { + out.WriteString(",") + out.NewLine() + } else { + out.IncreaseIdent(7) + } + + row.serialize(statement, out, options...) + } + out.DecreaseIdent(7) + out.DecreaseIdent(5) + out.NewLine() + out.WriteByte(')') +} + +func (v Values) projections() ProjectionList { + if len(v) == 0 { + return nil + } + + return v[0].projections() +} diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index 783fa27..03330e6 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -64,7 +64,7 @@ type CommonTableExpression struct { // CTE creates new named CommonTableExpression func CTE(name string, columns ...ColumnExpression) CommonTableExpression { cte := CommonTableExpression{ - selectTableImpl: NewSelectTable(nil, name), + selectTableImpl: NewSelectTable(nil, name, columns), Columns: columns, } @@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde out.WriteIdentifier(c.alias) } } - -// AllColumns returns list of all projections in the CTE -func (c CommonTableExpression) AllColumns() ProjectionList { - if len(c.Columns) > 0 { - return ColumnListToProjectionList(c.Columns) - } - - return c.selectTableImpl.AllColumns() -} diff --git a/mysql/cast.go b/mysql/cast.go index 83a0578..ca647a5 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -6,23 +6,27 @@ import ( ) type cast interface { - // Cast expressions as castType type + // AS casts expressions as castType type AS(castType string) Expression - // Cast expression as char with optional length + // AS_CHAR casts expression as char with optional length AS_CHAR(length ...int) StringExpression - // Cast expression AS date type + // AS_DATE casts expression AS date type AS_DATE() DateExpression - // Cast expression AS numeric type, using precision and optionally scale + // AS_FLOAT casts expressions as float type + AS_FLOAT() FloatExpression + // AS_DOUBLE casts expressions as double type + AS_DOUBLE() FloatExpression + // AS_DECIMAL casts expression AS numeric type AS_DECIMAL() FloatExpression - // Cast expression AS time type + // AS_TIME casts expression AS time type AS_TIME() TimeExpression - // Cast expression as datetime type + // AS_DATETIME casts expression as datetime type AS_DATETIME() DateTimeExpression - // Cast expressions as signed integer type + // AS_SIGNED casts expressions as signed integer type AS_SIGNED() IntegerExpression - // Cast expression as unsigned integer type + // AS_UNSIGNED casts expression as unsigned integer type AS_UNSIGNED() IntegerExpression - // Cast expression as binary type + // AS_BINARY casts expression as binary type AS_BINARY() StringExpression } @@ -73,6 +77,14 @@ func (c *castImpl) AS_DATE() DateExpression { return DateExp(c.AS("DATE")) } +func (c *castImpl) AS_FLOAT() FloatExpression { + return FloatExp(c.AS("FLOAT")) +} + +func (c *castImpl) AS_DOUBLE() FloatExpression { + return FloatExp(c.AS("DOUBLE")) +} + // AS_DECIMAL casts expression AS DECIMAL type func (c *castImpl) AS_DECIMAL() FloatExpression { return FloatExp(c.AS("DECIMAL")) diff --git a/mysql/dialect.go b/mysql/dialect.go index 18d2eec..9628bfb 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -1,6 +1,7 @@ package mysql import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -28,6 +29,9 @@ func newDialect() jet.Dialect { }, ReservedWords: reservedWords, SerializeOrderBy: serializeOrderBy, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column_%d", index) + }, } return jet.NewDialect(mySQLDialectParams) diff --git a/mysql/functions.go b/mysql/functions.go index ca31d18..ceec7ab 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -12,7 +12,9 @@ var ( ) // ROW function is used to create a tuple value that consists of a set of expressions or column values. -var ROW = jet.ROW +func ROW(expressions ...Expression) RowExpression { + return jet.ROW(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 6c5f345..aaeff9a 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -172,7 +172,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/mysql/select_table.go b/mysql/select_table.go index ad22193..8ca06d4 100644 --- a/mysql/select_table.go +++ b/mysql/select_table.go @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/mysql/set_statement.go b/mysql/set_statement.go index 2df75a0..7147f00 100644 --- a/mysql/set_statement.go +++ b/mysql/set_statement.go @@ -85,7 +85,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/mysql/statement.go b/mysql/statement.go index 073adce..883b99d 100644 --- a/mysql/statement.go +++ b/mysql/statement.go @@ -3,6 +3,6 @@ package mysql import "github.com/go-jet/jet/v2/internal/jet" // RawStatement creates new sql statements from raw query and optional map of named arguments -func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { +func RawStatement(rawQuery string, namedArguments ...RawArgs) jet.SerializerStatement { return jet.RawStatement(Dialect, rawQuery, namedArguments...) } diff --git a/mysql/values.go b/mysql/values.go new file mode 100644 index 0000000..b55895b --- /dev/null +++ b/mysql/values.go @@ -0,0 +1,32 @@ +package mysql + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the ROW constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// ROW(Int32(204), Float32(1.21)), +// ROW(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be +// overwritten by passing new list of columns. +// +// Example usage: +// +// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date")) +func (v values) AS(alias string, columns ...Column) SelectTable { + return newSelectTable(v, alias, columns) +} diff --git a/mysql/with_statement.go b/mysql/with_statement.go index ca608cb..03b4d5b 100644 --- a/mysql/with_statement.go +++ b/mysql/with_statement.go @@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -41,7 +41,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } @@ -52,7 +52,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/postgres/dialect.go b/postgres/dialect.go index 9484ab1..14929ad 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -1,6 +1,7 @@ package postgres import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" "strconv" ) @@ -25,6 +26,9 @@ func newDialect() jet.Dialect { return "$" + strconv.Itoa(ord) }, ReservedWords: reservedWords, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column%d", index+1) + }, } return jet.NewDialect(dialectParams) diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 9aadbc9..6fb987c 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -46,33 +46,33 @@ func TestExists(t *testing.T) { func TestIN(t *testing.T) { assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), - `($1 IN ( + `($1 IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 -))`, float64(1.11)) +)))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) IN ( + `(ROW($1, table1.col1) IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -))`, int64(12)) +)))`, int64(12)) } func TestNOT_IN(t *testing.T) { assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), - `($1 NOT IN ( + `($1 NOT IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 -))`, float64(1.11)) +)))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) NOT IN ( + `(ROW($1, table1.col1) NOT IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -))`, int64(12)) +)))`, int64(12)) } func TestReservedWordEscaped(t *testing.T) { diff --git a/postgres/functions.go b/postgres/functions.go index 7e71a96..3134325 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -14,7 +14,9 @@ var ( ) // ROW function is used to create a tuple value that consists of a set of expressions or column values. -var ROW = jet.ROW +func ROW(expressions ...Expression) RowExpression { + return jet.ROW(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// @@ -425,10 +427,12 @@ func castFloatLiteral(fraction FloatExpression) FloatExpression { // ), var GROUPING_SETS = jet.GROUPING_SETS -// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) -// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW method behave exactly the same, -// except when used in GROUPING_SETS. For top level GROUPING SETS expression lists WRAP has to be used. -var WRAP = jet.WRAP +// WRAP surrounds a list of expressions or columns with parentheses, producing new row: (expression1, expression2, ...) +// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW methods behave exactly the same, +// except when used in GROUPING_SETS and VALUES. In these contexts, WRAP must be used instead of ROW. +func WRAP(expressions ...Expression) RowExpression { + return jet.WRAP(Dialect, expressions...) +} // ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list. // It creates extra rows in the result set that represent the subtotal values for each combination of columns. diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 25300c2..5ace301 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -1,7 +1,6 @@ package postgres import ( - "github.com/go-jet/jet/v2/internal/jet" "github.com/stretchr/testify/require" "testing" "time" @@ -151,12 +150,13 @@ func TestInsert_ON_CONFLICT(t *testing.T) { VALUES("one", "two"). VALUES("1", "2"). VALUES("theta", "beta"). - ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( - SET(table1ColBool.SET(Bool(true)), - table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), - ).WHERE(table1Col1.GT(Int(2))), - ). + ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()). + DO_UPDATE( + SET(table1ColBool.SET(Bool(true)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), + ). RETURNING(table1Col1, table1ColBool) assertDebugStatementSql(t, stmt, ` @@ -178,12 +178,12 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { stmt := table1.INSERT(table1Col1, table1ColBool). VALUES("one", "two"). VALUES("1", "2"). - ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( - SET(table1ColBool.SET(Bool(false)), - table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), - ).WHERE(table1Col1.GT(Int(2))), - ). + ON_CONFLICT().ON_CONSTRAINT("idk_primary_key"). + DO_UPDATE( + SET(table1ColBool.SET(Bool(false)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2)))). RETURNING(table1Col1, table1ColBool) assertDebugStatementSql(t, stmt, ` diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 70a9a50..2a48fc5 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -181,7 +181,7 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/postgres/select_table.go b/postgres/select_table.go index f3d680d..5f57c1d 100644 --- a/postgres/select_table.go +++ b/postgres/select_table.go @@ -2,7 +2,7 @@ package postgres import "github.com/go-jet/jet/v2/internal/jet" -// SelectTable is interface for postgres sub-queries +// SelectTable is interface for postgres temporary tables like sub-queries, VALUES, CTEs etc... type SelectTable interface { readableTable jet.SelectTable @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(serializerWithProjections jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(serializerWithProjections, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/postgres/set_statement.go b/postgres/set_statement.go index 834560d..0dee00d 100644 --- a/postgres/set_statement.go +++ b/postgres/set_statement.go @@ -136,7 +136,7 @@ func (s *setStatementImpl) OFFSET_e(offset IntegerExpression) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/postgres/values.go b/postgres/values.go new file mode 100644 index 0000000..28d17cd --- /dev/null +++ b/postgres/values.go @@ -0,0 +1,32 @@ +package postgres + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the WRAP constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// WRAP(Int32(204), Float32(1.21)), +// WRAP(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be +// overwritten by passing new list of columns. +// +// Example usage: +// +// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date")) +func (v values) AS(alias string, columns ...Column) SelectTable { + return newSelectTable(v, alias, columns) +} diff --git a/postgres/with_statement.go b/postgres/with_statement.go index 698d6e3..99ddc8f 100644 --- a/postgres/with_statement.go +++ b/postgres/with_statement.go @@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -42,7 +42,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } @@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/sqlite/dialect.go b/sqlite/dialect.go index 93e1d2f..da03364 100644 --- a/sqlite/dialect.go +++ b/sqlite/dialect.go @@ -1,6 +1,7 @@ package sqlite import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -23,6 +24,9 @@ func newDialect() jet.Dialect { return "?" }, ReservedWords: reservedWords2, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column%d", index+1) + }, } return jet.NewDialect(mySQLDialectParams) diff --git a/sqlite/functions.go b/sqlite/functions.go index a76f236..ac6bd08 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -15,8 +15,10 @@ var ( OR = jet.OR ) -// ROW is construct one row from a list of expressions. -var ROW = jet.WRAP +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(expressions ...Expression) RowExpression { + return jet.WRAP(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go index e74cda6..531ae76 100644 --- a/sqlite/select_statement.go +++ b/sqlite/select_statement.go @@ -155,7 +155,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/sqlite/select_table.go b/sqlite/select_table.go index 9ac7f72..7421f4b 100644 --- a/sqlite/select_table.go +++ b/sqlite/select_table.go @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go index 0a004bf..47723a3 100644 --- a/sqlite/set_statement.go +++ b/sqlite/set_statement.go @@ -86,7 +86,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/sqlite/values.go b/sqlite/values.go new file mode 100644 index 0000000..cb683cb --- /dev/null +++ b/sqlite/values.go @@ -0,0 +1,26 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the ROW constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// ROW(Int32(204), Float32(1.21)), +// ROW(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +func (v values) AS(alias string) SelectTable { + return newSelectTable(v, alias, nil) +} diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go index 5375fff..b05da7d 100644 --- a/sqlite/with_statement.go +++ b/sqlite/with_statement.go @@ -6,8 +6,8 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression - AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression + AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -42,13 +42,13 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } // AS_NOT_MATERIALIZED is used to define not materialized CTE query -func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.NotMaterialized = true c.CommonTableExpression.Statement = statement return c @@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index 4e3b5d5..b67b901 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -92,3 +92,9 @@ func skipForMariaDB(t *testing.T) { t.SkipNow() } } + +func onlyMariaDB(t *testing.T) { + if !sourceIsMariaDB() { + t.SkipNow() + } +} diff --git a/tests/mysql/values_test.go b/tests/mysql/values_test.go new file mode 100644 index 0000000..09e9332 --- /dev/null +++ b/tests/mysql/values_test.go @@ -0,0 +1,347 @@ +package mysql + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" + + . "github.com/go-jet/jet/v2/mysql" + + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" +) + +func TestVALUES(t *testing.T) { + skipForMariaDB(t) + + valuesTable := VALUES( + ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")), + ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")), + ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + valuesTable.AllColumns(), + ).FROM( + valuesTable, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column_0 AS "column_0", + values_table.column_1 AS "column_1", + values_table.column_2 AS "column_2", + values_table.column_3 AS "column_3", + values_table.column_4 AS "column_4" +FROM ( + VALUES ROW(?, ?, ?, ?, ?), + ROW(? + ?, ?, ?, ?, ?), + ROW(?, ?, ?, ?, NULL) + ) AS values_table; +`) + + var dest []struct { + Column0 int + Column1 int + Column2 float32 + Column3 bool + Column4 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column0": 1, + "Column1": 2, + "Column2": 4.666, + "Column3": false, + "Column4": "txt" + }, + { + "Column0": 13, + "Column1": 22, + "Column2": 33.222, + "Column3": true, + "Column4": "png" + }, + { + "Column0": 11, + "Column1": 22, + "Column2": 33.222, + "Column3": true, + "Column4": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + skipForMariaDB(t) + + title := StringColumn("title") + releaseYear := IntegerColumn("ReleaseYear") + rentalRate := FloatColumn("rental_rate") + + lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0) + + films := VALUES( + ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate), + ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))), + ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL), + ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL), + ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))), + ).AS("film_values", + title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("last_update")) + + stmt := SELECT( + Film.AllColumns, + films.AllColumns(), + ).FROM( + Film. + INNER_JOIN(films, title.EQ(Film.Title)), + ).WHERE(AND( + Film.ReleaseYear.GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + film_values.title AS "title", + film_values.length AS "length", + film_values.''ReleaseYear'' AS "ReleaseYear", + film_values.rental_rate AS "rental_rate", + film_values.last_update AS "last_update" +FROM dvds.film + INNER JOIN ( + VALUES ROW('Chamber Italian', 117, 2005, 5.82, TIMESTAMP('2007-02-11 12:00:00')), + ROW('Grosse Wonderful', 49, 2004, 6.242, TIMESTAMP('2007-02-11 12:00:00') + INTERVAL 1 HOUR), + ROW('Airport Pollock', 54, 2001, 7.22, NULL), + ROW('Bright Encounters', 73, 2002, 8.25, NULL), + ROW('Academy Dinosaur', 83, 2010, 9.22, TIMESTAMP('2007-02-11 12:00:00') - INTERVAL 2 MINUTE) + ) AS film_values (title, length, ''ReleaseYear'', rental_rate, last_update) ON (film_values.title = film.title) +WHERE ( + (film.release_year > film_values.''ReleaseYear'') + AND (film.rental_rate < film_values.rental_rate) + ) +ORDER BY film_values.title; +`, "''", "`")) + + var dest []struct { + Film model.Film + + Title string + Length int + ReleaseYear int + RentalRate float32 + LastUpdate *time.Time + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Len(t, dest, 4) + testutils.AssertJSON(t, dest[0:2], ` +[ + { + "Film": { + "FilmID": 8, + "Title": "AIRPORT POLLOCK", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "SpecialFeatures": "Trailers", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + "Title": "Airport Pollock", + "Length": 54, + "ReleaseYear": 2001, + "RentalRate": 7.22, + "LastUpdate": null + }, + { + "Film": { + "FilmID": 98, + "Title": "BRIGHT ENCOUNTERS", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "SpecialFeatures": "Trailers", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + "Title": "Bright Encounters", + "Length": 73, + "ReleaseYear": 2002, + "RentalRate": 8.25, + "LastUpdate": null + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + skipForMariaDB(t) + + paymentID := IntegerColumn("payment_id") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + ROW(Int32(204), Float(1.21)), + ROW(Int32(207), Float(1.02)), + ROW(Int32(200), Float(1.34)), + ROW(Int32(203), Float(1.72)), + ), + ), + )( + Payment.INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)). + UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(increase)), + ).WHERE(Bool(true)), + ) + + testutils.AssertStatementSql(t, stmt, ` +WITH values_cte (payment_id, increase) AS ( + VALUES ROW(?, ?), + ROW(?, ?), + ROW(?, ?), + ROW(?, ?) +) +UPDATE dvds.payment +INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id) +SET amount = (payment.amount * values_cte.increase) +WHERE ?; +`) + + testutils.AssertExecAndRollback(t, stmt, db, 4) +} + +func TestVALUES_MariaDB(t *testing.T) { + onlyMariaDB(t) // mariadb won't accept values rows if all the elements are placeholders, so we have to use raw statement + + paymentID := IntegerColumn("payment_id") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + RawStatement(` + VALUES (204, 1.21), + (207, 1.02), + (200, 1.34), + (203, 1.72) + `), + ), + )( + SELECT( + Payment.AllColumns, + paymentsToUpdate.AllColumns(), + ).FROM( + Payment. + INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)), + ).WHERE( + increase.GT(Float(1.03)), + ).ORDER_BY( + increase, + ), + ) + + testutils.AssertStatementSql(t, stmt, ` +WITH values_cte (payment_id, increase) AS ( + VALUES (204, 1.21), + (207, 1.02), + (200, 1.34), + (203, 1.72) + +) +SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update", + values_cte.payment_id AS "payment_id", + values_cte.increase AS "increase" +FROM dvds.payment + INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id) +WHERE values_cte.increase > ? +ORDER BY values_cte.increase; +`) + + var dest []struct { + model.Payment + + Increase float64 + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "PaymentID": 204, + "CustomerID": 7, + "StaffID": 1, + "RentalID": 13476, + "Amount": 2.99, + "PaymentDate": "2005-08-20T01:06:04Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.21 + }, + { + "PaymentID": 200, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 11542, + "Amount": 7.99, + "PaymentDate": "2005-08-17T00:51:32Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.34 + }, + { + "PaymentID": 203, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 13373, + "Amount": 2.99, + "PaymentDate": "2005-08-19T21:23:31Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.72 + } +] +`) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 07350b8..b899ca2 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -4,7 +4,7 @@ import ( "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" - "log/slog" + "testing" "time" @@ -347,8 +347,6 @@ func TestExpressionOperators(t *testing.T) { AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) - // fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", @@ -376,9 +374,6 @@ LIMIT $11; err := query.Query(db, &dest) require.NoError(t, err) - - //testutils.PrintJson(dest) - testutils.AssertJSON(t, dest, ` [ { @@ -936,7 +931,7 @@ func TestTimeExpression(t *testing.T) { func TestIntervalSetFunctionality(t *testing.T) { t.Run("updateQueryIntervalTest", func(t *testing.T) { - slog.Info("Running test", slog.Any("test", t.Name())) + expectedQuery := ` UPDATE test_sample.employee SET pto_accrual = INTERVAL '3 HOUR' diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index 2d6784d..7aad0d4 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -2,6 +2,7 @@ package postgres import ( "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/table" "github.com/stretchr/testify/require" @@ -10,19 +11,7 @@ import ( func TestNorthwindJoinEverything(t *testing.T) { - stmt := Customers. - LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). - LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)). - LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)). - LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)). - LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)). - LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)). - LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)). - LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)). - LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)). - LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). - LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). - LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)). + stmt := SELECT( Customers.AllColumns, CustomerDemographics.AllColumns, @@ -32,8 +21,21 @@ func TestNorthwindJoinEverything(t *testing.T) { Products.AllColumns, Categories.AllColumns, Suppliers.AllColumns, - ). - ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID) + ).FROM( + Customers. + LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). + LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)). + LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)). + LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)). + LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)). + LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)). + LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)). + LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)). + LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)). + LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). + LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). + LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)), + ).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID) var dest []struct { model.Customers diff --git a/tests/postgres/values_test.go b/tests/postgres/values_test.go new file mode 100644 index 0000000..9e89f7e --- /dev/null +++ b/tests/postgres/values_test.go @@ -0,0 +1,284 @@ +package postgres + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestVALUES(t *testing.T) { + + values := VALUES( + WRAP(Int32(1), Int32(2), Float32(4.666), Bool(false), String("txt")), + WRAP(Int32(11).ADD(Int32(2)), Int32(22), Float32(33.222), Bool(true), String("png")), + WRAP(Int32(11), Int32(22), Float32(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + values.AllColumns(), + ).FROM( + values, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column1 AS "column1", + values_table.column2 AS "column2", + values_table.column3 AS "column3", + values_table.column4 AS "column4", + values_table.column5 AS "column5" +FROM ( + VALUES ($1::integer, $2::integer, $3::real, $4::boolean, $5::text), + ($6::integer + $7::integer, $8::integer, $9::real, $10::boolean, $11::text), + ($12::integer, $13::integer, $14::real, $15::boolean, NULL) + ) AS values_table; +`) + + var dest []struct { + Column1 int + Column2 int + Column3 float32 + Column4 bool + Column5 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column1": 1, + "Column2": 2, + "Column3": 4.666, + "Column4": false, + "Column5": "txt" + }, + { + "Column1": 13, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": "png" + }, + { + "Column1": 11, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + + title := StringColumn("title") + releaseYear := IntegerColumn("ReleaseYear") + rentalRate := FloatColumn("rental_rate") + + lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0) + + filmValues := VALUES( + WRAP(String("Chamber Italian"), Int64(117), Int32(2005), Float32(5.82), lastUpdate), + WRAP(String("Grosse Wonderful"), Int64(49), Int32(2004), Float32(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))), + WRAP(String("Airport Pollock"), Int64(54), Int32(2001), Float32(7.22), NULL), + WRAP(String("Bright Encounters"), Int64(73), Int32(2002), Float32(8.25), NULL), + WRAP(String("Academy Dinosaur"), Int64(83), Int32(2010), Float32(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))), + ).AS("film_values", + title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("update_date")) + + stmt := SELECT( + Film.AllColumns, + filmValues.AllColumns(), + ).FROM( + Film. + INNER_JOIN(filmValues, title.EQ(Film.Title)), + ).WHERE(AND( + Film.ReleaseYear.GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.last_update AS "film.last_update", + film.special_features AS "film.special_features", + film.fulltext AS "film.fulltext", + film_values.title AS "title", + film_values.length AS "length", + film_values."ReleaseYear" AS "ReleaseYear", + film_values.rental_rate AS "rental_rate", + film_values.update_date AS "update_date" +FROM dvds.film + INNER JOIN ( + VALUES ('Chamber Italian'::text, 117::bigint, 2005::integer, 5.820000171661377::real, '2007-02-11 12:00:00'::timestamp without time zone), + ('Grosse Wonderful'::text, 49::bigint, 2004::integer, 6.242000102996826::real, '2007-02-11 12:00:00'::timestamp without time zone + INTERVAL '1 HOUR'), + ('Airport Pollock'::text, 54::bigint, 2001::integer, 7.21999979019165::real, NULL), + ('Bright Encounters'::text, 73::bigint, 2002::integer, 8.25::real, NULL), + ('Academy Dinosaur'::text, 83::bigint, 2010::integer, 9.220000267028809::real, '2007-02-11 12:00:00'::timestamp without time zone - INTERVAL '2 MINUTE') + ) AS film_values (title, length, "ReleaseYear", rental_rate, update_date) ON (film_values.title = film.title) +WHERE ( + (film.release_year > film_values."ReleaseYear") + AND (film.rental_rate < film_values.rental_rate) + ) +ORDER BY film_values.title; +`) + + //fmt.Println(stmt.DebugSql()) + + var dest []struct { + Film model.Film + + Title string + Length int + ReleaseYear int + RentalRate float32 + UpdateDate *time.Time + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + assert.Len(t, dest, 4) + testutils.AssertJSON(t, dest[0:2], ` +[ + { + "Film": { + "FilmID": 8, + "Title": "Airport Pollock", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers}", + "Fulltext": "'airport':1 'ancient':18 'confront':14 'epic':4 'girl':11 'india':19 'monkey':16 'moos':8 'must':13 'pollock':2 'tale':5" + }, + "Title": "Airport Pollock", + "Length": 54, + "ReleaseYear": 2001, + "RentalRate": 7.22, + "UpdateDate": null + }, + { + "Film": { + "FilmID": 98, + "Title": "Bright Encounters", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers}", + "Fulltext": "'boat':20 'bright':1 'conquer':14 'encount':2 'fate':4 'feminist':11 'jet':19 'lumberjack':8 'must':13 'student':16 'yarn':5" + }, + "Title": "Bright Encounters", + "Length": 73, + "ReleaseYear": 2002, + "RentalRate": 8.25, + "UpdateDate": null + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + + paymentID := IntegerColumn("payment_ID") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + WRAP(Int32(20564), Float32(1.21)), + WRAP(Int32(20567), Float32(1.02)), + WRAP(Int32(20570), Float32(1.34)), + WRAP(Int32(20573), Float32(1.72)), + ), + ), + )( + Payment.UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(CAST(increase).AS_DECIMAL())), + ). + FROM(paymentsToUpdate). + WHERE(Payment.PaymentID.EQ(paymentID)). + RETURNING(Payment.AllColumns), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +WITH values_cte ("payment_ID", increase) AS ( + VALUES (20564::integer, 1.2100000381469727::real), + (20567::integer, 1.0199999809265137::real), + (20570::integer, 1.340000033378601::real), + (20573::integer, 1.7200000286102295::real) +) +UPDATE dvds.payment +SET amount = (payment.amount * values_cte.increase::decimal) +FROM values_cte +WHERE payment.payment_id = values_cte."payment_ID" +RETURNING payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + + var payments []model.Payment + + err := stmt.Query(tx, &payments) + require.NoError(t, err) + + assert.Len(t, payments, 4) + testutils.AssertJSON(t, payments[0:2], ` +[ + { + "PaymentID": 20564, + "CustomerID": 379, + "StaffID": 2, + "RentalID": 11457, + "Amount": 4.83, + "PaymentDate": "2007-03-02T19:42:42.996577Z" + }, + { + "PaymentID": 20567, + "CustomerID": 379, + "StaffID": 2, + "RentalID": 13397, + "Amount": 8.15, + "PaymentDate": "2007-03-19T20:35:01.996577Z" + } +] +`) + }) + +} diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go index 110c659..e5c5245 100644 --- a/tests/sqlite/update_test.go +++ b/tests/sqlite/update_test.go @@ -283,7 +283,7 @@ func TestUpdateContextDeadlineExceeded(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := updateStmt.QueryContext(ctx, tx, &dest) require.Error(t, err, "context deadline exceeded") diff --git a/tests/sqlite/values_test.go b/tests/sqlite/values_test.go new file mode 100644 index 0000000..0793397 --- /dev/null +++ b/tests/sqlite/values_test.go @@ -0,0 +1,344 @@ +package sqlite + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" + + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" +) + +func TestVALUES(t *testing.T) { + + values := VALUES( + ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")), + ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")), + ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + values.AllColumns(), + ).FROM( + values, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column1 AS "column1", + values_table.column2 AS "column2", + values_table.column3 AS "column3", + values_table.column4 AS "column4", + values_table.column5 AS "column5" +FROM ( + VALUES (?, ?, ?, ?, ?), + (? + ?, ?, ?, ?, ?), + (?, ?, ?, ?, NULL) + ) AS values_table; +`) + + var dest []struct { + Column1 int + Column2 int + Column3 float32 + Column4 bool + Column5 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column1": 1, + "Column2": 2, + "Column3": 4.666, + "Column4": false, + "Column5": "txt" + }, + { + "Column1": 13, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": "png" + }, + { + "Column1": 11, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + + lastUpdate := DateTime(2007, time.February, 11, 12, 0, 0) + + films := VALUES( + ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate), + ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate), + ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL), + ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL), + ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), DATETIME(lastUpdate, YEARS(2))), + ).AS("film_values") + + title := StringColumn("column1").From(films) + releaseYear := IntegerColumn("column3").From(films) + rentalRate := FloatColumn("column4").From(films) + + stmt := SELECT( + Film.AllColumns, + films.AllColumns(), + ).FROM( + Film. + INNER_JOIN(films, LOWER(title).EQ(LOWER(Film.Title))), + ).WHERE(AND( + CAST(Film.ReleaseYear).AS_INTEGER().GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + film_values.column1 AS "column1", + film_values.column2 AS "column2", + film_values.column3 AS "column3", + film_values.column4 AS "column4", + film_values.column5 AS "column5" +FROM film + INNER JOIN ( + VALUES ('Chamber Italian', 117, 2005, 5.82, DATETIME('2007-02-11 12:00:00')), + ('Grosse Wonderful', 49, 2004, 6.242, DATETIME('2007-02-11 12:00:00')), + ('Airport Pollock', 54, 2001, 7.22, NULL), + ('Bright Encounters', 73, 2002, 8.25, NULL), + ('Academy Dinosaur', 83, 2010, 9.22, DATETIME(DATETIME('2007-02-11 12:00:00'), '2 YEARS')) + ) AS film_values ON (LOWER(film_values.column1) = LOWER(film.title)) +WHERE ( + (CAST(film.release_year AS INTEGER) > film_values.column3) + AND (film.rental_rate < film_values.column4) + ) +ORDER BY film_values.column1; +`) + + var dest []struct { + Film model.Film + + Column1 string + Column2 int + Column3 int + Column4 float32 + Column5 *time.Time + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Film": { + "FilmID": 8, + "Title": "AIRPORT POLLOCK", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Airport Pollock", + "Column2": 54, + "Column3": 2001, + "Column4": 7.22, + "Column5": null + }, + { + "Film": { + "FilmID": 98, + "Title": "BRIGHT ENCOUNTERS", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Bright Encounters", + "Column2": 73, + "Column3": 2002, + "Column4": 8.25, + "Column5": null + }, + { + "Film": { + "FilmID": 133, + "Title": "CHAMBER ITALIAN", + "Description": "A Fateful Reflection of a Moose And a Husband who must Overcome a Monkey in Nigeria", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 7, + "RentalRate": 4.99, + "Length": 117, + "ReplacementCost": 14.99, + "Rating": "NC-17", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Chamber Italian", + "Column2": 117, + "Column3": 2005, + "Column4": 5.82, + "Column5": "2007-02-11T12:00:00Z" + }, + { + "Film": { + "FilmID": 384, + "Title": "GROSSE WONDERFUL", + "Description": "A Epic Drama of a Cat And a Explorer who must Redeem a Moose in Australia", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 5, + "RentalRate": 4.99, + "Length": 49, + "ReplacementCost": 19.99, + "Rating": "R", + "SpecialFeatures": "Behind the Scenes", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Grosse Wonderful", + "Column2": 49, + "Column3": 2004, + "Column4": 6.242, + "Column5": "2007-02-11T12:00:00Z" + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + + paymentID := IntegerColumn("payment_ID") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + ROW(Int32(204), Float(1.21)), + ROW(Int32(207), Float(1.02)), + ROW(Int32(200), Float(1.34)), + ROW(Int32(203), Float(1.72)), + ), + ), + )( + Payment.UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(increase)), + ). + FROM(paymentsToUpdate). + WHERE(Payment.PaymentID.EQ(paymentID)). + RETURNING(Payment.AllColumns), + ) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +WITH values_cte (''payment_ID'', increase) AS ( + VALUES (?, ?), + (?, ?), + (?, ?), + (?, ?) +) +UPDATE payment +SET amount = (payment.amount * values_cte.increase) +FROM values_cte +WHERE payment.payment_id = values_cte.''payment_ID'' +RETURNING payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update"; +`, "''", "`")) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var payments []model.Payment + + err := stmt.Query(tx, &payments) + + require.NoError(t, err) + testutils.AssertJSON(t, payments, ` +[ + { + "PaymentID": 200, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 11542, + "Amount": 10.706600000000002, + "PaymentDate": "2005-08-17T00:51:32Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 203, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 13373, + "Amount": 5.1428, + "PaymentDate": "2005-08-19T21:23:31Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 204, + "CustomerID": 7, + "StaffID": 1, + "RentalID": 13476, + "Amount": 3.6179, + "PaymentDate": "2005-08-20T01:06:04Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 207, + "CustomerID": 8, + "StaffID": 2, + "RentalID": 866, + "Amount": 7.1298, + "PaymentDate": "2005-05-30T03:43:54Z", + "LastUpdate": "2019-04-11T18:11:50Z" + } +] +`) + }) + +}