diff --git a/.circleci/config.yml b/.circleci/config.yml index f83283c..e21f00a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,8 +8,8 @@ jobs: build_and_tests: docker: # specify the version - - image: circleci/golang:1.16 - - image: circleci/postgres:12 + - image: cimg/go:1.21.6 + - image: cimg/postgres:14.10 environment: POSTGRES_USER: jet POSTGRES_PASSWORD: jet @@ -33,8 +33,8 @@ jobs: MYSQL_USER: jet MYSQL_PASSWORD: jet - - image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 - command: ['start-single-node', '--insecure'] + - image: cockroachdb/cockroach-unstable:v23.1.0-rc.2 + command: ['start-single-node', '--accept-sql-without-tls'] environment: COCKROACH_USER: jet COCKROACH_PASSWORD: jet diff --git a/README.md b/README.md index 23b1386..15fd887 100644 --- a/README.md +++ b/README.md @@ -105,25 +105,25 @@ Done ``` Procedure is similar for MySQL, CockroachDB, MariaDB and SQLite. For example: ```sh -jet -source=mysql -dsn="user:pass@tcp(localhost:3306)/dbname" -path=./gen +jet -source=mysql -dsn="user:pass@tcp(localhost:3306)/dbname" -path=./.gen jet -dsn=postgres://user:pass@localhost:26257/jetdb?sslmode=disable -schema=dvds -path=./.gen #cockroachdb -jet -dsn="mariadb://user:pass@tcp(localhost:3306)/dvds" -path=./gen # source flag can be omitted if data source appears in dsn -jet -source=sqlite -dsn="/path/to/sqlite/database/file" -schema=dvds -path=./gen -jet -dsn="file:///path/to/sqlite/database/file" -schema=dvds -path=./gen # sqlite database assumed for 'file' data sources +jet -dsn="mariadb://user:pass@tcp(localhost:3306)/dvds" -path=./.gen # source flag can be omitted if data source appears in dsn +jet -source=sqlite -dsn="/path/to/sqlite/database/file" -schema=dvds -path=./.gen +jet -dsn="file:///path/to/sqlite/database/file" -schema=dvds -path=./.gen # sqlite database assumed for 'file' data sources ``` _*User has to have a permission to read information schema tables._ As command output suggest, Jet will: - connect to postgres database and retrieve information about the _tables_, _views_ and _enums_ of `dvds` schema -- delete everything in schema destination folder - `./gen/jetdb/dvds`, +- delete everything in schema destination folder - `./.gen/jetdb/dvds`, - and finally generate SQL Builder and Model types for each schema table, view and enum. Generated files folder structure will look like this: ```sh -|-- gen # -path -| `-- jetdb # database name -| `-- dvds # schema name +|-- .gen # path +| -- jetdb # database name +| -- dvds # schema name | |-- enum # sql builder package for enums | | |-- mpaa_rating.go | |-- table # sql builder package for tables @@ -131,7 +131,7 @@ Generated files folder structure will look like this: | |-- address.go | |-- category.go | ... -| |-- view # sql builder package for views +| |-- view # sql builder package for views | |-- actor_info.go | |-- film_list.go | ... @@ -156,7 +156,7 @@ import ( . "github.com/go-jet/jet/v2/examples/quick-start/.gen/jetdb/dvds/table" . "github.com/go-jet/jet/v2/postgres" - "github.com/go-jet/jet/v2/examples/quick-start/gen/jetdb/dvds/model" + "github.com/go-jet/jet/v2/examples/quick-start/.gen/jetdb/dvds/model" ) ``` Let's say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English' @@ -530,8 +530,8 @@ Automatic scan to arbitrary structure removes a lot of headache and boilerplate ##### Speed of execution -While ORM libraries can introduce significant performance penalties due to number of round-trips to the database, -Jet will always perform better as developers can write complex query and retrieve result with a single database call. +While ORM libraries can introduce significant performance penalties due to number of round-trips to the database(N+1 query problem), +`jet` will always perform better as developers can write complex query and retrieve result with a single database call. Thus handler time lost on latency between server and database can be constant. Handler execution will be proportional only to the query complexity and the number of rows returned from database. @@ -579,5 +579,5 @@ To run the tests, additional dependencies are required: ## License -Copyright 2019-2023 Goran Bjelanovic +Copyright 2019-2024 Goran Bjelanovic Licensed under the Apache License, Version 2.0. diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index 55533f4..5c888a3 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -6,7 +6,7 @@ import ( // Column struct type Column struct { - Name string + Name string `sql:"primary_key"` IsPrimaryKey bool IsNullable bool IsGenerated bool @@ -33,6 +33,7 @@ const ( EnumType DataTypeKind = "enum" UserDefinedType DataTypeKind = "user-defined" ArrayType DataTypeKind = "array" + RangeType DataTypeKind = "range" ) // DataType contains information about column data type diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go index 2e01c61..df9514e 100644 --- a/generator/metadata/table_meta_data.go +++ b/generator/metadata/table_meta_data.go @@ -2,7 +2,7 @@ package metadata // Table metadata struct type Table struct { - Name string + Name string `sql:"primary_key"` Columns []Column } diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index 635bcbc..7495bec 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -12,6 +12,8 @@ import ( mysqldr "github.com/go-sql-driver/mysql" ) +const mysqlMaxConns = 10 + // DBConnection contains MySQL connection details type DBConnection struct { Host string @@ -83,6 +85,9 @@ func openConnection(connectionString string) (*sql.DB, error) { return nil, fmt.Errorf("failed to open mysql connection: %w", err) } + db.SetMaxOpenConns(mysqlMaxConns) + db.SetMaxIdleConns(mysqlMaxConns) + err = db.Ping() if err != nil { return nil, fmt.Errorf("failed to ping database: %w", err) diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 9eb9257..bd4d593 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -15,60 +15,48 @@ type mySqlQuerySet struct{} func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) { query := ` -SELECT table_name as "table.name" -FROM INFORMATION_SCHEMA.tables -WHERE table_schema = ? and table_type = ? -ORDER BY table_name; -` +SELECT + t.table_name as "table.name", + col.COLUMN_NAME AS "column.Name", + col.IS_NULLABLE = "YES" AS "column.IsNullable", + col.COLUMN_COMMENT AS "column.Comment", + COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey", + IF (col.COLUMN_TYPE = 'tinyint(1)', + 'boolean', + IF (col.DATA_TYPE = 'enum', + CONCAT(col.TABLE_NAME, '_', col.COLUMN_NAME), + col.DATA_TYPE) + ) AS "dataType.Name", + IF (col.DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind", + col.COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned" +FROM INFORMATION_SCHEMA.tables AS t +INNER JOIN + information_schema.columns AS col + ON t.table_schema = col.table_schema AND t.table_name = col.table_name +LEFT JOIN ( + SELECT k.column_name, 1 AS IsPrimaryKey, k.table_name + FROM information_schema.table_constraints t + JOIN information_schema.key_column_usage k USING(constraint_name, table_schema, table_name) + WHERE t.table_schema = ? + AND t.constraint_type = 'PRIMARY KEY' +) AS pk ON col.COLUMN_NAME = pk.column_name AND col.table_name = pk.table_name +WHERE t.table_schema = ? + AND t.table_type = ? +ORDER BY + t.table_name, + col.ordinal_position; + ` + var tables []metadata.Table - _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, schemaName, tableType}, &tables) if err != nil { - return nil, fmt.Errorf("failed to query %s metadata result: %w", tableType, err) - } - - for i := range tables { - tables[i].Columns, err = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name) - if err != nil { - return nil, fmt.Errorf("failed to get '%s' table columns metadata: %w", tables[i].Name, err) - } + return nil, fmt.Errorf("failed to query column meta data: %w", err) } return tables, nil } -func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { - query := ` -SELECT COLUMN_NAME AS "column.Name", - IS_NULLABLE = "YES" AS "column.IsNullable", - columns.COLUMN_COMMENT as "column.Comment", - (EXISTS( - SELECT 1 - FROM information_schema.table_constraints t - JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name) - WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY' AND k.column_name = columns.column_name - )) AS "column.IsPrimaryKey", - IF (COLUMN_TYPE = 'tinyint(1)', - 'boolean', - IF (DATA_TYPE='enum', - CONCAT(TABLE_NAME, '_', COLUMN_NAME), - DATA_TYPE) - ) AS "dataType.Name", - IF (DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind", - COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned" -FROM information_schema.columns -WHERE table_schema = ? AND table_name = ? -ORDER BY ordinal_position; -` - var columns []metadata.Column - _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) - if err != nil { - return nil, fmt.Errorf("failed to query %s column meta data: %w", tableName, err) - } - - return columns, nil -} - func (m mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) { query := ` SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name", diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 856173d..2d8835b 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -26,8 +26,25 @@ ORDER BY table_name; return nil, fmt.Errorf("failed to query %s metadata: %w", tableType, err) } + // add materialized views separately, because materialized views are not part of standard information schema + if tableType == metadata.ViewTable { + matViewQuery := ` + select matviewname as "table.name" + from pg_matviews + where schemaname = $1; + ` + var matViews []metadata.Table + + _, err := qrm.Query(context.Background(), db, matViewQuery, []interface{}{schemaName}, &matViews) + if err != nil { + return nil, fmt.Errorf("failed to query materialized view metadata: %w", err) + } + + tables = append(tables, matViews...) + } + for i := range tables { - tables[i].Columns, err = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + tables[i].Columns, err = getColumnsMetaData(db, schemaName, tables[i].Name) if err != nil { return nil, fmt.Errorf("failed to query %s columns metadata: %w", tableType, err) } @@ -36,39 +53,39 @@ ORDER BY table_name; return tables, nil } -func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { +func getColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { query := ` -WITH primaryKeys AS ( - SELECT column_name - FROM information_schema.key_column_usage AS c - LEFT JOIN information_schema.table_constraints AS t - ON t.constraint_name = c.constraint_name AND - c.table_schema = t.table_schema AND - c.table_name = t.table_name - WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY' -) -SELECT column_name as "column.Name", - is_nullable = 'YES' as "column.isNullable", - is_generated = 'ALWAYS' or is_generated = 'YES' as "column.isGenerated", - (EXISTS(SELECT 1 from primaryKeys as pk where pk.column_name = columns.column_name)) as "column.IsPrimaryKey", - dataType.kind as "dataType.Kind", - (case dataType.Kind when 'base' then data_type else LTRIM(udt_name, '_') end) as "dataType.Name", - FALSE as "dataType.isUnsigned" -FROM information_schema.columns, - LATERAL (select (case data_type - when 'ARRAY' then 'array' - when 'USER-DEFINED' then - case (select t.typtype - from pg_type as t - join pg_namespace as p on p.oid = t.typnamespace - where t.typname = columns.udt_name and p.nspname = $1) - when 'e' then 'enum' - else 'user-defined' - end - else 'base' - end) as Kind) as dataType -where table_schema = $1 and table_name = $2 -order by ordinal_position; +select + attr.attname as "column.Name", + exists( + select 1 + from pg_catalog.pg_index indx + where attr.attrelid = indx.indrelid and attr.attnum = any(indx.indkey) and indx.indisprimary + ) as "column.IsPrimaryKey", + not attr.attnotnull as "column.isNullable", + attr.attgenerated = 's' as "column.isGenerated", + (case tp.typtype + when 'b' then 'base' + when 'd' then 'base' + when 'e' then 'enum' + when 'r' then 'range' + end) as "dataType.Kind", + (case when tp.typtype = 'd' then (select pg_type.typname from pg_catalog.pg_type where pg_type.oid = tp.typbasetype) + when tp.typcategory = 'A' then pg_catalog.format_type(attr.atttypid, attr.atttypmod) + else tp.typname + end) as "dataType.Name", + false as "dataType.isUnsigned" +from pg_catalog.pg_attribute as attr + join pg_catalog.pg_class as cls on cls.oid = attr.attrelid + join pg_catalog.pg_namespace as ns on ns.oid = cls.relnamespace + join pg_catalog.pg_type as tp on tp.oid = attr.atttypid +where + ns.nspname = $1 and + cls.relname = $2 and + not attr.attisdropped and + attr.attnum > 0 +order by + attr.attnum; ` var columns []metadata.Column _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index d1d0bf7..745aae4 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/semantic" "github.com/go-jet/jet/v2/qrm" "strings" ) @@ -42,16 +43,45 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy return tables, nil } +func getTableInfoQuery(db *sql.DB) (string, error) { + var version string + err := db.QueryRow("select sqlite_version();").Scan(&version) + + if err != nil { + return "", fmt.Errorf("failed to get sqlite version: %w", err) + } + + sqliteVersion, err := semantic.VersionFromString(version) + + if err != nil { + return "", fmt.Errorf("can't parse sqlite version: %w", err) + } + + // generated columns were added in version 3.26.0 + if sqliteVersion.Lt(semantic.Version{Major: 3, Minor: 26, Patch: 0}) { + return `select * from pragma_table_info(?);`, nil + } + + return `select * from pragma_table_xinfo(?);`, nil +} + func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) { - query := fmt.Sprintf(`select * from pragma_table_info(?);`) + + tableInfoQuery, err := getTableInfoQuery(db) + + if err != nil { + return nil, err + } + var columnInfos []struct { Name string Type string NotNull int32 Pk int32 + Hidden int32 } - _, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) + _, err = qrm.Query(context.Background(), db, tableInfoQuery, []interface{}{tableName}, &columnInfos) if err != nil { return nil, fmt.Errorf("failed to query '%s' column metadata: %w", tableName, err) } @@ -59,12 +89,14 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t var columns []metadata.Column for _, columnInfo := range columnInfos { - columnType := getColumnType(columnInfo.Type) + columnType := strings.TrimSuffix(getColumnType(columnInfo.Type), " GENERATED ALWAYS") + isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column columns = append(columns, metadata.Column{ Name: columnInfo.Name, IsPrimaryKey: columnInfo.Pk != 0, IsNullable: columnInfo.NotNull != 1, + IsGenerated: isGenerated, DataType: metadata.DataType{ Name: columnType, Kind: metadata.BaseType, diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index c2dc33b..731b1af 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -24,7 +24,7 @@ import ( "github.com/go-jet/jet/v2/{{dialect.PackageName}}" ) -var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "") +var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "{{tableTemplate.DefaultAlias}}") type {{structImplName}} struct { {{dialect.PackageName}}.Table diff --git a/generator/template/model_template.go b/generator/template/model_template.go index 47f0339..f89ebd1 100644 --- a/generator/template/model_template.go +++ b/generator/template/model_template.go @@ -5,6 +5,7 @@ import ( "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/internal/utils/dbidentifier" "github.com/google/uuid" + "github.com/jackc/pgtype" "path" "reflect" "strings" @@ -320,6 +321,18 @@ func toGoType(column metadata.Column) interface{} { return float64(0.0) case "uuid": return uuid.UUID{} + case "daterange": + return pgtype.Daterange{} + case "tsrange": + return pgtype.Tsrange{} + case "tstzrange": + return pgtype.Tstzrange{} + case "int4range": + return pgtype.Int4range{} + case "int8range": + return pgtype.Int8range{} + case "numrange": + return pgtype.Numrange{} default: fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.") return "" diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index 4b68f21..fe7fba5 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -59,6 +59,7 @@ type TableSQLBuilder struct { FileName string InstanceName string TypeName string + DefaultAlias string Column func(columnMetaData metadata.Column) TableSQLBuilderColumn } @@ -67,11 +68,14 @@ type ViewSQLBuilder = TableSQLBuilder // DefaultTableSQLBuilder returns default implementation for TableSQLBuilder func DefaultTableSQLBuilder(tableMetaData metadata.Table) TableSQLBuilder { + tableNameGoIdentifier := dbidentifier.ToGoIdentifier(tableMetaData.Name) + return TableSQLBuilder{ Path: "/table", FileName: dbidentifier.ToGoFileName(tableMetaData.Name), - InstanceName: dbidentifier.ToGoIdentifier(tableMetaData.Name), - TypeName: dbidentifier.ToGoIdentifier(tableMetaData.Name) + "Table", + InstanceName: tableNameGoIdentifier, + TypeName: tableNameGoIdentifier + "Table", + DefaultAlias: "", Column: DefaultTableSQLBuilderColumn, } } @@ -112,6 +116,12 @@ func (tb TableSQLBuilder) UseTypeName(name string) TableSQLBuilder { return tb } +// UseDefaultAlias returns new TableSQLBuilder with new default alias set +func (tb TableSQLBuilder) UseDefaultAlias(defaultAlias string) TableSQLBuilder { + tb.DefaultAlias = defaultAlias + return tb +} + // UseColumn returns new TableSQLBuilder with new column template function set func (tb TableSQLBuilder) UseColumn(columnsFunc func(column metadata.Column) TableSQLBuilderColumn) TableSQLBuilder { tb.Column = columnsFunc @@ -134,14 +144,15 @@ func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilde // getSqlBuilderColumnType returns type of jet sql builder column func getSqlBuilderColumnType(columnMetaData metadata.Column) string { - if columnMetaData.DataType.Kind != metadata.BaseType { + if columnMetaData.DataType.Kind != metadata.BaseType && + columnMetaData.DataType.Kind != metadata.RangeType { return "String" } switch strings.ToLower(columnMetaData.DataType.Name) { - case "boolean": + case "boolean", "bool": return "Bool" - case "smallint", "integer", "bigint", + case "smallint", "integer", "bigint", "int2", "int4", "int8", "tinyint", "mediumint", "int", "year": //MySQL return "Integer" case "date": @@ -149,23 +160,35 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string { case "timestamp without time zone", "timestamp", "datetime": //MySQL: return "Timestamp" - case "timestamp with time zone": + case "timestamp with time zone", "timestamptz": return "Timestampz" case "time without time zone", "time": //MySQL return "Time" - case "time with time zone": + case "time with time zone", "timetz": return "Timez" case "interval": return "Interval" case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", - "char", "varchar", "nvarchar", "binary", "varbinary", + "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL return "String" - case "real", "numeric", "decimal", "double precision", "float", + case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", "double": // MySQL return "Float" + case "daterange": + return "DateRange" + case "tsrange": + return "TimestampRange" + case "tstzrange": + return "TimestampzRange" + case "int4range": + return "Int4Range" + case "int8range": + return "Int8Range" + case "numrange": + return "NumericRange" default: fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.") return "String" diff --git a/go.mod b/go.mod index 685bebb..2f60aea 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,15 @@ module github.com/go-jet/jet/v2 -go 1.11 +go 1.21.6 require ( - github.com/go-sql-driver/mysql v1.7.0 - github.com/google/uuid v1.3.0 - github.com/jackc/pgconn v1.14.0 + github.com/go-sql-driver/mysql v1.7.1 + github.com/google/uuid v1.6.0 + github.com/jackc/pgconn v1.14.1 github.com/lib/pq v1.10.8 - github.com/mattn/go-sqlite3 v1.14.16 + github.com/mattn/go-sqlite3 v1.14.17 ) -// test dependencies require ( github.com/google/go-cmp v0.5.9 github.com/jackc/pgx/v4 v4.18.1 @@ -18,5 +17,28 @@ require ( github.com/shopspring/decimal v1.3.1 github.com/stretchr/testify v1.8.2 github.com/volatiletech/null/v8 v8.1.2 + golang.org/x/sync v0.3.0 gopkg.in/guregu/null.v4 v4.0.0 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/felixge/fgprof v0.9.3 // indirect + github.com/friendsofgo/errors v0.9.2 // indirect + github.com/gofrs/uuid v4.0.0+incompatible // indirect + github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.3.2 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgtype v1.14.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/volatiletech/inflect v0.0.1 // indirect + github.com/volatiletech/randomize v0.0.1 // indirect + github.com/volatiletech/strmangle v0.0.1 // indirect + golang.org/x/crypto v0.6.0 // indirect + golang.org/x/text v0.7.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index 2c11a02..2044f31 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ github.com/friendsofgo/errors v0.9.2 h1:X6NYxef4efCBdwI7BgS820zFaN7Cphrmb+Pljdzj github.com/friendsofgo/errors v0.9.2/go.mod h1:yCvFW5AkDIL9qn7suHVLiI/gH228n7PC4Pn44IGoTOI= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= -github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= @@ -29,10 +29,9 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= -github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= @@ -43,8 +42,9 @@ github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsU github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= -github.com/jackc/pgconn v1.14.0 h1:vrbA9Ud87g6JdFWkHTJXppVce58qPIdP7N8y0Ml/A7Q= github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgconn v1.14.1 h1:smbxIaZA08n6YuxEX1sDyjV/qkbtUtkH20qLkR9MUR4= +github.com/jackc/pgconn v1.14.1/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -53,7 +53,6 @@ github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5W github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= @@ -102,8 +101,8 @@ github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= -github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= @@ -182,6 +181,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 1619124..533d223 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -222,15 +222,40 @@ func (l *ClauseLimit) Serialize(statementType StatementType, out *SQLBuilder, op // ClauseOffset struct type ClauseOffset struct { - Count int64 + Count IntegerExpression } // Serialize serializes clause into SQLBuilder func (o *ClauseOffset) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - if o.Count >= 0 { - out.NewLine() - out.WriteString("OFFSET") - out.insertParametrizedArgument(o.Count) + if is.Nil(o.Count) { + return + } + + out.NewLine() + out.WriteString("OFFSET") + o.Count.serialize(statementType, out, options...) +} + +// ClauseFetch struct +type ClauseFetch struct { + Count IntegerExpression + WithTies bool +} + +// Serialize serializes ClauseFetch into sql builder output +func (o *ClauseFetch) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if is.Nil(o.Count) { + return + } + + out.NewLine() + out.WriteString("FETCH FIRST") + o.Count.serialize(statementType, out, options...) + + if o.WithTies { + out.WriteString("ROWS WITH TIES") + } else { + out.WriteString("ROWS ONLY") } } diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index 4748b2c..a732061 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -358,3 +358,44 @@ func DateColumn(name string) ColumnDate { dateColumn.ColumnExpressionImpl = NewColumnImpl(name, "", dateColumn) return dateColumn } + +//------------------------------------------------------// + +// ColumnRange is interface for range columns which can be int range, string range +// timestamp range or date range. +type ColumnRange[T Expression] interface { + Range[T] + Column + + From(subQuery SelectTable) ColumnRange[T] + SET(rangeExp Range[T]) ColumnAssigment +} + +type rangeColumnImpl[T Expression] struct { + rangeInterfaceImpl[T] + ColumnExpressionImpl +} + +func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] { + newRangeColumn := RangeColumn[T](i.name) + newRangeColumn.setTableName(i.tableName) + newRangeColumn.setSubQuery(subQuery) + + return newRangeColumn +} + +func (i *rangeColumnImpl[T]) SET(rangeExp Range[T]) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: rangeExp, + } +} + +// RangeColumn creates named range column. +func RangeColumn[T Expression](name string) ColumnRange[T] { + rangeColumn := &rangeColumnImpl[T]{} + rangeColumn.rangeInterfaceImpl.parent = rangeColumn + rangeColumn.ColumnExpressionImpl = NewColumnImpl(name, "", rangeColumn) + + return rangeColumn +} diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 434e3b8..f3ad2b4 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -12,6 +12,7 @@ type Dialect interface { IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc IsReservedWord(name string) bool + SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc } // SerializerFunc func @@ -33,6 +34,7 @@ type DialectParams struct { IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc ReservedWords []string + SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc } // NewDialect creates new dialect with params @@ -46,6 +48,7 @@ func NewDialect(params DialectParams) Dialect { identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), + serializeOrderBy: params.SerializeOrderBy, } } @@ -58,8 +61,7 @@ type dialectImpl struct { identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc reservedWords map[string]bool - - supportsReturning bool + serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc } func (d *dialectImpl) Name() string { @@ -101,6 +103,10 @@ func (d *dialectImpl) IsReservedWord(name string) bool { return isReservedWord } +func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc { + return d.serializeOrderBy +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 436b9d6..d62920c 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -64,14 +64,24 @@ func (e *ExpressionInterfaceImpl) AS(alias string) Projection { return newAlias(e.Parent, alias) } -// ASC expression will be used to sort query result in ascending order +// ASC expression will be used to sort a query result in ascending order func (e *ExpressionInterfaceImpl) ASC() OrderByClause { - return newOrderByClause(e.Parent, true) + return newOrderByAscending(e.Parent, true) } -// DESC expression will be used to sort query result in descending order +// DESC expression will be used to sort a query result in descending order func (e *ExpressionInterfaceImpl) DESC() OrderByClause { - return newOrderByClause(e.Parent, false) + return newOrderByAscending(e.Parent, false) +} + +// NULLS_FIRST specifies sort where null values appear before all non-null values +func (e *ExpressionInterfaceImpl) NULLS_FIRST() OrderByClause { + return newOrderByNullsFirst(e.Parent, true) +} + +// NULLS_LAST specifies sort where null values appear after all non-null values +func (e *ExpressionInterfaceImpl) NULLS_LAST() OrderByClause { + return newOrderByNullsFirst(e.Parent, false) } func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 4ff3791..c56053f 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -468,6 +468,33 @@ func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType return newBoolFunc("REGEXP_LIKE", stringExp, pattern) } +//----------Range Type Functions ----------------------// + +// LOWER_BOUND returns range expressions lower bound +func LOWER_BOUND[T Expression](rangeExpression Range[T]) T { + return rangeTypeCaster[T](rangeExpression, NewFunc("LOWER", []Expression{rangeExpression}, nil)) +} + +// UPPER_BOUND returns range expressions upper bound +func UPPER_BOUND[T Expression](rangeExpression Range[T]) T { + return rangeTypeCaster[T](rangeExpression, NewFunc("UPPER", []Expression{rangeExpression}, nil)) +} + +func rangeTypeCaster[T Expression](rangeExpression Range[T], exp Expression) T { + var i Expression + switch rangeExpression.(type) { + case Range[DateExpression]: + i = DateExp(exp) + case Range[IntegerExpression]: + i = IntExp(exp) + case Range[TimestampzExpression]: + i = TimestampzExp(exp) + case Range[TimestampExpression]: + i = TimestampExp(exp) + } + return i.(T) +} + //----------Data Type Formatting Functions ----------------------// // TO_CHAR converts expression to string with format @@ -843,3 +870,35 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { func Func(name string, expressions ...Expression) Expression { return NewFunc(name, expressions, nil) } + +func NumRange(lowNum, highNum NumericExpression, bounds ...StringExpression) Range[NumericExpression] { + return RangeExp[NumericExpression](NewFunc("numrange", rangeFuncParamCombiner[NumericExpression](lowNum, highNum, bounds...), nil)) +} + +func Int4Range(lowNum, highNum IntegerExpression, bounds ...StringExpression) Range[IntegerExpression] { + return RangeExp[IntegerExpression](NewFunc("int4range", rangeFuncParamCombiner[IntegerExpression](lowNum, highNum, bounds...), nil)) +} + +func Int8Range(lowNum, highNum IntegerExpression, bounds ...StringExpression) Range[IntegerExpression] { + return RangeExp[IntegerExpression](NewFunc("int8range", rangeFuncParamCombiner[IntegerExpression](lowNum, highNum, bounds...), nil)) +} + +func TimestampRange(lowTs, highTs TimestampExpression, bounds ...StringExpression) Range[TimestampExpression] { + return RangeExp[TimestampExpression](NewFunc("tsrange", rangeFuncParamCombiner[TimestampExpression](lowTs, highTs, bounds...), nil)) +} + +func TimestampzRange(lowTs, highTs TimestampzExpression, bounds ...StringExpression) Range[TimestampzExpression] { + return RangeExp[TimestampzExpression](NewFunc("tstzrange", rangeFuncParamCombiner[TimestampzExpression](lowTs, highTs, bounds...), nil)) +} + +func DateRange(lowTs, highTs DateExpression, bounds ...StringExpression) Range[DateExpression] { + return RangeExp[DateExpression](NewFunc("daterange", rangeFuncParamCombiner[DateExpression](lowTs, highTs, bounds...), nil)) +} + +func rangeFuncParamCombiner[T Expression](low, high T, bounds ...StringExpression) []Expression { + exp := []Expression{low, high} + if len(bounds) != 0 { + exp = append(exp, bounds[0]) + } + return exp +} diff --git a/internal/jet/func_expression_test.go b/internal/jet/func_expression_test.go index 048ade2..ca32589 100644 --- a/internal/jet/func_expression_test.go +++ b/internal/jet/func_expression_test.go @@ -202,3 +202,14 @@ func TestTO_ASCII(t *testing.T) { func TestFunc(t *testing.T) { assertClauseSerialize(t, Func("FOO", String("test"), NULL, MAX(Int(1))), "FOO($1, NULL, MAX($2))", "test", int64(1)) } + +func Test_rangePointCaster(t *testing.T) { + mainRange := Int8Range(Int8(10), Int8(12)) + exp := NewFunc("UPPER", []Expression{mainRange}, nil) + + got := rangeTypeCaster(mainRange, exp) + _, ok := got.(IntegerExpression) + if !ok { + t.Errorf("expecting to get IntegerExpression but got %v", got) + } +} diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 8c28b44..d6f0b41 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -333,6 +333,10 @@ var ( NULL = newNullLiteral() // STAR is jet equivalent of SQL * STAR = newStarLiteral() + // PLUS_INFINITY is jet equivalent for sql infinity + PLUS_INFINITY = String("infinity") + // MINUS_INFINITY is jet equivalent for sql -infinity + MINUS_INFINITY = String("-infinity") ) type nullLiteral struct { @@ -490,6 +494,11 @@ func RawDate(raw string, namedArgs ...map[string]interface{}) DateExpression { return DateExp(Raw(raw, namedArgs...)) } +// RawRange helper that for range expressions +func RawRange[T Expression](raw string, namedArgs ...map[string]interface{}) Range[T] { + return RangeExp[T](Raw(raw, namedArgs...)) +} + // UUID is a helper function to create string literal expression from uuid object // value can be any uuid type with a String method func UUID(value fmt.Stringer) StringExpression { diff --git a/internal/jet/operators.go b/internal/jet/operators.go index b73a451..bf1dedf 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -69,6 +69,16 @@ func GtEq(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, ">=") } +// Contains returns a representation of "a @> b" +func Contains(lhs Expression, rhs Expression) BoolExpression { + return newBinaryBoolOperatorExpression(lhs, rhs, "@>") +} + +// Overlap returns a representation of "a && b" +func Overlap(lhs, rhs Expression) BoolExpression { + return newBinaryBoolOperatorExpression(lhs, rhs, "&&") +} + // Add notEq returns a representation of "a + b" func Add(lhs, rhs Serializer) Expression { return NewBinaryOperatorExpression(lhs, rhs, "+") diff --git a/internal/jet/order_by_clause.go b/internal/jet/order_by_clause.go index 55fb7d2..09d3ea6 100644 --- a/internal/jet/order_by_clause.go +++ b/internal/jet/order_by_clause.go @@ -2,28 +2,78 @@ package jet // OrderByClause interface type OrderByClause interface { + // NULLS_FIRST specifies sort where null values appear before all non-null values. + // For some dialects(mysql,mariadb), which do not support NULL_FIRST, NULL_FIRST is simulated + // with additional IS_NOT_NULL expression. + // For instance, + // Rental.ReturnDate.DESC().NULLS_FIRST() + // would translate to, + // rental.return_date IS NOT NULL, rental.return_date DESC + NULLS_FIRST() OrderByClause + + // NULLS_LAST specifies sort where null values appear after all non-null values. + // For some dialects(mysql,mariadb), which do not support NULLS_LAST, NULLS_LAST is simulated + // with additional IS_NULL expression. + // For instance, + // Rental.ReturnDate.ASC().NULLS_LAST() + // would translate to, + // rental.return_date IS NULL, rental.return_date ASC + NULLS_LAST() OrderByClause + serializeForOrderBy(statement StatementType, out *SQLBuilder) } type orderByClauseImpl struct { expression Expression - ascent bool + ascending *bool + nullsFirst *bool } -func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { - if o.expression == nil { +func (ord *orderByClauseImpl) NULLS_FIRST() OrderByClause { + nullsFirst := true + ord.nullsFirst = &nullsFirst + return ord +} +func (ord *orderByClauseImpl) NULLS_LAST() OrderByClause { + nullsFirst := false + ord.nullsFirst = &nullsFirst + return ord +} + +func (ord *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { + customSerializer := out.Dialect.SerializeOrderBy() + if customSerializer != nil { + customSerializer(ord.expression, ord.ascending, ord.nullsFirst)(statement, out) + return + } + + if ord.expression == nil { panic("jet: nil expression in ORDER BY clause") } - o.expression.serializeForOrderBy(statement, out) + ord.expression.serializeForOrderBy(statement, out) - if o.ascent { - out.WriteString("ASC") - } else { - out.WriteString("DESC") + if ord.ascending != nil { + if *ord.ascending { + out.WriteString("ASC") + } else { + out.WriteString("DESC") + } + } + + if ord.nullsFirst != nil { + if *ord.nullsFirst { + out.WriteString("NULLS FIRST") + } else { + out.WriteString("NULLS LAST") + } } } -func newOrderByClause(expression Expression, ascent bool) OrderByClause { - return &orderByClauseImpl{expression: expression, ascent: ascent} +func newOrderByAscending(expression Expression, ascending bool) OrderByClause { + return &orderByClauseImpl{expression: expression, ascending: &ascending} +} + +func newOrderByNullsFirst(expression Expression, nullsFirst bool) OrderByClause { + return &orderByClauseImpl{expression: expression, nullsFirst: &nullsFirst} } diff --git a/internal/jet/range_expression.go b/internal/jet/range_expression.go new file mode 100644 index 0000000..f05fdea --- /dev/null +++ b/internal/jet/range_expression.go @@ -0,0 +1,95 @@ +package jet + +// Range Expression is interface for date range types +type Range[T Expression] interface { + Expression + + EQ(rhs Range[T]) BoolExpression + NOT_EQ(rhs Range[T]) BoolExpression + + LT(rhs Range[T]) BoolExpression + LT_EQ(rhs Range[T]) BoolExpression + GT(rhs Range[T]) BoolExpression + GT_EQ(rhs Range[T]) BoolExpression + + CONTAINS(rhs T) BoolExpression + CONTAINS_RANGE(rhs Range[T]) BoolExpression + OVERLAP(rhs Range[T]) BoolExpression + UNION(rhs Range[T]) Range[T] + INTERSECTION(rhs Range[T]) Range[T] + DIFFERENCE(rhs Range[T]) Range[T] +} + +type rangeInterfaceImpl[T Expression] struct { + parent Expression +} + +func (r *rangeInterfaceImpl[T]) EQ(rhs Range[T]) BoolExpression { + return Eq(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) NOT_EQ(rhs Range[T]) BoolExpression { + return NotEq(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) LT(rhs Range[T]) BoolExpression { + return Lt(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) LT_EQ(rhs Range[T]) BoolExpression { + return LtEq(r.parent, rhs) + +} + +func (r *rangeInterfaceImpl[T]) GT(rhs Range[T]) BoolExpression { + return Gt(r.parent, rhs) + +} + +func (r *rangeInterfaceImpl[T]) GT_EQ(rhs Range[T]) BoolExpression { + return GtEq(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) CONTAINS(rhs T) BoolExpression { + return Contains(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) CONTAINS_RANGE(rhs Range[T]) BoolExpression { + return Contains(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) OVERLAP(rhs Range[T]) BoolExpression { + return Overlap(r.parent, rhs) +} + +func (r *rangeInterfaceImpl[T]) UNION(rhs Range[T]) Range[T] { + return RangeExp[T](Add(r.parent, rhs)) +} + +func (r *rangeInterfaceImpl[T]) INTERSECTION(rhs Range[T]) Range[T] { + return RangeExp[T](Mul(r.parent, rhs)) +} + +func (r *rangeInterfaceImpl[T]) DIFFERENCE(rhs Range[T]) Range[T] { + return RangeExp[T](Sub(r.parent, rhs)) +} + +//---------------------------------------------------// + +type rangeExpressionWrapper[T Expression] struct { + rangeInterfaceImpl[T] + Expression +} + +func newRangeExpressionWrap[T Expression](expression Expression) Range[T] { + rangeExpressionWrap := rangeExpressionWrapper[T]{Expression: expression} + rangeExpressionWrap.rangeInterfaceImpl.parent = &rangeExpressionWrap + return &rangeExpressionWrap +} + +// RangeExp is range expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as range expression. +// Does not add sql cast to generated sql builder output. +func RangeExp[T Expression](expression Expression) Range[T] { + return newRangeExpressionWrap[T](expression) +} diff --git a/internal/jet/range_expression_test.go b/internal/jet/range_expression_test.go new file mode 100644 index 0000000..e5e8b75 --- /dev/null +++ b/internal/jet/range_expression_test.go @@ -0,0 +1,63 @@ +package jet + +import "testing" + +func TestRangeExpressionEQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.EQ(table2ColRange), "(table1.col_range = table2.col_range)") + assertClauseSerialize(t, table1ColRange.EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range = int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionNOT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.NOT_EQ(table2ColRange), "(table1.col_range != table2.col_range)") + assertClauseSerialize(t, table1ColRange.NOT_EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range != int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionLT(t *testing.T) { + assertClauseSerialize(t, table1ColRange.LT(table2ColRange), "(table1.col_range < table2.col_range)") + assertClauseSerialize(t, table1ColRange.LT(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range < int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionLT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.LT_EQ(table2ColRange), "(table1.col_range <= table2.col_range)") + assertClauseSerialize(t, table1ColRange.LT_EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range <= int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionGT(t *testing.T) { + assertClauseSerialize(t, table1ColRange.GT(table2ColRange), "(table1.col_range > table2.col_range)") + assertClauseSerialize(t, table1ColRange.GT(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range > int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionGT_EQ(t *testing.T) { + assertClauseSerialize(t, table1ColRange.GT_EQ(table2ColRange), "(table1.col_range >= table2.col_range)") + assertClauseSerialize(t, table1ColRange.GT_EQ(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range >= int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionCONTAINS_RANGE(t *testing.T) { + assertClauseSerialize(t, table1ColRange.CONTAINS_RANGE(table2ColRange), "(table1.col_range @> table2.col_range)") + assertClauseSerialize(t, table1ColRange.CONTAINS_RANGE(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range @> int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionCONTAINS(t *testing.T) { + assertClauseSerialize(t, table1ColRange.CONTAINS(table2Col3), "(table1.col_range @> table2.col3)") + assertClauseSerialize(t, table1ColRange.CONTAINS(Int8(1)), "(table1.col_range @> $1)", int8(1)) +} + +func TestRangeExpressionOVERLAP(t *testing.T) { + assertClauseSerialize(t, table1ColRange.OVERLAP(table2ColRange), "(table1.col_range && table2.col_range)") + assertClauseSerialize(t, table1ColRange.OVERLAP(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range && int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionUNION(t *testing.T) { + assertClauseSerialize(t, table1ColRange.UNION(table2ColRange), "(table1.col_range + table2.col_range)") + assertClauseSerialize(t, table1ColRange.UNION(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range + int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionINTERSECTION(t *testing.T) { + assertClauseSerialize(t, table1ColRange.INTERSECTION(table2ColRange), "(table1.col_range * table2.col_range)") + assertClauseSerialize(t, table1ColRange.INTERSECTION(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range * int8range($1, $2, $3))", int8(1), int8(4), "[)") +} + +func TestRangeExpressionDIFFERENCE(t *testing.T) { + assertClauseSerialize(t, table1ColRange.DIFFERENCE(table2ColRange), "(table1.col_range - table2.col_range)") + assertClauseSerialize(t, table1ColRange.DIFFERENCE(Int8Range(Int8(1), Int8(4), String("[)"))), "(table1.col_range - int8range($1, $2, $3))", int8(1), int8(4), "[)") +} diff --git a/internal/jet/select_lock.go b/internal/jet/select_lock.go index 4694ee3..c9d431a 100644 --- a/internal/jet/select_lock.go +++ b/internal/jet/select_lock.go @@ -4,12 +4,14 @@ package jet type RowLock interface { Serializer + OF(...Table) RowLock NOWAIT() RowLock SKIP_LOCKED() RowLock } type selectLockImpl struct { lockStrength string + of []Table noWait, skipLocked bool } @@ -20,10 +22,15 @@ func NewRowLock(name string) func() RowLock { } } -func newSelectLock(lockStrength string) RowLock { +func newSelectLock(lockStrength string) *selectLockImpl { return &selectLockImpl{lockStrength: lockStrength} } +func (s *selectLockImpl) OF(tables ...Table) RowLock { + s.of = tables + return s +} + func (s *selectLockImpl) NOWAIT() RowLock { s.noWait = true return s @@ -37,6 +44,23 @@ func (s *selectLockImpl) SKIP_LOCKED() RowLock { func (s *selectLockImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(s.lockStrength) + if len(s.of) > 0 { + out.WriteString("OF") + + for i, of := range s.of { + if i > 0 { + out.WriteString(", ") + } + + table := of.Alias() + if table == "" { + table = of.TableName() + } + + out.WriteIdentifier(table) + } + } + if s.noWait { out.WriteString("NOWAIT") } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 93a1d3b..9d36de4 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -44,6 +44,10 @@ func Serialize(exp Serializer, statementType StatementType, out *SQLBuilder, opt exp.serialize(statementType, out, options...) } +func SerializeForOrderBy(exp Expression, statementType StatementType, out *SQLBuilder) { + exp.serializeForOrderBy(statementType, out) +} + func contains(options []SerializeOption, option SerializeOption) bool { for _, opt := range options { if opt == option { diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 43f1f5d..d91e30f 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -25,8 +25,9 @@ var ( table1ColTimestampz = TimestampzColumn("col_timestampz") table1ColBool = BoolColumn("col_bool") table1ColDate = DateColumn("col_date") + table1ColRange = RangeColumn[IntegerExpression]("col_range") ) -var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz) +var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColRange, table1ColTimestamp, table1ColTimestampz) var ( table2Col3 = IntegerColumn("col3") @@ -40,8 +41,9 @@ var ( table2ColTimestamp = TimestampColumn("col_timestamp") table2ColTimestampz = TimestampzColumn("col_timestampz") table2ColDate = DateColumn("col_date") + table2ColRange = RangeColumn[IntegerExpression]("col_range") ) -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColRange, table2ColTimestamp, table2ColTimestampz) var ( table3Col1 = IntegerColumn("col1") diff --git a/internal/utils/dbidentifier/dbidentifier.go b/internal/utils/dbidentifier/dbidentifier.go index c4278f6..7aee4a9 100644 --- a/internal/utils/dbidentifier/dbidentifier.go +++ b/internal/utils/dbidentifier/dbidentifier.go @@ -3,6 +3,7 @@ package dbidentifier import ( "github.com/go-jet/jet/v2/internal/3rdparty/snaker" "strings" + "unicode" ) // ToGoIdentifier converts database identifier to Go identifier. @@ -15,10 +16,92 @@ func ToGoFileName(databaseIdentifier string) string { return strings.ToLower(replaceInvalidChars(databaseIdentifier)) } -func replaceInvalidChars(str string) string { - str = strings.Replace(str, " ", "_", -1) - str = strings.Replace(str, "-", "_", -1) - str = strings.Replace(str, ".", "_", -1) +func replaceInvalidChars(identifier string) string { + increase, needs := needsCharReplacement(identifier) - return str + if !needs { + return identifier + } + + var b strings.Builder + + b.Grow(len(identifier) + increase) + + for _, c := range identifier { + switch { + case unicode.IsSpace(c): + b.WriteByte('_') + case unicode.IsControl(c): + continue + default: + replacement, ok := asciiCharacterReplacement[c] + + if ok { + b.WriteByte('_') + b.WriteString(replacement) + b.WriteByte('_') + } else { + b.WriteRune(c) + } + } + + } + + return b.String() +} + +func needsCharReplacement(identifier string) (increase int, needs bool) { + for _, c := range identifier { + switch { + case unicode.IsSpace(c): + needs = true + case unicode.IsControl(c): + increase += -1 + needs = true + continue + default: + replacement, ok := asciiCharacterReplacement[c] + + if ok { + increase += len(replacement) + 1 + needs = true + } + } + } + + return increase, needs +} + +var asciiCharacterReplacement = map[rune]string{ + '!': "exclamation", + '"': "quotation", + '#': "number", + '$': "dollar", + '%': "percent", + '&': "ampersand", + '\'': "apostrophe", + '(': "opening_parentheses", + ')': "closing_parentheses", + '*': "asterisk", + '+': "plus", + ',': "comma", + '-': "_", + '.': "_", + '/': "slash", + ':': "colon", + ';': "semicolon", + '<': "less", + '=': "equal", + '>': "greater", + '?': "question", + '@': "at", + '[': "opening_bracket", + '\\': "backslash", + ']': "closing_bracket", + '^': "caret", + '`': "accent", + '{': "opening_braces", + '|': "vertical_bar", + '}': "closing_braces", + '~': "tilde", } diff --git a/internal/utils/dbidentifier/dbidentifier_test.go b/internal/utils/dbidentifier/dbidentifier_test.go index 339fb5a..668bf56 100644 --- a/internal/utils/dbidentifier/dbidentifier_test.go +++ b/internal/utils/dbidentifier/dbidentifier_test.go @@ -8,6 +8,7 @@ import ( func TestToGoIdentifier(t *testing.T) { require.Equal(t, ToGoIdentifier(""), "") require.Equal(t, ToGoIdentifier("uuid"), "UUID") + require.Equal(t, ToGoIdentifier("uuid_ptr"), "UUIDPtr") require.Equal(t, ToGoIdentifier("col1"), "Col1") require.Equal(t, ToGoIdentifier("PG-13"), "Pg13") require.Equal(t, ToGoIdentifier("13_pg"), "13Pg") @@ -18,8 +19,50 @@ func TestToGoIdentifier(t *testing.T) { require.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE") require.Equal(t, ToGoIdentifier("my_table"), "MyTable") + require.Equal(t, ToGoIdentifier("my_____table"), "MyTable") require.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable") require.Equal(t, ToGoIdentifier("My_Table"), "MyTable") require.Equal(t, ToGoIdentifier("My Table"), "MyTable") require.Equal(t, ToGoIdentifier("My-Table"), "MyTable") + + require.Equal(t, ToGoIdentifier("EN\bUM"), "Enum") // control character + require.Equal(t, ToGoIdentifier("EN\tUM"), "EnUm") // space character + require.Equal(t, ToGoIdentifier("S3:INIT"), "S3ColonInit") // replacement chars + require.Equal(t, ToGoIdentifier("Entity-"), "Entity") + require.Equal(t, ToGoIdentifier("Entity+"), "EntityPlus") + require.Equal(t, ToGoIdentifier("="), "Equal") + require.Equal(t, ToGoIdentifier("<="), "LessEqual") + require.Equal(t, ToGoIdentifier(">="), "GreaterEqual") + require.Equal(t, ToGoIdentifier("some#$%name"), "SomeNumberDollarPercentName") + require.Equal(t, ToGoIdentifier(`An!"them`), "AnExclamationQuotationThem") + require.Equal(t, ToGoIdentifier(`An(Um)`), + "AnOpeningParenthesesUmClosingParentheses") +} + +func TestNeedsCharReplacement(t *testing.T) { + increase, needs := needsCharReplacement("some_name") + require.False(t, needs) + require.Zero(t, increase) + + increase, needs = needsCharReplacement("some name") + require.True(t, needs) + require.Zero(t, increase) + + increase, needs = needsCharReplacement("some\bname") + require.True(t, needs) + require.Equal(t, increase, -1) + + increase, needs = needsCharReplacement("some#$%name") + require.True(t, needs) + require.Equal(t, increase, 22) +} + +func TestToGoFileName(t *testing.T) { + require.Equal(t, ToGoFileName("FileName"), "filename") + require.Equal(t, ToGoFileName("File_Name"), "file_name") + require.Equal(t, ToGoFileName("File___Name__"), "file___name__") + require.Equal(t, ToGoFileName("File___Name__"), "file___name__") + require.Equal(t, ToGoFileName("File\bName"), "filename") + require.Equal(t, ToGoFileName("File\tName"), "file_name") + require.Equal(t, ToGoFileName("File^^Name"), "file_caret__caret_name") } diff --git a/internal/utils/semantic/version.go b/internal/utils/semantic/version.go new file mode 100644 index 0000000..b13413f --- /dev/null +++ b/internal/utils/semantic/version.go @@ -0,0 +1,66 @@ +package semantic + +import ( + "fmt" + "strconv" + "strings" +) + +// Version struct holds semantic versioning information +type Version struct { + Major int + Minor int + Patch int +} + +// VersionFromString creates new semantic Version by parsing version string +func VersionFromString(version string) (Version, error) { + parts := strings.Split(version, ".") + + var ret Version + + if len(parts) > 0 { + major, err := strconv.Atoi(parts[0]) + + if err != nil { + return ret, fmt.Errorf("major is not a number: %w", err) + } + + ret.Major = major + } + + if len(parts) > 1 { + minor, err := strconv.Atoi(parts[1]) + + if err != nil { + return ret, fmt.Errorf("minor is not a number: %w", err) + } + + ret.Minor = minor + } + + if len(parts) > 2 { + patch, err := strconv.Atoi(parts[2]) + + if err != nil { + return ret, fmt.Errorf("patch is not a number: %w", err) + } + + ret.Patch = patch + } + + return ret, nil +} + +// Lt returns true if this version is less than version parameter +func (v Version) Lt(version Version) bool { + if v.Major < version.Major { + return true + } + + if v.Minor < version.Minor { + return true + } + + return v.Patch < version.Patch +} diff --git a/mysql/dialect.go b/mysql/dialect.go index 24b8755..18d2eec 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -26,7 +26,8 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, - ReservedWords: reservedWords, + ReservedWords: reservedWords, + SerializeOrderBy: serializeOrderBy, } return jet.NewDialect(mySQLDialectParams) @@ -162,6 +163,53 @@ func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFun } } +func serializeOrderBy(expression Expression, ascending, nullsFirst *bool) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + + if nullsFirst == nil { + jet.SerializeForOrderBy(expression, statement, out) + + if ascending != nil { + serializeAscending(*ascending, out) + } + return + } + + asc := true + + if ascending != nil { + asc = *ascending + } + + if asc { + if !*nullsFirst { + jet.SerializeForOrderBy(expression.IS_NULL(), statement, out) + out.WriteString(", ") + } + jet.SerializeForOrderBy(expression, statement, out) + if ascending != nil { + serializeAscending(asc, out) + } + } else { + if *nullsFirst { + jet.SerializeForOrderBy(expression.IS_NOT_NULL(), statement, out) + out.WriteString(", ") + } + + jet.SerializeForOrderBy(expression, statement, out) + serializeAscending(asc, out) + } + } +} + +func serializeAscending(ascending bool, out *jet.SQLBuilder) { + if ascending { + out.WriteString("ASC") + } else { + out.WriteString("DESC") + } +} + var reservedWords = []string{ "ACCESSIBLE", "ADD", diff --git a/mysql/functions.go b/mysql/functions.go index 4eba942..ca31d18 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -222,6 +222,12 @@ var SUBSTR = jet.SUBSTR // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. var REGEXP_LIKE = jet.REGEXP_LIKE +// UUID_TO_BIN is a helper function that calls "uuid_to_bin" function on the passed value. +func UUID_TO_BIN(str StringExpression) StringExpression { + fn := Func("uuid_to_bin", str) + return StringExp(fn) +} + //----------------- Date/Time Functions and Operators ------------// // EXTRACT function retrieves subfields such as year or hour from date/time values diff --git a/mysql/functions_test.go b/mysql/functions_test.go new file mode 100644 index 0000000..197b515 --- /dev/null +++ b/mysql/functions_test.go @@ -0,0 +1,11 @@ +package mysql + +import ( + "testing" + + "github.com/google/uuid" +) + +func TestUUIDToBin(t *testing.T) { + assertSerialize(t, UUID_TO_BIN(String(uuid.Nil.String())), `uuid_to_bin(?)`, uuid.Nil.String()) +} diff --git a/mysql/literal.go b/mysql/literal.go index 1c69c31..ca720c8 100644 --- a/mysql/literal.go +++ b/mysql/literal.go @@ -1,8 +1,9 @@ package mysql import ( - "github.com/go-jet/jet/v2/internal/jet" "time" + + "github.com/go-jet/jet/v2/internal/jet" ) // Keywords diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 45c7782..6c5f345 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -86,7 +86,6 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta newSelect.From.Tables = []jet.Serializer{table} } newSelect.Limit.Count = -1 - newSelect.Offset.Count = -1 newSelect.ShareLock.Name = "LOCK IN SHARE MODE" newSelect.ShareLock.InNewLine = true @@ -158,7 +157,7 @@ func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { } func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { - s.Offset.Count = offset + s.Offset.Count = Int(offset) return s } diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go index bd3a0e9..1136dd7 100644 --- a/mysql/select_statement_test.go +++ b/mysql/select_statement_test.go @@ -122,6 +122,11 @@ FOR UPDATE; SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR SHARE NOWAIT; +`) + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE().OF(table1).NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE OF table1 NOWAIT; `) } diff --git a/mysql/set_statement.go b/mysql/set_statement.go index ec9f8fa..2df75a0 100644 --- a/mysql/set_statement.go +++ b/mysql/set_statement.go @@ -63,7 +63,6 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat newSetStatement.setOperator.All = all newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Limit.Count = -1 - newSetStatement.setOperator.Offset.Count = -1 newSetStatement.setOperatorsImpl.parent = newSetStatement @@ -81,7 +80,7 @@ func (s *setStatementImpl) LIMIT(limit int64) setStatement { } func (s *setStatementImpl) OFFSET(offset int64) setStatement { - s.setOperator.Offset.Count = offset + s.setOperator.Offset.Count = Int(offset) return s } diff --git a/postgres/columns.go b/postgres/columns.go index a25a88c..aee2896 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -65,6 +65,42 @@ type ColumnTimestampz = jet.ColumnTimestampz // TimestampzColumn creates named timestamp with time zone column. var TimestampzColumn = jet.TimestampzColumn +// ColumnDateRange is interface of SQL date range column +type ColumnDateRange = jet.ColumnRange[DateExpression] + +// DateRangeColumn creates named range with range column +var DateRangeColumn = jet.RangeColumn[DateExpression] + +// ColumnNumericRange is interface of SQL numeric range column +type ColumnNumericRange = jet.ColumnRange[NumericExpression] + +// NumericRangeColumn creates named range with range column +var NumericRangeColumn = jet.RangeColumn[NumericExpression] + +// ColumnTimestampRange is interface of SQL timestamp range column +type ColumnTimestampRange = jet.ColumnRange[TimestampExpression] + +// TimestampRangeColumn creates named range with range column +var TimestampRangeColumn = jet.RangeColumn[TimestampExpression] + +// ColumnTimestampzRange is interface of SQL timestamp range column +type ColumnTimestampzRange = jet.ColumnRange[TimestampzExpression] + +// TimestampzRangeColumn creates named range with range column +var TimestampzRangeColumn = jet.RangeColumn[TimestampzExpression] + +// ColumnInt4Range is interface of SQL int range column +type ColumnInt4Range = jet.ColumnRange[IntegerExpression] + +// Int4RangeColumn creates named range with range column +var Int4RangeColumn = jet.RangeColumn[IntegerExpression] + +// ColumnInt8Range is interface of SQL int range column +type ColumnInt8Range = jet.ColumnRange[IntegerExpression] + +// Int8RangeColumn creates named range with range column +var Int8RangeColumn = jet.RangeColumn[IntegerExpression] + //------------------------------------------------------// // ColumnInterval is interface of PostgreSQL interval columns. diff --git a/postgres/expressions.go b/postgres/expressions.go index faf153d..a903860 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -36,6 +36,24 @@ type TimestampExpression = jet.TimestampExpression // TimestampzExpression interface type TimestampzExpression = jet.TimestampzExpression +// DateRange Expression interface +type DateRange = jet.Range[DateExpression] + +// TimestampRange Expression interface +type TimestampRange = jet.Range[TimestampExpression] + +// TimestampzRange Expression interface +type TimestampzRange = jet.Range[TimestampzExpression] + +// NumericRange Expression interface +type NumericRange = jet.Range[NumericExpression] + +// Int4Range Expression interface +type Int4Range = jet.Range[IntegerExpression] + +// Int8Range Expression interface +type Int8Range = jet.Range[IntegerExpression] + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -81,6 +99,13 @@ var TimestampExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampzExp = jet.TimestampzExp +// RangeExp is range expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as range expression. +// Does not add sql cast to generated sql builder output. +func RangeExp[T Expression](expression T) jet.Range[T] { + return jet.RangeExp[T](expression) +} + // RawArgs is type used to pass optional arguments to Raw method type RawArgs = map[string]interface{} @@ -90,15 +115,21 @@ type RawArgs = map[string]interface{} var ( Raw = jet.Raw - RawBool = jet.RawBool - RawInt = jet.RawInt - RawFloat = jet.RawFloat - RawString = jet.RawString - RawTime = jet.RawTime - RawTimez = jet.RawTimez - RawTimestamp = jet.RawTimestamp - RawTimestampz = jet.RawTimestampz - RawDate = jet.RawDate + RawBool = jet.RawBool + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimez = jet.RawTimez + RawTimestamp = jet.RawTimestamp + RawTimestampz = jet.RawTimestampz + RawDate = jet.RawDate + RawNumRange = jet.RawRange[jet.NumericExpression] + RawInt4Range = jet.RawRange[jet.IntegerExpression] + RawInt8Range = jet.RawRange[jet.IntegerExpression] + RawTimestampRange = jet.RawRange[jet.TimestampExpression] + RawTimestampzRange = jet.RawRange[jet.TimestampzExpression] + RawDateRange = jet.RawRange[jet.DateExpression] ) // Func can be used to call custom or unsupported database functions. diff --git a/postgres/functions.go b/postgres/functions.go index 43a5f39..eddc930 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -267,6 +267,18 @@ var TO_HEX = jet.TO_HEX //----------Data Type Formatting Functions ----------------------// +// LOWER_BOUND returns range expressions lower bound +func LOWER_BOUND[T Expression](expression jet.Range[T]) T { + return jet.LOWER_BOUND[T](expression) +} + +// UPPER_BOUND returns range expressions upper bound +func UPPER_BOUND[T Expression](expression jet.Range[T]) T { + return jet.UPPER_BOUND[T](expression) +} + +//----------Data Type Formatting Functions ----------------------// + // TO_CHAR converts expression to string with format var TO_CHAR = jet.TO_CHAR @@ -421,3 +433,18 @@ var CUBE = jet.CUBE // It can be also used with multiple parameters to check if a set of columns is included in the current grouping set. The result // of the GROUPING function would then be an integer bit mask having 1’s for the arguments which have GROUPING(argument) as 1. var GROUPING = jet.GROUPING + +var ( + // DATE_RANGE constructor function to create a date range + DATE_RANGE = jet.DateRange + // NUM_Range constructor function to create a numeric range + NUM_Range = jet.NumRange + // TIMESTAMP_RANGE constructor function to create a timestamp range + TIMESTAMP_RANGE = jet.TimestampRange + // TIMESTAMPTZ_RANGE constructor function to create a timestampz range + TIMESTAMPTZ_RANGE = jet.TimestampzRange + // INT4_RANGE constructor function to create a int4 range + INT4_RANGE = jet.Int4Range + // INT8_RANGE constructor function to create a int8 range + INT8_RANGE = jet.Int8Range +) diff --git a/postgres/keywords.go b/postgres/keywords.go index cfc90a2..a468cc5 100644 --- a/postgres/keywords.go +++ b/postgres/keywords.go @@ -12,4 +12,8 @@ var ( NULL = jet.NULL // STAR is jet equivalent of SQL * STAR = jet.STAR + // PLUS_INFINITY is jet equivalent for sql infinity + PLUS_INFINITY = jet.PLUS_INFINITY + // MINUS_INFINITY is jet equivalent for sql -infinity + MINUS_INFINITY = jet.MINUS_INFINITY ) diff --git a/postgres/select_statement.go b/postgres/select_statement.go index d44d6aa..70a9a50 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -53,6 +53,9 @@ type SelectStatement interface { ORDER_BY(orderByClauses ...OrderByClause) SelectStatement LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement + // OFFSET_e can be used when an integer expression is needed as offset, otherwise OFFSET can be used + OFFSET_e(offset IntegerExpression) SelectStatement + FETCH_FIRST(count IntegerExpression) fetchExpand FOR(lock RowLock) SelectStatement UNION(rhs SelectStatement) setStatement @@ -72,16 +75,24 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} - newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, - &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, - &newSelect.Limit, &newSelect.Offset, &newSelect.For) + newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, + &newSelect.Select, + &newSelect.From, + &newSelect.Where, + &newSelect.GroupBy, + &newSelect.Having, + &newSelect.Window, + &newSelect.OrderBy, + &newSelect.Limit, + &newSelect.Offset, + &newSelect.Fetch, + &newSelect.For) newSelect.Select.ProjectionList = projections if table != nil { newSelect.From.Tables = []jet.Serializer{table} } newSelect.Limit.Count = -1 - newSelect.Offset.Count = -1 newSelect.setOperatorsImpl.parent = newSelect @@ -101,6 +112,7 @@ type selectStatementImpl struct { OrderBy jet.ClauseOrderBy Limit jet.ClauseLimit Offset jet.ClauseOffset + Fetch jet.ClauseFetch For jet.ClauseFor } @@ -146,10 +158,23 @@ func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { } func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { + s.Offset.Count = Int(offset) + return s +} + +func (s *selectStatementImpl) OFFSET_e(offset IntegerExpression) SelectStatement { s.Offset.Count = offset return s } +func (s *selectStatementImpl) FETCH_FIRST(count IntegerExpression) fetchExpand { + s.Fetch.Count = count + + return fetchExpand{ + selectStatement: s, + } +} + func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { s.For.Lock = lock return s @@ -188,3 +213,19 @@ func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { } return ret } + +type fetchExpand struct { + selectStatement *selectStatementImpl +} + +func (f fetchExpand) ROWS_ONLY() SelectStatement { + f.selectStatement.Fetch.WithTies = false + + return f.selectStatement +} + +func (f fetchExpand) ROWS_WITH_TIES() SelectStatement { + f.selectStatement.Fetch.WithTies = true + + return f.selectStatement +} diff --git a/postgres/select_statement_test.go b/postgres/select_statement_test.go index b487f90..5dcacf0 100644 --- a/postgres/select_statement_test.go +++ b/postgres/select_statement_test.go @@ -132,5 +132,11 @@ FOR KEY SHARE NOWAIT; SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR NO KEY UPDATE SKIP LOCKED; +`) + + assertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE().OF(table1, table2).NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE OF table1, table2 NOWAIT; `) } diff --git a/postgres/set_statement.go b/postgres/set_statement.go index 236f8de..834560d 100644 --- a/postgres/set_statement.go +++ b/postgres/set_statement.go @@ -45,6 +45,8 @@ type setStatement interface { LIMIT(limit int64) setStatement OFFSET(offset int64) setStatement + // OFFSET_e can be used when an integer expression is needed as offset, otherwise OFFSET can be used + OFFSET_e(offset IntegerExpression) setStatement AsTable(alias string) SelectTable } @@ -107,7 +109,6 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat newSetStatement.setOperator.All = all newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Limit.Count = -1 - newSetStatement.setOperator.Offset.Count = -1 newSetStatement.setOperatorsImpl.parent = newSetStatement @@ -125,6 +126,11 @@ func (s *setStatementImpl) LIMIT(limit int64) setStatement { } func (s *setStatementImpl) OFFSET(offset int64) setStatement { + s.setOperator.Offset.Count = Int(offset) + return s +} + +func (s *setStatementImpl) OFFSET_e(offset IntegerExpression) setStatement { s.setOperator.Offset.Count = offset return s } diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 292d7e4..96bb13b 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -17,6 +17,7 @@ var table1ColTimestampz = TimestampzColumn("col_timestampz") var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") var table1ColInterval = IntervalColumn("col_interval") +var table1ColRange = Int8RangeColumn("col_range") var table1 = NewTable( "db", @@ -32,6 +33,7 @@ var table1 = NewTable( table1ColTimestamp, table1ColTimestampz, table1ColInterval, + table1ColRange, ) var table2Col3 = IntegerColumn("col3") @@ -46,8 +48,9 @@ var table2ColTimestamp = TimestampColumn("col_timestamp") var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColDate = DateColumn("col_date") var table2ColInterval = IntervalColumn("col_interval") +var table2ColRange = Int8RangeColumn("col_range") -var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval) +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz, table2ColInterval, table2ColRange) var table3Col1 = IntegerColumn("col1") var table3ColInt = IntegerColumn("col_int") diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go index 5e92d52..e74cda6 100644 --- a/sqlite/select_statement.go +++ b/sqlite/select_statement.go @@ -74,7 +74,6 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta newSelect.From.Tables = []jet.Serializer{table} } newSelect.Limit.Count = -1 - newSelect.Offset.Count = -1 newSelect.ShareLock.Name = "LOCK IN SHARE MODE" newSelect.ShareLock.InNewLine = true @@ -141,7 +140,7 @@ func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { } func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { - s.Offset.Count = offset + s.Offset.Count = Int(offset) return s } diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go index 18bcca5..0a004bf 100644 --- a/sqlite/set_statement.go +++ b/sqlite/set_statement.go @@ -63,7 +63,6 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat newSetStatement.setOperator.All = all newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Limit.Count = -1 - newSetStatement.setOperator.Offset.Count = -1 newSetStatement.setOperator.SkipSelectWrap = true newSetStatement.setOperatorsImpl.parent = newSetStatement @@ -82,7 +81,7 @@ func (s *setStatementImpl) LIMIT(limit int64) setStatement { } func (s *setStatementImpl) OFFSET(offset int64) setStatement { - s.setOperator.Offset.Count = offset + s.setOperator.Offset.Count = Int(offset) return s } diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 9c562fb..9b3af50 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -39,14 +39,14 @@ services: - ./testdata/init/mysql:/docker-entrypoint-initdb.d cockroach: - image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 + image: cockroachdb/cockroach-unstable:v23.1.0-rc.2 environment: - COCKROACH_USER=jet - COCKROACH_PASSWORD=jet - COCKROACH_DATABASE=jetdb ports: - "26257:26257" - command: start-single-node --insecure + command: start-single-node --accept-sql-without-tls # volumes: # - ./testdata/init/cockroach:/docker-entrypoint-initdb.d diff --git a/tests/init/init.go b/tests/init/init.go index 2543eb1..5a21ee7 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -78,6 +78,7 @@ func main() { if err != nil { fmt.Println(errfmt.Trace(err)) + os.Exit(1) } } diff --git a/tests/mysql/generator_template_test.go b/tests/mysql/generator_template_test.go index a47a21d..a257f31 100644 --- a/tests/mysql/generator_template_test.go +++ b/tests/mysql/generator_template_test.go @@ -285,6 +285,29 @@ func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { require.Contains(t, mpaaRating, "var FilmRatingEnumSQLBuilder = &struct {") } +func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection("dvds"), + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + if table.Name == "actor" { + return template.DefaultTableSQLBuilder(table).UseDefaultAlias("actors") + } + return template.DefaultTableSQLBuilder(table) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "actor.go") + require.Contains(t, actor, "var Actor = newActorTable(\"dvds\", \"actor\", \"actors\")") +} + func TestGeneratorTemplate_Model_AddTags(t *testing.T) { err := mysql2.Generate( diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 4009bb8..8e03f82 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -2,6 +2,8 @@ package mysql import ( "context" + "database/sql" + "github.com/go-jet/jet/v2/postgres" "strings" "testing" "time" @@ -295,6 +297,296 @@ ORDER BY inventory.film_id, inventory.store_id; `) } +func TestOrderBy(t *testing.T) { + // NULLS_FIRST and NULLS_LAST are simulated using IS_NULL, ... + + ensureNullsFirstRentalResult := func(t *testing.T, stmt postgres.Statement, asc bool) { + var dest []model.Rental + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 200) + require.Nil(t, dest[0].ReturnDate) + require.NotNil(t, dest[199].ReturnDate) + + if asc { + require.True(t, dest[198].ReturnDate.Before(*dest[199].ReturnDate)) + } else { + require.True(t, dest[199].ReturnDate.Before(*dest[198].ReturnDate)) + } + } + + ensureNullsLastRentalResult := func(t *testing.T, stmt postgres.Statement, asc bool) { + var dest []model.Rental + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 200) + require.NotNil(t, dest[0].ReturnDate) + require.Nil(t, dest[199].ReturnDate) + + if asc { + require.True(t, dest[0].ReturnDate.Before(*dest[1].ReturnDate)) + } else { + require.True(t, dest[1].ReturnDate.Before(*dest[0].ReturnDate)) + } + } + + t.Run("default", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate, + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date +LIMIT 200; +`) + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date +LIMIT 200; +`) + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date IS NULL, rental.return_date +LIMIT 200 +OFFSET 15800; +`) + ensureNullsLastRentalResult(t, stmt, true) + }) + + t.Run("ASC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("ASC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + + ensureNullsFirstRentalResult(t, stmt, true) + }) + + t.Run("ASC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date IS NULL, rental.return_date ASC +LIMIT 200 +OFFSET 15800; +`) + + ensureNullsLastRentalResult(t, stmt, true) + }) + + t.Run("DESC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + ensureNullsLastRentalResult(t, stmt, false) + }) + + t.Run("DESC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + ensureNullsLastRentalResult(t, stmt, false) + }) + + t.Run("DESC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date IS NOT NULL, rental.return_date DESC +LIMIT 200; +`) + + ensureNullsFirstRentalResult(t, stmt, false) + }) + + t.Run("complex", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.RentalID.DESC(), + Rental.ReturnDate.DESC().NULLS_FIRST(), + Rental.LastUpdate.ASC(), + Rental.InventoryID.ADD(Rental.RentalID).ASC().NULLS_LAST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.rental_id DESC, rental.return_date IS NOT NULL, rental.return_date DESC, rental.last_update ASC, (rental.inventory_id + rental.rental_id) IS NULL, rental.inventory_id + rental.rental_id ASC +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + + }) + +} + func TestSubQuery(t *testing.T) { rRatingFilms := SELECT( @@ -632,6 +924,86 @@ FOR` } } +func TestRowLockWithUpdateOf(t *testing.T) { + skipForMariaDB(t) // MariaDB does not support UPDATE OF + + stmt := SELECT( + Film.FilmID, + Film.Title, + Actor.ActorID, + Actor.FirstName, + ).FROM( + Film. + INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)). + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)), + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(Film, Actor).NOWAIT(), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name" +FROM dvds.film + INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id) + INNER JOIN dvds.film_actor ON (film_actor.film_id = film.film_id) + INNER JOIN dvds.actor ON (actor.actor_id = film_actor.actor_id) +LIMIT 1 +FOR UPDATE OF film, actor NOWAIT; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film + CategoryID int + Actor []model.Actor + } + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) +} + +func TestRowLockWithUpdateOfAliasedTable(t *testing.T) { + skipForMariaDB(t) // MariaDB does not support UPDATE OF + + myFilm := Film.AS("myFilm") + + stmt := SELECT( + myFilm.FilmID, + myFilm.Title, + ).FROM( + myFilm, + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(myFilm), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +SELECT ''myFilm''.film_id AS "myFilm.film_id", + ''myFilm''.title AS "myFilm.title" +FROM dvds.film AS ''myFilm'' +LIMIT 1 +FOR UPDATE OF ''myFilm''; +`, "''", "`")) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film `alias:"myFilm.*"` + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) +} + func TestExpressionWrappers(t *testing.T) { query := SELECT( BoolExp(Raw("true")), diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 8a11191..d41feee 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -17,31 +17,52 @@ import ( "github.com/go-jet/jet/v2/tests/testdata/results/common" ) +var AllTypesAllColumns = AllTypes.AllColumns. + Except(IntegerColumn("rowid")) // cockroachDB: exclude rowid column + func TestAllTypesSelect(t *testing.T) { var dest []model.AllTypes - err := AllTypes.SELECT( - AllTypesAllColumns, - ).LIMIT(2). + err := AllTypes.SELECT(AllTypesAllColumns). + LIMIT(2). Query(db, &dest) - require.NoError(t, err) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes - var dest []AllTypesView - err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) - require.NoError(t, err) + err := SELECT(view.AllTypesView.AllColumns). + FROM(view.AllTypesView). + Query(db, &dest) + require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], AllTypesView(allTypesRow0)) testutils.AssertDeepEqual(t, dest[1], AllTypesView(allTypesRow1)) } +func TestMaterializedViewAllTypes(t *testing.T) { + stmt := SELECT( + view.AllTypesMaterializedView.AllColumns. + Except(IntegerColumn("rowid")), // cockroachDB: exclude rowid column + ).FROM( + view.AllTypesMaterializedView, + ) + + type AllTypesMaterializedView model.AllTypes + var dest []AllTypesMaterializedView + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest[0], AllTypesMaterializedView(allTypesRow0)) + testutils.AssertDeepEqual(t, dest[1], AllTypesMaterializedView(allTypesRow1)) +} + func TestAllTypesInsertModel(t *testing.T) { skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008) @@ -64,8 +85,6 @@ func TestAllTypesInsertModel(t *testing.T) { }) } -var AllTypesAllColumns = AllTypes.AllColumns.Except(IntegerColumn("rowid")) - func TestAllTypesInsertQuery(t *testing.T) { query := AllTypes.INSERT(AllTypesAllColumns). QUERY( @@ -230,7 +249,9 @@ SELECT "allTypesSubQuery"."all_types.small_int_ptr" AS "all_types.small_int_ptr" "allTypesSubQuery"."all_types.text_array" AS "all_types.text_array", "allTypesSubQuery"."all_types.jsonb_array" AS "all_types.jsonb_array", "allTypesSubQuery"."all_types.text_multi_dim_array_ptr" AS "all_types.text_multi_dim_array_ptr", - "allTypesSubQuery"."all_types.text_multi_dim_array" AS "all_types.text_multi_dim_array" + "allTypesSubQuery"."all_types.text_multi_dim_array" AS "all_types.text_multi_dim_array", + "allTypesSubQuery"."all_types.mood_ptr" AS "all_types.mood_ptr", + "allTypesSubQuery"."all_types.mood" AS "all_types.mood" FROM ( SELECT all_types.small_int_ptr AS "all_types.small_int_ptr", all_types.small_int AS "all_types.small_int", @@ -292,7 +313,9 @@ FROM ( all_types.text_array AS "all_types.text_array", all_types.jsonb_array AS "all_types.jsonb_array", all_types.text_multi_dim_array_ptr AS "all_types.text_multi_dim_array_ptr", - all_types.text_multi_dim_array AS "all_types.text_multi_dim_array" + all_types.text_multi_dim_array AS "all_types.text_multi_dim_array", + all_types.mood_ptr AS "all_types.mood_ptr", + all_types.mood AS "all_types.mood" FROM test_sample.all_types ) AS "allTypesSubQuery" LIMIT 2; @@ -1279,6 +1302,8 @@ RETURNING all_types.json AS "all_types.json"; }) } +var moodSad = model.Mood_Sad + var allTypesRow0 = model.AllTypes{ SmallIntPtr: testutils.Int16Ptr(14), SmallInt: 14, @@ -1343,6 +1368,8 @@ var allTypesRow0 = model.AllTypes{ JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", + MoodPtr: &moodSad, + Mood: model.Mood_Happy, } var allTypesRow1 = model.AllTypes{ @@ -1409,6 +1436,8 @@ var allTypesRow1 = model.AllTypes{ JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, TextMultiDimArrayPtr: nil, TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", + MoodPtr: nil, + Mood: model.Mood_Ok, } func TestAliasedDuplicateSliceSubType(t *testing.T) { diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 4f2f0da..2bc2d32 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -341,6 +341,29 @@ func UseSchema(schema string) { `) } +func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + if table.Name == "actor" { + return template.DefaultTableSQLBuilder(table).UseDefaultAlias("actors") + } + return template.DefaultTableSQLBuilder(table) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "actor.go") + require.Contains(t, actor, "var Actor = newActorTable(\"dvds\", \"actor\", \"actors\")") +} + func TestGeneratorTemplate_Model_AddTags(t *testing.T) { err := postgres.Generate( diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index a0257e7..479d82f 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -548,11 +548,12 @@ func UseSchema(schema string) { ` func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { - skipForCockroachDB(t) + skipForCockroachDB(t) // because of rowid column enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/") modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/") tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/") + viewDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/view/") testutils.AssertFileNamesEqual(t, enumDir, "mood.go", "level.go") testutils.AssertFileContent(t, enumDir+"/mood.go", moodEnumContent) @@ -560,15 +561,17 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go", - "components.go", "vulnerabilities.go") - + "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go") testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go", - "components.go", "vulnerabilities.go") - + "components.go", "vulnerabilities.go", "sample_ranges.go") testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) + testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent) + + testutils.AssertFileNamesEqual(t, viewDir, "all_types_materialized_view.go", "all_types_view.go", + "view_use_schema.go") } var moodEnumContent = ` @@ -698,6 +701,8 @@ type AllTypes struct { JsonbArray string TextMultiDimArrayPtr *string TextMultiDimArray string + MoodPtr *Mood + Mood Mood } ` @@ -782,6 +787,8 @@ type allTypesTable struct { JsonbArray postgres.ColumnString TextMultiDimArrayPtr postgres.ColumnString TextMultiDimArray postgres.ColumnString + MoodPtr postgres.ColumnString + Mood postgres.ColumnString AllColumns postgres.ColumnList MutableColumns postgres.ColumnList @@ -883,8 +890,10 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { JsonbArrayColumn = postgres.StringColumn("jsonb_array") TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr") TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array") - allColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} - mutableColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn} + MoodPtrColumn = postgres.StringColumn("mood_ptr") + MoodColumn = postgres.StringColumn("mood") + allColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn, MoodPtrColumn, MoodColumn} + mutableColumns = postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn, MoodPtrColumn, MoodColumn} ) return allTypesTable{ @@ -952,6 +961,101 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { JsonbArray: JsonbArrayColumn, TextMultiDimArrayPtr: TextMultiDimArrayPtrColumn, TextMultiDimArray: TextMultiDimArrayColumn, + MoodPtr: MoodPtrColumn, + Mood: MoodColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +var sampleRangeTableContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/postgres" +) + +var SampleRanges = newSampleRangesTable("test_sample", "sample_ranges", "") + +type sampleRangesTable struct { + postgres.Table + + // Columns + DateRange postgres.ColumnDateRange + TimestampRange postgres.ColumnTimestampRange + TimestampzRange postgres.ColumnTimestampzRange + Int4Range postgres.ColumnInt4Range + Int8Range postgres.ColumnInt8Range + NumRange postgres.ColumnNumericRange + + AllColumns postgres.ColumnList + MutableColumns postgres.ColumnList +} + +type SampleRangesTable struct { + sampleRangesTable + + EXCLUDED sampleRangesTable +} + +// AS creates new SampleRangesTable with assigned alias +func (a SampleRangesTable) AS(alias string) *SampleRangesTable { + return newSampleRangesTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new SampleRangesTable with assigned schema name +func (a SampleRangesTable) FromSchema(schemaName string) *SampleRangesTable { + return newSampleRangesTable(schemaName, a.TableName(), a.Alias()) +} + +// WithPrefix creates new SampleRangesTable with assigned table prefix +func (a SampleRangesTable) WithPrefix(prefix string) *SampleRangesTable { + return newSampleRangesTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new SampleRangesTable with assigned table suffix +func (a SampleRangesTable) WithSuffix(suffix string) *SampleRangesTable { + return newSampleRangesTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + +func newSampleRangesTable(schemaName, tableName, alias string) *SampleRangesTable { + return &SampleRangesTable{ + sampleRangesTable: newSampleRangesTableImpl(schemaName, tableName, alias), + EXCLUDED: newSampleRangesTableImpl("", "excluded", ""), + } +} + +func newSampleRangesTableImpl(schemaName, tableName, alias string) sampleRangesTable { + var ( + DateRangeColumn = postgres.DateRangeColumn("date_range") + TimestampRangeColumn = postgres.TimestampRangeColumn("timestamp_range") + TimestampzRangeColumn = postgres.TimestampzRangeColumn("timestampz_range") + Int4RangeColumn = postgres.Int4RangeColumn("int4_range") + Int8RangeColumn = postgres.Int8RangeColumn("int8_range") + NumRangeColumn = postgres.NumericRangeColumn("num_range") + allColumns = postgres.ColumnList{DateRangeColumn, TimestampRangeColumn, TimestampzRangeColumn, Int4RangeColumn, Int8RangeColumn, NumRangeColumn} + mutableColumns = postgres.ColumnList{DateRangeColumn, TimestampRangeColumn, TimestampzRangeColumn, Int4RangeColumn, Int8RangeColumn, NumRangeColumn} + ) + + return sampleRangesTable{ + Table: postgres.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + DateRange: DateRangeColumn, + TimestampRange: TimestampRangeColumn, + TimestampzRange: TimestampzRangeColumn, + Int4Range: Int4RangeColumn, + Int8Range: Int8RangeColumn, + NumRange: NumRangeColumn, AllColumns: allColumns, MutableColumns: mutableColumns, diff --git a/tests/postgres/range_test.go b/tests/postgres/range_test.go new file mode 100644 index 0000000..63498eb --- /dev/null +++ b/tests/postgres/range_test.go @@ -0,0 +1,514 @@ +package postgres + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/qrm" + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgtype" + "github.com/stretchr/testify/require" + "math/big" + "testing" + "time" + + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" +) + +func TestRangeTable_DateContainsSingle(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.date_range @> '2023-12-12'::date; +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.DateRange.CONTAINS(Date(2023, 12, 12))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2023-12-12") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.NoError(t, err) + + expectedRow := model.SampleRanges{ + DateRange: pgtype.Daterange{ + Lower: pgtype.Date{ + Time: time.Date(2023, 9, 25, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Date{ + Time: time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + TimestampRange: pgtype.Tsrange{ + Lower: pgtype.Timestamp{ + Time: time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamp{ + Time: time.Date(2021, 01, 01, 15, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Status: pgtype.Present, + }, + TimestampzRange: pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{ + Time: time.Date(2024, 05, 07, 15, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamptz{ + Time: time.Date(2024, 10, 11, 14, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int4Range: pgtype.Int4range{ + Lower: pgtype.Int4{ + Int: 11, + Status: pgtype.Present, + }, + Upper: pgtype.Int4{ + Int: 20, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int8Range: pgtype.Int8range{ + Lower: pgtype.Int8{ + Int: 200, + Status: pgtype.Present, + }, + Upper: pgtype.Int8{ + Int: 2450, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + NumRange: pgtype.Numrange{ + Lower: pgtype.Numeric{ + Int: big.NewInt(2), + Exp: 3, + Status: pgtype.Present, + }, + Upper: pgtype.Numeric{ + Int: big.NewInt(5), + Status: pgtype.Present, + Exp: 3, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + } + + testutils.AssertDeepEqual(t, sample, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestRangeTable_IntContainsRange(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.int4_range @> int4range(12, 18, '[)'::text); +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.Int4Range.CONTAINS_RANGE(INT4_RANGE(Int(12), Int(18), String("[)")))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(12), int64(18), "[)") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.NoError(t, err) + + expectedRow := model.SampleRanges{ + DateRange: pgtype.Daterange{ + Lower: pgtype.Date{ + Time: time.Date(2023, 9, 25, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Date{ + Time: time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + TimestampRange: pgtype.Tsrange{ + Lower: pgtype.Timestamp{ + Time: time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamp{ + Time: time.Date(2021, 01, 01, 15, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Status: pgtype.Present, + }, + TimestampzRange: pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{ + Time: time.Date(2024, 05, 07, 15, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamptz{ + Time: time.Date(2024, 10, 11, 14, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int4Range: pgtype.Int4range{ + Lower: pgtype.Int4{ + Int: 11, + Status: pgtype.Present, + }, + Upper: pgtype.Int4{ + Int: 20, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int8Range: pgtype.Int8range{ + Lower: pgtype.Int8{ + Int: 200, + Status: pgtype.Present, + }, + Upper: pgtype.Int8{ + Int: 2450, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + NumRange: pgtype.Numrange{ + Lower: pgtype.Numeric{ + Int: big.NewInt(2), + Exp: 3, + Status: pgtype.Present, + }, + Upper: pgtype.Numeric{ + Int: big.NewInt(5), + Status: pgtype.Present, + Exp: 3, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + } + + testutils.AssertDeepEqual(t, sample, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestRangeTable_TimestampContainsRange(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.timestamp_range @> tsrange('2020-02-01 00:00:00'::timestamp without time zone, '2020-10-01 00:00:00'::timestamp without time zone, '[)'::text); +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.TimestampRange.CONTAINS_RANGE(TIMESTAMP_RANGE(Timestamp(2020, 02, 01, 0, 0, 0), Timestamp(2020, 10, 01, 0, 0, 0), String("[)")))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2020-02-01 00:00:00", "2020-10-01 00:00:00", "[)") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.NoError(t, err) + + expectedRow := model.SampleRanges{ + DateRange: pgtype.Daterange{ + Lower: pgtype.Date{ + Time: time.Date(2023, 9, 25, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Date{ + Time: time.Date(2024, 2, 10, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + TimestampRange: pgtype.Tsrange{ + Lower: pgtype.Timestamp{ + Time: time.Date(2020, 01, 01, 0, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamp{ + Time: time.Date(2021, 01, 01, 15, 0, 0, 0, time.UTC), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Status: pgtype.Present, + }, + TimestampzRange: pgtype.Tstzrange{ + Lower: pgtype.Timestamptz{ + Time: time.Date(2024, 05, 07, 15, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + Upper: pgtype.Timestamptz{ + Time: time.Date(2024, 10, 11, 14, 0, 0, 0, time.FixedZone("", 0)), + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int4Range: pgtype.Int4range{ + Lower: pgtype.Int4{ + Int: 11, + Status: pgtype.Present, + }, + Upper: pgtype.Int4{ + Int: 20, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + Int8Range: pgtype.Int8range{ + Lower: pgtype.Int8{ + Int: 200, + Status: pgtype.Present, + }, + Upper: pgtype.Int8{ + Int: 2450, + Status: pgtype.Present, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + NumRange: pgtype.Numrange{ + Lower: pgtype.Numeric{ + Int: big.NewInt(2), + Exp: 3, + Status: pgtype.Present, + }, + Upper: pgtype.Numeric{ + Int: big.NewInt(5), + Status: pgtype.Present, + Exp: 3, + }, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Exclusive, + Status: pgtype.Present, + }, + } + + testutils.AssertDeepEqual(t, sample, expectedRow, cmp.AllowUnexported(big.Int{})) + requireLogged(t, query) +} + +func TestRangeTable_ContainsOutOfRange(t *testing.T) { + skipForCockroachDB(t) + expectedSQL := ` +SELECT DISTINCT sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range" +FROM test_sample.sample_ranges +WHERE sample_ranges.int4_range @> int4range(12, 30, '[)'::text); +` + + query := SELECT(SampleRanges.AllColumns). + DISTINCT(). + FROM(SampleRanges). + WHERE(SampleRanges.Int4Range.CONTAINS_RANGE(INT4_RANGE(Int(12), Int(30), String("[)")))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(12), int64(30), "[)") + + sample := model.SampleRanges{} + err := query.Query(db, &sample) + + require.ErrorIs(t, err, qrm.ErrNoRows) + requireLogged(t, query) +} + +func TestRangeTable_InsertColumn(t *testing.T) { + skipForCockroachDB(t) + + insertQuery := SampleRanges.INSERT(SampleRanges.AllColumns). + VALUES( + DATE_RANGE( + Date(2010, 01, 01), + Date(2014, 01, 01), + String("[)"), + ), + DEFAULT, + TIMESTAMPTZ_RANGE( + TimestampzT(time.Date(2010, 01, 01, 23, 0, 0, 0, time.UTC)), + TimestampzT(time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC)), + String("[)"), + ), + INT4_RANGE(Int(64), Int(128), String("[]")), + INT8_RANGE(Int(1024), Int(2048), String("[]")), + DEFAULT, + ). + RETURNING(SampleRanges.AllColumns) + + expectedQuery := ` +INSERT INTO test_sample.sample_ranges (date_range, timestamp_range, timestampz_range, int4_range, int8_range, num_range) +VALUES (daterange('2010-01-01'::date, '2014-01-01'::date, '[)'::text), DEFAULT, tstzrange('2010-01-01 23:00:00Z'::timestamp with time zone, '2014-01-01 15:00:00Z'::timestamp with time zone, '[)'::text), int4range(64, 128, '[]'::text), int8range(1024, 2048, '[]'::text), DEFAULT) +RETURNING sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range"; +` + testutils.AssertDebugStatementSql(t, insertQuery, expectedQuery, + "2010-01-01", "2014-01-01", "[)", + time.Date(2010, 01, 01, 23, 0, 0, 0, time.UTC), time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC), "[)", + int64(64), int64(128), "[]", + int64(1024), int64(2048), "[]", + ) +} + +func TestRangeTable_UpperBound(t *testing.T) { + skipForCockroachDB(t) + + expectedSQL := ` +SELECT UPPER(sample_ranges.date_range) +FROM test_sample.sample_ranges +WHERE sample_ranges.date_range @> '2023-12-12'::date; +` + + query := SELECT(UPPER_BOUND[DateExpression](SampleRanges.DateRange)). + FROM(SampleRanges). + WHERE(SampleRanges.DateRange.CONTAINS(Date(2023, 12, 12))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2023-12-12") + + var date time.Time + err := query.Query(db, &date) + require.NoError(t, err) + + expectedYear := 2024 + expectedMonth := time.February + expectedDay := 10 + if expectedYear != date.Year() || expectedMonth != date.Month() || expectedDay != date.Day() { + t.Errorf("expected: 2024-02-10 got: %s", date.Format("2006-01-02")) + } +} + +func TestRangeTable_LowerBound(t *testing.T) { + skipForCockroachDB(t) + + expectedSQL := ` +SELECT LOWER(sample_ranges.date_range) +FROM test_sample.sample_ranges +WHERE sample_ranges.date_range @> '2023-12-12'::date; +` + + query := SELECT(LOWER_BOUND[DateExpression](SampleRanges.DateRange)). + FROM(SampleRanges). + WHERE(SampleRanges.DateRange.CONTAINS(Date(2023, 12, 12))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "2023-12-12") + + var date time.Time + err := query.Query(db, &date) + require.NoError(t, err) + + expectedYear := 2023 + expectedMonth := time.September + expectedDay := 25 + if expectedYear != date.Year() || expectedMonth != date.Month() || expectedDay != date.Day() { + t.Errorf("expected: 2023-09-25 got: %s", date.Format("2006-01-02")) + } +} + +func TestRangeTable_InsertInfinite(t *testing.T) { + skipForCockroachDB(t) + + insertQuery := SampleRanges.INSERT(SampleRanges.AllColumns). + VALUES( + DATE_RANGE( + Date(2010, 01, 01), + DateExp(PLUS_INFINITY), + String("[)"), + ), + DEFAULT, + TIMESTAMPTZ_RANGE( + TimestampzExp(MINUS_INFINITY), + TimestampzT(time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC)), + String("[)"), + ), + INT4_RANGE(Int(64), Int(128), String("[]")), + INT8_RANGE(Int(1024), Int(2048), String("[]")), + DEFAULT, + ). + RETURNING(SampleRanges.AllColumns) + + expectedQuery := ` +INSERT INTO test_sample.sample_ranges (date_range, timestamp_range, timestampz_range, int4_range, int8_range, num_range) +VALUES (daterange('2010-01-01'::date, 'infinity', '[)'::text), DEFAULT, tstzrange('-infinity', '2014-01-01 15:00:00Z'::timestamp with time zone, '[)'::text), int4range(64, 128, '[]'::text), int8range(1024, 2048, '[]'::text), DEFAULT) +RETURNING sample_ranges.date_range AS "sample_ranges.date_range", + sample_ranges.timestamp_range AS "sample_ranges.timestamp_range", + sample_ranges.timestampz_range AS "sample_ranges.timestampz_range", + sample_ranges.int4_range AS "sample_ranges.int4_range", + sample_ranges.int8_range AS "sample_ranges.int8_range", + sample_ranges.num_range AS "sample_ranges.num_range"; +` + + testutils.AssertDebugStatementSql(t, insertQuery, expectedQuery, + "2010-01-01", "infinity", "[)", + "-infinity", time.Date(2014, 01, 01, 15, 0, 0, 0, time.UTC), "[)", + int64(64), int64(128), "[]", + int64(1024), int64(2048), "[]", + ) +} diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 8bc8105..1df72dd 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -120,9 +120,15 @@ RETURNING floats.decimal_ptr AS "floats.decimal_ptr", } func TestUUIDComplex(t *testing.T) { - query := Person.INNER_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)). - SELECT(Person.AllColumns, PersonPhone.AllColumns). - ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC()) + query := SELECT( + Person.AllColumns, + PersonPhone.AllColumns, + ).FROM( + Person.INNER_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)), + ).ORDER_BY( + Person.PersonID.ASC(), + PersonPhone.PhoneID.ASC(), + ) t.Run("slice of structs", func(t *testing.T) { diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index a6d9588..048db14 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -236,6 +236,122 @@ LIMIT 12; require.NoError(t, err) } +func TestFetchFirst(t *testing.T) { + + t.Run("rows only", func(t *testing.T) { + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID). + OFFSET(2). + FETCH_FIRST(Int(3)).ROWS_ONLY() + + testutils.AssertStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.actor_id +OFFSET $1 +FETCH FIRST $2 ROWS ONLY; +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 3) + require.Equal(t, dest[0].ActorID, int32(3)) + require.Equal(t, dest[2].ActorID, int32(5)) + }) + + t.Run("rows with ties", func(t *testing.T) { + skipForCockroachDB(t) // ROWS_WITH_TIES is not supported on cockroachdb + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.LastUpdate). + FETCH_FIRST(Int(3)).ROWS_WITH_TIES() + + testutils.AssertStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.last_update +FETCH FIRST $1 ROWS WITH TIES; +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 200) + }) + + t.Run("complex expression", func(t *testing.T) { + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.LastUpdate). + FETCH_FIRST(IntExp( + SELECT(MAX(Store.StoreID)). + FROM(Store), + )).ROWS_ONLY() + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.last_update +FETCH FIRST ( + SELECT MAX(store.store_id) + FROM dvds.store +) ROWS ONLY; +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 2) + }) +} + +func TestOffsetExpression(t *testing.T) { + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID). + OFFSET_e(IntExp( + SELECT(MAX(Store.StoreID)). + FROM(Store), + )).LIMIT(10) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +ORDER BY actor.actor_id +LIMIT 10 +OFFSET ( + SELECT MAX(store.store_id) + FROM dvds.store +); +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + require.Equal(t, dest[0].ActorID, int32(3)) +} + func TestJoinQueryStruct(t *testing.T) { expectedSQL := ` @@ -1067,7 +1183,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { firstCustomerAsc := customersAsc[0] lastCustomerAsc := customersAsc[len(customersAsc)-1] - customersDesc := []model.Customer{} + var customersDesc []model.Customer err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). ORDER_BY(Customer.FirstName.DESC()). Query(db, &customersDesc) @@ -1080,7 +1196,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { testutils.AssertDeepEqual(t, firstCustomerAsc, lastCustomerDesc) testutils.AssertDeepEqual(t, lastCustomerAsc, firstCustomerDesc) - customersAscDesc := []model.Customer{} + var customersAscDesc []model.Customer err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). ORDER_BY(Customer.FirstName.ASC(), Customer.LastName.DESC()). Query(db, &customersAscDesc) @@ -1103,6 +1219,233 @@ func TestSelectOrderByAscDesc(t *testing.T) { testutils.AssertDeepEqual(t, customerAscDesc327, customersAscDesc[327]) } +func TestOrderBy(t *testing.T) { + + t.Run("default", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate, + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date NULLS FIRST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_LAST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date NULLS LAST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date ASC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.rental +ORDER BY rental.return_date DESC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) +} + func TestSelectFullJoin(t *testing.T) { expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", @@ -2054,16 +2397,14 @@ OFFSET 20; Payment. SELECT(Payment.PaymentID, Payment.Amount). WHERE(Payment.Amount.GT_EQ(Float(200))), - ). - ORDER_BY(IntegerColumn("payment.payment_id").ASC(), Payment.Amount.DESC()). - LIMIT(10). - OFFSET(20) - - //fmt.Println(query.DebugSql()) + ).ORDER_BY( + IntegerColumn("payment.payment_id").ASC(), + Payment.Amount.DESC(), + ).LIMIT(10).OFFSET(20) testutils.AssertDebugStatementSql(t, query, expectedQuery, float64(100), float64(200), int64(10), int64(20)) - dest := []model.Payment{} + var dest []model.Payment err := query.Query(db, &dest) @@ -2083,6 +2424,56 @@ OFFSET 20; }) } +func TestUnionOffsetWithExpression(t *testing.T) { + stmt := UNION( + SELECT(Rental.AllColumns). + FROM(Rental). + WHERE(Rental.ReturnDate.IS_NULL()), + + SELECT(Rental.AllColumns). + FROM(Rental). + WHERE(Rental.LastUpdate.GT(LOCALTIMESTAMP())), + ).OFFSET_e(IntExp( + SELECT(Int32(3)), + )).LIMIT(10) + + testutils.AssertStatementSql(t, stmt, ` +( + SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" + FROM dvds.rental + WHERE rental.return_date IS NULL +) +UNION +( + SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" + FROM dvds.rental + WHERE rental.last_update > LOCALTIMESTAMP +) +LIMIT $1 +OFFSET ( + SELECT $2::integer +); +`) + + var dest []model.Rental + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) +} + func TestAllSetOperators(t *testing.T) { var select1 = Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17600)).AND(Payment.PaymentID.LT(Int(17610)))) var select2 = Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17620)).AND(Payment.PaymentID.LT(Int(17630)))) @@ -2251,6 +2642,83 @@ FOR` } } +func TestRowLockWithUpdateOf(t *testing.T) { + stmt := SELECT( + Film.FilmID, + Film.Title, + Actor.ActorID, + Actor.FirstName, + ).FROM( + Film. + INNER_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)). + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)), + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(Film, Actor).NOWAIT(), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name" +FROM dvds.film + INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id) + INNER JOIN dvds.film_actor ON (film_actor.film_id = film.film_id) + INNER JOIN dvds.actor ON (actor.actor_id = film_actor.actor_id) +LIMIT 1 +FOR UPDATE OF film, actor NOWAIT; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film + CategoryID int + Actor []model.Actor + } + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) +} + +func TestRowLockWithUpdateOfAliasedTable(t *testing.T) { + + myFilm := Film.AS("myFilm") + + stmt := SELECT( + myFilm.FilmID, + myFilm.Title, + ).FROM( + myFilm, + ).LIMIT( + 1, + ).FOR( + UPDATE().OF(myFilm), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT "myFilm".film_id AS "myFilm.film_id", + "myFilm".title AS "myFilm.title" +FROM dvds.film AS "myFilm" +LIMIT 1 +FOR UPDATE OF "myFilm"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + model.Film `alias:"myFilm.*"` + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 1) + }) +} + func TestQuickStart(t *testing.T) { var expectedSQL = ` diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go new file mode 100644 index 0000000..4775eeb --- /dev/null +++ b/tests/sqlite/sample_test.go @@ -0,0 +1,82 @@ +package sqlite + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "testing" + + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" +) + +func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) { + + t.Run("should not have the generated column in mutableColumns", func(t *testing.T) { + require.Equal(t, 2, len(People.MutableColumns)) + require.Equal(t, People.PeopleName, People.MutableColumns[0]) + require.Equal(t, People.PeopleHeightCm, People.MutableColumns[1]) + }) + + t.Run("should query with all columns", func(t *testing.T) { + query := SELECT( + People.AllColumns, + ).FROM( + People, + ).WHERE( + People.PeopleID.EQ(Int(3)), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT people.people_id AS "people.people_id", + people.people_name AS "people.people_name", + people.people_height_cm AS "people.people_height_cm", + people.people_height_inch AS "people.people_height_inch", + people.people_height_feet AS "people.people_height_feet" +FROM people +WHERE people.people_id = ?; +`) + var result model.People + + err := query.Query(sampleDB, &result) + require.NoError(t, err) + + require.Equal(t, "Carla", result.PeopleName) + require.Equal(t, 155., *result.PeopleHeightCm) + require.InEpsilon(t, 61.02, *result.PeopleHeightInch, 1e-3) + }) + + t.Run("should insert without generated columns", func(t *testing.T) { + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + insertQuery := People.INSERT( + People.MutableColumns, + ).MODEL( + model.People{ + PeopleName: "Dario", + PeopleHeightCm: testutils.Float64Ptr(190), + }, + ).RETURNING( + People.AllColumns, + ) + + testutils.AssertDebugStatementSql(t, insertQuery, ` +INSERT INTO people (people_name, people_height_cm) +VALUES ('Dario', 190) +RETURNING people.people_id AS "people.people_id", + people.people_name AS "people.people_name", + people.people_height_cm AS "people.people_height_cm", + people.people_height_inch AS "people.people_height_inch", + people.people_height_feet AS "people.people_height_feet"; +`) + var result model.People + err := insertQuery.Query(tx, &result) + require.NoError(t, err) + + require.Equal(t, "Dario", result.PeopleName) + require.Equal(t, 190., *result.PeopleHeightCm) + require.InEpsilon(t, float32(74.80314), *result.PeopleHeightInch, 1e-3) + require.InEpsilon(t, float32(6.233595), *result.PeopleHeightFeet, 1e-3) + }) + }) +} diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index 838cb69..63d43a9 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -145,6 +145,233 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; requireLogged(t, query) } +func TestOrderBy(t *testing.T) { + + t.Run("default", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate, + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date NULLS FIRST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.NULLS_LAST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date NULLS LAST +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date ASC +LIMIT 200; +`) + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date ASC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("ASC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.ASC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date ASC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date DESC +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS LAST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_LAST(), + ).LIMIT(200).OFFSET(15800) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date DESC NULLS LAST +LIMIT 200 +OFFSET 15800; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) + + t.Run("DESC NULLS FIRST", func(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).ORDER_BY( + Rental.ReturnDate.DESC().NULLS_FIRST(), + ).LIMIT(200) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM rental +ORDER BY rental.return_date DESC NULLS FIRST +LIMIT 200; +`) + + require.NoError(t, stmt.Query(db, &struct{}{})) + }) +} + func TestAggregateFunctionDistinct(t *testing.T) { stmt := SELECT( Payment.CustomerID, diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go index 9ec1acf..110c659 100644 --- a/tests/sqlite/update_test.go +++ b/tests/sqlite/update_test.go @@ -183,8 +183,8 @@ RETURNING link.id AS "link.id", BinaryOperator: 31, CastOperator: "20", LikeOperator: false, - // IsNull: true, //TODO: uncomment when sqlite driver updates to sqlite version > 3.40.1 - CaseOperator: "unknown", + IsNull: true, + CaseOperator: "unknown", }) requireLogged(t, stmt) } diff --git a/tests/testdata b/tests/testdata index 3398b97..915bdc1 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 3398b9735b9d097d2ee0c282976726affc6b96f0 +Subproject commit 915bdc16b723d89becc577c780949baef861a6ae