diff --git a/.circleci/config.yml b/.circleci/config.yml
index e21f00a..74f2bd8 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -8,7 +8,7 @@ jobs:
build_and_tests:
docker:
# specify the version
- - image: cimg/go:1.21.6
+ - image: cimg/go:1.22.8
- image: cimg/postgres:14.10
environment:
POSTGRES_USER: jet
@@ -128,17 +128,13 @@ jobs:
# to create test results report
- run:
- name: Install go-junit-report
- command: go install github.com/jstemmer/go-junit-report@latest
-
+ name: Install gotestsum
+ command: go install gotest.tools/gotestsum@latest
- run: mkdir -p $TEST_RESULTS
- # this will run all tests and exclude test files from code coverage report
- - run: |
- go test -v ./... \
- -covermode=atomic \
- -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \
- -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
+ - run:
+ name: Running tests
+ command: gotestsum --junitfile $TEST_RESULTS/report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... ./...
# run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests
- run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/
@@ -163,4 +159,4 @@ workflows:
version: 2
build_and_test:
jobs:
- - build_and_tests
\ No newline at end of file
+ - build_and_tests
diff --git a/.github/workflows/code_scanner.yml b/.github/workflows/code_scanner.yml
new file mode 100644
index 0000000..d4ecc5d
--- /dev/null
+++ b/.github/workflows/code_scanner.yml
@@ -0,0 +1,45 @@
+name: Code Scanners
+on:
+ push:
+ branches:
+ - master
+ pull_request:
+ branches:
+ - master
+
+permissions:
+ contents: read
+
+env:
+ go_version: "1.22.8"
+
+
+jobs:
+ security_scanning:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Source
+ uses: actions/checkout@v4
+ - uses: actions/setup-go@v5
+ with:
+ go-version: ${{ env.go_version }}
+ cache: true
+ - name: Setup Tools
+ run: |
+ go install github.com/securego/gosec/v2/cmd/gosec@latest
+ - name: Running Scan
+ run: gosec --exclude=G402,G304 ./...
+ lint_scanner:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout Source
+ uses: actions/checkout@v4
+ - uses: actions/setup-go@v5
+ with:
+ go-version: ${{ env.go_version }}
+ cache: true
+ - name: Setup Tools
+ run: |
+ go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
+ - name: Running Scan
+ run: golangci-lint run --timeout=30m ./...
diff --git a/.golangci.yml b/.golangci.yml
new file mode 100644
index 0000000..4eabe05
--- /dev/null
+++ b/.golangci.yml
@@ -0,0 +1,19 @@
+run:
+ # The default concurrency value is the number of available CPU.
+ concurrency: 4
+ # Timeout for analysis, e.g. 30s, 5m.
+ # Default: 1m
+ timeout: 30m
+ # Exit code when at least one issue was found.
+ # Default: 1
+ issues-exit-code: 2
+ # Include test files or not.
+ # Default: true
+ tests: false
+
+issues:
+ exclude-dirs:
+ - tests
+ exclude-files:
+ - "_test.go"
+ - "testutils.go"
diff --git a/cmd/jet/version.go b/cmd/jet/version.go
index 4241ec7..3300008 100644
--- a/cmd/jet/version.go
+++ b/cmd/jet/version.go
@@ -1,3 +1,3 @@
package main
-const version = "v2.11.0"
+const version = "v2.11.1"
diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go
index e453f51..dc84989 100644
--- a/examples/quick-start/quick-start.go
+++ b/examples/quick-start/quick-start.go
@@ -4,7 +4,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
- "io/ioutil"
+ "os"
_ "github.com/lib/pq"
@@ -90,7 +90,7 @@ func main() {
func jsonSave(path string, v interface{}) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
- err := ioutil.WriteFile(path, jsonText, 0644)
+ err := os.WriteFile(path, jsonText, 0600)
panicOnError(err)
}
diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go
index 5c888a3..ecd61e2 100644
--- a/generator/metadata/column_meta_data.go
+++ b/generator/metadata/column_meta_data.go
@@ -1,29 +1,16 @@
package metadata
-import (
- "regexp"
-)
-
// Column struct
type Column struct {
Name string `sql:"primary_key"`
IsPrimaryKey bool
IsNullable bool
IsGenerated bool
+ HasDefault bool
DataType DataType
Comment string
}
-// GoLangComment returns column comment without ascii control characters
-func (c Column) GoLangComment() string {
- if c.Comment == "" {
- return ""
- }
-
- // remove ascii control characters from string
- return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(c.Comment, "")
-}
-
// DataTypeKind is database type kind(base, enum, user-defined, array)
type DataTypeKind string
diff --git a/generator/metadata/enum_meta_data.go b/generator/metadata/enum_meta_data.go
index 7aea3d6..8cce596 100644
--- a/generator/metadata/enum_meta_data.go
+++ b/generator/metadata/enum_meta_data.go
@@ -2,6 +2,7 @@ package metadata
// Enum metadata struct
type Enum struct {
- Name string `sql:"primary_key"`
- Values []string
+ Name string `sql:"primary_key"`
+ Comment string
+ Values []string
}
diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go
index df9514e..1d56bc0 100644
--- a/generator/metadata/table_meta_data.go
+++ b/generator/metadata/table_meta_data.go
@@ -3,6 +3,7 @@ package metadata
// Table metadata struct
type Table struct {
Name string `sql:"primary_key"`
+ Comment string
Columns []Column
}
diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go
index bd4d593..6dc2714 100644
--- a/generator/mysql/query_set.go
+++ b/generator/mysql/query_set.go
@@ -18,6 +18,7 @@ func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTyp
SELECT
t.table_name as "table.name",
col.COLUMN_NAME AS "column.Name",
+ col.COLUMN_DEFAULT IS NOT NULL as "column.HasDefault",
col.IS_NULLABLE = "YES" AS "column.IsNullable",
col.COLUMN_COMMENT AS "column.Comment",
COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey",
diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go
index 2d8835b..fc4135a 100644
--- a/generator/postgres/query_set.go
+++ b/generator/postgres/query_set.go
@@ -14,7 +14,7 @@ type postgresQuerySet struct{}
func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
query := `
-SELECT table_name as "table.name"
+SELECT table_name as "table.name", obj_description((quote_ident(table_schema)||'.'||quote_ident(table_name))::regclass) as "table.comment"
FROM information_schema.tables
WHERE table_schema = $1 and table_type = $2
ORDER BY table_name;
@@ -57,6 +57,7 @@ func getColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]meta
query := `
select
attr.attname as "column.Name",
+ col_description(attr.attrelid, attr.attnum) as "column.Comment",
exists(
select 1
from pg_catalog.pg_index indx
@@ -64,11 +65,13 @@ select
) as "column.IsPrimaryKey",
not attr.attnotnull as "column.isNullable",
attr.attgenerated = 's' as "column.isGenerated",
- (case tp.typtype
- when 'b' then 'base'
- when 'd' then 'base'
- when 'e' then 'enum'
- when 'r' then 'range'
+ attr.atthasdef as "column.hasDefault",
+ (case
+ when tp.typtype = 'b' AND tp.typcategory <> 'A' then 'base'
+ when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array'
+ when tp.typtype = 'd' then 'base'
+ when tp.typtype = 'e' then 'enum'
+ when tp.typtype = 'r' then 'range'
end) as "dataType.Kind",
(case when tp.typtype = 'd' then (select pg_type.typname from pg_catalog.pg_type where pg_type.oid = tp.typbasetype)
when tp.typcategory = 'A' then pg_catalog.format_type(attr.atttypid, attr.atttypmod)
@@ -99,6 +102,7 @@ order by
func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
query := `
SELECT t.typname as "enum.name",
+ obj_description(t.oid) as "enum.comment",
e.enumlabel as "values"
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go
index 745aae4..dbf9d15 100644
--- a/generator/sqlite/query_set.go
+++ b/generator/sqlite/query_set.go
@@ -4,10 +4,11 @@ import (
"context"
"database/sql"
"fmt"
+ "strings"
+
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/semantic"
"github.com/go-jet/jet/v2/qrm"
- "strings"
)
// sqliteQuerySet is dialect query set for SQLite
@@ -74,11 +75,12 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
}
var columnInfos []struct {
- Name string
- Type string
- NotNull int32
- Pk int32
- Hidden int32
+ Name string
+ Type string
+ NotNull int32
+ DfltValue string
+ Pk int32
+ Hidden int32
}
_, err = qrm.Query(context.Background(), db, tableInfoQuery, []interface{}{tableName}, &columnInfos)
@@ -91,12 +93,14 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
for _, columnInfo := range columnInfos {
columnType := strings.TrimSuffix(getColumnType(columnInfo.Type), " GENERATED ALWAYS")
isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column
+ hasDefault := columnInfo.DfltValue != ""
columns = append(columns, metadata.Column{
Name: columnInfo.Name,
IsPrimaryKey: columnInfo.Pk != 0,
IsNullable: columnInfo.NotNull != 1,
IsGenerated: isGenerated,
+ HasDefault: hasDefault,
DataType: metadata.DataType{
Name: columnType,
Kind: metadata.BaseType,
diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go
index 731b1af..1538031 100644
--- a/generator/template/file_templates.go
+++ b/generator/template/file_templates.go
@@ -26,13 +26,14 @@ import (
var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "{{tableTemplate.DefaultAlias}}")
+{{golangComment .Comment}}
type {{structImplName}} struct {
{{dialect.PackageName}}.Table
// Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
- {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{- if $c.Comment }} // {{$c.GoLangComment}} {{end}}
+ {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{golangComment .Comment}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
@@ -119,10 +120,11 @@ import (
{{end}}
{{$modelTableTemplate := tableTemplate}}
+{{golangComment .Comment}}
type {{$modelTableTemplate.TypeName}} struct {
{{- range .Columns}}
{{- $field := structField .}}
- {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{- if .Comment }} // {{.GoLangComment}} {{end}}
+ {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{golangComment .Comment}}
{{- end}}
}
@@ -132,6 +134,7 @@ var enumSQLBuilderTemplate = `package {{package}}
import "github.com/go-jet/jet/v2/{{dialect.PackageName}}"
+{{golangComment .Comment}}
var {{enumTemplate.InstanceName}} = &struct {
{{- range $index, $value := .Values}}
{{enumValueName $value}} {{dialect.PackageName}}.StringExpression
@@ -148,6 +151,7 @@ var enumModelTemplate = `package {{package}}
import "errors"
+{{golangComment .Comment}}
type {{$enumTemplate.TypeName}} string
const (
@@ -156,6 +160,12 @@ const (
{{- end}}
)
+var {{$enumTemplate.TypeName}}AllValues = []{{$enumTemplate.TypeName}} {
+{{- range $_, $value := .Values}}
+ {{valueName $value}},
+{{- end}}
+}
+
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
var enumValue string
switch val := value.(type) {
diff --git a/generator/template/format.go b/generator/template/format.go
new file mode 100644
index 0000000..bd88fb1
--- /dev/null
+++ b/generator/template/format.go
@@ -0,0 +1,13 @@
+package template
+
+import "regexp"
+
+// Returns the provided string as golang comment without ascii control characters
+func formatGolangComment(comment string) string {
+ if len(comment) == 0 {
+ return ""
+ }
+
+ // Format as colang comment and remove ascii control characters from string
+ return "// " + regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(comment, "")
+}
diff --git a/generator/template/format_test.go b/generator/template/format_test.go
new file mode 100644
index 0000000..b43b61d
--- /dev/null
+++ b/generator/template/format_test.go
@@ -0,0 +1,27 @@
+package template
+
+import "testing"
+
+func Test_formatGolangComment(t *testing.T) {
+ type args struct {
+ comment string
+ }
+ tests := []struct {
+ name string
+ args args
+ want string
+ }{
+ {name: "Empty string", args: args{comment: ""}, want: ""},
+ {name: "Non-empty string", args: args{comment: "This is a comment"}, want: "// This is a comment"},
+ {name: "String with control characters", args: args{comment: "This is a comment with control characters \x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f and text after"}, want: "// This is a comment with control characters and text after"},
+ {name: "String with escape characters", args: args{comment: "This is a comment with escape characters \n\r\t and text after"}, want: "// This is a comment with escape characters and text after"},
+ {name: "String with unicode characters", args: args{comment: "This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"}, want: "// This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := formatGolangComment(tt.args.comment); got != tt.want {
+ t.Errorf("formatGoLangComment() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/generator/template/process.go b/generator/template/process.go
index 3f3798a..5abef87 100644
--- a/generator/template/process.go
+++ b/generator/template/process.go
@@ -140,6 +140,7 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []
"enumValueName": func(enumValue string) string {
return enumTemplate.ValueName(enumValue)
},
+ "golangComment": formatGolangComment,
})
if err != nil {
return fmt.Errorf("failed to generete enum type %s: %w", enumTemplate.FileName, err)
@@ -215,6 +216,7 @@ func processTableSQLBuilder(fileTypes, dirPath string,
"insertedRowAlias": func() string {
return insertedRowAlias(dialect)
},
+ "golangComment": formatGolangComment,
})
if err != nil {
return fmt.Errorf("failed to generate table sql builder type %s: %w", tableSQLBuilder.TypeName, err)
@@ -307,6 +309,7 @@ func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadat
"structField": func(columnMetaData metadata.Column) TableModelField {
return tableTemplate.Field(columnMetaData)
},
+ "golangComment": formatGolangComment,
})
if err != nil {
return fmt.Errorf("failed to generate model type '%s': %w", tableMetaData.Name, err)
@@ -347,6 +350,7 @@ func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemp
"valueName": func(value string) string {
return enumTemplate.ValueName(value)
},
+ "golangComment": formatGolangComment,
})
if err != nil {
diff --git a/go.mod b/go.mod
index cb204af..890dc0d 100644
--- a/go.mod
+++ b/go.mod
@@ -4,16 +4,16 @@ go 1.18
require (
github.com/go-sql-driver/mysql v1.8.1
+ github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/jackc/pgconn v1.14.3
+ github.com/jackc/pgtype v1.14.3
+ github.com/jackc/pgx/v4 v4.18.3
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.17
)
require (
- github.com/google/go-cmp v0.6.0
- github.com/jackc/pgtype v1.14.3
- github.com/jackc/pgx/v4 v4.18.3
github.com/pkg/profile v1.7.0
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.9.0
diff --git a/internal/jet/column_assigment.go b/internal/jet/column_assigment.go
index 440f3eb..c888433 100644
--- a/internal/jet/column_assigment.go
+++ b/internal/jet/column_assigment.go
@@ -11,6 +11,13 @@ type columnAssigmentImpl struct {
expression Expression
}
+func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment {
+ return &columnAssigmentImpl{
+ column: serializer,
+ expression: expression,
+ }
+}
+
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go
index f3ad2b4..68c4c02 100644
--- a/internal/jet/dialect.go
+++ b/internal/jet/dialect.go
@@ -13,6 +13,7 @@ type Dialect interface {
ArgumentPlaceholder() QueryPlaceholderFunc
IsReservedWord(name string) bool
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
+ ValuesDefaultColumnName(index int) string
}
// SerializerFunc func
@@ -35,6 +36,7 @@ type DialectParams struct {
ArgumentPlaceholder QueryPlaceholderFunc
ReservedWords []string
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
+ ValuesDefaultColumnName func(index int) string
}
// NewDialect creates new dialect with params
@@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect {
argumentPlaceholder: params.ArgumentPlaceholder,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
serializeOrderBy: params.SerializeOrderBy,
+ valuesDefaultColumnName: params.ValuesDefaultColumnName,
}
}
@@ -62,6 +65,7 @@ type dialectImpl struct {
argumentPlaceholder QueryPlaceholderFunc
reservedWords map[string]bool
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
+ valuesDefaultColumnName func(index int) string
}
func (d *dialectImpl) Name() string {
@@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending,
return d.serializeOrderBy
}
+func (d *dialectImpl) ValuesDefaultColumnName(index int) string {
+ return d.valuesDefaultColumnName(index)
+}
+
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {
diff --git a/internal/jet/expression.go b/internal/jet/expression.go
index 05b1797..9999803 100644
--- a/internal/jet/expression.go
+++ b/internal/jet/expression.go
@@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
// IN checks if this expressions matches any in expressions list
func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
- return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN")
+ return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN")
}
// NOT_IN checks if this expressions is different of all expressions in expressions list
func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
- return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN")
+ return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN")
}
// AS the temporary alias name to assign to the expression
@@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder,
}
}
-type skipParenthesisWrap struct {
- Expression
-}
-
-func skipWrap(expression Expression) Expression {
- return &skipParenthesisWrap{expression}
-}
-
-// since the expression is a function parameter, there is no need to wrap it in parentheses
-func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
- s.Expression.serialize(statement, out, append(options, NoWrap)...)
+func wrap(expressions ...Expression) Expression {
+ return NewFunc("", expressions, nil)
}
diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go
index 7e49880..ddc579e 100644
--- a/internal/jet/func_expression.go
+++ b/internal/jet/func_expression.go
@@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression {
return newBoolExpressionListOperator("OR", expressions...)
}
-// ROW function is used to create a tuple value that consists of a set of expressions or column values.
-func ROW(expressions ...Expression) Expression {
- return NewFunc("ROW", expressions, nil)
-}
-
// ------------------ Mathematical functions ---------------//
// ABSf calculates absolute value from float expression
@@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder
if _, isStatement := expression.(Statement); isStatement {
expression.serialize(statement, out, options...)
} else {
- skipWrap(expression).serialize(statement, out, options...)
+ expression.serialize(statement, out, append(options, NoWrap, Ident)...)
}
}
}
diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go
index d6f0b41..251d3ab 100644
--- a/internal/jet/literal_expression.go
+++ b/internal/jet/literal_expression.go
@@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option
//---------------------------------------------------//
-type wrap struct {
- ExpressionInterfaceImpl
- expressions []Expression
-}
-
-func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
- out.WriteString("(")
-
- if len(n.expressions) == 1 {
- options = append(options, NoWrap, Ident)
- }
- serializeExpressionList(statementType, n.expressions, ", ", out, options...)
-
- out.WriteString(")")
-}
-
-// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... )
-func WRAP(expression ...Expression) Expression {
- wrap := &wrap{expressions: expression}
- wrap.ExpressionInterfaceImpl.Parent = wrap
-
- return wrap
-}
-
-//---------------------------------------------------//
-
type rawExpression struct {
ExpressionInterfaceImpl
diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go
index 8ce5d1e..c853845 100644
--- a/internal/jet/order_set_aggregate_functions.go
+++ b/internal/jet/order_set_aggregate_functions.go
@@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct {
func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(p.name)
- WRAP(p.fraction).serialize(statement, out, FallTrough(options)...)
+
+ if p.fraction != nil {
+ wrap(p.fraction).serialize(statement, out, FallTrough(options)...)
+ } else {
+ wrap().serialize(statement, out, FallTrough(options)...)
+ }
out.WriteString("WITHIN GROUP")
p.orderBy.serialize(statement, out)
}
diff --git a/internal/jet/projection_test.go b/internal/jet/projection_test.go
index 0370b43..61dd920 100644
--- a/internal/jet/projection_test.go
+++ b/internal/jet/projection_test.go
@@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg",
table2.col3 AS "col3",
table2.col4 AS "col4"`)
- subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery"))
+ subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil))
assertProjectionSerialize(t, subQueryProjections,
`"subQuery"."table1.col3" AS "table1.col3",
diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go
index 191c7b4..99fb8eb 100644
--- a/internal/jet/raw_statement.go
+++ b/internal/jet/raw_statement.go
@@ -8,7 +8,7 @@ type rawStatementImpl struct {
}
// RawStatement creates new sql statements from raw query and optional map of named arguments
-func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement {
+func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement {
newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go
new file mode 100644
index 0000000..e7d5ed5
--- /dev/null
+++ b/internal/jet/row_expression.go
@@ -0,0 +1,102 @@
+package jet
+
+// RowExpression interface
+type RowExpression interface {
+ Expression
+ HasProjections
+
+ EQ(rhs RowExpression) BoolExpression
+ NOT_EQ(rhs RowExpression) BoolExpression
+ IS_DISTINCT_FROM(rhs RowExpression) BoolExpression
+ IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression
+
+ LT(rhs RowExpression) BoolExpression
+ LT_EQ(rhs RowExpression) BoolExpression
+ GT(rhs RowExpression) BoolExpression
+ GT_EQ(rhs RowExpression) BoolExpression
+}
+
+type rowInterfaceImpl struct {
+ parent Expression
+ dialect Dialect
+ elemCount int
+}
+
+func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression {
+ return Eq(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression {
+ return NotEq(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression {
+ return IsDistinctFrom(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression {
+ return IsNotDistinctFrom(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression {
+ return Gt(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression {
+ return GtEq(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression {
+ return Lt(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression {
+ return LtEq(n.parent, rhs)
+}
+
+func (n *rowInterfaceImpl) projections() ProjectionList {
+ var ret ProjectionList
+
+ for i := 0; i < n.elemCount; i++ {
+ rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil)
+ ret = append(ret, &rowColumn)
+ }
+
+ return ret
+}
+
+// ---------------------------------------------------//
+type rowExpressionWrapper struct {
+ rowInterfaceImpl
+ Expression
+}
+
+func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression {
+ ret := &rowExpressionWrapper{}
+ ret.rowInterfaceImpl.parent = ret
+
+ ret.Expression = NewFunc(name, expressions, ret)
+ ret.dialect = dialect
+ ret.elemCount = len(expressions)
+
+ return ret
+}
+
+// ROW function is used to create a tuple value that consists of a set of expressions or column values.
+func ROW(dialect Dialect, expressions ...Expression) RowExpression {
+ return newRowExpression("ROW", dialect, expressions...)
+}
+
+// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`.
+func WRAP(dialect Dialect, expressions ...Expression) RowExpression {
+ return newRowExpression("", dialect, expressions...)
+}
+
+// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
+// This enables the Go compiler to interpret any expression as a row expression
+// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
+func RowExp(expression Expression) RowExpression {
+ rowExpressionWrap := rowExpressionWrapper{Expression: expression}
+ rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap
+ return &rowExpressionWrap
+}
diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go
index c25fba3..f1f58ac 100644
--- a/internal/jet/select_table.go
+++ b/internal/jet/select_table.go
@@ -8,15 +8,21 @@ type SelectTable interface {
}
type selectTableImpl struct {
- Statement SerializerHasProjections
- alias string
+ Statement SerializerHasProjections
+ alias string
+ columnAliases []ColumnExpression
}
// NewSelectTable func
-func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
+func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl {
selectTable := selectTableImpl{
- Statement: selectStmt,
- alias: alias,
+ Statement: selectStmt,
+ alias: alias,
+ columnAliases: columnAliases,
+ }
+
+ for _, column := range selectTable.columnAliases {
+ column.setSubQuery(selectTable)
}
return selectTable
@@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string {
}
func (s selectTableImpl) AllColumns() ProjectionList {
+ if len(s.columnAliases) > 0 {
+ return ColumnListToProjectionList(s.columnAliases)
+ }
+
projectionList := s.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
@@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt
out.WriteString("AS")
out.WriteIdentifier(s.alias)
+
+ if len(s.columnAliases) > 0 {
+ out.WriteByte('(')
+ SerializeColumnExpressionNames(s.columnAliases, out)
+ out.WriteByte(')')
+ }
}
// --------------------------------------
@@ -50,7 +66,7 @@ type lateralImpl struct {
// NewLateral creates new lateral expression from select statement with alias
func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
- return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)}
+ return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)}
}
func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
diff --git a/internal/jet/statement.go b/internal/jet/statement.go
index 6703154..2ca229d 100644
--- a/internal/jet/statement.go
+++ b/internal/jet/statement.go
@@ -180,7 +180,7 @@ func duration(f func()) time.Duration {
f()
- return time.Now().Sub(start)
+ return time.Since(start)
}
// ExpressionStatement interfacess
diff --git a/internal/jet/values.go b/internal/jet/values.go
new file mode 100644
index 0000000..9d0f691
--- /dev/null
+++ b/internal/jet/values.go
@@ -0,0 +1,35 @@
+package jet
+
+// Values hold a set of one or more rows
+type Values []RowExpression
+
+func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
+ out.WriteByte('(')
+ out.IncreaseIdent(5)
+
+ out.NewLine()
+ out.WriteString("VALUES")
+
+ for rowIndex, row := range v {
+ if rowIndex > 0 {
+ out.WriteString(",")
+ out.NewLine()
+ } else {
+ out.IncreaseIdent(7)
+ }
+
+ row.serialize(statement, out, options...)
+ }
+ out.DecreaseIdent(7)
+ out.DecreaseIdent(5)
+ out.NewLine()
+ out.WriteByte(')')
+}
+
+func (v Values) projections() ProjectionList {
+ if len(v) == 0 {
+ return nil
+ }
+
+ return v[0].projections()
+}
diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go
index 783fa27..03330e6 100644
--- a/internal/jet/with_statement.go
+++ b/internal/jet/with_statement.go
@@ -64,7 +64,7 @@ type CommonTableExpression struct {
// CTE creates new named CommonTableExpression
func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{
- selectTableImpl: NewSelectTable(nil, name),
+ selectTableImpl: NewSelectTable(nil, name, columns),
Columns: columns,
}
@@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde
out.WriteIdentifier(c.alias)
}
}
-
-// AllColumns returns list of all projections in the CTE
-func (c CommonTableExpression) AllColumns() ProjectionList {
- if len(c.Columns) > 0 {
- return ColumnListToProjectionList(c.Columns)
- }
-
- return c.selectTableImpl.AllColumns()
-}
diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go
index 3e03506..a4a06ee 100644
--- a/internal/testutils/test_utils.go
+++ b/internal/testutils/test_utils.go
@@ -9,17 +9,15 @@ import (
jet2 "github.com/go-jet/jet/v2/internal/jet/db"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
+ "github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
"time"
-
- "github.com/google/go-cmp/cmp"
)
// UnixTimeComparer will compare time equality while ignoring time zone
@@ -105,11 +103,12 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
}
// SaveJSONFile saves v as json at testRelativePath
+// nolint:unused
func SaveJSONFile(v interface{}, testRelativePath string) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
filePath := getFullPath(testRelativePath)
- err := ioutil.WriteFile(filePath, jsonText, 0644)
+ err := os.WriteFile(filePath, jsonText, 0600)
throw.OnError(err)
}
@@ -118,7 +117,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath)
- fileJSONData, err := ioutil.ReadFile(filePath)
+ fileJSONData, err := os.ReadFile(filePath) // #nosec G304
require.NoError(t, err)
if runtime.GOOS == "windows" {
@@ -245,7 +244,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter
// AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
- enumFileData, err := ioutil.ReadFile(filePath)
+ enumFileData, err := os.ReadFile(filePath) // #nosec G304
require.NoError(t, err)
@@ -254,7 +253,7 @@ func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
// AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) {
- files, err := ioutil.ReadDir(dirPath)
+ files, err := os.ReadDir(dirPath)
require.NoError(t, err)
require.Equal(t, len(files), len(fileNames))
@@ -293,76 +292,6 @@ func printDiff(actual, expected interface{}, options ...cmp.Option) {
fmt.Println(expected)
}
-// BoolPtr returns address of bool parameter
-func BoolPtr(b bool) *bool {
- return &b
-}
-
-// Int8Ptr returns address of int8 parameter
-func Int8Ptr(i int8) *int8 {
- return &i
-}
-
-// UInt8Ptr returns address of uint8 parameter
-func UInt8Ptr(i uint8) *uint8 {
- return &i
-}
-
-// Int16Ptr returns address of int16 parameter
-func Int16Ptr(i int16) *int16 {
- return &i
-}
-
-// UInt16Ptr returns address of uint16 parameter
-func UInt16Ptr(i uint16) *uint16 {
- return &i
-}
-
-// Int32Ptr returns address of int32 parameter
-func Int32Ptr(i int32) *int32 {
- return &i
-}
-
-// UInt32Ptr returns address of uint32 parameter
-func UInt32Ptr(i uint32) *uint32 {
- return &i
-}
-
-// Int64Ptr returns address of int64 parameter
-func Int64Ptr(i int64) *int64 {
- return &i
-}
-
-// UInt64Ptr returns address of uint64 parameter
-func UInt64Ptr(i uint64) *uint64 {
- return &i
-}
-
-// StringPtr returns address of string parameter
-func StringPtr(s string) *string {
- return &s
-}
-
-// TimePtr returns address of time.Time parameter
-func TimePtr(t time.Time) *time.Time {
- return &t
-}
-
-// ByteArrayPtr returns address of []byte parameter
-func ByteArrayPtr(arr []byte) *[]byte {
- return &arr
-}
-
-// Float32Ptr returns address of float32 parameter
-func Float32Ptr(f float32) *float32 {
- return &f
-}
-
-// Float64Ptr returns address of float64 parameter
-func Float64Ptr(f float64) *float64 {
- return &f
-}
-
// UUIDPtr returns address of uuid.UUID
func UUIDPtr(u string) *uuid.UUID {
newUUID := uuid.MustParse(u)
diff --git a/internal/utils/filesys/filesys.go b/internal/utils/filesys/filesys.go
index d904be8..0724470 100644
--- a/internal/utils/filesys/filesys.go
+++ b/internal/utils/filesys/filesys.go
@@ -1,6 +1,7 @@
package filesys
import (
+ "errors"
"fmt"
"go/format"
"os"
@@ -16,7 +17,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath += ".go"
}
- file, err := os.Create(newGoFilePath)
+ file, err := os.Create(newGoFilePath) // #nosec 304
if err != nil {
return err
@@ -28,7 +29,10 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
// if there is a format error we will write unformulated text for debug purposes
if err != nil {
- file.Write(text)
+ _, writeErr := file.Write(text)
+ if writeErr != nil {
+ return errors.Join(writeErr, fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err))
+ }
return fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err)
}
@@ -43,7 +47,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
// EnsureDirPathExist ensures dir path exists. If path does not exist, creates new path.
func EnsureDirPathExist(dirPath string) error {
if _, err := os.Stat(dirPath); os.IsNotExist(err) {
- err := os.MkdirAll(dirPath, os.ModePerm)
+ err := os.MkdirAll(dirPath, 0o750)
if err != nil {
return fmt.Errorf("can't create directory - %s: %w", dirPath, err)
diff --git a/internal/utils/ptr/ptr.go b/internal/utils/ptr/ptr.go
new file mode 100644
index 0000000..26f2688
--- /dev/null
+++ b/internal/utils/ptr/ptr.go
@@ -0,0 +1,6 @@
+package ptr
+
+// Of returns the address of any given parameter
+func Of[T any](value T) *T {
+ return &value
+}
diff --git a/mysql/cast.go b/mysql/cast.go
index 83a0578..ca647a5 100644
--- a/mysql/cast.go
+++ b/mysql/cast.go
@@ -6,23 +6,27 @@ import (
)
type cast interface {
- // Cast expressions as castType type
+ // AS casts expressions as castType type
AS(castType string) Expression
- // Cast expression as char with optional length
+ // AS_CHAR casts expression as char with optional length
AS_CHAR(length ...int) StringExpression
- // Cast expression AS date type
+ // AS_DATE casts expression AS date type
AS_DATE() DateExpression
- // Cast expression AS numeric type, using precision and optionally scale
+ // AS_FLOAT casts expressions as float type
+ AS_FLOAT() FloatExpression
+ // AS_DOUBLE casts expressions as double type
+ AS_DOUBLE() FloatExpression
+ // AS_DECIMAL casts expression AS numeric type
AS_DECIMAL() FloatExpression
- // Cast expression AS time type
+ // AS_TIME casts expression AS time type
AS_TIME() TimeExpression
- // Cast expression as datetime type
+ // AS_DATETIME casts expression as datetime type
AS_DATETIME() DateTimeExpression
- // Cast expressions as signed integer type
+ // AS_SIGNED casts expressions as signed integer type
AS_SIGNED() IntegerExpression
- // Cast expression as unsigned integer type
+ // AS_UNSIGNED casts expression as unsigned integer type
AS_UNSIGNED() IntegerExpression
- // Cast expression as binary type
+ // AS_BINARY casts expression as binary type
AS_BINARY() StringExpression
}
@@ -73,6 +77,14 @@ func (c *castImpl) AS_DATE() DateExpression {
return DateExp(c.AS("DATE"))
}
+func (c *castImpl) AS_FLOAT() FloatExpression {
+ return FloatExp(c.AS("FLOAT"))
+}
+
+func (c *castImpl) AS_DOUBLE() FloatExpression {
+ return FloatExp(c.AS("DOUBLE"))
+}
+
// AS_DECIMAL casts expression AS DECIMAL type
func (c *castImpl) AS_DECIMAL() FloatExpression {
return FloatExp(c.AS("DECIMAL"))
diff --git a/mysql/dialect.go b/mysql/dialect.go
index 18d2eec..9628bfb 100644
--- a/mysql/dialect.go
+++ b/mysql/dialect.go
@@ -1,6 +1,7 @@
package mysql
import (
+ "fmt"
"github.com/go-jet/jet/v2/internal/jet"
)
@@ -28,6 +29,9 @@ func newDialect() jet.Dialect {
},
ReservedWords: reservedWords,
SerializeOrderBy: serializeOrderBy,
+ ValuesDefaultColumnName: func(index int) string {
+ return fmt.Sprintf("column_%d", index)
+ },
}
return jet.NewDialect(mySQLDialectParams)
diff --git a/mysql/expressions.go b/mysql/expressions.go
index 53b1fa7..4073ef5 100644
--- a/mysql/expressions.go
+++ b/mysql/expressions.go
@@ -30,6 +30,9 @@ type DateTimeExpression = jet.TimestampExpression
// TimestampExpression interface
type TimestampExpression = jet.TimestampExpression
+// RowExpression interface
+type RowExpression = jet.RowExpression
+
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
@@ -70,6 +73,11 @@ var DateTimeExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output.
var TimestampExp = jet.TimestampExp
+// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
+// This enables the Go compiler to interpret any expression as a row expression
+// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
+var RowExp = jet.RowExp
+
// CustomExpression is used to define custom expressions.
var CustomExpression = jet.CustomExpression
diff --git a/mysql/functions.go b/mysql/functions.go
index ca31d18..ceec7ab 100644
--- a/mysql/functions.go
+++ b/mysql/functions.go
@@ -12,7 +12,9 @@ var (
)
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
-var ROW = jet.ROW
+func ROW(expressions ...Expression) RowExpression {
+ return jet.ROW(Dialect, expressions...)
+}
// ------------------ Mathematical functions ---------------//
diff --git a/mysql/interval.go b/mysql/interval.go
index ce1c609..23be21f 100644
--- a/mysql/interval.go
+++ b/mysql/interval.go
@@ -14,25 +14,25 @@ type unitType string
// List of interval unit types for MySQL
const (
MICROSECOND unitType = "MICROSECOND"
- SECOND = "SECOND"
- MINUTE = "MINUTE"
- HOUR = "HOUR"
- DAY = "DAY"
- WEEK = "WEEK"
- MONTH = "MONTH"
- QUARTER = "QUARTER"
- YEAR = "YEAR"
- SECOND_MICROSECOND = "SECOND_MICROSECOND"
- MINUTE_MICROSECOND = "MINUTE_MICROSECOND"
- MINUTE_SECOND = "MINUTE_SECOND"
- HOUR_MICROSECOND = "HOUR_MICROSECOND"
- HOUR_SECOND = "HOUR_SECOND"
- HOUR_MINUTE = "HOUR_MINUTE"
- DAY_MICROSECOND = "DAY_MICROSECOND"
- DAY_SECOND = "DAY_SECOND"
- DAY_MINUTE = "DAY_MINUTE"
- DAY_HOUR = "DAY_HOUR"
- YEAR_MONTH = "YEAR_MONTH"
+ SECOND unitType = "SECOND"
+ MINUTE unitType = "MINUTE"
+ HOUR unitType = "HOUR"
+ DAY unitType = "DAY"
+ WEEK unitType = "WEEK"
+ MONTH unitType = "MONTH"
+ QUARTER unitType = "QUARTER"
+ YEAR unitType = "YEAR"
+ SECOND_MICROSECOND unitType = "SECOND_MICROSECOND"
+ MINUTE_MICROSECOND unitType = "MINUTE_MICROSECOND"
+ MINUTE_SECOND unitType = "MINUTE_SECOND"
+ HOUR_MICROSECOND unitType = "HOUR_MICROSECOND"
+ HOUR_SECOND unitType = "HOUR_SECOND"
+ HOUR_MINUTE unitType = "HOUR_MINUTE"
+ DAY_MICROSECOND unitType = "DAY_MICROSECOND"
+ DAY_SECOND unitType = "DAY_SECOND"
+ DAY_MINUTE unitType = "DAY_MINUTE"
+ DAY_HOUR unitType = "DAY_HOUR"
+ YEAR_MONTH unitType = "YEAR_MONTH"
)
// Interval is representation of MySQL interval
diff --git a/mysql/select_statement.go b/mysql/select_statement.go
index 6c5f345..aaeff9a 100644
--- a/mysql/select_statement.go
+++ b/mysql/select_statement.go
@@ -172,7 +172,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
- return newSelectTable(s, alias)
+ return newSelectTable(s, alias, nil)
}
//-----------------------------------------------------
diff --git a/mysql/select_table.go b/mysql/select_table.go
index ad22193..8ca06d4 100644
--- a/mysql/select_table.go
+++ b/mysql/select_table.go
@@ -13,9 +13,9 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
-func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
+func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable {
subQuery := &selectTableImpl{
- SelectTable: jet.NewSelectTable(selectStmt, alias),
+ SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases),
}
subQuery.readableTableInterfaceImpl.parent = subQuery
diff --git a/mysql/set_statement.go b/mysql/set_statement.go
index 2df75a0..7147f00 100644
--- a/mysql/set_statement.go
+++ b/mysql/set_statement.go
@@ -85,7 +85,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement {
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
- return newSelectTable(s, alias)
+ return newSelectTable(s, alias, nil)
}
const (
diff --git a/mysql/statement.go b/mysql/statement.go
index a0bb0a7..008ace5 100644
--- a/mysql/statement.go
+++ b/mysql/statement.go
@@ -6,7 +6,7 @@ import (
)
// RawStatement creates new sql statements from raw query and optional map of named arguments
-func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement {
+func RawStatement(rawQuery string, namedArguments ...RawArgs) jet.SerializerStatement {
return jet.RawStatement(Dialect, rawQuery, namedArguments...)
}
diff --git a/mysql/values.go b/mysql/values.go
new file mode 100644
index 0000000..b55895b
--- /dev/null
+++ b/mysql/values.go
@@ -0,0 +1,32 @@
+package mysql
+
+import "github.com/go-jet/jet/v2/internal/jet"
+
+type values struct {
+ jet.Values
+}
+
+// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table.
+// Each row is defined by the ROW constructor, which takes one or more expressions.
+//
+// Example usage:
+//
+// VALUES(
+// ROW(Int32(204), Float32(1.21)),
+// ROW(Int32(207), Float32(1.02)),
+// )
+func VALUES(rows ...RowExpression) values {
+ return values{Values: jet.Values(rows)}
+}
+
+// AS assigns an alias to the temporary VALUES table, allowing it to be referenced
+// within SQL FROM clauses, just like a regular table.
+// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be
+// overwritten by passing new list of columns.
+//
+// Example usage:
+//
+// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date"))
+func (v values) AS(alias string, columns ...Column) SelectTable {
+ return newSelectTable(v, alias, columns)
+}
diff --git a/mysql/with_statement.go b/mysql/with_statement.go
index ca608cb..03b4d5b 100644
--- a/mysql/with_statement.go
+++ b/mysql/with_statement.go
@@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet"
type CommonTableExpression interface {
SelectTable
- AS(statement jet.SerializerStatement) CommonTableExpression
+ AS(statement jet.SerializerHasProjections) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
@@ -41,7 +41,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
}
// AS is used to define a CTE query
-func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
+func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
@@ -52,7 +52,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
- return newSelectTable(c, name)
+ return newSelectTable(c, name, nil)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
diff --git a/postgres/columns.go b/postgres/columns.go
index 819da38..a70c234 100644
--- a/postgres/columns.go
+++ b/postgres/columns.go
@@ -109,13 +109,20 @@ type ColumnInterval interface {
jet.Column
From(subQuery SelectTable) ColumnInterval
+ SET(intervalExp IntervalExpression) ColumnAssigment
}
+//------------------------------------------------------//
+
type intervalColumnImpl struct {
jet.ColumnExpressionImpl
intervalInterfaceImpl
}
+func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
+ return jet.NewColumnAssignment(i, intervalExp)
+}
+
func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval {
newIntervalColumn := IntervalColumn(i.Name())
jet.SetTableName(newIntervalColumn, i.TableName())
diff --git a/postgres/dialect.go b/postgres/dialect.go
index 9484ab1..14929ad 100644
--- a/postgres/dialect.go
+++ b/postgres/dialect.go
@@ -1,6 +1,7 @@
package postgres
import (
+ "fmt"
"github.com/go-jet/jet/v2/internal/jet"
"strconv"
)
@@ -25,6 +26,9 @@ func newDialect() jet.Dialect {
return "$" + strconv.Itoa(ord)
},
ReservedWords: reservedWords,
+ ValuesDefaultColumnName: func(index int) string {
+ return fmt.Sprintf("column%d", index+1)
+ },
}
return jet.NewDialect(dialectParams)
diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go
index 9aadbc9..6fb987c 100644
--- a/postgres/dialect_test.go
+++ b/postgres/dialect_test.go
@@ -46,33 +46,33 @@ func TestExists(t *testing.T) {
func TestIN(t *testing.T) {
assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
- `($1 IN (
+ `($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
-))`, float64(1.11))
+)))`, float64(1.11))
assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
- `(ROW($1, table1.col1) IN (
+ `(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
-))`, int64(12))
+)))`, int64(12))
}
func TestNOT_IN(t *testing.T) {
assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)),
- `($1 NOT IN (
+ `($1 NOT IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
-))`, float64(1.11))
+)))`, float64(1.11))
assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
- `(ROW($1, table1.col1) NOT IN (
+ `(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
-))`, int64(12))
+)))`, int64(12))
}
func TestReservedWordEscaped(t *testing.T) {
diff --git a/postgres/expressions.go b/postgres/expressions.go
index 9872910..d8ad34b 100644
--- a/postgres/expressions.go
+++ b/postgres/expressions.go
@@ -36,6 +36,9 @@ type TimestampExpression = jet.TimestampExpression
// TimestampzExpression interface
type TimestampzExpression = jet.TimestampzExpression
+// RowExpression interface
+type RowExpression = jet.RowExpression
+
// DateRange Expression interface
type DateRange = jet.Range[DateExpression]
@@ -99,6 +102,11 @@ var TimestampExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output.
var TimestampzExp = jet.TimestampzExp
+// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
+// This enables the Go compiler to interpret any expression as a row expression
+// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
+var RowExp = jet.RowExp
+
// RangeExp is range expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as range expression.
// Does not add sql cast to generated sql builder output.
diff --git a/postgres/functions.go b/postgres/functions.go
index 7b6d1e1..3134325 100644
--- a/postgres/functions.go
+++ b/postgres/functions.go
@@ -14,7 +14,9 @@ var (
)
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
-var ROW = jet.ROW
+func ROW(expressions ...Expression) RowExpression {
+ return jet.ROW(Dialect, expressions...)
+}
// ------------------ Mathematical functions ---------------//
@@ -332,6 +334,16 @@ var LOCALTIMESTAMP = jet.LOCALTIMESTAMP
// NOW returns current date and time
var NOW = jet.NOW
+// DATE_TRUNC returns the truncated date and time using optional time zone.
+// Use TimestampzExp if you need timestamp with time zone and IntervalExp if you need interval.
+func DATE_TRUNC(field unit, source Expression, timezone ...string) TimestampExpression {
+ if len(timezone) > 0 {
+ return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source, jet.FixedLiteral(timezone[0]))
+ }
+
+ return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source)
+}
+
// --------------- Conditional Expressions Functions -------------//
// COALESCE function returns the first of its arguments that is not null.
@@ -415,10 +427,12 @@ func castFloatLiteral(fraction FloatExpression) FloatExpression {
// ),
var GROUPING_SETS = jet.GROUPING_SETS
-// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... )
-// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW method behave exactly the same,
-// except when used in GROUPING_SETS. For top level GROUPING SETS expression lists WRAP has to be used.
-var WRAP = jet.WRAP
+// WRAP surrounds a list of expressions or columns with parentheses, producing new row: (expression1, expression2, ...)
+// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW methods behave exactly the same,
+// except when used in GROUPING_SETS and VALUES. In these contexts, WRAP must be used instead of ROW.
+func WRAP(expressions ...Expression) RowExpression {
+ return jet.WRAP(Dialect, expressions...)
+}
// ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list.
// It creates extra rows in the result set that represent the subtotal values for each combination of columns.
diff --git a/postgres/functions_test.go b/postgres/functions_test.go
index 4190f70..d0081b0 100644
--- a/postgres/functions_test.go
+++ b/postgres/functions_test.go
@@ -1,6 +1,8 @@
package postgres
-import "testing"
+import (
+ "testing"
+)
func TestROW(t *testing.T) {
assertSerialize(t, ROW(SELECT(Int(1))), `ROW((
@@ -10,3 +12,12 @@ func TestROW(t *testing.T) {
SELECT $2
), $3)`)
}
+
+func TestDATE_TRUNC(t *testing.T) {
+ assertSerialize(t, DATE_TRUNC(YEAR, NOW()), "DATE_TRUNC('YEAR', NOW())")
+ assertSerialize(
+ t,
+ DATE_TRUNC(DAY, NOW().ADD(INTERVAL(1, HOUR)), "Australia/Sydney"),
+ "DATE_TRUNC('DAY', NOW() + INTERVAL '1 HOUR', 'Australia/Sydney')",
+ )
+}
diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go
index 25300c2..5ace301 100644
--- a/postgres/insert_statement_test.go
+++ b/postgres/insert_statement_test.go
@@ -1,7 +1,6 @@
package postgres
import (
- "github.com/go-jet/jet/v2/internal/jet"
"github.com/stretchr/testify/require"
"testing"
"time"
@@ -151,12 +150,13 @@ func TestInsert_ON_CONFLICT(t *testing.T) {
VALUES("one", "two").
VALUES("1", "2").
VALUES("theta", "beta").
- ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
- SET(table1ColBool.SET(Bool(true)),
- table2ColInt.SET(Int(1)),
- ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
- ).WHERE(table1Col1.GT(Int(2))),
- ).
+ ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).
+ DO_UPDATE(
+ SET(table1ColBool.SET(Bool(true)),
+ table2ColInt.SET(Int(1)),
+ ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
+ ).WHERE(table1Col1.GT(Int(2))),
+ ).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
@@ -178,12 +178,12 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
- ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
- SET(table1ColBool.SET(Bool(false)),
- table2ColInt.SET(Int(1)),
- ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
- ).WHERE(table1Col1.GT(Int(2))),
- ).
+ ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").
+ DO_UPDATE(
+ SET(table1ColBool.SET(Bool(false)),
+ table2ColInt.SET(Int(1)),
+ ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
+ ).WHERE(table1Col1.GT(Int(2)))).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
diff --git a/postgres/literal.go b/postgres/literal.go
index e3a95b3..26b75d8 100644
--- a/postgres/literal.go
+++ b/postgres/literal.go
@@ -57,6 +57,16 @@ func Uint64(value uint64) IntegerExpression {
// Float creates new float literal expression
var Float = jet.Float
+// Float32 is constructor for 32 bit float literals
+func Float32(value float32) FloatExpression {
+ return CAST(jet.Literal(value)).AS_REAL()
+}
+
+// Float64 is constructor for 64 bit float literals
+func Float64(value float64) FloatExpression {
+ return CAST(jet.Literal(value)).AS_DOUBLE()
+}
+
// Decimal creates new float literal expression
var Decimal = jet.Decimal
diff --git a/postgres/select_statement.go b/postgres/select_statement.go
index 70a9a50..2a48fc5 100644
--- a/postgres/select_statement.go
+++ b/postgres/select_statement.go
@@ -181,7 +181,7 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement {
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
- return newSelectTable(s, alias)
+ return newSelectTable(s, alias, nil)
}
//-----------------------------------------------------
diff --git a/postgres/select_table.go b/postgres/select_table.go
index f3d680d..5f57c1d 100644
--- a/postgres/select_table.go
+++ b/postgres/select_table.go
@@ -2,7 +2,7 @@ package postgres
import "github.com/go-jet/jet/v2/internal/jet"
-// SelectTable is interface for postgres sub-queries
+// SelectTable is interface for postgres temporary tables like sub-queries, VALUES, CTEs etc...
type SelectTable interface {
readableTable
jet.SelectTable
@@ -13,9 +13,9 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
-func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
+func newSelectTable(serializerWithProjections jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable {
subQuery := &selectTableImpl{
- SelectTable: jet.NewSelectTable(selectStmt, alias),
+ SelectTable: jet.NewSelectTable(serializerWithProjections, alias, columnAliases),
}
subQuery.readableTableInterfaceImpl.parent = subQuery
diff --git a/postgres/set_statement.go b/postgres/set_statement.go
index 834560d..0dee00d 100644
--- a/postgres/set_statement.go
+++ b/postgres/set_statement.go
@@ -136,7 +136,7 @@ func (s *setStatementImpl) OFFSET_e(offset IntegerExpression) setStatement {
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
- return newSelectTable(s, alias)
+ return newSelectTable(s, alias, nil)
}
const (
diff --git a/postgres/values.go b/postgres/values.go
new file mode 100644
index 0000000..28d17cd
--- /dev/null
+++ b/postgres/values.go
@@ -0,0 +1,32 @@
+package postgres
+
+import "github.com/go-jet/jet/v2/internal/jet"
+
+type values struct {
+ jet.Values
+}
+
+// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table.
+// Each row is defined by the WRAP constructor, which takes one or more expressions.
+//
+// Example usage:
+//
+// VALUES(
+// WRAP(Int32(204), Float32(1.21)),
+// WRAP(Int32(207), Float32(1.02)),
+// )
+func VALUES(rows ...RowExpression) values {
+ return values{Values: jet.Values(rows)}
+}
+
+// AS assigns an alias to the temporary VALUES table, allowing it to be referenced
+// within SQL FROM clauses, just like a regular table.
+// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be
+// overwritten by passing new list of columns.
+//
+// Example usage:
+//
+// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date"))
+func (v values) AS(alias string, columns ...Column) SelectTable {
+ return newSelectTable(v, alias, columns)
+}
diff --git a/postgres/with_statement.go b/postgres/with_statement.go
index 698d6e3..99ddc8f 100644
--- a/postgres/with_statement.go
+++ b/postgres/with_statement.go
@@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet"
type CommonTableExpression interface {
SelectTable
- AS(statement jet.SerializerStatement) CommonTableExpression
+ AS(statement jet.SerializerHasProjections) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
@@ -42,7 +42,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
}
// AS is used to define a CTE query
-func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
+func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
@@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
- return newSelectTable(c, name)
+ return newSelectTable(c, name, nil)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
diff --git a/qrm/internal/null_types.go b/qrm/internal/null_types.go
index ab75cf6..85d9c68 100644
--- a/qrm/internal/null_types.go
+++ b/qrm/internal/null_types.go
@@ -10,6 +10,10 @@ import (
"time"
)
+var (
+ castOverFlowError = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error")
+)
+
// NullBool struct
type NullBool struct {
sql.NullBool
@@ -119,32 +123,47 @@ func (n *NullUInt64) Scan(value interface{}) error {
n.Valid = false
return nil
case int64:
+ if v < 0 {
+ return castOverFlowError
+ }
+ n.UInt64, n.Valid = uint64(v), true
+ return nil
+ case int32:
+ if v < 0 {
+ return castOverFlowError
+ }
+ n.UInt64, n.Valid = uint64(v), true
+ return nil
+ case int16:
+ if v < 0 {
+ return castOverFlowError
+ }
+ n.UInt64, n.Valid = uint64(v), true
+ return nil
+ case int8:
+ if v < 0 {
+ return castOverFlowError
+ }
+ n.UInt64, n.Valid = uint64(v), true
+ return nil
+ case int:
+ if v < 0 {
+ return castOverFlowError
+ }
n.UInt64, n.Valid = uint64(v), true
return nil
case uint64:
n.UInt64, n.Valid = v, true
return nil
- case int32:
- n.UInt64, n.Valid = uint64(v), true
- return nil
case uint32:
n.UInt64, n.Valid = uint64(v), true
return nil
- case int16:
- n.UInt64, n.Valid = uint64(v), true
- return nil
case uint16:
n.UInt64, n.Valid = uint64(v), true
return nil
- case int8:
- n.UInt64, n.Valid = uint64(v), true
- return nil
case uint8:
n.UInt64, n.Valid = uint64(v), true
return nil
- case int:
- n.UInt64, n.Valid = uint64(v), true
- return nil
case uint:
n.UInt64, n.Valid = uint64(v), true
return nil
diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go
index a15b104..eab2dd2 100644
--- a/qrm/internal/null_types_test.go
+++ b/qrm/internal/null_types_test.go
@@ -2,6 +2,7 @@ package internal
import (
"fmt"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
@@ -62,11 +63,21 @@ func TestNullUInt64(t *testing.T) {
value, _ := nullUInt64.Value()
require.Equal(t, value, uint64(11))
+ require.NoError(t, nullUInt64.Scan(uint64(11)))
+ require.Equal(t, nullUInt64.Valid, true)
+ value, _ = nullUInt64.Value()
+ require.Equal(t, value, uint64(11))
+
require.NoError(t, nullUInt64.Scan(int32(32)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(32))
+ require.NoError(t, nullUInt64.Scan(uint32(32)))
+ require.Equal(t, nullUInt64.Valid, true)
+ value, _ = nullUInt64.Value()
+ require.Equal(t, value, uint64(32))
+
require.NoError(t, nullUInt64.Scan(int16(20)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
@@ -88,4 +99,29 @@ func TestNullUInt64(t *testing.T) {
require.Equal(t, value, uint64(30))
require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text")
+
+ //Validate negative use cases
+ err := nullUInt64.Scan(int64(-5))
+ assert.NotNil(t, err)
+ assert.Error(t, err, castOverFlowError)
+
+ //Validate negative use cases
+ err = nullUInt64.Scan(-5)
+ assert.NotNil(t, err)
+ assert.Error(t, err, castOverFlowError)
+
+ //Validate negative use cases
+ err = nullUInt64.Scan(int32(-5))
+ assert.NotNil(t, err)
+ assert.Error(t, err, castOverFlowError)
+
+ //Validate negative use cases
+ err = nullUInt64.Scan(int16(-5))
+ assert.NotNil(t, err)
+ assert.Error(t, err, castOverFlowError)
+
+ //Validate negative use cases
+ err = nullUInt64.Scan(int8(-5))
+ assert.NotNil(t, err)
+ assert.Error(t, err, castOverFlowError)
}
diff --git a/sqlite/dialect.go b/sqlite/dialect.go
index 93e1d2f..da03364 100644
--- a/sqlite/dialect.go
+++ b/sqlite/dialect.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "fmt"
"github.com/go-jet/jet/v2/internal/jet"
)
@@ -23,6 +24,9 @@ func newDialect() jet.Dialect {
return "?"
},
ReservedWords: reservedWords2,
+ ValuesDefaultColumnName: func(index int) string {
+ return fmt.Sprintf("column%d", index+1)
+ },
}
return jet.NewDialect(mySQLDialectParams)
diff --git a/sqlite/expressions.go b/sqlite/expressions.go
index 42ccc96..0b2d320 100644
--- a/sqlite/expressions.go
+++ b/sqlite/expressions.go
@@ -33,6 +33,9 @@ type DateTimeExpression = jet.TimestampExpression
// TimestampExpression interface
type TimestampExpression = jet.TimestampExpression
+// RowExpression interface
+type RowExpression = jet.RowExpression
+
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
@@ -73,6 +76,11 @@ var DateTimeExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output.
var TimestampExp = jet.TimestampExp
+// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
+// This enables the Go compiler to interpret any expression as a row expression
+// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
+var RowExp = jet.RowExp
+
// CustomExpression is used to define custom expressions.
var CustomExpression = jet.CustomExpression
diff --git a/sqlite/functions.go b/sqlite/functions.go
index 47a0a5b..ac6bd08 100644
--- a/sqlite/functions.go
+++ b/sqlite/functions.go
@@ -15,9 +15,9 @@ var (
OR = jet.OR
)
-// ROW is construct one table row from list of expressions.
-func ROW(expressions ...Expression) Expression {
- return jet.NewFunc("", expressions, nil)
+// ROW function is used to create a tuple value that consists of a set of expressions or column values.
+func ROW(expressions ...Expression) RowExpression {
+ return jet.WRAP(Dialect, expressions...)
}
// ------------------ Mathematical functions ---------------//
diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go
index e74cda6..531ae76 100644
--- a/sqlite/select_statement.go
+++ b/sqlite/select_statement.go
@@ -155,7 +155,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
- return newSelectTable(s, alias)
+ return newSelectTable(s, alias, nil)
}
//-----------------------------------------------------
diff --git a/sqlite/select_table.go b/sqlite/select_table.go
index 9ac7f72..7421f4b 100644
--- a/sqlite/select_table.go
+++ b/sqlite/select_table.go
@@ -13,9 +13,9 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
-func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
+func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable {
subQuery := &selectTableImpl{
- SelectTable: jet.NewSelectTable(selectStmt, alias),
+ SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases),
}
subQuery.readableTableInterfaceImpl.parent = subQuery
diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go
index 0a004bf..47723a3 100644
--- a/sqlite/set_statement.go
+++ b/sqlite/set_statement.go
@@ -86,7 +86,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement {
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
- return newSelectTable(s, alias)
+ return newSelectTable(s, alias, nil)
}
const (
diff --git a/sqlite/values.go b/sqlite/values.go
new file mode 100644
index 0000000..cb683cb
--- /dev/null
+++ b/sqlite/values.go
@@ -0,0 +1,26 @@
+package sqlite
+
+import "github.com/go-jet/jet/v2/internal/jet"
+
+type values struct {
+ jet.Values
+}
+
+// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table.
+// Each row is defined by the ROW constructor, which takes one or more expressions.
+//
+// Example usage:
+//
+// VALUES(
+// ROW(Int32(204), Float32(1.21)),
+// ROW(Int32(207), Float32(1.02)),
+// )
+func VALUES(rows ...RowExpression) values {
+ return values{Values: jet.Values(rows)}
+}
+
+// AS assigns an alias to the temporary VALUES table, allowing it to be referenced
+// within SQL FROM clauses, just like a regular table.
+func (v values) AS(alias string) SelectTable {
+ return newSelectTable(v, alias, nil)
+}
diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go
index 5375fff..b05da7d 100644
--- a/sqlite/with_statement.go
+++ b/sqlite/with_statement.go
@@ -6,8 +6,8 @@ import "github.com/go-jet/jet/v2/internal/jet"
type CommonTableExpression interface {
SelectTable
- AS(statement jet.SerializerStatement) CommonTableExpression
- AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
+ AS(statement jet.SerializerHasProjections) CommonTableExpression
+ AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
@@ -42,13 +42,13 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
}
// AS is used to define a CTE query
-func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
+func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
// AS_NOT_MATERIALIZED is used to define not materialized CTE query
-func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression {
+func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.NotMaterialized = true
c.CommonTableExpression.Statement = statement
return c
@@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
- return newSelectTable(c, name)
+ return newSelectTable(c, name, nil)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {
diff --git a/tests/Makefile b/tests/Makefile
index 1a84778..43a8d66 100644
--- a/tests/Makefile
+++ b/tests/Makefile
@@ -6,6 +6,8 @@ setup: checkout-testdata docker-compose-up
checkout-testdata:
git submodule init
git submodule update
+#
+checkout-latest-testdata: checkout-testdata
cd ./testdata && git fetch && git checkout master && git pull
# docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize
diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml
index 9b3af50..09ce9d7 100644
--- a/tests/docker-compose.yaml
+++ b/tests/docker-compose.yaml
@@ -13,7 +13,7 @@ services:
- ./testdata/init/postgres:/docker-entrypoint-initdb.d
mysql:
- image: mysql:8.0.27
+ image: mysql:8.0
command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1']
restart: always
environment:
diff --git a/tests/init/init.go b/tests/init/init.go
index 5a21ee7..5a0ccdb 100644
--- a/tests/init/init.go
+++ b/tests/init/init.go
@@ -3,6 +3,7 @@ package main
import (
"context"
"database/sql"
+ "errors"
"flag"
"fmt"
"github.com/go-jet/jet/v2/generator/mysql"
@@ -10,7 +11,6 @@ import (
"github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/internal/utils/errfmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
- "io/ioutil"
"os"
"os/exec"
"strings"
@@ -125,7 +125,7 @@ func initMySQLDB(isMariaDB bool) error {
fmt.Println(cmdLine)
- cmd := exec.Command("sh", "-c", cmdLine)
+ cmd := exec.Command("sh", "-c", cmdLine) // #nosec G204
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
@@ -184,7 +184,7 @@ func initPostgresDB(dbType string, connectionString string) error {
}
func execFile(db *sql.DB, sqlFilePath string) error {
- testSampleSql, err := ioutil.ReadFile(sqlFilePath)
+ testSampleSql, err := os.ReadFile(sqlFilePath) // #nosec G304
if err != nil {
return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err)
}
@@ -211,7 +211,10 @@ func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error {
err = f(tx)
if err != nil {
- tx.Rollback()
+ rollBackError := tx.Rollback()
+ if rollBackError != nil {
+ return errors.Join(rollBackError, err)
+ }
return err
}
diff --git a/tests/internal/utils/file/file.go b/tests/internal/utils/file/file.go
index 6d08d22..73095ea 100644
--- a/tests/internal/utils/file/file.go
+++ b/tests/internal/utils/file/file.go
@@ -2,7 +2,6 @@ package file
import (
"github.com/stretchr/testify/require"
- "io/ioutil"
"os"
"path"
"testing"
@@ -11,7 +10,7 @@ import (
// Exists expects file to exist on path constructed from pathElems and returns content of the file
func Exists(t *testing.T, pathElems ...string) (fileContent string) {
modelFilePath := path.Join(pathElems...)
- file, err := ioutil.ReadFile(modelFilePath)
+ file, err := os.ReadFile(modelFilePath) // #nosec G304
require.Nil(t, err)
require.NotEmpty(t, file)
return string(file)
@@ -20,6 +19,6 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) {
// NotExists expects file not to exist on path constructed from pathElems
func NotExists(t *testing.T, pathElems ...string) {
modelFilePath := path.Join(pathElems...)
- _, err := ioutil.ReadFile(modelFilePath)
+ _, err := os.ReadFile(modelFilePath) // #nosec G304
require.True(t, os.IsNotExist(err))
}
diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go
index c90f774..61bc7f2 100644
--- a/tests/mysql/alltypes_test.go
+++ b/tests/mysql/alltypes_test.go
@@ -1,6 +1,7 @@
package mysql
import (
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"strings"
@@ -96,18 +97,18 @@ func TestExpressionOperators(t *testing.T) {
SELECT all_types.'integer' IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in",
- (all_types.small_int_ptr IN (
+ (all_types.small_int_ptr IN ((
SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types
- )) AS "result.in_select",
+ ))) AS "result.in_select",
(CURRENT_USER()) AS "result.raw",
(? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg",
(? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
- (all_types.small_int_ptr NOT IN (
+ (all_types.small_int_ptr NOT IN ((
SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types
- )) AS "result.not_in_select"
+ ))) AS "result.not_in_select"
FROM test_sample.all_types
LIMIT ?;
`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2))
@@ -1067,7 +1068,7 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
var toInsert = model.AllTypes{
Boolean: false,
- BooleanPtr: testutils.BoolPtr(true),
+ BooleanPtr: ptr.Of(true),
TinyInt: 1,
UTinyInt: 2,
SmallInt: 3,
@@ -1078,53 +1079,53 @@ var toInsert = model.AllTypes{
UInteger: 8,
BigInt: 9,
UBigInt: 1122334455,
- TinyIntPtr: testutils.Int8Ptr(11),
- UTinyIntPtr: testutils.UInt8Ptr(22),
- SmallIntPtr: testutils.Int16Ptr(33),
- USmallIntPtr: testutils.UInt16Ptr(44),
- MediumIntPtr: testutils.Int32Ptr(55),
- UMediumIntPtr: testutils.UInt32Ptr(66),
- IntegerPtr: testutils.Int32Ptr(77),
- UIntegerPtr: testutils.UInt32Ptr(88),
- BigIntPtr: testutils.Int64Ptr(99),
- UBigIntPtr: testutils.UInt64Ptr(111),
+ TinyIntPtr: ptr.Of(int8(11)),
+ UTinyIntPtr: ptr.Of(uint8(22)),
+ SmallIntPtr: ptr.Of(int16(33)),
+ USmallIntPtr: ptr.Of(uint16(44)),
+ MediumIntPtr: ptr.Of(int32(55)),
+ UMediumIntPtr: ptr.Of(uint32(66)),
+ IntegerPtr: ptr.Of(int32(77)),
+ UIntegerPtr: ptr.Of(uint32(88)),
+ BigIntPtr: ptr.Of(int64(99)),
+ UBigIntPtr: ptr.Of(uint64(111)),
Decimal: 11.22,
- DecimalPtr: testutils.Float64Ptr(33.44),
+ DecimalPtr: ptr.Of(33.44),
Numeric: 55.66,
- NumericPtr: testutils.Float64Ptr(77.88),
+ NumericPtr: ptr.Of(77.88),
Float: 99.00,
- FloatPtr: testutils.Float64Ptr(11.22),
+ FloatPtr: ptr.Of(11.22),
Double: 33.44,
- DoublePtr: testutils.Float64Ptr(55.66),
+ DoublePtr: ptr.Of(55.66),
Real: 77.88,
- RealPtr: testutils.Float64Ptr(99.00),
+ RealPtr: ptr.Of(99.00),
Bit: "1",
- BitPtr: testutils.StringPtr("0"),
+ BitPtr: ptr.Of("0"),
Time: time.Date(1, 1, 1, 10, 11, 12, 100, &time.Location{}),
- TimePtr: testutils.TimePtr(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)),
+ TimePtr: ptr.Of(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)),
Date: time.Now(),
- DatePtr: testutils.TimePtr(time.Now()),
+ DatePtr: ptr.Of(time.Now()),
DateTime: time.Now(),
- DateTimePtr: testutils.TimePtr(time.Now()),
+ DateTimePtr: ptr.Of(time.Now()),
Timestamp: time.Now(),
//TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB
Year: 2000,
- YearPtr: testutils.Int16Ptr(2001),
+ YearPtr: ptr.Of(int16(2001)),
Char: "abcd",
- CharPtr: testutils.StringPtr("absd"),
+ CharPtr: ptr.Of("absd"),
VarChar: "abcd",
- VarCharPtr: testutils.StringPtr("absd"),
+ VarCharPtr: ptr.Of("absd"),
Binary: []byte("1010"),
- BinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
+ BinaryPtr: ptr.Of([]byte("100001")),
VarBinary: []byte("1010"),
- VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
+ VarBinaryPtr: ptr.Of([]byte("100001")),
Blob: []byte("large file"),
- BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
+ BlobPtr: ptr.Of([]byte("very large file")),
Text: "some text",
- TextPtr: testutils.StringPtr("text"),
+ TextPtr: ptr.Of("text"),
Enum: model.AllTypesEnum_Value1,
JSON: "{}",
- JSONPtr: testutils.StringPtr(`{"a": 1}`),
+ JSONPtr: ptr.Of(`{"a": 1}`),
}
var allTypesJson = `
@@ -1358,17 +1359,17 @@ func TestExactDecimals(t *testing.T) {
Floats: model.Floats{
// overwritten by wrapped(floats) scope
Numeric: 0.1,
- NumericPtr: testutils.Float64Ptr(0.1),
+ NumericPtr: ptr.Of(0.1),
Decimal: 0.1,
- DecimalPtr: testutils.Float64Ptr(0.1),
+ DecimalPtr: ptr.Of(0.1),
// not overwritten
Float: 0.2,
- FloatPtr: testutils.Float64Ptr(0.22),
+ FloatPtr: ptr.Of(0.22),
Double: 0.3,
- DoublePtr: testutils.Float64Ptr(0.33),
+ DoublePtr: ptr.Of(0.33),
Real: 0.4,
- RealPtr: testutils.Float64Ptr(0.44),
+ RealPtr: ptr.Of(0.44),
},
Numeric: decimal.RequireFromString("12.35"),
NumericPtr: decimal.RequireFromString("56.79"),
@@ -1403,3 +1404,34 @@ VALUES ('91.23', '45.67', '12.35', '56.79', 0.2, 0.22, 0.3, 0.33, 0.4, 0.44);
require.Equal(t, 45.67, *result.Floats.DecimalPtr)
})
}
+
+func TestRowExpression(t *testing.T) {
+ now := time.Now()
+ nowAddHour := time.Now().Add(time.Hour)
+
+ stmt := SELECT(
+ ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))),
+ ROW(Bool(false), DateT(now)).NOT_EQ(ROW(Bool(true), DateT(now))),
+ ROW(TimestampT(nowAddHour), String("txt")).IS_DISTINCT_FROM(RowExp(Raw("row(NOW(), 'png')"))),
+ ROW(TimestampT(now), DateTimeT(nowAddHour)).GT(ROW(TimestampT(now), DateTimeT(now))),
+ ROW(DateTimeT(nowAddHour), Int(1)).GT_EQ(ROW(DateTimeT(now), Int(2))),
+ ROW(TimestampT(now), DateTimeT(nowAddHour)).LT(ROW(TimestampT(now), DateTimeT(now))),
+ ROW(DateTimeT(nowAddHour), Float(1.22)).LT_EQ(ROW(DateTimeT(now), Float(2.33))),
+ )
+
+ //fmt.Println(stmt.Sql())
+ //fmt.Println(stmt.DebugSql())
+
+ testutils.AssertStatementSql(t, stmt, `
+SELECT ROW(?, CAST(? AS DATE)) = ROW(?, CAST(? AS DATE)),
+ ROW(?, CAST(? AS DATE)) != ROW(?, CAST(? AS DATE)),
+ NOT(ROW(TIMESTAMP(?), ?) <=> (row(NOW(), 'png'))),
+ ROW(TIMESTAMP(?), CAST(? AS DATETIME)) > ROW(TIMESTAMP(?), CAST(? AS DATETIME)),
+ ROW(CAST(? AS DATETIME), ?) >= ROW(CAST(? AS DATETIME), ?),
+ ROW(TIMESTAMP(?), CAST(? AS DATETIME)) < ROW(TIMESTAMP(?), CAST(? AS DATETIME)),
+ ROW(CAST(? AS DATETIME), ?) <= ROW(CAST(? AS DATETIME), ?);
+`)
+
+ err := stmt.Query(db, &struct{}{})
+ require.NoError(t, err)
+}
diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go
index a7b5554..2a915a5 100644
--- a/tests/mysql/generator_test.go
+++ b/tests/mysql/generator_test.go
@@ -8,8 +8,11 @@ import (
"github.com/stretchr/testify/require"
+ "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/mysql"
+ "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/testutils"
+ mysql2 "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/dbconfig"
)
@@ -39,6 +42,39 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)
}
+func TestGenerator_TableMetadata(t *testing.T) {
+ var schema metadata.Schema
+ err := mysql.Generate(genTestDir3, dbConnection("dvds"),
+ template.Default(mysql2.Dialect).UseSchema(func(m metadata.Schema) template.Schema {
+ schema = m
+ return template.DefaultSchema(m)
+ }))
+ require.NoError(t, err)
+
+ // Spot check the actor table and assert that the emitted
+ // properties are as expected.
+ var got metadata.Table
+ for _, table := range schema.TablesMetaData {
+ if table.Name == "actor" {
+ got = table
+ }
+ }
+
+ want := metadata.Table{
+ Name: "actor",
+ Columns: []metadata.Column{
+ {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "smallint", Kind: "base", IsUnsigned: true}, Comment: ""},
+ {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""},
+ },
+ }
+ require.Equal(t, want, got)
+
+ err = os.RemoveAll(genTestDirRoot)
+ require.NoError(t, err)
+}
+
func TestCmdGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir3)
require.NoError(t, err)
diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go
index 7eb6ec8..54b37c0 100644
--- a/tests/mysql/insert_test.go
+++ b/tests/mysql/insert_test.go
@@ -3,10 +3,12 @@ package mysql
import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/qrm"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table"
+
"github.com/stretchr/testify/require"
"math/rand"
"testing"
@@ -300,7 +302,7 @@ func TestInsertOnDuplicateKeyUpdateNEW(t *testing.T) {
ID: randId,
URL: "https://www.yahoo.com",
Name: "Yahoo",
- Description: testutils.StringPtr("web portal and search engine"),
+ Description: ptr.Of("web portal and search engine"),
},
}).AS_NEW().
ON_DUPLICATE_KEY_UPDATE(
@@ -337,7 +339,7 @@ ON DUPLICATE KEY UPDATE id = (link.id + ?),
ID: randId + 11,
URL: "https://www.yahoo.com",
Name: "Yahoo",
- Description: testutils.StringPtr("web portal and search engine"),
+ Description: ptr.Of("web portal and search engine"),
})
})
}
diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go
index fd711b4..ee02a77 100644
--- a/tests/mysql/main_test.go
+++ b/tests/mysql/main_test.go
@@ -6,12 +6,9 @@ import (
jetmysql "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/dbconfig"
- "github.com/stretchr/testify/require"
- "math/rand"
- "runtime"
- "time"
-
_ "github.com/go-sql-driver/mysql"
+ "github.com/stretchr/testify/require"
+ "runtime"
"github.com/pkg/profile"
"os"
@@ -33,7 +30,6 @@ func sourceIsMariaDB() bool {
}
func TestMain(m *testing.M) {
- rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
var err error
@@ -101,3 +97,9 @@ func skipForMariaDB(t *testing.T) {
t.SkipNow()
}
}
+
+func onlyMariaDB(t *testing.T) {
+ if !sourceIsMariaDB() {
+ t.SkipNow()
+ }
+}
diff --git a/tests/mysql/values_test.go b/tests/mysql/values_test.go
new file mode 100644
index 0000000..09e9332
--- /dev/null
+++ b/tests/mysql/values_test.go
@@ -0,0 +1,347 @@
+package mysql
+
+import (
+ "github.com/go-jet/jet/v2/internal/testutils"
+ "github.com/stretchr/testify/require"
+ "strings"
+ "testing"
+ "time"
+
+ . "github.com/go-jet/jet/v2/mysql"
+
+ "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model"
+ . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
+)
+
+func TestVALUES(t *testing.T) {
+ skipForMariaDB(t)
+
+ valuesTable := VALUES(
+ ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")),
+ ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")),
+ ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL),
+ ).AS("values_table")
+
+ stmt := SELECT(
+ valuesTable.AllColumns(),
+ ).FROM(
+ valuesTable,
+ )
+
+ testutils.AssertStatementSql(t, stmt, `
+SELECT values_table.column_0 AS "column_0",
+ values_table.column_1 AS "column_1",
+ values_table.column_2 AS "column_2",
+ values_table.column_3 AS "column_3",
+ values_table.column_4 AS "column_4"
+FROM (
+ VALUES ROW(?, ?, ?, ?, ?),
+ ROW(? + ?, ?, ?, ?, ?),
+ ROW(?, ?, ?, ?, NULL)
+ ) AS values_table;
+`)
+
+ var dest []struct {
+ Column0 int
+ Column1 int
+ Column2 float32
+ Column3 bool
+ Column4 *string
+ }
+
+ err := stmt.Query(db, &dest)
+
+ require.NoError(t, err)
+ testutils.AssertJSON(t, dest, `
+[
+ {
+ "Column0": 1,
+ "Column1": 2,
+ "Column2": 4.666,
+ "Column3": false,
+ "Column4": "txt"
+ },
+ {
+ "Column0": 13,
+ "Column1": 22,
+ "Column2": 33.222,
+ "Column3": true,
+ "Column4": "png"
+ },
+ {
+ "Column0": 11,
+ "Column1": 22,
+ "Column2": 33.222,
+ "Column3": true,
+ "Column4": null
+ }
+]
+`)
+}
+
+func TestVALUES_Join(t *testing.T) {
+ skipForMariaDB(t)
+
+ title := StringColumn("title")
+ releaseYear := IntegerColumn("ReleaseYear")
+ rentalRate := FloatColumn("rental_rate")
+
+ lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0)
+
+ films := VALUES(
+ ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate),
+ ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))),
+ ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL),
+ ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL),
+ ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))),
+ ).AS("film_values",
+ title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("last_update"))
+
+ stmt := SELECT(
+ Film.AllColumns,
+ films.AllColumns(),
+ ).FROM(
+ Film.
+ INNER_JOIN(films, title.EQ(Film.Title)),
+ ).WHERE(AND(
+ Film.ReleaseYear.GT(releaseYear),
+ Film.RentalRate.LT(rentalRate),
+ )).ORDER_BY(
+ title,
+ )
+
+ testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(`
+SELECT film.film_id AS "film.film_id",
+ film.title AS "film.title",
+ film.description AS "film.description",
+ film.release_year AS "film.release_year",
+ film.language_id AS "film.language_id",
+ film.original_language_id AS "film.original_language_id",
+ film.rental_duration AS "film.rental_duration",
+ film.rental_rate AS "film.rental_rate",
+ film.length AS "film.length",
+ film.replacement_cost AS "film.replacement_cost",
+ film.rating AS "film.rating",
+ film.special_features AS "film.special_features",
+ film.last_update AS "film.last_update",
+ film_values.title AS "title",
+ film_values.length AS "length",
+ film_values.''ReleaseYear'' AS "ReleaseYear",
+ film_values.rental_rate AS "rental_rate",
+ film_values.last_update AS "last_update"
+FROM dvds.film
+ INNER JOIN (
+ VALUES ROW('Chamber Italian', 117, 2005, 5.82, TIMESTAMP('2007-02-11 12:00:00')),
+ ROW('Grosse Wonderful', 49, 2004, 6.242, TIMESTAMP('2007-02-11 12:00:00') + INTERVAL 1 HOUR),
+ ROW('Airport Pollock', 54, 2001, 7.22, NULL),
+ ROW('Bright Encounters', 73, 2002, 8.25, NULL),
+ ROW('Academy Dinosaur', 83, 2010, 9.22, TIMESTAMP('2007-02-11 12:00:00') - INTERVAL 2 MINUTE)
+ ) AS film_values (title, length, ''ReleaseYear'', rental_rate, last_update) ON (film_values.title = film.title)
+WHERE (
+ (film.release_year > film_values.''ReleaseYear'')
+ AND (film.rental_rate < film_values.rental_rate)
+ )
+ORDER BY film_values.title;
+`, "''", "`"))
+
+ var dest []struct {
+ Film model.Film
+
+ Title string
+ Length int
+ ReleaseYear int
+ RentalRate float32
+ LastUpdate *time.Time
+ }
+
+ err := stmt.Query(db, &dest)
+
+ require.NoError(t, err)
+ require.Len(t, dest, 4)
+ testutils.AssertJSON(t, dest[0:2], `
+[
+ {
+ "Film": {
+ "FilmID": 8,
+ "Title": "AIRPORT POLLOCK",
+ "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India",
+ "ReleaseYear": 2006,
+ "LanguageID": 1,
+ "OriginalLanguageID": null,
+ "RentalDuration": 6,
+ "RentalRate": 4.99,
+ "Length": 54,
+ "ReplacementCost": 15.99,
+ "Rating": "R",
+ "SpecialFeatures": "Trailers",
+ "LastUpdate": "2006-02-15T05:03:42Z"
+ },
+ "Title": "Airport Pollock",
+ "Length": 54,
+ "ReleaseYear": 2001,
+ "RentalRate": 7.22,
+ "LastUpdate": null
+ },
+ {
+ "Film": {
+ "FilmID": 98,
+ "Title": "BRIGHT ENCOUNTERS",
+ "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat",
+ "ReleaseYear": 2006,
+ "LanguageID": 1,
+ "OriginalLanguageID": null,
+ "RentalDuration": 4,
+ "RentalRate": 4.99,
+ "Length": 73,
+ "ReplacementCost": 12.99,
+ "Rating": "PG-13",
+ "SpecialFeatures": "Trailers",
+ "LastUpdate": "2006-02-15T05:03:42Z"
+ },
+ "Title": "Bright Encounters",
+ "Length": 73,
+ "ReleaseYear": 2002,
+ "RentalRate": 8.25,
+ "LastUpdate": null
+ }
+]
+`)
+}
+
+func TestVALUES_CTE_Update(t *testing.T) {
+ skipForMariaDB(t)
+
+ paymentID := IntegerColumn("payment_id")
+ increase := FloatColumn("increase")
+ paymentsToUpdate := CTE("values_cte", paymentID, increase)
+
+ stmt := WITH(
+ paymentsToUpdate.AS(
+ VALUES(
+ ROW(Int32(204), Float(1.21)),
+ ROW(Int32(207), Float(1.02)),
+ ROW(Int32(200), Float(1.34)),
+ ROW(Int32(203), Float(1.72)),
+ ),
+ ),
+ )(
+ Payment.INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)).
+ UPDATE().
+ SET(
+ Payment.Amount.SET(Payment.Amount.MUL(increase)),
+ ).WHERE(Bool(true)),
+ )
+
+ testutils.AssertStatementSql(t, stmt, `
+WITH values_cte (payment_id, increase) AS (
+ VALUES ROW(?, ?),
+ ROW(?, ?),
+ ROW(?, ?),
+ ROW(?, ?)
+)
+UPDATE dvds.payment
+INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id)
+SET amount = (payment.amount * values_cte.increase)
+WHERE ?;
+`)
+
+ testutils.AssertExecAndRollback(t, stmt, db, 4)
+}
+
+func TestVALUES_MariaDB(t *testing.T) {
+ onlyMariaDB(t) // mariadb won't accept values rows if all the elements are placeholders, so we have to use raw statement
+
+ paymentID := IntegerColumn("payment_id")
+ increase := FloatColumn("increase")
+ paymentsToUpdate := CTE("values_cte", paymentID, increase)
+
+ stmt := WITH(
+ paymentsToUpdate.AS(
+ RawStatement(`
+ VALUES (204, 1.21),
+ (207, 1.02),
+ (200, 1.34),
+ (203, 1.72)
+ `),
+ ),
+ )(
+ SELECT(
+ Payment.AllColumns,
+ paymentsToUpdate.AllColumns(),
+ ).FROM(
+ Payment.
+ INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)),
+ ).WHERE(
+ increase.GT(Float(1.03)),
+ ).ORDER_BY(
+ increase,
+ ),
+ )
+
+ testutils.AssertStatementSql(t, stmt, `
+WITH values_cte (payment_id, increase) AS (
+ VALUES (204, 1.21),
+ (207, 1.02),
+ (200, 1.34),
+ (203, 1.72)
+
+)
+SELECT payment.payment_id AS "payment.payment_id",
+ payment.customer_id AS "payment.customer_id",
+ payment.staff_id AS "payment.staff_id",
+ payment.rental_id AS "payment.rental_id",
+ payment.amount AS "payment.amount",
+ payment.payment_date AS "payment.payment_date",
+ payment.last_update AS "payment.last_update",
+ values_cte.payment_id AS "payment_id",
+ values_cte.increase AS "increase"
+FROM dvds.payment
+ INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id)
+WHERE values_cte.increase > ?
+ORDER BY values_cte.increase;
+`)
+
+ var dest []struct {
+ model.Payment
+
+ Increase float64
+ }
+
+ err := stmt.Query(db, &dest)
+
+ require.NoError(t, err)
+ testutils.AssertJSON(t, dest, `
+[
+ {
+ "PaymentID": 204,
+ "CustomerID": 7,
+ "StaffID": 1,
+ "RentalID": 13476,
+ "Amount": 2.99,
+ "PaymentDate": "2005-08-20T01:06:04Z",
+ "LastUpdate": "2006-02-15T22:12:31Z",
+ "Increase": 1.21
+ },
+ {
+ "PaymentID": 200,
+ "CustomerID": 7,
+ "StaffID": 2,
+ "RentalID": 11542,
+ "Amount": 7.99,
+ "PaymentDate": "2005-08-17T00:51:32Z",
+ "LastUpdate": "2006-02-15T22:12:31Z",
+ "Increase": 1.34
+ },
+ {
+ "PaymentID": 203,
+ "CustomerID": 7,
+ "StaffID": 2,
+ "RentalID": 13373,
+ "Amount": 2.99,
+ "PaymentDate": "2005-08-19T21:23:31Z",
+ "LastUpdate": "2006-02-15T22:12:31Z",
+ "Increase": 1.72
+ }
+]
+`)
+}
diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go
index d7d8d3d..59da3a5 100644
--- a/tests/mysql/with_test.go
+++ b/tests/mysql/with_test.go
@@ -164,10 +164,10 @@ WITH payments_to_delete AS (
WHERE payment.amount < 0.5
)
DELETE FROM dvds.payment
-WHERE payment.payment_id IN (
+WHERE payment.payment_id IN ((
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete
- );
+ ));
`, "''", "`"))
tx, err := db.Begin()
diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go
index 25a29d2..6d9c0b7 100644
--- a/tests/postgres/alltypes_test.go
+++ b/tests/postgres/alltypes_test.go
@@ -1,6 +1,10 @@
package postgres
import (
+ "database/sql"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
+ "github.com/stretchr/testify/assert"
+
"github.com/go-jet/jet/v2/qrm"
"testing"
"time"
@@ -344,24 +348,22 @@ func TestExpressionOperators(t *testing.T) {
AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"),
).LIMIT(2)
- //fmt.Println(query.Sql())
-
testutils.AssertStatementSql(t, query, `
SELECT all_types.integer IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in",
- (all_types.small_int_ptr IN (
+ (all_types.small_int_ptr IN ((
SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types
- )) AS "result.in_select",
+ ))) AS "result.in_select",
(CURRENT_USER) AS "result.raw",
($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg",
($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in",
- (all_types.small_int_ptr NOT IN (
+ (all_types.small_int_ptr NOT IN ((
SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types
- )) AS "result.not_in_select"
+ ))) AS "result.not_in_select"
FROM test_sample.all_types
LIMIT $11;
`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2))
@@ -373,9 +375,6 @@ LIMIT $11;
err := query.Query(db, &dest)
require.NoError(t, err)
-
- //testutils.PrintJson(dest)
-
testutils.AssertJSON(t, dest, `
[
{
@@ -930,6 +929,68 @@ func TestTimeExpression(t *testing.T) {
require.NoError(t, err)
}
+func TestIntervalSetFunctionality(t *testing.T) {
+
+ t.Run("updateQueryIntervalTest", func(t *testing.T) {
+
+ expectedQuery := `
+UPDATE test_sample.employee
+SET pto_accrual = INTERVAL '3 HOUR'
+WHERE employee.employee_id = $1
+RETURNING employee.employee_id AS "employee.employee_id",
+ employee.first_name AS "employee.first_name",
+ employee.last_name AS "employee.last_name",
+ employee.employment_date AS "employee.employment_date",
+ employee.manager_id AS "employee.manager_id",
+ employee.pto_accrual AS "employee.pto_accrual";
+`
+ testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
+ var windy model.Employee
+ windy.PtoAccrual = ptr.Of("3h")
+ stmt := Employee.UPDATE(Employee.PtoAccrual).SET(
+ Employee.PtoAccrual.SET(INTERVAL(3, HOUR)),
+ ).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns)
+
+ testutils.AssertStatementSql(t, stmt, expectedQuery)
+ err := stmt.Query(tx, &windy)
+ assert.Nil(t, err)
+ assert.Equal(t, *windy.PtoAccrual, "03:00:00")
+
+ })
+ })
+
+ t.Run("upsertQueryIntervalTest", func(t *testing.T) {
+ expectedQuery := `
+INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
+VALUES ($1, $2, $3, $4, $5, $6)
+ON CONFLICT (employee_id) DO UPDATE
+ SET pto_accrual = excluded.pto_accrual
+RETURNING employee.employee_id AS "employee.employee_id",
+ employee.first_name AS "employee.first_name",
+ employee.last_name AS "employee.last_name",
+ employee.employment_date AS "employee.employment_date",
+ employee.manager_id AS "employee.manager_id",
+ employee.pto_accrual AS "employee.pto_accrual";
+`
+ testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
+ var employee model.Employee
+ employee.PtoAccrual = ptr.Of("5h")
+ stmt := Employee.INSERT(Employee.AllColumns).
+ MODEL(employee).
+ ON_CONFLICT(Employee.EmployeeID).
+ DO_UPDATE(SET(
+ Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual),
+ )).RETURNING(Employee.AllColumns)
+
+ testutils.AssertStatementSql(t, stmt, expectedQuery)
+ err := stmt.Query(tx, &employee)
+ assert.Nil(t, err)
+ assert.Equal(t, *employee.PtoAccrual, "05:00:00")
+
+ })
+ })
+}
+
func TestInterval(t *testing.T) {
skipForCockroachDB(t)
@@ -1046,6 +1107,46 @@ FROM test_sample.all_types;
require.NoError(t, err)
}
+func TestRowExpression(t *testing.T) {
+ now := time.Now()
+ nowAddHour := time.Now().Add(time.Hour)
+
+ stmt := SELECT(
+ ROW(Int32(1), Float32(11.22), String("john")).AS("row"),
+ WRAP(Int64(1), Float64(11.22), String("john")).AS("wrap"),
+
+ ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))),
+ WRAP(Bool(false), DateT(now)).NOT_EQ(WRAP(Bool(true), DateT(now))),
+
+ ROW(TimeT(nowAddHour)).IS_DISTINCT_FROM(RowExp(Raw("row(NOW()::time)"))),
+ ROW().IS_NOT_DISTINCT_FROM(ROW()),
+
+ ROW(TimestampT(now), TimestampzT(nowAddHour)).GT(WRAP(TimestampT(now), TimestampzT(now))),
+ ROW(TimestampzT(nowAddHour)).GT_EQ(ROW(TimestampzT(now))),
+ WRAP(TimestampT(now), TimestampzT(nowAddHour)).LT(ROW(TimestampT(now), TimestampzT(now))),
+ ROW(TimestampzT(nowAddHour)).LT_EQ(ROW(TimestampzT(now))),
+ )
+
+ //fmt.Println(stmt.Sql())
+ //fmt.Println(stmt.DebugSql())
+
+ testutils.AssertStatementSql(t, stmt, `
+SELECT ROW($1::integer, $2::real, $3::text) AS "row",
+ ($4::bigint, $5::double precision, $6::text) AS "wrap",
+ ROW($7::boolean, $8::date) = ROW($9::boolean, $10::date),
+ ($11::boolean, $12::date) != ($13::boolean, $14::date),
+ ROW($15::time without time zone) IS DISTINCT FROM (row(NOW()::time)),
+ ROW() IS NOT DISTINCT FROM ROW(),
+ ROW($16::timestamp without time zone, $17::timestamp with time zone) > ($18::timestamp without time zone, $19::timestamp with time zone),
+ ROW($20::timestamp with time zone) >= ROW($21::timestamp with time zone),
+ ($22::timestamp without time zone, $23::timestamp with time zone) < ROW($24::timestamp without time zone, $25::timestamp with time zone),
+ ROW($26::timestamp with time zone) <= ROW($27::timestamp with time zone);
+`)
+
+ err := stmt.Query(db, &struct{}{})
+ require.NoError(t, err)
+}
+
func TestSubQueryColumnReference(t *testing.T) {
type expected struct {
sql string
@@ -1305,32 +1406,32 @@ RETURNING all_types.json AS "all_types.json";
var moodSad = model.Mood_Sad
var allTypesRow0 = model.AllTypes{
- SmallIntPtr: testutils.Int16Ptr(14),
+ SmallIntPtr: ptr.Of(int16(14)),
SmallInt: 14,
- IntegerPtr: testutils.Int32Ptr(300),
+ IntegerPtr: ptr.Of(int32(300)),
Integer: 300,
- BigIntPtr: testutils.Int64Ptr(50000),
+ BigIntPtr: ptr.Of(int64(50000)),
BigInt: 5000,
- DecimalPtr: testutils.Float64Ptr(1.11),
+ DecimalPtr: ptr.Of(1.11),
Decimal: 1.11,
- NumericPtr: testutils.Float64Ptr(2.22),
+ NumericPtr: ptr.Of(2.22),
Numeric: 2.22,
- RealPtr: testutils.Float32Ptr(5.55),
+ RealPtr: ptr.Of(float32(5.55)),
Real: 5.55,
- DoublePrecisionPtr: testutils.Float64Ptr(11111111.22),
+ DoublePrecisionPtr: ptr.Of(11111111.22),
DoublePrecision: 11111111.22,
Smallserial: 1,
Serial: 1,
Bigserial: 1,
//MoneyPtr: nil,
//Money:
- VarCharPtr: testutils.StringPtr("ABBA"),
+ VarCharPtr: ptr.Of("ABBA"),
VarChar: "ABBA",
- CharPtr: testutils.StringPtr("JOHN "),
+ CharPtr: ptr.Of("JOHN "),
Char: "JOHN ",
- TextPtr: testutils.StringPtr("Some text"),
+ TextPtr: ptr.Of("Some text"),
Text: "Some text",
- ByteaPtr: testutils.ByteArrayPtr([]byte("bytea")),
+ ByteaPtr: ptr.Of([]byte("bytea")),
Bytea: []byte("bytea"),
TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
@@ -1342,31 +1443,31 @@ var allTypesRow0 = model.AllTypes{
Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"),
TimePtr: testutils.TimeWithoutTimeZone("04:05:06"),
Time: *testutils.TimeWithoutTimeZone("04:05:06"),
- IntervalPtr: testutils.StringPtr("3 days 04:05:06"),
+ IntervalPtr: ptr.Of("3 days 04:05:06"),
Interval: "3 days 04:05:06",
- BooleanPtr: testutils.BoolPtr(true),
+ BooleanPtr: ptr.Of(true),
Boolean: false,
- PointPtr: testutils.StringPtr("(2,3)"),
- BitPtr: testutils.StringPtr("101"),
+ PointPtr: ptr.Of("(2,3)"),
+ BitPtr: ptr.Of("101"),
Bit: "101",
- BitVaryingPtr: testutils.StringPtr("101111"),
+ BitVaryingPtr: ptr.Of("101111"),
BitVarying: "101111",
- TsvectorPtr: testutils.StringPtr("'supernova':1"),
+ TsvectorPtr: ptr.Of("'supernova':1"),
Tsvector: "'supernova':1",
UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
- XMLPtr: testutils.StringPtr("abc"),
+ XMLPtr: ptr.Of("abc"),
XML: "abc",
- JSONPtr: testutils.StringPtr(`{"a": 1, "b": 3}`),
+ JSONPtr: ptr.Of(`{"a": 1, "b": 3}`),
JSON: `{"a": 1, "b": 3}`,
- JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`),
+ JsonbPtr: ptr.Of(`{"a": 1, "b": 3}`),
Jsonb: `{"a": 1, "b": 3}`,
- IntegerArrayPtr: testutils.StringPtr("{1,2,3}"),
+ IntegerArrayPtr: ptr.Of("{1,2,3}"),
IntegerArray: "{1,2,3}",
- TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"),
+ TextArrayPtr: ptr.Of("{breakfast,consulting}"),
TextArray: "{breakfast,consulting}",
JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`,
- TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"),
+ TextMultiDimArrayPtr: ptr.Of("{{meeting,lunch},{training,presentation}}"),
TextMultiDimArray: "{{meeting,lunch},{training,presentation}}",
MoodPtr: &moodSad,
Mood: model.Mood_Happy,
diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go
index 462e205..349a4ba 100644
--- a/tests/postgres/chinook_db_test.go
+++ b/tests/postgres/chinook_db_test.go
@@ -3,6 +3,7 @@ package postgres
import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
. "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table"
@@ -455,7 +456,7 @@ FROM (
require.Len(t, dest, 275)
require.Equal(t, dest[0].Artist1.Artist, model.Artist{
ArtistId: 1,
- Name: testutils.StringPtr("AC/DC"),
+ Name: ptr.Of("AC/DC"),
})
require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1")
require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2")
diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go
index 2bc2d32..65603cd 100644
--- a/tests/postgres/generator_template_test.go
+++ b/tests/postgres/generator_template_test.go
@@ -3,6 +3,9 @@ package postgres
import (
"database/sql"
"fmt"
+ "path"
+ "testing"
+
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/generator/template"
@@ -13,8 +16,6 @@ import (
"github.com/go-jet/jet/v2/tests/dbconfig"
file2 "github.com/go-jet/jet/v2/tests/internal/utils/file"
"github.com/stretchr/testify/require"
- "path"
- "testing"
)
const tempTestDir = "./.tempTestDir"
@@ -170,6 +171,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) {
mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go")
require.Contains(t, mpaaRating, "type MpaaRatingEnum string")
+ require.Contains(t, mpaaRating, "MpaaRatingEnumAllValues")
}
func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) {
@@ -267,7 +269,6 @@ func UseSchema(schema string) {
FilmList = FilmList.FromSchema(schema)
}
`)
-
}
func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) {
@@ -365,7 +366,6 @@ func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) {
}
func TestGeneratorTemplate_Model_AddTags(t *testing.T) {
-
err := postgres.Generate(
tempTestDir,
dbConnection,
diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go
index 479d82f..87a7eee 100644
--- a/tests/postgres/generator_test.go
+++ b/tests/postgres/generator_test.go
@@ -2,7 +2,6 @@ package postgres
import (
"fmt"
- "github.com/go-jet/jet/v2/tests/internal/utils/file"
"os"
"os/exec"
"path/filepath"
@@ -12,11 +11,14 @@ import (
"github.com/stretchr/testify/require"
+ "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/postgres"
+ "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/testutils"
- "github.com/go-jet/jet/v2/tests/dbconfig"
-
+ postgres2 "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
+ "github.com/go-jet/jet/v2/tests/dbconfig"
+ "github.com/go-jet/jet/v2/tests/internal/utils/file"
)
func dsn(host string, port int, dbName, user, password string) string {
@@ -208,6 +210,45 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)
}
+func TestGenerator_TableMetadata(t *testing.T) {
+ var schema metadata.Schema
+ err := postgres.GenerateDSN(defaultDSN(), "dvds", genTestDir2,
+ template.Default(postgres2.Dialect).UseSchema(func(m metadata.Schema) template.Schema {
+ schema = m
+ return template.DefaultSchema(m)
+ }))
+ require.NoError(t, err)
+
+ // Spot check the actor table and assert that the emitted
+ // properties are as expected.
+ var got metadata.Table
+ var specialFeatures metadata.Column
+ for _, table := range schema.TablesMetaData {
+ if table.Name == "actor" {
+ got = table
+ }
+ if table.Name == "film" {
+ for _, column := range table.Columns {
+ if column.Name == "special_features" {
+ specialFeatures = column
+ }
+ }
+ }
+ }
+
+ want := metadata.Table{
+ Name: "actor",
+ Columns: []metadata.Column{
+ {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "int4", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""},
+ },
+ }
+ require.Equal(t, want, got)
+ require.Equal(t, metadata.ArrayType, specialFeatures.DataType.Kind)
+}
+
func TestGeneratorSpecialCharacters(t *testing.T) {
t.SkipNow()
err := postgres.Generate(genTestDir2, postgres.DBConnection{
@@ -563,6 +604,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go",
"components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go")
testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent)
+ testutils.AssertFileContent(t, modelDir+"/link.go", linkModelContent)
testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go",
"person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go",
@@ -570,6 +612,8 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent)
testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent)
+ testutils.AssertFileContent(t, tableDir+"/link.go", linkTableContent)
+
testutils.AssertFileNamesEqual(t, viewDir, "all_types_materialized_view.go", "all_types_view.go",
"view_use_schema.go")
}
@@ -609,6 +653,7 @@ package enum
import "github.com/go-jet/jet/v2/postgres"
+// Level enum
var Level = &struct {
Level1 postgres.StringExpression
Level2 postgres.StringExpression
@@ -706,6 +751,25 @@ type AllTypes struct {
}
`
+var linkModelContent = `
+//
+// Code generated by go-jet DO NOT EDIT.
+//
+// WARNING: Changes to this file may cause incorrect behavior
+// and will be lost if the code is regenerated
+//
+
+package model
+
+// Link table
+type Link struct {
+ ID int64 ` + "`sql:\"primary_key\"`" + ` // this is link id
+ URL string // link url
+ Name string // Unicode characters comment ₲鬼佬℧⇄↻
+ Description *string // '"Z\%_
+}
+`
+
var allTypesTableContent = `
//
// Code generated by go-jet DO NOT EDIT.
@@ -1062,3 +1126,91 @@ func newSampleRangesTableImpl(schemaName, tableName, alias string) sampleRangesT
}
}
`
+
+var linkTableContent = `
+//
+// Code generated by go-jet DO NOT EDIT.
+//
+// WARNING: Changes to this file may cause incorrect behavior
+// and will be lost if the code is regenerated
+//
+
+package table
+
+import (
+ "github.com/go-jet/jet/v2/postgres"
+)
+
+var Link = newLinkTable("test_sample", "link", "")
+
+// Link table
+type linkTable struct {
+ postgres.Table
+
+ // Columns
+ ID postgres.ColumnInteger // this is link id
+ URL postgres.ColumnString // link url
+ Name postgres.ColumnString // Unicode characters comment ₲鬼佬℧⇄↻
+ Description postgres.ColumnString // '"Z\%_
+
+ AllColumns postgres.ColumnList
+ MutableColumns postgres.ColumnList
+}
+
+type LinkTable struct {
+ linkTable
+
+ EXCLUDED linkTable
+}
+
+// AS creates new LinkTable with assigned alias
+func (a LinkTable) AS(alias string) *LinkTable {
+ return newLinkTable(a.SchemaName(), a.TableName(), alias)
+}
+
+// Schema creates new LinkTable with assigned schema name
+func (a LinkTable) FromSchema(schemaName string) *LinkTable {
+ return newLinkTable(schemaName, a.TableName(), a.Alias())
+}
+
+// WithPrefix creates new LinkTable with assigned table prefix
+func (a LinkTable) WithPrefix(prefix string) *LinkTable {
+ return newLinkTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
+}
+
+// WithSuffix creates new LinkTable with assigned table suffix
+func (a LinkTable) WithSuffix(suffix string) *LinkTable {
+ return newLinkTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
+}
+
+func newLinkTable(schemaName, tableName, alias string) *LinkTable {
+ return &LinkTable{
+ linkTable: newLinkTableImpl(schemaName, tableName, alias),
+ EXCLUDED: newLinkTableImpl("", "excluded", ""),
+ }
+}
+
+func newLinkTableImpl(schemaName, tableName, alias string) linkTable {
+ var (
+ IDColumn = postgres.IntegerColumn("id")
+ URLColumn = postgres.StringColumn("url")
+ NameColumn = postgres.StringColumn("name")
+ DescriptionColumn = postgres.StringColumn("description")
+ allColumns = postgres.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn}
+ mutableColumns = postgres.ColumnList{URLColumn, NameColumn, DescriptionColumn}
+ )
+
+ return linkTable{
+ Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
+
+ //Columns
+ ID: IDColumn,
+ URL: URLColumn,
+ Name: NameColumn,
+ Description: DescriptionColumn,
+
+ AllColumns: allColumns,
+ MutableColumns: mutableColumns,
+ }
+}
+`
diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go
index 6f05a18..c375396 100644
--- a/tests/postgres/insert_test.go
+++ b/tests/postgres/insert_test.go
@@ -93,9 +93,9 @@ func TestInsertOnConflict(t *testing.T) {
ON_CONFLICT().DO_NOTHING()
testutils.AssertStatementSql(t, stmt, `
-INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
-VALUES ($1, $2, $3, $4, $5),
- ($6, $7, $8, $9, $10)
+INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
+VALUES ($1, $2, $3, $4, $5, $6),
+ ($7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING;
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
@@ -111,9 +111,9 @@ ON CONFLICT DO NOTHING;
ON_CONFLICT(Employee.EmployeeID).DO_NOTHING()
testutils.AssertStatementSql(t, stmt, `
-INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
-VALUES ($1, $2, $3, $4, $5),
- ($6, $7, $8, $9, $10)
+INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
+VALUES ($1, $2, $3, $4, $5, $6),
+ ($7, $8, $9, $10, $11, $12)
ON CONFLICT (employee_id) DO NOTHING;
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
@@ -130,9 +130,9 @@ ON CONFLICT (employee_id) DO NOTHING;
ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING()
testutils.AssertStatementSql(t, stmt, `
-INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
-VALUES ($1, $2, $3, $4, $5),
- ($6, $7, $8, $9, $10)
+INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
+VALUES ($1, $2, $3, $4, $5, $6),
+ ($7, $8, $9, $10, $11, $12)
ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING;
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
@@ -234,8 +234,8 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE
ON_CONFLICT().DO_UPDATE(nil)
testutils.AssertStatementSql(t, stmt, `
-INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
-VALUES ($1, $2, $3, $4, $5);
+INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
+VALUES ($1, $2, $3, $4, $5, $6);
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
requireLogged(t, stmt)
diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go
index 1251c93..1f02e69 100644
--- a/tests/postgres/main_test.go
+++ b/tests/postgres/main_test.go
@@ -5,13 +5,10 @@ import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
- "math/rand"
+ "github.com/jackc/pgx/v4/stdlib"
"os"
"runtime"
"testing"
- "time"
-
- "github.com/jackc/pgx/v4/stdlib"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/dbconfig"
@@ -44,7 +41,6 @@ func skipForCockroachDB(t *testing.T) {
}
func TestMain(m *testing.M) {
- rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
setTestRoot()
diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go
index 3084b0d..7aad0d4 100644
--- a/tests/postgres/northwind_test.go
+++ b/tests/postgres/northwind_test.go
@@ -11,34 +11,31 @@ import (
func TestNorthwindJoinEverything(t *testing.T) {
- stmt := SELECT(
- Customers.AllColumns,
- CustomerDemographics.AllColumns,
- Orders.AllColumns,
- Shippers.AllColumns,
- OrderDetails.AllColumns,
- Products.AllColumns,
- Categories.AllColumns,
- Suppliers.AllColumns,
- ).FROM(
- Customers.
- LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)).
- LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)).
- LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)).
- LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)).
- LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)).
- LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)).
- LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)).
- LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)).
- LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)).
- LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)).
- LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)).
- LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)),
- ).ORDER_BY(
- Customers.CustomerID,
- Orders.OrderID,
- Products.ProductID,
- )
+ stmt :=
+ SELECT(
+ Customers.AllColumns,
+ CustomerDemographics.AllColumns,
+ Orders.AllColumns,
+ Shippers.AllColumns,
+ OrderDetails.AllColumns,
+ Products.AllColumns,
+ Categories.AllColumns,
+ Suppliers.AllColumns,
+ ).FROM(
+ Customers.
+ LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)).
+ LEFT_JOIN(CustomerDemographics, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID)).
+ LEFT_JOIN(Orders, Orders.CustomerID.EQ(Customers.CustomerID)).
+ LEFT_JOIN(Shippers, Orders.ShipVia.EQ(Shippers.ShipperID)).
+ LEFT_JOIN(OrderDetails, Orders.OrderID.EQ(OrderDetails.OrderID)).
+ LEFT_JOIN(Products, OrderDetails.ProductID.EQ(Products.ProductID)).
+ LEFT_JOIN(Categories, Products.CategoryID.EQ(Categories.CategoryID)).
+ LEFT_JOIN(Suppliers, Products.SupplierID.EQ(Suppliers.SupplierID)).
+ LEFT_JOIN(Employees, Orders.EmployeeID.EQ(Employees.EmployeeID)).
+ LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)).
+ LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)).
+ LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)),
+ ).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID)
var dest []struct {
model.Customers
diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go
index 3d4d011..73b652c 100644
--- a/tests/postgres/sample_test.go
+++ b/tests/postgres/sample_test.go
@@ -2,6 +2,7 @@ package postgres
import (
"github.com/go-jet/jet/v2/qrm"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/google/uuid"
"testing"
@@ -63,15 +64,15 @@ func TestExactDecimals(t *testing.T) {
Floats: model.Floats{
// overwritten by wrapped(floats) scope
Numeric: 0.1,
- NumericPtr: testutils.Float64Ptr(0.1),
+ NumericPtr: ptr.Of(0.1),
Decimal: 0.1,
- DecimalPtr: testutils.Float64Ptr(0.1),
+ DecimalPtr: ptr.Of(0.1),
// not overwritten
Real: 0.4,
- RealPtr: testutils.Float32Ptr(0.44),
+ RealPtr: ptr.Of(float32(0.44)),
Double: 0.3,
- DoublePtr: testutils.Float64Ptr(0.33),
+ DoublePtr: ptr.Of(0.33),
},
Numeric: decimal.RequireFromString("0.1234567890123456789"),
NumericPtr: decimal.RequireFromString("1.1111111111111111111"),
@@ -330,11 +331,13 @@ SELECT employee.employee_id AS "employee.employee_id",
employee.last_name AS "employee.last_name",
employee.employment_date AS "employee.employment_date",
employee.manager_id AS "employee.manager_id",
+ employee.pto_accrual AS "employee.pto_accrual",
manager.employee_id AS "manager.employee_id",
manager.first_name AS "manager.first_name",
manager.last_name AS "manager.last_name",
manager.employment_date AS "manager.employment_date",
- manager.manager_id AS "manager.manager_id"
+ manager.manager_id AS "manager.manager_id",
+ manager.pto_accrual AS "manager.pto_accrual"
FROM test_sample.employee
LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id)
ORDER BY employee.employee_id;
@@ -369,6 +372,7 @@ ORDER BY employee.employee_id;
LastName: "Hays",
EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06.1 +0100 CET", 1),
ManagerID: nil,
+ PtoAccrual: ptr.Of("22:00:00"),
})
require.True(t, dest[0].Manager == nil)
@@ -378,7 +382,7 @@ ORDER BY employee.employee_id;
FirstName: "Salley",
LastName: "Lester",
EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1),
- ManagerID: testutils.Int32Ptr(3),
+ ManagerID: ptr.Of(int32(3)),
})
}
@@ -420,7 +424,7 @@ FROM test_sample."WEIRD NAMES TABLE";
WeirdColumnName5: "Doe",
WeirdColumnName6: "Doe",
WeirdColumnName7: "Doe",
- Weirdcolumnname8: testutils.StringPtr("Doe"),
+ Weirdcolumnname8: ptr.Of("Doe"),
WeirdColName9: "Doe",
WeirdColuName10: "Doe",
WeirdColuName11: "Doe",
@@ -518,7 +522,7 @@ func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) {
).MODEL(
model.People{
PeopleName: "Dario",
- PeopleHeightCm: testutils.Float64Ptr(120),
+ PeopleHeightCm: ptr.Of(120.0),
},
).RETURNING(
People.MutableColumns,
diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go
index 24b5949..321fc38 100644
--- a/tests/postgres/scan_test.go
+++ b/tests/postgres/scan_test.go
@@ -2,6 +2,7 @@ package postgres
import (
"context"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/volatiletech/null/v8"
"testing"
"time"
@@ -93,7 +94,7 @@ func TestScanToValidDestination(t *testing.T) {
err := oneInventoryQuery.Query(db, &dest)
require.NoError(t, err)
- require.Equal(t, dest[0], testutils.Int32Ptr(1))
+ require.Equal(t, dest[0], ptr.Of(int32(1)))
})
t.Run("NULL to integer", func(t *testing.T) {
@@ -530,10 +531,10 @@ func TestScanToSlice(t *testing.T) {
require.NoError(t, err)
require.Equal(t, len(dest), 2)
testutils.AssertDeepEqual(t, dest[0].Film, film1)
- testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.Int32Ptr(1), testutils.Int32Ptr(2), testutils.Int32Ptr(3), testutils.Int32Ptr(4),
- testutils.Int32Ptr(5), testutils.Int32Ptr(6), testutils.Int32Ptr(7), testutils.Int32Ptr(8)})
+ testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{ptr.Of(int32(1)), ptr.Of(int32(2)), ptr.Of(int32(3)), ptr.Of(int32(4)),
+ ptr.Of(int32(5)), ptr.Of(int32(6)), ptr.Of(int32(7)), ptr.Of(int32(8))})
testutils.AssertDeepEqual(t, dest[1].Film, film2)
- testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.Int32Ptr(9), testutils.Int32Ptr(10)})
+ testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{ptr.Of(int32(9)), ptr.Of(int32(10))})
})
t.Run("complex struct 1", func(t *testing.T) {
@@ -1076,10 +1077,10 @@ VALUES (1234, 0, 'Joe', '', NULL, 1, TRUE, '2020-02-02 10:00:00Z', NULL, 1);
var address256 = model.Address{
AddressID: 256,
Address: "1497 Yuzhou Drive",
- Address2: testutils.StringPtr(""),
+ Address2: ptr.Of(""),
District: "England",
CityID: 312,
- PostalCode: testutils.StringPtr("3433"),
+ PostalCode: ptr.Of("3433"),
Phone: "246810237916",
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0),
}
@@ -1087,10 +1088,10 @@ var address256 = model.Address{
var addres517 = model.Address{
AddressID: 517,
Address: "548 Uruapan Street",
- Address2: testutils.StringPtr(""),
+ Address2: ptr.Of(""),
District: "Ontario",
CityID: 312,
- PostalCode: testutils.StringPtr("35653"),
+ PostalCode: ptr.Of("35653"),
Phone: "879347453467",
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0),
}
@@ -1100,12 +1101,12 @@ var customer256 = model.Customer{
StoreID: 2,
FirstName: "Mattie",
LastName: "Hoffman",
- Email: testutils.StringPtr("mattie.hoffman@sakilacustomer.org"),
+ Email: ptr.Of("mattie.hoffman@sakilacustomer.org"),
AddressID: 256,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0),
- Active: testutils.Int32Ptr(1),
+ Active: ptr.Of(int32(1)),
}
var customer512 = model.Customer{
@@ -1113,12 +1114,12 @@ var customer512 = model.Customer{
StoreID: 1,
FirstName: "Cecil",
LastName: "Vines",
- Email: testutils.StringPtr("cecil.vines@sakilacustomer.org"),
+ Email: ptr.Of("cecil.vines@sakilacustomer.org"),
AddressID: 517,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0),
- Active: testutils.Int32Ptr(1),
+ Active: ptr.Of(int32(1)),
}
var countryUk = model.Country{
@@ -1151,32 +1152,32 @@ var inventory2 = model.Inventory{
var film1 = model.Film{
FilmID: 1,
Title: "Academy Dinosaur",
- Description: testutils.StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"),
- ReleaseYear: testutils.Int32Ptr(2006),
+ Description: ptr.Of("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"),
+ ReleaseYear: ptr.Of(int32(2006)),
LanguageID: 1,
RentalDuration: 6,
RentalRate: 0.99,
- Length: testutils.Int16Ptr(86),
+ Length: ptr.Of(int16(86)),
ReplacementCost: 20.99,
Rating: &pgRating,
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
- SpecialFeatures: testutils.StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"),
+ SpecialFeatures: ptr.Of("{\"Deleted Scenes\",\"Behind the Scenes\"}"),
Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17",
}
var film2 = model.Film{
FilmID: 2,
Title: "Ace Goldfinger",
- Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
- ReleaseYear: testutils.Int32Ptr(2006),
+ Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
+ ReleaseYear: ptr.Of(int32(2006)),
LanguageID: 1,
RentalDuration: 3,
RentalRate: 4.99,
- Length: testutils.Int16Ptr(48),
+ Length: ptr.Of(int16(48)),
ReplacementCost: 12.99,
Rating: &gRating,
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
- SpecialFeatures: testutils.StringPtr(`{Trailers,"Deleted Scenes"}`),
+ SpecialFeatures: ptr.Of(`{Trailers,"Deleted Scenes"}`),
Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`,
}
diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go
index 0bcbffd..8bc816e 100644
--- a/tests/postgres/select_test.go
+++ b/tests/postgres/select_test.go
@@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
"testing"
"time"
@@ -1828,16 +1829,16 @@ ORDER BY film.film_id ASC;
testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{
FilmID: 2,
Title: "Ace Goldfinger",
- Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
- ReleaseYear: testutils.Int32Ptr(2006),
+ Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
+ ReleaseYear: ptr.Of(int32(2006)),
LanguageID: 1,
RentalRate: 4.99,
- Length: testutils.Int16Ptr(48),
+ Length: ptr.Of(int16(48)),
ReplacementCost: 12.99,
Rating: &gRating,
RentalDuration: 3,
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
- SpecialFeatures: testutils.StringPtr("{Trailers,\"Deleted Scenes\"}"),
+ SpecialFeatures: ptr.Of("{Trailers,\"Deleted Scenes\"}"),
Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14",
})
}
@@ -2286,11 +2287,11 @@ ORDER BY customer_payment_sum.amount_sum ASC;
FirstName: "Brian",
LastName: "Wyman",
AddressID: 323,
- Email: testutils.StringPtr("brian.wyman@sakilacustomer.org"),
+ Email: ptr.Of("brian.wyman@sakilacustomer.org"),
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
- Active: testutils.Int32Ptr(1),
+ Active: ptr.Of(int32(1)),
})
require.Equal(t, customersWithAmounts[0].AmountSum, 27.93)
@@ -3110,8 +3111,8 @@ func TestSelectDynamicCondition(t *testing.T) {
Active *bool
}
- request.CustomerID = testutils.Int64Ptr(1)
- request.Active = testutils.BoolPtr(true)
+ request.CustomerID = ptr.Of(int64(1))
+ request.Active = ptr.Of(true)
// ...
@@ -3871,12 +3872,12 @@ var customer0 = model.Customer{
StoreID: 1,
FirstName: "Mary",
LastName: "Smith",
- Email: testutils.StringPtr("mary.smith@sakilacustomer.org"),
+ Email: ptr.Of("mary.smith@sakilacustomer.org"),
AddressID: 5,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
- Active: testutils.Int32Ptr(1),
+ Active: ptr.Of(int32(1)),
}
var customer1 = model.Customer{
@@ -3884,12 +3885,12 @@ var customer1 = model.Customer{
StoreID: 1,
FirstName: "Patricia",
LastName: "Johnson",
- Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"),
+ Email: ptr.Of("patricia.johnson@sakilacustomer.org"),
AddressID: 6,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
- Active: testutils.Int32Ptr(1),
+ Active: ptr.Of(int32(1)),
}
var lastCustomer = model.Customer{
@@ -3897,10 +3898,10 @@ var lastCustomer = model.Customer{
StoreID: 2,
FirstName: "Austin",
LastName: "Cintron",
- Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"),
+ Email: ptr.Of("austin.cintron@sakilacustomer.org"),
AddressID: 605,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
- Active: testutils.Int32Ptr(1),
+ Active: ptr.Of(int32(1)),
}
diff --git a/tests/postgres/values_test.go b/tests/postgres/values_test.go
new file mode 100644
index 0000000..9e89f7e
--- /dev/null
+++ b/tests/postgres/values_test.go
@@ -0,0 +1,284 @@
+package postgres
+
+import (
+ "database/sql"
+ "github.com/go-jet/jet/v2/internal/testutils"
+ . "github.com/go-jet/jet/v2/postgres"
+ "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
+ . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "testing"
+ "time"
+)
+
+func TestVALUES(t *testing.T) {
+
+ values := VALUES(
+ WRAP(Int32(1), Int32(2), Float32(4.666), Bool(false), String("txt")),
+ WRAP(Int32(11).ADD(Int32(2)), Int32(22), Float32(33.222), Bool(true), String("png")),
+ WRAP(Int32(11), Int32(22), Float32(33.222), Bool(true), NULL),
+ ).AS("values_table")
+
+ stmt := SELECT(
+ values.AllColumns(),
+ ).FROM(
+ values,
+ )
+
+ testutils.AssertStatementSql(t, stmt, `
+SELECT values_table.column1 AS "column1",
+ values_table.column2 AS "column2",
+ values_table.column3 AS "column3",
+ values_table.column4 AS "column4",
+ values_table.column5 AS "column5"
+FROM (
+ VALUES ($1::integer, $2::integer, $3::real, $4::boolean, $5::text),
+ ($6::integer + $7::integer, $8::integer, $9::real, $10::boolean, $11::text),
+ ($12::integer, $13::integer, $14::real, $15::boolean, NULL)
+ ) AS values_table;
+`)
+
+ var dest []struct {
+ Column1 int
+ Column2 int
+ Column3 float32
+ Column4 bool
+ Column5 *string
+ }
+
+ err := stmt.Query(db, &dest)
+
+ require.NoError(t, err)
+ testutils.AssertJSON(t, dest, `
+[
+ {
+ "Column1": 1,
+ "Column2": 2,
+ "Column3": 4.666,
+ "Column4": false,
+ "Column5": "txt"
+ },
+ {
+ "Column1": 13,
+ "Column2": 22,
+ "Column3": 33.222,
+ "Column4": true,
+ "Column5": "png"
+ },
+ {
+ "Column1": 11,
+ "Column2": 22,
+ "Column3": 33.222,
+ "Column4": true,
+ "Column5": null
+ }
+]
+`)
+}
+
+func TestVALUES_Join(t *testing.T) {
+
+ title := StringColumn("title")
+ releaseYear := IntegerColumn("ReleaseYear")
+ rentalRate := FloatColumn("rental_rate")
+
+ lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0)
+
+ filmValues := VALUES(
+ WRAP(String("Chamber Italian"), Int64(117), Int32(2005), Float32(5.82), lastUpdate),
+ WRAP(String("Grosse Wonderful"), Int64(49), Int32(2004), Float32(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))),
+ WRAP(String("Airport Pollock"), Int64(54), Int32(2001), Float32(7.22), NULL),
+ WRAP(String("Bright Encounters"), Int64(73), Int32(2002), Float32(8.25), NULL),
+ WRAP(String("Academy Dinosaur"), Int64(83), Int32(2010), Float32(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))),
+ ).AS("film_values",
+ title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("update_date"))
+
+ stmt := SELECT(
+ Film.AllColumns,
+ filmValues.AllColumns(),
+ ).FROM(
+ Film.
+ INNER_JOIN(filmValues, title.EQ(Film.Title)),
+ ).WHERE(AND(
+ Film.ReleaseYear.GT(releaseYear),
+ Film.RentalRate.LT(rentalRate),
+ )).ORDER_BY(
+ title,
+ )
+
+ testutils.AssertDebugStatementSql(t, stmt, `
+SELECT film.film_id AS "film.film_id",
+ film.title AS "film.title",
+ film.description AS "film.description",
+ film.release_year AS "film.release_year",
+ film.language_id AS "film.language_id",
+ film.rental_duration AS "film.rental_duration",
+ film.rental_rate AS "film.rental_rate",
+ film.length AS "film.length",
+ film.replacement_cost AS "film.replacement_cost",
+ film.rating AS "film.rating",
+ film.last_update AS "film.last_update",
+ film.special_features AS "film.special_features",
+ film.fulltext AS "film.fulltext",
+ film_values.title AS "title",
+ film_values.length AS "length",
+ film_values."ReleaseYear" AS "ReleaseYear",
+ film_values.rental_rate AS "rental_rate",
+ film_values.update_date AS "update_date"
+FROM dvds.film
+ INNER JOIN (
+ VALUES ('Chamber Italian'::text, 117::bigint, 2005::integer, 5.820000171661377::real, '2007-02-11 12:00:00'::timestamp without time zone),
+ ('Grosse Wonderful'::text, 49::bigint, 2004::integer, 6.242000102996826::real, '2007-02-11 12:00:00'::timestamp without time zone + INTERVAL '1 HOUR'),
+ ('Airport Pollock'::text, 54::bigint, 2001::integer, 7.21999979019165::real, NULL),
+ ('Bright Encounters'::text, 73::bigint, 2002::integer, 8.25::real, NULL),
+ ('Academy Dinosaur'::text, 83::bigint, 2010::integer, 9.220000267028809::real, '2007-02-11 12:00:00'::timestamp without time zone - INTERVAL '2 MINUTE')
+ ) AS film_values (title, length, "ReleaseYear", rental_rate, update_date) ON (film_values.title = film.title)
+WHERE (
+ (film.release_year > film_values."ReleaseYear")
+ AND (film.rental_rate < film_values.rental_rate)
+ )
+ORDER BY film_values.title;
+`)
+
+ //fmt.Println(stmt.DebugSql())
+
+ var dest []struct {
+ Film model.Film
+
+ Title string
+ Length int
+ ReleaseYear int
+ RentalRate float32
+ UpdateDate *time.Time
+ }
+
+ err := stmt.Query(db, &dest)
+ require.NoError(t, err)
+
+ assert.Len(t, dest, 4)
+ testutils.AssertJSON(t, dest[0:2], `
+[
+ {
+ "Film": {
+ "FilmID": 8,
+ "Title": "Airport Pollock",
+ "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India",
+ "ReleaseYear": 2006,
+ "LanguageID": 1,
+ "RentalDuration": 6,
+ "RentalRate": 4.99,
+ "Length": 54,
+ "ReplacementCost": 15.99,
+ "Rating": "R",
+ "LastUpdate": "2013-05-26T14:50:58.951Z",
+ "SpecialFeatures": "{Trailers}",
+ "Fulltext": "'airport':1 'ancient':18 'confront':14 'epic':4 'girl':11 'india':19 'monkey':16 'moos':8 'must':13 'pollock':2 'tale':5"
+ },
+ "Title": "Airport Pollock",
+ "Length": 54,
+ "ReleaseYear": 2001,
+ "RentalRate": 7.22,
+ "UpdateDate": null
+ },
+ {
+ "Film": {
+ "FilmID": 98,
+ "Title": "Bright Encounters",
+ "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat",
+ "ReleaseYear": 2006,
+ "LanguageID": 1,
+ "RentalDuration": 4,
+ "RentalRate": 4.99,
+ "Length": 73,
+ "ReplacementCost": 12.99,
+ "Rating": "PG-13",
+ "LastUpdate": "2013-05-26T14:50:58.951Z",
+ "SpecialFeatures": "{Trailers}",
+ "Fulltext": "'boat':20 'bright':1 'conquer':14 'encount':2 'fate':4 'feminist':11 'jet':19 'lumberjack':8 'must':13 'student':16 'yarn':5"
+ },
+ "Title": "Bright Encounters",
+ "Length": 73,
+ "ReleaseYear": 2002,
+ "RentalRate": 8.25,
+ "UpdateDate": null
+ }
+]
+`)
+}
+
+func TestVALUES_CTE_Update(t *testing.T) {
+
+ paymentID := IntegerColumn("payment_ID")
+ increase := FloatColumn("increase")
+ paymentsToUpdate := CTE("values_cte", paymentID, increase)
+
+ stmt := WITH(
+ paymentsToUpdate.AS(
+ VALUES(
+ WRAP(Int32(20564), Float32(1.21)),
+ WRAP(Int32(20567), Float32(1.02)),
+ WRAP(Int32(20570), Float32(1.34)),
+ WRAP(Int32(20573), Float32(1.72)),
+ ),
+ ),
+ )(
+ Payment.UPDATE().
+ SET(
+ Payment.Amount.SET(Payment.Amount.MUL(CAST(increase).AS_DECIMAL())),
+ ).
+ FROM(paymentsToUpdate).
+ WHERE(Payment.PaymentID.EQ(paymentID)).
+ RETURNING(Payment.AllColumns),
+ )
+
+ testutils.AssertDebugStatementSql(t, stmt, `
+WITH values_cte ("payment_ID", increase) AS (
+ VALUES (20564::integer, 1.2100000381469727::real),
+ (20567::integer, 1.0199999809265137::real),
+ (20570::integer, 1.340000033378601::real),
+ (20573::integer, 1.7200000286102295::real)
+)
+UPDATE dvds.payment
+SET amount = (payment.amount * values_cte.increase::decimal)
+FROM values_cte
+WHERE payment.payment_id = values_cte."payment_ID"
+RETURNING payment.payment_id AS "payment.payment_id",
+ payment.customer_id AS "payment.customer_id",
+ payment.staff_id AS "payment.staff_id",
+ payment.rental_id AS "payment.rental_id",
+ payment.amount AS "payment.amount",
+ payment.payment_date AS "payment.payment_date";
+`)
+
+ testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
+
+ var payments []model.Payment
+
+ err := stmt.Query(tx, &payments)
+ require.NoError(t, err)
+
+ assert.Len(t, payments, 4)
+ testutils.AssertJSON(t, payments[0:2], `
+[
+ {
+ "PaymentID": 20564,
+ "CustomerID": 379,
+ "StaffID": 2,
+ "RentalID": 11457,
+ "Amount": 4.83,
+ "PaymentDate": "2007-03-02T19:42:42.996577Z"
+ },
+ {
+ "PaymentID": 20567,
+ "CustomerID": 379,
+ "StaffID": 2,
+ "RentalID": 13397,
+ "Amount": 8.15,
+ "PaymentDate": "2007-03-19T20:35:01.996577Z"
+ }
+]
+`)
+ })
+
+}
diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go
index 21fca32..92b5649 100644
--- a/tests/postgres/with_test.go
+++ b/tests/postgres/with_test.go
@@ -83,10 +83,10 @@ SELECT orders.ship_region AS "orders.ship_region",
SUM(order_details.quantity) AS "product_sales"
FROM northwind.orders
INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id)
-WHERE orders.ship_region IN (
+WHERE orders.ship_region IN ((
SELECT top_region."orders.ship_region" AS "orders.ship_region"
FROM top_region
- )
+ ))
GROUP BY orders.ship_region, order_details.product_id
ORDER BY SUM(order_details.quantity) DESC;
`)
@@ -157,19 +157,19 @@ func TestWithStatementDeleteAndInsert(t *testing.T) {
testutils.AssertStatementSql(t, stmt, `
WITH remove_discontinued_orders AS (
DELETE FROM northwind.order_details
- WHERE order_details.product_id IN (
+ WHERE order_details.product_id IN ((
SELECT products.product_id AS "products.product_id"
FROM northwind.products
WHERE products.discontinued = $1
- )
+ ))
RETURNING order_details.product_id AS "order_details.product_id"
),update_discontinued_price AS (
UPDATE northwind.products
SET unit_price = $2
- WHERE products.product_id IN (
+ WHERE products.product_id IN ((
SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id"
FROM remove_discontinued_orders
- )
+ ))
RETURNING products.product_id AS "products.product_id",
products.product_name AS "products.product_name",
products.supplier_id AS "products.supplier_id",
diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go
index 58a9da6..080e870 100644
--- a/tests/sqlite/alltypes_test.go
+++ b/tests/sqlite/alltypes_test.go
@@ -2,6 +2,7 @@ package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table"
@@ -153,43 +154,43 @@ func TestAllTypesInsert(t *testing.T) {
var toInsert = model.AllTypes{
Boolean: false,
- BooleanPtr: testutils.BoolPtr(true),
+ BooleanPtr: ptr.Of(true),
TinyInt: 1,
SmallInt: 3,
MediumInt: 5,
Integer: 7,
BigInt: 9,
- TinyIntPtr: testutils.Int8Ptr(11),
- SmallIntPtr: testutils.Int16Ptr(33),
- MediumIntPtr: testutils.Int32Ptr(55),
- IntegerPtr: testutils.Int32Ptr(77),
- BigIntPtr: testutils.Int64Ptr(99),
+ TinyIntPtr: ptr.Of(int8(11)),
+ SmallIntPtr: ptr.Of(int16(33)),
+ MediumIntPtr: ptr.Of(int32(55)),
+ IntegerPtr: ptr.Of(int32(77)),
+ BigIntPtr: ptr.Of(int64(99)),
Decimal: 11.22,
- DecimalPtr: testutils.Float64Ptr(33.44),
+ DecimalPtr: ptr.Of(33.44),
Numeric: 55.66,
- NumericPtr: testutils.Float64Ptr(77.88),
+ NumericPtr: ptr.Of(77.88),
Float: 99.00,
- FloatPtr: testutils.Float64Ptr(11.22),
+ FloatPtr: ptr.Of(11.22),
Double: 33.44,
- DoublePtr: testutils.Float64Ptr(55.66),
+ DoublePtr: ptr.Of(55.66),
Real: 77.88,
- RealPtr: testutils.Float32Ptr(99.00),
+ RealPtr: ptr.Of(float32(99.00)),
Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC),
- TimePtr: testutils.TimePtr(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)),
+ TimePtr: ptr.Of(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)),
Date: time.Now(),
- DatePtr: testutils.TimePtr(time.Now()),
+ DatePtr: ptr.Of(time.Now()),
DateTime: time.Now(),
- DateTimePtr: testutils.TimePtr(time.Now()),
+ DateTimePtr: ptr.Of(time.Now()),
Timestamp: time.Now(),
- TimestampPtr: testutils.TimePtr(time.Now()),
+ TimestampPtr: ptr.Of(time.Now()),
Char: "abcd",
- CharPtr: testutils.StringPtr("absd"),
+ CharPtr: ptr.Of("absd"),
VarChar: "abcd",
- VarCharPtr: testutils.StringPtr("absd"),
+ VarCharPtr: ptr.Of("absd"),
Blob: []byte("large file"),
- BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
+ BlobPtr: ptr.Of([]byte("very large file")),
Text: "some text",
- TextPtr: testutils.StringPtr("text"),
+ TextPtr: ptr.Of("text"),
}
func TestUUID(t *testing.T) {
@@ -233,18 +234,18 @@ func TestExpressionOperators(t *testing.T) {
SELECT all_types.integer IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in",
- (all_types.small_int_ptr IN (
+ (all_types.small_int_ptr IN ((
SELECT all_types.integer AS "all_types.integer"
FROM all_types
- )) AS "result.in_select",
+ ))) AS "result.in_select",
(length(121232459)) AS "result.raw",
(? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg",
(? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
- (all_types.small_int_ptr NOT IN (
+ (all_types.small_int_ptr NOT IN ((
SELECT all_types.integer AS "all_types.integer"
FROM all_types
- )) AS "result.not_in_select"
+ ))) AS "result.not_in_select"
FROM all_types
LIMIT ?;
`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2))
@@ -659,7 +660,7 @@ func TestExactDecimals(t *testing.T) {
// not overwritten
Numeric: "6.7",
- NumericPtr: testutils.StringPtr("7.7"),
+ NumericPtr: ptr.Of("7.7"),
},
Decimal: decimal.RequireFromString("91.23"),
DecimalPtr: decimal.RequireFromString("45.67"),
@@ -899,3 +900,36 @@ func TestDateTimeExpressions(t *testing.T) {
require.Equal(t, dest.JulianDay, 2.4551543576232754e+06)
require.Equal(t, dest.StrfTime, "20:34")
}
+
+func TestRowExpression(t *testing.T) {
+ date := Date(2000, 9, 9)
+ time := Time(11, 22, 11)
+ dateTime := DateTime(2008, 11, 22, 10, 12, 40)
+ dateTime2 := DateTime(2011, 1, 2, 5, 12, 40)
+
+ stmt := SELECT(
+ ROW(Bool(false), date).EQ(ROW(Bool(true), date)),
+ ROW(Bool(false), time).NOT_EQ(ROW(Bool(true), time)),
+ ROW(time).IS_DISTINCT_FROM(RowExp(Raw("(time('now'))"))),
+ ROW(dateTime, dateTime2).GT(ROW(dateTime, dateTime2)),
+ ROW(dateTime2).GT_EQ(ROW(dateTime)),
+ ROW(dateTime, dateTime2).LT(ROW(dateTime, dateTime2)),
+ ROW(dateTime2).LT_EQ(ROW(dateTime2)),
+ )
+
+ //fmt.Println(stmt.Sql())
+ //fmt.Println(stmt.DebugSql())
+
+ testutils.AssertDebugStatementSql(t, stmt, `
+SELECT (FALSE, DATE('2000-09-09')) = (TRUE, DATE('2000-09-09')),
+ (FALSE, TIME('11:22:11')) != (TRUE, TIME('11:22:11')),
+ (TIME('11:22:11')) IS NOT ((time('now'))),
+ (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) > (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')),
+ (DATETIME('2011-01-02 05:12:40')) >= (DATETIME('2008-11-22 10:12:40')),
+ (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) < (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')),
+ (DATETIME('2011-01-02 05:12:40')) <= (DATETIME('2011-01-02 05:12:40'));
+`)
+
+ err := stmt.Query(db, &struct{}{})
+ require.NoError(t, err)
+}
diff --git a/tests/sqlite/generator_test.go b/tests/sqlite/generator_test.go
index b232aba..8a15568 100644
--- a/tests/sqlite/generator_test.go
+++ b/tests/sqlite/generator_test.go
@@ -8,8 +8,11 @@ import (
"github.com/stretchr/testify/require"
+ "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/sqlite"
+ "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/testutils"
+ sqlite2 "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
)
@@ -58,6 +61,36 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)
}
+func TestGenerator_TableMetadata(t *testing.T) {
+ var schema metadata.Schema
+ err := sqlite.GenerateDSN(testDatabaseFilePath, genDestDir,
+ template.Default(sqlite2.Dialect).UseSchema(func(m metadata.Schema) template.Schema {
+ schema = m
+ return template.DefaultSchema(m)
+ }))
+ require.NoError(t, err)
+
+ // Spot check the actor table and assert that the emitted
+ // properties are as expected.
+ var got metadata.Table
+ for _, table := range schema.TablesMetaData {
+ if table.Name == "actor" {
+ got = table
+ }
+ }
+
+ want := metadata.Table{
+ Name: "actor",
+ Columns: []metadata.Column{
+ {Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "INTEGER", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""},
+ {Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "TIMESTAMP", Kind: "base", IsUnsigned: false}, Comment: ""},
+ },
+ }
+ require.Equal(t, want, got)
+}
+
func TestCmdGenerator(t *testing.T) {
cmd := exec.Command("jet", "-source=SQLite", "-dsn=file://"+testDatabaseFilePath, "-path="+genDestDir)
diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go
index 079ba4f..c504d90 100644
--- a/tests/sqlite/insert_test.go
+++ b/tests/sqlite/insert_test.go
@@ -3,6 +3,8 @@ package sqlite
import (
"context"
"github.com/go-jet/jet/v2/qrm"
+ "database/sql"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
"math/rand"
"testing"
@@ -49,7 +51,7 @@ VALUES (?, ?, ?, ?),
ID: 101,
URL: "http://www.google.com",
Name: "Google",
- Description: testutils.StringPtr("Search engine"),
+ Description: ptr.Of("Search engine"),
})
testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{
ID: 102,
diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go
index 3e06124..593d630 100644
--- a/tests/sqlite/main_test.go
+++ b/tests/sqlite/main_test.go
@@ -8,16 +8,11 @@ import (
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/dbconfig"
- "github.com/stretchr/testify/require"
- "math/rand"
- "os"
- "os/exec"
- "runtime"
- "strings"
- "testing"
- "time"
-
"github.com/pkg/profile"
+ "github.com/stretchr/testify/require"
+ "os"
+ "runtime"
+ "testing"
_ "github.com/mattn/go-sqlite3"
)
@@ -27,11 +22,8 @@ var sampleDB *sqlite.DB
var testRoot string
func TestMain(m *testing.M) {
- rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
- setTestRoot()
-
sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath)
throw.OnError(err)
db = sqlite.NewDB(sqlDB).WithStatementsCaching(true)
@@ -59,16 +51,6 @@ func TestMain(m *testing.M) {
}
}
-func setTestRoot() {
- cmd := exec.Command("git", "rev-parse", "--show-toplevel")
- byteArr, err := cmd.Output()
- if err != nil {
- panic(err)
- }
-
- testRoot = strings.TrimSpace(string(byteArr)) + "/tests/"
-}
-
var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string
diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go
index 1a68b4c..0655f1f 100644
--- a/tests/sqlite/sample_test.go
+++ b/tests/sqlite/sample_test.go
@@ -3,6 +3,7 @@ package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/qrm"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/stretchr/testify/require"
"testing"
@@ -54,7 +55,7 @@ WHERE people.people_id = ?;
).MODEL(
model.People{
PeopleName: "Dario",
- PeopleHeightCm: testutils.Float64Ptr(190),
+ PeopleHeightCm: ptr.Of(190.0),
},
).RETURNING(
People.AllColumns,
diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go
index 57ee587..52c89b3 100644
--- a/tests/sqlite/select_test.go
+++ b/tests/sqlite/select_test.go
@@ -2,6 +2,7 @@ package sqlite
import (
"context"
+ "github.com/go-jet/jet/v2/internal/utils/ptr"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table"
"strings"
@@ -846,15 +847,15 @@ func TestSimpleView(t *testing.T) {
require.Equal(t, len(dest), 10)
require.Equal(t, dest[2], model.CustomerList{
- ID: testutils.Int32Ptr(3),
- Name: testutils.StringPtr("LINDA WILLIAMS"),
- Address: testutils.StringPtr("692 Joliet Street"),
- ZipCode: testutils.StringPtr("83579"),
- Phone: testutils.StringPtr(" "),
- City: testutils.StringPtr("Athenai"),
- Country: testutils.StringPtr("Greece"),
- Notes: testutils.StringPtr("active"),
- Sid: testutils.Int32Ptr(1),
+ ID: ptr.Of(int32(3)),
+ Name: ptr.Of("LINDA WILLIAMS"),
+ Address: ptr.Of("692 Joliet Street"),
+ ZipCode: ptr.Of("83579"),
+ Phone: ptr.Of(" "),
+ City: ptr.Of("Athenai"),
+ Country: ptr.Of("Greece"),
+ Notes: ptr.Of("active"),
+ Sid: ptr.Of(int32(1)),
})
}
diff --git a/tests/sqlite/values_test.go b/tests/sqlite/values_test.go
new file mode 100644
index 0000000..0793397
--- /dev/null
+++ b/tests/sqlite/values_test.go
@@ -0,0 +1,344 @@
+package sqlite
+
+import (
+ "database/sql"
+ "github.com/go-jet/jet/v2/internal/testutils"
+ "github.com/stretchr/testify/require"
+ "strings"
+ "testing"
+ "time"
+
+ . "github.com/go-jet/jet/v2/sqlite"
+ "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
+ . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table"
+)
+
+func TestVALUES(t *testing.T) {
+
+ values := VALUES(
+ ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")),
+ ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")),
+ ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL),
+ ).AS("values_table")
+
+ stmt := SELECT(
+ values.AllColumns(),
+ ).FROM(
+ values,
+ )
+
+ testutils.AssertStatementSql(t, stmt, `
+SELECT values_table.column1 AS "column1",
+ values_table.column2 AS "column2",
+ values_table.column3 AS "column3",
+ values_table.column4 AS "column4",
+ values_table.column5 AS "column5"
+FROM (
+ VALUES (?, ?, ?, ?, ?),
+ (? + ?, ?, ?, ?, ?),
+ (?, ?, ?, ?, NULL)
+ ) AS values_table;
+`)
+
+ var dest []struct {
+ Column1 int
+ Column2 int
+ Column3 float32
+ Column4 bool
+ Column5 *string
+ }
+
+ err := stmt.Query(db, &dest)
+
+ require.NoError(t, err)
+ testutils.AssertJSON(t, dest, `
+[
+ {
+ "Column1": 1,
+ "Column2": 2,
+ "Column3": 4.666,
+ "Column4": false,
+ "Column5": "txt"
+ },
+ {
+ "Column1": 13,
+ "Column2": 22,
+ "Column3": 33.222,
+ "Column4": true,
+ "Column5": "png"
+ },
+ {
+ "Column1": 11,
+ "Column2": 22,
+ "Column3": 33.222,
+ "Column4": true,
+ "Column5": null
+ }
+]
+`)
+}
+
+func TestVALUES_Join(t *testing.T) {
+
+ lastUpdate := DateTime(2007, time.February, 11, 12, 0, 0)
+
+ films := VALUES(
+ ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate),
+ ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate),
+ ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL),
+ ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL),
+ ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), DATETIME(lastUpdate, YEARS(2))),
+ ).AS("film_values")
+
+ title := StringColumn("column1").From(films)
+ releaseYear := IntegerColumn("column3").From(films)
+ rentalRate := FloatColumn("column4").From(films)
+
+ stmt := SELECT(
+ Film.AllColumns,
+ films.AllColumns(),
+ ).FROM(
+ Film.
+ INNER_JOIN(films, LOWER(title).EQ(LOWER(Film.Title))),
+ ).WHERE(AND(
+ CAST(Film.ReleaseYear).AS_INTEGER().GT(releaseYear),
+ Film.RentalRate.LT(rentalRate),
+ )).ORDER_BY(
+ title,
+ )
+
+ testutils.AssertDebugStatementSql(t, stmt, `
+SELECT film.film_id AS "film.film_id",
+ film.title AS "film.title",
+ film.description AS "film.description",
+ film.release_year AS "film.release_year",
+ film.language_id AS "film.language_id",
+ film.original_language_id AS "film.original_language_id",
+ film.rental_duration AS "film.rental_duration",
+ film.rental_rate AS "film.rental_rate",
+ film.length AS "film.length",
+ film.replacement_cost AS "film.replacement_cost",
+ film.rating AS "film.rating",
+ film.special_features AS "film.special_features",
+ film.last_update AS "film.last_update",
+ film_values.column1 AS "column1",
+ film_values.column2 AS "column2",
+ film_values.column3 AS "column3",
+ film_values.column4 AS "column4",
+ film_values.column5 AS "column5"
+FROM film
+ INNER JOIN (
+ VALUES ('Chamber Italian', 117, 2005, 5.82, DATETIME('2007-02-11 12:00:00')),
+ ('Grosse Wonderful', 49, 2004, 6.242, DATETIME('2007-02-11 12:00:00')),
+ ('Airport Pollock', 54, 2001, 7.22, NULL),
+ ('Bright Encounters', 73, 2002, 8.25, NULL),
+ ('Academy Dinosaur', 83, 2010, 9.22, DATETIME(DATETIME('2007-02-11 12:00:00'), '2 YEARS'))
+ ) AS film_values ON (LOWER(film_values.column1) = LOWER(film.title))
+WHERE (
+ (CAST(film.release_year AS INTEGER) > film_values.column3)
+ AND (film.rental_rate < film_values.column4)
+ )
+ORDER BY film_values.column1;
+`)
+
+ var dest []struct {
+ Film model.Film
+
+ Column1 string
+ Column2 int
+ Column3 int
+ Column4 float32
+ Column5 *time.Time
+ }
+
+ err := stmt.Query(db, &dest)
+
+ require.NoError(t, err)
+ testutils.AssertJSON(t, dest, `
+[
+ {
+ "Film": {
+ "FilmID": 8,
+ "Title": "AIRPORT POLLOCK",
+ "Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India",
+ "ReleaseYear": "2006",
+ "LanguageID": 1,
+ "OriginalLanguageID": null,
+ "RentalDuration": 6,
+ "RentalRate": 4.99,
+ "Length": 54,
+ "ReplacementCost": 15.99,
+ "Rating": "R",
+ "SpecialFeatures": "Trailers",
+ "LastUpdate": "2019-04-11T18:11:48Z"
+ },
+ "Column1": "Airport Pollock",
+ "Column2": 54,
+ "Column3": 2001,
+ "Column4": 7.22,
+ "Column5": null
+ },
+ {
+ "Film": {
+ "FilmID": 98,
+ "Title": "BRIGHT ENCOUNTERS",
+ "Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat",
+ "ReleaseYear": "2006",
+ "LanguageID": 1,
+ "OriginalLanguageID": null,
+ "RentalDuration": 4,
+ "RentalRate": 4.99,
+ "Length": 73,
+ "ReplacementCost": 12.99,
+ "Rating": "PG-13",
+ "SpecialFeatures": "Trailers",
+ "LastUpdate": "2019-04-11T18:11:48Z"
+ },
+ "Column1": "Bright Encounters",
+ "Column2": 73,
+ "Column3": 2002,
+ "Column4": 8.25,
+ "Column5": null
+ },
+ {
+ "Film": {
+ "FilmID": 133,
+ "Title": "CHAMBER ITALIAN",
+ "Description": "A Fateful Reflection of a Moose And a Husband who must Overcome a Monkey in Nigeria",
+ "ReleaseYear": "2006",
+ "LanguageID": 1,
+ "OriginalLanguageID": null,
+ "RentalDuration": 7,
+ "RentalRate": 4.99,
+ "Length": 117,
+ "ReplacementCost": 14.99,
+ "Rating": "NC-17",
+ "SpecialFeatures": "Trailers",
+ "LastUpdate": "2019-04-11T18:11:48Z"
+ },
+ "Column1": "Chamber Italian",
+ "Column2": 117,
+ "Column3": 2005,
+ "Column4": 5.82,
+ "Column5": "2007-02-11T12:00:00Z"
+ },
+ {
+ "Film": {
+ "FilmID": 384,
+ "Title": "GROSSE WONDERFUL",
+ "Description": "A Epic Drama of a Cat And a Explorer who must Redeem a Moose in Australia",
+ "ReleaseYear": "2006",
+ "LanguageID": 1,
+ "OriginalLanguageID": null,
+ "RentalDuration": 5,
+ "RentalRate": 4.99,
+ "Length": 49,
+ "ReplacementCost": 19.99,
+ "Rating": "R",
+ "SpecialFeatures": "Behind the Scenes",
+ "LastUpdate": "2019-04-11T18:11:48Z"
+ },
+ "Column1": "Grosse Wonderful",
+ "Column2": 49,
+ "Column3": 2004,
+ "Column4": 6.242,
+ "Column5": "2007-02-11T12:00:00Z"
+ }
+]
+`)
+}
+
+func TestVALUES_CTE_Update(t *testing.T) {
+
+ paymentID := IntegerColumn("payment_ID")
+ increase := FloatColumn("increase")
+ paymentsToUpdate := CTE("values_cte", paymentID, increase)
+
+ stmt := WITH(
+ paymentsToUpdate.AS(
+ VALUES(
+ ROW(Int32(204), Float(1.21)),
+ ROW(Int32(207), Float(1.02)),
+ ROW(Int32(200), Float(1.34)),
+ ROW(Int32(203), Float(1.72)),
+ ),
+ ),
+ )(
+ Payment.UPDATE().
+ SET(
+ Payment.Amount.SET(Payment.Amount.MUL(increase)),
+ ).
+ FROM(paymentsToUpdate).
+ WHERE(Payment.PaymentID.EQ(paymentID)).
+ RETURNING(Payment.AllColumns),
+ )
+
+ testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
+WITH values_cte (''payment_ID'', increase) AS (
+ VALUES (?, ?),
+ (?, ?),
+ (?, ?),
+ (?, ?)
+)
+UPDATE payment
+SET amount = (payment.amount * values_cte.increase)
+FROM values_cte
+WHERE payment.payment_id = values_cte.''payment_ID''
+RETURNING payment.payment_id AS "payment.payment_id",
+ payment.customer_id AS "payment.customer_id",
+ payment.staff_id AS "payment.staff_id",
+ payment.rental_id AS "payment.rental_id",
+ payment.amount AS "payment.amount",
+ payment.payment_date AS "payment.payment_date",
+ payment.last_update AS "payment.last_update";
+`, "''", "`"))
+
+ testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
+ var payments []model.Payment
+
+ err := stmt.Query(tx, &payments)
+
+ require.NoError(t, err)
+ testutils.AssertJSON(t, payments, `
+[
+ {
+ "PaymentID": 200,
+ "CustomerID": 7,
+ "StaffID": 2,
+ "RentalID": 11542,
+ "Amount": 10.706600000000002,
+ "PaymentDate": "2005-08-17T00:51:32Z",
+ "LastUpdate": "2019-04-11T18:11:50Z"
+ },
+ {
+ "PaymentID": 203,
+ "CustomerID": 7,
+ "StaffID": 2,
+ "RentalID": 13373,
+ "Amount": 5.1428,
+ "PaymentDate": "2005-08-19T21:23:31Z",
+ "LastUpdate": "2019-04-11T18:11:50Z"
+ },
+ {
+ "PaymentID": 204,
+ "CustomerID": 7,
+ "StaffID": 1,
+ "RentalID": 13476,
+ "Amount": 3.6179,
+ "PaymentDate": "2005-08-20T01:06:04Z",
+ "LastUpdate": "2019-04-11T18:11:50Z"
+ },
+ {
+ "PaymentID": 207,
+ "CustomerID": 8,
+ "StaffID": 2,
+ "RentalID": 866,
+ "Amount": 7.1298,
+ "PaymentDate": "2005-05-30T03:43:54Z",
+ "LastUpdate": "2019-04-11T18:11:50Z"
+ }
+]
+`)
+ })
+
+}
diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go
index 402df2f..485655c 100644
--- a/tests/sqlite/with_test.go
+++ b/tests/sqlite/with_test.go
@@ -153,10 +153,10 @@ WITH payments_to_update AS (
)
UPDATE payment
SET amount = 0
-WHERE payment.payment_id IN (
+WHERE payment.payment_id IN ((
SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_update
- );
+ ));
`, "''", "`", -1))
tx := beginDBTx(t)
@@ -205,10 +205,10 @@ WITH payments_to_delete AS (
WHERE payment.amount < 0.5
)
DELETE FROM payment
-WHERE payment.payment_id IN (
+WHERE payment.payment_id IN ((
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete
- );
+ ));
`, "''", "`", -1))
tx := beginDBTx(t)
diff --git a/tests/testdata b/tests/testdata
index 915bdc1..6a39774 160000
--- a/tests/testdata
+++ b/tests/testdata
@@ -1 +1 @@
-Subproject commit 915bdc16b723d89becc577c780949baef861a6ae
+Subproject commit 6a397747d310938b41d3950d68009578180d3dd5