Merge remote-tracking branch 'upstream/master' into stmt-cache2

# Conflicts:
#	tests/postgres/alltypes_test.go
#	tests/postgres/northwind_test.go
#	tests/postgres/sample_test.go
#	tests/postgres/update_test.go
#	tests/sqlite/insert_test.go
#	tests/sqlite/main_test.go
#	tests/sqlite/sample_test.go
#	tests/sqlite/update_test.go
This commit is contained in:
go-jet 2024-10-19 14:01:55 +02:00
commit 4bb9775134
97 changed files with 2306 additions and 537 deletions

View file

@ -8,7 +8,7 @@ jobs:
build_and_tests:
docker:
# specify the version
- image: cimg/go:1.21.6
- image: cimg/go:1.22.8
- image: cimg/postgres:14.10
environment:
POSTGRES_USER: jet
@ -128,17 +128,13 @@ jobs:
# to create test results report
- run:
name: Install go-junit-report
command: go install github.com/jstemmer/go-junit-report@latest
name: Install gotestsum
command: go install gotest.tools/gotestsum@latest
- run: mkdir -p $TEST_RESULTS
# this will run all tests and exclude test files from code coverage report
- run: |
go test -v ./... \
-covermode=atomic \
-coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \
-coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
- run:
name: Running tests
command: gotestsum --junitfile $TEST_RESULTS/report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... ./...
# run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests
- run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/

45
.github/workflows/code_scanner.yml vendored Normal file
View file

@ -0,0 +1,45 @@
name: Code Scanners
on:
push:
branches:
- master
pull_request:
branches:
- master
permissions:
contents: read
env:
go_version: "1.22.8"
jobs:
security_scanning:
runs-on: ubuntu-latest
steps:
- name: Checkout Source
uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ env.go_version }}
cache: true
- name: Setup Tools
run: |
go install github.com/securego/gosec/v2/cmd/gosec@latest
- name: Running Scan
run: gosec --exclude=G402,G304 ./...
lint_scanner:
runs-on: ubuntu-latest
steps:
- name: Checkout Source
uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ env.go_version }}
cache: true
- name: Setup Tools
run: |
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
- name: Running Scan
run: golangci-lint run --timeout=30m ./...

19
.golangci.yml Normal file
View file

@ -0,0 +1,19 @@
run:
# The default concurrency value is the number of available CPU.
concurrency: 4
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 30m
# Exit code when at least one issue was found.
# Default: 1
issues-exit-code: 2
# Include test files or not.
# Default: true
tests: false
issues:
exclude-dirs:
- tests
exclude-files:
- "_test.go"
- "testutils.go"

View file

@ -1,3 +1,3 @@
package main
const version = "v2.11.0"
const version = "v2.11.1"

View file

@ -4,7 +4,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"io/ioutil"
"os"
_ "github.com/lib/pq"
@ -90,7 +90,7 @@ func main() {
func jsonSave(path string, v interface{}) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
err := ioutil.WriteFile(path, jsonText, 0644)
err := os.WriteFile(path, jsonText, 0600)
panicOnError(err)
}

View file

@ -1,29 +1,16 @@
package metadata
import (
"regexp"
)
// Column struct
type Column struct {
Name string `sql:"primary_key"`
IsPrimaryKey bool
IsNullable bool
IsGenerated bool
HasDefault bool
DataType DataType
Comment string
}
// GoLangComment returns column comment without ascii control characters
func (c Column) GoLangComment() string {
if c.Comment == "" {
return ""
}
// remove ascii control characters from string
return regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(c.Comment, "")
}
// DataTypeKind is database type kind(base, enum, user-defined, array)
type DataTypeKind string

View file

@ -3,5 +3,6 @@ package metadata
// Enum metadata struct
type Enum struct {
Name string `sql:"primary_key"`
Comment string
Values []string
}

View file

@ -3,6 +3,7 @@ package metadata
// Table metadata struct
type Table struct {
Name string `sql:"primary_key"`
Comment string
Columns []Column
}

View file

@ -18,6 +18,7 @@ func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTyp
SELECT
t.table_name as "table.name",
col.COLUMN_NAME AS "column.Name",
col.COLUMN_DEFAULT IS NOT NULL as "column.HasDefault",
col.IS_NULLABLE = "YES" AS "column.IsNullable",
col.COLUMN_COMMENT AS "column.Comment",
COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey",

View file

@ -14,7 +14,7 @@ type postgresQuerySet struct{}
func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
query := `
SELECT table_name as "table.name"
SELECT table_name as "table.name", obj_description((quote_ident(table_schema)||'.'||quote_ident(table_name))::regclass) as "table.comment"
FROM information_schema.tables
WHERE table_schema = $1 and table_type = $2
ORDER BY table_name;
@ -57,6 +57,7 @@ func getColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]meta
query := `
select
attr.attname as "column.Name",
col_description(attr.attrelid, attr.attnum) as "column.Comment",
exists(
select 1
from pg_catalog.pg_index indx
@ -64,11 +65,13 @@ select
) as "column.IsPrimaryKey",
not attr.attnotnull as "column.isNullable",
attr.attgenerated = 's' as "column.isGenerated",
(case tp.typtype
when 'b' then 'base'
when 'd' then 'base'
when 'e' then 'enum'
when 'r' then 'range'
attr.atthasdef as "column.hasDefault",
(case
when tp.typtype = 'b' AND tp.typcategory <> 'A' then 'base'
when tp.typtype = 'b' AND tp.typcategory = 'A' then 'array'
when tp.typtype = 'd' then 'base'
when tp.typtype = 'e' then 'enum'
when tp.typtype = 'r' then 'range'
end) as "dataType.Kind",
(case when tp.typtype = 'd' then (select pg_type.typname from pg_catalog.pg_type where pg_type.oid = tp.typbasetype)
when tp.typcategory = 'A' then pg_catalog.format_type(attr.atttypid, attr.atttypmod)
@ -99,6 +102,7 @@ order by
func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
query := `
SELECT t.typname as "enum.name",
obj_description(t.oid) as "enum.comment",
e.enumlabel as "values"
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid

View file

@ -4,10 +4,11 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/semantic"
"github.com/go-jet/jet/v2/qrm"
"strings"
)
// sqliteQuerySet is dialect query set for SQLite
@ -77,6 +78,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
Name string
Type string
NotNull int32
DfltValue string
Pk int32
Hidden int32
}
@ -91,12 +93,14 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
for _, columnInfo := range columnInfos {
columnType := strings.TrimSuffix(getColumnType(columnInfo.Type), " GENERATED ALWAYS")
isGenerated := columnInfo.Hidden == 2 || columnInfo.Hidden == 3 // stored or virtual column
hasDefault := columnInfo.DfltValue != ""
columns = append(columns, metadata.Column{
Name: columnInfo.Name,
IsPrimaryKey: columnInfo.Pk != 0,
IsNullable: columnInfo.NotNull != 1,
IsGenerated: isGenerated,
HasDefault: hasDefault,
DataType: metadata.DataType{
Name: columnType,
Kind: metadata.BaseType,

View file

@ -26,13 +26,14 @@ import (
var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "{{tableTemplate.DefaultAlias}}")
{{golangComment .Comment}}
type {{structImplName}} struct {
{{dialect.PackageName}}.Table
// Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{- if $c.Comment }} // {{$c.GoLangComment}} {{end}}
{{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} {{golangComment .Comment}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
@ -119,10 +120,11 @@ import (
{{end}}
{{$modelTableTemplate := tableTemplate}}
{{golangComment .Comment}}
type {{$modelTableTemplate.TypeName}} struct {
{{- range .Columns}}
{{- $field := structField .}}
{{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{- if .Comment }} // {{.GoLangComment}} {{end}}
{{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` {{golangComment .Comment}}
{{- end}}
}
@ -132,6 +134,7 @@ var enumSQLBuilderTemplate = `package {{package}}
import "github.com/go-jet/jet/v2/{{dialect.PackageName}}"
{{golangComment .Comment}}
var {{enumTemplate.InstanceName}} = &struct {
{{- range $index, $value := .Values}}
{{enumValueName $value}} {{dialect.PackageName}}.StringExpression
@ -148,6 +151,7 @@ var enumModelTemplate = `package {{package}}
import "errors"
{{golangComment .Comment}}
type {{$enumTemplate.TypeName}} string
const (
@ -156,6 +160,12 @@ const (
{{- end}}
)
var {{$enumTemplate.TypeName}}AllValues = []{{$enumTemplate.TypeName}} {
{{- range $_, $value := .Values}}
{{valueName $value}},
{{- end}}
}
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
var enumValue string
switch val := value.(type) {

View file

@ -0,0 +1,13 @@
package template
import "regexp"
// Returns the provided string as golang comment without ascii control characters
func formatGolangComment(comment string) string {
if len(comment) == 0 {
return ""
}
// Format as colang comment and remove ascii control characters from string
return "// " + regexp.MustCompile(`[[:cntrl:]]+`).ReplaceAllString(comment, "")
}

View file

@ -0,0 +1,27 @@
package template
import "testing"
func Test_formatGolangComment(t *testing.T) {
type args struct {
comment string
}
tests := []struct {
name string
args args
want string
}{
{name: "Empty string", args: args{comment: ""}, want: ""},
{name: "Non-empty string", args: args{comment: "This is a comment"}, want: "// This is a comment"},
{name: "String with control characters", args: args{comment: "This is a comment with control characters \x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f and text after"}, want: "// This is a comment with control characters and text after"},
{name: "String with escape characters", args: args{comment: "This is a comment with escape characters \n\r\t and text after"}, want: "// This is a comment with escape characters and text after"},
{name: "String with unicode characters", args: args{comment: "This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"}, want: "// This is a comment with unicode characters ₲鬼佬℧⇄↻ and text after"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := formatGolangComment(tt.args.comment); got != tt.want {
t.Errorf("formatGoLangComment() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -140,6 +140,7 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []
"enumValueName": func(enumValue string) string {
return enumTemplate.ValueName(enumValue)
},
"golangComment": formatGolangComment,
})
if err != nil {
return fmt.Errorf("failed to generete enum type %s: %w", enumTemplate.FileName, err)
@ -215,6 +216,7 @@ func processTableSQLBuilder(fileTypes, dirPath string,
"insertedRowAlias": func() string {
return insertedRowAlias(dialect)
},
"golangComment": formatGolangComment,
})
if err != nil {
return fmt.Errorf("failed to generate table sql builder type %s: %w", tableSQLBuilder.TypeName, err)
@ -307,6 +309,7 @@ func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadat
"structField": func(columnMetaData metadata.Column) TableModelField {
return tableTemplate.Field(columnMetaData)
},
"golangComment": formatGolangComment,
})
if err != nil {
return fmt.Errorf("failed to generate model type '%s': %w", tableMetaData.Name, err)
@ -347,6 +350,7 @@ func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemp
"valueName": func(value string) string {
return enumTemplate.ValueName(value)
},
"golangComment": formatGolangComment,
})
if err != nil {

6
go.mod
View file

@ -4,16 +4,16 @@ go 1.18
require (
github.com/go-sql-driver/mysql v1.8.1
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/jackc/pgconn v1.14.3
github.com/jackc/pgtype v1.14.3
github.com/jackc/pgx/v4 v4.18.3
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.17
)
require (
github.com/google/go-cmp v0.6.0
github.com/jackc/pgtype v1.14.3
github.com/jackc/pgx/v4 v4.18.3
github.com/pkg/profile v1.7.0
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.9.0

View file

@ -11,6 +11,13 @@ type columnAssigmentImpl struct {
expression Expression
}
func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment {
return &columnAssigmentImpl{
column: serializer,
expression: expression,
}
}
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {

View file

@ -13,6 +13,7 @@ type Dialect interface {
ArgumentPlaceholder() QueryPlaceholderFunc
IsReservedWord(name string) bool
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName(index int) string
}
// SerializerFunc func
@ -35,6 +36,7 @@ type DialectParams struct {
ArgumentPlaceholder QueryPlaceholderFunc
ReservedWords []string
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName func(index int) string
}
// NewDialect creates new dialect with params
@ -49,6 +51,7 @@ func NewDialect(params DialectParams) Dialect {
argumentPlaceholder: params.ArgumentPlaceholder,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
serializeOrderBy: params.SerializeOrderBy,
valuesDefaultColumnName: params.ValuesDefaultColumnName,
}
}
@ -62,6 +65,7 @@ type dialectImpl struct {
argumentPlaceholder QueryPlaceholderFunc
reservedWords map[string]bool
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
valuesDefaultColumnName func(index int) string
}
func (d *dialectImpl) Name() string {
@ -107,6 +111,10 @@ func (d *dialectImpl) SerializeOrderBy() func(expression Expression, ascending,
return d.serializeOrderBy
}
func (d *dialectImpl) ValuesDefaultColumnName(index int) string {
return d.valuesDefaultColumnName(index)
}
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {

View file

@ -51,12 +51,12 @@ func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
// IN checks if this expressions matches any in expressions list
func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "IN")
}
// NOT_IN checks if this expressions is different of all expressions in expressions list
func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN")
return newBinaryBoolOperatorExpression(e.Parent, wrap(expressions...), "NOT IN")
}
// AS the temporary alias name to assign to the expression
@ -316,15 +316,6 @@ func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder,
}
}
type skipParenthesisWrap struct {
Expression
}
func skipWrap(expression Expression) Expression {
return &skipParenthesisWrap{expression}
}
// since the expression is a function parameter, there is no need to wrap it in parentheses
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.Expression.serialize(statement, out, append(options, NoWrap)...)
func wrap(expressions ...Expression) Expression {
return NewFunc("", expressions, nil)
}

View file

@ -12,11 +12,6 @@ func OR(expressions ...BoolExpression) BoolExpression {
return newBoolExpressionListOperator("OR", expressions...)
}
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(expressions ...Expression) Expression {
return NewFunc("ROW", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
// ABSf calculates absolute value from float expression
@ -711,7 +706,7 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder
if _, isStatement := expression.(Statement); isStatement {
expression.serialize(statement, out, options...)
} else {
skipWrap(expression).serialize(statement, out, options...)
expression.serialize(statement, out, append(options, NoWrap, Ident)...)
}
}
}

View file

@ -374,32 +374,6 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option
//---------------------------------------------------//
type wrap struct {
ExpressionInterfaceImpl
expressions []Expression
}
func (n *wrap) serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
if len(n.expressions) == 1 {
options = append(options, NoWrap, Ident)
}
serializeExpressionList(statementType, n.expressions, ", ", out, options...)
out.WriteString(")")
}
// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... )
func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression}
wrap.ExpressionInterfaceImpl.Parent = wrap
return wrap
}
//---------------------------------------------------//
type rawExpression struct {
ExpressionInterfaceImpl

View file

@ -54,7 +54,12 @@ type orderSetAggregateFuncExpression struct {
func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(p.name)
WRAP(p.fraction).serialize(statement, out, FallTrough(options)...)
if p.fraction != nil {
wrap(p.fraction).serialize(statement, out, FallTrough(options)...)
} else {
wrap().serialize(statement, out, FallTrough(options)...)
}
out.WriteString("WITHIN GROUP")
p.orderBy.serialize(statement, out)
}

View file

@ -39,7 +39,7 @@ AVG(table1.col_int) AS "avg",
table2.col3 AS "col3",
table2.col4 AS "col4"`)
subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery"))
subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery", nil))
assertProjectionSerialize(t, subQueryProjections,
`"subQuery"."table1.col3" AS "table1.col3",

View file

@ -8,7 +8,7 @@ type rawStatementImpl struct {
}
// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement {
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement {
newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,

View file

@ -0,0 +1,102 @@
package jet
// RowExpression interface
type RowExpression interface {
Expression
HasProjections
EQ(rhs RowExpression) BoolExpression
NOT_EQ(rhs RowExpression) BoolExpression
IS_DISTINCT_FROM(rhs RowExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression
LT(rhs RowExpression) BoolExpression
LT_EQ(rhs RowExpression) BoolExpression
GT(rhs RowExpression) BoolExpression
GT_EQ(rhs RowExpression) BoolExpression
}
type rowInterfaceImpl struct {
parent Expression
dialect Dialect
elemCount int
}
func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression {
return Eq(n.parent, rhs)
}
func (n *rowInterfaceImpl) NOT_EQ(rhs RowExpression) BoolExpression {
return NotEq(n.parent, rhs)
}
func (n *rowInterfaceImpl) IS_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsDistinctFrom(n.parent, rhs)
}
func (n *rowInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs RowExpression) BoolExpression {
return IsNotDistinctFrom(n.parent, rhs)
}
func (n *rowInterfaceImpl) GT(rhs RowExpression) BoolExpression {
return Gt(n.parent, rhs)
}
func (n *rowInterfaceImpl) GT_EQ(rhs RowExpression) BoolExpression {
return GtEq(n.parent, rhs)
}
func (n *rowInterfaceImpl) LT(rhs RowExpression) BoolExpression {
return Lt(n.parent, rhs)
}
func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression {
return LtEq(n.parent, rhs)
}
func (n *rowInterfaceImpl) projections() ProjectionList {
var ret ProjectionList
for i := 0; i < n.elemCount; i++ {
rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil)
ret = append(ret, &rowColumn)
}
return ret
}
// ---------------------------------------------------//
type rowExpressionWrapper struct {
rowInterfaceImpl
Expression
}
func newRowExpression(name string, dialect Dialect, expressions ...Expression) RowExpression {
ret := &rowExpressionWrapper{}
ret.rowInterfaceImpl.parent = ret
ret.Expression = NewFunc(name, expressions, ret)
ret.dialect = dialect
ret.elemCount = len(expressions)
return ret
}
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("ROW", dialect, expressions...)
}
// WRAP creates row expressions without ROW keyword `( expression1, expression2, ... )`.
func WRAP(dialect Dialect, expressions ...Expression) RowExpression {
return newRowExpression("", dialect, expressions...)
}
// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
func RowExp(expression Expression) RowExpression {
rowExpressionWrap := rowExpressionWrapper{Expression: expression}
rowExpressionWrap.rowInterfaceImpl.parent = &rowExpressionWrap
return &rowExpressionWrap
}

