diff --git a/.circleci/config.yml b/.circleci/config.yml index e21f00a..74f2bd8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ jobs: build_and_tests: docker: # specify the version - - image: cimg/go:1.21.6 + - image: cimg/go:1.22.8 - image: cimg/postgres:14.10 environment: POSTGRES_USER: jet @@ -128,17 +128,13 @@ jobs: # to create test results report - run: - name: Install go-junit-report - command: go install github.com/jstemmer/go-junit-report@latest - + name: Install gotestsum + command: go install gotest.tools/gotestsum@latest - run: mkdir -p $TEST_RESULTS - # this will run all tests and exclude test files from code coverage report - - run: | - go test -v ./... \ - -covermode=atomic \ - -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \ - -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml + - run: + name: Running tests + command: gotestsum --junitfile $TEST_RESULTS/report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... ./... # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ @@ -163,4 +159,4 @@ workflows: version: 2 build_and_test: jobs: - - build_and_tests \ No newline at end of file + - build_and_tests diff --git a/.github/workflows/code_scanner.yml b/.github/workflows/code_scanner.yml new file mode 100644 index 0000000..d4ecc5d --- /dev/null +++ b/.github/workflows/code_scanner.yml @@ -0,0 +1,45 @@ +name: Code Scanners +on: + push: + branches: + - master + pull_request: + branches: + - master + +permissions: + contents: read + +env: + go_version: "1.22.8" + + +jobs: + security_scanning: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: ${{ env.go_version }} + cache: true + - name: Setup Tools + run: | + go install github.com/securego/gosec/v2/cmd/gosec@latest + - name: Running Scan + run: gosec --exclude=G402,G304 ./... + lint_scanner: + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: ${{ env.go_version }} + cache: true + - name: Setup Tools + run: | + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + - name: Running Scan + run: golangci-lint run --timeout=30m ./... diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..4eabe05 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,19 @@ +run: + # The default concurrency value is the number of available CPU. + concurrency: 4 + # Timeout for analysis, e.g. 30s, 5m. + # Default: 1m + timeout: 30m + # Exit code when at least one issue was found. + # Default: 1 + issues-exit-code: 2 + # Include test files or not. + # Default: true + tests: false + +issues: + exclude-dirs: + - tests + exclude-files: + - "_test.go" + - "testutils.go" diff --git a/cmd/jet/version.go b/cmd/jet/version.go index 4241ec7..3300008 100644 --- a/cmd/jet/version.go +++ b/cmd/jet/version.go @@ -1,3 +1,3 @@ package main -const version = "v2.11.0" +const version = "v2.11.1" diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go index e453f51..dc84989 100644 --- a/examples/quick-start/quick-start.go +++ b/examples/quick-start/quick-start.go @@ -4,7 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" - "io/ioutil" + "os" _ "github.com/lib/pq" @@ -90,7 +90,7 @@ func main() { func jsonSave(path string, v interface{}) { jsonText, _ := json.MarshalIndent(v, "", "\t") - err := ioutil.WriteFile(path, jsonText, 0644) + err := os.WriteFile(path, jsonText, 0600) panicOnError(err) } diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go index 5c888a3..ecd61e2 100644 --- a/generator/metadata/column_meta_data.go +++ b/generator/metadata/column_meta_data.go @@ -1,29 +1,16 @@ package metadata -import ( - "regexp" -) - // Column struct type Column struct { Name string `sql:"primary_key"` IsPrimaryKey bool IsNullable bool IsGenerated bool + HasDefault bool DataType DataType Comment string } -// GoLangComment returns column comment without ascii control characters -func (c Column) GoLangComment() string { - if c.Comment == "" { - return "" - } - - // remove ascii control characters from string - return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(c.Comment, "") -} - // DataTypeKind is database type kind(base, enum, user-defined, array) type DataTypeKind string diff --git a/generator/metadata/enum_meta_data.go b/generator/metadata/enum_meta_data.go index 7aea3d6..8cce596 100644 --- a/generator/metadata/enum_meta_data.go +++ b/generator/metadata/enum_meta_data.go @@ -2,6 +2,7 @@ package metadata // Enum metadata struct type Enum struct { - Name string `sql:"primary_key"` - Values []string + Name string `sql:"primary_key"` + Comment string + Values []string } diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go index df9514e..1d56bc0 100644 --- a/generator/metadata/table_meta_data.go +++ b/generator/metadata/table_meta_data.go @@ -3,6 +3,7 @@ package metadata // Table metadata struct type Table struct { Name string `sql:"primary_key"` + Comment string Columns []Column } diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index bd4d593..6dc2714 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -18,6 +18,7 @@ func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTyp SELECT t.table_name as "table.name", col.COLUMN_NAME AS "column.Name", + col.COLUMN_DEFAULT IS NOT NULL as "column.HasDefault", col.IS_NULLABLE = "YES" AS "column.IsNullable", col.COLUMN_COMMENT AS "column.Comment", COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey", diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 2d8835b..fc4135a 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -14,7 +14,7 @@ type postgresQuerySet struct{} func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) { query := ` -SELECT table_name as "table.name" +SELECT table_name as "table.name", obj_description((quote_ident(table_schema)||'.'||quote_ident(table_name))::regclass) as "table.comment" FROM information_schema.tables WHERE table_schema = $1 and table_type = $2 ORDER BY table_name; @@ -57,6 +57,7 @@ func getColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]meta query := ` select attr.attname as "column.Name", + col_description(attr.attrelid, attr.attnum) as "column.Comment", exists( select 1 from pg_catalog.pg_index indx @@ -64,11 +65,13 @@ select ) as "column.IsPrimaryKey", not attr.attnotnull as "column.isNullable", attr.attgenerated = 's' as "column.isGenerated", - (case tp.typtype - when 'b' then 'base' - when 'd' then 'base' - when 'e' then 'enum' - when 'r' then 'range' + attr.atthasdef as "column.hasDefault", + (case + when tp.typtype = 'b' AND tp.typcategory <> 'A' then 'base' + when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array' + when tp.typtype = 'd' then 'base' + when tp.typtype = 'e' then 'enum' + when tp.typtype = 'r' then 'range' end) as "dataType.Kind", (case when tp.typtype = 'd' then (select pg_type.typname from pg_catalog.pg_type where pg_type.oid = tp.typbasetype) when tp.typcategory = 'A' then pg_catalog.format_type(attr.atttypid, attr.atttypmod) @@ -99,6 +102,7 @@ order by func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) { query := ` SELECT t.typname as "enum.name", + obj_description(t.oid) as "enum.comment", e.enumlabel as "values" FROM pg_catalog.pg_type t JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index 745aae4..dbf9d15 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -4,10 +4,11 @@ import ( "context" "database/sql" "fmt" + "strings" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/internal/utils/semantic" "github.com/go-jet/jet/v2/qrm" - "strings" ) // sqliteQuerySet is dialect query set for SQLite @@ -74,11 +75,12 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t } var columnInfos []struct { - Name string - Type string - NotNull int32 - Pk int32 - Hidden int32 + Name string + Type string + NotNull int32 + DfltValue string + Pk int32 + Hidden int32 } _, err = qrm.Query(context.Background(), db, tableInfoQuery, []interface{}{tableName}, &columnInfos) @@ -91,12 +93,14 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t for _, columnInfo := range columnInfos { columnType := strings.TrimSuffix(getColumnType(columnInfo.Type), " GENERATED ALWAYS") isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column + hasDefault := columnInfo.DfltValue != "" columns = append(columns, metadata.Column{ Name: columnInfo.Name, IsPrimaryKey: columnInfo.Pk != 0, IsNullable: columnInfo.NotNull != 1, IsGenerated: isGenerated, + HasDefault: hasDefault, DataType: metadata.DataType{ Name: columnType, Kind: metadata.BaseType, diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index 731b1af..1538031 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -26,13 +26,14 @@ import ( var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "{{tableTemplate.DefaultAlias}}") +{{golangComment .Comment}} type {{structImplName}} struct { {{dialect.PackageName}}.Table // Columns {{- range $i, $c := .Columns}} {{- $field := columnField $c}} - {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{- if $c.Comment }} // {{$c.GoLangComment}} {{end}} + {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{golangComment .Comment}} {{- end}} AllColumns {{dialect.PackageName}}.ColumnList @@ -119,10 +120,11 @@ import ( {{end}} {{$modelTableTemplate := tableTemplate}} +{{golangComment .Comment}} type {{$modelTableTemplate.TypeName}} struct { {{- range .Columns}} {{- $field := structField .}} - {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{- if .Comment }} // {{.GoLangComment}} {{end}} + {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{golangComment .Comment}} {{- end}} } @@ -132,6 +134,7 @@ var enumSQLBuilderTemplate = `package {{package}} import "github.com/go-jet/jet/v2/{{dialect.PackageName}}" +{{golangComment .Comment}} var {{enumTemplate.InstanceName}} = &struct { {{- range $index, $value := .Values}} {{enumValueName $value}} {{dialect.PackageName}}.StringExpression @@ -148,6 +151,7 @@ var enumModelTemplate = `package {{package}} import "errors" +{{golangComment .Comment}} type {{$enumTemplate.TypeName}} string const ( @@ -156,6 +160,12 @@ const ( {{- end}} ) +var {{$enumTemplate.TypeName}}AllValues = []{{$enumTemplate.TypeName}} { +{{- range $_, $value := .Values}} + {{valueName $value}}, +{{- end}} +} + func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error { var enumValue string switch val := value.(type) { diff --git a/generator/template/format.go b/generator/template/format.go new file mode 100644 index 0000000..bd88fb1 --- /dev/null +++ b/generator/template/format.go @@ -0,0 +1,13 @@ +package template + +import "regexp" + +// Returns the provided string as golang comment without ascii control characters +func formatGolangComment(comment string) string { + if len(comment) == 0 { + return "" + } + + // Format as colang comment and remove ascii control characters from string + return "// " + regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(comment, "") +} diff --git a/generator/template/format_test.go b/generator/template/format_test.go new file mode 100644 index 0000000..b43b61d --- /dev/null +++ b/generator/template/format_test.go @@ -0,0 +1,27 @@ +package template + +import "testing" + +func Test_formatGolangComment(t *testing.T) { + type args struct { + comment string + } + tests := []struct { + name string + args args + want string + }{ + {name: "Empty string", args: args{comment: ""}, want: ""}, + {name: "Non-empty string", args: args{comment: "This is a comment"}, want: "// This is a comment"}, + {name: "String with control characters", args: args{comment: "This is a comment with control characters \x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f and text after"}, want: "// This is a comment with control characters and text after"}, + {name: "String with escape characters", args: args{comment: "This is a comment with escape characters \n\r\t and text after"}, want: "// This is a comment with escape characters and text after"}, + {name: "String with unicode characters", args: args{comment: "This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"}, want: "// This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := formatGolangComment(tt.args.comment); got != tt.want { + t.Errorf("formatGoLangComment() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/generator/template/process.go b/generator/template/process.go index 3f3798a..5abef87 100644 --- a/generator/template/process.go +++ b/generator/template/process.go @@ -140,6 +140,7 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData [] "enumValueName": func(enumValue string) string { return enumTemplate.ValueName(enumValue) }, + "golangComment": formatGolangComment, }) if err != nil { return fmt.Errorf("failed to generete enum type %s: %w", enumTemplate.FileName, err) @@ -215,6 +216,7 @@ func processTableSQLBuilder(fileTypes, dirPath string, "insertedRowAlias": func() string { return insertedRowAlias(dialect) }, + "golangComment": formatGolangComment, }) if err != nil { return fmt.Errorf("failed to generate table sql builder type %s: %w", tableSQLBuilder.TypeName, err) @@ -307,6 +309,7 @@ func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadat "structField": func(columnMetaData metadata.Column) TableModelField { return tableTemplate.Field(columnMetaData) }, + "golangComment": formatGolangComment, }) if err != nil { return fmt.Errorf("failed to generate model type '%s': %w", tableMetaData.Name, err) @@ -347,6 +350,7 @@ func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemp "valueName": func(value string) string { return enumTemplate.ValueName(value) }, + "golangComment": formatGolangComment, }) if err != nil { diff --git a/go.mod b/go.mod index cb204af..890dc0d 100644 --- a/go.mod +++ b/go.mod @@ -4,16 +4,16 @@ go 1.18 require ( github.com/go-sql-driver/mysql v1.8.1 + github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/jackc/pgconn v1.14.3 + github.com/jackc/pgtype v1.14.3 + github.com/jackc/pgx/v4 v4.18.3 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.17 ) require ( - github.com/google/go-cmp v0.6.0 - github.com/jackc/pgtype v1.14.3 - github.com/jackc/pgx/v4 v4.18.3 github.com/pkg/profile v1.7.0 github.com/shopspring/decimal v1.4.0 github.com/stretchr/testify v1.9.0 diff --git a/internal/jet/column_assigment.go b/internal/jet/column_assigment.go index 440f3eb..c888433 100644 --- a/internal/jet/column_assigment.go +++ b/internal/jet/column_assigment.go @@ -11,6 +11,13 @@ type columnAssigmentImpl struct { expression Expression } +func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment { + return &columnAssigmentImpl{ + column: serializer, + expression: expression, + } +} + func (a columnAssigmentImpl) isColumnAssigment() {} func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index f3ad2b4..68c4c02 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -13,6 +13,7 @@ type Dialect interface { ArgumentPlaceholder() QueryPlaceholderFunc IsReservedWord(name string) bool SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + ValuesDefaultColumnName(index int) string } // SerializerFunc func @@ -35,6 +36,7 @@ type DialectParams struct { ArgumentPlaceholder QueryPlaceholderFunc ReservedWords []string SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + ValuesDefaultColumnName func(index int) string } // NewDialect creates new dialect with params @@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect { argumentPlaceholder: params.ArgumentPlaceholder, reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), serializeOrderBy: params.SerializeOrderBy, + valuesDefaultColumnName: params.ValuesDefaultColumnName, } } @@ -62,6 +65,7 @@ type dialectImpl struct { argumentPlaceholder QueryPlaceholderFunc reservedWords map[string]bool serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc + valuesDefaultColumnName func(index int) string } func (d *dialectImpl) Name() string { @@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending, return d.serializeOrderBy } +func (d *dialectImpl) ValuesDefaultColumnName(index int) string { + return d.valuesDefaultColumnName(index) +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 05b1797..9999803 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { // IN checks if this expressions matches any in expressions list func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN") + return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN") } // NOT_IN checks if this expressions is different of all expressions in expressions list func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN") + return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN") } // AS the temporary alias name to assign to the expression @@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, } } -type skipParenthesisWrap struct { - Expression -} - -func skipWrap(expression Expression) Expression { - return &skipParenthesisWrap{expression} -} - -// since the expression is a function parameter, there is no need to wrap it in parentheses -func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - s.Expression.serialize(statement, out, append(options, NoWrap)...) +func wrap(expressions ...Expression) Expression { + return NewFunc("", expressions, nil) } diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 7e49880..ddc579e 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression { return newBoolExpressionListOperator("OR", expressions...) } -// ROW function is used to create a tuple value that consists of a set of expressions or column values. -func ROW(expressions ...Expression) Expression { - return NewFunc("ROW", expressions, nil) -} - // ------------------ Mathematical functions ---------------// // ABSf calculates absolute value from float expression @@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder if _, isStatement := expression.(Statement); isStatement { expression.serialize(statement, out, options...) } else { - skipWrap(expression).serialize(statement, out, options...) + expression.serialize(statement, out, append(options, NoWrap, Ident)...) } } } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index d6f0b41..251d3ab 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option //---------------------------------------------------// -type wrap struct { - ExpressionInterfaceImpl - expressions []Expression -} - -func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("(") - - if len(n.expressions) == 1 { - options = append(options, NoWrap, Ident) - } - serializeExpressionList(statementType, n.expressions, ", ", out, options...) - - out.WriteString(")") -} - -// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) -func WRAP(expression ...Expression) Expression { - wrap := &wrap{expressions: expression} - wrap.ExpressionInterfaceImpl.Parent = wrap - - return wrap -} - -//---------------------------------------------------// - type rawExpression struct { ExpressionInterfaceImpl diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go index 8ce5d1e..c853845 100644 --- a/internal/jet/order_set_aggregate_functions.go +++ b/internal/jet/order_set_aggregate_functions.go @@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct { func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(p.name) - WRAP(p.fraction).serialize(statement, out, FallTrough(options)...) + + if p.fraction != nil { + wrap(p.fraction).serialize(statement, out, FallTrough(options)...) + } else { + wrap().serialize(statement, out, FallTrough(options)...) + } out.WriteString("WITHIN GROUP") p.orderBy.serialize(statement, out) } diff --git a/internal/jet/projection_test.go b/internal/jet/projection_test.go index 0370b43..61dd920 100644 --- a/internal/jet/projection_test.go +++ b/internal/jet/projection_test.go @@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg", table2.col3 AS "col3", table2.col4 AS "col4"`) - subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery")) + subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil)) assertProjectionSerialize(t, subQueryProjections, `"subQuery"."table1.col3" AS "table1.col3", diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go index 191c7b4..99fb8eb 100644 --- a/internal/jet/raw_statement.go +++ b/internal/jet/raw_statement.go @@ -8,7 +8,7 @@ type rawStatementImpl struct { } // RawStatement creates new sql statements from raw query and optional map of named arguments -func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement { +func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement { newRawStatement := rawStatementImpl{ serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ dialect: dialect, diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go new file mode 100644 index 0000000..e7d5ed5 --- /dev/null +++ b/internal/jet/row_expression.go @@ -0,0 +1,102 @@ +package jet + +// RowExpression interface +type RowExpression interface { + Expression + HasProjections + + EQ(rhs RowExpression) BoolExpression + NOT_EQ(rhs RowExpression) BoolExpression + IS_DISTINCT_FROM(rhs RowExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression + + LT(rhs RowExpression) BoolExpression + LT_EQ(rhs RowExpression) BoolExpression + GT(rhs RowExpression) BoolExpression + GT_EQ(rhs RowExpression) BoolExpression +} + +type rowInterfaceImpl struct { + parent Expression + dialect Dialect + elemCount int +} + +func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression { + return Eq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression { + return NotEq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression { + return IsDistinctFrom(n.parent, rhs) +} + +func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression { + return IsNotDistinctFrom(n.parent, rhs) +} + +func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression { + return Gt(n.parent, rhs) +} + +func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression { + return GtEq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression { + return Lt(n.parent, rhs) +} + +func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression { + return LtEq(n.parent, rhs) +} + +func (n *rowInterfaceImpl) projections() ProjectionList { + var ret ProjectionList + + for i := 0; i < n.elemCount; i++ { + rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil) + ret = append(ret, &rowColumn) + } + + return ret +} + +// ---------------------------------------------------// +type rowExpressionWrapper struct { + rowInterfaceImpl + Expression +} + +func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression { + ret := &rowExpressionWrapper{} + ret.rowInterfaceImpl.parent = ret + + ret.Expression = NewFunc(name, expressions, ret) + ret.dialect = dialect + ret.elemCount = len(expressions) + + return ret +} + +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(dialect Dialect, expressions ...Expression) RowExpression { + return newRowExpression("ROW", dialect, expressions...) +} + +// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`. +func WRAP(dialect Dialect, expressions ...Expression) RowExpression { + return newRowExpression("", dialect, expressions...) +} + +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +func RowExp(expression Expression) RowExpression { + rowExpressionWrap := rowExpressionWrapper{Expression: expression} + rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap + return &rowExpressionWrap +} diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index c25fba3..f1f58ac 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -8,15 +8,21 @@ type SelectTable interface { } type selectTableImpl struct { - Statement SerializerHasProjections - alias string + Statement SerializerHasProjections + alias string + columnAliases []ColumnExpression } // NewSelectTable func -func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl { +func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl { selectTable := selectTableImpl{ - Statement: selectStmt, - alias: alias, + Statement: selectStmt, + alias: alias, + columnAliases: columnAliases, + } + + for _, column := range selectTable.columnAliases { + column.setSubQuery(selectTable) } return selectTable @@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string { } func (s selectTableImpl) AllColumns() ProjectionList { + if len(s.columnAliases) > 0 { + return ColumnListToProjectionList(s.columnAliases) + } + projectionList := s.projections().fromImpl(s) return projectionList.(ProjectionList) } @@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt out.WriteString("AS") out.WriteIdentifier(s.alias) + + if len(s.columnAliases) > 0 { + out.WriteByte('(') + SerializeColumnExpressionNames(s.columnAliases, out) + out.WriteByte(')') + } } // -------------------------------------- @@ -50,7 +66,7 @@ type lateralImpl struct { // NewLateral creates new lateral expression from select statement with alias func NewLateral(selectStmt SerializerStatement, alias string) SelectTable { - return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)} + return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)} } func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 6703154..2ca229d 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -180,7 +180,7 @@ func duration(f func()) time.Duration { f() - return time.Now().Sub(start) + return time.Since(start) } // ExpressionStatement interfacess diff --git a/internal/jet/values.go b/internal/jet/values.go new file mode 100644 index 0000000..9d0f691 --- /dev/null +++ b/internal/jet/values.go @@ -0,0 +1,35 @@ +package jet + +// Values hold a set of one or more rows +type Values []RowExpression + +func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteByte('(') + out.IncreaseIdent(5) + + out.NewLine() + out.WriteString("VALUES") + + for rowIndex, row := range v { + if rowIndex > 0 { + out.WriteString(",") + out.NewLine() + } else { + out.IncreaseIdent(7) + } + + row.serialize(statement, out, options...) + } + out.DecreaseIdent(7) + out.DecreaseIdent(5) + out.NewLine() + out.WriteByte(')') +} + +func (v Values) projections() ProjectionList { + if len(v) == 0 { + return nil + } + + return v[0].projections() +} diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index 783fa27..03330e6 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -64,7 +64,7 @@ type CommonTableExpression struct { // CTE creates new named CommonTableExpression func CTE(name string, columns ...ColumnExpression) CommonTableExpression { cte := CommonTableExpression{ - selectTableImpl: NewSelectTable(nil, name), + selectTableImpl: NewSelectTable(nil, name, columns), Columns: columns, } @@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde out.WriteIdentifier(c.alias) } } - -// AllColumns returns list of all projections in the CTE -func (c CommonTableExpression) AllColumns() ProjectionList { - if len(c.Columns) > 0 { - return ColumnListToProjectionList(c.Columns) - } - - return c.selectTableImpl.AllColumns() -} diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 3e03506..a4a06ee 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -9,17 +9,15 @@ import ( jet2 "github.com/go-jet/jet/v2/internal/jet/db" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/qrm" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "io/ioutil" "os" "path/filepath" "runtime" "testing" "time" - - "github.com/google/go-cmp/cmp" ) // UnixTimeComparer will compare time equality while ignoring time zone @@ -105,11 +103,12 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { } // SaveJSONFile saves v as json at testRelativePath +// nolint:unused func SaveJSONFile(v interface{}, testRelativePath string) { jsonText, _ := json.MarshalIndent(v, "", "\t") filePath := getFullPath(testRelativePath) - err := ioutil.WriteFile(filePath, jsonText, 0644) + err := os.WriteFile(filePath, jsonText, 0600) throw.OnError(err) } @@ -118,7 +117,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) { func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) - fileJSONData, err := ioutil.ReadFile(filePath) + fileJSONData, err := os.ReadFile(filePath) // #nosec G304 require.NoError(t, err) if runtime.GOOS == "windows" { @@ -245,7 +244,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter // AssertFileContent check if file content at filePath contains expectedContent text. func AssertFileContent(t *testing.T, filePath string, expectedContent string) { - enumFileData, err := ioutil.ReadFile(filePath) + enumFileData, err := os.ReadFile(filePath) // #nosec G304 require.NoError(t, err) @@ -254,7 +253,7 @@ func AssertFileContent(t *testing.T, filePath string, expectedContent string) { // AssertFileNamesEqual check if all filesInfos are contained in fileNames func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) { - files, err := ioutil.ReadDir(dirPath) + files, err := os.ReadDir(dirPath) require.NoError(t, err) require.Equal(t, len(files), len(fileNames)) @@ -293,76 +292,6 @@ func printDiff(actual, expected interface{}, options ...cmp.Option) { fmt.Println(expected) } -// BoolPtr returns address of bool parameter -func BoolPtr(b bool) *bool { - return &b -} - -// Int8Ptr returns address of int8 parameter -func Int8Ptr(i int8) *int8 { - return &i -} - -// UInt8Ptr returns address of uint8 parameter -func UInt8Ptr(i uint8) *uint8 { - return &i -} - -// Int16Ptr returns address of int16 parameter -func Int16Ptr(i int16) *int16 { - return &i -} - -// UInt16Ptr returns address of uint16 parameter -func UInt16Ptr(i uint16) *uint16 { - return &i -} - -// Int32Ptr returns address of int32 parameter -func Int32Ptr(i int32) *int32 { - return &i -} - -// UInt32Ptr returns address of uint32 parameter -func UInt32Ptr(i uint32) *uint32 { - return &i -} - -// Int64Ptr returns address of int64 parameter -func Int64Ptr(i int64) *int64 { - return &i -} - -// UInt64Ptr returns address of uint64 parameter -func UInt64Ptr(i uint64) *uint64 { - return &i -} - -// StringPtr returns address of string parameter -func StringPtr(s string) *string { - return &s -} - -// TimePtr returns address of time.Time parameter -func TimePtr(t time.Time) *time.Time { - return &t -} - -// ByteArrayPtr returns address of []byte parameter -func ByteArrayPtr(arr []byte) *[]byte { - return &arr -} - -// Float32Ptr returns address of float32 parameter -func Float32Ptr(f float32) *float32 { - return &f -} - -// Float64Ptr returns address of float64 parameter -func Float64Ptr(f float64) *float64 { - return &f -} - // UUIDPtr returns address of uuid.UUID func UUIDPtr(u string) *uuid.UUID { newUUID := uuid.MustParse(u) diff --git a/internal/utils/filesys/filesys.go b/internal/utils/filesys/filesys.go index d904be8..0724470 100644 --- a/internal/utils/filesys/filesys.go +++ b/internal/utils/filesys/filesys.go @@ -1,6 +1,7 @@ package filesys import ( + "errors" "fmt" "go/format" "os" @@ -16,7 +17,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error { newGoFilePath += ".go" } - file, err := os.Create(newGoFilePath) + file, err := os.Create(newGoFilePath) // #nosec 304 if err != nil { return err @@ -28,7 +29,10 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error { // if there is a format error we will write unformulated text for debug purposes if err != nil { - file.Write(text) + _, writeErr := file.Write(text) + if writeErr != nil { + return errors.Join(writeErr, fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err)) + } return fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err) } @@ -43,7 +47,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error { // EnsureDirPathExist ensures dir path exists. If path does not exist, creates new path. func EnsureDirPathExist(dirPath string) error { if _, err := os.Stat(dirPath); os.IsNotExist(err) { - err := os.MkdirAll(dirPath, os.ModePerm) + err := os.MkdirAll(dirPath, 0o750) if err != nil { return fmt.Errorf("can't create directory - %s: %w", dirPath, err) diff --git a/internal/utils/ptr/ptr.go b/internal/utils/ptr/ptr.go new file mode 100644 index 0000000..26f2688 --- /dev/null +++ b/internal/utils/ptr/ptr.go @@ -0,0 +1,6 @@ +package ptr + +// Of returns the address of any given parameter +func Of[T any](value T) *T { + return &value +} diff --git a/mysql/cast.go b/mysql/cast.go index 83a0578..ca647a5 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -6,23 +6,27 @@ import ( ) type cast interface { - // Cast expressions as castType type + // AS casts expressions as castType type AS(castType string) Expression - // Cast expression as char with optional length + // AS_CHAR casts expression as char with optional length AS_CHAR(length ...int) StringExpression - // Cast expression AS date type + // AS_DATE casts expression AS date type AS_DATE() DateExpression - // Cast expression AS numeric type, using precision and optionally scale + // AS_FLOAT casts expressions as float type + AS_FLOAT() FloatExpression + // AS_DOUBLE casts expressions as double type + AS_DOUBLE() FloatExpression + // AS_DECIMAL casts expression AS numeric type AS_DECIMAL() FloatExpression - // Cast expression AS time type + // AS_TIME casts expression AS time type AS_TIME() TimeExpression - // Cast expression as datetime type + // AS_DATETIME casts expression as datetime type AS_DATETIME() DateTimeExpression - // Cast expressions as signed integer type + // AS_SIGNED casts expressions as signed integer type AS_SIGNED() IntegerExpression - // Cast expression as unsigned integer type + // AS_UNSIGNED casts expression as unsigned integer type AS_UNSIGNED() IntegerExpression - // Cast expression as binary type + // AS_BINARY casts expression as binary type AS_BINARY() StringExpression } @@ -73,6 +77,14 @@ func (c *castImpl) AS_DATE() DateExpression { return DateExp(c.AS("DATE")) } +func (c *castImpl) AS_FLOAT() FloatExpression { + return FloatExp(c.AS("FLOAT")) +} + +func (c *castImpl) AS_DOUBLE() FloatExpression { + return FloatExp(c.AS("DOUBLE")) +} + // AS_DECIMAL casts expression AS DECIMAL type func (c *castImpl) AS_DECIMAL() FloatExpression { return FloatExp(c.AS("DECIMAL")) diff --git a/mysql/dialect.go b/mysql/dialect.go index 18d2eec..9628bfb 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -1,6 +1,7 @@ package mysql import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -28,6 +29,9 @@ func newDialect() jet.Dialect { }, ReservedWords: reservedWords, SerializeOrderBy: serializeOrderBy, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column_%d", index) + }, } return jet.NewDialect(mySQLDialectParams) diff --git a/mysql/expressions.go b/mysql/expressions.go index 53b1fa7..4073ef5 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -30,6 +30,9 @@ type DateTimeExpression = jet.TimestampExpression // TimestampExpression interface type TimestampExpression = jet.TimestampExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -70,6 +73,11 @@ var DateTimeExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampExp = jet.TimestampExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // CustomExpression is used to define custom expressions. var CustomExpression = jet.CustomExpression diff --git a/mysql/functions.go b/mysql/functions.go index ca31d18..ceec7ab 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -12,7 +12,9 @@ var ( ) // ROW function is used to create a tuple value that consists of a set of expressions or column values. -var ROW = jet.ROW +func ROW(expressions ...Expression) RowExpression { + return jet.ROW(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// diff --git a/mysql/interval.go b/mysql/interval.go index ce1c609..23be21f 100644 --- a/mysql/interval.go +++ b/mysql/interval.go @@ -14,25 +14,25 @@ type unitType string // List of interval unit types for MySQL const ( MICROSECOND unitType = "MICROSECOND" - SECOND = "SECOND" - MINUTE = "MINUTE" - HOUR = "HOUR" - DAY = "DAY" - WEEK = "WEEK" - MONTH = "MONTH" - QUARTER = "QUARTER" - YEAR = "YEAR" - SECOND_MICROSECOND = "SECOND_MICROSECOND" - MINUTE_MICROSECOND = "MINUTE_MICROSECOND" - MINUTE_SECOND = "MINUTE_SECOND" - HOUR_MICROSECOND = "HOUR_MICROSECOND" - HOUR_SECOND = "HOUR_SECOND" - HOUR_MINUTE = "HOUR_MINUTE" - DAY_MICROSECOND = "DAY_MICROSECOND" - DAY_SECOND = "DAY_SECOND" - DAY_MINUTE = "DAY_MINUTE" - DAY_HOUR = "DAY_HOUR" - YEAR_MONTH = "YEAR_MONTH" + SECOND unitType = "SECOND" + MINUTE unitType = "MINUTE" + HOUR unitType = "HOUR" + DAY unitType = "DAY" + WEEK unitType = "WEEK" + MONTH unitType = "MONTH" + QUARTER unitType = "QUARTER" + YEAR unitType = "YEAR" + SECOND_MICROSECOND unitType = "SECOND_MICROSECOND" + MINUTE_MICROSECOND unitType = "MINUTE_MICROSECOND" + MINUTE_SECOND unitType = "MINUTE_SECOND" + HOUR_MICROSECOND unitType = "HOUR_MICROSECOND" + HOUR_SECOND unitType = "HOUR_SECOND" + HOUR_MINUTE unitType = "HOUR_MINUTE" + DAY_MICROSECOND unitType = "DAY_MICROSECOND" + DAY_SECOND unitType = "DAY_SECOND" + DAY_MINUTE unitType = "DAY_MINUTE" + DAY_HOUR unitType = "DAY_HOUR" + YEAR_MONTH unitType = "YEAR_MONTH" ) // Interval is representation of MySQL interval diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 6c5f345..aaeff9a 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -172,7 +172,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/mysql/select_table.go b/mysql/select_table.go index ad22193..8ca06d4 100644 --- a/mysql/select_table.go +++ b/mysql/select_table.go @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/mysql/set_statement.go b/mysql/set_statement.go index 2df75a0..7147f00 100644 --- a/mysql/set_statement.go +++ b/mysql/set_statement.go @@ -85,7 +85,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/mysql/statement.go b/mysql/statement.go index a0bb0a7..008ace5 100644 --- a/mysql/statement.go +++ b/mysql/statement.go @@ -6,7 +6,7 @@ import ( ) // RawStatement creates new sql statements from raw query and optional map of named arguments -func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { +func RawStatement(rawQuery string, namedArguments ...RawArgs) jet.SerializerStatement { return jet.RawStatement(Dialect, rawQuery, namedArguments...) } diff --git a/mysql/values.go b/mysql/values.go new file mode 100644 index 0000000..b55895b --- /dev/null +++ b/mysql/values.go @@ -0,0 +1,32 @@ +package mysql + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the ROW constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// ROW(Int32(204), Float32(1.21)), +// ROW(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be +// overwritten by passing new list of columns. +// +// Example usage: +// +// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date")) +func (v values) AS(alias string, columns ...Column) SelectTable { + return newSelectTable(v, alias, columns) +} diff --git a/mysql/with_statement.go b/mysql/with_statement.go index ca608cb..03b4d5b 100644 --- a/mysql/with_statement.go +++ b/mysql/with_statement.go @@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -41,7 +41,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } @@ -52,7 +52,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/postgres/columns.go b/postgres/columns.go index 819da38..a70c234 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -109,13 +109,20 @@ type ColumnInterval interface { jet.Column From(subQuery SelectTable) ColumnInterval + SET(intervalExp IntervalExpression) ColumnAssigment } +//------------------------------------------------------// + type intervalColumnImpl struct { jet.ColumnExpressionImpl intervalInterfaceImpl } +func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment { + return jet.NewColumnAssignment(i, intervalExp) +} + func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval { newIntervalColumn := IntervalColumn(i.Name()) jet.SetTableName(newIntervalColumn, i.TableName()) diff --git a/postgres/dialect.go b/postgres/dialect.go index 9484ab1..14929ad 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -1,6 +1,7 @@ package postgres import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" "strconv" ) @@ -25,6 +26,9 @@ func newDialect() jet.Dialect { return "$" + strconv.Itoa(ord) }, ReservedWords: reservedWords, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column%d", index+1) + }, } return jet.NewDialect(dialectParams) diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 9aadbc9..6fb987c 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -46,33 +46,33 @@ func TestExists(t *testing.T) { func TestIN(t *testing.T) { assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), - `($1 IN ( + `($1 IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 -))`, float64(1.11)) +)))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) IN ( + `(ROW($1, table1.col1) IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -))`, int64(12)) +)))`, int64(12)) } func TestNOT_IN(t *testing.T) { assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), - `($1 NOT IN ( + `($1 NOT IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 -))`, float64(1.11)) +)))`, float64(1.11)) assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), - `(ROW($1, table1.col1) NOT IN ( + `(ROW($1, table1.col1) NOT IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2 -))`, int64(12)) +)))`, int64(12)) } func TestReservedWordEscaped(t *testing.T) { diff --git a/postgres/expressions.go b/postgres/expressions.go index 9872910..d8ad34b 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -36,6 +36,9 @@ type TimestampExpression = jet.TimestampExpression // TimestampzExpression interface type TimestampzExpression = jet.TimestampzExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // DateRange Expression interface type DateRange = jet.Range[DateExpression] @@ -99,6 +102,11 @@ var TimestampExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampzExp = jet.TimestampzExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // RangeExp is range expression wrapper around arbitrary expression. // Allows go compiler to see any expression as range expression. // Does not add sql cast to generated sql builder output. diff --git a/postgres/functions.go b/postgres/functions.go index 7b6d1e1..3134325 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -14,7 +14,9 @@ var ( ) // ROW function is used to create a tuple value that consists of a set of expressions or column values. -var ROW = jet.ROW +func ROW(expressions ...Expression) RowExpression { + return jet.ROW(Dialect, expressions...) +} // ------------------ Mathematical functions ---------------// @@ -332,6 +334,16 @@ var LOCALTIMESTAMP = jet.LOCALTIMESTAMP // NOW returns current date and time var NOW = jet.NOW +// DATE_TRUNC returns the truncated date and time using optional time zone. +// Use TimestampzExp if you need timestamp with time zone and IntervalExp if you need interval. +func DATE_TRUNC(field unit, source Expression, timezone ...string) TimestampExpression { + if len(timezone) > 0 { + return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source, jet.FixedLiteral(timezone[0])) + } + + return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source) +} + // --------------- Conditional Expressions Functions -------------// // COALESCE function returns the first of its arguments that is not null. @@ -415,10 +427,12 @@ func castFloatLiteral(fraction FloatExpression) FloatExpression { // ), var GROUPING_SETS = jet.GROUPING_SETS -// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... ) -// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW method behave exactly the same, -// except when used in GROUPING_SETS. For top level GROUPING SETS expression lists WRAP has to be used. -var WRAP = jet.WRAP +// WRAP surrounds a list of expressions or columns with parentheses, producing new row: (expression1, expression2, ...) +// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW methods behave exactly the same, +// except when used in GROUPING_SETS and VALUES. In these contexts, WRAP must be used instead of ROW. +func WRAP(expressions ...Expression) RowExpression { + return jet.WRAP(Dialect, expressions...) +} // ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list. // It creates extra rows in the result set that represent the subtotal values for each combination of columns. diff --git a/postgres/functions_test.go b/postgres/functions_test.go index 4190f70..d0081b0 100644 --- a/postgres/functions_test.go +++ b/postgres/functions_test.go @@ -1,6 +1,8 @@ package postgres -import "testing" +import ( + "testing" +) func TestROW(t *testing.T) { assertSerialize(t, ROW(SELECT(Int(1))), `ROW(( @@ -10,3 +12,12 @@ func TestROW(t *testing.T) { SELECT $2 ), $3)`) } + +func TestDATE_TRUNC(t *testing.T) { + assertSerialize(t, DATE_TRUNC(YEAR, NOW()), "DATE_TRUNC('YEAR', NOW())") + assertSerialize( + t, + DATE_TRUNC(DAY, NOW().ADD(INTERVAL(1, HOUR)), "Australia/Sydney"), + "DATE_TRUNC('DAY', NOW() + INTERVAL '1 HOUR', 'Australia/Sydney')", + ) +} diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 25300c2..5ace301 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -1,7 +1,6 @@ package postgres import ( - "github.com/go-jet/jet/v2/internal/jet" "github.com/stretchr/testify/require" "testing" "time" @@ -151,12 +150,13 @@ func TestInsert_ON_CONFLICT(t *testing.T) { VALUES("one", "two"). VALUES("1", "2"). VALUES("theta", "beta"). - ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( - SET(table1ColBool.SET(Bool(true)), - table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), - ).WHERE(table1Col1.GT(Int(2))), - ). + ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()). + DO_UPDATE( + SET(table1ColBool.SET(Bool(true)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), + ). RETURNING(table1Col1, table1ColBool) assertDebugStatementSql(t, stmt, ` @@ -178,12 +178,12 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { stmt := table1.INSERT(table1Col1, table1ColBool). VALUES("one", "two"). VALUES("1", "2"). - ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( - SET(table1ColBool.SET(Bool(false)), - table2ColInt.SET(Int(1)), - ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), - ).WHERE(table1Col1.GT(Int(2))), - ). + ON_CONFLICT().ON_CONSTRAINT("idk_primary_key"). + DO_UPDATE( + SET(table1ColBool.SET(Bool(false)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2)))). RETURNING(table1Col1, table1ColBool) assertDebugStatementSql(t, stmt, ` diff --git a/postgres/literal.go b/postgres/literal.go index e3a95b3..26b75d8 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -57,6 +57,16 @@ func Uint64(value uint64) IntegerExpression { // Float creates new float literal expression var Float = jet.Float +// Float32 is constructor for 32 bit float literals +func Float32(value float32) FloatExpression { + return CAST(jet.Literal(value)).AS_REAL() +} + +// Float64 is constructor for 64 bit float literals +func Float64(value float64) FloatExpression { + return CAST(jet.Literal(value)).AS_DOUBLE() +} + // Decimal creates new float literal expression var Decimal = jet.Decimal diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 70a9a50..2a48fc5 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -181,7 +181,7 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/postgres/select_table.go b/postgres/select_table.go index f3d680d..5f57c1d 100644 --- a/postgres/select_table.go +++ b/postgres/select_table.go @@ -2,7 +2,7 @@ package postgres import "github.com/go-jet/jet/v2/internal/jet" -// SelectTable is interface for postgres sub-queries +// SelectTable is interface for postgres temporary tables like sub-queries, VALUES, CTEs etc... type SelectTable interface { readableTable jet.SelectTable @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(serializerWithProjections jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(serializerWithProjections, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/postgres/set_statement.go b/postgres/set_statement.go index 834560d..0dee00d 100644 --- a/postgres/set_statement.go +++ b/postgres/set_statement.go @@ -136,7 +136,7 @@ func (s *setStatementImpl) OFFSET_e(offset IntegerExpression) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/postgres/values.go b/postgres/values.go new file mode 100644 index 0000000..28d17cd --- /dev/null +++ b/postgres/values.go @@ -0,0 +1,32 @@ +package postgres + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the WRAP constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// WRAP(Int32(204), Float32(1.21)), +// WRAP(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be +// overwritten by passing new list of columns. +// +// Example usage: +// +// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date")) +func (v values) AS(alias string, columns ...Column) SelectTable { + return newSelectTable(v, alias, columns) +} diff --git a/postgres/with_statement.go b/postgres/with_statement.go index 698d6e3..99ddc8f 100644 --- a/postgres/with_statement.go +++ b/postgres/with_statement.go @@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -42,7 +42,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } @@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/qrm/internal/null_types.go b/qrm/internal/null_types.go index ab75cf6..85d9c68 100644 --- a/qrm/internal/null_types.go +++ b/qrm/internal/null_types.go @@ -10,6 +10,10 @@ import ( "time" ) +var ( + castOverFlowError = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error") +) + // NullBool struct type NullBool struct { sql.NullBool @@ -119,32 +123,47 @@ func (n *NullUInt64) Scan(value interface{}) error { n.Valid = false return nil case int64: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int32: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int16: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int8: + if v < 0 { + return castOverFlowError + } + n.UInt64, n.Valid = uint64(v), true + return nil + case int: + if v < 0 { + return castOverFlowError + } n.UInt64, n.Valid = uint64(v), true return nil case uint64: n.UInt64, n.Valid = v, true return nil - case int32: - n.UInt64, n.Valid = uint64(v), true - return nil case uint32: n.UInt64, n.Valid = uint64(v), true return nil - case int16: - n.UInt64, n.Valid = uint64(v), true - return nil case uint16: n.UInt64, n.Valid = uint64(v), true return nil - case int8: - n.UInt64, n.Valid = uint64(v), true - return nil case uint8: n.UInt64, n.Valid = uint64(v), true return nil - case int: - n.UInt64, n.Valid = uint64(v), true - return nil case uint: n.UInt64, n.Valid = uint64(v), true return nil diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go index a15b104..eab2dd2 100644 --- a/qrm/internal/null_types_test.go +++ b/qrm/internal/null_types_test.go @@ -2,6 +2,7 @@ package internal import ( "fmt" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" "time" @@ -62,11 +63,21 @@ func TestNullUInt64(t *testing.T) { value, _ := nullUInt64.Value() require.Equal(t, value, uint64(11)) + require.NoError(t, nullUInt64.Scan(uint64(11))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(11)) + require.NoError(t, nullUInt64.Scan(int32(32))) require.Equal(t, nullUInt64.Valid, true) value, _ = nullUInt64.Value() require.Equal(t, value, uint64(32)) + require.NoError(t, nullUInt64.Scan(uint32(32))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(32)) + require.NoError(t, nullUInt64.Scan(int16(20))) require.Equal(t, nullUInt64.Valid, true) value, _ = nullUInt64.Value() @@ -88,4 +99,29 @@ func TestNullUInt64(t *testing.T) { require.Equal(t, value, uint64(30)) require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text") + + //Validate negative use cases + err := nullUInt64.Scan(int64(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(-5) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(int32(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(int16(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) + + //Validate negative use cases + err = nullUInt64.Scan(int8(-5)) + assert.NotNil(t, err) + assert.Error(t, err, castOverFlowError) } diff --git a/sqlite/dialect.go b/sqlite/dialect.go index 93e1d2f..da03364 100644 --- a/sqlite/dialect.go +++ b/sqlite/dialect.go @@ -1,6 +1,7 @@ package sqlite import ( + "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -23,6 +24,9 @@ func newDialect() jet.Dialect { return "?" }, ReservedWords: reservedWords2, + ValuesDefaultColumnName: func(index int) string { + return fmt.Sprintf("column%d", index+1) + }, } return jet.NewDialect(mySQLDialectParams) diff --git a/sqlite/expressions.go b/sqlite/expressions.go index 42ccc96..0b2d320 100644 --- a/sqlite/expressions.go +++ b/sqlite/expressions.go @@ -33,6 +33,9 @@ type DateTimeExpression = jet.TimestampExpression // TimestampExpression interface type TimestampExpression = jet.TimestampExpression +// RowExpression interface +type RowExpression = jet.RowExpression + // BoolExp is bool expression wrapper around arbitrary expression. // Allows go compiler to see any expression as bool expression. // Does not add sql cast to generated sql builder output. @@ -73,6 +76,11 @@ var DateTimeExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampExp = jet.TimestampExp +// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. +// This enables the Go compiler to interpret any expression as a row expression +// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. +var RowExp = jet.RowExp + // CustomExpression is used to define custom expressions. var CustomExpression = jet.CustomExpression diff --git a/sqlite/functions.go b/sqlite/functions.go index 47a0a5b..ac6bd08 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -15,9 +15,9 @@ var ( OR = jet.OR ) -// ROW is construct one table row from list of expressions. -func ROW(expressions ...Expression) Expression { - return jet.NewFunc("", expressions, nil) +// ROW function is used to create a tuple value that consists of a set of expressions or column values. +func ROW(expressions ...Expression) RowExpression { + return jet.WRAP(Dialect, expressions...) } // ------------------ Mathematical functions ---------------// diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go index e74cda6..531ae76 100644 --- a/sqlite/select_statement.go +++ b/sqlite/select_statement.go @@ -155,7 +155,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { } func (s *selectStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } //----------------------------------------------------- diff --git a/sqlite/select_table.go b/sqlite/select_table.go index 9ac7f72..7421f4b 100644 --- a/sqlite/select_table.go +++ b/sqlite/select_table.go @@ -13,9 +13,9 @@ type selectTableImpl struct { readableTableInterfaceImpl } -func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable { +func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable { subQuery := &selectTableImpl{ - SelectTable: jet.NewSelectTable(selectStmt, alias), + SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go index 0a004bf..47723a3 100644 --- a/sqlite/set_statement.go +++ b/sqlite/set_statement.go @@ -86,7 +86,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement { } func (s *setStatementImpl) AsTable(alias string) SelectTable { - return newSelectTable(s, alias) + return newSelectTable(s, alias, nil) } const ( diff --git a/sqlite/values.go b/sqlite/values.go new file mode 100644 index 0000000..cb683cb --- /dev/null +++ b/sqlite/values.go @@ -0,0 +1,26 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +type values struct { + jet.Values +} + +// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table. +// Each row is defined by the ROW constructor, which takes one or more expressions. +// +// Example usage: +// +// VALUES( +// ROW(Int32(204), Float32(1.21)), +// ROW(Int32(207), Float32(1.02)), +// ) +func VALUES(rows ...RowExpression) values { + return values{Values: jet.Values(rows)} +} + +// AS assigns an alias to the temporary VALUES table, allowing it to be referenced +// within SQL FROM clauses, just like a regular table. +func (v values) AS(alias string) SelectTable { + return newSelectTable(v, alias, nil) +} diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go index 5375fff..b05da7d 100644 --- a/sqlite/with_statement.go +++ b/sqlite/with_statement.go @@ -6,8 +6,8 @@ import "github.com/go-jet/jet/v2/internal/jet" type CommonTableExpression interface { SelectTable - AS(statement jet.SerializerStatement) CommonTableExpression - AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression + AS(statement jet.SerializerHasProjections) CommonTableExpression + AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. ALIAS(alias string) SelectTable @@ -42,13 +42,13 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression { } // AS is used to define a CTE query -func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.Statement = statement return c } // AS_NOT_MATERIALIZED is used to define not materialized CTE query -func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression { +func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression { c.CommonTableExpression.NotMaterialized = true c.CommonTableExpression.Statement = statement return c @@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression { // ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query. func (c *commonTableExpression) ALIAS(name string) SelectTable { - return newSelectTable(c, name) + return newSelectTable(c, name, nil) } func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression { diff --git a/tests/Makefile b/tests/Makefile index 1a84778..43a8d66 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -6,6 +6,8 @@ setup: checkout-testdata docker-compose-up checkout-testdata: git submodule init git submodule update +# +checkout-latest-testdata: checkout-testdata cd ./testdata && git fetch && git checkout master && git pull # docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 9b3af50..09ce9d7 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -13,7 +13,7 @@ services: - ./testdata/init/postgres:/docker-entrypoint-initdb.d mysql: - image: mysql:8.0.27 + image: mysql:8.0 command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] restart: always environment: diff --git a/tests/init/init.go b/tests/init/init.go index 5a21ee7..5a0ccdb 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -3,6 +3,7 @@ package main import ( "context" "database/sql" + "errors" "flag" "fmt" "github.com/go-jet/jet/v2/generator/mysql" @@ -10,7 +11,6 @@ import ( "github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/internal/utils/errfmt" "github.com/go-jet/jet/v2/tests/internal/utils/repo" - "io/ioutil" "os" "os/exec" "strings" @@ -125,7 +125,7 @@ func initMySQLDB(isMariaDB bool) error { fmt.Println(cmdLine) - cmd := exec.Command("sh", "-c", cmdLine) + cmd := exec.Command("sh", "-c", cmdLine) // #nosec G204 cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout @@ -184,7 +184,7 @@ func initPostgresDB(dbType string, connectionString string) error { } func execFile(db *sql.DB, sqlFilePath string) error { - testSampleSql, err := ioutil.ReadFile(sqlFilePath) + testSampleSql, err := os.ReadFile(sqlFilePath) // #nosec G304 if err != nil { return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err) } @@ -211,7 +211,10 @@ func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error { err = f(tx) if err != nil { - tx.Rollback() + rollBackError := tx.Rollback() + if rollBackError != nil { + return errors.Join(rollBackError, err) + } return err } diff --git a/tests/internal/utils/file/file.go b/tests/internal/utils/file/file.go index 6d08d22..73095ea 100644 --- a/tests/internal/utils/file/file.go +++ b/tests/internal/utils/file/file.go @@ -2,7 +2,6 @@ package file import ( "github.com/stretchr/testify/require" - "io/ioutil" "os" "path" "testing" @@ -11,7 +10,7 @@ import ( // Exists expects file to exist on path constructed from pathElems and returns content of the file func Exists(t *testing.T, pathElems ...string) (fileContent string) { modelFilePath := path.Join(pathElems...) - file, err := ioutil.ReadFile(modelFilePath) + file, err := os.ReadFile(modelFilePath) // #nosec G304 require.Nil(t, err) require.NotEmpty(t, file) return string(file) @@ -20,6 +19,6 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) { // NotExists expects file not to exist on path constructed from pathElems func NotExists(t *testing.T, pathElems ...string) { modelFilePath := path.Join(pathElems...) - _, err := ioutil.ReadFile(modelFilePath) + _, err := os.ReadFile(modelFilePath) // #nosec G304 require.True(t, os.IsNotExist(err)) } diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index c90f774..61bc7f2 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,6 +1,7 @@ package mysql import ( + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "strings" @@ -96,18 +97,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - )) AS "result.in_select", + ))) AS "result.in_select", (CURRENT_USER()) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM test_sample.all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) @@ -1067,7 +1068,7 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) { var toInsert = model.AllTypes{ Boolean: false, - BooleanPtr: testutils.BoolPtr(true), + BooleanPtr: ptr.Of(true), TinyInt: 1, UTinyInt: 2, SmallInt: 3, @@ -1078,53 +1079,53 @@ var toInsert = model.AllTypes{ UInteger: 8, BigInt: 9, UBigInt: 1122334455, - TinyIntPtr: testutils.Int8Ptr(11), - UTinyIntPtr: testutils.UInt8Ptr(22), - SmallIntPtr: testutils.Int16Ptr(33), - USmallIntPtr: testutils.UInt16Ptr(44), - MediumIntPtr: testutils.Int32Ptr(55), - UMediumIntPtr: testutils.UInt32Ptr(66), - IntegerPtr: testutils.Int32Ptr(77), - UIntegerPtr: testutils.UInt32Ptr(88), - BigIntPtr: testutils.Int64Ptr(99), - UBigIntPtr: testutils.UInt64Ptr(111), + TinyIntPtr: ptr.Of(int8(11)), + UTinyIntPtr: ptr.Of(uint8(22)), + SmallIntPtr: ptr.Of(int16(33)), + USmallIntPtr: ptr.Of(uint16(44)), + MediumIntPtr: ptr.Of(int32(55)), + UMediumIntPtr: ptr.Of(uint32(66)), + IntegerPtr: ptr.Of(int32(77)), + UIntegerPtr: ptr.Of(uint32(88)), + BigIntPtr: ptr.Of(int64(99)), + UBigIntPtr: ptr.Of(uint64(111)), Decimal: 11.22, - DecimalPtr: testutils.Float64Ptr(33.44), + DecimalPtr: ptr.Of(33.44), Numeric: 55.66, - NumericPtr: testutils.Float64Ptr(77.88), + NumericPtr: ptr.Of(77.88), Float: 99.00, - FloatPtr: testutils.Float64Ptr(11.22), + FloatPtr: ptr.Of(11.22), Double: 33.44, - DoublePtr: testutils.Float64Ptr(55.66), + DoublePtr: ptr.Of(55.66), Real: 77.88, - RealPtr: testutils.Float64Ptr(99.00), + RealPtr: ptr.Of(99.00), Bit: "1", - BitPtr: testutils.StringPtr("0"), + BitPtr: ptr.Of("0"), Time: time.Date(1, 1, 1, 10, 11, 12, 100, &time.Location{}), - TimePtr: testutils.TimePtr(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)), + TimePtr: ptr.Of(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)), Date: time.Now(), - DatePtr: testutils.TimePtr(time.Now()), + DatePtr: ptr.Of(time.Now()), DateTime: time.Now(), - DateTimePtr: testutils.TimePtr(time.Now()), + DateTimePtr: ptr.Of(time.Now()), Timestamp: time.Now(), //TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB Year: 2000, - YearPtr: testutils.Int16Ptr(2001), + YearPtr: ptr.Of(int16(2001)), Char: "abcd", - CharPtr: testutils.StringPtr("absd"), + CharPtr: ptr.Of("absd"), VarChar: "abcd", - VarCharPtr: testutils.StringPtr("absd"), + VarCharPtr: ptr.Of("absd"), Binary: []byte("1010"), - BinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + BinaryPtr: ptr.Of([]byte("100001")), VarBinary: []byte("1010"), - VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + VarBinaryPtr: ptr.Of([]byte("100001")), Blob: []byte("large file"), - BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + BlobPtr: ptr.Of([]byte("very large file")), Text: "some text", - TextPtr: testutils.StringPtr("text"), + TextPtr: ptr.Of("text"), Enum: model.AllTypesEnum_Value1, JSON: "{}", - JSONPtr: testutils.StringPtr(`{"a": 1}`), + JSONPtr: ptr.Of(`{"a": 1}`), } var allTypesJson = ` @@ -1358,17 +1359,17 @@ func TestExactDecimals(t *testing.T) { Floats: model.Floats{ // overwritten by wrapped(floats) scope Numeric: 0.1, - NumericPtr: testutils.Float64Ptr(0.1), + NumericPtr: ptr.Of(0.1), Decimal: 0.1, - DecimalPtr: testutils.Float64Ptr(0.1), + DecimalPtr: ptr.Of(0.1), // not overwritten Float: 0.2, - FloatPtr: testutils.Float64Ptr(0.22), + FloatPtr: ptr.Of(0.22), Double: 0.3, - DoublePtr: testutils.Float64Ptr(0.33), + DoublePtr: ptr.Of(0.33), Real: 0.4, - RealPtr: testutils.Float64Ptr(0.44), + RealPtr: ptr.Of(0.44), }, Numeric: decimal.RequireFromString("12.35"), NumericPtr: decimal.RequireFromString("56.79"), @@ -1403,3 +1404,34 @@ VALUES ('91.23', '45.67', '12.35', '56.79', 0.2, 0.22, 0.3, 0.33, 0.4, 0.44); require.Equal(t, 45.67, *result.Floats.DecimalPtr) }) } + +func TestRowExpression(t *testing.T) { + now := time.Now() + nowAddHour := time.Now().Add(time.Hour) + + stmt := SELECT( + ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))), + ROW(Bool(false), DateT(now)).NOT_EQ(ROW(Bool(true), DateT(now))), + ROW(TimestampT(nowAddHour), String("txt")).IS_DISTINCT_FROM(RowExp(Raw("row(NOW(), 'png')"))), + ROW(TimestampT(now), DateTimeT(nowAddHour)).GT(ROW(TimestampT(now), DateTimeT(now))), + ROW(DateTimeT(nowAddHour), Int(1)).GT_EQ(ROW(DateTimeT(now), Int(2))), + ROW(TimestampT(now), DateTimeT(nowAddHour)).LT(ROW(TimestampT(now), DateTimeT(now))), + ROW(DateTimeT(nowAddHour), Float(1.22)).LT_EQ(ROW(DateTimeT(now), Float(2.33))), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT ROW(?, CAST(? AS DATE)) = ROW(?, CAST(? AS DATE)), + ROW(?, CAST(? AS DATE)) != ROW(?, CAST(? AS DATE)), + NOT(ROW(TIMESTAMP(?), ?) <=> (row(NOW(), 'png'))), + ROW(TIMESTAMP(?), CAST(? AS DATETIME)) > ROW(TIMESTAMP(?), CAST(? AS DATETIME)), + ROW(CAST(? AS DATETIME), ?) >= ROW(CAST(? AS DATETIME), ?), + ROW(TIMESTAMP(?), CAST(? AS DATETIME)) < ROW(TIMESTAMP(?), CAST(? AS DATETIME)), + ROW(CAST(? AS DATETIME), ?) <= ROW(CAST(? AS DATETIME), ?); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index a7b5554..2a915a5 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -8,8 +8,11 @@ import ( "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/testutils" + mysql2 "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/dbconfig" ) @@ -39,6 +42,39 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) } +func TestGenerator_TableMetadata(t *testing.T) { + var schema metadata.Schema + err := mysql.Generate(genTestDir3, dbConnection("dvds"), + template.Default(mysql2.Dialect).UseSchema(func(m metadata.Schema) template.Schema { + schema = m + return template.DefaultSchema(m) + })) + require.NoError(t, err) + + // Spot check the actor table and assert that the emitted + // properties are as expected. + var got metadata.Table + for _, table := range schema.TablesMetaData { + if table.Name == "actor" { + got = table + } + } + + want := metadata.Table{ + Name: "actor", + Columns: []metadata.Column{ + {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "smallint", Kind: "base", IsUnsigned: true}, Comment: ""}, + {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""}, + }, + } + require.Equal(t, want, got) + + err = os.RemoveAll(genTestDirRoot) + require.NoError(t, err) +} + func TestCmdGenerator(t *testing.T) { err := os.RemoveAll(genTestDir3) require.NoError(t, err) diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 7eb6ec8..54b37c0 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -3,10 +3,12 @@ package mysql import ( "context" "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table" + "github.com/stretchr/testify/require" "math/rand" "testing" @@ -300,7 +302,7 @@ func TestInsertOnDuplicateKeyUpdateNEW(t *testing.T) { ID: randId, URL: "https://www.yahoo.com", Name: "Yahoo", - Description: testutils.StringPtr("web portal and search engine"), + Description: ptr.Of("web portal and search engine"), }, }).AS_NEW(). ON_DUPLICATE_KEY_UPDATE( @@ -337,7 +339,7 @@ ON DUPLICATE KEY UPDATE id = (link.id + ?), ID: randId + 11, URL: "https://www.yahoo.com", Name: "Yahoo", - Description: testutils.StringPtr("web portal and search engine"), + Description: ptr.Of("web portal and search engine"), }) }) } diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index fd711b4..ee02a77 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -6,12 +6,9 @@ import ( jetmysql "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" - "math/rand" - "runtime" - "time" - _ "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + "runtime" "github.com/pkg/profile" "os" @@ -33,7 +30,6 @@ func sourceIsMariaDB() bool { } func TestMain(m *testing.M) { - rand.Seed(time.Now().Unix()) defer profile.Start().Stop() var err error @@ -101,3 +97,9 @@ func skipForMariaDB(t *testing.T) { t.SkipNow() } } + +func onlyMariaDB(t *testing.T) { + if !sourceIsMariaDB() { + t.SkipNow() + } +} diff --git a/tests/mysql/values_test.go b/tests/mysql/values_test.go new file mode 100644 index 0000000..09e9332 --- /dev/null +++ b/tests/mysql/values_test.go @@ -0,0 +1,347 @@ +package mysql + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" + + . "github.com/go-jet/jet/v2/mysql" + + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" +) + +func TestVALUES(t *testing.T) { + skipForMariaDB(t) + + valuesTable := VALUES( + ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")), + ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")), + ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + valuesTable.AllColumns(), + ).FROM( + valuesTable, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column_0 AS "column_0", + values_table.column_1 AS "column_1", + values_table.column_2 AS "column_2", + values_table.column_3 AS "column_3", + values_table.column_4 AS "column_4" +FROM ( + VALUES ROW(?, ?, ?, ?, ?), + ROW(? + ?, ?, ?, ?, ?), + ROW(?, ?, ?, ?, NULL) + ) AS values_table; +`) + + var dest []struct { + Column0 int + Column1 int + Column2 float32 + Column3 bool + Column4 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column0": 1, + "Column1": 2, + "Column2": 4.666, + "Column3": false, + "Column4": "txt" + }, + { + "Column0": 13, + "Column1": 22, + "Column2": 33.222, + "Column3": true, + "Column4": "png" + }, + { + "Column0": 11, + "Column1": 22, + "Column2": 33.222, + "Column3": true, + "Column4": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + skipForMariaDB(t) + + title := StringColumn("title") + releaseYear := IntegerColumn("ReleaseYear") + rentalRate := FloatColumn("rental_rate") + + lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0) + + films := VALUES( + ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate), + ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))), + ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL), + ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL), + ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))), + ).AS("film_values", + title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("last_update")) + + stmt := SELECT( + Film.AllColumns, + films.AllColumns(), + ).FROM( + Film. + INNER_JOIN(films, title.EQ(Film.Title)), + ).WHERE(AND( + Film.ReleaseYear.GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + film_values.title AS "title", + film_values.length AS "length", + film_values.''ReleaseYear'' AS "ReleaseYear", + film_values.rental_rate AS "rental_rate", + film_values.last_update AS "last_update" +FROM dvds.film + INNER JOIN ( + VALUES ROW('Chamber Italian', 117, 2005, 5.82, TIMESTAMP('2007-02-11 12:00:00')), + ROW('Grosse Wonderful', 49, 2004, 6.242, TIMESTAMP('2007-02-11 12:00:00') + INTERVAL 1 HOUR), + ROW('Airport Pollock', 54, 2001, 7.22, NULL), + ROW('Bright Encounters', 73, 2002, 8.25, NULL), + ROW('Academy Dinosaur', 83, 2010, 9.22, TIMESTAMP('2007-02-11 12:00:00') - INTERVAL 2 MINUTE) + ) AS film_values (title, length, ''ReleaseYear'', rental_rate, last_update) ON (film_values.title = film.title) +WHERE ( + (film.release_year > film_values.''ReleaseYear'') + AND (film.rental_rate < film_values.rental_rate) + ) +ORDER BY film_values.title; +`, "''", "`")) + + var dest []struct { + Film model.Film + + Title string + Length int + ReleaseYear int + RentalRate float32 + LastUpdate *time.Time + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Len(t, dest, 4) + testutils.AssertJSON(t, dest[0:2], ` +[ + { + "Film": { + "FilmID": 8, + "Title": "AIRPORT POLLOCK", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "SpecialFeatures": "Trailers", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + "Title": "Airport Pollock", + "Length": 54, + "ReleaseYear": 2001, + "RentalRate": 7.22, + "LastUpdate": null + }, + { + "Film": { + "FilmID": 98, + "Title": "BRIGHT ENCOUNTERS", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "SpecialFeatures": "Trailers", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + "Title": "Bright Encounters", + "Length": 73, + "ReleaseYear": 2002, + "RentalRate": 8.25, + "LastUpdate": null + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + skipForMariaDB(t) + + paymentID := IntegerColumn("payment_id") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + ROW(Int32(204), Float(1.21)), + ROW(Int32(207), Float(1.02)), + ROW(Int32(200), Float(1.34)), + ROW(Int32(203), Float(1.72)), + ), + ), + )( + Payment.INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)). + UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(increase)), + ).WHERE(Bool(true)), + ) + + testutils.AssertStatementSql(t, stmt, ` +WITH values_cte (payment_id, increase) AS ( + VALUES ROW(?, ?), + ROW(?, ?), + ROW(?, ?), + ROW(?, ?) +) +UPDATE dvds.payment +INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id) +SET amount = (payment.amount * values_cte.increase) +WHERE ?; +`) + + testutils.AssertExecAndRollback(t, stmt, db, 4) +} + +func TestVALUES_MariaDB(t *testing.T) { + onlyMariaDB(t) // mariadb won't accept values rows if all the elements are placeholders, so we have to use raw statement + + paymentID := IntegerColumn("payment_id") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + RawStatement(` + VALUES (204, 1.21), + (207, 1.02), + (200, 1.34), + (203, 1.72) + `), + ), + )( + SELECT( + Payment.AllColumns, + paymentsToUpdate.AllColumns(), + ).FROM( + Payment. + INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)), + ).WHERE( + increase.GT(Float(1.03)), + ).ORDER_BY( + increase, + ), + ) + + testutils.AssertStatementSql(t, stmt, ` +WITH values_cte (payment_id, increase) AS ( + VALUES (204, 1.21), + (207, 1.02), + (200, 1.34), + (203, 1.72) + +) +SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update", + values_cte.payment_id AS "payment_id", + values_cte.increase AS "increase" +FROM dvds.payment + INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id) +WHERE values_cte.increase > ? +ORDER BY values_cte.increase; +`) + + var dest []struct { + model.Payment + + Increase float64 + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "PaymentID": 204, + "CustomerID": 7, + "StaffID": 1, + "RentalID": 13476, + "Amount": 2.99, + "PaymentDate": "2005-08-20T01:06:04Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.21 + }, + { + "PaymentID": 200, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 11542, + "Amount": 7.99, + "PaymentDate": "2005-08-17T00:51:32Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.34 + }, + { + "PaymentID": 203, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 13373, + "Amount": 2.99, + "PaymentDate": "2005-08-19T21:23:31Z", + "LastUpdate": "2006-02-15T22:12:31Z", + "Increase": 1.72 + } +] +`) +} diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index d7d8d3d..59da3a5 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -164,10 +164,10 @@ WITH payments_to_delete AS ( WHERE payment.amount < 0.5 ) DELETE FROM dvds.payment -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_delete - ); + )); `, "''", "`")) tx, err := db.Begin() diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 25a29d2..6d9c0b7 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,6 +1,10 @@ package postgres import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/stretchr/testify/assert" + "github.com/go-jet/jet/v2/qrm" "testing" "time" @@ -344,24 +348,22 @@ func TestExpressionOperators(t *testing.T) { AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) - //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - )) AS "result.in_select", + ))) AS "result.in_select", (CURRENT_USER) AS "result.raw", ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM test_sample.all_types LIMIT $11; `, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2)) @@ -373,9 +375,6 @@ LIMIT $11; err := query.Query(db, &dest) require.NoError(t, err) - - //testutils.PrintJson(dest) - testutils.AssertJSON(t, dest, ` [ { @@ -930,6 +929,68 @@ func TestTimeExpression(t *testing.T) { require.NoError(t, err) } +func TestIntervalSetFunctionality(t *testing.T) { + + t.Run("updateQueryIntervalTest", func(t *testing.T) { + + expectedQuery := ` +UPDATE test_sample.employee +SET pto_accrual = INTERVAL '3 HOUR' +WHERE employee.employee_id = $1 +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var windy model.Employee + windy.PtoAccrual = ptr.Of("3h") + stmt := Employee.UPDATE(Employee.PtoAccrual).SET( + Employee.PtoAccrual.SET(INTERVAL(3, HOUR)), + ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns) + + testutils.AssertStatementSql(t, stmt, expectedQuery) + err := stmt.Query(tx, &windy) + assert.Nil(t, err) + assert.Equal(t, *windy.PtoAccrual, "03:00:00") + + }) + }) + + t.Run("upsertQueryIntervalTest", func(t *testing.T) { + expectedQuery := ` +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6) +ON CONFLICT (employee_id) DO UPDATE + SET pto_accrual = excluded.pto_accrual +RETURNING employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", + employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual"; +` + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var employee model.Employee + employee.PtoAccrual = ptr.Of("5h") + stmt := Employee.INSERT(Employee.AllColumns). + MODEL(employee). + ON_CONFLICT(Employee.EmployeeID). + DO_UPDATE(SET( + Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual), + )).RETURNING(Employee.AllColumns) + + testutils.AssertStatementSql(t, stmt, expectedQuery) + err := stmt.Query(tx, &employee) + assert.Nil(t, err) + assert.Equal(t, *employee.PtoAccrual, "05:00:00") + + }) + }) +} + func TestInterval(t *testing.T) { skipForCockroachDB(t) @@ -1046,6 +1107,46 @@ FROM test_sample.all_types; require.NoError(t, err) } +func TestRowExpression(t *testing.T) { + now := time.Now() + nowAddHour := time.Now().Add(time.Hour) + + stmt := SELECT( + ROW(Int32(1), Float32(11.22), String("john")).AS("row"), + WRAP(Int64(1), Float64(11.22), String("john")).AS("wrap"), + + ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))), + WRAP(Bool(false), DateT(now)).NOT_EQ(WRAP(Bool(true), DateT(now))), + + ROW(TimeT(nowAddHour)).IS_DISTINCT_FROM(RowExp(Raw("row(NOW()::time)"))), + ROW().IS_NOT_DISTINCT_FROM(ROW()), + + ROW(TimestampT(now), TimestampzT(nowAddHour)).GT(WRAP(TimestampT(now), TimestampzT(now))), + ROW(TimestampzT(nowAddHour)).GT_EQ(ROW(TimestampzT(now))), + WRAP(TimestampT(now), TimestampzT(nowAddHour)).LT(ROW(TimestampT(now), TimestampzT(now))), + ROW(TimestampzT(nowAddHour)).LT_EQ(ROW(TimestampzT(now))), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertStatementSql(t, stmt, ` +SELECT ROW($1::integer, $2::real, $3::text) AS "row", + ($4::bigint, $5::double precision, $6::text) AS "wrap", + ROW($7::boolean, $8::date) = ROW($9::boolean, $10::date), + ($11::boolean, $12::date) != ($13::boolean, $14::date), + ROW($15::time without time zone) IS DISTINCT FROM (row(NOW()::time)), + ROW() IS NOT DISTINCT FROM ROW(), + ROW($16::timestamp without time zone, $17::timestamp with time zone) > ($18::timestamp without time zone, $19::timestamp with time zone), + ROW($20::timestamp with time zone) >= ROW($21::timestamp with time zone), + ($22::timestamp without time zone, $23::timestamp with time zone) < ROW($24::timestamp without time zone, $25::timestamp with time zone), + ROW($26::timestamp with time zone) <= ROW($27::timestamp with time zone); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { sql string @@ -1305,32 +1406,32 @@ RETURNING all_types.json AS "all_types.json"; var moodSad = model.Mood_Sad var allTypesRow0 = model.AllTypes{ - SmallIntPtr: testutils.Int16Ptr(14), + SmallIntPtr: ptr.Of(int16(14)), SmallInt: 14, - IntegerPtr: testutils.Int32Ptr(300), + IntegerPtr: ptr.Of(int32(300)), Integer: 300, - BigIntPtr: testutils.Int64Ptr(50000), + BigIntPtr: ptr.Of(int64(50000)), BigInt: 5000, - DecimalPtr: testutils.Float64Ptr(1.11), + DecimalPtr: ptr.Of(1.11), Decimal: 1.11, - NumericPtr: testutils.Float64Ptr(2.22), + NumericPtr: ptr.Of(2.22), Numeric: 2.22, - RealPtr: testutils.Float32Ptr(5.55), + RealPtr: ptr.Of(float32(5.55)), Real: 5.55, - DoublePrecisionPtr: testutils.Float64Ptr(11111111.22), + DoublePrecisionPtr: ptr.Of(11111111.22), DoublePrecision: 11111111.22, Smallserial: 1, Serial: 1, Bigserial: 1, //MoneyPtr: nil, //Money: - VarCharPtr: testutils.StringPtr("ABBA"), + VarCharPtr: ptr.Of("ABBA"), VarChar: "ABBA", - CharPtr: testutils.StringPtr("JOHN "), + CharPtr: ptr.Of("JOHN "), Char: "JOHN ", - TextPtr: testutils.StringPtr("Some text"), + TextPtr: ptr.Of("Some text"), Text: "Some text", - ByteaPtr: testutils.ByteArrayPtr([]byte("bytea")), + ByteaPtr: ptr.Of([]byte("bytea")), Bytea: []byte("bytea"), TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), @@ -1342,31 +1443,31 @@ var allTypesRow0 = model.AllTypes{ Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"), TimePtr: testutils.TimeWithoutTimeZone("04:05:06"), Time: *testutils.TimeWithoutTimeZone("04:05:06"), - IntervalPtr: testutils.StringPtr("3 days 04:05:06"), + IntervalPtr: ptr.Of("3 days 04:05:06"), Interval: "3 days 04:05:06", - BooleanPtr: testutils.BoolPtr(true), + BooleanPtr: ptr.Of(true), Boolean: false, - PointPtr: testutils.StringPtr("(2,3)"), - BitPtr: testutils.StringPtr("101"), + PointPtr: ptr.Of("(2,3)"), + BitPtr: ptr.Of("101"), Bit: "101", - BitVaryingPtr: testutils.StringPtr("101111"), + BitVaryingPtr: ptr.Of("101111"), BitVarying: "101111", - TsvectorPtr: testutils.StringPtr("'supernova':1"), + TsvectorPtr: ptr.Of("'supernova':1"), Tsvector: "'supernova':1", UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), - XMLPtr: testutils.StringPtr("abc"), + XMLPtr: ptr.Of("abc"), XML: "abc", - JSONPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), + JSONPtr: ptr.Of(`{"a": 1, "b": 3}`), JSON: `{"a": 1, "b": 3}`, - JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), + JsonbPtr: ptr.Of(`{"a": 1, "b": 3}`), Jsonb: `{"a": 1, "b": 3}`, - IntegerArrayPtr: testutils.StringPtr("{1,2,3}"), + IntegerArrayPtr: ptr.Of("{1,2,3}"), IntegerArray: "{1,2,3}", - TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"), + TextArrayPtr: ptr.Of("{breakfast,consulting}"), TextArray: "{breakfast,consulting}", JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, - TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"), + TextMultiDimArrayPtr: ptr.Of("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", MoodPtr: &moodSad, Mood: model.Mood_Happy, diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 462e205..349a4ba 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -3,6 +3,7 @@ package postgres import ( "context" "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table" @@ -455,7 +456,7 @@ FROM ( require.Len(t, dest, 275) require.Equal(t, dest[0].Artist1.Artist, model.Artist{ ArtistId: 1, - Name: testutils.StringPtr("AC/DC"), + Name: ptr.Of("AC/DC"), }) require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1") require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2") diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index 2bc2d32..65603cd 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -3,6 +3,9 @@ package postgres import ( "database/sql" "fmt" + "path" + "testing" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/template" @@ -13,8 +16,6 @@ import ( "github.com/go-jet/jet/v2/tests/dbconfig" file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" "github.com/stretchr/testify/require" - "path" - "testing" ) const tempTestDir = "./.tempTestDir" @@ -170,6 +171,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go") require.Contains(t, mpaaRating, "type MpaaRatingEnum string") + require.Contains(t, mpaaRating, "MpaaRatingEnumAllValues") } func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { @@ -267,7 +269,6 @@ func UseSchema(schema string) { FilmList = FilmList.FromSchema(schema) } `) - } func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { @@ -365,7 +366,6 @@ func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) { } func TestGeneratorTemplate_Model_AddTags(t *testing.T) { - err := postgres.Generate( tempTestDir, dbConnection, diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 479d82f..87a7eee 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -2,7 +2,6 @@ package postgres import ( "fmt" - "github.com/go-jet/jet/v2/tests/internal/utils/file" "os" "os/exec" "path/filepath" @@ -12,11 +11,14 @@ import ( "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/postgres" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/dbconfig" - + postgres2 "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/go-jet/jet/v2/tests/internal/utils/file" ) func dsn(host string, port int, dbName, user, password string) string { @@ -208,6 +210,45 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) } +func TestGenerator_TableMetadata(t *testing.T) { + var schema metadata.Schema + err := postgres.GenerateDSN(defaultDSN(), "dvds", genTestDir2, + template.Default(postgres2.Dialect).UseSchema(func(m metadata.Schema) template.Schema { + schema = m + return template.DefaultSchema(m) + })) + require.NoError(t, err) + + // Spot check the actor table and assert that the emitted + // properties are as expected. + var got metadata.Table + var specialFeatures metadata.Column + for _, table := range schema.TablesMetaData { + if table.Name == "actor" { + got = table + } + if table.Name == "film" { + for _, column := range table.Columns { + if column.Name == "special_features" { + specialFeatures = column + } + } + } + } + + want := metadata.Table{ + Name: "actor", + Columns: []metadata.Column{ + {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "int4", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""}, + }, + } + require.Equal(t, want, got) + require.Equal(t, metadata.ArrayType, specialFeatures.DataType.Kind) +} + func TestGeneratorSpecialCharacters(t *testing.T) { t.SkipNow() err := postgres.Generate(genTestDir2, postgres.DBConnection{ @@ -563,6 +604,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go", "components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go") testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) + testutils.AssertFileContent(t, modelDir+"/link.go", linkModelContent) testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go", @@ -570,6 +612,8 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent) + testutils.AssertFileContent(t, tableDir+"/link.go", linkTableContent) + testutils.AssertFileNamesEqual(t, viewDir, "all_types_materialized_view.go", "all_types_view.go", "view_use_schema.go") } @@ -609,6 +653,7 @@ package enum import "github.com/go-jet/jet/v2/postgres" +// Level enum var Level = &struct { Level1 postgres.StringExpression Level2 postgres.StringExpression @@ -706,6 +751,25 @@ type AllTypes struct { } ` +var linkModelContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package model + +// Link table +type Link struct { + ID int64 ` + "`sql:\"primary_key\"`" + ` // this is link id + URL string // link url + Name string // Unicode characters comment ₲鬼佬℧⇄↻ + Description *string // '"Z\%_ +} +` + var allTypesTableContent = ` // // Code generated by go-jet DO NOT EDIT. @@ -1062,3 +1126,91 @@ func newSampleRangesTableImpl(schemaName, tableName, alias string) sampleRangesT } } ` + +var linkTableContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/postgres" +) + +var Link = newLinkTable("test_sample", "link", "") + +// Link table +type linkTable struct { + postgres.Table + + // Columns + ID postgres.ColumnInteger // this is link id + URL postgres.ColumnString // link url + Name postgres.ColumnString // Unicode characters comment ₲鬼佬℧⇄↻ + Description postgres.ColumnString // '"Z\%_ + + AllColumns postgres.ColumnList + MutableColumns postgres.ColumnList +} + +type LinkTable struct { + linkTable + + EXCLUDED linkTable +} + +// AS creates new LinkTable with assigned alias +func (a LinkTable) AS(alias string) *LinkTable { + return newLinkTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new LinkTable with assigned schema name +func (a LinkTable) FromSchema(schemaName string) *LinkTable { + return newLinkTable(schemaName, a.TableName(), a.Alias()) +} + +// WithPrefix creates new LinkTable with assigned table prefix +func (a LinkTable) WithPrefix(prefix string) *LinkTable { + return newLinkTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new LinkTable with assigned table suffix +func (a LinkTable) WithSuffix(suffix string) *LinkTable { + return newLinkTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + +func newLinkTable(schemaName, tableName, alias string) *LinkTable { + return &LinkTable{ + linkTable: newLinkTableImpl(schemaName, tableName, alias), + EXCLUDED: newLinkTableImpl("", "excluded", ""), + } +} + +func newLinkTableImpl(schemaName, tableName, alias string) linkTable { + var ( + IDColumn = postgres.IntegerColumn("id") + URLColumn = postgres.StringColumn("url") + NameColumn = postgres.StringColumn("name") + DescriptionColumn = postgres.StringColumn("description") + allColumns = postgres.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn} + mutableColumns = postgres.ColumnList{URLColumn, NameColumn, DescriptionColumn} + ) + + return linkTable{ + Table: postgres.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + ID: IDColumn, + URL: URLColumn, + Name: NameColumn, + Description: DescriptionColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 6f05a18..c375396 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -93,9 +93,9 @@ func TestInsertOnConflict(t *testing.T) { ON_CONFLICT().DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -111,9 +111,9 @@ ON CONFLICT DO NOTHING; ON_CONFLICT(Employee.EmployeeID).DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT (employee_id) DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -130,9 +130,9 @@ ON CONFLICT (employee_id) DO NOTHING; ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING() testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5), - ($6, $7, $8, $9, $10) +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6), + ($7, $8, $9, $10, $11, $12) ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) testutils.AssertExecAndRollback(t, stmt, db, 1) @@ -234,8 +234,8 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE ON_CONFLICT().DO_UPDATE(nil) testutils.AssertStatementSql(t, stmt, ` -INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id) -VALUES ($1, $2, $3, $4, $5); +INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual) +VALUES ($1, $2, $3, $4, $5, $6); `) testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 1251c93..1f02e69 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -5,13 +5,10 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/tests/internal/utils/repo" - "math/rand" + "github.com/jackc/pgx/v4/stdlib" "os" "runtime" "testing" - "time" - - "github.com/jackc/pgx/v4/stdlib" "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/dbconfig" @@ -44,7 +41,6 @@ func skipForCockroachDB(t *testing.T) { } func TestMain(m *testing.M) { - rand.Seed(time.Now().Unix()) defer profile.Start().Stop() setTestRoot() diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index 3084b0d..7aad0d4 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -11,34 +11,31 @@ import ( func TestNorthwindJoinEverything(t *testing.T) { - stmt := SELECT( - Customers.AllColumns, - CustomerDemographics.AllColumns, - Orders.AllColumns, - Shippers.AllColumns, - OrderDetails.AllColumns, - Products.AllColumns, - Categories.AllColumns, - Suppliers.AllColumns, - ).FROM( - Customers. - LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). - LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)). - LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)). - LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)). - LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)). - LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)). - LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)). - LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)). - LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)). - LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). - LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). - LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)), - ).ORDER_BY( - Customers.CustomerID, - Orders.OrderID, - Products.ProductID, - ) + stmt := + SELECT( + Customers.AllColumns, + CustomerDemographics.AllColumns, + Orders.AllColumns, + Shippers.AllColumns, + OrderDetails.AllColumns, + Products.AllColumns, + Categories.AllColumns, + Suppliers.AllColumns, + ).FROM( + Customers. + LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). + LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)). + LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)). + LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)). + LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)). + LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)). + LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)). + LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)). + LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)). + LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). + LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). + LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)), + ).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID) var dest []struct { model.Customers diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 3d4d011..73b652c 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -2,6 +2,7 @@ package postgres import ( "github.com/go-jet/jet/v2/qrm" + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/google/uuid" "testing" @@ -63,15 +64,15 @@ func TestExactDecimals(t *testing.T) { Floats: model.Floats{ // overwritten by wrapped(floats) scope Numeric: 0.1, - NumericPtr: testutils.Float64Ptr(0.1), + NumericPtr: ptr.Of(0.1), Decimal: 0.1, - DecimalPtr: testutils.Float64Ptr(0.1), + DecimalPtr: ptr.Of(0.1), // not overwritten Real: 0.4, - RealPtr: testutils.Float32Ptr(0.44), + RealPtr: ptr.Of(float32(0.44)), Double: 0.3, - DoublePtr: testutils.Float64Ptr(0.33), + DoublePtr: ptr.Of(0.33), }, Numeric: decimal.RequireFromString("0.1234567890123456789"), NumericPtr: decimal.RequireFromString("1.1111111111111111111"), @@ -330,11 +331,13 @@ SELECT employee.employee_id AS "employee.employee_id", employee.last_name AS "employee.last_name", employee.employment_date AS "employee.employment_date", employee.manager_id AS "employee.manager_id", + employee.pto_accrual AS "employee.pto_accrual", manager.employee_id AS "manager.employee_id", manager.first_name AS "manager.first_name", manager.last_name AS "manager.last_name", manager.employment_date AS "manager.employment_date", - manager.manager_id AS "manager.manager_id" + manager.manager_id AS "manager.manager_id", + manager.pto_accrual AS "manager.pto_accrual" FROM test_sample.employee LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id) ORDER BY employee.employee_id; @@ -369,6 +372,7 @@ ORDER BY employee.employee_id; LastName: "Hays", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06.1 +0100 CET", 1), ManagerID: nil, + PtoAccrual: ptr.Of("22:00:00"), }) require.True(t, dest[0].Manager == nil) @@ -378,7 +382,7 @@ ORDER BY employee.employee_id; FirstName: "Salley", LastName: "Lester", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1), - ManagerID: testutils.Int32Ptr(3), + ManagerID: ptr.Of(int32(3)), }) } @@ -420,7 +424,7 @@ FROM test_sample."WEIRD NAMES TABLE"; WeirdColumnName5: "Doe", WeirdColumnName6: "Doe", WeirdColumnName7: "Doe", - Weirdcolumnname8: testutils.StringPtr("Doe"), + Weirdcolumnname8: ptr.Of("Doe"), WeirdColName9: "Doe", WeirdColuName10: "Doe", WeirdColuName11: "Doe", @@ -518,7 +522,7 @@ func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) { ).MODEL( model.People{ PeopleName: "Dario", - PeopleHeightCm: testutils.Float64Ptr(120), + PeopleHeightCm: ptr.Of(120.0), }, ).RETURNING( People.MutableColumns, diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 24b5949..321fc38 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/volatiletech/null/v8" "testing" "time" @@ -93,7 +94,7 @@ func TestScanToValidDestination(t *testing.T) { err := oneInventoryQuery.Query(db, &dest) require.NoError(t, err) - require.Equal(t, dest[0], testutils.Int32Ptr(1)) + require.Equal(t, dest[0], ptr.Of(int32(1))) }) t.Run("NULL to integer", func(t *testing.T) { @@ -530,10 +531,10 @@ func TestScanToSlice(t *testing.T) { require.NoError(t, err) require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) - testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.Int32Ptr(1), testutils.Int32Ptr(2), testutils.Int32Ptr(3), testutils.Int32Ptr(4), - testutils.Int32Ptr(5), testutils.Int32Ptr(6), testutils.Int32Ptr(7), testutils.Int32Ptr(8)}) + testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{ptr.Of(int32(1)), ptr.Of(int32(2)), ptr.Of(int32(3)), ptr.Of(int32(4)), + ptr.Of(int32(5)), ptr.Of(int32(6)), ptr.Of(int32(7)), ptr.Of(int32(8))}) testutils.AssertDeepEqual(t, dest[1].Film, film2) - testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.Int32Ptr(9), testutils.Int32Ptr(10)}) + testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{ptr.Of(int32(9)), ptr.Of(int32(10))}) }) t.Run("complex struct 1", func(t *testing.T) { @@ -1076,10 +1077,10 @@ VALUES (1234, 0, 'Joe', '', NULL, 1, TRUE, '2020-02-02 10:00:00Z', NULL, 1); var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", - Address2: testutils.StringPtr(""), + Address2: ptr.Of(""), District: "England", CityID: 312, - PostalCode: testutils.StringPtr("3433"), + PostalCode: ptr.Of("3433"), Phone: "246810237916", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -1087,10 +1088,10 @@ var address256 = model.Address{ var addres517 = model.Address{ AddressID: 517, Address: "548 Uruapan Street", - Address2: testutils.StringPtr(""), + Address2: ptr.Of(""), District: "Ontario", CityID: 312, - PostalCode: testutils.StringPtr("35653"), + PostalCode: ptr.Of("35653"), Phone: "879347453467", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -1100,12 +1101,12 @@ var customer256 = model.Customer{ StoreID: 2, FirstName: "Mattie", LastName: "Hoffman", - Email: testutils.StringPtr("mattie.hoffman@sakilacustomer.org"), + Email: ptr.Of("mattie.hoffman@sakilacustomer.org"), AddressID: 256, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: testutils.Int32Ptr(1), + Active: ptr.Of(int32(1)), } var customer512 = model.Customer{ @@ -1113,12 +1114,12 @@ var customer512 = model.Customer{ StoreID: 1, FirstName: "Cecil", LastName: "Vines", - Email: testutils.StringPtr("cecil.vines@sakilacustomer.org"), + Email: ptr.Of("cecil.vines@sakilacustomer.org"), AddressID: 517, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: testutils.Int32Ptr(1), + Active: ptr.Of(int32(1)), } var countryUk = model.Country{ @@ -1151,32 +1152,32 @@ var inventory2 = model.Inventory{ var film1 = model.Film{ FilmID: 1, Title: "Academy Dinosaur", - Description: testutils.StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), - ReleaseYear: testutils.Int32Ptr(2006), + Description: ptr.Of("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), + ReleaseYear: ptr.Of(int32(2006)), LanguageID: 1, RentalDuration: 6, RentalRate: 0.99, - Length: testutils.Int16Ptr(86), + Length: ptr.Of(int16(86)), ReplacementCost: 20.99, Rating: &pgRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + SpecialFeatures: ptr.Of("{\"Deleted Scenes\",\"Behind the Scenes\"}"), Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", } var film2 = model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: testutils.Int32Ptr(2006), + Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: ptr.Of(int32(2006)), LanguageID: 1, RentalDuration: 3, RentalRate: 4.99, - Length: testutils.Int16Ptr(48), + Length: ptr.Of(int16(48)), ReplacementCost: 12.99, Rating: &gRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr(`{Trailers,"Deleted Scenes"}`), + SpecialFeatures: ptr.Of(`{Trailers,"Deleted Scenes"}`), Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`, } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 0bcbffd..8bc816e 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "github.com/go-jet/jet/v2/internal/utils/ptr" "testing" "time" @@ -1828,16 +1829,16 @@ ORDER BY film.film_id ASC; testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: testutils.Int32Ptr(2006), + Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: ptr.Of(int32(2006)), LanguageID: 1, RentalRate: 4.99, - Length: testutils.Int16Ptr(48), + Length: ptr.Of(int16(48)), ReplacementCost: 12.99, Rating: &gRating, RentalDuration: 3, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: testutils.StringPtr("{Trailers,\"Deleted Scenes\"}"), + SpecialFeatures: ptr.Of("{Trailers,\"Deleted Scenes\"}"), Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) } @@ -2286,11 +2287,11 @@ ORDER BY customer_payment_sum.amount_sum ASC; FirstName: "Brian", LastName: "Wyman", AddressID: 323, - Email: testutils.StringPtr("brian.wyman@sakilacustomer.org"), + Email: ptr.Of("brian.wyman@sakilacustomer.org"), Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: ptr.Of(int32(1)), }) require.Equal(t, customersWithAmounts[0].AmountSum, 27.93) @@ -3110,8 +3111,8 @@ func TestSelectDynamicCondition(t *testing.T) { Active *bool } - request.CustomerID = testutils.Int64Ptr(1) - request.Active = testutils.BoolPtr(true) + request.CustomerID = ptr.Of(int64(1)) + request.Active = ptr.Of(true) // ... @@ -3871,12 +3872,12 @@ var customer0 = model.Customer{ StoreID: 1, FirstName: "Mary", LastName: "Smith", - Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), + Email: ptr.Of("mary.smith@sakilacustomer.org"), AddressID: 5, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: ptr.Of(int32(1)), } var customer1 = model.Customer{ @@ -3884,12 +3885,12 @@ var customer1 = model.Customer{ StoreID: 1, FirstName: "Patricia", LastName: "Johnson", - Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), + Email: ptr.Of("patricia.johnson@sakilacustomer.org"), AddressID: 6, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: ptr.Of(int32(1)), } var lastCustomer = model.Customer{ @@ -3897,10 +3898,10 @@ var lastCustomer = model.Customer{ StoreID: 2, FirstName: "Austin", LastName: "Cintron", - Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), + Email: ptr.Of("austin.cintron@sakilacustomer.org"), AddressID: 605, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), + Active: ptr.Of(int32(1)), } diff --git a/tests/postgres/values_test.go b/tests/postgres/values_test.go new file mode 100644 index 0000000..9e89f7e --- /dev/null +++ b/tests/postgres/values_test.go @@ -0,0 +1,284 @@ +package postgres + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestVALUES(t *testing.T) { + + values := VALUES( + WRAP(Int32(1), Int32(2), Float32(4.666), Bool(false), String("txt")), + WRAP(Int32(11).ADD(Int32(2)), Int32(22), Float32(33.222), Bool(true), String("png")), + WRAP(Int32(11), Int32(22), Float32(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + values.AllColumns(), + ).FROM( + values, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column1 AS "column1", + values_table.column2 AS "column2", + values_table.column3 AS "column3", + values_table.column4 AS "column4", + values_table.column5 AS "column5" +FROM ( + VALUES ($1::integer, $2::integer, $3::real, $4::boolean, $5::text), + ($6::integer + $7::integer, $8::integer, $9::real, $10::boolean, $11::text), + ($12::integer, $13::integer, $14::real, $15::boolean, NULL) + ) AS values_table; +`) + + var dest []struct { + Column1 int + Column2 int + Column3 float32 + Column4 bool + Column5 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column1": 1, + "Column2": 2, + "Column3": 4.666, + "Column4": false, + "Column5": "txt" + }, + { + "Column1": 13, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": "png" + }, + { + "Column1": 11, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + + title := StringColumn("title") + releaseYear := IntegerColumn("ReleaseYear") + rentalRate := FloatColumn("rental_rate") + + lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0) + + filmValues := VALUES( + WRAP(String("Chamber Italian"), Int64(117), Int32(2005), Float32(5.82), lastUpdate), + WRAP(String("Grosse Wonderful"), Int64(49), Int32(2004), Float32(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))), + WRAP(String("Airport Pollock"), Int64(54), Int32(2001), Float32(7.22), NULL), + WRAP(String("Bright Encounters"), Int64(73), Int32(2002), Float32(8.25), NULL), + WRAP(String("Academy Dinosaur"), Int64(83), Int32(2010), Float32(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))), + ).AS("film_values", + title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("update_date")) + + stmt := SELECT( + Film.AllColumns, + filmValues.AllColumns(), + ).FROM( + Film. + INNER_JOIN(filmValues, title.EQ(Film.Title)), + ).WHERE(AND( + Film.ReleaseYear.GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.last_update AS "film.last_update", + film.special_features AS "film.special_features", + film.fulltext AS "film.fulltext", + film_values.title AS "title", + film_values.length AS "length", + film_values."ReleaseYear" AS "ReleaseYear", + film_values.rental_rate AS "rental_rate", + film_values.update_date AS "update_date" +FROM dvds.film + INNER JOIN ( + VALUES ('Chamber Italian'::text, 117::bigint, 2005::integer, 5.820000171661377::real, '2007-02-11 12:00:00'::timestamp without time zone), + ('Grosse Wonderful'::text, 49::bigint, 2004::integer, 6.242000102996826::real, '2007-02-11 12:00:00'::timestamp without time zone + INTERVAL '1 HOUR'), + ('Airport Pollock'::text, 54::bigint, 2001::integer, 7.21999979019165::real, NULL), + ('Bright Encounters'::text, 73::bigint, 2002::integer, 8.25::real, NULL), + ('Academy Dinosaur'::text, 83::bigint, 2010::integer, 9.220000267028809::real, '2007-02-11 12:00:00'::timestamp without time zone - INTERVAL '2 MINUTE') + ) AS film_values (title, length, "ReleaseYear", rental_rate, update_date) ON (film_values.title = film.title) +WHERE ( + (film.release_year > film_values."ReleaseYear") + AND (film.rental_rate < film_values.rental_rate) + ) +ORDER BY film_values.title; +`) + + //fmt.Println(stmt.DebugSql()) + + var dest []struct { + Film model.Film + + Title string + Length int + ReleaseYear int + RentalRate float32 + UpdateDate *time.Time + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + assert.Len(t, dest, 4) + testutils.AssertJSON(t, dest[0:2], ` +[ + { + "Film": { + "FilmID": 8, + "Title": "Airport Pollock", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers}", + "Fulltext": "'airport':1 'ancient':18 'confront':14 'epic':4 'girl':11 'india':19 'monkey':16 'moos':8 'must':13 'pollock':2 'tale':5" + }, + "Title": "Airport Pollock", + "Length": 54, + "ReleaseYear": 2001, + "RentalRate": 7.22, + "UpdateDate": null + }, + { + "Film": { + "FilmID": 98, + "Title": "Bright Encounters", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers}", + "Fulltext": "'boat':20 'bright':1 'conquer':14 'encount':2 'fate':4 'feminist':11 'jet':19 'lumberjack':8 'must':13 'student':16 'yarn':5" + }, + "Title": "Bright Encounters", + "Length": 73, + "ReleaseYear": 2002, + "RentalRate": 8.25, + "UpdateDate": null + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + + paymentID := IntegerColumn("payment_ID") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + WRAP(Int32(20564), Float32(1.21)), + WRAP(Int32(20567), Float32(1.02)), + WRAP(Int32(20570), Float32(1.34)), + WRAP(Int32(20573), Float32(1.72)), + ), + ), + )( + Payment.UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(CAST(increase).AS_DECIMAL())), + ). + FROM(paymentsToUpdate). + WHERE(Payment.PaymentID.EQ(paymentID)). + RETURNING(Payment.AllColumns), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +WITH values_cte ("payment_ID", increase) AS ( + VALUES (20564::integer, 1.2100000381469727::real), + (20567::integer, 1.0199999809265137::real), + (20570::integer, 1.340000033378601::real), + (20573::integer, 1.7200000286102295::real) +) +UPDATE dvds.payment +SET amount = (payment.amount * values_cte.increase::decimal) +FROM values_cte +WHERE payment.payment_id = values_cte."payment_ID" +RETURNING payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + + var payments []model.Payment + + err := stmt.Query(tx, &payments) + require.NoError(t, err) + + assert.Len(t, payments, 4) + testutils.AssertJSON(t, payments[0:2], ` +[ + { + "PaymentID": 20564, + "CustomerID": 379, + "StaffID": 2, + "RentalID": 11457, + "Amount": 4.83, + "PaymentDate": "2007-03-02T19:42:42.996577Z" + }, + { + "PaymentID": 20567, + "CustomerID": 379, + "StaffID": 2, + "RentalID": 13397, + "Amount": 8.15, + "PaymentDate": "2007-03-19T20:35:01.996577Z" + } +] +`) + }) + +} diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 21fca32..92b5649 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -83,10 +83,10 @@ SELECT orders.ship_region AS "orders.ship_region", SUM(order_details.quantity) AS "product_sales" FROM northwind.orders INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) -WHERE orders.ship_region IN ( +WHERE orders.ship_region IN (( SELECT top_region."orders.ship_region" AS "orders.ship_region" FROM top_region - ) + )) GROUP BY orders.ship_region, order_details.product_id ORDER BY SUM(order_details.quantity) DESC; `) @@ -157,19 +157,19 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { testutils.AssertStatementSql(t, stmt, ` WITH remove_discontinued_orders AS ( DELETE FROM northwind.order_details - WHERE order_details.product_id IN ( + WHERE order_details.product_id IN (( SELECT products.product_id AS "products.product_id" FROM northwind.products WHERE products.discontinued = $1 - ) + )) RETURNING order_details.product_id AS "order_details.product_id" ),update_discontinued_price AS ( UPDATE northwind.products SET unit_price = $2 - WHERE products.product_id IN ( + WHERE products.product_id IN (( SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" FROM remove_discontinued_orders - ) + )) RETURNING products.product_id AS "products.product_id", products.product_name AS "products.product_name", products.supplier_id AS "products.supplier_id", diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 58a9da6..080e870 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -2,6 +2,7 @@ package sqlite import ( "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" @@ -153,43 +154,43 @@ func TestAllTypesInsert(t *testing.T) { var toInsert = model.AllTypes{ Boolean: false, - BooleanPtr: testutils.BoolPtr(true), + BooleanPtr: ptr.Of(true), TinyInt: 1, SmallInt: 3, MediumInt: 5, Integer: 7, BigInt: 9, - TinyIntPtr: testutils.Int8Ptr(11), - SmallIntPtr: testutils.Int16Ptr(33), - MediumIntPtr: testutils.Int32Ptr(55), - IntegerPtr: testutils.Int32Ptr(77), - BigIntPtr: testutils.Int64Ptr(99), + TinyIntPtr: ptr.Of(int8(11)), + SmallIntPtr: ptr.Of(int16(33)), + MediumIntPtr: ptr.Of(int32(55)), + IntegerPtr: ptr.Of(int32(77)), + BigIntPtr: ptr.Of(int64(99)), Decimal: 11.22, - DecimalPtr: testutils.Float64Ptr(33.44), + DecimalPtr: ptr.Of(33.44), Numeric: 55.66, - NumericPtr: testutils.Float64Ptr(77.88), + NumericPtr: ptr.Of(77.88), Float: 99.00, - FloatPtr: testutils.Float64Ptr(11.22), + FloatPtr: ptr.Of(11.22), Double: 33.44, - DoublePtr: testutils.Float64Ptr(55.66), + DoublePtr: ptr.Of(55.66), Real: 77.88, - RealPtr: testutils.Float32Ptr(99.00), + RealPtr: ptr.Of(float32(99.00)), Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC), - TimePtr: testutils.TimePtr(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), + TimePtr: ptr.Of(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), Date: time.Now(), - DatePtr: testutils.TimePtr(time.Now()), + DatePtr: ptr.Of(time.Now()), DateTime: time.Now(), - DateTimePtr: testutils.TimePtr(time.Now()), + DateTimePtr: ptr.Of(time.Now()), Timestamp: time.Now(), - TimestampPtr: testutils.TimePtr(time.Now()), + TimestampPtr: ptr.Of(time.Now()), Char: "abcd", - CharPtr: testutils.StringPtr("absd"), + CharPtr: ptr.Of("absd"), VarChar: "abcd", - VarCharPtr: testutils.StringPtr("absd"), + VarCharPtr: ptr.Of("absd"), Blob: []byte("large file"), - BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + BlobPtr: ptr.Of([]byte("very large file")), Text: "some text", - TextPtr: testutils.StringPtr("text"), + TextPtr: ptr.Of("text"), } func TestUUID(t *testing.T) { @@ -233,18 +234,18 @@ func TestExpressionOperators(t *testing.T) { SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", - (all_types.small_int_ptr IN ( + (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" FROM all_types - )) AS "result.in_select", + ))) AS "result.in_select", (length(121232459)) AS "result.raw", (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", - (all_types.small_int_ptr NOT IN ( + (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM all_types - )) AS "result.not_in_select" + ))) AS "result.not_in_select" FROM all_types LIMIT ?; `, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) @@ -659,7 +660,7 @@ func TestExactDecimals(t *testing.T) { // not overwritten Numeric: "6.7", - NumericPtr: testutils.StringPtr("7.7"), + NumericPtr: ptr.Of("7.7"), }, Decimal: decimal.RequireFromString("91.23"), DecimalPtr: decimal.RequireFromString("45.67"), @@ -899,3 +900,36 @@ func TestDateTimeExpressions(t *testing.T) { require.Equal(t, dest.JulianDay, 2.4551543576232754e+06) require.Equal(t, dest.StrfTime, "20:34") } + +func TestRowExpression(t *testing.T) { + date := Date(2000, 9, 9) + time := Time(11, 22, 11) + dateTime := DateTime(2008, 11, 22, 10, 12, 40) + dateTime2 := DateTime(2011, 1, 2, 5, 12, 40) + + stmt := SELECT( + ROW(Bool(false), date).EQ(ROW(Bool(true), date)), + ROW(Bool(false), time).NOT_EQ(ROW(Bool(true), time)), + ROW(time).IS_DISTINCT_FROM(RowExp(Raw("(time('now'))"))), + ROW(dateTime, dateTime2).GT(ROW(dateTime, dateTime2)), + ROW(dateTime2).GT_EQ(ROW(dateTime)), + ROW(dateTime, dateTime2).LT(ROW(dateTime, dateTime2)), + ROW(dateTime2).LT_EQ(ROW(dateTime2)), + ) + + //fmt.Println(stmt.Sql()) + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT (FALSE, DATE('2000-09-09')) = (TRUE, DATE('2000-09-09')), + (FALSE, TIME('11:22:11')) != (TRUE, TIME('11:22:11')), + (TIME('11:22:11')) IS NOT ((time('now'))), + (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) > (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')), + (DATETIME('2011-01-02 05:12:40')) >= (DATETIME('2008-11-22 10:12:40')), + (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) < (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')), + (DATETIME('2011-01-02 05:12:40')) <= (DATETIME('2011-01-02 05:12:40')); +`) + + err := stmt.Query(db, &struct{}{}) + require.NoError(t, err) +} diff --git a/tests/sqlite/generator_test.go b/tests/sqlite/generator_test.go index b232aba..8a15568 100644 --- a/tests/sqlite/generator_test.go +++ b/tests/sqlite/generator_test.go @@ -8,8 +8,11 @@ import ( "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/testutils" + sqlite2 "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" "github.com/go-jet/jet/v2/tests/internal/utils/repo" ) @@ -58,6 +61,36 @@ func TestGenerator(t *testing.T) { require.NoError(t, err) } +func TestGenerator_TableMetadata(t *testing.T) { + var schema metadata.Schema + err := sqlite.GenerateDSN(testDatabaseFilePath, genDestDir, + template.Default(sqlite2.Dialect).UseSchema(func(m metadata.Schema) template.Schema { + schema = m + return template.DefaultSchema(m) + })) + require.NoError(t, err) + + // Spot check the actor table and assert that the emitted + // properties are as expected. + var got metadata.Table + for _, table := range schema.TablesMetaData { + if table.Name == "actor" { + got = table + } + } + + want := metadata.Table{ + Name: "actor", + Columns: []metadata.Column{ + {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "INTEGER", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""}, + {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "TIMESTAMP", Kind: "base", IsUnsigned: false}, Comment: ""}, + }, + } + require.Equal(t, want, got) +} + func TestCmdGenerator(t *testing.T) { cmd := exec.Command("jet", "-source=SQLite", "-dsn=file://"+testDatabaseFilePath, "-path="+genDestDir) diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go index 079ba4f..c504d90 100644 --- a/tests/sqlite/insert_test.go +++ b/tests/sqlite/insert_test.go @@ -3,6 +3,8 @@ package sqlite import ( "context" "github.com/go-jet/jet/v2/qrm" + "database/sql" + "github.com/go-jet/jet/v2/internal/utils/ptr" "math/rand" "testing" @@ -49,7 +51,7 @@ VALUES (?, ?, ?, ?), ID: 101, URL: "http://www.google.com", Name: "Google", - Description: testutils.StringPtr("Search engine"), + Description: ptr.Of("Search engine"), }) testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ ID: 102, diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 3e06124..593d630 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -8,16 +8,11 @@ import ( "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" - "math/rand" - "os" - "os/exec" - "runtime" - "strings" - "testing" - "time" - "github.com/pkg/profile" + "github.com/stretchr/testify/require" + "os" + "runtime" + "testing" _ "github.com/mattn/go-sqlite3" ) @@ -27,11 +22,8 @@ var sampleDB *sqlite.DB var testRoot string func TestMain(m *testing.M) { - rand.Seed(time.Now().Unix()) defer profile.Start().Stop() - setTestRoot() - sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) throw.OnError(err) db = sqlite.NewDB(sqlDB).WithStatementsCaching(true) @@ -59,16 +51,6 @@ func TestMain(m *testing.M) { } } -func setTestRoot() { - cmd := exec.Command("git", "rev-parse", "--show-toplevel") - byteArr, err := cmd.Output() - if err != nil { - panic(err) - } - - testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" -} - var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go index 1a68b4c..0655f1f 100644 --- a/tests/sqlite/sample_test.go +++ b/tests/sqlite/sample_test.go @@ -3,6 +3,7 @@ package sqlite import ( "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/qrm" + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/require" "testing" @@ -54,7 +55,7 @@ WHERE people.people_id = ?; ).MODEL( model.People{ PeopleName: "Dario", - PeopleHeightCm: testutils.Float64Ptr(190), + PeopleHeightCm: ptr.Of(190.0), }, ).RETURNING( People.AllColumns, diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index 57ee587..52c89b3 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "github.com/go-jet/jet/v2/internal/utils/ptr" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table" "strings" @@ -846,15 +847,15 @@ func TestSimpleView(t *testing.T) { require.Equal(t, len(dest), 10) require.Equal(t, dest[2], model.CustomerList{ - ID: testutils.Int32Ptr(3), - Name: testutils.StringPtr("LINDA WILLIAMS"), - Address: testutils.StringPtr("692 Joliet Street"), - ZipCode: testutils.StringPtr("83579"), - Phone: testutils.StringPtr(" "), - City: testutils.StringPtr("Athenai"), - Country: testutils.StringPtr("Greece"), - Notes: testutils.StringPtr("active"), - Sid: testutils.Int32Ptr(1), + ID: ptr.Of(int32(3)), + Name: ptr.Of("LINDA WILLIAMS"), + Address: ptr.Of("692 Joliet Street"), + ZipCode: ptr.Of("83579"), + Phone: ptr.Of(" "), + City: ptr.Of("Athenai"), + Country: ptr.Of("Greece"), + Notes: ptr.Of("active"), + Sid: ptr.Of(int32(1)), }) } diff --git a/tests/sqlite/values_test.go b/tests/sqlite/values_test.go new file mode 100644 index 0000000..0793397 --- /dev/null +++ b/tests/sqlite/values_test.go @@ -0,0 +1,344 @@ +package sqlite + +import ( + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" + + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" +) + +func TestVALUES(t *testing.T) { + + values := VALUES( + ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")), + ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")), + ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL), + ).AS("values_table") + + stmt := SELECT( + values.AllColumns(), + ).FROM( + values, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT values_table.column1 AS "column1", + values_table.column2 AS "column2", + values_table.column3 AS "column3", + values_table.column4 AS "column4", + values_table.column5 AS "column5" +FROM ( + VALUES (?, ?, ?, ?, ?), + (? + ?, ?, ?, ?, ?), + (?, ?, ?, ?, NULL) + ) AS values_table; +`) + + var dest []struct { + Column1 int + Column2 int + Column3 float32 + Column4 bool + Column5 *string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Column1": 1, + "Column2": 2, + "Column3": 4.666, + "Column4": false, + "Column5": "txt" + }, + { + "Column1": 13, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": "png" + }, + { + "Column1": 11, + "Column2": 22, + "Column3": 33.222, + "Column4": true, + "Column5": null + } +] +`) +} + +func TestVALUES_Join(t *testing.T) { + + lastUpdate := DateTime(2007, time.February, 11, 12, 0, 0) + + films := VALUES( + ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate), + ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate), + ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL), + ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL), + ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), DATETIME(lastUpdate, YEARS(2))), + ).AS("film_values") + + title := StringColumn("column1").From(films) + releaseYear := IntegerColumn("column3").From(films) + rentalRate := FloatColumn("column4").From(films) + + stmt := SELECT( + Film.AllColumns, + films.AllColumns(), + ).FROM( + Film. + INNER_JOIN(films, LOWER(title).EQ(LOWER(Film.Title))), + ).WHERE(AND( + CAST(Film.ReleaseYear).AS_INTEGER().GT(releaseYear), + Film.RentalRate.LT(rentalRate), + )).ORDER_BY( + title, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + film_values.column1 AS "column1", + film_values.column2 AS "column2", + film_values.column3 AS "column3", + film_values.column4 AS "column4", + film_values.column5 AS "column5" +FROM film + INNER JOIN ( + VALUES ('Chamber Italian', 117, 2005, 5.82, DATETIME('2007-02-11 12:00:00')), + ('Grosse Wonderful', 49, 2004, 6.242, DATETIME('2007-02-11 12:00:00')), + ('Airport Pollock', 54, 2001, 7.22, NULL), + ('Bright Encounters', 73, 2002, 8.25, NULL), + ('Academy Dinosaur', 83, 2010, 9.22, DATETIME(DATETIME('2007-02-11 12:00:00'), '2 YEARS')) + ) AS film_values ON (LOWER(film_values.column1) = LOWER(film.title)) +WHERE ( + (CAST(film.release_year AS INTEGER) > film_values.column3) + AND (film.rental_rate < film_values.column4) + ) +ORDER BY film_values.column1; +`) + + var dest []struct { + Film model.Film + + Column1 string + Column2 int + Column3 int + Column4 float32 + Column5 *time.Time + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "Film": { + "FilmID": 8, + "Title": "AIRPORT POLLOCK", + "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 4.99, + "Length": 54, + "ReplacementCost": 15.99, + "Rating": "R", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Airport Pollock", + "Column2": 54, + "Column3": 2001, + "Column4": 7.22, + "Column5": null + }, + { + "Film": { + "FilmID": 98, + "Title": "BRIGHT ENCOUNTERS", + "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 4, + "RentalRate": 4.99, + "Length": 73, + "ReplacementCost": 12.99, + "Rating": "PG-13", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Bright Encounters", + "Column2": 73, + "Column3": 2002, + "Column4": 8.25, + "Column5": null + }, + { + "Film": { + "FilmID": 133, + "Title": "CHAMBER ITALIAN", + "Description": "A Fateful Reflection of a Moose And a Husband who must Overcome a Monkey in Nigeria", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 7, + "RentalRate": 4.99, + "Length": 117, + "ReplacementCost": 14.99, + "Rating": "NC-17", + "SpecialFeatures": "Trailers", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Chamber Italian", + "Column2": 117, + "Column3": 2005, + "Column4": 5.82, + "Column5": "2007-02-11T12:00:00Z" + }, + { + "Film": { + "FilmID": 384, + "Title": "GROSSE WONDERFUL", + "Description": "A Epic Drama of a Cat And a Explorer who must Redeem a Moose in Australia", + "ReleaseYear": "2006", + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 5, + "RentalRate": 4.99, + "Length": 49, + "ReplacementCost": 19.99, + "Rating": "R", + "SpecialFeatures": "Behind the Scenes", + "LastUpdate": "2019-04-11T18:11:48Z" + }, + "Column1": "Grosse Wonderful", + "Column2": 49, + "Column3": 2004, + "Column4": 6.242, + "Column5": "2007-02-11T12:00:00Z" + } +] +`) +} + +func TestVALUES_CTE_Update(t *testing.T) { + + paymentID := IntegerColumn("payment_ID") + increase := FloatColumn("increase") + paymentsToUpdate := CTE("values_cte", paymentID, increase) + + stmt := WITH( + paymentsToUpdate.AS( + VALUES( + ROW(Int32(204), Float(1.21)), + ROW(Int32(207), Float(1.02)), + ROW(Int32(200), Float(1.34)), + ROW(Int32(203), Float(1.72)), + ), + ), + )( + Payment.UPDATE(). + SET( + Payment.Amount.SET(Payment.Amount.MUL(increase)), + ). + FROM(paymentsToUpdate). + WHERE(Payment.PaymentID.EQ(paymentID)). + RETURNING(Payment.AllColumns), + ) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +WITH values_cte (''payment_ID'', increase) AS ( + VALUES (?, ?), + (?, ?), + (?, ?), + (?, ?) +) +UPDATE payment +SET amount = (payment.amount * values_cte.increase) +FROM values_cte +WHERE payment.payment_id = values_cte.''payment_ID'' +RETURNING payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update"; +`, "''", "`")) + + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var payments []model.Payment + + err := stmt.Query(tx, &payments) + + require.NoError(t, err) + testutils.AssertJSON(t, payments, ` +[ + { + "PaymentID": 200, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 11542, + "Amount": 10.706600000000002, + "PaymentDate": "2005-08-17T00:51:32Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 203, + "CustomerID": 7, + "StaffID": 2, + "RentalID": 13373, + "Amount": 5.1428, + "PaymentDate": "2005-08-19T21:23:31Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 204, + "CustomerID": 7, + "StaffID": 1, + "RentalID": 13476, + "Amount": 3.6179, + "PaymentDate": "2005-08-20T01:06:04Z", + "LastUpdate": "2019-04-11T18:11:50Z" + }, + { + "PaymentID": 207, + "CustomerID": 8, + "StaffID": 2, + "RentalID": 866, + "Amount": 7.1298, + "PaymentDate": "2005-05-30T03:43:54Z", + "LastUpdate": "2019-04-11T18:11:50Z" + } +] +`) + }) + +} diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go index 402df2f..485655c 100644 --- a/tests/sqlite/with_test.go +++ b/tests/sqlite/with_test.go @@ -153,10 +153,10 @@ WITH payments_to_update AS ( ) UPDATE payment SET amount = 0 -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_update - ); + )); `, "''", "`", -1)) tx := beginDBTx(t) @@ -205,10 +205,10 @@ WITH payments_to_delete AS ( WHERE payment.amount < 0.5 ) DELETE FROM payment -WHERE payment.payment_id IN ( +WHERE payment.payment_id IN (( SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" FROM payments_to_delete - ); + )); `, "''", "`", -1)) tx := beginDBTx(t) diff --git a/tests/testdata b/tests/testdata index 915bdc1..6a39774 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 915bdc16b723d89becc577c780949baef861a6ae +Subproject commit 6a397747d310938b41d3950d68009578180d3dd5