View file

@ -10,13 +10,19 @@ type SelectTable interface {
type selectTableImpl struct {
Statement SerializerHasProjections
alias string
columnAliases []ColumnExpression
}
// NewSelectTable func
func NewSelectTable(selectStmt SerializerHasProjections, alias string) selectTableImpl {
func NewSelectTable(selectStmt SerializerHasProjections, alias string, columnAliases []ColumnExpression) selectTableImpl {
selectTable := selectTableImpl{
Statement: selectStmt,
alias: alias,
columnAliases: columnAliases,
}
for _, column := range selectTable.columnAliases {
column.setSubQuery(selectTable)
}
return selectTable
@ -31,6 +37,10 @@ func (s selectTableImpl) Alias() string {
}
func (s selectTableImpl) AllColumns() ProjectionList {
if len(s.columnAliases) > 0 {
return ColumnListToProjectionList(s.columnAliases)
}
projectionList := s.projections().fromImpl(s)
return projectionList.(ProjectionList)
}
@ -40,6 +50,12 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt
out.WriteString("AS")
out.WriteIdentifier(s.alias)
if len(s.columnAliases) > 0 {
out.WriteByte('(')
SerializeColumnExpressionNames(s.columnAliases, out)
out.WriteByte(')')
}
}
// --------------------------------------
@ -50,7 +66,7 @@ type lateralImpl struct {
// NewLateral creates new lateral expression from select statement with alias
func NewLateral(selectStmt SerializerStatement, alias string) SelectTable {
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)}
return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias, nil)}
}
func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {

View file

@ -180,7 +180,7 @@ func duration(f func()) time.Duration {
f()
return time.Now().Sub(start)
return time.Since(start)
}
// ExpressionStatement interfacess

35
internal/jet/values.go Normal file
View file

@ -0,0 +1,35 @@
package jet
// Values hold a set of one or more rows
type Values []RowExpression
func (v Values) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteByte('(')
out.IncreaseIdent(5)
out.NewLine()
out.WriteString("VALUES")
for rowIndex, row := range v {
if rowIndex > 0 {
out.WriteString(",")
out.NewLine()
} else {
out.IncreaseIdent(7)
}
row.serialize(statement, out, options...)
}
out.DecreaseIdent(7)
out.DecreaseIdent(5)
out.NewLine()
out.WriteByte(')')
}
func (v Values) projections() ProjectionList {
if len(v) == 0 {
return nil
}
return v[0].projections()
}

View file

@ -64,7 +64,7 @@ type CommonTableExpression struct {
// CTE creates new named CommonTableExpression
func CTE(name string, columns ...ColumnExpression) CommonTableExpression {
cte := CommonTableExpression{
selectTableImpl: NewSelectTable(nil, name),
selectTableImpl: NewSelectTable(nil, name, columns),
Columns: columns,
}
@ -99,12 +99,3 @@ func (c CommonTableExpression) serialize(statement StatementType, out *SQLBuilde
out.WriteIdentifier(c.alias)
}
}
// AllColumns returns list of all projections in the CTE
func (c CommonTableExpression) AllColumns() ProjectionList {
if len(c.Columns) > 0 {
return ColumnListToProjectionList(c.Columns)
}
return c.selectTableImpl.AllColumns()
}

View file

@ -9,17 +9,15 @@ import (
jet2 "github.com/go-jet/jet/v2/internal/jet/db"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
// UnixTimeComparer will compare time equality while ignoring time zone
@ -105,11 +103,12 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
}
// SaveJSONFile saves v as json at testRelativePath
// nolint:unused
func SaveJSONFile(v interface{}, testRelativePath string) {
jsonText, _ := json.MarshalIndent(v, "", "\t")
filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644)
err := os.WriteFile(filePath, jsonText, 0600)
throw.OnError(err)
}
@ -118,7 +117,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath)
fileJSONData, err := os.ReadFile(filePath) // #nosec G304
require.NoError(t, err)
if runtime.GOOS == "windows" {
@ -245,7 +244,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter
// AssertFileContent check if file content at filePath contains expectedContent text.
func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath)
enumFileData, err := os.ReadFile(filePath) // #nosec G304
require.NoError(t, err)
@ -254,7 +253,7 @@ func AssertFileContent(t *testing.T, filePath string, expectedContent string) {
// AssertFileNamesEqual check if all filesInfos are contained in fileNames
func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) {
files, err := ioutil.ReadDir(dirPath)
files, err := os.ReadDir(dirPath)
require.NoError(t, err)
require.Equal(t, len(files), len(fileNames))
@ -293,76 +292,6 @@ func printDiff(actual, expected interface{}, options ...cmp.Option) {
fmt.Println(expected)
}
// BoolPtr returns address of bool parameter
func BoolPtr(b bool) *bool {
return &b
}
// Int8Ptr returns address of int8 parameter
func Int8Ptr(i int8) *int8 {
return &i
}
// UInt8Ptr returns address of uint8 parameter
func UInt8Ptr(i uint8) *uint8 {
return &i
}
// Int16Ptr returns address of int16 parameter
func Int16Ptr(i int16) *int16 {
return &i
}
// UInt16Ptr returns address of uint16 parameter
func UInt16Ptr(i uint16) *uint16 {
return &i
}
// Int32Ptr returns address of int32 parameter
func Int32Ptr(i int32) *int32 {
return &i
}
// UInt32Ptr returns address of uint32 parameter
func UInt32Ptr(i uint32) *uint32 {
return &i
}
// Int64Ptr returns address of int64 parameter
func Int64Ptr(i int64) *int64 {
return &i
}
// UInt64Ptr returns address of uint64 parameter
func UInt64Ptr(i uint64) *uint64 {
return &i
}
// StringPtr returns address of string parameter
func StringPtr(s string) *string {
return &s
}
// TimePtr returns address of time.Time parameter
func TimePtr(t time.Time) *time.Time {
return &t
}
// ByteArrayPtr returns address of []byte parameter
func ByteArrayPtr(arr []byte) *[]byte {
return &arr
}
// Float32Ptr returns address of float32 parameter
func Float32Ptr(f float32) *float32 {
return &f
}
// Float64Ptr returns address of float64 parameter
func Float64Ptr(f float64) *float64 {
return &f
}
// UUIDPtr returns address of uuid.UUID
func UUIDPtr(u string) *uuid.UUID {
newUUID := uuid.MustParse(u)

View file

@ -1,6 +1,7 @@
package filesys
import (
"errors"
"fmt"
"go/format"
"os"
@ -16,7 +17,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath += ".go"
}
file, err := os.Create(newGoFilePath)
file, err := os.Create(newGoFilePath) // #nosec 304
if err != nil {
return err
@ -28,7 +29,10 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
// if there is a format error we will write unformulated text for debug purposes
if err != nil {
file.Write(text)
_, writeErr := file.Write(text)
if writeErr != nil {
return errors.Join(writeErr, fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err))
}
return fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err)
}
@ -43,7 +47,7 @@ func FormatAndSaveGoFile(dirPath, fileName string, text []byte) error {
// EnsureDirPathExist ensures dir path exists. If path does not exist, creates new path.
func EnsureDirPathExist(dirPath string) error {
if _, err := os.Stat(dirPath); os.IsNotExist(err) {
err := os.MkdirAll(dirPath, os.ModePerm)
err := os.MkdirAll(dirPath, 0o750)
if err != nil {
return fmt.Errorf("can't create directory - %s: %w", dirPath, err)

View file

@ -0,0 +1,6 @@
package ptr
// Of returns the address of any given parameter
func Of[T any](value T) *T {
return &value
}

View file

@ -6,23 +6,27 @@ import (
)
type cast interface {
// Cast expressions as castType type
// AS casts expressions as castType type
AS(castType string) Expression
// Cast expression as char with optional length
// AS_CHAR casts expression as char with optional length
AS_CHAR(length ...int) StringExpression
// Cast expression AS date type
// AS_DATE casts expression AS date type
AS_DATE() DateExpression
// Cast expression AS numeric type, using precision and optionally scale
// AS_FLOAT casts expressions as float type
AS_FLOAT() FloatExpression
// AS_DOUBLE casts expressions as double type
AS_DOUBLE() FloatExpression
// AS_DECIMAL casts expression AS numeric type
AS_DECIMAL() FloatExpression
// Cast expression AS time type
// AS_TIME casts expression AS time type
AS_TIME() TimeExpression
// Cast expression as datetime type
// AS_DATETIME casts expression as datetime type
AS_DATETIME() DateTimeExpression
// Cast expressions as signed integer type
// AS_SIGNED casts expressions as signed integer type
AS_SIGNED() IntegerExpression
// Cast expression as unsigned integer type
// AS_UNSIGNED casts expression as unsigned integer type
AS_UNSIGNED() IntegerExpression
// Cast expression as binary type
// AS_BINARY casts expression as binary type
AS_BINARY() StringExpression
}
@ -73,6 +77,14 @@ func (c *castImpl) AS_DATE() DateExpression {
return DateExp(c.AS("DATE"))
}
func (c *castImpl) AS_FLOAT() FloatExpression {
return FloatExp(c.AS("FLOAT"))
}
func (c *castImpl) AS_DOUBLE() FloatExpression {
return FloatExp(c.AS("DOUBLE"))
}
// AS_DECIMAL casts expression AS DECIMAL type
func (c *castImpl) AS_DECIMAL() FloatExpression {
return FloatExp(c.AS("DECIMAL"))

View file

@ -1,6 +1,7 @@
package mysql
import (
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
)
@ -28,6 +29,9 @@ func newDialect() jet.Dialect {
},
ReservedWords: reservedWords,
SerializeOrderBy: serializeOrderBy,
ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column_%d", index)
},
}
return jet.NewDialect(mySQLDialectParams)

View file

@ -30,6 +30,9 @@ type DateTimeExpression = jet.TimestampExpression
// TimestampExpression interface
type TimestampExpression = jet.TimestampExpression
// RowExpression interface
type RowExpression = jet.RowExpression
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
@ -70,6 +73,11 @@ var DateTimeExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output.
var TimestampExp = jet.TimestampExp
// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
var RowExp = jet.RowExp
// CustomExpression is used to define custom expressions.
var CustomExpression = jet.CustomExpression

View file

@ -12,7 +12,9 @@ var (
)
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
var ROW = jet.ROW
func ROW(expressions ...Expression) RowExpression {
return jet.ROW(Dialect, expressions...)
}
// ------------------ Mathematical functions ---------------//

View file

@ -14,25 +14,25 @@ type unitType string
// List of interval unit types for MySQL
const (
MICROSECOND unitType = "MICROSECOND"
SECOND = "SECOND"
MINUTE = "MINUTE"
HOUR = "HOUR"
DAY = "DAY"
WEEK = "WEEK"
MONTH = "MONTH"
QUARTER = "QUARTER"
YEAR = "YEAR"
SECOND_MICROSECOND = "SECOND_MICROSECOND"
MINUTE_MICROSECOND = "MINUTE_MICROSECOND"
MINUTE_SECOND = "MINUTE_SECOND"
HOUR_MICROSECOND = "HOUR_MICROSECOND"
HOUR_SECOND = "HOUR_SECOND"
HOUR_MINUTE = "HOUR_MINUTE"
DAY_MICROSECOND = "DAY_MICROSECOND"
DAY_SECOND = "DAY_SECOND"
DAY_MINUTE = "DAY_MINUTE"
DAY_HOUR = "DAY_HOUR"
YEAR_MONTH = "YEAR_MONTH"
SECOND unitType = "SECOND"
MINUTE unitType = "MINUTE"
HOUR unitType = "HOUR"
DAY unitType = "DAY"
WEEK unitType = "WEEK"
MONTH unitType = "MONTH"
QUARTER unitType = "QUARTER"
YEAR unitType = "YEAR"
SECOND_MICROSECOND unitType = "SECOND_MICROSECOND"
MINUTE_MICROSECOND unitType = "MINUTE_MICROSECOND"
MINUTE_SECOND unitType = "MINUTE_SECOND"
HOUR_MICROSECOND unitType = "HOUR_MICROSECOND"
HOUR_SECOND unitType = "HOUR_SECOND"
HOUR_MINUTE unitType = "HOUR_MINUTE"
DAY_MICROSECOND unitType = "DAY_MICROSECOND"
DAY_SECOND unitType = "DAY_SECOND"
DAY_MINUTE unitType = "DAY_MINUTE"
DAY_HOUR unitType = "DAY_HOUR"
YEAR_MONTH unitType = "YEAR_MONTH"
)
// Interval is representation of MySQL interval

View file

@ -172,7 +172,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
return newSelectTable(s, alias, nil)
}
//-----------------------------------------------------

View file

@ -13,9 +13,9 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases),
}
subQuery.readableTableInterfaceImpl.parent = subQuery

View file

@ -85,7 +85,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement {
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
return newSelectTable(s, alias, nil)
}
const (

View file

@ -6,7 +6,7 @@ import (
)
// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement {
func RawStatement(rawQuery string, namedArguments ...RawArgs) jet.SerializerStatement {
return jet.RawStatement(Dialect, rawQuery, namedArguments...)
}

32
mysql/values.go Normal file
View file

@ -0,0 +1,32 @@
package mysql
import "github.com/go-jet/jet/v2/internal/jet"
type values struct {
jet.Values
}
// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table.
// Each row is defined by the ROW constructor, which takes one or more expressions.
//
// Example usage:
//
// VALUES(
// ROW(Int32(204), Float32(1.21)),
// ROW(Int32(207), Float32(1.02)),
// )
func VALUES(rows ...RowExpression) values {
return values{Values: jet.Values(rows)}
}
// AS assigns an alias to the temporary VALUES table, allowing it to be referenced
// within SQL FROM clauses, just like a regular table.
// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be
// overwritten by passing new list of columns.
//
// Example usage:
//
// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date"))
func (v values) AS(alias string, columns ...Column) SelectTable {
return newSelectTable(v, alias, columns)
}

View file

@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet"
type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
AS(statement jet.SerializerHasProjections) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
@ -41,7 +41,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
}
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
@ -52,7 +52,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
return newSelectTable(c, name, nil)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {

View file

@ -109,13 +109,20 @@ type ColumnInterval interface {
jet.Column
From(subQuery SelectTable) ColumnInterval
SET(intervalExp IntervalExpression) ColumnAssigment
}
//------------------------------------------------------//
type intervalColumnImpl struct {
jet.ColumnExpressionImpl
intervalInterfaceImpl
}
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
return jet.NewColumnAssignment(i, intervalExp)
}
func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval {
newIntervalColumn := IntervalColumn(i.Name())
jet.SetTableName(newIntervalColumn, i.TableName())

View file

@ -1,6 +1,7 @@
package postgres
import (
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
"strconv"
)
@ -25,6 +26,9 @@ func newDialect() jet.Dialect {
return "$" + strconv.Itoa(ord)
},
ReservedWords: reservedWords,
ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column%d", index+1)
},
}
return jet.NewDialect(dialectParams)

View file

@ -46,33 +46,33 @@ func TestExists(t *testing.T) {
func TestIN(t *testing.T) {
assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN (
`($1 IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
))`, float64(1.11))
)))`, float64(1.11))
assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN (
`(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
))`, int64(12))
)))`, int64(12))
}
func TestNOT_IN(t *testing.T) {
assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)),
`($1 NOT IN (
`($1 NOT IN ((
SELECT table1.col1 AS "table1.col1"
FROM db.table1
))`, float64(1.11))
)))`, float64(1.11))
assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN (
`(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1"
FROM db.table2
))`, int64(12))
)))`, int64(12))
}
func TestReservedWordEscaped(t *testing.T) {

View file

@ -36,6 +36,9 @@ type TimestampExpression = jet.TimestampExpression
// TimestampzExpression interface
type TimestampzExpression = jet.TimestampzExpression
// RowExpression interface
type RowExpression = jet.RowExpression
// DateRange Expression interface
type DateRange = jet.Range[DateExpression]
@ -99,6 +102,11 @@ var TimestampExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output.
var TimestampzExp = jet.TimestampzExp
// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
var RowExp = jet.RowExp
// RangeExp is range expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as range expression.
// Does not add sql cast to generated sql builder output.

View file

@ -14,7 +14,9 @@ var (
)
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
var ROW = jet.ROW
func ROW(expressions ...Expression) RowExpression {
return jet.ROW(Dialect, expressions...)
}
// ------------------ Mathematical functions ---------------//
@ -332,6 +334,16 @@ var LOCALTIMESTAMP = jet.LOCALTIMESTAMP
// NOW returns current date and time
var NOW = jet.NOW
// DATE_TRUNC returns the truncated date and time using optional time zone.
// Use TimestampzExp if you need timestamp with time zone and IntervalExp if you need interval.
func DATE_TRUNC(field unit, source Expression, timezone ...string) TimestampExpression {
if len(timezone) > 0 {
return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source, jet.FixedLiteral(timezone[0]))
}
return jet.NewTimestampFunc("DATE_TRUNC", jet.FixedLiteral(unitToString(field)), source)
}
// --------------- Conditional Expressions Functions -------------//
// COALESCE function returns the first of its arguments that is not null.
@ -415,10 +427,12 @@ func castFloatLiteral(fraction FloatExpression) FloatExpression {
// ),
var GROUPING_SETS = jet.GROUPING_SETS
// WRAP wraps list of expressions with brackets - ( expression1, expression2, ... )
// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW method behave exactly the same,
// except when used in GROUPING_SETS. For top level GROUPING SETS expression lists WRAP has to be used.
var WRAP = jet.WRAP
// WRAP surrounds a list of expressions or columns with parentheses, producing new row: (expression1, expression2, ...)
// The construct (a, b) is normally recognized in expressions as a row constructor. WRAP and ROW methods behave exactly the same,
// except when used in GROUPING_SETS and VALUES. In these contexts, WRAP must be used instead of ROW.
func WRAP(expressions ...Expression) RowExpression {
return jet.WRAP(Dialect, expressions...)
}
// ROLLUP operator is used with the GROUP BY clause to generate all prefixes of a group of columns including the empty list.
// It creates extra rows in the result set that represent the subtotal values for each combination of columns.

View file

@ -1,6 +1,8 @@
package postgres
import "testing"
import (
"testing"
)
func TestROW(t *testing.T) {
assertSerialize(t, ROW(SELECT(Int(1))), `ROW((
@ -10,3 +12,12 @@ func TestROW(t *testing.T) {
SELECT $2
), $3)`)
}
func TestDATE_TRUNC(t *testing.T) {
assertSerialize(t, DATE_TRUNC(YEAR, NOW()), "DATE_TRUNC('YEAR', NOW())")
assertSerialize(
t,
DATE_TRUNC(DAY, NOW().ADD(INTERVAL(1, HOUR)), "Australia/Sydney"),
"DATE_TRUNC('DAY', NOW() + INTERVAL '1 HOUR', 'Australia/Sydney')",
)
}

View file

@ -1,7 +1,6 @@
package postgres
import (
"github.com/go-jet/jet/v2/internal/jet"
"github.com/stretchr/testify/require"
"testing"
"time"
@ -151,7 +150,8 @@ func TestInsert_ON_CONFLICT(t *testing.T) {
VALUES("one", "two").
VALUES("1", "2").
VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).
DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
@ -178,12 +178,12 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").
DO_UPDATE(
SET(table1ColBool.SET(Bool(false)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2)))).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `

View file

@ -57,6 +57,16 @@ func Uint64(value uint64) IntegerExpression {
// Float creates new float literal expression
var Float = jet.Float
// Float32 is constructor for 32 bit float literals
func Float32(value float32) FloatExpression {
return CAST(jet.Literal(value)).AS_REAL()
}
// Float64 is constructor for 64 bit float literals
func Float64(value float64) FloatExpression {
return CAST(jet.Literal(value)).AS_DOUBLE()
}
// Decimal creates new float literal expression
var Decimal = jet.Decimal

View file

@ -181,7 +181,7 @@ func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement {
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
return newSelectTable(s, alias, nil)
}
//-----------------------------------------------------

View file

@ -2,7 +2,7 @@ package postgres
import "github.com/go-jet/jet/v2/internal/jet"
// SelectTable is interface for postgres sub-queries
// SelectTable is interface for postgres temporary tables like sub-queries, VALUES, CTEs etc...
type SelectTable interface {
readableTable
jet.SelectTable
@ -13,9 +13,9 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
func newSelectTable(serializerWithProjections jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
SelectTable: jet.NewSelectTable(serializerWithProjections, alias, columnAliases),
}
subQuery.readableTableInterfaceImpl.parent = subQuery

View file

@ -136,7 +136,7 @@ func (s *setStatementImpl) OFFSET_e(offset IntegerExpression) setStatement {
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
return newSelectTable(s, alias, nil)
}
const (

32
postgres/values.go Normal file
View file

@ -0,0 +1,32 @@
package postgres
import "github.com/go-jet/jet/v2/internal/jet"
type values struct {
jet.Values
}
// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table.
// Each row is defined by the WRAP constructor, which takes one or more expressions.
//
// Example usage:
//
// VALUES(
// WRAP(Int32(204), Float32(1.21)),
// WRAP(Int32(207), Float32(1.02)),
// )
func VALUES(rows ...RowExpression) values {
return values{Values: jet.Values(rows)}
}
// AS assigns an alias to the temporary VALUES table, allowing it to be referenced
// within SQL FROM clauses, just like a regular table.
// By default, VALUES columns are named `column1`, `column2`, etc... Default column aliasing can be
// overwritten by passing new list of columns.
//
// Example usage:
//
// VALUES(...).AS("film_values", IntegerColumn("length"), TimestampColumn("update_date"))
func (v values) AS(alias string, columns ...Column) SelectTable {
return newSelectTable(v, alias, columns)
}

View file

@ -6,7 +6,7 @@ import "github.com/go-jet/jet/v2/internal/jet"
type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
AS(statement jet.SerializerHasProjections) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
@ -42,7 +42,7 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
}
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
return newSelectTable(c, name, nil)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {

View file

@ -10,6 +10,10 @@ import (
"time"
)
var (
castOverFlowError = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error")
)
// NullBool struct
type NullBool struct {
sql.NullBool
@ -119,32 +123,47 @@ func (n *NullUInt64) Scan(value interface{}) error {
n.Valid = false
return nil
case int64:
if v < 0 {
return castOverFlowError
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int32:
if v < 0 {
return castOverFlowError
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int16:
if v < 0 {
return castOverFlowError
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int8:
if v < 0 {
return castOverFlowError
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int:
if v < 0 {
return castOverFlowError
}
n.UInt64, n.Valid = uint64(v), true
return nil
case uint64:
n.UInt64, n.Valid = v, true
return nil
case int32:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint32:
n.UInt64, n.Valid = uint64(v), true
return nil
case int16:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint16:
n.UInt64, n.Valid = uint64(v), true
return nil
case int8:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint8:
n.UInt64, n.Valid = uint64(v), true
return nil
case int:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint:
n.UInt64, n.Valid = uint64(v), true
return nil

View file

@ -2,6 +2,7 @@ package internal
import (
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
@ -62,11 +63,21 @@ func TestNullUInt64(t *testing.T) {
value, _ := nullUInt64.Value()
require.Equal(t, value, uint64(11))
require.NoError(t, nullUInt64.Scan(uint64(11)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(11))
require.NoError(t, nullUInt64.Scan(int32(32)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(32))
require.NoError(t, nullUInt64.Scan(uint32(32)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(32))
require.NoError(t, nullUInt64.Scan(int16(20)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
@ -88,4 +99,29 @@ func TestNullUInt64(t *testing.T) {
require.Equal(t, value, uint64(30))
require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text")
//Validate negative use cases
err := nullUInt64.Scan(int64(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
//Validate negative use cases
err = nullUInt64.Scan(-5)
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
//Validate negative use cases
err = nullUInt64.Scan(int32(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
//Validate negative use cases
err = nullUInt64.Scan(int16(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
//Validate negative use cases
err = nullUInt64.Scan(int8(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
}

View file

@ -1,6 +1,7 @@
package sqlite
import (
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
)
@ -23,6 +24,9 @@ func newDialect() jet.Dialect {
return "?"
},
ReservedWords: reservedWords2,
ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column%d", index+1)
},
}
return jet.NewDialect(mySQLDialectParams)

View file

@ -33,6 +33,9 @@ type DateTimeExpression = jet.TimestampExpression
// TimestampExpression interface
type TimestampExpression = jet.TimestampExpression
// RowExpression interface
type RowExpression = jet.RowExpression
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
@ -73,6 +76,11 @@ var DateTimeExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output.
var TimestampExp = jet.TimestampExp
// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
var RowExp = jet.RowExp
// CustomExpression is used to define custom expressions.
var CustomExpression = jet.CustomExpression

View file

@ -15,9 +15,9 @@ var (
OR = jet.OR
)
// ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression {
return jet.NewFunc("", expressions, nil)
// ROW function is used to create a tuple value that consists of a set of expressions or column values.
func ROW(expressions ...Expression) RowExpression {
return jet.WRAP(Dialect, expressions...)
}
// ------------------ Mathematical functions ---------------//

View file

@ -155,7 +155,7 @@ func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
return newSelectTable(s, alias, nil)
}
//-----------------------------------------------------

View file

@ -13,9 +13,9 @@ type selectTableImpl struct {
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string) SelectTable {
func newSelectTable(selectStmt jet.SerializerHasProjections, alias string, columnAliases []jet.ColumnExpression) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
SelectTable: jet.NewSelectTable(selectStmt, alias, columnAliases),
}
subQuery.readableTableInterfaceImpl.parent = subQuery

View file

@ -86,7 +86,7 @@ func (s *setStatementImpl) OFFSET(offset int64) setStatement {
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
return newSelectTable(s, alias, nil)
}
const (

26
sqlite/values.go Normal file
View file

@ -0,0 +1,26 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
type values struct {
jet.Values
}
// VALUES is a table value constructor that computes a set of one or more rows as a temporary constant table.
// Each row is defined by the ROW constructor, which takes one or more expressions.
//
// Example usage:
//
// VALUES(
// ROW(Int32(204), Float32(1.21)),
// ROW(Int32(207), Float32(1.02)),
// )
func VALUES(rows ...RowExpression) values {
return values{Values: jet.Values(rows)}
}
// AS assigns an alias to the temporary VALUES table, allowing it to be referenced
// within SQL FROM clauses, just like a regular table.
func (v values) AS(alias string) SelectTable {
return newSelectTable(v, alias, nil)
}

View file

@ -6,8 +6,8 @@ import "github.com/go-jet/jet/v2/internal/jet"
type CommonTableExpression interface {
SelectTable
AS(statement jet.SerializerStatement) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression
AS(statement jet.SerializerHasProjections) CommonTableExpression
AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
ALIAS(alias string) SelectTable
@ -42,13 +42,13 @@ func CTE(name string, columns ...jet.ColumnExpression) CommonTableExpression {
}
// AS is used to define a CTE query
func (c *commonTableExpression) AS(statement jet.SerializerStatement) CommonTableExpression {
func (c *commonTableExpression) AS(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.Statement = statement
return c
}
// AS_NOT_MATERIALIZED is used to define not materialized CTE query
func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerStatement) CommonTableExpression {
func (c *commonTableExpression) AS_NOT_MATERIALIZED(statement jet.SerializerHasProjections) CommonTableExpression {
c.CommonTableExpression.NotMaterialized = true
c.CommonTableExpression.Statement = statement
return c
@ -60,7 +60,7 @@ func (c *commonTableExpression) internalCTE() *jet.CommonTableExpression {
// ALIAS is used to create another alias of the CTE, if a CTE needs to appear multiple times in the main query.
func (c *commonTableExpression) ALIAS(name string) SelectTable {
return newSelectTable(c, name)
return newSelectTable(c, name, nil)
}
func toInternalCTE(ctes []CommonTableExpression) []*jet.CommonTableExpression {

View file

@ -6,6 +6,8 @@ setup: checkout-testdata docker-compose-up
checkout-testdata:
git submodule init
git submodule update
#
checkout-latest-testdata: checkout-testdata
cd ./testdata && git fetch && git checkout master && git pull
# docker-compose-up will download docker image for each of the databases listed in docker-compose.yaml file, and then it will initialize

View file

@ -13,7 +13,7 @@ services:
- ./testdata/init/postgres:/docker-entrypoint-initdb.d
mysql:
image: mysql:8.0.27
image: mysql:8.0
command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1']
restart: always
environment:

View file

@ -3,6 +3,7 @@ package main
import (
"context"
"database/sql"
"errors"
"flag"
"fmt"
"github.com/go-jet/jet/v2/generator/mysql"
@ -10,7 +11,6 @@ import (
"github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/internal/utils/errfmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
"io/ioutil"
"os"
"os/exec"
"strings"
@ -125,7 +125,7 @@ func initMySQLDB(isMariaDB bool) error {
fmt.Println(cmdLine)
cmd := exec.Command("sh", "-c", cmdLine)
cmd := exec.Command("sh", "-c", cmdLine) // #nosec G204
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
@ -184,7 +184,7 @@ func initPostgresDB(dbType string, connectionString string) error {
}
func execFile(db *sql.DB, sqlFilePath string) error {
testSampleSql, err := ioutil.ReadFile(sqlFilePath)
testSampleSql, err := os.ReadFile(sqlFilePath) // #nosec G304
if err != nil {
return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err)
}
@ -211,7 +211,10 @@ func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error {
err = f(tx)
if err != nil {
tx.Rollback()
rollBackError := tx.Rollback()
if rollBackError != nil {
return errors.Join(rollBackError, err)
}
return err
}

View file

@ -2,7 +2,6 @@ package file
import (
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path"
"testing"
@ -11,7 +10,7 @@ import (
// Exists expects file to exist on path constructed from pathElems and returns content of the file
func Exists(t *testing.T, pathElems ...string) (fileContent string) {
modelFilePath := path.Join(pathElems...)
file, err := ioutil.ReadFile(modelFilePath)
file, err := os.ReadFile(modelFilePath) // #nosec G304
require.Nil(t, err)
require.NotEmpty(t, file)
return string(file)
@ -20,6 +19,6 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) {
// NotExists expects file not to exist on path constructed from pathElems
func NotExists(t *testing.T, pathElems ...string) {
modelFilePath := path.Join(pathElems...)
_, err := ioutil.ReadFile(modelFilePath)
_, err := os.ReadFile(modelFilePath) // #nosec G304
require.True(t, os.IsNotExist(err))
}

View file

@ -1,6 +1,7 @@
package mysql
import (
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"strings"
@ -96,18 +97,18 @@ func TestExpressionOperators(t *testing.T) {
SELECT all_types.'integer' IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in",
(all_types.small_int_ptr IN (
(all_types.small_int_ptr IN ((
SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types
)) AS "result.in_select",
))) AS "result.in_select",
(CURRENT_USER()) AS "result.raw",
(? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg",
(? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN (
(all_types.small_int_ptr NOT IN ((
SELECT all_types.'integer' AS "all_types.integer"
FROM test_sample.all_types
)) AS "result.not_in_select"
))) AS "result.not_in_select"
FROM test_sample.all_types
LIMIT ?;
`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2))
@ -1067,7 +1068,7 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: testutils.BoolPtr(true),
BooleanPtr: ptr.Of(true),
TinyInt: 1,
UTinyInt: 2,
SmallInt: 3,
@ -1078,53 +1079,53 @@ var toInsert = model.AllTypes{
UInteger: 8,
BigInt: 9,
UBigInt: 1122334455,
TinyIntPtr: testutils.Int8Ptr(11),
UTinyIntPtr: testutils.UInt8Ptr(22),
SmallIntPtr: testutils.Int16Ptr(33),
USmallIntPtr: testutils.UInt16Ptr(44),
MediumIntPtr: testutils.Int32Ptr(55),
UMediumIntPtr: testutils.UInt32Ptr(66),
IntegerPtr: testutils.Int32Ptr(77),
UIntegerPtr: testutils.UInt32Ptr(88),
BigIntPtr: testutils.Int64Ptr(99),
UBigIntPtr: testutils.UInt64Ptr(111),
TinyIntPtr: ptr.Of(int8(11)),
UTinyIntPtr: ptr.Of(uint8(22)),
SmallIntPtr: ptr.Of(int16(33)),
USmallIntPtr: ptr.Of(uint16(44)),
MediumIntPtr: ptr.Of(int32(55)),
UMediumIntPtr: ptr.Of(uint32(66)),
IntegerPtr: ptr.Of(int32(77)),
UIntegerPtr: ptr.Of(uint32(88)),
BigIntPtr: ptr.Of(int64(99)),
UBigIntPtr: ptr.Of(uint64(111)),
Decimal: 11.22,
DecimalPtr: testutils.Float64Ptr(33.44),
DecimalPtr: ptr.Of(33.44),
Numeric: 55.66,
NumericPtr: testutils.Float64Ptr(77.88),
NumericPtr: ptr.Of(77.88),
Float: 99.00,
FloatPtr: testutils.Float64Ptr(11.22),
FloatPtr: ptr.Of(11.22),
Double: 33.44,
DoublePtr: testutils.Float64Ptr(55.66),
DoublePtr: ptr.Of(55.66),
Real: 77.88,
RealPtr: testutils.Float64Ptr(99.00),
RealPtr: ptr.Of(99.00),
Bit: "1",
BitPtr: testutils.StringPtr("0"),
BitPtr: ptr.Of("0"),
Time: time.Date(1, 1, 1, 10, 11, 12, 100, &time.Location{}),
TimePtr: testutils.TimePtr(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)),
TimePtr: ptr.Of(time.Date(1, 1, 1, 10, 11, 12, 100, time.UTC)),
Date: time.Now(),
DatePtr: testutils.TimePtr(time.Now()),
DatePtr: ptr.Of(time.Now()),
DateTime: time.Now(),
DateTimePtr: testutils.TimePtr(time.Now()),
DateTimePtr: ptr.Of(time.Now()),
Timestamp: time.Now(),
//TimestampPtr: testutils.TimePtr(time.Now()), // TODO: build fails for MariaDB
Year: 2000,
YearPtr: testutils.Int16Ptr(2001),
YearPtr: ptr.Of(int16(2001)),
Char: "abcd",
CharPtr: testutils.StringPtr("absd"),
CharPtr: ptr.Of("absd"),
VarChar: "abcd",
VarCharPtr: testutils.StringPtr("absd"),
VarCharPtr: ptr.Of("absd"),
Binary: []byte("1010"),
BinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
BinaryPtr: ptr.Of([]byte("100001")),
VarBinary: []byte("1010"),
VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")),
VarBinaryPtr: ptr.Of([]byte("100001")),
Blob: []byte("large file"),
BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
BlobPtr: ptr.Of([]byte("very large file")),
Text: "some text",
TextPtr: testutils.StringPtr("text"),
TextPtr: ptr.Of("text"),
Enum: model.AllTypesEnum_Value1,
JSON: "{}",
JSONPtr: testutils.StringPtr(`{"a": 1}`),
JSONPtr: ptr.Of(`{"a": 1}`),
}
var allTypesJson = `
@ -1358,17 +1359,17 @@ func TestExactDecimals(t *testing.T) {
Floats: model.Floats{
// overwritten by wrapped(floats) scope
Numeric: 0.1,
NumericPtr: testutils.Float64Ptr(0.1),
NumericPtr: ptr.Of(0.1),
Decimal: 0.1,
DecimalPtr: testutils.Float64Ptr(0.1),
DecimalPtr: ptr.Of(0.1),
// not overwritten
Float: 0.2,
FloatPtr: testutils.Float64Ptr(0.22),
FloatPtr: ptr.Of(0.22),
Double: 0.3,
DoublePtr: testutils.Float64Ptr(0.33),
DoublePtr: ptr.Of(0.33),
Real: 0.4,
RealPtr: testutils.Float64Ptr(0.44),
RealPtr: ptr.Of(0.44),
},
Numeric: decimal.RequireFromString("12.35"),
NumericPtr: decimal.RequireFromString("56.79"),
@ -1403,3 +1404,34 @@ VALUES ('91.23', '45.67', '12.35', '56.79', 0.2, 0.22, 0.3, 0.33, 0.4, 0.44);
require.Equal(t, 45.67, *result.Floats.DecimalPtr)
})
}
func TestRowExpression(t *testing.T) {
now := time.Now()
nowAddHour := time.Now().Add(time.Hour)
stmt := SELECT(
ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))),
ROW(Bool(false), DateT(now)).NOT_EQ(ROW(Bool(true), DateT(now))),
ROW(TimestampT(nowAddHour), String("txt")).IS_DISTINCT_FROM(RowExp(Raw("row(NOW(), 'png')"))),
ROW(TimestampT(now), DateTimeT(nowAddHour)).GT(ROW(TimestampT(now), DateTimeT(now))),
ROW(DateTimeT(nowAddHour), Int(1)).GT_EQ(ROW(DateTimeT(now), Int(2))),
ROW(TimestampT(now), DateTimeT(nowAddHour)).LT(ROW(TimestampT(now), DateTimeT(now))),
ROW(DateTimeT(nowAddHour), Float(1.22)).LT_EQ(ROW(DateTimeT(now), Float(2.33))),
)
//fmt.Println(stmt.Sql())
//fmt.Println(stmt.DebugSql())
testutils.AssertStatementSql(t, stmt, `
SELECT ROW(?, CAST(? AS DATE)) = ROW(?, CAST(? AS DATE)),
ROW(?, CAST(? AS DATE)) != ROW(?, CAST(? AS DATE)),
NOT(ROW(TIMESTAMP(?), ?) <=> (row(NOW(), 'png'))),
ROW(TIMESTAMP(?), CAST(? AS DATETIME)) > ROW(TIMESTAMP(?), CAST(? AS DATETIME)),
ROW(CAST(? AS DATETIME), ?) >= ROW(CAST(? AS DATETIME), ?),
ROW(TIMESTAMP(?), CAST(? AS DATETIME)) < ROW(TIMESTAMP(?), CAST(? AS DATETIME)),
ROW(CAST(? AS DATETIME), ?) <= ROW(CAST(? AS DATETIME), ?);
`)
err := stmt.Query(db, &struct{}{})
require.NoError(t, err)
}

View file

@ -8,8 +8,11 @@ import (
"github.com/stretchr/testify/require"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/testutils"
mysql2 "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/dbconfig"
)
@ -39,6 +42,39 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)
}
func TestGenerator_TableMetadata(t *testing.T) {
var schema metadata.Schema
err := mysql.Generate(genTestDir3, dbConnection("dvds"),
template.Default(mysql2.Dialect).UseSchema(func(m metadata.Schema) template.Schema {
schema = m
return template.DefaultSchema(m)
}))
require.NoError(t, err)
// Spot check the actor table and assert that the emitted
// properties are as expected.
var got metadata.Table
for _, table := range schema.TablesMetaData {
if table.Name == "actor" {
got = table
}
}
want := metadata.Table{
Name: "actor",
Columns: []metadata.Column{
{Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "smallint", Kind: "base", IsUnsigned: true}, Comment: ""},
{Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""},
},
}
require.Equal(t, want, got)
err = os.RemoveAll(genTestDirRoot)
require.NoError(t, err)
}
func TestCmdGenerator(t *testing.T) {
err := os.RemoveAll(genTestDir3)
require.NoError(t, err)

View file

@ -3,10 +3,12 @@ package mysql
import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/internal/utils/ptr"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/qrm"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/require"
"math/rand"
"testing"
@ -300,7 +302,7 @@ func TestInsertOnDuplicateKeyUpdateNEW(t *testing.T) {
ID: randId,
URL: "https://www.yahoo.com",
Name: "Yahoo",
Description: testutils.StringPtr("web portal and search engine"),
Description: ptr.Of("web portal and search engine"),
},
}).AS_NEW().
ON_DUPLICATE_KEY_UPDATE(
@ -337,7 +339,7 @@ ON DUPLICATE KEY UPDATE id = (link.id + ?),
ID: randId + 11,
URL: "https://www.yahoo.com",
Name: "Yahoo",
Description: testutils.StringPtr("web portal and search engine"),
Description: ptr.Of("web portal and search engine"),
})
})
}

View file

@ -6,12 +6,9 @@ import (
jetmysql "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/stretchr/testify/require"
"math/rand"
"runtime"
"time"
_ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/require"
"runtime"
"github.com/pkg/profile"
"os"
@ -33,7 +30,6 @@ func sourceIsMariaDB() bool {
}
func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
var err error
@ -101,3 +97,9 @@ func skipForMariaDB(t *testing.T) {
t.SkipNow()
}
}
func onlyMariaDB(t *testing.T) {
if !sourceIsMariaDB() {
t.SkipNow()
}
}

347
tests/mysql/values_test.go Normal file
View file

@ -0,0 +1,347 @@
package mysql
import (
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/stretchr/testify/require"
"strings"
"testing"
"time"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
)
func TestVALUES(t *testing.T) {
skipForMariaDB(t)
valuesTable := VALUES(
ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")),
ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")),
ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL),
).AS("values_table")
stmt := SELECT(
valuesTable.AllColumns(),
).FROM(
valuesTable,
)
testutils.AssertStatementSql(t, stmt, `
SELECT values_table.column_0 AS "column_0",
values_table.column_1 AS "column_1",
values_table.column_2 AS "column_2",
values_table.column_3 AS "column_3",
values_table.column_4 AS "column_4"
FROM (
VALUES ROW(?, ?, ?, ?, ?),
ROW(? + ?, ?, ?, ?, ?),
ROW(?, ?, ?, ?, NULL)
) AS values_table;
`)
var dest []struct {
Column0 int
Column1 int
Column2 float32
Column3 bool
Column4 *string
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"Column0": 1,
"Column1": 2,
"Column2": 4.666,
"Column3": false,
"Column4": "txt"
},
{
"Column0": 13,
"Column1": 22,
"Column2": 33.222,
"Column3": true,
"Column4": "png"
},
{
"Column0": 11,
"Column1": 22,
"Column2": 33.222,
"Column3": true,
"Column4": null
}
]
`)
}
func TestVALUES_Join(t *testing.T) {
skipForMariaDB(t)
title := StringColumn("title")
releaseYear := IntegerColumn("ReleaseYear")
rentalRate := FloatColumn("rental_rate")
lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0)
films := VALUES(
ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate),
ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))),
ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL),
ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL),
ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))),
).AS("film_values",
title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("last_update"))
stmt := SELECT(
Film.AllColumns,
films.AllColumns(),
).FROM(
Film.
INNER_JOIN(films, title.EQ(Film.Title)),
).WHERE(AND(
Film.ReleaseYear.GT(releaseYear),
Film.RentalRate.LT(rentalRate),
)).ORDER_BY(
title,
)
testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(`
SELECT film.film_id AS "film.film_id",
film.title AS "film.title",
film.description AS "film.description",
film.release_year AS "film.release_year",
film.language_id AS "film.language_id",
film.original_language_id AS "film.original_language_id",
film.rental_duration AS "film.rental_duration",
film.rental_rate AS "film.rental_rate",
film.length AS "film.length",
film.replacement_cost AS "film.replacement_cost",
film.rating AS "film.rating",
film.special_features AS "film.special_features",
film.last_update AS "film.last_update",
film_values.title AS "title",
film_values.length AS "length",
film_values.''ReleaseYear'' AS "ReleaseYear",
film_values.rental_rate AS "rental_rate",
film_values.last_update AS "last_update"
FROM dvds.film
INNER JOIN (
VALUES ROW('Chamber Italian', 117, 2005, 5.82, TIMESTAMP('2007-02-11 12:00:00')),
ROW('Grosse Wonderful', 49, 2004, 6.242, TIMESTAMP('2007-02-11 12:00:00') + INTERVAL 1 HOUR),
ROW('Airport Pollock', 54, 2001, 7.22, NULL),
ROW('Bright Encounters', 73, 2002, 8.25, NULL),
ROW('Academy Dinosaur', 83, 2010, 9.22, TIMESTAMP('2007-02-11 12:00:00') - INTERVAL 2 MINUTE)
) AS film_values (title, length, ''ReleaseYear'', rental_rate, last_update) ON (film_values.title = film.title)
WHERE (
(film.release_year > film_values.''ReleaseYear'')
AND (film.rental_rate < film_values.rental_rate)
)
ORDER BY film_values.title;
`, "''", "`"))
var dest []struct {
Film model.Film
Title string
Length int
ReleaseYear int
RentalRate float32
LastUpdate *time.Time
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Len(t, dest, 4)
testutils.AssertJSON(t, dest[0:2], `
[
{
"Film": {
"FilmID": 8,
"Title": "AIRPORT POLLOCK",
"Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 6,
"RentalRate": 4.99,
"Length": 54,
"ReplacementCost": 15.99,
"Rating": "R",
"SpecialFeatures": "Trailers",
"LastUpdate": "2006-02-15T05:03:42Z"
},
"Title": "Airport Pollock",
"Length": 54,
"ReleaseYear": 2001,
"RentalRate": 7.22,
"LastUpdate": null
},
{
"Film": {
"FilmID": 98,
"Title": "BRIGHT ENCOUNTERS",
"Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 4,
"RentalRate": 4.99,
"Length": 73,
"ReplacementCost": 12.99,
"Rating": "PG-13",
"SpecialFeatures": "Trailers",
"LastUpdate": "2006-02-15T05:03:42Z"
},
"Title": "Bright Encounters",
"Length": 73,
"ReleaseYear": 2002,
"RentalRate": 8.25,
"LastUpdate": null
}
]
`)
}
func TestVALUES_CTE_Update(t *testing.T) {
skipForMariaDB(t)
paymentID := IntegerColumn("payment_id")
increase := FloatColumn("increase")
paymentsToUpdate := CTE("values_cte", paymentID, increase)
stmt := WITH(
paymentsToUpdate.AS(
VALUES(
ROW(Int32(204), Float(1.21)),
ROW(Int32(207), Float(1.02)),
ROW(Int32(200), Float(1.34)),
ROW(Int32(203), Float(1.72)),
),
),
)(
Payment.INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)).
UPDATE().
SET(
Payment.Amount.SET(Payment.Amount.MUL(increase)),
).WHERE(Bool(true)),
)
testutils.AssertStatementSql(t, stmt, `
WITH values_cte (payment_id, increase) AS (
VALUES ROW(?, ?),
ROW(?, ?),
ROW(?, ?),
ROW(?, ?)
)
UPDATE dvds.payment
INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id)
SET amount = (payment.amount * values_cte.increase)
WHERE ?;
`)
testutils.AssertExecAndRollback(t, stmt, db, 4)
}
func TestVALUES_MariaDB(t *testing.T) {
onlyMariaDB(t) // mariadb won't accept values rows if all the elements are placeholders, so we have to use raw statement
paymentID := IntegerColumn("payment_id")
increase := FloatColumn("increase")
paymentsToUpdate := CTE("values_cte", paymentID, increase)
stmt := WITH(
paymentsToUpdate.AS(
RawStatement(`
VALUES (204, 1.21),
(207, 1.02),
(200, 1.34),
(203, 1.72)
`),
),
)(
SELECT(
Payment.AllColumns,
paymentsToUpdate.AllColumns(),
).FROM(
Payment.
INNER_JOIN(paymentsToUpdate, paymentID.EQ(Payment.PaymentID)),
).WHERE(
increase.GT(Float(1.03)),
).ORDER_BY(
increase,
),
)
testutils.AssertStatementSql(t, stmt, `
WITH values_cte (payment_id, increase) AS (
VALUES (204, 1.21),
(207, 1.02),
(200, 1.34),
(203, 1.72)
)
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date",
payment.last_update AS "payment.last_update",
values_cte.payment_id AS "payment_id",
values_cte.increase AS "increase"
FROM dvds.payment
INNER JOIN values_cte ON (values_cte.payment_id = payment.payment_id)
WHERE values_cte.increase > ?
ORDER BY values_cte.increase;
`)
var dest []struct {
model.Payment
Increase float64
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"PaymentID": 204,
"CustomerID": 7,
"StaffID": 1,
"RentalID": 13476,
"Amount": 2.99,
"PaymentDate": "2005-08-20T01:06:04Z",
"LastUpdate": "2006-02-15T22:12:31Z",
"Increase": 1.21
},
{
"PaymentID": 200,
"CustomerID": 7,
"StaffID": 2,
"RentalID": 11542,
"Amount": 7.99,
"PaymentDate": "2005-08-17T00:51:32Z",
"LastUpdate": "2006-02-15T22:12:31Z",
"Increase": 1.34
},
{
"PaymentID": 203,
"CustomerID": 7,
"StaffID": 2,
"RentalID": 13373,
"Amount": 2.99,
"PaymentDate": "2005-08-19T21:23:31Z",
"LastUpdate": "2006-02-15T22:12:31Z",
"Increase": 1.72
}
]
`)
}

View file

@ -164,10 +164,10 @@ WITH payments_to_delete AS (
WHERE payment.amount < 0.5
)
DELETE FROM dvds.payment
WHERE payment.payment_id IN (
WHERE payment.payment_id IN ((
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete
);
));
`, "''", "`"))
tx, err := db.Begin()

View file

@ -1,6 +1,10 @@
package postgres
import (
"database/sql"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/stretchr/testify/assert"
"github.com/go-jet/jet/v2/qrm"
"testing"
"time"
@ -344,24 +348,22 @@ func TestExpressionOperators(t *testing.T) {
AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"),
).LIMIT(2)
//fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, `
SELECT all_types.integer IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in",
(all_types.small_int_ptr IN (
(all_types.small_int_ptr IN ((
SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types
)) AS "result.in_select",
))) AS "result.in_select",
(CURRENT_USER) AS "result.raw",
($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg",
($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN ($9, $10::smallint, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN (
(all_types.small_int_ptr NOT IN ((
SELECT all_types.integer AS "all_types.integer"
FROM test_sample.all_types
)) AS "result.not_in_select"
))) AS "result.not_in_select"
FROM test_sample.all_types
LIMIT $11;
`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2))
@ -373,9 +375,6 @@ LIMIT $11;
err := query.Query(db, &dest)
require.NoError(t, err)
//testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, `
[
{
@ -930,6 +929,68 @@ func TestTimeExpression(t *testing.T) {
require.NoError(t, err)
}
func TestIntervalSetFunctionality(t *testing.T) {
t.Run("updateQueryIntervalTest", func(t *testing.T) {
expectedQuery := `
UPDATE test_sample.employee
SET pto_accrual = INTERVAL '3 HOUR'
WHERE employee.employee_id = $1
RETURNING employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name",
employee.employment_date AS "employee.employment_date",
employee.manager_id AS "employee.manager_id",
employee.pto_accrual AS "employee.pto_accrual";
`
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var windy model.Employee
windy.PtoAccrual = ptr.Of("3h")
stmt := Employee.UPDATE(Employee.PtoAccrual).SET(
Employee.PtoAccrual.SET(INTERVAL(3, HOUR)),
).WHERE(Employee.EmployeeID.EQ(Int(1))).RETURNING(Employee.AllColumns)
testutils.AssertStatementSql(t, stmt, expectedQuery)
err := stmt.Query(tx, &windy)
assert.Nil(t, err)
assert.Equal(t, *windy.PtoAccrual, "03:00:00")
})
})
t.Run("upsertQueryIntervalTest", func(t *testing.T) {
expectedQuery := `
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (employee_id) DO UPDATE
SET pto_accrual = excluded.pto_accrual
RETURNING employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name",
employee.employment_date AS "employee.employment_date",
employee.manager_id AS "employee.manager_id",
employee.pto_accrual AS "employee.pto_accrual";
`
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var employee model.Employee
employee.PtoAccrual = ptr.Of("5h")
stmt := Employee.INSERT(Employee.AllColumns).
MODEL(employee).
ON_CONFLICT(Employee.EmployeeID).
DO_UPDATE(SET(
Employee.PtoAccrual.SET(Employee.EXCLUDED.PtoAccrual),
)).RETURNING(Employee.AllColumns)
testutils.AssertStatementSql(t, stmt, expectedQuery)
err := stmt.Query(tx, &employee)
assert.Nil(t, err)
assert.Equal(t, *employee.PtoAccrual, "05:00:00")
})
})
}
func TestInterval(t *testing.T) {
skipForCockroachDB(t)
@ -1046,6 +1107,46 @@ FROM test_sample.all_types;
require.NoError(t, err)
}
func TestRowExpression(t *testing.T) {
now := time.Now()
nowAddHour := time.Now().Add(time.Hour)
stmt := SELECT(
ROW(Int32(1), Float32(11.22), String("john")).AS("row"),
WRAP(Int64(1), Float64(11.22), String("john")).AS("wrap"),
ROW(Bool(false), DateT(now)).EQ(ROW(Bool(true), DateT(now))),
WRAP(Bool(false), DateT(now)).NOT_EQ(WRAP(Bool(true), DateT(now))),
ROW(TimeT(nowAddHour)).IS_DISTINCT_FROM(RowExp(Raw("row(NOW()::time)"))),
ROW().IS_NOT_DISTINCT_FROM(ROW()),
ROW(TimestampT(now), TimestampzT(nowAddHour)).GT(WRAP(TimestampT(now), TimestampzT(now))),
ROW(TimestampzT(nowAddHour)).GT_EQ(ROW(TimestampzT(now))),
WRAP(TimestampT(now), TimestampzT(nowAddHour)).LT(ROW(TimestampT(now), TimestampzT(now))),
ROW(TimestampzT(nowAddHour)).LT_EQ(ROW(TimestampzT(now))),
)
//fmt.Println(stmt.Sql())
//fmt.Println(stmt.DebugSql())
testutils.AssertStatementSql(t, stmt, `
SELECT ROW($1::integer, $2::real, $3::text) AS "row",
($4::bigint, $5::double precision, $6::text) AS "wrap",
ROW($7::boolean, $8::date) = ROW($9::boolean, $10::date),
($11::boolean, $12::date) != ($13::boolean, $14::date),
ROW($15::time without time zone) IS DISTINCT FROM (row(NOW()::time)),
ROW() IS NOT DISTINCT FROM ROW(),
ROW($16::timestamp without time zone, $17::timestamp with time zone) > ($18::timestamp without time zone, $19::timestamp with time zone),
ROW($20::timestamp with time zone) >= ROW($21::timestamp with time zone),
($22::timestamp without time zone, $23::timestamp with time zone) < ROW($24::timestamp without time zone, $25::timestamp with time zone),
ROW($26::timestamp with time zone) <= ROW($27::timestamp with time zone);
`)
err := stmt.Query(db, &struct{}{})
require.NoError(t, err)
}
func TestSubQueryColumnReference(t *testing.T) {
type expected struct {
sql string
@ -1305,32 +1406,32 @@ RETURNING all_types.json AS "all_types.json";
var moodSad = model.Mood_Sad
var allTypesRow0 = model.AllTypes{
SmallIntPtr: testutils.Int16Ptr(14),
SmallIntPtr: ptr.Of(int16(14)),
SmallInt: 14,
IntegerPtr: testutils.Int32Ptr(300),
IntegerPtr: ptr.Of(int32(300)),
Integer: 300,
BigIntPtr: testutils.Int64Ptr(50000),
BigIntPtr: ptr.Of(int64(50000)),
BigInt: 5000,
DecimalPtr: testutils.Float64Ptr(1.11),
DecimalPtr: ptr.Of(1.11),
Decimal: 1.11,
NumericPtr: testutils.Float64Ptr(2.22),
NumericPtr: ptr.Of(2.22),
Numeric: 2.22,
RealPtr: testutils.Float32Ptr(5.55),
RealPtr: ptr.Of(float32(5.55)),
Real: 5.55,
DoublePrecisionPtr: testutils.Float64Ptr(11111111.22),
DoublePrecisionPtr: ptr.Of(11111111.22),
DoublePrecision: 11111111.22,
Smallserial: 1,
Serial: 1,
Bigserial: 1,
//MoneyPtr: nil,
//Money:
VarCharPtr: testutils.StringPtr("ABBA"),
VarCharPtr: ptr.Of("ABBA"),
VarChar: "ABBA",
CharPtr: testutils.StringPtr("JOHN "),
CharPtr: ptr.Of("JOHN "),
Char: "JOHN ",
TextPtr: testutils.StringPtr("Some text"),
TextPtr: ptr.Of("Some text"),
Text: "Some text",
ByteaPtr: testutils.ByteArrayPtr([]byte("bytea")),
ByteaPtr: ptr.Of([]byte("bytea")),
Bytea: []byte("bytea"),
TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
@ -1342,31 +1443,31 @@ var allTypesRow0 = model.AllTypes{
Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"),
TimePtr: testutils.TimeWithoutTimeZone("04:05:06"),
Time: *testutils.TimeWithoutTimeZone("04:05:06"),
IntervalPtr: testutils.StringPtr("3 days 04:05:06"),
IntervalPtr: ptr.Of("3 days 04:05:06"),
Interval: "3 days 04:05:06",
BooleanPtr: testutils.BoolPtr(true),
BooleanPtr: ptr.Of(true),
Boolean: false,
PointPtr: testutils.StringPtr("(2,3)"),
BitPtr: testutils.StringPtr("101"),
PointPtr: ptr.Of("(2,3)"),
BitPtr: ptr.Of("101"),
Bit: "101",
BitVaryingPtr: testutils.StringPtr("101111"),
BitVaryingPtr: ptr.Of("101111"),
BitVarying: "101111",
TsvectorPtr: testutils.StringPtr("'supernova':1"),
TsvectorPtr: ptr.Of("'supernova':1"),
Tsvector: "'supernova':1",
UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"),
XMLPtr: testutils.StringPtr("<Sub>abc</Sub>"),
XMLPtr: ptr.Of("<Sub>abc</Sub>"),
XML: "<Sub>abc</Sub>",
JSONPtr: testutils.StringPtr(`{"a": 1, "b": 3}`),
JSONPtr: ptr.Of(`{"a": 1, "b": 3}`),
JSON: `{"a": 1, "b": 3}`,
JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`),
JsonbPtr: ptr.Of(`{"a": 1, "b": 3}`),
Jsonb: `{"a": 1, "b": 3}`,
IntegerArrayPtr: testutils.StringPtr("{1,2,3}"),
IntegerArrayPtr: ptr.Of("{1,2,3}"),
IntegerArray: "{1,2,3}",
TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"),
TextArrayPtr: ptr.Of("{breakfast,consulting}"),
TextArray: "{breakfast,consulting}",
JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`,
TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"),
TextMultiDimArrayPtr: ptr.Of("{{meeting,lunch},{training,presentation}}"),
TextMultiDimArray: "{{meeting,lunch},{training,presentation}}",
MoodPtr: &moodSad,
Mood: model.Mood_Happy,

View file

@ -3,6 +3,7 @@ package postgres
import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/internal/utils/ptr"
. "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table"
@ -455,7 +456,7 @@ FROM (
require.Len(t, dest, 275)
require.Equal(t, dest[0].Artist1.Artist, model.Artist{
ArtistId: 1,
Name: testutils.StringPtr("AC/DC"),
Name: ptr.Of("AC/DC"),
})
require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1")
require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2")

View file

@ -3,6 +3,9 @@ package postgres
import (
"database/sql"
"fmt"
"path"
"testing"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/generator/template"
@ -13,8 +16,6 @@ import (
"github.com/go-jet/jet/v2/tests/dbconfig"
file2 "github.com/go-jet/jet/v2/tests/internal/utils/file"
"github.com/stretchr/testify/require"
"path"
"testing"
)
const tempTestDir = "./.tempTestDir"
@ -170,6 +171,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) {
mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go")
require.Contains(t, mpaaRating, "type MpaaRatingEnum string")
require.Contains(t, mpaaRating, "MpaaRatingEnumAllValues")
}
func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) {
@ -267,7 +269,6 @@ func UseSchema(schema string) {
FilmList = FilmList.FromSchema(schema)
}
`)
}
func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) {
@ -365,7 +366,6 @@ func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) {
}
func TestGeneratorTemplate_Model_AddTags(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,

View file

@ -2,7 +2,6 @@ package postgres
import (
"fmt"
"github.com/go-jet/jet/v2/tests/internal/utils/file"
"os"
"os/exec"
"path/filepath"
@ -12,11 +11,14 @@ import (
"github.com/stretchr/testify/require"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/tests/dbconfig"
postgres2 "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
"github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/go-jet/jet/v2/tests/internal/utils/file"
)
func dsn(host string, port int, dbName, user, password string) string {
@ -208,6 +210,45 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)
}
func TestGenerator_TableMetadata(t *testing.T) {
var schema metadata.Schema
err := postgres.GenerateDSN(defaultDSN(), "dvds", genTestDir2,
template.Default(postgres2.Dialect).UseSchema(func(m metadata.Schema) template.Schema {
schema = m
return template.DefaultSchema(m)
}))
require.NoError(t, err)
// Spot check the actor table and assert that the emitted
// properties are as expected.
var got metadata.Table
var specialFeatures metadata.Column
for _, table := range schema.TablesMetaData {
if table.Name == "actor" {
got = table
}
if table.Name == "film" {
for _, column := range table.Columns {
if column.Name == "special_features" {
specialFeatures = column
}
}
}
}
want := metadata.Table{
Name: "actor",
Columns: []metadata.Column{
{Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "int4", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "varchar", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "timestamp", Kind: "base", IsUnsigned: false}, Comment: ""},
},
}
require.Equal(t, want, got)
require.Equal(t, metadata.ArrayType, specialFeatures.DataType.Kind)
}
func TestGeneratorSpecialCharacters(t *testing.T) {
t.SkipNow()
err := postgres.Generate(genTestDir2, postgres.DBConnection{
@ -563,6 +604,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go", "people.go",
"components.go", "vulnerabilities.go", "all_types_materialized_view.go", "sample_ranges.go")
testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent)
testutils.AssertFileContent(t, modelDir+"/link.go", linkModelContent)
testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", "employee.go", "link.go",
"person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go", "people.go", "table_use_schema.go",
@ -570,6 +612,8 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent)
testutils.AssertFileContent(t, tableDir+"/sample_ranges.go", sampleRangeTableContent)
testutils.AssertFileContent(t, tableDir+"/link.go", linkTableContent)
testutils.AssertFileNamesEqual(t, viewDir, "all_types_materialized_view.go", "all_types_view.go",
"view_use_schema.go")
}
@ -609,6 +653,7 @@ package enum
import "github.com/go-jet/jet/v2/postgres"
// Level enum
var Level = &struct {
Level1 postgres.StringExpression
Level2 postgres.StringExpression
@ -706,6 +751,25 @@ type AllTypes struct {
}
`
var linkModelContent = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
// Link table
type Link struct {
ID int64 ` + "`sql:\"primary_key\"`" + ` // this is link id
URL string // link url
Name string // Unicode characters comment ₲鬼佬℧⇄↻
Description *string // '"Z\%_
}
`
var allTypesTableContent = `
//
// Code generated by go-jet DO NOT EDIT.
@ -1062,3 +1126,91 @@ func newSampleRangesTableImpl(schemaName, tableName, alias string) sampleRangesT
}
}
`
var linkTableContent = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/postgres"
)
var Link = newLinkTable("test_sample", "link", "")
// Link table
type linkTable struct {
postgres.Table
// Columns
ID postgres.ColumnInteger // this is link id
URL postgres.ColumnString // link url
Name postgres.ColumnString // Unicode characters comment ₲鬼佬℧⇄↻
Description postgres.ColumnString // '"Z\%_
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
}
type LinkTable struct {
linkTable
EXCLUDED linkTable
}
// AS creates new LinkTable with assigned alias
func (a LinkTable) AS(alias string) *LinkTable {
return newLinkTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new LinkTable with assigned schema name
func (a LinkTable) FromSchema(schemaName string) *LinkTable {
return newLinkTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new LinkTable with assigned table prefix
func (a LinkTable) WithPrefix(prefix string) *LinkTable {
return newLinkTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new LinkTable with assigned table suffix
func (a LinkTable) WithSuffix(suffix string) *LinkTable {
return newLinkTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newLinkTable(schemaName, tableName, alias string) *LinkTable {
return &LinkTable{
linkTable: newLinkTableImpl(schemaName, tableName, alias),
EXCLUDED: newLinkTableImpl("", "excluded", ""),
}
}
func newLinkTableImpl(schemaName, tableName, alias string) linkTable {
var (
IDColumn = postgres.IntegerColumn("id")
URLColumn = postgres.StringColumn("url")
NameColumn = postgres.StringColumn("name")
DescriptionColumn = postgres.StringColumn("description")
allColumns = postgres.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn}
mutableColumns = postgres.ColumnList{URLColumn, NameColumn, DescriptionColumn}
)
return linkTable{
Table: postgres.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
ID: IDColumn,
URL: URLColumn,
Name: NameColumn,
Description: DescriptionColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`

View file

@ -93,9 +93,9 @@ func TestInsertOnConflict(t *testing.T) {
ON_CONFLICT().DO_NOTHING()
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
VALUES ($1, $2, $3, $4, $5),
($6, $7, $8, $9, $10)
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
VALUES ($1, $2, $3, $4, $5, $6),
($7, $8, $9, $10, $11, $12)
ON CONFLICT DO NOTHING;
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
@ -111,9 +111,9 @@ ON CONFLICT DO NOTHING;
ON_CONFLICT(Employee.EmployeeID).DO_NOTHING()
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
VALUES ($1, $2, $3, $4, $5),
($6, $7, $8, $9, $10)
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
VALUES ($1, $2, $3, $4, $5, $6),
($7, $8, $9, $10, $11, $12)
ON CONFLICT (employee_id) DO NOTHING;
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
@ -130,9 +130,9 @@ ON CONFLICT (employee_id) DO NOTHING;
ON_CONFLICT().ON_CONSTRAINT("employee_pkey").DO_NOTHING()
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
VALUES ($1, $2, $3, $4, $5),
($6, $7, $8, $9, $10)
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
VALUES ($1, $2, $3, $4, $5, $6),
($7, $8, $9, $10, $11, $12)
ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING;
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
@ -234,8 +234,8 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE
ON_CONFLICT().DO_UPDATE(nil)
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id)
VALUES ($1, $2, $3, $4, $5);
INSERT INTO test_sample.employee (employee_id, first_name, last_name, employment_date, manager_id, pto_accrual)
VALUES ($1, $2, $3, $4, $5, $6);
`)
testutils.AssertExecAndRollback(t, stmt, db, 1)
requireLogged(t, stmt)

View file

@ -5,13 +5,10 @@ import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
"math/rand"
"github.com/jackc/pgx/v4/stdlib"
"os"
"runtime"
"testing"
"time"
"github.com/jackc/pgx/v4/stdlib"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/dbconfig"
@ -44,7 +41,6 @@ func skipForCockroachDB(t *testing.T) {
}
func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
setTestRoot()

View file

@ -11,7 +11,8 @@ import (
func TestNorthwindJoinEverything(t *testing.T) {
stmt := SELECT(
stmt :=
SELECT(
Customers.AllColumns,
CustomerDemographics.AllColumns,
Orders.AllColumns,
@ -34,11 +35,7 @@ func TestNorthwindJoinEverything(t *testing.T) {
LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)).
LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)).
LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)),
).ORDER_BY(
Customers.CustomerID,
Orders.OrderID,
Products.ProductID,
)
).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID)
var dest []struct {
model.Customers

View file

@ -2,6 +2,7 @@ package postgres
import (
"github.com/go-jet/jet/v2/qrm"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/google/uuid"
"testing"
@ -63,15 +64,15 @@ func TestExactDecimals(t *testing.T) {
Floats: model.Floats{
// overwritten by wrapped(floats) scope
Numeric: 0.1,
NumericPtr: testutils.Float64Ptr(0.1),
NumericPtr: ptr.Of(0.1),
Decimal: 0.1,
DecimalPtr: testutils.Float64Ptr(0.1),
DecimalPtr: ptr.Of(0.1),
// not overwritten
Real: 0.4,
RealPtr: testutils.Float32Ptr(0.44),
RealPtr: ptr.Of(float32(0.44)),
Double: 0.3,
DoublePtr: testutils.Float64Ptr(0.33),
DoublePtr: ptr.Of(0.33),
},
Numeric: decimal.RequireFromString("0.1234567890123456789"),
NumericPtr: decimal.RequireFromString("1.1111111111111111111"),
@ -330,11 +331,13 @@ SELECT employee.employee_id AS "employee.employee_id",
employee.last_name AS "employee.last_name",
employee.employment_date AS "employee.employment_date",
employee.manager_id AS "employee.manager_id",
employee.pto_accrual AS "employee.pto_accrual",
manager.employee_id AS "manager.employee_id",
manager.first_name AS "manager.first_name",
manager.last_name AS "manager.last_name",
manager.employment_date AS "manager.employment_date",
manager.manager_id AS "manager.manager_id"
manager.manager_id AS "manager.manager_id",
manager.pto_accrual AS "manager.pto_accrual"
FROM test_sample.employee
LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id)
ORDER BY employee.employee_id;
@ -369,6 +372,7 @@ ORDER BY employee.employee_id;
LastName: "Hays",
EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06.1 +0100 CET", 1),
ManagerID: nil,
PtoAccrual: ptr.Of("22:00:00"),
})
require.True(t, dest[0].Manager == nil)
@ -378,7 +382,7 @@ ORDER BY employee.employee_id;
FirstName: "Salley",
LastName: "Lester",
EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1),
ManagerID: testutils.Int32Ptr(3),
ManagerID: ptr.Of(int32(3)),
})
}
@ -420,7 +424,7 @@ FROM test_sample."WEIRD NAMES TABLE";
WeirdColumnName5: "Doe",
WeirdColumnName6: "Doe",
WeirdColumnName7: "Doe",
Weirdcolumnname8: testutils.StringPtr("Doe"),
Weirdcolumnname8: ptr.Of("Doe"),
WeirdColName9: "Doe",
WeirdColuName10: "Doe",
WeirdColuName11: "Doe",
@ -518,7 +522,7 @@ func TestMutableColumnsExcludeGeneratedColumn(t *testing.T) {
).MODEL(
model.People{
PeopleName: "Dario",
PeopleHeightCm: testutils.Float64Ptr(120),
PeopleHeightCm: ptr.Of(120.0),
},
).RETURNING(
People.MutableColumns,

View file

@ -2,6 +2,7 @@ package postgres
import (
"context"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/volatiletech/null/v8"
"testing"
"time"
@ -93,7 +94,7 @@ func TestScanToValidDestination(t *testing.T) {
err := oneInventoryQuery.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest[0], testutils.Int32Ptr(1))
require.Equal(t, dest[0], ptr.Of(int32(1)))
})
t.Run("NULL to integer", func(t *testing.T) {
@ -530,10 +531,10 @@ func TestScanToSlice(t *testing.T) {
require.NoError(t, err)
require.Equal(t, len(dest), 2)
testutils.AssertDeepEqual(t, dest[0].Film, film1)
testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.Int32Ptr(1), testutils.Int32Ptr(2), testutils.Int32Ptr(3), testutils.Int32Ptr(4),
testutils.Int32Ptr(5), testutils.Int32Ptr(6), testutils.Int32Ptr(7), testutils.Int32Ptr(8)})
testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{ptr.Of(int32(1)), ptr.Of(int32(2)), ptr.Of(int32(3)), ptr.Of(int32(4)),
ptr.Of(int32(5)), ptr.Of(int32(6)), ptr.Of(int32(7)), ptr.Of(int32(8))})
testutils.AssertDeepEqual(t, dest[1].Film, film2)
testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.Int32Ptr(9), testutils.Int32Ptr(10)})
testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{ptr.Of(int32(9)), ptr.Of(int32(10))})
})
t.Run("complex struct 1", func(t *testing.T) {
@ -1076,10 +1077,10 @@ VALUES (1234, 0, 'Joe', '', NULL, 1, TRUE, '2020-02-02 10:00:00Z', NULL, 1);
var address256 = model.Address{
AddressID: 256,
Address: "1497 Yuzhou Drive",
Address2: testutils.StringPtr(""),
Address2: ptr.Of(""),
District: "England",
CityID: 312,
PostalCode: testutils.StringPtr("3433"),
PostalCode: ptr.Of("3433"),
Phone: "246810237916",
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0),
}
@ -1087,10 +1088,10 @@ var address256 = model.Address{
var addres517 = model.Address{
AddressID: 517,
Address: "548 Uruapan Street",
Address2: testutils.StringPtr(""),
Address2: ptr.Of(""),
District: "Ontario",
CityID: 312,
PostalCode: testutils.StringPtr("35653"),
PostalCode: ptr.Of("35653"),
Phone: "879347453467",
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0),
}
@ -1100,12 +1101,12 @@ var customer256 = model.Customer{
StoreID: 2,
FirstName: "Mattie",
LastName: "Hoffman",
Email: testutils.StringPtr("mattie.hoffman@sakilacustomer.org"),
Email: ptr.Of("mattie.hoffman@sakilacustomer.org"),
AddressID: 256,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0),
Active: testutils.Int32Ptr(1),
Active: ptr.Of(int32(1)),
}
var customer512 = model.Customer{
@ -1113,12 +1114,12 @@ var customer512 = model.Customer{
StoreID: 1,
FirstName: "Cecil",
LastName: "Vines",
Email: testutils.StringPtr("cecil.vines@sakilacustomer.org"),
Email: ptr.Of("cecil.vines@sakilacustomer.org"),
AddressID: 517,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0),
Active: testutils.Int32Ptr(1),
Active: ptr.Of(int32(1)),
}
var countryUk = model.Country{
@ -1151,32 +1152,32 @@ var inventory2 = model.Inventory{
var film1 = model.Film{
FilmID: 1,
Title: "Academy Dinosaur",
Description: testutils.StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"),
ReleaseYear: testutils.Int32Ptr(2006),
Description: ptr.Of("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"),
ReleaseYear: ptr.Of(int32(2006)),
LanguageID: 1,
RentalDuration: 6,
RentalRate: 0.99,
Length: testutils.Int16Ptr(86),
Length: ptr.Of(int16(86)),
ReplacementCost: 20.99,
Rating: &pgRating,
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: testutils.StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"),
SpecialFeatures: ptr.Of("{\"Deleted Scenes\",\"Behind the Scenes\"}"),
Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17",
}
var film2 = model.Film{
FilmID: 2,
Title: "Ace Goldfinger",
Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
ReleaseYear: testutils.Int32Ptr(2006),
Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
ReleaseYear: ptr.Of(int32(2006)),
LanguageID: 1,
RentalDuration: 3,
RentalRate: 4.99,
Length: testutils.Int16Ptr(48),
Length: ptr.Of(int16(48)),
ReplacementCost: 12.99,
Rating: &gRating,
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: testutils.StringPtr(`{Trailers,"Deleted Scenes"}`),
SpecialFeatures: ptr.Of(`{Trailers,"Deleted Scenes"}`),
Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`,
}

View file

@ -3,6 +3,7 @@ package postgres
import (
"context"
"database/sql"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"testing"
"time"
@ -1828,16 +1829,16 @@ ORDER BY film.film_id ASC;
testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{
FilmID: 2,
Title: "Ace Goldfinger",
Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
ReleaseYear: testutils.Int32Ptr(2006),
Description: ptr.Of("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
ReleaseYear: ptr.Of(int32(2006)),
LanguageID: 1,
RentalRate: 4.99,
Length: testutils.Int16Ptr(48),
Length: ptr.Of(int16(48)),
ReplacementCost: 12.99,
Rating: &gRating,
RentalDuration: 3,
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: testutils.StringPtr("{Trailers,\"Deleted Scenes\"}"),
SpecialFeatures: ptr.Of("{Trailers,\"Deleted Scenes\"}"),
Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14",
})
}
@ -2286,11 +2287,11 @@ ORDER BY customer_payment_sum.amount_sum ASC;
FirstName: "Brian",
LastName: "Wyman",
AddressID: 323,
Email: testutils.StringPtr("brian.wyman@sakilacustomer.org"),
Email: ptr.Of("brian.wyman@sakilacustomer.org"),
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
Active: ptr.Of(int32(1)),
})
require.Equal(t, customersWithAmounts[0].AmountSum, 27.93)
@ -3110,8 +3111,8 @@ func TestSelectDynamicCondition(t *testing.T) {
Active *bool
}
request.CustomerID = testutils.Int64Ptr(1)
request.Active = testutils.BoolPtr(true)
request.CustomerID = ptr.Of(int64(1))
request.Active = ptr.Of(true)
// ...
@ -3871,12 +3872,12 @@ var customer0 = model.Customer{
StoreID: 1,
FirstName: "Mary",
LastName: "Smith",
Email: testutils.StringPtr("mary.smith@sakilacustomer.org"),
Email: ptr.Of("mary.smith@sakilacustomer.org"),
AddressID: 5,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
Active: ptr.Of(int32(1)),
}
var customer1 = model.Customer{
@ -3884,12 +3885,12 @@ var customer1 = model.Customer{
StoreID: 1,
FirstName: "Patricia",
LastName: "Johnson",
Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"),
Email: ptr.Of("patricia.johnson@sakilacustomer.org"),
AddressID: 6,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
Active: ptr.Of(int32(1)),
}
var lastCustomer = model.Customer{
@ -3897,10 +3898,10 @@ var lastCustomer = model.Customer{
StoreID: 2,
FirstName: "Austin",
LastName: "Cintron",
Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"),
Email: ptr.Of("austin.cintron@sakilacustomer.org"),
AddressID: 605,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
Active: ptr.Of(int32(1)),
}

View file

@ -0,0 +1,284 @@
package postgres
import (
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestVALUES(t *testing.T) {
values := VALUES(
WRAP(Int32(1), Int32(2), Float32(4.666), Bool(false), String("txt")),
WRAP(Int32(11).ADD(Int32(2)), Int32(22), Float32(33.222), Bool(true), String("png")),
WRAP(Int32(11), Int32(22), Float32(33.222), Bool(true), NULL),
).AS("values_table")
stmt := SELECT(
values.AllColumns(),
).FROM(
values,
)
testutils.AssertStatementSql(t, stmt, `
SELECT values_table.column1 AS "column1",
values_table.column2 AS "column2",
values_table.column3 AS "column3",
values_table.column4 AS "column4",
values_table.column5 AS "column5"
FROM (
VALUES ($1::integer, $2::integer, $3::real, $4::boolean, $5::text),
($6::integer + $7::integer, $8::integer, $9::real, $10::boolean, $11::text),
($12::integer, $13::integer, $14::real, $15::boolean, NULL)
) AS values_table;
`)
var dest []struct {
Column1 int
Column2 int
Column3 float32
Column4 bool
Column5 *string
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"Column1": 1,
"Column2": 2,
"Column3": 4.666,
"Column4": false,
"Column5": "txt"
},
{
"Column1": 13,
"Column2": 22,
"Column3": 33.222,
"Column4": true,
"Column5": "png"
},
{
"Column1": 11,
"Column2": 22,
"Column3": 33.222,
"Column4": true,
"Column5": null
}
]
`)
}
func TestVALUES_Join(t *testing.T) {
title := StringColumn("title")
releaseYear := IntegerColumn("ReleaseYear")
rentalRate := FloatColumn("rental_rate")
lastUpdate := Timestamp(2007, time.February, 11, 12, 0, 0)
filmValues := VALUES(
WRAP(String("Chamber Italian"), Int64(117), Int32(2005), Float32(5.82), lastUpdate),
WRAP(String("Grosse Wonderful"), Int64(49), Int32(2004), Float32(6.242), lastUpdate.ADD(INTERVAL(1, HOUR))),
WRAP(String("Airport Pollock"), Int64(54), Int32(2001), Float32(7.22), NULL),
WRAP(String("Bright Encounters"), Int64(73), Int32(2002), Float32(8.25), NULL),
WRAP(String("Academy Dinosaur"), Int64(83), Int32(2010), Float32(9.22), lastUpdate.SUB(INTERVAL(2, MINUTE))),
).AS("film_values",
title, IntegerColumn("length"), releaseYear, rentalRate, TimestampColumn("update_date"))
stmt := SELECT(
Film.AllColumns,
filmValues.AllColumns(),
).FROM(
Film.
INNER_JOIN(filmValues, title.EQ(Film.Title)),
).WHERE(AND(
Film.ReleaseYear.GT(releaseYear),
Film.RentalRate.LT(rentalRate),
)).ORDER_BY(
title,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT film.film_id AS "film.film_id",
film.title AS "film.title",
film.description AS "film.description",
film.release_year AS "film.release_year",
film.language_id AS "film.language_id",
film.rental_duration AS "film.rental_duration",
film.rental_rate AS "film.rental_rate",
film.length AS "film.length",
film.replacement_cost AS "film.replacement_cost",
film.rating AS "film.rating",
film.last_update AS "film.last_update",
film.special_features AS "film.special_features",
film.fulltext AS "film.fulltext",
film_values.title AS "title",
film_values.length AS "length",
film_values."ReleaseYear" AS "ReleaseYear",
film_values.rental_rate AS "rental_rate",
film_values.update_date AS "update_date"
FROM dvds.film
INNER JOIN (
VALUES ('Chamber Italian'::text, 117::bigint, 2005::integer, 5.820000171661377::real, '2007-02-11 12:00:00'::timestamp without time zone),
('Grosse Wonderful'::text, 49::bigint, 2004::integer, 6.242000102996826::real, '2007-02-11 12:00:00'::timestamp without time zone + INTERVAL '1 HOUR'),
('Airport Pollock'::text, 54::bigint, 2001::integer, 7.21999979019165::real, NULL),
('Bright Encounters'::text, 73::bigint, 2002::integer, 8.25::real, NULL),
('Academy Dinosaur'::text, 83::bigint, 2010::integer, 9.220000267028809::real, '2007-02-11 12:00:00'::timestamp without time zone - INTERVAL '2 MINUTE')
) AS film_values (title, length, "ReleaseYear", rental_rate, update_date) ON (film_values.title = film.title)
WHERE (
(film.release_year > film_values."ReleaseYear")
AND (film.rental_rate < film_values.rental_rate)
)
ORDER BY film_values.title;
`)
//fmt.Println(stmt.DebugSql())
var dest []struct {
Film model.Film
Title string
Length int
ReleaseYear int
RentalRate float32
UpdateDate *time.Time
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
assert.Len(t, dest, 4)
testutils.AssertJSON(t, dest[0:2], `
[
{
"Film": {
"FilmID": 8,
"Title": "Airport Pollock",
"Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India",
"ReleaseYear": 2006,
"LanguageID": 1,
"RentalDuration": 6,
"RentalRate": 4.99,
"Length": 54,
"ReplacementCost": 15.99,
"Rating": "R",
"LastUpdate": "2013-05-26T14:50:58.951Z",
"SpecialFeatures": "{Trailers}",
"Fulltext": "'airport':1 'ancient':18 'confront':14 'epic':4 'girl':11 'india':19 'monkey':16 'moos':8 'must':13 'pollock':2 'tale':5"
},
"Title": "Airport Pollock",
"Length": 54,
"ReleaseYear": 2001,
"RentalRate": 7.22,
"UpdateDate": null
},
{
"Film": {
"FilmID": 98,
"Title": "Bright Encounters",
"Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat",
"ReleaseYear": 2006,
"LanguageID": 1,
"RentalDuration": 4,
"RentalRate": 4.99,
"Length": 73,
"ReplacementCost": 12.99,
"Rating": "PG-13",
"LastUpdate": "2013-05-26T14:50:58.951Z",
"SpecialFeatures": "{Trailers}",
"Fulltext": "'boat':20 'bright':1 'conquer':14 'encount':2 'fate':4 'feminist':11 'jet':19 'lumberjack':8 'must':13 'student':16 'yarn':5"
},
"Title": "Bright Encounters",
"Length": 73,
"ReleaseYear": 2002,
"RentalRate": 8.25,
"UpdateDate": null
}
]
`)
}
func TestVALUES_CTE_Update(t *testing.T) {
paymentID := IntegerColumn("payment_ID")
increase := FloatColumn("increase")
paymentsToUpdate := CTE("values_cte", paymentID, increase)
stmt := WITH(
paymentsToUpdate.AS(
VALUES(
WRAP(Int32(20564), Float32(1.21)),
WRAP(Int32(20567), Float32(1.02)),
WRAP(Int32(20570), Float32(1.34)),
WRAP(Int32(20573), Float32(1.72)),
),
),
)(
Payment.UPDATE().
SET(
Payment.Amount.SET(Payment.Amount.MUL(CAST(increase).AS_DECIMAL())),
).
FROM(paymentsToUpdate).
WHERE(Payment.PaymentID.EQ(paymentID)).
RETURNING(Payment.AllColumns),
)
testutils.AssertDebugStatementSql(t, stmt, `
WITH values_cte ("payment_ID", increase) AS (
VALUES (20564::integer, 1.2100000381469727::real),
(20567::integer, 1.0199999809265137::real),
(20570::integer, 1.340000033378601::real),
(20573::integer, 1.7200000286102295::real)
)
UPDATE dvds.payment
SET amount = (payment.amount * values_cte.increase::decimal)
FROM values_cte
WHERE payment.payment_id = values_cte."payment_ID"
RETURNING payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date";
`)
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var payments []model.Payment
err := stmt.Query(tx, &payments)
require.NoError(t, err)
assert.Len(t, payments, 4)
testutils.AssertJSON(t, payments[0:2], `
[
{
"PaymentID": 20564,
"CustomerID": 379,
"StaffID": 2,
"RentalID": 11457,
"Amount": 4.83,
"PaymentDate": "2007-03-02T19:42:42.996577Z"
},
{
"PaymentID": 20567,
"CustomerID": 379,
"StaffID": 2,
"RentalID": 13397,
"Amount": 8.15,
"PaymentDate": "2007-03-19T20:35:01.996577Z"
}
]
`)
})
}

View file

@ -83,10 +83,10 @@ SELECT orders.ship_region AS "orders.ship_region",
SUM(order_details.quantity) AS "product_sales"
FROM northwind.orders
INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id)
WHERE orders.ship_region IN (
WHERE orders.ship_region IN ((
SELECT top_region."orders.ship_region" AS "orders.ship_region"
FROM top_region
)
))
GROUP BY orders.ship_region, order_details.product_id
ORDER BY SUM(order_details.quantity) DESC;
`)
@ -157,19 +157,19 @@ func TestWithStatementDeleteAndInsert(t *testing.T) {
testutils.AssertStatementSql(t, stmt, `
WITH remove_discontinued_orders AS (
DELETE FROM northwind.order_details
WHERE order_details.product_id IN (
WHERE order_details.product_id IN ((
SELECT products.product_id AS "products.product_id"
FROM northwind.products
WHERE products.discontinued = $1
)
))
RETURNING order_details.product_id AS "order_details.product_id"
),update_discontinued_price AS (
UPDATE northwind.products
SET unit_price = $2
WHERE products.product_id IN (
WHERE products.product_id IN ((
SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id"
FROM remove_discontinued_orders
)
))
RETURNING products.product_id AS "products.product_id",
products.product_name AS "products.product_name",
products.supplier_id AS "products.supplier_id",

View file

@ -2,6 +2,7 @@ package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/internal/utils/ptr"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table"
@ -153,43 +154,43 @@ func TestAllTypesInsert(t *testing.T) {
var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: testutils.BoolPtr(true),
BooleanPtr: ptr.Of(true),
TinyInt: 1,
SmallInt: 3,
MediumInt: 5,
Integer: 7,
BigInt: 9,
TinyIntPtr: testutils.Int8Ptr(11),
SmallIntPtr: testutils.Int16Ptr(33),
MediumIntPtr: testutils.Int32Ptr(55),
IntegerPtr: testutils.Int32Ptr(77),
BigIntPtr: testutils.Int64Ptr(99),
TinyIntPtr: ptr.Of(int8(11)),
SmallIntPtr: ptr.Of(int16(33)),
MediumIntPtr: ptr.Of(int32(55)),
IntegerPtr: ptr.Of(int32(77)),
BigIntPtr: ptr.Of(int64(99)),
Decimal: 11.22,
DecimalPtr: testutils.Float64Ptr(33.44),
DecimalPtr: ptr.Of(33.44),
Numeric: 55.66,
NumericPtr: testutils.Float64Ptr(77.88),
NumericPtr: ptr.Of(77.88),
Float: 99.00,
FloatPtr: testutils.Float64Ptr(11.22),
FloatPtr: ptr.Of(11.22),
Double: 33.44,
DoublePtr: testutils.Float64Ptr(55.66),
DoublePtr: ptr.Of(55.66),
Real: 77.88,
RealPtr: testutils.Float32Ptr(99.00),
RealPtr: ptr.Of(float32(99.00)),
Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC),
TimePtr: testutils.TimePtr(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)),
TimePtr: ptr.Of(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)),
Date: time.Now(),
DatePtr: testutils.TimePtr(time.Now()),
DatePtr: ptr.Of(time.Now()),
DateTime: time.Now(),
DateTimePtr: testutils.TimePtr(time.Now()),
DateTimePtr: ptr.Of(time.Now()),
Timestamp: time.Now(),
TimestampPtr: testutils.TimePtr(time.Now()),
TimestampPtr: ptr.Of(time.Now()),
Char: "abcd",
CharPtr: testutils.StringPtr("absd"),
CharPtr: ptr.Of("absd"),
VarChar: "abcd",
VarCharPtr: testutils.StringPtr("absd"),
VarCharPtr: ptr.Of("absd"),
Blob: []byte("large file"),
BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
BlobPtr: ptr.Of([]byte("very large file")),
Text: "some text",
TextPtr: testutils.StringPtr("text"),
TextPtr: ptr.Of("text"),
}
func TestUUID(t *testing.T) {
@ -233,18 +234,18 @@ func TestExpressionOperators(t *testing.T) {
SELECT all_types.integer IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in",
(all_types.small_int_ptr IN (
(all_types.small_int_ptr IN ((
SELECT all_types.integer AS "all_types.integer"
FROM all_types
)) AS "result.in_select",
))) AS "result.in_select",
(length(121232459)) AS "result.raw",
(? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg",
(? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN (
(all_types.small_int_ptr NOT IN ((
SELECT all_types.integer AS "all_types.integer"
FROM all_types
)) AS "result.not_in_select"
))) AS "result.not_in_select"
FROM all_types
LIMIT ?;
`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2))
@ -659,7 +660,7 @@ func TestExactDecimals(t *testing.T) {
// not overwritten
Numeric: "6.7",
NumericPtr: testutils.StringPtr("7.7"),
NumericPtr: ptr.Of("7.7"),
},
Decimal: decimal.RequireFromString("91.23"),
DecimalPtr: decimal.RequireFromString("45.67"),
@ -899,3 +900,36 @@ func TestDateTimeExpressions(t *testing.T) {
require.Equal(t, dest.JulianDay, 2.4551543576232754e+06)
require.Equal(t, dest.StrfTime, "20:34")
}
func TestRowExpression(t *testing.T) {
date := Date(2000, 9, 9)
time := Time(11, 22, 11)
dateTime := DateTime(2008, 11, 22, 10, 12, 40)
dateTime2 := DateTime(2011, 1, 2, 5, 12, 40)
stmt := SELECT(
ROW(Bool(false), date).EQ(ROW(Bool(true), date)),
ROW(Bool(false), time).NOT_EQ(ROW(Bool(true), time)),
ROW(time).IS_DISTINCT_FROM(RowExp(Raw("(time('now'))"))),
ROW(dateTime, dateTime2).GT(ROW(dateTime, dateTime2)),
ROW(dateTime2).GT_EQ(ROW(dateTime)),
ROW(dateTime, dateTime2).LT(ROW(dateTime, dateTime2)),
ROW(dateTime2).LT_EQ(ROW(dateTime2)),
)
//fmt.Println(stmt.Sql())
//fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, `
SELECT (FALSE, DATE('2000-09-09')) = (TRUE, DATE('2000-09-09')),
(FALSE, TIME('11:22:11')) != (TRUE, TIME('11:22:11')),
(TIME('11:22:11')) IS NOT ((time('now'))),
(DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) > (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')),
(DATETIME('2011-01-02 05:12:40')) >= (DATETIME('2008-11-22 10:12:40')),
(DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')) < (DATETIME('2008-11-22 10:12:40'), DATETIME('2011-01-02 05:12:40')),
(DATETIME('2011-01-02 05:12:40')) <= (DATETIME('2011-01-02 05:12:40'));
`)
err := stmt.Query(db, &struct{}{})
require.NoError(t, err)
}

View file

@ -8,8 +8,11 @@ import (
"github.com/stretchr/testify/require"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/testutils"
sqlite2 "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
)
@ -58,6 +61,36 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)
}
func TestGenerator_TableMetadata(t *testing.T) {
var schema metadata.Schema
err := sqlite.GenerateDSN(testDatabaseFilePath, genDestDir,
template.Default(sqlite2.Dialect).UseSchema(func(m metadata.Schema) template.Schema {
schema = m
return template.DefaultSchema(m)
}))
require.NoError(t, err)
// Spot check the actor table and assert that the emitted
// properties are as expected.
var got metadata.Table
for _, table := range schema.TablesMetaData {
if table.Name == "actor" {
got = table
}
}
want := metadata.Table{
Name: "actor",
Columns: []metadata.Column{
{Name: "actor_id", IsPrimaryKey: true, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "INTEGER", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "first_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "last_name", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: false, DataType: metadata.DataType{Name: "VARCHAR", Kind: "base", IsUnsigned: false}, Comment: ""},
{Name: "last_update", IsPrimaryKey: false, IsNullable: false, IsGenerated: false, HasDefault: true, DataType: metadata.DataType{Name: "TIMESTAMP", Kind: "base", IsUnsigned: false}, Comment: ""},
},
}
require.Equal(t, want, got)
}
func TestCmdGenerator(t *testing.T) {
cmd := exec.Command("jet", "-source=SQLite", "-dsn=file://"+testDatabaseFilePath, "-path="+genDestDir)

View file

@ -3,6 +3,8 @@ package sqlite
import (
"context"
"github.com/go-jet/jet/v2/qrm"
"database/sql"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"math/rand"
"testing"
@ -49,7 +51,7 @@ VALUES (?, ?, ?, ?),
ID: 101,
URL: "http://www.google.com",
Name: "Google",
Description: testutils.StringPtr("Search engine"),
Description: ptr.Of("Search engine"),
})
testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{
ID: 102,

View file

@ -8,16 +8,11 @@ import (
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/stretchr/testify/require"
"math/rand"
"os"
"os/exec"
"runtime"
"strings"
"testing"
"time"
"github.com/pkg/profile"
"github.com/stretchr/testify/require"
"os"
"runtime"
"testing"
_ "github.com/mattn/go-sqlite3"
)
@ -27,11 +22,8 @@ var sampleDB *sqlite.DB
var testRoot string
func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
setTestRoot()
sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath)
throw.OnError(err)
db = sqlite.NewDB(sqlDB).WithStatementsCaching(true)
@ -59,16 +51,6 @@ func TestMain(m *testing.M) {
}
}
func setTestRoot() {
cmd := exec.Command("git", "rev-parse", "--show-toplevel")
byteArr, err := cmd.Output()
if err != nil {
panic(err)
}
testRoot = strings.TrimSpace(string(byteArr)) + "/tests/"
}
var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string

View file

@ -3,6 +3,7 @@ package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/qrm"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/stretchr/testify/require"
"testing"
@ -54,7 +55,7 @@ WHERE people.people_id = ?;
).MODEL(
model.People{
PeopleName: "Dario",
PeopleHeightCm: testutils.Float64Ptr(190),
PeopleHeightCm: ptr.Of(190.0),
},
).RETURNING(
People.AllColumns,

View file

@ -2,6 +2,7 @@ package sqlite
import (
"context"
"github.com/go-jet/jet/v2/internal/utils/ptr"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table"
"strings"
@ -846,15 +847,15 @@ func TestSimpleView(t *testing.T) {
require.Equal(t, len(dest), 10)
require.Equal(t, dest[2], model.CustomerList{
ID: testutils.Int32Ptr(3),
Name: testutils.StringPtr("LINDA WILLIAMS"),
Address: testutils.StringPtr("692 Joliet Street"),
ZipCode: testutils.StringPtr("83579"),
Phone: testutils.StringPtr(" "),
City: testutils.StringPtr("Athenai"),
Country: testutils.StringPtr("Greece"),
Notes: testutils.StringPtr("active"),
Sid: testutils.Int32Ptr(1),
ID: ptr.Of(int32(3)),
Name: ptr.Of("LINDA WILLIAMS"),
Address: ptr.Of("692 Joliet Street"),
ZipCode: ptr.Of("83579"),
Phone: ptr.Of(" "),
City: ptr.Of("Athenai"),
Country: ptr.Of("Greece"),
Notes: ptr.Of("active"),
Sid: ptr.Of(int32(1)),
})
}

344
tests/sqlite/values_test.go Normal file
View file

@ -0,0 +1,344 @@
package sqlite
import (
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/stretchr/testify/require"
"strings"
"testing"
"time"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table"
)
func TestVALUES(t *testing.T) {
values := VALUES(
ROW(Int32(1), Int32(2), Float(4.666), Bool(false), String("txt")),
ROW(Int32(11).ADD(Int32(2)), Int32(22), Float(33.222), Bool(true), String("png")),
ROW(Int32(11), Int32(22), Float(33.222), Bool(true), NULL),
).AS("values_table")
stmt := SELECT(
values.AllColumns(),
).FROM(
values,
)
testutils.AssertStatementSql(t, stmt, `
SELECT values_table.column1 AS "column1",
values_table.column2 AS "column2",
values_table.column3 AS "column3",
values_table.column4 AS "column4",
values_table.column5 AS "column5"
FROM (
VALUES (?, ?, ?, ?, ?),
(? + ?, ?, ?, ?, ?),
(?, ?, ?, ?, NULL)
) AS values_table;
`)
var dest []struct {
Column1 int
Column2 int
Column3 float32
Column4 bool
Column5 *string
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"Column1": 1,
"Column2": 2,
"Column3": 4.666,
"Column4": false,
"Column5": "txt"
},
{
"Column1": 13,
"Column2": 22,
"Column3": 33.222,
"Column4": true,
"Column5": "png"
},
{
"Column1": 11,
"Column2": 22,
"Column3": 33.222,
"Column4": true,
"Column5": null
}
]
`)
}
func TestVALUES_Join(t *testing.T) {
lastUpdate := DateTime(2007, time.February, 11, 12, 0, 0)
films := VALUES(
ROW(String("Chamber Italian"), Int64(117), Int32(2005), Float(5.82), lastUpdate),
ROW(String("Grosse Wonderful"), Int64(49), Int32(2004), Float(6.242), lastUpdate),
ROW(String("Airport Pollock"), Int64(54), Int32(2001), Float(7.22), NULL),
ROW(String("Bright Encounters"), Int64(73), Int32(2002), Float(8.25), NULL),
ROW(String("Academy Dinosaur"), Int64(83), Int32(2010), Float(9.22), DATETIME(lastUpdate, YEARS(2))),
).AS("film_values")
title := StringColumn("column1").From(films)
releaseYear := IntegerColumn("column3").From(films)
rentalRate := FloatColumn("column4").From(films)
stmt := SELECT(
Film.AllColumns,
films.AllColumns(),
).FROM(
Film.
INNER_JOIN(films, LOWER(title).EQ(LOWER(Film.Title))),
).WHERE(AND(
CAST(Film.ReleaseYear).AS_INTEGER().GT(releaseYear),
Film.RentalRate.LT(rentalRate),
)).ORDER_BY(
title,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT film.film_id AS "film.film_id",
film.title AS "film.title",
film.description AS "film.description",
film.release_year AS "film.release_year",
film.language_id AS "film.language_id",
film.original_language_id AS "film.original_language_id",
film.rental_duration AS "film.rental_duration",
film.rental_rate AS "film.rental_rate",
film.length AS "film.length",
film.replacement_cost AS "film.replacement_cost",
film.rating AS "film.rating",
film.special_features AS "film.special_features",
film.last_update AS "film.last_update",
film_values.column1 AS "column1",
film_values.column2 AS "column2",
film_values.column3 AS "column3",
film_values.column4 AS "column4",
film_values.column5 AS "column5"
FROM film
INNER JOIN (
VALUES ('Chamber Italian', 117, 2005, 5.82, DATETIME('2007-02-11 12:00:00')),
('Grosse Wonderful', 49, 2004, 6.242, DATETIME('2007-02-11 12:00:00')),
('Airport Pollock', 54, 2001, 7.22, NULL),
('Bright Encounters', 73, 2002, 8.25, NULL),
('Academy Dinosaur', 83, 2010, 9.22, DATETIME(DATETIME('2007-02-11 12:00:00'), '2 YEARS'))
) AS film_values ON (LOWER(film_values.column1) = LOWER(film.title))
WHERE (
(CAST(film.release_year AS INTEGER) > film_values.column3)
AND (film.rental_rate < film_values.column4)
)
ORDER BY film_values.column1;
`)
var dest []struct {
Film model.Film
Column1 string
Column2 int
Column3 int
Column4 float32
Column5 *time.Time
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"Film": {
"FilmID": 8,
"Title": "AIRPORT POLLOCK",
"Description": "A Epic Tale of a Moose And a Girl who must Confront a Monkey in Ancient India",
"ReleaseYear": "2006",
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 6,
"RentalRate": 4.99,
"Length": 54,
"ReplacementCost": 15.99,
"Rating": "R",
"SpecialFeatures": "Trailers",
"LastUpdate": "2019-04-11T18:11:48Z"
},
"Column1": "Airport Pollock",
"Column2": 54,
"Column3": 2001,
"Column4": 7.22,
"Column5": null
},
{
"Film": {
"FilmID": 98,
"Title": "BRIGHT ENCOUNTERS",
"Description": "A Fateful Yarn of a Lumberjack And a Feminist who must Conquer a Student in A Jet Boat",
"ReleaseYear": "2006",
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 4,
"RentalRate": 4.99,
"Length": 73,
"ReplacementCost": 12.99,
"Rating": "PG-13",
"SpecialFeatures": "Trailers",
"LastUpdate": "2019-04-11T18:11:48Z"
},
"Column1": "Bright Encounters",
"Column2": 73,
"Column3": 2002,
"Column4": 8.25,
"Column5": null
},
{
"Film": {
"FilmID": 133,
"Title": "CHAMBER ITALIAN",
"Description": "A Fateful Reflection of a Moose And a Husband who must Overcome a Monkey in Nigeria",
"ReleaseYear": "2006",
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 7,
"RentalRate": 4.99,
"Length": 117,
"ReplacementCost": 14.99,
"Rating": "NC-17",
"SpecialFeatures": "Trailers",
"LastUpdate": "2019-04-11T18:11:48Z"
},
"Column1": "Chamber Italian",
"Column2": 117,
"Column3": 2005,
"Column4": 5.82,
"Column5": "2007-02-11T12:00:00Z"
},
{
"Film": {
"FilmID": 384,
"Title": "GROSSE WONDERFUL",
"Description": "A Epic Drama of a Cat And a Explorer who must Redeem a Moose in Australia",
"ReleaseYear": "2006",
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 5,
"RentalRate": 4.99,
"Length": 49,
"ReplacementCost": 19.99,
"Rating": "R",
"SpecialFeatures": "Behind the Scenes",
"LastUpdate": "2019-04-11T18:11:48Z"
},
"Column1": "Grosse Wonderful",
"Column2": 49,
"Column3": 2004,
"Column4": 6.242,
"Column5": "2007-02-11T12:00:00Z"
}
]
`)
}
func TestVALUES_CTE_Update(t *testing.T) {
paymentID := IntegerColumn("payment_ID")
increase := FloatColumn("increase")
paymentsToUpdate := CTE("values_cte", paymentID, increase)
stmt := WITH(
paymentsToUpdate.AS(
VALUES(
ROW(Int32(204), Float(1.21)),
ROW(Int32(207), Float(1.02)),
ROW(Int32(200), Float(1.34)),
ROW(Int32(203), Float(1.72)),
),
),
)(
Payment.UPDATE().
SET(
Payment.Amount.SET(Payment.Amount.MUL(increase)),
).
FROM(paymentsToUpdate).
WHERE(Payment.PaymentID.EQ(paymentID)).
RETURNING(Payment.AllColumns),
)
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
WITH values_cte (''payment_ID'', increase) AS (
VALUES (?, ?),
(?, ?),
(?, ?),
(?, ?)
)
UPDATE payment
SET amount = (payment.amount * values_cte.increase)
FROM values_cte
WHERE payment.payment_id = values_cte.''payment_ID''
RETURNING payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date",
payment.last_update AS "payment.last_update";
`, "''", "`"))
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var payments []model.Payment
err := stmt.Query(tx, &payments)
require.NoError(t, err)
testutils.AssertJSON(t, payments, `
[
{
"PaymentID": 200,
"CustomerID": 7,
"StaffID": 2,
"RentalID": 11542,
"Amount": 10.706600000000002,
"PaymentDate": "2005-08-17T00:51:32Z",
"LastUpdate": "2019-04-11T18:11:50Z"
},
{
"PaymentID": 203,
"CustomerID": 7,
"StaffID": 2,
"RentalID": 13373,
"Amount": 5.1428,
"PaymentDate": "2005-08-19T21:23:31Z",
"LastUpdate": "2019-04-11T18:11:50Z"
},
{
"PaymentID": 204,
"CustomerID": 7,
"StaffID": 1,
"RentalID": 13476,
"Amount": 3.6179,
"PaymentDate": "2005-08-20T01:06:04Z",
"LastUpdate": "2019-04-11T18:11:50Z"
},
{
"PaymentID": 207,
"CustomerID": 8,
"StaffID": 2,
"RentalID": 866,
"Amount": 7.1298,
"PaymentDate": "2005-05-30T03:43:54Z",
"LastUpdate": "2019-04-11T18:11:50Z"
}
]
`)
})
}

View file

@ -153,10 +153,10 @@ WITH payments_to_update AS (
)
UPDATE payment
SET amount = 0
WHERE payment.payment_id IN (
WHERE payment.payment_id IN ((
SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_update
);
));
`, "''", "`", -1))
tx := beginDBTx(t)
@ -205,10 +205,10 @@ WITH payments_to_delete AS (
WHERE payment.amount < 0.5
)
DELETE FROM payment
WHERE payment.payment_id IN (
WHERE payment.payment_id IN ((
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete
);
));
`, "''", "`", -1))
tx := beginDBTx(t)

@ -1 +1 @@
Subproject commit 915bdc16b723d89becc577c780949baef861a6ae
Subproject commit 6a397747d310938b41d3950d68009578180d3dd5