Merge pull request #457 from go-jet/select_json

Add support for SELECT_JSON statements
This commit is contained in:
go-jet 2025-03-09 18:43:49 +01:00 committed by GitHub
commit 1f3215c879
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
104 changed files with 5249 additions and 900 deletions

View file

@ -11,7 +11,7 @@ jobs:
- image: cimg/go:1.22.8 - image: cimg/go:1.22.8
# Please keep the version in sync with test/docker-compose.yaml # Please keep the version in sync with test/docker-compose.yaml
- image: cimg/postgres:14.10 - image: cimg/postgres:14.1
environment: environment:
POSTGRES_USER: jet POSTGRES_USER: jet
POSTGRES_PASSWORD: jet POSTGRES_PASSWORD: jet
@ -19,7 +19,7 @@ jobs:
PGPORT: 50901 PGPORT: 50901
# Please keep the version in sync with test/docker-compose.yaml # Please keep the version in sync with test/docker-compose.yaml
- image: circleci/mysql:8.0.27 - image: cimg/mysql:8.0.27
command: [ --default-authentication-plugin=mysql_native_password ] command: [ --default-authentication-plugin=mysql_native_password ]
environment: environment:
MYSQL_ROOT_PASSWORD: jet MYSQL_ROOT_PASSWORD: jet
@ -29,7 +29,7 @@ jobs:
MYSQL_TCP_PORT: 50902 MYSQL_TCP_PORT: 50902
# Please keep the version in sync with test/docker-compose.yaml # Please keep the version in sync with test/docker-compose.yaml
- image: circleci/mariadb:10.3 - image: cimg/mariadb:11.4
command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ] command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ]
environment: environment:
MYSQL_ROOT_PASSWORD: jet MYSQL_ROOT_PASSWORD: jet
@ -116,25 +116,27 @@ jobs:
name: Create MySQL/MariaDB user and test databases name: Create MySQL/MariaDB user and test databases
command: | command: |
mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_ALL_TABLES,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database test_sample" mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database dvds2" mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database dvds2"
mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_ALL_TABLES,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database test_sample" mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database dvds2" mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database dvds2"
- run:
name: Init databases
command: |
cd tests
go run ./init/init.go -testsuite all
- run: - run:
name: Install gotestsum name: Install gotestsum
command: go install gotest.tools/gotestsum@latest command: go install gotest.tools/gotestsum@latest
- run:
name: Init databases (postgres, mysql, sqlite) and generate jet files
command: |
cd tests
go run ./init/init.go -testsuite postgres
go run ./init/init.go -testsuite mysql
go run ./init/init.go -testsuite sqlite
# to create test results report # to create test results report
- run: mkdir -p $TEST_RESULTS - run: mkdir -p $TEST_RESULTS
@ -146,14 +148,14 @@ jobs:
name: Running tests with statement caching enabled name: Running tests with statement caching enabled
command: JET_TESTS_WITH_STMT_CACHE=true go test -tags postgres -v ./tests/... command: JET_TESTS_WITH_STMT_CACHE=true go test -tags postgres -v ./tests/...
# run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests
- run: - run:
name: Jet generate mariadb and cockroachdb name: Init databases (mariadb, cockroachdb) and generate jet files
command: | command: |
cd tests cd tests
make jet-gen-mariadb go run ./init/init.go -testsuite mariadb
make jet-gen-cockroach go run ./init/init.go -testsuite cockroach
# run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests
- run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/
- run: PG_SOURCE=COCKROACH_DB go test -v ./tests/postgres/ - run: PG_SOURCE=COCKROACH_DB go test -v ./tests/postgres/

View file

@ -579,5 +579,5 @@ To run the tests, additional dependencies are required:
## License ## License
Copyright 2019-2024 Goran Bjelanovic Copyright 2019-2025 Goran Bjelanovic
Licensed under the Apache License, Version 2.0. Licensed under the Apache License, Version 2.0.

View file

@ -18,8 +18,8 @@ func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTyp
SELECT SELECT
t.table_name as "table.name", t.table_name as "table.name",
col.COLUMN_NAME AS "column.Name", col.COLUMN_NAME AS "column.Name",
col.COLUMN_DEFAULT IS NOT NULL AND t.table_type != 'VIEW' as "column.HasDefault", (col.COLUMN_DEFAULT IS NOT NULL AND col.COLUMN_DEFAULT != 'NULL') AND t.table_type != 'VIEW' as "column.HasDefault",
col.IS_NULLABLE = "YES" AS "column.IsNullable", col.IS_NULLABLE = 'YES' AS "column.IsNullable",
col.COLUMN_COMMENT AS "column.Comment", col.COLUMN_COMMENT AS "column.Comment",
COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey", COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey",
IF (col.COLUMN_TYPE = 'tinyint(1)', IF (col.COLUMN_TYPE = 'tinyint(1)',

View file

@ -180,11 +180,15 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string {
return "Timez" return "Timez"
case "interval": case "interval":
return "Interval" return "Interval"
case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", case "user-defined", "enum", "text", "character", "character varying", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", "char", "varchar", "nvarchar", "bpchar", "varbit",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL "tinytext", "mediumtext", "longtext": // MySQL
return "String" return "String"
case "bytea": // postgres
return "Bytea"
case "binary", "varbinary", "tinyblob", "mediumblob", "longblob", "blob": // mysql and sqlite
return "Blob"
case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", case "real", "numeric", "decimal", "double precision", "float", "float4", "float8",
"double": // MySQL "double": // MySQL
return "Float" return "Float"

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/go-jet/jet/v2 module github.com/go-jet/jet/v2
go 1.21 go 1.22
// used by jet generator // used by jet generator
require ( require (

View file

@ -40,14 +40,23 @@ func snakeToCamel(s string, upperCase bool) string {
if upperCase || i > 0 { if upperCase || i > 0 {
result += camelizeWord(word, len(words) > 1) result += camelizeWord(word, len(words) > 1)
} else { } else { // lowerCase and i == 0
result += word result += toLowerFirstLetter(word)
} }
} }
return result return result
} }
func toLowerFirstLetter(s string) string {
if s == "" {
return s
}
runes := []rune(s)
runes[0] = unicode.ToLower(runes[0])
return string(runes)
}
func camelizeWord(word string, force bool) string { func camelizeWord(word string, force bool) string {
runes := []rune(word) runes := []rune(word)

View file

@ -7,7 +7,10 @@ import (
func TestSnakeToCamel(t *testing.T) { func TestSnakeToCamel(t *testing.T) {
require.Equal(t, SnakeToCamel(""), "") require.Equal(t, SnakeToCamel(""), "")
require.Equal(t, SnakeToCamel("_", false), "")
require.Equal(t, SnakeToCamel("potato_"), "Potato") require.Equal(t, SnakeToCamel("potato_"), "Potato")
require.Equal(t, SnakeToCamel("potato_", false), "potato")
require.Equal(t, SnakeToCamel("Potato_", false), "potato")
require.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased") require.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
require.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID") require.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
require.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier") require.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")

View file

@ -18,10 +18,45 @@ func (a *alias) fromImpl(subQuery SelectTable) Projection {
// Generated columns have default aliasing. // Generated columns have default aliasing.
tableName, columnName := extractTableAndColumnName(a.alias) tableName, columnName := extractTableAndColumnName(a.alias)
column := NewColumnImpl(columnName, tableName, nil) newDummyColumn := newDummyColumnForExpression(a.expression, columnName)
column.subQuery = subQuery newDummyColumn.setTableName(tableName)
newDummyColumn.setSubQuery(subQuery)
return &column return newDummyColumn
}
// This function is used to create dummy columns when exporting sub-query columns using subQuery.AllColumns()
// In most case we don't care about type of the column, except when sub-query columns are used as SELECT_JSON projection.
// We need to know type to encode value for json unmarshal. At the moment only bool, time and blob columns are of interest,
// so we don't have to support every column type.
func newDummyColumnForExpression(exp Expression, name string) ColumnExpression {
switch exp.(type) {
case BoolExpression:
return BoolColumn(name)
case IntegerExpression:
return IntegerColumn(name)
case FloatExpression:
return FloatColumn(name)
case BlobExpression:
return BlobColumn(name)
case DateExpression:
return DateColumn(name)
case TimeExpression:
return TimeColumn(name)
case TimezExpression:
return TimezColumn(name)
case TimestampExpression:
return TimestampColumn(name)
case TimestampzExpression:
return TimestampzColumn(name)
case IntervalExpression:
return IntervalColumn(name)
case StringExpression:
return StringColumn(name)
}
return StringColumn(name)
} }
func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder) { func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder) {
@ -30,3 +65,15 @@ func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder)
out.WriteString("AS") out.WriteString("AS")
out.WriteAlias(a.alias) out.WriteAlias(a.alias)
} }
func (a *alias) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
out.WriteJsonObjKey(a.alias)
a.expression.serializeForJsonValue(statement, out)
}
func (a *alias) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
a.expression.serializeForJsonValue(statement, out)
out.WriteString("AS")
out.WriteAlias(a.alias)
}

View file

@ -0,0 +1,104 @@
package jet
// BlobExpression interface
type BlobExpression interface {
Expression
isStringOrBlob()
EQ(rhs BlobExpression) BoolExpression
NOT_EQ(rhs BlobExpression) BoolExpression
IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression
LT(rhs BlobExpression) BoolExpression
LT_EQ(rhs BlobExpression) BoolExpression
GT(rhs BlobExpression) BoolExpression
GT_EQ(rhs BlobExpression) BoolExpression
BETWEEN(min, max BlobExpression) BoolExpression
NOT_BETWEEN(min, max BlobExpression) BoolExpression
CONCAT(rhs BlobExpression) BlobExpression
LIKE(pattern BlobExpression) BoolExpression
NOT_LIKE(pattern BlobExpression) BoolExpression
}
type blobInterfaceImpl struct {
parent BlobExpression
}
func (b *blobInterfaceImpl) isStringOrBlob() {}
func (b *blobInterfaceImpl) EQ(rhs BlobExpression) BoolExpression {
return Eq(b.parent, rhs)
}
func (b *blobInterfaceImpl) NOT_EQ(rhs BlobExpression) BoolExpression {
return NotEq(b.parent, rhs)
}
func (b *blobInterfaceImpl) IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression {
return IsDistinctFrom(b.parent, rhs)
}
func (b *blobInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression {
return IsNotDistinctFrom(b.parent, rhs)
}
func (b *blobInterfaceImpl) GT(rhs BlobExpression) BoolExpression {
return Gt(b.parent, rhs)
}
func (b *blobInterfaceImpl) GT_EQ(rhs BlobExpression) BoolExpression {
return GtEq(b.parent, rhs)
}
func (b *blobInterfaceImpl) LT(rhs BlobExpression) BoolExpression {
return Lt(b.parent, rhs)
}
func (b *blobInterfaceImpl) LT_EQ(rhs BlobExpression) BoolExpression {
return LtEq(b.parent, rhs)
}
func (b *blobInterfaceImpl) BETWEEN(min, max BlobExpression) BoolExpression {
return NewBetweenOperatorExpression(b.parent, min, max, false)
}
func (b *blobInterfaceImpl) NOT_BETWEEN(min, max BlobExpression) BoolExpression {
return NewBetweenOperatorExpression(b.parent, min, max, true)
}
func (b *blobInterfaceImpl) CONCAT(rhs BlobExpression) BlobExpression {
return BlobExp(newBinaryStringOperatorExpression(b.parent, rhs, StringConcatOperator))
}
func (b *blobInterfaceImpl) LIKE(pattern BlobExpression) BoolExpression {
return newBinaryBoolOperatorExpression(b.parent, pattern, "LIKE")
}
func (b *blobInterfaceImpl) NOT_LIKE(pattern BlobExpression) BoolExpression {
return newBinaryBoolOperatorExpression(b.parent, pattern, "NOT LIKE")
}
//---------------------------------------------------//
type blobExpressionWrapper struct {
Expression
blobInterfaceImpl
}
func newBlobExpressionWrap(expression Expression) BlobExpression {
blobExpressionWrap := &blobExpressionWrapper{Expression: expression}
blobExpressionWrap.blobInterfaceImpl.parent = blobExpressionWrap
expression.setParent(blobExpressionWrap)
return blobExpressionWrap
}
// BlobExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as blob expression.
// Does not add sql cast to generated sql builder output.
func BlobExp(expression Expression) BlobExpression {
return newBlobExpressionWrap(expression)
}

View file

@ -102,9 +102,10 @@ type boolExpressionWrapper struct {
} }
func newBoolExpressionWrap(expression Expression) BoolExpression { func newBoolExpressionWrap(expression Expression) BoolExpression {
boolExpressionWrap := boolExpressionWrapper{Expression: expression} boolExpressionWrap := &boolExpressionWrapper{Expression: expression}
boolExpressionWrap.boolInterfaceImpl.parent = &boolExpressionWrap boolExpressionWrap.boolInterfaceImpl.parent = boolExpressionWrap
return &boolExpressionWrap expression.setParent(boolExpressionWrap)
return boolExpressionWrap
} }
// BoolExp is bool expression wrapper around arbitrary expression. // BoolExp is bool expression wrapper around arbitrary expression.

View file

@ -41,6 +41,8 @@ type ClauseSelect struct {
DistinctOnColumns []ColumnExpression DistinctOnColumns []ColumnExpression
ProjectionList []Projection ProjectionList []Projection
IsForRowToJson bool
// MySQL only // MySQL only
OptimizerHints optimizerHints OptimizerHints optimizerHints
} }
@ -52,6 +54,10 @@ func (s *ClauseSelect) Projections() ProjectionList {
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.NewLine() out.NewLine()
out.WriteString("SELECT") out.WriteString("SELECT")
s.OptimizerHints.Serialize(statementType, out, options...) s.OptimizerHints.Serialize(statementType, out, options...)
@ -66,12 +72,14 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
out.WriteByte(')') out.WriteByte(')')
} }
if len(s.ProjectionList) == 0 { if s.IsForRowToJson {
panic("jet: SELECT clause has to have at least one projection") out.IncreaseIdent()
} out.WriteRowToJsonProjections(statementType, s.ProjectionList)
out.DecreaseIdent()
} else {
out.WriteProjections(statementType, s.ProjectionList) out.WriteProjections(statementType, s.ProjectionList)
} }
}
// ClauseFrom struct // ClauseFrom struct
type ClauseFrom struct { type ClauseFrom struct {

View file

@ -2,6 +2,10 @@
package jet package jet
import (
"github.com/go-jet/jet/v2/internal/3rdparty/snaker"
)
// Column is common column interface for all types of columns. // Column is common column interface for all types of columns.
type Column interface { type Column interface {
Name() string Name() string
@ -35,19 +39,19 @@ type ColumnExpressionImpl struct {
} }
// NewColumnImpl creates new ColumnExpressionImpl // NewColumnImpl creates new ColumnExpressionImpl
func NewColumnImpl(name string, tableName string, parent ColumnExpression) ColumnExpressionImpl { func NewColumnImpl(name string, tableName string, parent ColumnExpression) *ColumnExpressionImpl {
bc := ColumnExpressionImpl{ newColumn := &ColumnExpressionImpl{
name: name, name: name,
tableName: tableName, tableName: tableName,
} }
if parent != nil { if parent != nil {
bc.ExpressionInterfaceImpl.Parent = parent newColumn.ExpressionInterfaceImpl.Parent = parent
} else { } else {
bc.ExpressionInterfaceImpl.Parent = &bc newColumn.ExpressionInterfaceImpl.Parent = newColumn
} }
return bc return newColumn
} }
// Name returns name of the column // Name returns name of the column
@ -76,13 +80,6 @@ func (c *ColumnExpressionImpl) defaultAlias() string {
return c.name return c.name
} }
func (c *ColumnExpressionImpl) fromImpl(subQuery SelectTable) Projection {
newColumn := NewColumnImpl(c.name, c.tableName, nil)
newColumn.setSubQuery(subQuery)
return &newColumn
}
func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
if statement == SetStatementType { if statement == SetStatementType {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
@ -93,14 +90,28 @@ func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out
c.serialize(statement, out) c.serialize(statement, out)
} }
func (c ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { func (c *ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
c.serialize(statement, out) c.serialize(statement, out)
out.WriteString("AS") out.WriteString("AS")
out.WriteAlias(c.defaultAlias()) out.WriteAlias(c.defaultAlias())
} }
func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (c *ColumnExpressionImpl) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
out.WriteJsonObjKey(snaker.SnakeToCamel(c.name, false))
c.Parent.serializeForJsonValue(statement, out)
}
func (c *ColumnExpressionImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
c.Parent.serializeForJsonValue(statement, out)
out.WriteString("AS")
out.WriteAlias(snaker.SnakeToCamel(c.name, false))
}
func (c *ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.subQuery != nil { if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias()) out.WriteIdentifier(c.subQuery.Alias())

View file

@ -78,6 +78,18 @@ func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBui
SerializeProjectionList(statement, projections, out) SerializeProjectionList(statement, projections, out)
} }
func (cl ColumnList) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionListJsonObj(statement, projections, out)
}
func (cl ColumnList) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
out.WriteRowToJsonProjections(statement, projections)
}
// dummy column interface implementation // dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface // Name is placeholder for ColumnList to implement Column interface

View file

@ -4,11 +4,10 @@ import "testing"
func TestColumn(t *testing.T) { func TestColumn(t *testing.T) {
column := NewColumnImpl("col", "", nil) column := NewColumnImpl("col", "", nil)
column.ExpressionInterfaceImpl.Parent = &column
assertClauseSerialize(t, column, "col") assertClauseSerialize(t, column, "col")
column.setTableName("table1") column.setTableName("table1")
assertClauseSerialize(t, column, "table1.col") assertClauseSerialize(t, column, "table1.col")
assertProjectionSerialize(t, &column, `table1.col AS "table1.col"`) assertProjectionSerialize(t, column, `table1.col AS "table1.col"`)
assertProjectionSerialize(t, column.AS("alias1"), `table1.col AS "alias1"`) assertProjectionSerialize(t, column.AS("alias1"), `table1.col AS "alias1"`)
} }

View file

@ -11,7 +11,11 @@ type ColumnBool interface {
type boolColumnImpl struct { type boolColumnImpl struct {
boolInterfaceImpl boolInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
@ -51,7 +55,11 @@ type ColumnFloat interface {
type floatColumnImpl struct { type floatColumnImpl struct {
floatInterfaceImpl floatInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
@ -92,7 +100,11 @@ type ColumnInteger interface {
type integerColumnImpl struct { type integerColumnImpl struct {
integerInterfaceImpl integerInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
@ -122,7 +134,7 @@ func IntegerColumn(name string) ColumnInteger {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnString is interface for SQL text, character, character varying // ColumnString is interface for SQL text, character, character varying
// bytea, uuid columns and enums types. // uuid columns and enums types.
type ColumnString interface { type ColumnString interface {
StringExpression StringExpression
Column Column
@ -134,7 +146,11 @@ type ColumnString interface {
type stringColumnImpl struct { type stringColumnImpl struct {
stringInterfaceImpl stringInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
@ -163,6 +179,51 @@ func StringColumn(name string) ColumnString {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnBlob is interface for binary data types (bytea, binary, blob, etc...)
type ColumnBlob interface {
BlobExpression
Column
From(subQuery SelectTable) ColumnBlob
SET(blob BlobExpression) ColumnAssigment
}
type blobColumnImpl struct {
blobInterfaceImpl
*ColumnExpressionImpl
}
func (i *blobColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *blobColumnImpl) From(subQuery SelectTable) ColumnBlob {
newBlobColumn := BlobColumn(i.name)
newBlobColumn.setTableName(i.tableName)
newBlobColumn.setSubQuery(subQuery)
return newBlobColumn
}
func (i *blobColumnImpl) SET(blobExp BlobExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: blobExp,
}
}
// BlobColumn creates named blob column.
func BlobColumn(name string) ColumnBlob {
blobColumn := &blobColumnImpl{}
blobColumn.blobInterfaceImpl.parent = blobColumn
blobColumn.ColumnExpressionImpl = NewColumnImpl(name, "", blobColumn)
return blobColumn
}
//------------------------------------------------------//
// ColumnTime is interface for SQL time column. // ColumnTime is interface for SQL time column.
type ColumnTime interface { type ColumnTime interface {
TimeExpression TimeExpression
@ -174,7 +235,11 @@ type ColumnTime interface {
type timeColumnImpl struct { type timeColumnImpl struct {
timeInterfaceImpl timeInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
@ -213,7 +278,11 @@ type ColumnTimez interface {
type timezColumnImpl struct { type timezColumnImpl struct {
timezInterfaceImpl timezInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
@ -253,7 +322,11 @@ type ColumnTimestamp interface {
type timestampColumnImpl struct { type timestampColumnImpl struct {
timestampInterfaceImpl timestampInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
@ -293,7 +366,11 @@ type ColumnTimestampz interface {
type timestampzColumnImpl struct { type timestampzColumnImpl struct {
timestampzInterfaceImpl timestampzInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
@ -333,7 +410,11 @@ type ColumnDate interface {
type dateColumnImpl struct { type dateColumnImpl struct {
dateInterfaceImpl dateInterfaceImpl
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
@ -361,6 +442,51 @@ func DateColumn(name string) ColumnDate {
//------------------------------------------------------// //------------------------------------------------------//
// ColumnInterval is interface of PostgreSQL interval columns.
type ColumnInterval interface {
IntervalExpression
Column
From(subQuery SelectTable) ColumnInterval
SET(intervalExp IntervalExpression) ColumnAssigment
}
//------------------------------------------------------//
type intervalColumnImpl struct {
*ColumnExpressionImpl
intervalInterfaceImpl
}
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intervalExp,
}
}
func (i *intervalColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval {
newIntervalColumn := IntervalColumn(i.name)
newIntervalColumn.setTableName(i.tableName)
newIntervalColumn.setSubQuery(subQuery)
return newIntervalColumn
}
// IntervalColumn creates named interval column.
func IntervalColumn(name string) ColumnInterval {
intervalColumn := &intervalColumnImpl{}
intervalColumn.ColumnExpressionImpl = NewColumnImpl(name, "", intervalColumn)
intervalColumn.intervalInterfaceImpl.parent = intervalColumn
return intervalColumn
}
//------------------------------------------------------//
// ColumnRange is interface for range columns which can be int range, string range // ColumnRange is interface for range columns which can be int range, string range
// timestamp range or date range. // timestamp range or date range.
type ColumnRange[T Expression] interface { type ColumnRange[T Expression] interface {
@ -373,7 +499,11 @@ type ColumnRange[T Expression] interface {
type rangeColumnImpl[T Expression] struct { type rangeColumnImpl[T Expression] struct {
rangeInterfaceImpl[T] rangeInterfaceImpl[T]
ColumnExpressionImpl *ColumnExpressionImpl
}
func (i *rangeColumnImpl[T]) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
} }
func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] { func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] {

View file

@ -80,9 +80,10 @@ type dateExpressionWrapper struct {
} }
func newDateExpressionWrap(expression Expression) DateExpression { func newDateExpressionWrap(expression Expression) DateExpression {
dateExpressionWrap := dateExpressionWrapper{Expression: expression} dateExpressionWrap := &dateExpressionWrapper{Expression: expression}
dateExpressionWrap.dateInterfaceImpl.parent = &dateExpressionWrap dateExpressionWrap.dateInterfaceImpl.parent = dateExpressionWrap
return &dateExpressionWrap expression.setParent(dateExpressionWrap)
return dateExpressionWrap
} }
// DateExp is date expression wrapper around arbitrary expression. // DateExp is date expression wrapper around arbitrary expression.

View file

@ -1,13 +0,0 @@
package jet
import (
"testing"
)
func TestDateArithmetic(t *testing.T) {
timestamp := Timestamp(2000, 1, 1, 0, 0, 0)
assertClauseDebugSerialize(t, table1ColDate.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_date + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
assertClauseDebugSerialize(t, table1ColDate.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_date - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
}

View file

@ -1,6 +1,8 @@
package jet package jet
import "strings" import (
"strings"
)
// Dialect interface // Dialect interface
type Dialect interface { type Dialect interface {
@ -11,9 +13,11 @@ type Dialect interface {
AliasQuoteChar() byte AliasQuoteChar() byte
IdentifierQuoteChar() byte IdentifierQuoteChar() byte
ArgumentPlaceholder() QueryPlaceholderFunc ArgumentPlaceholder() QueryPlaceholderFunc
ArgumentToString(value any) (string, bool)
IsReservedWord(name string) bool IsReservedWord(name string) bool
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName(index int) string ValuesDefaultColumnName(index int) string
JsonValueEncode(expr Expression) Expression
} }
// SerializerFunc func // SerializerFunc func
@ -34,9 +38,11 @@ type DialectParams struct {
AliasQuoteChar byte AliasQuoteChar byte
IdentifierQuoteChar byte IdentifierQuoteChar byte
ArgumentPlaceholder QueryPlaceholderFunc ArgumentPlaceholder QueryPlaceholderFunc
ArgumentToString func(value any) (string, bool)
ReservedWords []string ReservedWords []string
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName func(index int) string ValuesDefaultColumnName func(index int) string
JsonValueEncode func(expr Expression) Expression
} }
// NewDialect creates new dialect with params // NewDialect creates new dialect with params
@ -49,9 +55,11 @@ func NewDialect(params DialectParams) Dialect {
aliasQuoteChar: params.AliasQuoteChar, aliasQuoteChar: params.AliasQuoteChar,
identifierQuoteChar: params.IdentifierQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar,
argumentPlaceholder: params.ArgumentPlaceholder, argumentPlaceholder: params.ArgumentPlaceholder,
argumentToString: params.ArgumentToString,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
serializeOrderBy: params.SerializeOrderBy, serializeOrderBy: params.SerializeOrderBy,
valuesDefaultColumnName: params.ValuesDefaultColumnName, valuesDefaultColumnName: params.ValuesDefaultColumnName,
jsonValueEncode: params.JsonValueEncode,
} }
} }
@ -63,9 +71,11 @@ type dialectImpl struct {
aliasQuoteChar byte aliasQuoteChar byte
identifierQuoteChar byte identifierQuoteChar byte
argumentPlaceholder QueryPlaceholderFunc argumentPlaceholder QueryPlaceholderFunc
argumentToString func(value any) (string, bool)
reservedWords map[string]bool reservedWords map[string]bool
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
valuesDefaultColumnName func(index int) string valuesDefaultColumnName func(index int) string
jsonValueEncode func(expr Expression) Expression
} }
func (d *dialectImpl) Name() string { func (d *dialectImpl) Name() string {
@ -102,6 +112,10 @@ func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
return d.argumentPlaceholder return d.argumentPlaceholder
} }
func (d *dialectImpl) ArgumentToString(value any) (string, bool) {
return d.argumentToString(value)
}
func (d *dialectImpl) IsReservedWord(name string) bool { func (d *dialectImpl) IsReservedWord(name string) bool {
_, isReservedWord := d.reservedWords[strings.ToLower(name)] _, isReservedWord := d.reservedWords[strings.ToLower(name)]
return isReservedWord return isReservedWord
@ -115,6 +129,10 @@ func (d *dialectImpl) ValuesDefaultColumnName(index int) string {
return d.valuesDefaultColumnName(index) return d.valuesDefaultColumnName(index)
} }
func (d *dialectImpl) JsonValueEncode(expr Expression) Expression {
return d.jsonValueEncode(expr)
}
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{} ret := map[string]bool{}
for _, elem := range arr { for _, elem := range arr {

View file

@ -2,7 +2,7 @@ package jet
import "fmt" import "fmt"
// Expression is common interface for all expressions. // Expression is a common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions. // Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface { type Expression interface {
Serializer Serializer
@ -10,6 +10,9 @@ type Expression interface {
GroupByClause GroupByClause
OrderByClause OrderByClause
serializeForJsonValue(statement StatementType, out *SQLBuilder)
setParent(parent Expression)
// IS_NULL tests expression whether it is a NULL value. // IS_NULL tests expression whether it is a NULL value.
IS_NULL() BoolExpression IS_NULL() BoolExpression
// IS_NOT_NULL tests expression whether it is a non-NULL value. // IS_NOT_NULL tests expression whether it is a non-NULL value.
@ -34,6 +37,10 @@ type ExpressionInterfaceImpl struct {
Parent Expression Parent Expression
} }
func (e *ExpressionInterfaceImpl) setParent(parent Expression) {
e.Parent = parent
}
func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection {
panic(fmt.Sprintf("jet: can't export unaliased expression subQuery: %s, expression: %s", panic(fmt.Sprintf("jet: can't export unaliased expression subQuery: %s, expression: %s",
subQuery.Alias(), serializeToDefaultDebugString(e.Parent))) subQuery.Alias(), serializeToDefaultDebugString(e.Parent)))
@ -92,6 +99,18 @@ func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType
e.Parent.serialize(statement, out, NoWrap) e.Parent.serialize(statement, out, NoWrap)
} }
func (e *ExpressionInterfaceImpl) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
panic("jet: expression need to be aliased when used as SELECT JSON projection.")
}
func (e *ExpressionInterfaceImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
panic("jet: expression need to be aliased when used as SELECT JSON projection.")
}
func (e *ExpressionInterfaceImpl) serializeForJsonValue(statement StatementType, out *SQLBuilder) {
out.Dialect.JsonValueEncode(e.Parent).serialize(statement, out)
}
func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, NoWrap) e.Parent.serialize(statement, out, NoWrap)
} }
@ -152,7 +171,7 @@ func newExpressionListOperator(operator string, expressions ...Expression) *expr
} }
func newBoolExpressionListOperator(operator string, expressions ...BoolExpression) BoolExpression { func newBoolExpressionListOperator(operator string, expressions ...BoolExpression) BoolExpression {
return BoolExp(newExpressionListOperator(operator, BoolExpressionListToExpressionList(expressions)...)) return BoolExp(newExpressionListOperator(operator, ToExpressionList(expressions)...))
} }
func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {

View file

@ -102,9 +102,10 @@ type floatExpressionWrapper struct {
} }
func newFloatExpressionWrap(expression Expression) FloatExpression { func newFloatExpressionWrap(expression Expression) FloatExpression {
floatExpressionWrap := floatExpressionWrapper{Expression: expression} floatExpressionWrap := &floatExpressionWrapper{Expression: expression}
floatExpressionWrap.floatInterfaceImpl.parent = &floatExpressionWrap floatExpressionWrap.floatInterfaceImpl.parent = floatExpressionWrap
return &floatExpressionWrap expression.setParent(floatExpressionWrap)
return floatExpressionWrap
} }
// FloatExp is date expression wrapper around arbitrary expression. // FloatExp is date expression wrapper around arbitrary expression.

View file

@ -255,18 +255,30 @@ func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{})
//------------ String functions ------------------// //------------ String functions ------------------//
// HEX function takes an input and returns its equivalent hexadecimal representation
func HEX(expression Expression) StringExpression {
return StringExp(Func("HEX", expression))
}
// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument
// as a hexadecimal number and converts it to the byte represented by the number.
// The return value is a binary string.
func UNHEX(expression StringExpression) BlobExpression {
return BlobExp(Func("UNHEX", expression))
}
// BIT_LENGTH returns number of bits in string expression // BIT_LENGTH returns number of bits in string expression
func BIT_LENGTH(stringExpression StringExpression) IntegerExpression { func BIT_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression {
return newIntegerFunc("BIT_LENGTH", stringExpression) return newIntegerFunc("BIT_LENGTH", stringExpression)
} }
// CHAR_LENGTH returns number of characters in string expression // CHAR_LENGTH returns number of characters in string expression
func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression { func CHAR_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression {
return newIntegerFunc("CHAR_LENGTH", stringExpression) return newIntegerFunc("CHAR_LENGTH", stringExpression)
} }
// OCTET_LENGTH returns number of bytes in string expression // OCTET_LENGTH returns number of bytes in string expression
func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression { func OCTET_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression {
return newIntegerFunc("OCTET_LENGTH", stringExpression) return newIntegerFunc("OCTET_LENGTH", stringExpression)
} }
@ -282,7 +294,7 @@ func UPPER(stringExpression StringExpression) StringExpression {
// BTRIM removes the longest string consisting only of characters // BTRIM removes the longest string consisting only of characters
// in characters (a space by default) from the start and end of string // in characters (a space by default) from the start and end of string
func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression { func BTRIM(stringExpression StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return NewStringFunc("BTRIM", stringExpression, trimChars[0]) return NewStringFunc("BTRIM", stringExpression, trimChars[0])
} }
@ -291,7 +303,7 @@ func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) Str
// LTRIM removes the longest string containing only characters // LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string // from characters (a space by default) from the start of string
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { func LTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return NewStringFunc("LTRIM", str, trimChars[0]) return NewStringFunc("LTRIM", str, trimChars[0])
} }
@ -300,7 +312,7 @@ func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression
// RTRIM removes the longest string containing only characters // RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string // from characters (a space by default) from the end of string
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { func RTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return NewStringFunc("RTRIM", str, trimChars[0]) return NewStringFunc("RTRIM", str, trimChars[0])
} }
@ -324,32 +336,32 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression
// CONVERT converts string to dest_encoding. The original encoding is // CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding. // specified by src_encoding. The string must be valid in this encoding.
func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression { func CONVERT(str BlobExpression, srcEncoding StringExpression, destEncoding StringExpression) BlobExpression {
return NewStringFunc("CONVERT", str, srcEncoding, destEncoding) return BlobExp(Func("CONVERT", str, srcEncoding, destEncoding))
} }
// CONVERT_FROM converts string to the database encoding. The original // CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding. // encoding is specified by src_encoding. The string must be valid in this encoding.
func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression { func CONVERT_FROM(str BlobExpression, srcEncoding StringExpression) StringExpression {
return NewStringFunc("CONVERT_FROM", str, srcEncoding) return NewStringFunc("CONVERT_FROM", str, srcEncoding)
} }
// CONVERT_TO converts string to dest_encoding. // CONVERT_TO converts string to dest_encoding.
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { func CONVERT_TO(str StringExpression, toEncoding StringExpression) BlobExpression {
return NewStringFunc("CONVERT_TO", str, toEncoding) return BlobExp(Func("CONVERT_TO", str, toEncoding))
} }
// ENCODE encodes binary data into a textual representation. // ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and // Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
func ENCODE(data StringExpression, format StringExpression) StringExpression { func ENCODE(data BlobExpression, format StringExpression) StringExpression {
return NewStringFunc("ENCODE", data, format) return StringExp(Func("ENCODE", data, format))
} }
// DECODE decodes binary data from textual representation in string. // DECODE decodes binary data from textual representation in string.
// Options for format are same as in encode. // Options for format are same as in encode.
func DECODE(data StringExpression, format StringExpression) StringExpression { func DECODE(data StringExpression, format StringExpression) BlobExpression {
return NewStringFunc("DECODE", data, format) return BlobExp(Func("DECODE", data, format))
} }
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
@ -379,11 +391,11 @@ func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
} }
// LENGTH returns number of characters in string with a given encoding // LENGTH returns number of characters in string with a given encoding
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { func LENGTH(str StringOrBlobExpression, encoding ...StringExpression) IntegerExpression {
if len(encoding) > 0 { if len(encoding) > 0 {
return NewStringFunc("LENGTH", str, encoding[0]) return IntExp(Func("LENGTH", str, encoding[0]))
} }
return NewStringFunc("LENGTH", str) return IntExp(Func("LENGTH", str))
} }
// LPAD fills up the string to length length by prepending the characters // LPAD fills up the string to length length by prepending the characters
@ -407,8 +419,13 @@ func RPAD(str StringExpression, length IntegerExpression, text ...StringExpressi
return NewStringFunc("RPAD", str, length) return NewStringFunc("RPAD", str, length)
} }
// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”).
func BIT_COUNT(bytes BlobExpression) IntegerExpression {
return IntExp(Func("BIT_COUNT", bytes))
}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal // MD5 calculates the MD5 hash of string, returning the result in hexadecimal
func MD5(stringExpression StringExpression) StringExpression { func MD5(stringExpression StringOrBlobExpression) StringExpression {
return NewStringFunc("MD5", stringExpression) return NewStringFunc("MD5", stringExpression)
} }
@ -434,7 +451,7 @@ func STRPOS(str, substring StringExpression) IntegerExpression {
} }
// SUBSTR extracts substring // SUBSTR extracts substring
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { func SUBSTR(str StringOrBlobExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 { if len(count) > 0 {
return NewStringFunc("SUBSTR", str, from, count[0]) return NewStringFunc("SUBSTR", str, from, count[0])
} }

View file

@ -141,11 +141,11 @@ type integerExpressionWrapper struct {
} }
func newIntExpressionWrap(expression Expression) IntegerExpression { func newIntExpressionWrap(expression Expression) IntegerExpression {
intExpressionWrap := integerExpressionWrapper{Expression: expression} intExpressionWrap := &integerExpressionWrapper{Expression: expression}
intExpressionWrap.integerInterfaceImpl.parent = intExpressionWrap
expression.setParent(intExpressionWrap)
intExpressionWrap.integerInterfaceImpl.parent = &intExpressionWrap return intExpressionWrap
return &intExpressionWrap
} }
// IntExp is int expression wrapper around arbitrary expression. // IntExp is int expression wrapper around arbitrary expression.

View file

@ -1,37 +0,0 @@
package jet
// Interval is internal common representation of sql interval
type Interval interface {
Serializer
IsInterval
}
// IsInterval interface
type IsInterval interface {
isInterval()
}
// IsIntervalImpl is implementation of IsInterval interface
type IsIntervalImpl struct{}
func (i *IsIntervalImpl) isInterval() {}
// NewInterval creates new interval from serializer
func NewInterval(s Serializer) *IntervalImpl {
newInterval := &IntervalImpl{
Value: s,
}
return newInterval
}
// IntervalImpl is implementation of Interval type
type IntervalImpl struct {
Value Serializer
IsIntervalImpl
}
func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("INTERVAL")
i.Value.serialize(statement, out, FallTrough(options)...)
}

View file

@ -0,0 +1,112 @@
package jet
// IntervalExpression interface
type IntervalExpression interface {
Expression
isInterval()
EQ(rhs IntervalExpression) BoolExpression
NOT_EQ(rhs IntervalExpression) BoolExpression
IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
LT(rhs IntervalExpression) BoolExpression
LT_EQ(rhs IntervalExpression) BoolExpression
GT(rhs IntervalExpression) BoolExpression
GT_EQ(rhs IntervalExpression) BoolExpression
BETWEEN(min, max IntervalExpression) BoolExpression
NOT_BETWEEN(min, max IntervalExpression) BoolExpression
ADD(rhs IntervalExpression) IntervalExpression
SUB(rhs IntervalExpression) IntervalExpression
MUL(rhs NumericExpression) IntervalExpression
DIV(rhs NumericExpression) IntervalExpression
}
type intervalInterfaceImpl struct {
parent IntervalExpression
}
func (i *intervalInterfaceImpl) isInterval() {}
func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression {
return Eq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression {
return NotEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return IsDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return IsNotDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression {
return Lt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression {
return LtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression {
return Gt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression {
return GtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression {
return NewBetweenOperatorExpression(i.parent, min, max, false)
}
func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression {
return NewBetweenOperatorExpression(i.parent, min, max, true)
}
func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression {
return IntervalExp(Add(i.parent, rhs))
}
func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression {
return IntervalExp(Sub(i.parent, rhs))
}
func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression {
return IntervalExp(Mul(i.parent, rhs))
}
func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression {
return IntervalExp(Div(i.parent, rhs))
}
type intervalWrapper struct {
intervalInterfaceImpl
Expression
}
func newIntervalExpressionWrap(expression Expression) IntervalExpression {
intervalWrap := &intervalWrapper{Expression: expression}
intervalWrap.intervalInterfaceImpl.parent = intervalWrap
expression.setParent(intervalWrap)
return intervalWrap
}
// IntervalExp is interval expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as interval expression.
// Does not add sql cast to generated sql builder output.
func IntervalExp(expression Expression) IntervalExpression {
return newIntervalExpressionWrap(expression)
}
// Interval interface
type Interval interface {
Serializer
isInterval()
}

View file

@ -412,17 +412,6 @@ func Raw(raw string, namedArgs ...map[string]interface{}) Expression {
return rawExp return rawExp
} }
// RawWithParent is a Raw constructor used for construction dialect specific expression
func RawWithParent(raw string, parent ...Expression) Expression {
rawExp := &rawExpression{
Raw: raw,
noWrap: true,
}
rawExp.ExpressionInterfaceImpl.Parent = OptionalOrDefaultExpression(rawExp, parent...)
return rawExp
}
// RawBool helper that for raw string boolean expressions // RawBool helper that for raw string boolean expressions
func RawBool(raw string, namedArgs ...map[string]interface{}) BoolExpression { func RawBool(raw string, namedArgs ...map[string]interface{}) BoolExpression {
return BoolExp(Raw(raw, namedArgs...)) return BoolExp(Raw(raw, namedArgs...))
@ -468,6 +457,11 @@ func RawDate(raw string, namedArgs ...map[string]interface{}) DateExpression {
return DateExp(Raw(raw, namedArgs...)) return DateExp(Raw(raw, namedArgs...))
} }
// RawBlob is raw query helper that for blob expressions
func RawBlob(raw string, namedArgs ...map[string]interface{}) BlobExpression {
return BlobExp(Raw(raw, namedArgs...))
}
// RawRange helper that for range expressions // RawRange helper that for range expressions
func RawRange[T Expression](raw string, namedArgs ...map[string]interface{}) Range[T] { func RawRange[T Expression](raw string, namedArgs ...map[string]interface{}) Range[T] {
return RangeExp[T](Raw(raw, namedArgs...)) return RangeExp[T](Raw(raw, namedArgs...))

View file

@ -3,6 +3,8 @@ package jet
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection interface { type Projection interface {
serializeForProjection(statement StatementType, out *SQLBuilder) serializeForProjection(statement StatementType, out *SQLBuilder)
serializeForJsonObjEntry(statement StatementType, out *SQLBuilder)
serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder)
fromImpl(subQuery SelectTable) Projection fromImpl(subQuery SelectTable) Projection
} }
@ -28,6 +30,10 @@ func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQ
SerializeProjectionList(statement, pl, out) SerializeProjectionList(statement, pl, out)
} }
func (pl ProjectionList) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
SerializeProjectionListJsonObj(statement, pl, out)
}
// As will create new projection list where each column is wrapped with a new table alias. // As will create new projection list where each column is wrapped with a new table alias.
// tableAlias should be in the form 'name' or 'name.*', or it can be an empty string, which will remove existing table alias. // tableAlias should be in the form 'name' or 'name.*', or it can be an empty string, which will remove existing table alias.
// For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will // For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will
@ -79,3 +85,18 @@ func (pl ProjectionList) Except(toExclude ...Column) ProjectionList {
return ret return ret
} }
func (pl ProjectionList) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
out.WriteRowToJsonProjections(statement, pl)
}
// JsonObjProjectionList redefines []Projection so projections can be serialized as json object key/values
type JsonObjProjectionList []Projection
func (j JsonObjProjectionList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.IncreaseIdent()
out.NewLine()
SerializeProjectionListJsonObj(statement, j, out)
out.DecreaseIdent()
out.NewLine()
}

View file

@ -118,9 +118,10 @@ type rangeExpressionWrapper[T Expression] struct {
} }
func newRangeExpressionWrap[T Expression](expression Expression) Range[T] { func newRangeExpressionWrap[T Expression](expression Expression) Range[T] {
rangeExpressionWrap := rangeExpressionWrapper[T]{Expression: expression} rangeExpressionWrap := &rangeExpressionWrapper[T]{Expression: expression}
rangeExpressionWrap.rangeInterfaceImpl.parent = &rangeExpressionWrap rangeExpressionWrap.rangeInterfaceImpl.parent = rangeExpressionWrap
return &rangeExpressionWrap expression.setParent(rangeExpressionWrap)
return rangeExpressionWrap
} }
// RangeExp is range expression wrapper around arbitrary expression. // RangeExp is range expression wrapper around arbitrary expression.

View file

@ -1,7 +1,7 @@
package jet package jet
type rawStatementImpl struct { type rawStatementImpl struct {
serializerStatementInterfaceImpl statementInterfaceImpl
RawQuery string RawQuery string
NamedArguments map[string]interface{} NamedArguments map[string]interface{}
@ -10,7 +10,7 @@ type rawStatementImpl struct {
// RawStatement creates new sql statements from raw query and optional map of named arguments // RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement { func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement {
newRawStatement := rawStatementImpl{ newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ statementInterfaceImpl: statementInterfaceImpl{
dialect: dialect, dialect: dialect,
statementType: "", statementType: "",
parent: nil, parent: nil,

View file

@ -19,7 +19,7 @@ type RowExpression interface {
type rowInterfaceImpl struct { type rowInterfaceImpl struct {
parent Expression parent Expression
dialect Dialect dialect Dialect
elemCount int expressions []Expression
} }
func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression { func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression {
@ -57,9 +57,8 @@ func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression {
func (n *rowInterfaceImpl) projections() ProjectionList { func (n *rowInterfaceImpl) projections() ProjectionList {
var ret ProjectionList var ret ProjectionList
for i := 0; i < n.elemCount; i++ { for i, expression := range n.expressions {
rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil) ret = append(ret, newDummyColumnForExpression(expression, n.dialect.ValuesDefaultColumnName(i)))
ret = append(ret, &rowColumn)
} }
return ret return ret
@ -77,7 +76,7 @@ func newRowExpression(name string, dialect Dialect, expressions ...Expression) R
ret.Expression = NewFunc(name, expressions, ret) ret.Expression = NewFunc(name, expressions, ret)
ret.dialect = dialect ret.dialect = dialect
ret.elemCount = len(expressions) ret.expressions = expressions
return ret return ret
} }

View file

@ -25,6 +25,8 @@ type StatementType string
// Statement types // Statement types
const ( const (
SelectStatementType StatementType = "SELECT" SelectStatementType StatementType = "SELECT"
SelectJsonObjStatementType StatementType = "SELECT_JSON_OBJ"
SelectJsonArrStatementType StatementType = "SELECT_JSON_ARR"
InsertStatementType StatementType = "INSERT" InsertStatementType StatementType = "INSERT"
UpdateStatementType StatementType = "UPDATE" UpdateStatementType StatementType = "UPDATE"
DeleteStatementType StatementType = "DELETE" DeleteStatementType StatementType = "DELETE"

View file

@ -61,6 +61,17 @@ func (s *SQLBuilder) WriteProjections(statement StatementType, projections []Pro
s.DecreaseIdent() s.DecreaseIdent()
} }
// WriteRowToJsonProjections serializes slice of projections intended for row_to_json json aggregation
func (s *SQLBuilder) WriteRowToJsonProjections(statement StatementType, projections []Projection) {
for i, projection := range projections {
if i > 0 {
s.WriteString(",")
s.NewLine()
}
projection.serializeForRowToJsonProjection(statement, s)
}
}
// NewLine adds new line to output SQL // NewLine adds new line to output SQL
func (s *SQLBuilder) NewLine() { func (s *SQLBuilder) NewLine() {
s.write([]byte{'\n'}) s.write([]byte{'\n'})
@ -99,6 +110,11 @@ func (s *SQLBuilder) WriteString(str string) {
s.write([]byte(str)) s.write([]byte(str))
} }
// WriteJsonObjKey serializes json object key
func (s *SQLBuilder) WriteJsonObjKey(key string) {
s.WriteString(fmt.Sprintf(`'%s', `, key))
}
// WriteIdentifier adds identifier to output SQL // WriteIdentifier adds identifier to output SQL
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
if s.shouldQuote(name, alwaysQuote...) { if s.shouldQuote(name, alwaysQuote...) {
@ -123,7 +139,7 @@ func (s *SQLBuilder) finalize() (string, []interface{}) {
} }
func (s *SQLBuilder) insertConstantArgument(arg interface{}) { func (s *SQLBuilder) insertConstantArgument(arg interface{}) {
s.WriteString(argToString(arg)) s.WriteString(s.argToString(arg))
} }
func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) {
@ -196,7 +212,7 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{})
} }
if s.Debug { if s.Debug {
placeholder = argToString(namedArgumentPos.Value) placeholder = s.argToString(namedArgumentPos.Value)
} }
raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace) raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace)
@ -205,11 +221,17 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{})
s.WriteString(raw) s.WriteString(raw)
} }
func argToString(value interface{}) string { func (s *SQLBuilder) argToString(value interface{}) string {
if is.Nil(value) { if is.Nil(value) {
return "NULL" return "NULL"
} }
strVal, ok := s.Dialect.ArgumentToString(value)
if ok {
return strVal
}
switch bindVal := value.(type) { switch bindVal := value.(type) {
case bool: case bool:
if bindVal { if bindVal {
@ -246,7 +268,7 @@ func argToString(value interface{}) string {
return err.Error() return err.Error()
} }
return argToString(val) return s.argToString(val)
} }
panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String()))

View file

@ -8,37 +8,39 @@ import (
) )
func TestArgToString(t *testing.T) { func TestArgToString(t *testing.T) {
require.Equal(t, argToString(true), "TRUE") s := &SQLBuilder{Dialect: defaultDialect, Debug: true}
require.Equal(t, argToString(false), "FALSE")
require.Equal(t, argToString(int(-32)), "-32") require.Equal(t, s.argToString(true), "TRUE")
require.Equal(t, argToString(uint(32)), "32") require.Equal(t, s.argToString(false), "FALSE")
require.Equal(t, argToString(int8(-43)), "-43")
require.Equal(t, argToString(uint8(43)), "43")
require.Equal(t, argToString(int16(-54)), "-54")
require.Equal(t, argToString(uint16(54)), "54")
require.Equal(t, argToString(int32(-65)), "-65")
require.Equal(t, argToString(uint32(65)), "65")
require.Equal(t, argToString(int64(-64)), "-64")
require.Equal(t, argToString(uint64(64)), "64")
require.Equal(t, argToString(float32(2.0)), "2")
require.Equal(t, argToString(float64(1.11)), "1.11")
require.Equal(t, argToString("john"), "'john'") require.Equal(t, s.argToString(int(-32)), "-32")
require.Equal(t, argToString("It's text"), "'It''s text'") require.Equal(t, s.argToString(uint(32)), "32")
require.Equal(t, argToString([]byte("john")), "'john'") require.Equal(t, s.argToString(int8(-43)), "-43")
require.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") require.Equal(t, s.argToString(uint8(43)), "43")
require.Equal(t, s.argToString(int16(-54)), "-54")
require.Equal(t, s.argToString(uint16(54)), "54")
require.Equal(t, s.argToString(int32(-65)), "-65")
require.Equal(t, s.argToString(uint32(65)), "65")
require.Equal(t, s.argToString(int64(-64)), "-64")
require.Equal(t, s.argToString(uint64(64)), "64")
require.Equal(t, s.argToString(float32(2.0)), "2")
require.Equal(t, s.argToString(float64(1.11)), "1.11")
require.Equal(t, s.argToString("john"), "'john'")
require.Equal(t, s.argToString("It's text"), "'It''s text'")
require.Equal(t, s.argToString([]byte("john")), "'john'")
require.Equal(t, s.argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") require.Equal(t, s.argToString(time), "'2006-01-02 15:04:05-07:00'")
func() { func() {
defer func() { defer func() {
require.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter") require.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter")
}() }()
argToString(map[string]bool{}) s.argToString(map[string]bool{})
}() }()
} }

View file

@ -7,25 +7,38 @@ import (
"time" "time"
) )
// Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) // Statement is a common interface for all SQL statements, including SELECT, SELECT_JSON_ARR, SELECT_JSON_OBJ, INSERT,
// UPDATE, DELETE, and LOCK.
type Statement interface { type Statement interface {
// Sql returns parametrized sql query with list of arguments. // Sql returns a parameterized SQL query along with its list of arguments.
Sql() (query string, args []interface{}) Sql() (query string, args []interface{})
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument string representation.
// Do not use it in production. Use it only for debug purposes. // DebugSql returns a debug-friendly SQL query where all parameterized placeholders
// are replaced with their respective argument string representations.
//
// Warning: This method should only be used for debugging purposes.
// Do not use it in production, as it may lead to security risks such as SQL injection.
DebugSql() (query string) DebugSql() (query string)
// Query executes statement over database connection/transaction db and stores row results in destination.
// Destination can be either pointer to struct or pointer to a slice. // Query delegates call to QueryContext using context.Background() as parameter.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
Query(db qrm.Queryable, destination interface{}) error Query(db qrm.Queryable, destination interface{}) error
// QueryContext executes statement with a context over database connection/transaction db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice. // QueryContext executes the statement with the provided context over a database connection or transaction (`db`),
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // and stores the retrieved row results in the given destination.
//
// For statements of type SELECT, INSERT, UPDATE, or DELETE, the destination must be a pointer to either a struct or a slice.
// For SELECT_JSON_ARR statements, the destination must be a pointer to a slice of structs or a pointer to []map[string]any.
// For SELECT_JSON_OBJ statements, the destination must be a pointer to a struct or a pointer to map[string]any.
//
// If the destination is a pointer to a struct and the query returns no rows, QueryContext returns qrm.ErrNoRows.
QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error
// Exec executes statement over db connection/transaction without returning any rows.
// Exec delegates call to ExecContext using context.Background() as parameter.
Exec(db qrm.Executable) (sql.Result, error) Exec(db qrm.Executable) (sql.Result, error)
// ExecContext executes statement with context over db connection/transaction without returning any rows. // ExecContext executes statement with context over db connection/transaction without returning any rows.
ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error)
// Rows executes statements over db connection/transaction and returns rows // Rows executes statements over db connection/transaction and returns rows
Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error)
} }
@ -60,14 +73,14 @@ type SerializerHasProjections interface {
HasProjections HasProjections
} }
// serializerStatementInterfaceImpl struct // statementInterfaceImpl struct
type serializerStatementInterfaceImpl struct { type statementInterfaceImpl struct {
dialect Dialect dialect Dialect
statementType StatementType statementType StatementType
parent SerializerStatement parent SerializerStatement
} }
func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface{}) { func (s *statementInterfaceImpl) Sql() (query string, args []interface{}) {
queryData := &SQLBuilder{Dialect: s.dialect} queryData := &SQLBuilder{Dialect: s.dialect}
@ -77,7 +90,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface
return return
} }
func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { func (s *statementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true} sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true}
s.parent.serialize(s.statementType, sqlBuilder, NoWrap) s.parent.serialize(s.statementType, sqlBuilder, NoWrap)
@ -86,11 +99,27 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
return return
} }
func (s *serializerStatementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error { func (s *statementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error {
return s.QueryContext(context.Background(), db, destination) return s.QueryContext(context.Background(), db, destination)
} }
func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error { func (s *statementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error {
return s.query(ctx, func(query string, args []interface{}) (int64, error) {
switch s.statementType {
case SelectJsonObjStatementType:
return qrm.QueryJsonObj(ctx, db, query, args, destination)
case SelectJsonArrStatementType:
return qrm.QueryJsonArr(ctx, db, query, args, destination)
default:
return qrm.Query(ctx, db, query, args, destination)
}
})
}
func (s *statementInterfaceImpl) query(
ctx context.Context,
queryFunc func(query string, args []interface{}) (int64, error),
) error {
query, args := s.Sql() query, args := s.Sql()
callLogger(ctx, s) callLogger(ctx, s)
@ -99,7 +128,7 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db
var err error var err error
duration := duration(func() { duration := duration(func() {
rowsProcessed, err = qrm.Query(ctx, db, query, args, destination) rowsProcessed, err = queryFunc(query, args)
}) })
callQueryLoggerFunc(ctx, QueryInfo{ callQueryLoggerFunc(ctx, QueryInfo{
@ -112,11 +141,11 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db
return err return err
} }
func (s *serializerStatementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) { func (s *statementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) {
return s.ExecContext(context.Background(), db) return s.ExecContext(context.Background(), db)
} }
func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) { func (s *statementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) {
query, args := s.Sql() query, args := s.Sql()
callLogger(ctx, s) callLogger(ctx, s)
@ -141,7 +170,7 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q
return res, err return res, err
} }
func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) { func (s *statementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) {
query, args := s.Sql() query, args := s.Sql()
callLogger(ctx, s) callLogger(ctx, s)
@ -191,11 +220,15 @@ type ExpressionStatement interface {
} }
// NewExpressionStatementImpl creates new expression statement // NewExpressionStatementImpl creates new expression statement
func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement { func NewExpressionStatementImpl(Dialect Dialect,
statementType StatementType,
parent ExpressionStatement,
clauses ...Clause) ExpressionStatement {
return &expressionStatementImpl{ return &expressionStatementImpl{
ExpressionInterfaceImpl{Parent: parent}, ExpressionInterfaceImpl{Parent: parent},
statementImpl{ statementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ statementInterfaceImpl: statementInterfaceImpl{
parent: parent, parent: parent,
dialect: Dialect, dialect: Dialect,
statementType: statementType, statementType: statementType,
@ -214,10 +247,14 @@ func (s *expressionStatementImpl) serializeForProjection(statement StatementType
s.serialize(statement, out) s.serialize(statement, out)
} }
func (e *expressionStatementImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
panic("jet: SELECT JSON statements need to be aliased when used as a projection.")
}
// NewStatementImpl creates new statementImpl // NewStatementImpl creates new statementImpl
func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) SerializerStatement { func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) SerializerStatement {
return &statementImpl{ return &statementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ statementInterfaceImpl: statementInterfaceImpl{
parent: parent, parent: parent,
dialect: Dialect, dialect: Dialect,
statementType: statementType, statementType: statementType,
@ -227,7 +264,7 @@ func NewStatementImpl(Dialect Dialect, statementType StatementType, parent Seria
} }
type statementImpl struct { type statementImpl struct {
serializerStatementInterfaceImpl statementInterfaceImpl
Clauses []Clause Clauses []Clause
} }

View file

@ -3,6 +3,7 @@ package jet
// StringExpression interface // StringExpression interface
type StringExpression interface { type StringExpression interface {
Expression Expression
isStringOrBlob()
EQ(rhs StringExpression) BoolExpression EQ(rhs StringExpression) BoolExpression
NOT_EQ(rhs StringExpression) BoolExpression NOT_EQ(rhs StringExpression) BoolExpression
@ -29,6 +30,8 @@ type stringInterfaceImpl struct {
parent StringExpression parent StringExpression
} }
func (s *stringInterfaceImpl) isStringOrBlob() {}
func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression {
return Eq(s.parent, rhs) return Eq(s.parent, rhs)
} }
@ -102,9 +105,10 @@ type stringExpressionWrapper struct {
} }
func newStringExpressionWrap(expression Expression) StringExpression { func newStringExpressionWrap(expression Expression) StringExpression {
stringExpressionWrap := stringExpressionWrapper{Expression: expression} stringExpressionWrap := &stringExpressionWrapper{Expression: expression}
stringExpressionWrap.stringInterfaceImpl.parent = &stringExpressionWrap stringExpressionWrap.stringInterfaceImpl.parent = stringExpressionWrap
return &stringExpressionWrap expression.setParent(stringExpressionWrap)
return stringExpressionWrap
} }
// StringExp is string expression wrapper around arbitrary expression. // StringExp is string expression wrapper around arbitrary expression.

View file

@ -0,0 +1,8 @@
package jet
// StringOrBlobExpression is common interface for all string and blob expressions
type StringOrBlobExpression interface {
Expression
isStringOrBlob()
}

View file

@ -12,6 +12,9 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests
ArgumentPlaceholder: func(ord int) string { ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord) return "$" + strconv.Itoa(ord)
}, },
ArgumentToString: func(value any) (string, bool) {
return "", false
},
}) })
var ( var (

View file

@ -75,14 +75,15 @@ func (t *timeInterfaceImpl) SUB(rhs Interval) TimeExpression {
//---------------------------------------------------// //---------------------------------------------------//
type timeExpressionWrapper struct { type timeExpressionWrapper struct {
timeInterfaceImpl
Expression Expression
timeInterfaceImpl
} }
func newTimeExpressionWrap(expression Expression) TimeExpression { func newTimeExpressionWrap(expression Expression) TimeExpression {
timeExpressionWrap := timeExpressionWrapper{Expression: expression} timeExpressionWrap := &timeExpressionWrapper{Expression: expression}
timeExpressionWrap.timeInterfaceImpl.parent = &timeExpressionWrap timeExpressionWrap.timeInterfaceImpl.parent = timeExpressionWrap
return &timeExpressionWrap expression.setParent(timeExpressionWrap)
return timeExpressionWrap
} }
// TimeExp is time expression wrapper around arbitrary expression. // TimeExp is time expression wrapper around arbitrary expression.

View file

@ -52,11 +52,3 @@ func TestTimeExp(t *testing.T) {
assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)), assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)),
"(table1.col_float < $1)", string("01:01:01.001")) "(table1.col_float < $1)", string("01:01:01.001"))
} }
func TestTimeArithmetic(t *testing.T) {
time := Time(10, 20, 3)
assertClauseDebugSerialize(t, table1ColTime.ADD(NewInterval(String("1 HOUR"))).EQ(time),
"((table1.col_time + INTERVAL '1 HOUR') = '10:20:03')")
assertClauseDebugSerialize(t, table1ColTime.SUB(NewInterval(String("1 HOUR"))).EQ(time),
"((table1.col_time - INTERVAL '1 HOUR') = '10:20:03')")
}

View file

@ -80,9 +80,10 @@ type timestampExpressionWrapper struct {
} }
func newTimestampExpressionWrap(expression Expression) TimestampExpression { func newTimestampExpressionWrap(expression Expression) TimestampExpression {
timestampExpressionWrap := timestampExpressionWrapper{Expression: expression} timestampExpressionWrap := &timestampExpressionWrapper{Expression: expression}
timestampExpressionWrap.timestampInterfaceImpl.parent = &timestampExpressionWrap timestampExpressionWrap.timestampInterfaceImpl.parent = timestampExpressionWrap
return &timestampExpressionWrap expression.setParent(timestampExpressionWrap)
return timestampExpressionWrap
} }
// TimestampExp is timestamp expression wrapper around arbitrary expression. // TimestampExp is timestamp expression wrapper around arbitrary expression.

View file

@ -53,11 +53,3 @@ func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp), assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1)", "2000-01-31 10:20:00.003") "(table1.col_float < $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampArithmetic(t *testing.T) {
timestamp := Timestamp(2000, 1, 1, 0, 0, 0)
assertClauseDebugSerialize(t, table1ColTimestamp.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_timestamp + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
assertClauseDebugSerialize(t, table1ColTimestamp.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_timestamp - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
}

View file

@ -80,9 +80,10 @@ type timestampzExpressionWrapper struct {
} }
func newTimestampzExpressionWrap(expression Expression) TimestampzExpression { func newTimestampzExpressionWrap(expression Expression) TimestampzExpression {
timestampzExpressionWrap := timestampzExpressionWrapper{Expression: expression} timestampzExpressionWrap := &timestampzExpressionWrapper{Expression: expression}
timestampzExpressionWrap.timestampzInterfaceImpl.parent = &timestampzExpressionWrap timestampzExpressionWrap.timestampzInterfaceImpl.parent = timestampzExpressionWrap
return &timestampzExpressionWrap expression.setParent(timestampzExpressionWrap)
return timestampzExpressionWrap
} }
// TimestampzExp is timestamp with time zone expression wrapper around arbitrary expression. // TimestampzExp is timestamp with time zone expression wrapper around arbitrary expression.

View file

@ -53,11 +53,3 @@ func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200") "(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzArithmetic(t *testing.T) {
timestampz := Timestampz(2000, 1, 1, 0, 0, 0, 100, "UTC")
assertClauseDebugSerialize(t, table1ColTimestampz.ADD(NewInterval(String("1 HOUR"))).EQ(timestampz),
"((table1.col_timestampz + INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')")
assertClauseDebugSerialize(t, table1ColTimestampz.SUB(NewInterval(String("1 HOUR"))).EQ(timestampz),
"((table1.col_timestampz - INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')")
}

View file

@ -75,14 +75,15 @@ func (t *timezInterfaceImpl) SUB(rhs Interval) TimezExpression {
//---------------------------------------------------// //---------------------------------------------------//
type timezExpressionWrapper struct { type timezExpressionWrapper struct {
timezInterfaceImpl
Expression Expression
timezInterfaceImpl
} }
func newTimezExpressionWrap(expression Expression) TimezExpression { func newTimezExpressionWrap(expression Expression) TimezExpression {
timezExpressionWrap := timezExpressionWrapper{Expression: expression} timezExpressionWrap := &timezExpressionWrapper{Expression: expression}
timezExpressionWrap.timezInterfaceImpl.parent = &timezExpressionWrap timezExpressionWrap.timezInterfaceImpl.parent = timezExpressionWrap
return &timezExpressionWrap expression.setParent(timezExpressionWrap)
return timezExpressionWrap
} }
// TimezExp is time with time zone expression wrapper around arbitrary expression. // TimezExp is time with time zone expression wrapper around arbitrary expression.

View file

@ -51,11 +51,3 @@ func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")), assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")),
"(table1.col_float < $1)", string("01:01:01.000000001 +4:00")) "(table1.col_float < $1)", string("01:01:01.000000001 +4:00"))
} }
func TestTimezArithmetic(t *testing.T) {
timez := Timez(0, 0, 0, 100, "UTC")
assertClauseDebugSerialize(t, table1ColTimez.ADD(NewInterval(String("1 HOUR"))).EQ(timez),
"((table1.col_timez + INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')")
assertClauseDebugSerialize(t, table1ColTimez.SUB(NewInterval(String("1 HOUR"))).EQ(timez),
"((table1.col_timez - INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')")
}

View file

@ -58,6 +58,23 @@ func SerializeProjectionList(statement StatementType, projections []Projection,
} }
} }
// SerializeProjectionListJsonObj serializes a list of projections for JSON object
func SerializeProjectionListJsonObj(statement StatementType, projections []Projection, out *SQLBuilder) {
for i, p := range projections {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
if p == nil {
panic("jet: Projection is nil")
}
p.serializeForJsonObjEntry(statement, out)
}
}
// SerializeColumnNames func // SerializeColumnNames func
func SerializeColumnNames(columns []Column, out *SQLBuilder) { func SerializeColumnNames(columns []Column, out *SQLBuilder) {
for i, col := range columns { for i, col := range columns {
@ -115,8 +132,8 @@ func ExpressionListToSerializerList(expressions []Expression) []Serializer {
return ret return ret
} }
// BoolExpressionListToExpressionList converts list of bool expressions to list of expressions // ToExpressionList converts list of any expressions to list of expressions
func BoolExpressionListToExpressionList(expressions []BoolExpression) []Expression { func ToExpressionList[T Expression](expressions []T) []Expression {
var ret []Expression var ret []Expression
for _, expression := range expressions { for _, expression := range expressions {

View file

@ -7,7 +7,7 @@ func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(s
newWithImpl := &withImpl{ newWithImpl := &withImpl{
recursive: recursive, recursive: recursive,
ctes: cte, ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ statementInterfaceImpl: statementInterfaceImpl{
dialect: dialect, dialect: dialect,
statementType: WithStatementType, statementType: WithStatementType,
}, },
@ -25,7 +25,7 @@ func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(s
} }
type withImpl struct { type withImpl struct {
serializerStatementInterfaceImpl statementInterfaceImpl
recursive bool recursive bool
ctes []*CommonTableExpression ctes []*CommonTableExpression
primaryStatement SerializerStatement primaryStatement SerializerStatement

View file

@ -115,6 +115,16 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
require.Equal(t, dataJson, expectedJSON) require.Equal(t, dataJson, expectedJSON)
} }
// AssertJsonEqual checks if actual and expected json representation are the same
func AssertJsonEqual(t require.TestingT, actual, expected interface{}, option ...cmp.Option) {
actualJsonData, err := json.MarshalIndent(actual, "", "\t")
require.NoError(t, err)
expectedJsonData, err := json.MarshalIndent(expected, "", "\t")
require.NoError(t, err)
require.Equal(t, string(actualJsonData), string(expectedJsonData))
}
// SaveJSONFile saves v as json at testRelativePath // SaveJSONFile saves v as json at testRelativePath
// nolint:unused // nolint:unused
func SaveJSONFile(v interface{}, testRelativePath string) { func SaveJSONFile(v interface{}, testRelativePath string) {
@ -127,7 +137,10 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
} }
// AssertJSONFile check if data json representation is the same as json at testRelativePath // AssertJSONFile check if data json representation is the same as json at testRelativePath
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { func AssertJSONFile(t require.TestingT, data interface{}, testRelativePath string) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
filePath := getFullPath(testRelativePath) filePath := getFullPath(testRelativePath)
fileJSONData, err := os.ReadFile(filePath) // #nosec G304 fileJSONData, err := os.ReadFile(filePath) // #nosec G304
@ -145,7 +158,11 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
} }
// AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs
func AssertStatementSql(t *testing.T, query jet.PrintableStatement, expectedQuery string, expectedArgs ...interface{}) { func AssertStatementSql(t require.TestingT, query jet.PrintableStatement, expectedQuery string, expectedArgs ...interface{}) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
queryStr, args := query.Sql() queryStr, args := query.Sql()
assertQueryString(t, queryStr, expectedQuery) assertQueryString(t, queryStr, expectedQuery)
@ -283,14 +300,14 @@ func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) {
} }
// AssertDeepEqual checks if actual and expected objects are deeply equal. // AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}, option ...cmp.Option) { func AssertDeepEqual(t require.TestingT, actual, expected interface{}, option ...cmp.Option) {
if !assert.True(t, cmp.Equal(actual, expected, option...)) { if !assert.True(t, cmp.Equal(actual, expected, option...)) {
printDiff(actual, expected, option...) printDiff(actual, expected, option...)
t.FailNow() t.FailNow()
} }
} }
func assertQueryString(t *testing.T, actual, expected string) { func assertQueryString(t require.TestingT, actual, expected string) {
if !assert.Equal(t, actual, expected) { if !assert.Equal(t, actual, expected) {
printDiff(actual, expected) printDiff(actual, expected)
t.FailNow() t.FailNow()

View file

@ -1,6 +1,9 @@
package datetime package datetime
import "time" import (
//"github.com/go-jet/jet/v2/internal/utils/min"
"time"
)
// ExtractTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration // ExtractTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration
func ExtractTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) { func ExtractTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) {
@ -20,3 +23,36 @@ func ExtractTimeComponents(duration time.Duration) (days, hours, minutes, second
return return
} }
// TryParseAsTime attempts to parse the provided value as a time using one of the given formats.
//
// The function iterates over the provided formats and tries to parse the value into a time.Time object.
// It returns the parsed time and a boolean indicating whether the parsing was successful.
func TryParseAsTime(value interface{}, formats []string) (time.Time, bool) {
var timeStr string
switch v := value.(type) {
case string:
timeStr = v
case []byte:
timeStr = string(v)
case int64:
return time.Unix(v, 0), true // sqlite
default:
return time.Time{}, false
}
for _, format := range formats {
formatLen := min(len(format), len(timeStr))
t, err := time.Parse(format[:formatLen], timeStr)
if err != nil {
continue
}
return t, true
}
return time.Time{}, false
}

View file

@ -1,6 +1,8 @@
package is package is
import "reflect" import (
"reflect"
)
// Nil check if v is nil // Nil check if v is nil
func Nil(v interface{}) bool { func Nil(v interface{}) bool {

View file

@ -1,9 +0,0 @@
package min
// Int returns minimum of two int values
func Int(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -70,6 +70,6 @@ func (c *cast) AS_TIME() TimeExpression {
} }
// AS_BINARY casts expression as BINARY type // AS_BINARY casts expression as BINARY type
func (c *cast) AS_BINARY() StringExpression { func (c *cast) AS_BINARY() BlobExpression {
return StringExp(c.AS("BINARY")) return BlobExp(c.AS("BINARY"))
} }

View file

@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString
// StringColumn creates named string column. // StringColumn creates named string column.
var StringColumn = jet.StringColumn var StringColumn = jet.StringColumn
// ColumnBlob is interface for blob columns.
type ColumnBlob = jet.ColumnBlob
// BlobColumn creates named blob column.
var BlobColumn = jet.BlobColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns. // ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger type ColumnInteger = jet.ColumnInteger

View file

@ -1,11 +1,12 @@
package mysql package mysql
import ( import (
"encoding/hex"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )
// Dialect is implementation of MySQL dialect for SQL Builder serialisation. // Dialect is implementation of MySQL dialect for SQL Builder serialization.
var Dialect = newDialect() var Dialect = newDialect()
func newDialect() jet.Dialect { func newDialect() jet.Dialect {
@ -27,16 +28,43 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(int) string { ArgumentPlaceholder: func(int) string {
return "?" return "?"
}, },
ArgumentToString: argumentToString,
ReservedWords: reservedWords, ReservedWords: reservedWords,
SerializeOrderBy: serializeOrderBy, SerializeOrderBy: serializeOrderBy,
ValuesDefaultColumnName: func(index int) string { ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column_%d", index) return fmt.Sprintf("column_%d", index)
}, },
JsonValueEncode: func(expr Expression) Expression {
switch e := expr.(type) {
case BlobExpression:
return TO_BASE64(e)
// CustomExpression used bellow (instead DATE_FORMAT function) so that only expr is parametrized
case TimestampExpression:
return CustomExpression(Token("DATE_FORMAT("), e, Token(",'%Y-%m-%dT%H:%i:%s.%fZ')"))
case TimeExpression:
return CustomExpression(Token("CONCAT('0000-01-01T', DATE_FORMAT("), e, Token(",'%H:%i:%s.%fZ'))"))
case DateExpression:
return CustomExpression(Token("CONCAT(DATE_FORMAT("), e, Token(",'%Y-%m-%d')"), Token(", 'T00:00:00Z')"))
case BoolExpression:
return CustomExpression(e, Token(" = 1"))
}
return expr
},
} }
return jet.NewDialect(mySQLDialectParams) return jet.NewDialect(mySQLDialectParams)
} }
func argumentToString(value any) (string, bool) {
switch bindVal := value.(type) {
case []byte:
return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true
}
return "", false
}
func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc { func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {

View file

@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression
// StringExpression interface // StringExpression interface
type StringExpression = jet.StringExpression type StringExpression = jet.StringExpression
// BlobExpression interface
type BlobExpression = jet.BlobExpression
// IntegerExpression interface // IntegerExpression interface
type IntegerExpression = jet.IntegerExpression type IntegerExpression = jet.IntegerExpression
@ -43,6 +46,11 @@ var BoolExp = jet.BoolExp
// Does not add sql cast to generated sql builder output. // Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp var StringExp = jet.StringExp
// BlobExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as blob expression.
// Does not add sql cast to generated sql builder output.
var BlobExp = jet.BlobExp
// IntExp is int expression wrapper around arbitrary expression. // IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression. // Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output. // Does not add sql cast to generated sql builder output.
@ -100,6 +108,7 @@ var (
RawTime = jet.RawTime RawTime = jet.RawTime
RawTimestamp = jet.RawTimestamp RawTimestamp = jet.RawTimestamp
RawDate = jet.RawDate RawDate = jet.RawDate
RawBlob = jet.RawBlob
) )
// Func can be used to call custom or unsupported database functions. // Func can be used to call custom or unsupported database functions.

View file

@ -148,6 +148,14 @@ var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------// //--------------------- String functions ------------------//
// HEX function in MySQL takes an input and returns its equivalent hexadecimal representation
var HEX = jet.HEX
// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument
// as a hexadecimal number and converts it to the byte represented by the number.
// The return value is a binary string.
var UNHEX = jet.UNHEX
// BIT_LENGTH returns number of bits in string expression // BIT_LENGTH returns number of bits in string expression
var BIT_LENGTH = jet.BIT_LENGTH var BIT_LENGTH = jet.BIT_LENGTH
@ -157,6 +165,23 @@ var CHAR_LENGTH = jet.CHAR_LENGTH
// OCTET_LENGTH returns number of bytes in string expression // OCTET_LENGTH returns number of bytes in string expression
var OCTET_LENGTH = jet.OCTET_LENGTH var OCTET_LENGTH = jet.OCTET_LENGTH
// ELT returns the Nth element of the list of strings: str1 if N = 1, str2 if N = 2, and so on.
// Returns NULL if N is less than 1, greater than the number of arguments, or NULL.
func ELT(n IntegerExpression, list ...StringExpression) StringExpression {
args := []Expression{n}
args = append(args, jet.ToExpressionList(list)...)
return StringExp(Func("ELT", args...))
}
// FIELD returns the index (position) of str in the str1, str2, str3, ... list. Returns 0 if str is not found.
func FIELD(str StringExpression, list ...StringExpression) StringExpression {
args := []Expression{str}
args = append(args, jet.ToExpressionList(list)...)
return StringExp(Func("FIELD", args...))
}
// LOWER returns string expression in lower case // LOWER returns string expression in lower case
var LOWER = jet.LOWER var LOWER = jet.LOWER
@ -178,7 +203,35 @@ var CONCAT = jet.CONCAT
var CONCAT_WS = jet.CONCAT_WS var CONCAT_WS = jet.CONCAT_WS
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
var FORMAT = jet.FORMAT func FORMAT(number jet.NumericExpression, decimals IntegerExpression, optionalLocale ...StringExpression) StringExpression {
if len(optionalLocale) > 0 {
return StringExp(Func("FORMAT", number, decimals, optionalLocale[0]))
}
return StringExp(Func("FORMAT", number, decimals))
}
// TO_BASE64 converts the string argument to base-64 encoded form and returns the
// result as a character string with the connection character set and collation.
func TO_BASE64(data jet.StringOrBlobExpression) StringExpression {
return StringExp(Func("TO_BASE64", data))
}
// FROM_BASE64 takes a string encoded with the base-64 encoded rules used by TO_BASE64()
// and returns the decoded result as a binary string.
func FROM_BASE64(data StringExpression) BlobExpression {
return BlobExp(Func("FROM_BASE64", data))
}
// CHARSET returns the character set of the string argument, or NULL if the argument is NULL.
func CHARSET(exp Expression) StringExpression {
return StringExp(Func("CHARSET", exp))
}
// COLLATION returns the collation of the string argument.
func COLLATION(exp Expression) StringExpression {
return StringExp(Func("COLLATION ", exp))
}
// LEFT returns first n characters in the string. // LEFT returns first n characters in the string.
// When n is negative, return all but last |n| characters. // When n is negative, return all but last |n| characters.
@ -189,7 +242,7 @@ var LEFT = jet.LEFT
var RIGHT = jet.RIGHT var RIGHT = jet.RIGHT
// LENGTH returns number of characters in string with a given encoding // LENGTH returns number of characters in string with a given encoding
func LENGTH(str jet.StringExpression) jet.StringExpression { func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression {
return jet.LENGTH(str) return jet.LENGTH(str)
} }

View file

@ -98,13 +98,10 @@ func INTERVAL(value interface{}, unitType unitType) Interval {
// INTERVALe creates new temporal interval from expresion and unit type. // INTERVALe creates new temporal interval from expresion and unit type.
func INTERVALe(expr Expression, unitType unitType) Interval { func INTERVALe(expr Expression, unitType unitType) Interval {
return jet.NewInterval(jet.ListSerializer{ return jet.IntervalExp(CustomExpression(Token("INTERVAL"), expr, Token(unitType)))
Serializers: []jet.Serializer{expr, jet.RawWithParent(string(unitType))},
Separator: " ",
})
} }
// INTERVALd temoral interval from time.Duration // INTERVALd creates new temporal interval from time.Duration
func INTERVALd(duration time.Duration) Interval { func INTERVALd(duration time.Duration) Interval {
var sign int64 = 1 var sign int64 = 1
if duration < 0 { if duration < 0 {

View file

@ -56,6 +56,11 @@ var String = jet.String
// value can be any uuid type with a String method // value can be any uuid type with a String method
var UUID = jet.UUID var UUID = jet.UUID
// Blob creates new blob literal expression
func Blob(data []byte) BlobExpression {
return BlobExp(jet.Literal(data))
}
// Date creates new date literal // Date creates new date literal
func Date(year int, month time.Month, day int) DateExpression { func Date(year int, month time.Month, day int) DateExpression {
return CAST(jet.Date(year, month, day)).AS_DATE() return CAST(jet.Date(year, month, day)).AS_DATE()

79
mysql/select_json.go Normal file
View file

@ -0,0 +1,79 @@
package mysql
import (
"github.com/go-jet/jet/v2/internal/jet"
)
// SelectJsonStatement is an interface for MySQL statements that generate JSON on the server.
type SelectJsonStatement interface {
Statement
jet.Serializer
AS(alias string) Projection
FROM(table ReadableTable) SelectJsonStatement
WHERE(condition BoolExpression) SelectJsonStatement
ORDER_BY(orderByClauses ...OrderByClause) SelectJsonStatement
LIMIT(limit int64) SelectJsonStatement
OFFSET(offset int64) SelectJsonStatement
}
// SELECT_JSON_ARR creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_ARR(projections ...Projection) SelectJsonStatement {
return newSelectStatementJson(projections, jet.SelectJsonArrStatementType)
}
// SELECT_JSON_OBJ creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_OBJ(projections ...Projection) SelectJsonStatement {
return newSelectStatementJson(projections, jet.SelectJsonObjStatementType)
}
type selectJsonStatement struct {
*selectStatementImpl
}
func newSelectStatementJson(projections []Projection, statementType jet.StatementType) SelectJsonStatement {
newSelect := &selectJsonStatement{
selectStatementImpl: newSelectStatement(statementType, nil, nil),
}
newSelect.Select.ProjectionList = ProjectionList{constructJsonFunc(projections, statementType).AS("json")}
return newSelect
}
func constructJsonFunc(projections []Projection, statementType jet.StatementType) Expression {
jsonObj := Func("JSON_OBJECT", CustomExpression(jet.JsonObjProjectionList(projections)))
if statementType == jet.SelectJsonArrStatementType {
return Func("JSON_ARRAYAGG", jsonObj)
}
return jsonObj
}
func (s *selectJsonStatement) FROM(table ReadableTable) SelectJsonStatement {
s.From.Tables = []jet.Serializer{table}
return s
}
func (s *selectJsonStatement) WHERE(condition BoolExpression) SelectJsonStatement {
s.Where.Condition = condition
return s
}
func (s *selectJsonStatement) ORDER_BY(orderByClauses ...OrderByClause) SelectJsonStatement {
s.OrderBy.List = orderByClauses
return s
}
func (s *selectJsonStatement) LIMIT(limit int64) SelectJsonStatement {
s.Limit.Count = limit
return s
}
func (s *selectJsonStatement) OFFSET(offset int64) SelectJsonStatement {
s.Offset.Count = Int(offset)
return s
}

View file

@ -62,12 +62,12 @@ type SelectStatement interface {
// SELECT creates new SelectStatement with list of projections // SELECT creates new SelectStatement with list of projections
func SELECT(projection Projection, projections ...Projection) SelectStatement { func SELECT(projection Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection}, projections...)) return newSelectStatement(jet.SelectStatementType, nil, append([]Projection{projection}, projections...))
} }
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { func newSelectStatement(stmtType jet.StatementType, table ReadableTable, projections []Projection) *selectStatementImpl {
newSelect := &selectStatementImpl{} newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, stmtType, newSelect,
&newSelect.Select, &newSelect.Select,
&newSelect.From, &newSelect.From,
&newSelect.Where, &newSelect.Where,

View file

@ -50,7 +50,7 @@ type readableTableInterfaceImpl struct {
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) return newSelectStatement(jet.SelectStatementType, r.parent, append([]Projection{projection1}, projections...))
} }
// Creates a inner join tableName Expression using onCondition. // Creates a inner join tableName Expression using onCondition.

View file

@ -101,9 +101,9 @@ func (b *cast) AS_DECIMAL() FloatExpression {
return FloatExp(b.AS("decimal")) return FloatExp(b.AS("decimal"))
} }
// AS_BYTEA casts expression AS text type // AS_BYTEA casts expression AS bytea type
func (b *cast) AS_BYTEA() StringExpression { func (b *cast) AS_BYTEA() ByteaExpression {
return StringExp(b.AS("bytea")) return ByteaExp(b.AS("bytea"))
} }
// AS_TIME casts expression AS date type // AS_TIME casts expression AS date type

View file

@ -23,6 +23,12 @@ type ColumnString = jet.ColumnString
// StringColumn creates named string column. // StringColumn creates named string column.
var StringColumn = jet.StringColumn var StringColumn = jet.StringColumn
// ColumnBytea is interface for bytea columns
type ColumnBytea = jet.ColumnBlob
// ByteaColumn creates new named bytea column.
var ByteaColumn = jet.BlobColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns. // ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger type ColumnInteger = jet.ColumnInteger
@ -65,6 +71,12 @@ type ColumnTimestampz = jet.ColumnTimestampz
// TimestampzColumn creates named timestamp with time zone column. // TimestampzColumn creates named timestamp with time zone column.
var TimestampzColumn = jet.TimestampzColumn var TimestampzColumn = jet.TimestampzColumn
// ColumnInterval is interface of PostgreSQL interval columns.
type ColumnInterval = jet.ColumnInterval
// IntervalColumn creates named interval column
var IntervalColumn = jet.IntervalColumn
// ColumnDateRange is interface of SQL date range column // ColumnDateRange is interface of SQL date range column
type ColumnDateRange = jet.ColumnRange[DateExpression] type ColumnDateRange = jet.ColumnRange[DateExpression]
@ -100,41 +112,3 @@ type ColumnInt8Range jet.ColumnRange[jet.Int8Expression]
// Int8RangeColumn creates named range with range column // Int8RangeColumn creates named range with range column
var Int8RangeColumn = jet.RangeColumn[jet.Int8Expression] var Int8RangeColumn = jet.RangeColumn[jet.Int8Expression]
//------------------------------------------------------//
// ColumnInterval is interface of PostgreSQL interval columns.
type ColumnInterval interface {
IntervalExpression
jet.Column
From(subQuery SelectTable) ColumnInterval
SET(intervalExp IntervalExpression) ColumnAssigment
}
//------------------------------------------------------//
type intervalColumnImpl struct {
jet.ColumnExpressionImpl
intervalInterfaceImpl
}
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
return jet.NewColumnAssignment(i, intervalExp)
}
func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval {
newIntervalColumn := IntervalColumn(i.Name())
jet.SetTableName(newIntervalColumn, i.TableName())
jet.SetSubQuery(newIntervalColumn, subQuery)
return newIntervalColumn
}
// IntervalColumn creates named interval column.
func IntervalColumn(name string) ColumnInterval {
intervalColumn := &intervalColumnImpl{}
intervalColumn.ColumnExpressionImpl = jet.NewColumnImpl(name, "", intervalColumn)
intervalColumn.intervalInterfaceImpl.parent = intervalColumn
return intervalColumn
}

View file

@ -1,10 +1,10 @@
package postgres package postgres
import ( import (
"encoding/hex"
"fmt" "fmt"
"strconv"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
"strconv"
) )
// Dialect is implementation of postgres dialect for SQL Builder serialisation. // Dialect is implementation of postgres dialect for SQL Builder serialisation.
@ -26,15 +26,42 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(ord int) string { ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord) return "$" + strconv.Itoa(ord)
}, },
ArgumentToString: argumentToString,
ReservedWords: reservedWords, ReservedWords: reservedWords,
ValuesDefaultColumnName: func(index int) string { ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column%d", index+1) return fmt.Sprintf("column%d", index+1)
}, },
JsonValueEncode: func(expr Expression) Expression {
switch e := expr.(type) {
case ByteaExpression:
return ENCODE(e, Base64)
// CustomExpression used bellow (instead TO_CHAR function) so that only expr is parametrized
case TimeExpression:
return CustomExpression(Token("'0000-01-01T' || to_char('2000-10-10'::date + "), e, Token(`, 'HH24:MI:SS.USZ')`))
case TimezExpression:
return CustomExpression(Token("'0000-01-01T' || to_char('2000-10-10'::date + "), e, Token(`, 'HH24:MI:SS.USTZH:TZM')`))
case TimestampExpression:
return CustomExpression(Token("to_char("), e, Token(`, 'YYYY-MM-DD"T"HH24:MI:SS.USZ')`))
case DateExpression:
return CustomExpression(Token("to_char("), e, Token(`::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z'`))
}
return expr
},
} }
return jet.NewDialect(dialectParams) return jet.NewDialect(dialectParams)
} }
func argumentToString(value any) (string, bool) {
switch bindVal := value.(type) {
case []byte:
return fmt.Sprintf("'\\x%s'", hex.EncodeToString(bindVal)), true
}
return "", false
}
func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {

View file

@ -12,6 +12,8 @@ type BoolExpression = jet.BoolExpression
// StringExpression interface // StringExpression interface
type StringExpression = jet.StringExpression type StringExpression = jet.StringExpression
type ByteaExpression = jet.BlobExpression
// NumericExpression interface // NumericExpression interface
type NumericExpression = jet.NumericExpression type NumericExpression = jet.NumericExpression
@ -39,6 +41,9 @@ type TimestampzExpression = jet.TimestampzExpression
// RowExpression interface // RowExpression interface
type RowExpression = jet.RowExpression type RowExpression = jet.RowExpression
// IntervalExpression interface
type IntervalExpression = jet.IntervalExpression
// DateRange Expression interface // DateRange Expression interface
type DateRange = jet.Range[DateExpression] type DateRange = jet.Range[DateExpression]
@ -82,6 +87,11 @@ var TimeExp = jet.TimeExp
// Does not add sql cast to generated sql builder output. // Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp var StringExp = jet.StringExp
// ByteaExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as string expression.
// Does not add sql cast to generated sql builder output.
var ByteaExp = jet.BlobExp
// TimezExp is time with time zone expression wrapper around arbitrary expression. // TimezExp is time with time zone expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time with time zone expression. // Allows go compiler to see any expression as time with time zone expression.
// Does not add sql cast to generated sql builder output. // Does not add sql cast to generated sql builder output.
@ -102,6 +112,11 @@ var TimestampExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output. // Does not add sql cast to generated sql builder output.
var TimestampzExp = jet.TimestampzExp var TimestampzExp = jet.TimestampzExp
// IntervalExp is interval expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as interval expression.
// Does not add sql cast to generated sql builder output.
var IntervalExp = jet.IntervalExp
// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. // RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression // This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. // Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
@ -143,6 +158,8 @@ var (
RawTimestamp = jet.RawTimestamp RawTimestamp = jet.RawTimestamp
RawTimestampz = jet.RawTimestampz RawTimestampz = jet.RawTimestampz
RawDate = jet.RawDate RawDate = jet.RawDate
RawBytea = jet.RawBlob
RawNumRange = jet.RawRange[jet.NumericExpression] RawNumRange = jet.RawRange[jet.NumericExpression]
RawInt4Range = jet.RawRange[jet.Int4Expression] RawInt4Range = jet.RawRange[jet.Int4Expression]
RawInt8Range = jet.RawRange[jet.Int8Expression] RawInt8Range = jet.RawRange[jet.Int8Expression]

View file

@ -192,9 +192,27 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression
return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...) return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...)
} }
// Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions
var (
UTF8 = StringExp(jet.FixedLiteral("UTF8"))
LATIN1 = StringExp(jet.FixedLiteral("LATIN1"))
LATIN2 = StringExp(jet.FixedLiteral("LATIN2"))
LATIN3 = StringExp(jet.FixedLiteral("LATIN3"))
LATIN4 = StringExp(jet.FixedLiteral("LATIN4"))
WIN1252 = StringExp(jet.FixedLiteral("WIN1252"))
ISO_8859_5 = StringExp(jet.FixedLiteral("ISO_8859_5"))
ISO_8859_6 = StringExp(jet.FixedLiteral("ISO_8859_6"))
ISO_8859_7 = StringExp(jet.FixedLiteral("ISO_8859_7"))
ISO_8859_8 = StringExp(jet.FixedLiteral("ISO_8859_8"))
KOI8R = StringExp(jet.FixedLiteral("KOI8R"))
KOI8U = StringExp(jet.FixedLiteral("KOI8U"))
)
// CONVERT converts string to dest_encoding. The original encoding is // CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding. // specified by src_encoding. The string must be valid in this encoding.
var CONVERT = jet.CONVERT func CONVERT(str ByteaExpression, srcEncoding StringExpression, destEncoding StringExpression) ByteaExpression {
return jet.CONVERT(str, srcEncoding, destEncoding)
}
// CONVERT_FROM converts string to the database encoding. The original // CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding. // encoding is specified by src_encoding. The string must be valid in this encoding.
@ -203,6 +221,13 @@ var CONVERT_FROM = jet.CONVERT_FROM
// CONVERT_TO converts string to dest_encoding. // CONVERT_TO converts string to dest_encoding.
var CONVERT_TO = jet.CONVERT_TO var CONVERT_TO = jet.CONVERT_TO
// ENCODE/DECODE textual formats
var (
Base64 = StringExp(jet.FixedLiteral("base64"))
Escape = StringExp(jet.FixedLiteral("escape"))
Hex = StringExp(jet.FixedLiteral("hex"))
)
// ENCODE encodes binary data into a textual representation. // ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and // Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
@ -212,7 +237,7 @@ var ENCODE = jet.ENCODE
// Options for format are same as in encode. // Options for format are same as in encode.
var DECODE = jet.DECODE var DECODE = jet.DECODE
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. // FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf.
func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...) return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...)
} }
@ -242,6 +267,49 @@ var LPAD = jet.LPAD
// fill (a space by default). If the string is already longer than length then it is truncated. // fill (a space by default). If the string is already longer than length then it is truncated.
var RPAD = jet.RPAD var RPAD = jet.RPAD
// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”).
var BIT_COUNT = jet.BIT_COUNT
// GET_BIT extracts n'th bit from binary string.
func GET_BIT(bytes ByteaExpression, n IntegerExpression) IntegerExpression {
return IntExp(Func("GET_BIT", bytes, n))
}
// GET_BYTE extracts n'th byte from binary string.
func GET_BYTE(bytes ByteaExpression, n IntegerExpression) IntegerExpression {
return IntExp(Func("GET_BYTE", bytes, n))
}
// SET_BIT sets n'th bit in binary string to newvalue.
func SET_BIT(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression {
return ByteaExp(Func("SET_BIT", bytes, n, newValue))
}
// SET_BYTE sets n'th byte in binary string to newvalue.
func SET_BYTE(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression {
return ByteaExp(Func("SET_BYTE", bytes, n, newValue))
}
// SHA224 computes the SHA-224 hash of the binary string.
func SHA224(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA224", bytes))
}
// SHA256 computes the SHA-256 hash of the binary string.
func SHA256(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA256", bytes))
}
// SHA384 computes the SHA-384 hash of the binary string.
func SHA384(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA384", bytes))
}
// SHA512 computes the SHA-512 hash of the binary string.
func SHA512(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA512", bytes))
}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal // MD5 calculates the MD5 hash of string, returning the result in hexadecimal
var MD5 = jet.MD5 var MD5 = jet.MD5

View file

@ -1,257 +0,0 @@
package postgres
import (
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils/datetime"
"strconv"
"strings"
"time"
)
type quantityAndUnit = float64
type unit = float64
// Interval unit types
const (
YEAR unit = 123456789 + iota
MONTH
WEEK
DAY
HOUR
MINUTE
SECOND
MILLISECOND
MICROSECOND
DECADE
CENTURY
MILLENNIUM
)
// IntervalExpression is representation of postgres INTERVAL
type IntervalExpression interface {
jet.IsInterval
jet.Expression
EQ(rhs IntervalExpression) BoolExpression
NOT_EQ(rhs IntervalExpression) BoolExpression
IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
LT(rhs IntervalExpression) BoolExpression
LT_EQ(rhs IntervalExpression) BoolExpression
GT(rhs IntervalExpression) BoolExpression
GT_EQ(rhs IntervalExpression) BoolExpression
BETWEEN(min, max IntervalExpression) BoolExpression
NOT_BETWEEN(min, max IntervalExpression) BoolExpression
ADD(rhs IntervalExpression) IntervalExpression
SUB(rhs IntervalExpression) IntervalExpression
MUL(rhs NumericExpression) IntervalExpression
DIV(rhs NumericExpression) IntervalExpression
}
type intervalInterfaceImpl struct {
jet.IsIntervalImpl
parent IntervalExpression
}
func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression {
return jet.Eq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression {
return jet.NotEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return jet.IsDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return jet.IsNotDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression {
return jet.Lt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression {
return jet.LtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression {
return jet.Gt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression {
return jet.GtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression {
return jet.NewBetweenOperatorExpression(i.parent, min, max, false)
}
func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression {
return jet.NewBetweenOperatorExpression(i.parent, min, max, true)
}
func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression {
return IntervalExp(jet.Add(i.parent, rhs))
}
func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression {
return IntervalExp(jet.Sub(i.parent, rhs))
}
func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression {
return IntervalExp(jet.Mul(i.parent, rhs))
}
func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression {
return IntervalExp(jet.Div(i.parent, rhs))
}
type intervalExpression struct {
jet.Expression
intervalInterfaceImpl
}
// INTERVAL creates new interval expression from the list of quantity-unit pairs.
//
// INTERVAL(1, DAY, 3, MINUTE)
func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression {
quantityAndUnitLen := len(quantityAndUnit)
if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 {
panic("jet: invalid number of quantity and unit fields")
}
var fields []string
for i := 0; i < len(quantityAndUnit); i += 2 {
quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64)
unitString := unitToString(quantityAndUnit[i+1])
fields = append(fields, quantity+" "+unitString)
}
intervalStr := fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " "))
newInterval := &intervalExpression{}
newInterval.Expression = jet.RawWithParent(intervalStr, newInterval)
newInterval.intervalInterfaceImpl.parent = newInterval
return newInterval
}
// INTERVALd creates interval expression from time.Duration
func INTERVALd(duration time.Duration) IntervalExpression {
days, hours, minutes, seconds, microseconds := datetime.ExtractTimeComponents(duration)
var quantityAndUnits []quantityAndUnit
if days > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days))
quantityAndUnits = append(quantityAndUnits, DAY)
}
if hours > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours))
quantityAndUnits = append(quantityAndUnits, HOUR)
}
if minutes > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes))
quantityAndUnits = append(quantityAndUnits, MINUTE)
}
if seconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds))
quantityAndUnits = append(quantityAndUnits, SECOND)
}
if microseconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds))
quantityAndUnits = append(quantityAndUnits, MICROSECOND)
}
if len(quantityAndUnits) == 0 {
return INTERVAL(0, MICROSECOND)
}
return INTERVAL(quantityAndUnits...)
}
func unitToString(unit quantityAndUnit) string {
switch unit {
case YEAR:
return "YEAR"
case MONTH:
return "MONTH"
case WEEK:
return "WEEK"
case DAY:
return "DAY"
case HOUR:
return "HOUR"
case MINUTE:
return "MINUTE"
case SECOND:
return "SECOND"
case MILLISECOND:
return "MILLISECOND"
case MICROSECOND:
return "MICROSECOND"
case DECADE:
return "DECADE"
case CENTURY:
return "CENTURY"
case MILLENNIUM:
return "MILLENNIUM"
// additional field units for EXTRACT function
case DOW:
return "DOW"
case DOY:
return "DOY"
case EPOCH:
return "EPOCH"
case ISODOW:
return "ISODOW"
case ISOYEAR:
return "ISOYEAR"
case JULIAN:
return "JULIAN"
case QUARTER:
return "QUARTER"
case TIMEZONE:
return "TIMEZONE"
case TIMEZONE_HOUR:
return "TIMEZONE_HOUR"
case TIMEZONE_MINUTE:
return "TIMEZONE_MINUTE"
default:
panic("jet: invalid INTERVAL unit type")
}
}
//---------------------------------------------------//
type intervalWrapper struct {
intervalInterfaceImpl
Expression
}
func newIntervalExpressionWrap(expression Expression) IntervalExpression {
intervalWrap := &intervalWrapper{Expression: expression}
intervalWrap.intervalInterfaceImpl.parent = intervalWrap
return intervalWrap
}
// IntervalExp is interval expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as interval expression.
// Does not add sql cast to generated sql builder output.
func IntervalExp(expression Expression) IntervalExpression {
return newIntervalExpressionWrap(expression)
}

View file

@ -0,0 +1,140 @@
package postgres
import (
"fmt"
"github.com/go-jet/jet/v2/internal/utils/datetime"
"strconv"
"strings"
"time"
)
type quantityAndUnit = float64
type unit = float64
// Interval unit types
const (
YEAR unit = 123456789 + iota
MONTH
WEEK
DAY
HOUR
MINUTE
SECOND
MILLISECOND
MICROSECOND
DECADE
CENTURY
MILLENNIUM
)
// INTERVAL creates new interval expression from the list of quantity-unit pairs.
//
// INTERVAL(1, DAY, 3, MINUTE)
func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression {
quantityAndUnitLen := len(quantityAndUnit)
if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 {
panic("jet: invalid number of quantity and unit fields")
}
var fields []string
for i := 0; i < len(quantityAndUnit); i += 2 {
quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64)
unitString := unitToString(quantityAndUnit[i+1])
fields = append(fields, quantity+" "+unitString)
}
return IntervalExp(CustomExpression(Token(fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " ")))))
}
// INTERVALd creates interval expression from time.Duration
func INTERVALd(duration time.Duration) IntervalExpression {
days, hours, minutes, seconds, microseconds := datetime.ExtractTimeComponents(duration)
var quantityAndUnits []quantityAndUnit
if days > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days))
quantityAndUnits = append(quantityAndUnits, DAY)
}
if hours > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours))
quantityAndUnits = append(quantityAndUnits, HOUR)
}
if minutes > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes))
quantityAndUnits = append(quantityAndUnits, MINUTE)
}
if seconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds))
quantityAndUnits = append(quantityAndUnits, SECOND)
}
if microseconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds))
quantityAndUnits = append(quantityAndUnits, MICROSECOND)
}
if len(quantityAndUnits) == 0 {
return INTERVAL(0, MICROSECOND)
}
return INTERVAL(quantityAndUnits...)
}
func unitToString(unit quantityAndUnit) string {
switch unit {
case YEAR:
return "YEAR"
case MONTH:
return "MONTH"
case WEEK:
return "WEEK"
case DAY:
return "DAY"
case HOUR:
return "HOUR"
case MINUTE:
return "MINUTE"
case SECOND:
return "SECOND"
case MILLISECOND:
return "MILLISECOND"
case MICROSECOND:
return "MICROSECOND"
case DECADE:
return "DECADE"
case CENTURY:
return "CENTURY"
case MILLENNIUM:
return "MILLENNIUM"
// additional field units for EXTRACT function
case DOW:
return "DOW"
case DOY:
return "DOY"
case EPOCH:
return "EPOCH"
case ISODOW:
return "ISODOW"
case ISOYEAR:
return "ISOYEAR"
case JULIAN:
return "JULIAN"
case QUARTER:
return "QUARTER"
case TIMEZONE:
return "TIMEZONE"
case TIMEZONE_HOUR:
return "TIMEZONE_HOUR"
case TIMEZONE_MINUTE:
return "TIMEZONE_MINUTE"
default:
panic("jet: invalid INTERVAL unit type")
}
}
//---------------------------------------------------//

View file

@ -127,7 +127,7 @@ func Json(value interface{}) StringExpression {
var UUID = jet.UUID var UUID = jet.UUID
// Bytea creates new bytea literal expression // Bytea creates new bytea literal expression
func Bytea(value interface{}) StringExpression { func Bytea(value interface{}) ByteaExpression {
switch value.(type) { switch value.(type) {
case string, []byte: case string, []byte:
default: default:

132
postgres/select_json.go Normal file
View file

@ -0,0 +1,132 @@
package postgres
import (
"github.com/go-jet/jet/v2/internal/jet"
"strings"
)
// SELECT_JSON_ARR creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_ARR(projections ...Projection) SelectStatement {
return newSelectStatementJson(projections, jet.SelectJsonArrStatementType)
}
// SELECT_JSON_OBJ creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_OBJ(projections ...Projection) SelectStatement {
return newSelectStatementJson(projections, jet.SelectJsonObjStatementType)
}
type selectJsonStatement struct {
*selectStatementImpl
subQuery *selectStatementImpl
statementType jet.StatementType
}
func (s *selectJsonStatement) AS(alias string) Projection {
s.setSubQueryAlias(strings.ToLower(alias) + "_")
return s.selectStatementImpl.AS(alias)
}
func (s *selectJsonStatement) FROM(table ...ReadableTable) SelectStatement {
s.subQuery.From.Tables = readableTablesToSerializerList(table)
return s
}
func (s *selectJsonStatement) DISTINCT(on ...jet.ColumnExpression) SelectStatement {
s.subQuery.Select.Distinct = true
s.subQuery.Select.DistinctOnColumns = on
return s
}
func (s *selectJsonStatement) WHERE(condition BoolExpression) SelectStatement {
s.subQuery.Where.Condition = condition
return s
}
func (s *selectJsonStatement) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.subQuery.GroupBy.List = groupByClauses
return s
}
func (s *selectJsonStatement) HAVING(boolExpression BoolExpression) SelectStatement {
s.subQuery.Having.Condition = boolExpression
return s
}
func (s *selectJsonStatement) WINDOW(name string) windowExpand {
s.subQuery.Window.Definitions = append(s.subQuery.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{
selectStatement: s.subQuery,
rootStmt: s,
}
}
func (s *selectJsonStatement) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement {
s.subQuery.OrderBy.List = orderByClauses
return s
}
func (s *selectJsonStatement) LIMIT(limit int64) SelectStatement {
s.subQuery.Limit.Count = limit
return s
}
func (s *selectJsonStatement) OFFSET(offset int64) SelectStatement {
s.subQuery.Offset.Count = Int(offset)
return s
}
func (s *selectJsonStatement) OFFSET_e(offset IntegerExpression) SelectStatement {
s.subQuery.Offset.Count = offset
return s
}
func (s *selectJsonStatement) FETCH_FIRST(count IntegerExpression) fetchExpand {
s.subQuery.Fetch.Count = count
return fetchExpand{
selectStatement: s.subQuery,
rootStmt: s,
}
}
func (s *selectJsonStatement) FOR(lock RowLock) SelectStatement {
s.subQuery.For.Lock = lock
return s
}
func newSelectStatementJson(projections []Projection, statementType jet.StatementType) SelectStatement {
newSelectJson := &selectJsonStatement{
selectStatementImpl: newSelectStatement(statementType, nil, nil),
subQuery: newSelectStatement(statementType, nil, projections),
statementType: statementType,
}
newSelectJson.setOperatorsImpl.stmtRoot = newSelectJson
newSelectJson.subQuery.Select.IsForRowToJson = true
newSelectJson.setSubQueryAlias("")
return newSelectJson
}
func (s *selectJsonStatement) setSubQueryAlias(alias string) {
subQueryAlias := alias + "records"
jsonAlias := alias + "json"
s.Select.ProjectionList = ProjectionList{constructJsonFunc(s.statementType, subQueryAlias).AS(jsonAlias)}
s.From.Tables = []jet.Serializer{newSelectTable(s.subQuery, subQueryAlias, nil)}
}
func constructJsonFunc(statementType jet.StatementType, subQueryAlias string) Expression {
rowToJson := Func("row_to_json", CustomExpression(Token(subQueryAlias)))
if statementType == jet.SelectJsonArrStatementType {
return Func("json_agg", rowToJson)
}
return rowToJson
}

View file

@ -70,12 +70,12 @@ type SelectStatement interface {
// SELECT creates new SelectStatement with list of projections // SELECT creates new SelectStatement with list of projections
func SELECT(projection Projection, projections ...Projection) SelectStatement { func SELECT(projection Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection}, projections...)) return newSelectStatement(jet.SelectStatementType, nil, append([]Projection{projection}, projections...))
} }
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { func newSelectStatement(stmtType jet.StatementType, table ReadableTable, projections []Projection) *selectStatementImpl {
newSelect := &selectStatementImpl{} newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, stmtType, newSelect,
&newSelect.Select, &newSelect.Select,
&newSelect.From, &newSelect.From,
&newSelect.Where, &newSelect.Where,
@ -94,7 +94,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
} }
newSelect.Limit.Count = -1 newSelect.Limit.Count = -1
newSelect.setOperatorsImpl.parent = newSelect newSelect.setOperatorsImpl.stmtRoot = newSelect
return newSelect return newSelect
} }
@ -144,7 +144,10 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem
func (s *selectStatementImpl) WINDOW(name string) windowExpand { func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s} return windowExpand{
selectStatement: s,
rootStmt: s,
}
} }
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement { func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement {
@ -172,6 +175,7 @@ func (s *selectStatementImpl) FETCH_FIRST(count IntegerExpression) fetchExpand {
return fetchExpand{ return fetchExpand{
selectStatement: s, selectStatement: s,
rootStmt: s,
} }
} }
@ -188,6 +192,7 @@ func (s *selectStatementImpl) AsTable(alias string) SelectTable {
type windowExpand struct { type windowExpand struct {
selectStatement *selectStatementImpl selectStatement *selectStatementImpl
rootStmt SelectStatement
} }
func (w windowExpand) AS(window ...jet.Window) SelectStatement { func (w windowExpand) AS(window ...jet.Window) SelectStatement {
@ -196,7 +201,7 @@ func (w windowExpand) AS(window ...jet.Window) SelectStatement {
} }
windowsDefinition := w.selectStatement.Window.Definitions windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0] windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement return w.rootStmt
} }
func toJetFrameOffset(offset int64) jet.Serializer { func toJetFrameOffset(offset int64) jet.Serializer {
@ -216,16 +221,17 @@ func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
type fetchExpand struct { type fetchExpand struct {
selectStatement *selectStatementImpl selectStatement *selectStatementImpl
rootStmt SelectStatement
} }
func (f fetchExpand) ROWS_ONLY() SelectStatement { func (f fetchExpand) ROWS_ONLY() SelectStatement {
f.selectStatement.Fetch.WithTies = false f.selectStatement.Fetch.WithTies = false
return f.selectStatement return f.rootStmt
} }
func (f fetchExpand) ROWS_WITH_TIES() SelectStatement { func (f fetchExpand) ROWS_WITH_TIES() SelectStatement {
f.selectStatement.Fetch.WithTies = true f.selectStatement.Fetch.WithTies = true
return f.selectStatement return f.rootStmt
} }

View file

@ -65,31 +65,31 @@ type setOperators interface {
} }
type setOperatorsImpl struct { type setOperatorsImpl struct {
parent setOperators stmtRoot setOperators
} }
func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement { func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement {
return UNION(s.parent, rhs) return UNION(s.stmtRoot, rhs)
} }
func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement { func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement {
return UNION_ALL(s.parent, rhs) return UNION_ALL(s.stmtRoot, rhs)
} }
func (s *setOperatorsImpl) INTERSECT(rhs SelectStatement) setStatement { func (s *setOperatorsImpl) INTERSECT(rhs SelectStatement) setStatement {
return INTERSECT(s.parent, rhs) return INTERSECT(s.stmtRoot, rhs)
} }
func (s *setOperatorsImpl) INTERSECT_ALL(rhs SelectStatement) setStatement { func (s *setOperatorsImpl) INTERSECT_ALL(rhs SelectStatement) setStatement {
return INTERSECT_ALL(s.parent, rhs) return INTERSECT_ALL(s.stmtRoot, rhs)
} }
func (s *setOperatorsImpl) EXCEPT(rhs SelectStatement) setStatement { func (s *setOperatorsImpl) EXCEPT(rhs SelectStatement) setStatement {
return EXCEPT(s.parent, rhs) return EXCEPT(s.stmtRoot, rhs)
} }
func (s *setOperatorsImpl) EXCEPT_ALL(rhs SelectStatement) setStatement { func (s *setOperatorsImpl) EXCEPT_ALL(rhs SelectStatement) setStatement {
return EXCEPT_ALL(s.parent, rhs) return EXCEPT_ALL(s.stmtRoot, rhs)
} }
type setStatementImpl struct { type setStatementImpl struct {
@ -110,7 +110,7 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat
newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Selects = selects
newSetStatement.setOperator.Limit.Count = -1 newSetStatement.setOperator.Limit.Count = -1
newSetStatement.setOperatorsImpl.parent = newSetStatement newSetStatement.setOperatorsImpl.stmtRoot = newSetStatement
return newSetStatement return newSetStatement
} }

View file

@ -55,7 +55,7 @@ type readableTableInterfaceImpl struct {
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) return newSelectStatement(jet.SelectStatementType, r.parent, append([]Projection{projection1}, projections...))
} }
// Creates a inner join tableName Expression using onCondition. // Creates a inner join tableName Expression using onCondition.

View file

@ -4,14 +4,13 @@ import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/min" "github.com/go-jet/jet/v2/internal/utils/datetime"
"reflect" "reflect"
"strconv" "strconv"
"time"
) )
var ( var (
castOverFlowError = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error") errCastOverFlow = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error")
) )
// NullBool struct // NullBool struct
@ -64,7 +63,12 @@ func (nt *NullTime) Scan(value interface{}) error {
// Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value. // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
// At this point we try to parse those values using some of the predefined formats // At this point we try to parse those values using some of the predefined formats
nt.Time, nt.Valid = tryParseAsTime(value) nt.Time, nt.Valid = datetime.TryParseAsTime(value, []string{
"2006-01-02 15:04:05-07:00", // sqlite
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx
"15:04:05.999999", // pgx
})
if !nt.Valid { if !nt.Valid {
return fmt.Errorf("can't scan time.Time from %q", value) return fmt.Errorf("can't scan time.Time from %q", value)
@ -73,42 +77,6 @@ func (nt *NullTime) Scan(value interface{}) error {
return nil return nil
} }
var formats = []string{
"2006-01-02 15:04:05-07:00", // sqlite
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx
"15:04:05.999999", // pgx
}
func tryParseAsTime(value interface{}) (time.Time, bool) {
var timeStr string
switch v := value.(type) {
case string:
timeStr = v
case []byte:
timeStr = string(v)
case int64:
return time.Unix(v, 0), true // sqlite
default:
return time.Time{}, false
}
for _, format := range formats {
formatLen := min.Int(len(format), len(timeStr))
t, err := time.Parse(format[:formatLen], timeStr)
if err != nil {
continue
}
return t, true
}
return time.Time{}, false
}
// NullUInt64 struct // NullUInt64 struct
type NullUInt64 struct { type NullUInt64 struct {
UInt64 uint64 UInt64 uint64
@ -124,31 +92,31 @@ func (n *NullUInt64) Scan(value interface{}) error {
return nil return nil
case int64: case int64:
if v < 0 { if v < 0 {
return castOverFlowError return errCastOverFlow
} }
n.UInt64, n.Valid = uint64(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int32: case int32:
if v < 0 { if v < 0 {
return castOverFlowError return errCastOverFlow
} }
n.UInt64, n.Valid = uint64(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int16: case int16:
if v < 0 { if v < 0 {
return castOverFlowError return errCastOverFlow
} }
n.UInt64, n.Valid = uint64(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int8: case int8:
if v < 0 { if v < 0 {
return castOverFlowError return errCastOverFlow
} }
n.UInt64, n.Valid = uint64(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int: case int:
if v < 0 { if v < 0 {
return castOverFlowError return errCastOverFlow
} }
n.UInt64, n.Valid = uint64(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil

View file

@ -103,25 +103,25 @@ func TestNullUInt64(t *testing.T) {
//Validate negative use cases //Validate negative use cases
err := nullUInt64.Scan(int64(-5)) err := nullUInt64.Scan(int64(-5))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError) assert.Error(t, err, errCastOverFlow)
//Validate negative use cases //Validate negative use cases
err = nullUInt64.Scan(-5) err = nullUInt64.Scan(-5)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError) assert.Error(t, err, errCastOverFlow)
//Validate negative use cases //Validate negative use cases
err = nullUInt64.Scan(int32(-5)) err = nullUInt64.Scan(int32(-5))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError) assert.Error(t, err, errCastOverFlow)
//Validate negative use cases //Validate negative use cases
err = nullUInt64.Scan(int16(-5)) err = nullUInt64.Scan(int16(-5))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError) assert.Error(t, err, errCastOverFlow)
//Validate negative use cases //Validate negative use cases
err = nullUInt64.Scan(int8(-5)) err = nullUInt64.Scan(int8(-5))
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError) assert.Error(t, err, errCastOverFlow)
} }

View file

@ -3,6 +3,7 @@ package qrm
import ( import (
"context" "context"
"database/sql" "database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/must" "github.com/go-jet/jet/v2/internal/utils/must"
@ -12,10 +13,130 @@ import (
// ErrNoRows is returned by Query when query result set is empty // ErrNoRows is returned by Query when query result set is empty
var ErrNoRows = errors.New("qrm: no rows in result set") var ErrNoRows = errors.New("qrm: no rows in result set")
// Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` // QueryJsonObj executes a SQL query that returns a JSON object, unmarshals the result into the provided destination,
// using context `ctx` into destination `destPtr`. // and returns the number of rows processed.
// Destination can be either pointer to struct or pointer to slice of structs. //
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // The query must return exactly one row with a single column; otherwise, an error is returned.
//
// Parameters:
//
// ctx - The context for managing query execution (timeouts, cancellations).
// db - The database connection or transaction that implements the Queryable interface.
// query - The SQL query string to be executed.
// args - A slice of arguments to be used with the query.
// destPtr - A pointer to the variable where the unmarshaled JSON result will be stored.
// The destination should be a pointer to a struct or map[string]any.
//
// Returns:
//
// rowsProcessed - The number of rows processed by the query execution.
// err - An error if query execution or unmarshaling fails.
func QueryJsonObj(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(destPtr, "jet: destination is nil")
must.BeTypeKind(destPtr, reflect.Ptr, jsonDestObjErr)
destType := reflect.TypeOf(destPtr).Elem()
must.BeTrue(destType.Kind() == reflect.Struct || destType.Kind() == reflect.Map, jsonDestObjErr)
return queryJson(ctx, db, query, args, destPtr)
}
// QueryJsonArr executes a SQL query that returns a JSON array, unmarshals the result into the provided destination,
// and returns the number of rows processed.
//
// The query must return exactly one row with a single column; otherwise, an error is returned.
//
// Parameters:
//
// ctx - The context for managing query execution (timeouts, cancellations).
// db - The database connection or transaction that implements the Queryable interface.
// query - The SQL query string to be executed.
// args - A slice of arguments to be used with the query.
// destPtr - A pointer to the variable where the unmarshaled JSON array will be stored.
// The destination should be a pointer to a slice of structs or []map[string]any.
//
// Returns:
//
// rowsProcessed - The number of rows processed by the query execution.
// err - An error if query execution or unmarshaling fails.
func QueryJsonArr(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(destPtr, "jet: destination is nil")
must.BeTypeKind(destPtr, reflect.Ptr, jsonDestArrErr)
destType := reflect.TypeOf(destPtr).Elem()
must.BeTrue(destType.Kind() == reflect.Slice, jsonDestArrErr)
return queryJson(ctx, db, query, args, destPtr)
}
var jsonDestObjErr = "jet: SELECT_JSON_OBJ destination has to be a pointer to struct or pointer to map[string]any"
var jsonDestArrErr = "jet: SELECT_JSON_ARR destination has to be a pointer to slice of struct or pointer to []map[string]any"
func queryJson(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(db, "jet: db is nil")
var rows *sql.Rows
rows, err = db.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
err = rows.Err()
if err != nil {
return 0, err
}
return 0, ErrNoRows
}
var jsonData []byte
err = rows.Scan(&jsonData)
if err != nil {
return 1, err
}
if jsonData == nil {
return 1, nil
}
err = json.Unmarshal(jsonData, &destPtr)
if err != nil {
return 1, fmt.Errorf("jet: invalid json, %w", err)
}
if rows.Next() {
return 1, fmt.Errorf("jet: query returned more then one row")
}
err = rows.Close()
if err != nil {
return 1, err
}
return 1, nil
}
// Query executes a Query Result Mapping (QRM) of the provided SQL `query` with a list of parameterized arguments `args`
// over the database connection `db` using the provided context `ctx` and stores the result in the destination `destPtr`.
//
// The destination must be a pointer to either a struct or a slice of structs
// If the destination is a pointer to a struct and no rows are returned, the method returns qrm.ErrNoRows.
//
// Parameters:
//
// ctx - The context for managing query execution (timeouts, cancellations).
// db - The database connection or transaction implementing the Queryable interface.
// query - The SQL query string to be executed.
// args - A slice of arguments to be used with the query.
// destPtr - A pointer to the variable where the query result will be stored. This can be a pointer to a struct or a slice of structs.
//
// Returns:
//
// rowsProcessed - The number of rows processed by the query execution.
// err - An error if query execution or result mapping fails, or if no rows are found when a struct is expected.
func Query(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { func Query(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(db, "jet: db is nil") must.BeInitializedPtr(db, "jet: db is nil")
@ -185,7 +306,7 @@ func mapRowToSlice(
func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
index := 0 index := 0
if field != nil { if field != nil {
typeName, columnName := getTypeAndFieldName("", *field) typeName, columnName, _ := getTypeAndFieldName("", *field)
if index = scanContext.typeToColumnIndex(typeName, columnName); index < 0 { if index = scanContext.typeToColumnIndex(typeName, columnName); index < 0 {
return return
} }
@ -233,9 +354,11 @@ func mapRowToStruct(
continue continue
} }
fieldMap := typeInf.fieldMappings[i] fieldMappingInfo := typeInf.fieldMappings[i]
if fieldMap.complexType { switch fieldMappingInfo.Type {
case complexType:
var changed bool var changed bool
changed, err = mapRowToDestinationValue(scanContext, concat(groupKey, ":", field.Name), fieldValue, &field) changed, err = mapRowToDestinationValue(scanContext, concat(groupKey, ":", field.Name), fieldValue, &field)
@ -246,13 +369,12 @@ func mapRowToStruct(
if changed { if changed {
updated = true updated = true
} }
default:
} else { if mapOnlySlices || fieldMappingInfo.rowIndex == -1 {
if mapOnlySlices || fieldMap.rowIndex == -1 {
continue continue
} }
scannedValue := scanContext.rowElemValue(fieldMap.rowIndex) scannedValue := scanContext.rowElemValue(fieldMappingInfo.rowIndex)
if !scannedValue.IsValid() { if !scannedValue.IsValid() {
setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value
@ -261,7 +383,8 @@ func mapRowToStruct(
updated = true updated = true
if fieldMap.implementsScanner { switch fieldMappingInfo.Type {
case implementsScanner:
initializeValueIfNilPtr(fieldValue) initializeValueIfNilPtr(fieldValue)
fieldScanner := getScanner(fieldValue) fieldScanner := getScanner(fieldValue)
@ -270,14 +393,27 @@ func mapRowToStruct(
err := fieldScanner.Scan(value) err := fieldScanner.Scan(value)
if err != nil { if err != nil {
return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err) return updated, qrmAssignError(scannedValue, field, err)
} }
} else { case jsonUnmarshal:
value, ok := scannedValue.Interface().([]byte)
if !ok {
return updated, qrmAssignError(scannedValue, field, fmt.Errorf("value not convertable to []byte"))
}
fieldInterface := fieldValue.Addr().Interface()
err := json.Unmarshal(value, fieldInterface)
if err != nil {
return updated, qrmAssignError(scannedValue, field, fmt.Errorf("invalid json, %w", err))
}
default: // simple type
err := assign(scannedValue, fieldValue) err := assign(scannedValue, fieldValue)
if err != nil { if err != nil {
return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(), return updated, qrmAssignError(scannedValue, field, err)
field.Name, field.Type.String(), err)
} }
} }
} }
@ -286,6 +422,11 @@ func mapRowToStruct(
return return
} }
func qrmAssignError(scannedValue reflect.Value, field reflect.StructField, err error) error {
return fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(),
field.Name, field.Type.String(), err)
}
func mapRowToDestinationValue( func mapRowToDestinationValue(
scanContext *ScanContext, scanContext *ScanContext,
groupKey string, groupKey string,

View file

@ -75,10 +75,18 @@ type typeInfo struct {
fieldMappings []fieldMapping fieldMappings []fieldMapping
} }
type fieldMappingType int
const (
simpleType fieldMappingType = iota
complexType // slice and struct are complex types supported
implementsScanner
jsonUnmarshal
)
type fieldMapping struct { type fieldMapping struct {
complexType bool // slice and struct are complex types
rowIndex int // index in ScanContext.row rowIndex int // index in ScanContext.row
implementsScanner bool Type fieldMappingType
} }
func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo {
@ -100,17 +108,21 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
for i := 0; i < structType.NumField(); i++ { for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i) field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field) newTypeName, fieldName, jsonUnmarshaler := getTypeAndFieldName(typeName, field)
columnIndex := s.typeToColumnIndex(newTypeName, fieldName) columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{ fieldMap := fieldMapping{
rowIndex: columnIndex, rowIndex: columnIndex,
} }
if implementsScannerType(field.Type) { if jsonUnmarshaler {
fieldMap.implementsScanner = true fieldMap.Type = jsonUnmarshal
} else if implementsScannerType(field.Type) {
fieldMap.Type = implementsScanner
} else if !isSimpleModelType(field.Type) { } else if !isSimpleModelType(field.Type) {
fieldMap.complexType = true fieldMap.Type = complexType
} else {
fieldMap.Type = simpleType
} }
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap) newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
@ -188,7 +200,7 @@ func (s *ScanContext) getGroupKeyInfo(
fieldType := indirectType(field.Type) fieldType := indirectType(field.Type)
if isPrimaryKey(field, primaryKeyOverwrites) { if isPrimaryKey(field, primaryKeyOverwrites) {
newTypeName, fieldName := getTypeAndFieldName(typeName, field) newTypeName, fieldName, _ := getTypeAndFieldName(typeName, field)
pkIndex := s.typeToColumnIndex(newTypeName, fieldName) pkIndex := s.typeToColumnIndex(newTypeName, fieldName)

View file

@ -107,20 +107,26 @@ func getTypeName(structType reflect.Type, parentField *reflect.StructField) stri
return toCommonIdentifier(aliasParts[0]) return toCommonIdentifier(aliasParts[0])
} }
func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) { func getTypeAndFieldName(structType string, field reflect.StructField) (string, string, bool) {
aliasTag := field.Tag.Get("alias") aliasTag := field.Tag.Get("alias")
if aliasTag == "" { if aliasTag != "" {
return structType, field.Name
}
aliasParts := strings.Split(aliasTag, ".") aliasParts := strings.Split(aliasTag, ".")
if len(aliasParts) == 1 { if len(aliasParts) == 1 {
return structType, toCommonIdentifier(aliasParts[0]) return structType, toCommonIdentifier(aliasParts[0]), false
} }
return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]) return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]), false
}
jsonColumnTag := field.Tag.Get("json_column")
if jsonColumnTag != "" {
return "", toCommonIdentifier(jsonColumnTag), true
}
return structType, field.Name, false
} }
var replacer = strings.NewReplacer(" ", "", "-", "", "_", "") var replacer = strings.NewReplacer(" ", "", "-", "", "_", "")

View file

@ -42,6 +42,6 @@ func (c *cast) AS_REAL() FloatExpression {
} }
// AS_BLOB cast expression to BLOB type // AS_BLOB cast expression to BLOB type
func (c *cast) AS_BLOB() StringExpression { func (c *cast) AS_BLOB() BlobExpression {
return StringExp(c.AS("BLOB")) return BlobExp(c.AS("BLOB"))
} }

View file

@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString
// StringColumn creates named string column. // StringColumn creates named string column.
var StringColumn = jet.StringColumn var StringColumn = jet.StringColumn
// ColumnBlob is interface for
type ColumnBlob = jet.ColumnBlob
// BlobColumn creates new named blob column
var BlobColumn = jet.BlobColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns. // ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger type ColumnInteger = jet.ColumnInteger

View file

@ -1,6 +1,7 @@
package sqlite package sqlite
import ( import (
"encoding/hex"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )
@ -23,6 +24,7 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(int) string { ArgumentPlaceholder: func(int) string {
return "?" return "?"
}, },
ArgumentToString: argumentToString,
ReservedWords: reservedWords2, ReservedWords: reservedWords2,
ValuesDefaultColumnName: func(index int) string { ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column%d", index+1) return fmt.Sprintf("column%d", index+1)
@ -32,6 +34,15 @@ func newDialect() jet.Dialect {
return jet.NewDialect(mySQLDialectParams) return jet.NewDialect(mySQLDialectParams)
} }
func argumentToString(value any) (string, bool) {
switch bindVal := value.(type) {
case []byte:
return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true
}
return "", false
}
func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc { func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {

View file

@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression
// StringExpression interface // StringExpression interface
type StringExpression = jet.StringExpression type StringExpression = jet.StringExpression
// BlobExpression interface
type BlobExpression = jet.BlobExpression
// NumericExpression is shared interface for integer or real expression // NumericExpression is shared interface for integer or real expression
type NumericExpression = jet.NumericExpression type NumericExpression = jet.NumericExpression
@ -46,6 +49,11 @@ var BoolExp = jet.BoolExp
// Does not add sql cast to generated sql builder output. // Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp var StringExp = jet.StringExp
// BlobExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as blob expression.
// Does not add sql cast to generated sql builder output.
var BlobExp = jet.BlobExp
// IntExp is int expression wrapper around arbitrary expression. // IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression. // Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output. // Does not add sql cast to generated sql builder output.

View file

@ -196,11 +196,22 @@ var RTRIM = jet.RTRIM
// return jet.NewStringFunc("RIGHTSTR", str, n) // return jet.NewStringFunc("RIGHTSTR", str, n)
//} //}
// HEX function takes an input and returns its equivalent hexadecimal representation
var HEX = jet.HEX
// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument
// as a hexadecimal number and converts it to the byte represented by the number.
// The return value is a binary string.
var UNHEX = jet.UNHEX
// LENGTH returns number of characters in string with a given encoding // LENGTH returns number of characters in string with a given encoding
func LENGTH(str jet.StringExpression) jet.StringExpression { func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression {
return jet.LENGTH(str) return jet.LENGTH(str)
} }
// OCTET_LENGTH returns number of bytes in string expression
var OCTET_LENGTH = jet.OCTET_LENGTH
// LPAD fills up the string to length length by prepending the characters // LPAD fills up the string to length length by prepending the characters
// fill (a space by default). If the string is already longer than length // fill (a space by default). If the string is already longer than length
// then it is truncated (on the right). // then it is truncated (on the right).

View file

@ -50,6 +50,11 @@ var Decimal = jet.Decimal
// String creates new string literal expression // String creates new string literal expression
var String = jet.String var String = jet.String
// Blob creates new blob literal expression
func Blob(data []byte) BlobExpression {
return BlobExp(jet.Literal(data))
}
// UUID is a helper function to create string literal expression from uuid object // UUID is a helper function to create string literal expression from uuid object
// value can be any uuid type with a String method // value can be any uuid type with a String method
var UUID = jet.UUID var UUID = jet.UUID

View file

@ -26,7 +26,7 @@ services:
- ./testdata/init/mysql:/docker-entrypoint-initdb.d - ./testdata/init/mysql:/docker-entrypoint-initdb.d
mariadb: mariadb:
image: mariadb:10.3 image: mariadb:11.4
command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1']
restart: always restart: always
environment: environment:

View file

@ -23,19 +23,115 @@ func TestAllTypes(t *testing.T) {
var dest []model.AllTypes var dest []model.AllTypes
err := AllTypes. err := SELECT(AllTypes.AllColumns).
SELECT(AllTypes.AllColumns). FROM(AllTypes).
LIMIT(2). LIMIT(2).
Query(db, &dest) Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, allTypesJson) testutils.AssertJSON(t, dest, allTypesJson)
} }
func TestAllTypesJSON(t *testing.T) {
stmt := SELECT_JSON_ARR(
AllTypes.AllColumns.Except(
AllTypes.JSON,
AllTypes.JSONPtr,
AllTypes.Bit,
AllTypes.BitPtr,
),
CAST(AllTypes.JSON).AS_CHAR().AS("Json"),
CAST(AllTypes.JSONPtr).AS_CHAR().AS("JsonPtr"),
CAST(AllTypes.Bit).AS_CHAR().AS("Bit"),
CAST(AllTypes.BitPtr).AS_CHAR().AS("BitPtr"),
).FROM(AllTypes)
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'id', all_types.id,
'boolean', all_types.boolean = 1,
'booleanPtr', all_types.boolean_ptr = 1,
'tinyInt', all_types.tiny_int,
'uTinyInt', all_types.u_tiny_int,
'smallInt', all_types.small_int,
'uSmallInt', all_types.u_small_int,
'mediumInt', all_types.medium_int,
'uMediumInt', all_types.u_medium_int,
'integer', all_types.''integer'',
'uInteger', all_types.u_integer,
'bigInt', all_types.big_int,
'uBigInt', all_types.u_big_int,
'tinyIntPtr', all_types.tiny_int_ptr,
'uTinyIntPtr', all_types.u_tiny_int_ptr,
'smallIntPtr', all_types.small_int_ptr,
'uSmallIntPtr', all_types.u_small_int_ptr,
'mediumIntPtr', all_types.medium_int_ptr,
'uMediumIntPtr', all_types.u_medium_int_ptr,
'integerPtr', all_types.integer_ptr,
'uIntegerPtr', all_types.u_integer_ptr,
'bigIntPtr', all_types.big_int_ptr,
'uBigIntPtr', all_types.u_big_int_ptr,
'decimal', all_types.''decimal'',
'decimalPtr', all_types.decimal_ptr,
'numeric', all_types.''numeric'',
'numericPtr', all_types.numeric_ptr,
'float', all_types.''float'',
'floatPtr', all_types.float_ptr,
'double', all_types.''double'',
'doublePtr', all_types.double_ptr,
'real', all_types.''real'',
'realPtr', all_types.real_ptr,
'time', CONCAT('0000-01-01T', DATE_FORMAT(all_types.time,'%H:%i:%s.%fZ')),
'timePtr', CONCAT('0000-01-01T', DATE_FORMAT(all_types.time_ptr,'%H:%i:%s.%fZ')),
'date', CONCAT(DATE_FORMAT(all_types.date,'%Y-%m-%d'), 'T00:00:00Z'),
'datePtr', CONCAT(DATE_FORMAT(all_types.date_ptr,'%Y-%m-%d'), 'T00:00:00Z'),
'dateTime', DATE_FORMAT(all_types.date_time,'%Y-%m-%dT%H:%i:%s.%fZ'),
'dateTimePtr', DATE_FORMAT(all_types.date_time_ptr,'%Y-%m-%dT%H:%i:%s.%fZ'),
'timestamp', DATE_FORMAT(all_types.timestamp,'%Y-%m-%dT%H:%i:%s.%fZ'),
'timestampPtr', DATE_FORMAT(all_types.timestamp_ptr,'%Y-%m-%dT%H:%i:%s.%fZ'),
'year', all_types.year,
'yearPtr', all_types.year_ptr,
'char', all_types.''char'',
'charPtr', all_types.char_ptr,
'varChar', all_types.var_char,
'varCharPtr', all_types.var_char_ptr,
'binary', TO_BASE64(all_types.''binary''),
'binaryPtr', TO_BASE64(all_types.binary_ptr),
'varBinary', TO_BASE64(all_types.var_binary),
'varBinaryPtr', TO_BASE64(all_types.var_binary_ptr),
'blob', TO_BASE64(all_types.''blob''),
'blobPtr', TO_BASE64(all_types.blob_ptr),
'text', all_types.text,
'textPtr', all_types.text_ptr,
'enum', all_types.enum,
'enumPtr', all_types.enum_ptr,
'set', all_types.''set'',
'setPtr', all_types.set_ptr,
'Json', CAST(all_types.json AS CHAR),
'JsonPtr', CAST(all_types.json_ptr AS CHAR),
'Bit', CAST(all_types.bit AS CHAR),
'BitPtr', CAST(all_types.bit_ptr AS CHAR)
)) AS "json"
FROM test_sample.all_types;
`, "''", "`"))
var dest []model.AllTypes
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
// fix float rounding lost before comparison
dest[0].Float = 3.33
dest[0].FloatPtr = ptr.Of(3.33)
dest[1].Float = 3.33
testutils.AssertJSON(t, dest, allTypesJson)
}
func TestAllTypesViewSelect(t *testing.T) { func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes type AllTypesView model.AllTypes
@ -467,7 +563,8 @@ func TestStringOperators(t *testing.T) {
RTRIM(AllTypes.VarCharPtr), RTRIM(AllTypes.VarCharPtr),
CONCAT(String("string1"), Int(1), Float(11.12)), CONCAT(String("string1"), Int(1), Float(11.12)),
CONCAT_WS(String("string1"), Int(1), Float(11.12)), CONCAT_WS(String("string1"), Int(1), Float(11.12)),
FORMAT(String("Hello %s, %1$s"), String("World")), FORMAT(Int(11), Int(2)),
FORMAT(Int(11), Int(2), String("de_DE")),
LEFT(String("abcde"), Int(2)), LEFT(String("abcde"), Int(2)),
RIGHT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)),
LENGTH(String("jose")), LENGTH(String("jose")),
@ -479,6 +576,12 @@ func TestStringOperators(t *testing.T) {
REVERSE(AllTypes.VarCharPtr), REVERSE(AllTypes.VarCharPtr),
SUBSTR(AllTypes.CharPtr, Int(3)), SUBSTR(AllTypes.CharPtr, Int(3)),
SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), SUBSTR(AllTypes.CharPtr, Int(3), Int(2)),
ELT(Int(2), AllTypes.CharPtr, AllTypes.Char, AllTypes.Text),
FIELD(AllTypes.Char, AllTypes.VarChar, AllTypes.Text),
FROM_BASE64(String("SGVsbG8gV29ybGQ=")),
TO_BASE64(String("Hello World")),
CHARSET(AllTypes.Char),
COLLATION(AllTypes.Text),
} }
if !sourceIsMariaDB() { if !sourceIsMariaDB() {
@ -500,6 +603,71 @@ func TestStringOperators(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestBlob(t *testing.T) {
var sampleBlob = Blob([]byte{11, 0, 22, 33, 44})
var textBlob = Blob([]byte("text blob"))
stmt := SELECT(
AllTypes.BlobPtr.EQ(sampleBlob),
AllTypes.BlobPtr.EQ(AllTypes.BlobPtr),
AllTypes.BlobPtr.NOT_EQ(sampleBlob),
AllTypes.BlobPtr.GT(textBlob),
AllTypes.BlobPtr.GT_EQ(AllTypes.BlobPtr),
AllTypes.BlobPtr.LT(AllTypes.BlobPtr),
AllTypes.BlobPtr.LT_EQ(sampleBlob),
AllTypes.BlobPtr.BETWEEN(Blob([]byte("min")), Blob([]byte("max"))),
AllTypes.BlobPtr.NOT_BETWEEN(AllTypes.BlobPtr, AllTypes.BlobPtr),
AllTypes.BlobPtr.CONCAT(textBlob),
AllTypes.BlobPtr.LIKE(AllTypes.BlobPtr),
AllTypes.BlobPtr.NOT_LIKE(sampleBlob),
BIT_LENGTH(textBlob),
LENGTH(sampleBlob),
CHAR_LENGTH(AllTypes.BlobPtr),
OCTET_LENGTH(textBlob),
CONCAT(sampleBlob, Int(1), Float(11.12)),
TO_BASE64(sampleBlob),
HEX(sampleBlob),
UNHEX(String("616B263A")),
SUBSTR(AllTypes.BlobPtr, Int(3)),
SUBSTR(AllTypes.BlobPtr, Int(3), Int(2)),
).FROM(
AllTypes,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT all_types.blob_ptr = X'0b0016212c',
all_types.blob_ptr = all_types.blob_ptr,
all_types.blob_ptr != X'0b0016212c',
all_types.blob_ptr > X'7465787420626c6f62',
all_types.blob_ptr >= all_types.blob_ptr,
all_types.blob_ptr < all_types.blob_ptr,
all_types.blob_ptr <= X'0b0016212c',
all_types.blob_ptr BETWEEN X'6d696e' AND X'6d6178',
all_types.blob_ptr NOT BETWEEN all_types.blob_ptr AND all_types.blob_ptr,
CONCAT(all_types.blob_ptr, X'7465787420626c6f62'),
all_types.blob_ptr LIKE all_types.blob_ptr,
all_types.blob_ptr NOT LIKE X'0b0016212c',
BIT_LENGTH(X'7465787420626c6f62'),
LENGTH(X'0b0016212c'),
CHAR_LENGTH(all_types.blob_ptr),
OCTET_LENGTH(X'7465787420626c6f62'),
CONCAT(X'0b0016212c', 1, 11.12),
TO_BASE64(X'0b0016212c'),
HEX(X'0b0016212c'),
UNHEX('616B263A'),
SUBSTR(all_types.blob_ptr, 3),
SUBSTR(all_types.blob_ptr, 3, 2)
FROM test_sample.all_types;
`)
var dest []struct{}
err := stmt.Query(db, &dest)
require.NoError(t, err)
}
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func TestTimeExpressions(t *testing.T) { func TestTimeExpressions(t *testing.T) {
@ -1066,6 +1234,118 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestAllTypesSubQueryFrom(t *testing.T) {
subQuery := SELECT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Double,
AllTypes.Text,
AllTypes.Date,
AllTypes.Time,
AllTypes.Timestamp,
AllTypes.Blob,
).FROM(
AllTypes,
).AsTable("sub_query")
stmt := SELECT(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.Double.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Blob.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
SELECT sub_query.''all_types.boolean'' AS "all_types.boolean",
sub_query.''all_types.integer'' AS "all_types.integer",
sub_query.''all_types.double'' AS "all_types.double",
sub_query.''all_types.text'' AS "all_types.text",
sub_query.''all_types.date'' AS "all_types.date",
sub_query.''all_types.time'' AS "all_types.time",
sub_query.''all_types.timestamp'' AS "all_types.timestamp",
sub_query.''all_types.blob'' AS "all_types.blob"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.''integer'' AS "all_types.integer",
all_types.''double'' AS "all_types.double",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timestamp AS "all_types.timestamp",
all_types.''blob'' AS "all_types.blob"
FROM test_sample.all_types
) AS sub_query;
`, "''", "`"))
var dest []model.AllTypes
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.NotEmpty(t, dest)
t.Run("using SELECT_JSON", func(t *testing.T) {
stmtJson := SELECT_JSON_ARR(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.Double.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Blob.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertDebugStatementSql(t, stmtJson, strings.ReplaceAll(`
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'boolean', sub_query.''all_types.boolean'' = 1,
'integer', sub_query.''all_types.integer'',
'double', sub_query.''all_types.double'',
'text', sub_query.''all_types.text'',
'date', CONCAT(DATE_FORMAT(sub_query.''all_types.date'','%Y-%m-%d'), 'T00:00:00Z'),
'time', CONCAT('0000-01-01T', DATE_FORMAT(sub_query.''all_types.time'','%H:%i:%s.%fZ')),
'timestamp', DATE_FORMAT(sub_query.''all_types.timestamp'','%Y-%m-%dT%H:%i:%s.%fZ'),
'blob', TO_BASE64(sub_query.''all_types.blob'')
)) AS "json"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.''integer'' AS "all_types.integer",
all_types.''double'' AS "all_types.double",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timestamp AS "all_types.timestamp",
all_types.''blob'' AS "all_types.blob"
FROM test_sample.all_types
) AS sub_query;
`, "''", "`"))
var destJson []model.AllTypes
err := stmtJson.QueryContext(ctx, db, &destJson)
require.NoError(t, err)
t.Run("using AllColumns()", func(t *testing.T) {
stmtJsonAllColumns := SELECT_JSON_ARR(
subQuery.AllColumns(),
).FROM(
subQuery,
)
require.Equal(t, stmtJson.DebugSql(), stmtJsonAllColumns.DebugSql())
})
testutils.AssertJsonEqual(t, dest, destJson)
})
}
var toInsert = model.AllTypes{ var toInsert = model.AllTypes{
Boolean: false, Boolean: false,
BooleanPtr: ptr.Of(true), BooleanPtr: ptr.Of(true),
@ -1131,6 +1411,7 @@ var toInsert = model.AllTypes{
var allTypesJson = ` var allTypesJson = `
[ [
{ {
"ID": 1,
"Boolean": false, "Boolean": false,
"BooleanPtr": true, "BooleanPtr": true,
"TinyInt": -3, "TinyInt": -3,
@ -1195,6 +1476,7 @@ var allTypesJson = `
"JSONPtr": "{\"key1\": \"value1\", \"key2\": \"value2\"}" "JSONPtr": "{\"key1\": \"value1\", \"key2\": \"value2\"}"
}, },
{ {
"ID": 2,
"Boolean": false, "Boolean": false,
"BooleanPtr": null, "BooleanPtr": null,
"TinyInt": -3, "TinyInt": -3,

138
tests/mysql/bench_test.go Normal file
View file

@ -0,0 +1,138 @@
//go:build bench
// +build bench
package mysql
import (
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/require"
"testing"
)
type allInfo []struct {
model.Actor
Films []struct {
model.Film
Language model.Language
Categories []model.Category
Inventories []struct {
model.Inventory
Rentals []struct {
model.Rental
Customer model.Customer
}
}
}
}
func BenchmarkTestDVDsJoinEverything(b *testing.B) {
for i := 0; i < b.N; i++ {
testDVDsJoinEverything(b)
}
}
func TestDVDsJoinEverything(t *testing.T) {
testDVDsJoinEverything(t)
}
func testDVDsJoinEverything(t require.TestingT) {
stmt := SELECT(
Actor.AllColumns,
Film.AllColumns,
Language.AllColumns,
Category.AllColumns,
Inventory.AllColumns,
Rental.AllColumns,
Customer.AllColumns,
).FROM(
Actor.
LEFT_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)).
LEFT_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)).
LEFT_JOIN(Language, Language.LanguageID.EQ(Film.LanguageID)).
LEFT_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)).
LEFT_JOIN(Category, Category.CategoryID.EQ(FilmCategory.CategoryID)).
LEFT_JOIN(Inventory, Inventory.FilmID.EQ(Film.FilmID)).
LEFT_JOIN(Rental, Rental.InventoryID.EQ(Inventory.InventoryID)).
LEFT_JOIN(Customer, Customer.CustomerID.EQ(Rental.CustomerID)),
).ORDER_BY(
Actor.ActorID.ASC(),
Film.FilmID.ASC(),
Category.CategoryID.ASC(),
Inventory.InventoryID.ASC(),
Rental.RentalID.ASC(),
)
var dest allInfo
err := stmt.Query(db, &dest)
require.NoError(t, err)
//testutils.SaveJSONFile(dest, "./testdata/results/mysql/dvds_join_everything.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/dvds_join_everything.json")
}
func BenchmarkTestDVDsJoinEverythingJSON(b *testing.B) {
for i := 0; i < b.N; i++ {
testDVDsJoinEverythingJSON(b)
}
}
func TestDVDsJoinEverythingJSON(t *testing.T) {
testDVDsJoinEverythingJSON(t)
}
func testDVDsJoinEverythingJSON(t require.TestingT) {
stmt := SELECT_JSON_ARR(
Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate,
SELECT_JSON_ARR(
Film.AllColumns,
SELECT_JSON_OBJ(Language.AllColumns).
FROM(Language).
WHERE(Language.LanguageID.EQ(Film.LanguageID)).AS("Language"),
SELECT_JSON_ARR(Category.AllColumns).
FROM(Category.INNER_JOIN(FilmCategory, FilmCategory.CategoryID.EQ(Category.CategoryID))).
WHERE(FilmCategory.FilmID.EQ(Film.FilmID)).AS("Categories"),
SELECT_JSON_ARR(
Inventory.AllColumns,
SELECT_JSON_ARR(
Rental.AllColumns,
SELECT_JSON_OBJ(Customer.AllColumns).
FROM(Customer).
WHERE(Customer.CustomerID.EQ(Rental.CustomerID)).AS("Customer"),
).FROM(Rental).
WHERE(Rental.InventoryID.EQ(Inventory.InventoryID)).
ORDER_BY(Rental.RentalID).AS("Rentals"),
).FROM(Inventory).
WHERE(Inventory.FilmID.EQ(Film.FilmID)).
ORDER_BY(Inventory.InventoryID).AS("Inventories"),
).FROM(Film.
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)),
).WHERE(FilmActor.ActorID.EQ(Actor.ActorID)).
ORDER_BY(Film.FilmID.ASC()).AS("Films"),
).FROM(Actor).
ORDER_BY(Actor.ActorID.ASC())
//fmt.Println(stmt.DebugSql())
var dest allInfo
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
//testutils.SaveJSONFile(dest, "./testdata/results/mysql/dvds_join_everything2.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/dvds_join_everything.json")
}

View file

@ -3,6 +3,7 @@ package mysql
import ( import (
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"strconv" "strconv"
"testing" "testing"
@ -304,7 +305,7 @@ func newLinkTableImpl(schemaName, tableName, alias string) linkTable {
DescriptionColumn = mysql.StringColumn("description") DescriptionColumn = mysql.StringColumn("description")
allColumns = mysql.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn} allColumns = mysql.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn}
mutableColumns = mysql.ColumnList{URLColumn, NameColumn, DescriptionColumn} mutableColumns = mysql.ColumnList{URLColumn, NameColumn, DescriptionColumn}
defaultColumns = mysql.ColumnList{DescriptionColumn} defaultColumns = mysql.ColumnList{}
) )
return linkTable{ return linkTable{
@ -606,3 +607,398 @@ func UseSchema(schema string) {
StaffList = StaffList.FromSchema(schema) StaffList = StaffList.FromSchema(schema)
} }
` `
func TestGeneratedTestSampleDatabase(t *testing.T) {
enumDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/enum/")
modelDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/model/")
tableDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/table/")
viewDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/view/")
testutils.AssertFileNamesEqual(t, enumDir, "all_types_enum.go", "all_types_enum_ptr.go",
"all_types_view_enum.go", "all_types_view_enum_ptr.go")
testutils.AssertFileContent(t, enumDir+"/all_types_enum.go", allTypesEnum)
testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_enum.go", "all_types_enum_ptr.go",
"all_types_view.go", "all_types_view_enum.go", "all_types_view_enum_ptr.go", "link.go", "link2.go",
"floats.go", "user.go")
testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent)
testutils.AssertFileNamesEqual(t, tableDir, "all_types.go",
"link.go", "link2.go", "user.go", "floats.go", "table_use_schema.go")
testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent)
testutils.AssertFileNamesEqual(t, viewDir, "all_types_view.go", "view_use_schema.go")
}
var allTypesEnum = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package enum
import "github.com/go-jet/jet/v2/mysql"
var AllTypesEnum = &struct {
Value1 mysql.StringExpression
Value2 mysql.StringExpression
Value3 mysql.StringExpression
}{
Value1: mysql.NewEnumValue("value1"),
Value2: mysql.NewEnumValue("value2"),
Value3: mysql.NewEnumValue("value3"),
}
`
var allTypesModelContent = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type AllTypes struct {
ID int32 ` + "`" + `sql:"primary_key"` + "`" + `
Boolean bool
BooleanPtr *bool
TinyInt int8
UTinyInt uint8
SmallInt int16
USmallInt uint16
MediumInt int32
UMediumInt uint32
Integer int32
UInteger uint32
BigInt int64
UBigInt uint64
TinyIntPtr *int8
UTinyIntPtr *uint8
SmallIntPtr *int16
USmallIntPtr *uint16
MediumIntPtr *int32
UMediumIntPtr *uint32
IntegerPtr *int32
UIntegerPtr *uint32
BigIntPtr *int64
UBigIntPtr *uint64
Decimal float64
DecimalPtr *float64
Numeric float64
NumericPtr *float64
Float float64
FloatPtr *float64
Double float64
DoublePtr *float64
Real float64
RealPtr *float64
Bit string
BitPtr *string
Time time.Time
TimePtr *time.Time
Date time.Time
DatePtr *time.Time
DateTime time.Time
DateTimePtr *time.Time
Timestamp time.Time
TimestampPtr *time.Time
Year int16
YearPtr *int16
Char string
CharPtr *string
VarChar string
VarCharPtr *string
Binary []byte
BinaryPtr *[]byte
VarBinary []byte
VarBinaryPtr *[]byte
Blob []byte
BlobPtr *[]byte
Text string
TextPtr *string
Enum AllTypesEnum
EnumPtr *AllTypesEnumPtr
Set string
SetPtr *string
JSON string
JSONPtr *string
}
`
var allTypesTableContent = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/mysql"
)
var AllTypes = newAllTypesTable("test_sample", "all_types", "")
type allTypesTable struct {
mysql.Table
// Columns
ID mysql.ColumnInteger
Boolean mysql.ColumnBool
BooleanPtr mysql.ColumnBool
TinyInt mysql.ColumnInteger
UTinyInt mysql.ColumnInteger
SmallInt mysql.ColumnInteger
USmallInt mysql.ColumnInteger
MediumInt mysql.ColumnInteger
UMediumInt mysql.ColumnInteger
Integer mysql.ColumnInteger
UInteger mysql.ColumnInteger
BigInt mysql.ColumnInteger
UBigInt mysql.ColumnInteger
TinyIntPtr mysql.ColumnInteger
UTinyIntPtr mysql.ColumnInteger
SmallIntPtr mysql.ColumnInteger
USmallIntPtr mysql.ColumnInteger
MediumIntPtr mysql.ColumnInteger
UMediumIntPtr mysql.ColumnInteger
IntegerPtr mysql.ColumnInteger
UIntegerPtr mysql.ColumnInteger
BigIntPtr mysql.ColumnInteger
UBigIntPtr mysql.ColumnInteger
Decimal mysql.ColumnFloat
DecimalPtr mysql.ColumnFloat
Numeric mysql.ColumnFloat
NumericPtr mysql.ColumnFloat
Float mysql.ColumnFloat
FloatPtr mysql.ColumnFloat
Double mysql.ColumnFloat
DoublePtr mysql.ColumnFloat
Real mysql.ColumnFloat
RealPtr mysql.ColumnFloat
Bit mysql.ColumnString
BitPtr mysql.ColumnString
Time mysql.ColumnTime
TimePtr mysql.ColumnTime
Date mysql.ColumnDate
DatePtr mysql.ColumnDate
DateTime mysql.ColumnTimestamp
DateTimePtr mysql.ColumnTimestamp
Timestamp mysql.ColumnTimestamp
TimestampPtr mysql.ColumnTimestamp
Year mysql.ColumnInteger
YearPtr mysql.ColumnInteger
Char mysql.ColumnString
CharPtr mysql.ColumnString
VarChar mysql.ColumnString
VarCharPtr mysql.ColumnString
Binary mysql.ColumnBlob
BinaryPtr mysql.ColumnBlob
VarBinary mysql.ColumnBlob
VarBinaryPtr mysql.ColumnBlob
Blob mysql.ColumnBlob
BlobPtr mysql.ColumnBlob
Text mysql.ColumnString
TextPtr mysql.ColumnString
Enum mysql.ColumnString
EnumPtr mysql.ColumnString
Set mysql.ColumnString
SetPtr mysql.ColumnString
JSON mysql.ColumnString
JSONPtr mysql.ColumnString
AllColumns mysql.ColumnList
MutableColumns mysql.ColumnList
DefaultColumns mysql.ColumnList
}
type AllTypesTable struct {
allTypesTable
NEW allTypesTable
}
// AS creates new AllTypesTable with assigned alias
func (a AllTypesTable) AS(alias string) *AllTypesTable {
return newAllTypesTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new AllTypesTable with assigned schema name
func (a AllTypesTable) FromSchema(schemaName string) *AllTypesTable {
return newAllTypesTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new AllTypesTable with assigned table prefix
func (a AllTypesTable) WithPrefix(prefix string) *AllTypesTable {
return newAllTypesTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new AllTypesTable with assigned table suffix
func (a AllTypesTable) WithSuffix(suffix string) *AllTypesTable {
return newAllTypesTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newAllTypesTable(schemaName, tableName, alias string) *AllTypesTable {
return &AllTypesTable{
allTypesTable: newAllTypesTableImpl(schemaName, tableName, alias),
NEW: newAllTypesTableImpl("", "new", ""),
}
}
func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable {
var (
IDColumn = mysql.IntegerColumn("id")
BooleanColumn = mysql.BoolColumn("boolean")
BooleanPtrColumn = mysql.BoolColumn("boolean_ptr")
TinyIntColumn = mysql.IntegerColumn("tiny_int")
UTinyIntColumn = mysql.IntegerColumn("u_tiny_int")
SmallIntColumn = mysql.IntegerColumn("small_int")
USmallIntColumn = mysql.IntegerColumn("u_small_int")
MediumIntColumn = mysql.IntegerColumn("medium_int")
UMediumIntColumn = mysql.IntegerColumn("u_medium_int")
IntegerColumn = mysql.IntegerColumn("integer")
UIntegerColumn = mysql.IntegerColumn("u_integer")
BigIntColumn = mysql.IntegerColumn("big_int")
UBigIntColumn = mysql.IntegerColumn("u_big_int")
TinyIntPtrColumn = mysql.IntegerColumn("tiny_int_ptr")
UTinyIntPtrColumn = mysql.IntegerColumn("u_tiny_int_ptr")
SmallIntPtrColumn = mysql.IntegerColumn("small_int_ptr")
USmallIntPtrColumn = mysql.IntegerColumn("u_small_int_ptr")
MediumIntPtrColumn = mysql.IntegerColumn("medium_int_ptr")
UMediumIntPtrColumn = mysql.IntegerColumn("u_medium_int_ptr")
IntegerPtrColumn = mysql.IntegerColumn("integer_ptr")
UIntegerPtrColumn = mysql.IntegerColumn("u_integer_ptr")
BigIntPtrColumn = mysql.IntegerColumn("big_int_ptr")
UBigIntPtrColumn = mysql.IntegerColumn("u_big_int_ptr")
DecimalColumn = mysql.FloatColumn("decimal")
DecimalPtrColumn = mysql.FloatColumn("decimal_ptr")
NumericColumn = mysql.FloatColumn("numeric")
NumericPtrColumn = mysql.FloatColumn("numeric_ptr")
FloatColumn = mysql.FloatColumn("float")
FloatPtrColumn = mysql.FloatColumn("float_ptr")
DoubleColumn = mysql.FloatColumn("double")
DoublePtrColumn = mysql.FloatColumn("double_ptr")
RealColumn = mysql.FloatColumn("real")
RealPtrColumn = mysql.FloatColumn("real_ptr")
BitColumn = mysql.StringColumn("bit")
BitPtrColumn = mysql.StringColumn("bit_ptr")
TimeColumn = mysql.TimeColumn("time")
TimePtrColumn = mysql.TimeColumn("time_ptr")
DateColumn = mysql.DateColumn("date")
DatePtrColumn = mysql.DateColumn("date_ptr")
DateTimeColumn = mysql.TimestampColumn("date_time")
DateTimePtrColumn = mysql.TimestampColumn("date_time_ptr")
TimestampColumn = mysql.TimestampColumn("timestamp")
TimestampPtrColumn = mysql.TimestampColumn("timestamp_ptr")
YearColumn = mysql.IntegerColumn("year")
YearPtrColumn = mysql.IntegerColumn("year_ptr")
CharColumn = mysql.StringColumn("char")
CharPtrColumn = mysql.StringColumn("char_ptr")
VarCharColumn = mysql.StringColumn("var_char")
VarCharPtrColumn = mysql.StringColumn("var_char_ptr")
BinaryColumn = mysql.BlobColumn("binary")
BinaryPtrColumn = mysql.BlobColumn("binary_ptr")
VarBinaryColumn = mysql.BlobColumn("var_binary")
VarBinaryPtrColumn = mysql.BlobColumn("var_binary_ptr")
BlobColumn = mysql.BlobColumn("blob")
BlobPtrColumn = mysql.BlobColumn("blob_ptr")
TextColumn = mysql.StringColumn("text")
TextPtrColumn = mysql.StringColumn("text_ptr")
EnumColumn = mysql.StringColumn("enum")
EnumPtrColumn = mysql.StringColumn("enum_ptr")
SetColumn = mysql.StringColumn("set")
SetPtrColumn = mysql.StringColumn("set_ptr")
JSONColumn = mysql.StringColumn("json")
JSONPtrColumn = mysql.StringColumn("json_ptr")
allColumns = mysql.ColumnList{IDColumn, BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn}
mutableColumns = mysql.ColumnList{BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn}
defaultColumns = mysql.ColumnList{BooleanColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, DecimalColumn, NumericColumn, FloatColumn, DoubleColumn, RealColumn, BitColumn, TimeColumn, DateColumn, DateTimeColumn, TimestampColumn, YearColumn, CharColumn, VarCharColumn, BinaryColumn, VarBinaryColumn, EnumColumn, SetColumn}
)
return allTypesTable{
Table: mysql.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
ID: IDColumn,
Boolean: BooleanColumn,
BooleanPtr: BooleanPtrColumn,
TinyInt: TinyIntColumn,
UTinyInt: UTinyIntColumn,
SmallInt: SmallIntColumn,
USmallInt: USmallIntColumn,
MediumInt: MediumIntColumn,
UMediumInt: UMediumIntColumn,
Integer: IntegerColumn,
UInteger: UIntegerColumn,
BigInt: BigIntColumn,
UBigInt: UBigIntColumn,
TinyIntPtr: TinyIntPtrColumn,
UTinyIntPtr: UTinyIntPtrColumn,
SmallIntPtr: SmallIntPtrColumn,
USmallIntPtr: USmallIntPtrColumn,
MediumIntPtr: MediumIntPtrColumn,
UMediumIntPtr: UMediumIntPtrColumn,
IntegerPtr: IntegerPtrColumn,
UIntegerPtr: UIntegerPtrColumn,
BigIntPtr: BigIntPtrColumn,
UBigIntPtr: UBigIntPtrColumn,
Decimal: DecimalColumn,
DecimalPtr: DecimalPtrColumn,
Numeric: NumericColumn,
NumericPtr: NumericPtrColumn,
Float: FloatColumn,
FloatPtr: FloatPtrColumn,
Double: DoubleColumn,
DoublePtr: DoublePtrColumn,
Real: RealColumn,
RealPtr: RealPtrColumn,
Bit: BitColumn,
BitPtr: BitPtrColumn,
Time: TimeColumn,
TimePtr: TimePtrColumn,
Date: DateColumn,
DatePtr: DatePtrColumn,
DateTime: DateTimeColumn,
DateTimePtr: DateTimePtrColumn,
Timestamp: TimestampColumn,
TimestampPtr: TimestampPtrColumn,
Year: YearColumn,
YearPtr: YearPtrColumn,
Char: CharColumn,
CharPtr: CharPtrColumn,
VarChar: VarCharColumn,
VarCharPtr: VarCharPtrColumn,
Binary: BinaryColumn,
BinaryPtr: BinaryPtrColumn,
VarBinary: VarBinaryColumn,
VarBinaryPtr: VarBinaryPtrColumn,
Blob: BlobColumn,
BlobPtr: BlobPtrColumn,
Text: TextColumn,
TextPtr: TextPtrColumn,
Enum: EnumColumn,
EnumPtr: EnumPtrColumn,
Set: SetColumn,
SetPtr: SetPtrColumn,
JSON: JSONColumn,
JSONPtr: JSONPtrColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
`

View file

@ -8,6 +8,7 @@ import (
jetmysql "github.com/go-jet/jet/v2/mysql" jetmysql "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/stmtcache"
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"runtime" "runtime"
@ -21,12 +22,14 @@ var db *stmtcache.DB
var source string var source string
var withStatementCaching bool var withStatementCaching bool
var testRoot string
const MariaDB = "MariaDB" const MariaDB = "MariaDB"
func init() { func init() {
source = os.Getenv("MY_SQL_SOURCE") source = os.Getenv("MY_SQL_SOURCE")
withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true"
testRoot = repo.GetTestsDirPath()
} }
func sourceIsMariaDB() bool { func sourceIsMariaDB() bool {

View file

@ -0,0 +1,453 @@
package mysql
import (
"context"
"fmt"
"github.com/go-jet/jet/v2/qrm"
"strings"
"testing"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/require"
)
var ctx = context.Background()
func TestSelectJsonObj(t *testing.T) {
stmt := SELECT_JSON_OBJ(Actor.AllColumns).
FROM(Actor).
WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertStatementSql(t, stmt, `
SELECT JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
) AS "json"
FROM dvds.actor
WHERE actor.actor_id = ?;
`, int64(2))
var dest model.Actor
err := stmt.Query(db, &dest)
require.Nil(t, err)
testutils.AssertDeepEqual(t, dest, actor2)
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 1)
}
func TestSelectJsonObj_NestedObj(t *testing.T) {
stmt := SELECT_JSON_OBJ(
Actor.AllColumns,
SELECT_JSON_OBJ(Film.AllColumns).
FROM(FilmActor.INNER_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID))).
WHERE(Actor.ActorID.EQ(FilmActor.ActorID)).
ORDER_BY(Film.Length.DESC()).
LIMIT(1).OFFSET(3).AS("LongestFilm"),
).FROM(
Actor,
).WHERE(
Actor.ActorID.EQ(Int(2)),
)
testutils.AssertStatementSql(t, stmt, `
SELECT JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ'),
'LongestFilm', (
SELECT JSON_OBJECT(
'filmID', film.film_id,
'title', film.title,
'description', film.description,
'releaseYear', film.release_year,
'languageID', film.language_id,
'originalLanguageID', film.original_language_id,
'rentalDuration', film.rental_duration,
'rentalRate', film.rental_rate,
'length', film.length,
'replacementCost', film.replacement_cost,
'rating', film.rating,
'specialFeatures', film.special_features,
'lastUpdate', DATE_FORMAT(film.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
) AS "json"
FROM dvds.film_actor
INNER JOIN dvds.film ON (film.film_id = film_actor.film_id)
WHERE actor.actor_id = film_actor.actor_id
ORDER BY film.length DESC
LIMIT ?
OFFSET ?
)
) AS "json"
FROM dvds.actor
WHERE actor.actor_id = ?;
`)
var dest struct {
model.Actor
LongestFilm model.Film
}
err := stmt.QueryContext(ctx, db, &dest)
require.Nil(t, err)
testutils.AssertJSON(t, dest, `
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"LastUpdate": "2006-02-15T04:34:33Z",
"LongestFilm": {
"FilmID": 754,
"Title": "RUSHMORE MERMAID",
"Description": "A Boring Story of a Woman And a Moose who must Reach a Husband in A Shark Tank",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 6,
"RentalRate": 2.99,
"Length": 150,
"ReplacementCost": 17.99,
"Rating": "PG-13",
"SpecialFeatures": "Trailers,Commentaries,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
}
`)
}
func TestSelectJsonArr(t *testing.T) {
stmt := SELECT_JSON_ARR(Actor.AllColumns).
FROM(Actor).
ORDER_BY(Actor.ActorID)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
)) AS "json"
FROM dvds.actor
ORDER BY actor.actor_id;
`)
var dest []model.Actor
err := stmt.Query(db, &dest)
require.Nil(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json")
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 1)
}
func TestSelectJsonArr_NestedArr(t *testing.T) {
stmt := SELECT_JSON_ARR(
Actor.AllColumns,
SELECT_JSON_ARR(
Film.AllColumns,
).FROM(
FilmActor.INNER_JOIN(
Film,
Film.FilmID.EQ(FilmActor.FilmID).AND(
Actor.ActorID.EQ(FilmActor.ActorID)),
),
).WHERE(
Film.FilmID.MOD(Int(17)).EQ(Int(0)),
).ORDER_BY(
Film.Length.DESC(),
).AS("Films"),
).FROM(
Actor,
).WHERE(
Actor.ActorID.BETWEEN(Int(1), Int(3)),
).ORDER_BY(
Actor.ActorID,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ'),
'Films', (
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'filmID', film.film_id,
'title', film.title,
'description', film.description,
'releaseYear', film.release_year,
'languageID', film.language_id,
'originalLanguageID', film.original_language_id,
'rentalDuration', film.rental_duration,
'rentalRate', film.rental_rate,
'length', film.length,
'replacementCost', film.replacement_cost,
'rating', film.rating,
'specialFeatures', film.special_features,
'lastUpdate', DATE_FORMAT(film.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
)) AS "json"
FROM dvds.film_actor
INNER JOIN dvds.film ON ((film.film_id = film_actor.film_id) AND (actor.actor_id = film_actor.actor_id))
WHERE (film.film_id % 17) = 0
ORDER BY film.length DESC
)
)) AS "json"
FROM dvds.actor
WHERE actor.actor_id BETWEEN 1 AND 3
ORDER BY actor.actor_id;
`)
var dest []struct {
model.Actor
Films []model.Film
}
err := stmt.QueryContext(ctx, db, &dest)
fmt.Println(err)
require.Nil(t, err)
testutils.AssertJSON(t, dest, `
[
{
"ActorID": 1,
"FirstName": "PENELOPE",
"LastName": "GUINESS",
"LastUpdate": "2006-02-15T04:34:33Z",
"Films": null
},
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"LastUpdate": "2006-02-15T04:34:33Z",
"Films": [
{
"FilmID": 357,
"Title": "GILBERT PELICAN",
"Description": "A Fateful Tale of a Man And a Feminist who must Conquer a Crocodile in A Manhattan Penthouse",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 7,
"RentalRate": 0.99,
"Length": 114,
"ReplacementCost": 13.99,
"Rating": "G",
"SpecialFeatures": "Trailers,Commentaries",
"LastUpdate": "2006-02-15T05:03:42Z"
},
{
"FilmID": 561,
"Title": "MASK PEACH",
"Description": "A Boring Character Study of a Student And a Robot who must Meet a Woman in California",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 6,
"RentalRate": 2.99,
"Length": 123,
"ReplacementCost": 26.99,
"Rating": "NC-17",
"SpecialFeatures": "Commentaries,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
]
},
{
"ActorID": 3,
"FirstName": "ED",
"LastName": "CHASE",
"LastUpdate": "2006-02-15T04:34:33Z",
"Films": [
{
"FilmID": 17,
"Title": "ALONE TRIP",
"Description": "A Fast-Paced Character Study of a Composer And a Dog who must Outgun a Boat in An Abandoned Fun House",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 3,
"RentalRate": 0.99,
"Length": 82,
"ReplacementCost": 14.99,
"Rating": "R",
"SpecialFeatures": "Trailers,Behind the Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
},
{
"FilmID": 289,
"Title": "EVE RESURRECTION",
"Description": "A Awe-Inspiring Yarn of a Pastry Chef And a Database Administrator who must Challenge a Teacher in A Baloon",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 5,
"RentalRate": 4.99,
"Length": 66,
"ReplacementCost": 25.99,
"Rating": "G",
"SpecialFeatures": "Trailers,Commentaries,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
]
}
]
`)
}
func TestSelectJson_GroupBy(t *testing.T) {
skipForMariaDB(t) // scope issues with select without FROM
subQuery := SELECT(
Customer.AllColumns,
SUM(Payment.Amount).AS("sum"),
AVG(Payment.Amount).AS("avg"),
MAX(Payment.Amount).AS("max"),
MIN(Payment.Amount).AS("min"),
COUNT(Payment.Amount).AS("count"),
).FROM(
Payment.
INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)),
).GROUP_BY(
Customer.CustomerID,
).HAVING(
SUMf(Payment.Amount).GT(Float(125)),
).ORDER_BY(
Customer.CustomerID, SUM(Payment.Amount).ASC(),
).AsTable("customers_info")
stmt := SELECT_JSON_ARR(
subQuery.AllColumns().Except( // TODO: remove when ColumnList.From() is implemented
FloatColumn("sum"),
FloatColumn("avg"),
FloatColumn("max"),
FloatColumn("min"),
FloatColumn("count"),
),
SELECT_JSON_OBJ(
FloatColumn("sum").From(subQuery),
FloatColumn("avg").From(subQuery),
FloatColumn("max").From(subQuery),
FloatColumn("min").From(subQuery),
FloatColumn("count").From(subQuery),
).AS("amount"),
).FROM(subQuery)
testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(`
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'customerID', customers_info.''customer.customer_id'',
'storeID', customers_info.''customer.store_id'',
'firstName', customers_info.''customer.first_name'',
'lastName', customers_info.''customer.last_name'',
'email', customers_info.''customer.email'',
'addressID', customers_info.''customer.address_id'',
'active', customers_info.''customer.active'' = 1,
'createDate', DATE_FORMAT(customers_info.''customer.create_date'','%Y-%m-%dT%H:%i:%s.%fZ'),
'lastUpdate', DATE_FORMAT(customers_info.''customer.last_update'','%Y-%m-%dT%H:%i:%s.%fZ'),
'amount', (
SELECT JSON_OBJECT(
'sum', customers_info.sum,
'avg', customers_info.avg,
'max', customers_info.max,
'min', customers_info.min,
'count', customers_info.count
) AS "json"
)
)) AS "json"
FROM (
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
customer.last_name AS "customer.last_name",
customer.email AS "customer.email",
customer.address_id AS "customer.address_id",
customer.active AS "customer.active",
customer.create_date AS "customer.create_date",
customer.last_update AS "customer.last_update",
SUM(payment.amount) AS "sum",
AVG(payment.amount) AS "avg",
MAX(payment.amount) AS "max",
MIN(payment.amount) AS "min",
COUNT(payment.amount) AS "count"
FROM dvds.payment
INNER JOIN dvds.customer ON (customer.customer_id = payment.customer_id)
GROUP BY customer.customer_id
HAVING SUM(payment.amount) > 125
ORDER BY customer.customer_id, SUM(payment.amount) ASC
) AS customers_info;
`, "''", "`"))
var dest []struct {
model.Customer
Amount struct {
Sum float64
Avg float64
Max float64
Min float64
Count int64
}
}
err := stmt.QueryContext(ctx, db, &dest)
require.Nil(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json")
requireLogged(t, stmt)
}
func TestSelectJsonObject_EmptyResult(t *testing.T) {
t.Run("json obj", func(t *testing.T) {
stmt := SELECT_JSON_OBJ(Actor.AllColumns).
FROM(Actor).
WHERE(Actor.FirstName.EQ(String("Kowalski")))
var dest model.Actor
err := stmt.QueryContext(ctx, db, &dest)
require.ErrorIs(t, err, qrm.ErrNoRows)
})
t.Run("json arr", func(t *testing.T) {
stmt := SELECT_JSON_ARR(Actor.AllColumns).
FROM(Actor).
WHERE(Actor.FirstName.EQ(String("Kowalski")))
var dest []model.Actor
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
require.Empty(t, dest)
})
}
func TestSelectJson_ProjectionNotAliased(t *testing.T) {
t.Run("expression not aliased", func(t *testing.T) {
testutils.AssertPanicErr(t, func() {
stmt := SELECT_JSON_ARR(
Int(2).ADD(Customer.CustomerID),
).FROM(Customer)
stmt.DebugSql()
}, "jet: expression need to be aliased when used as SELECT JSON projection.")
})
}

View file

@ -19,9 +19,9 @@ import (
) )
func TestSelect_ScanToStruct(t *testing.T) { func TestSelect_ScanToStruct(t *testing.T) {
query := Actor. query := SELECT(Actor.AllColumns).
SELECT(Actor.AllColumns).
DISTINCT(). DISTINCT().
FROM(Actor).
WHERE(Actor.ActorID.EQ(Int(2))) WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
@ -50,9 +50,56 @@ var actor2 = model.Actor{
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2),
} }
func TestSelect_NestedObject(t *testing.T) {
stmt := SELECT(
Actor.AllColumns,
Film.AllColumns,
).FROM(
Actor.
LEFT_JOIN(FilmActor, FilmActor.ActorID.EQ(Actor.ActorID)).
LEFT_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)),
).WHERE(
Actor.ActorID.EQ(Int(2)),
).ORDER_BY(
Film.LastUpdate.DESC(),
).LIMIT(1)
var dest struct {
model.Actor
LatestFilm model.Film
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"LastUpdate": "2006-02-15T04:34:33Z",
"LatestFilm": {
"FilmID": 3,
"Title": "ADAPTATION HOLES",
"Description": "A Astounding Reflection of a Lumberjack And a Car who must Sink a Lumberjack in A Baloon Factory",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 7,
"RentalRate": 2.99,
"Length": 50,
"ReplacementCost": 18.99,
"Rating": "NC-17",
"SpecialFeatures": "Trailers,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
}
`)
}
func TestSelect_ScanToSlice(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) {
query := Actor. query := SELECT(Actor.AllColumns).
SELECT(Actor.AllColumns). FROM(Actor).
ORDER_BY(Actor.ActorID) ORDER_BY(Actor.ActorID)
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
@ -107,9 +154,7 @@ GROUP BY payment.customer_id
HAVING SUM(payment.amount) > 125.6 HAVING SUM(payment.amount) > 125.6
ORDER BY payment.customer_id, SUM(payment.amount) ASC; ORDER BY payment.customer_id, SUM(payment.amount) ASC;
` `
query := Payment. query := SELECT(
INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)).
SELECT(
Customer.AllColumns, Customer.AllColumns,
SUMf(Payment.Amount).AS("amount.sum"), SUMf(Payment.Amount).AS("amount.sum"),
@ -119,6 +164,9 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC;
MIN(Payment.PaymentDate).AS("amount.min_date"), MIN(Payment.PaymentDate).AS("amount.min_date"),
MINf(Payment.Amount).AS("amount.min"), MINf(Payment.Amount).AS("amount.min"),
COUNT(Payment.Amount).AS("amount.count"), COUNT(Payment.Amount).AS("amount.count"),
).FROM(
Payment.
INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)),
). ).
GROUP_BY(Payment.CustomerID). GROUP_BY(Payment.CustomerID).
HAVING( HAVING(
@ -1122,7 +1170,7 @@ WHERE payment.payment_id < ?
WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id)
ORDER BY payment.customer_id; ORDER BY payment.customer_id;
` `
query := Payment.SELECT( query := SELECT(
AVG(Payment.Amount).OVER(), AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(Window("w1")), AVG(Payment.Amount).OVER(Window("w1")),
AVG(Payment.Amount).OVER( AVG(Payment.Amount).OVER(
@ -1131,7 +1179,7 @@ ORDER BY payment.customer_id;
RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
), ),
AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
). ).FROM(Payment).
WHERE(Payment.PaymentID.LT(Int(10))). WHERE(Payment.PaymentID.LT(Int(10))).
WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)).
WINDOW("w2").AS(Window("w1")). WINDOW("w2").AS(Window("w1")).

View file

@ -1,6 +1,8 @@
package postgres package postgres
import ( import (
"encoding/base64"
"fmt"
"github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"math" "math"
@ -36,6 +38,141 @@ func TestAllTypesSelect(t *testing.T) {
testutils.AssertDeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
} }
func TestAllTypesSelectJson(t *testing.T) {
stmt := SELECT_JSON_ARR(
AllTypesAllColumns.Except(
AllTypes.JSON, AllTypes.JSONPtr,
AllTypes.Jsonb, AllTypes.JsonbPtr,
AllTypes.TextArray, AllTypes.TextArrayPtr,
AllTypes.JsonbArray, AllTypes.IntegerArray, AllTypes.IntegerArrayPtr,
AllTypes.TextMultiDimArray, AllTypes.TextMultiDimArrayPtr,
),
// unsupported at the moment, casting to text allows these columns to be assigned to string fields
CAST(AllTypes.JSONPtr).AS_TEXT().AS("jsonPtr"),
CAST(AllTypes.JSON).AS_TEXT().AS("JSON"),
CAST(AllTypes.JsonbPtr).AS_TEXT().AS("jsonbPtr"),
CAST(AllTypes.Jsonb).AS_TEXT().AS("Jsonb"),
CAST(AllTypes.TextArrayPtr).AS_TEXT().AS("TextArrayPtr"),
CAST(AllTypes.TextArray).AS_TEXT().AS("TextArray"),
CAST(AllTypes.JsonbArray).AS_TEXT().AS("JsonbArray"),
CAST(AllTypes.IntegerArray).AS_TEXT().AS("IntegerArray"),
CAST(AllTypes.IntegerArrayPtr).AS_TEXT().AS("IntegerArrayPtr"),
CAST(AllTypes.TextMultiDimArray).AS_TEXT().AS("TextMultiDimArray"),
CAST(AllTypes.TextMultiDimArrayPtr).AS_TEXT().AS("TextMultiDimArrayPtr"),
).FROM(AllTypes)
testutils.AssertStatementSql(t, stmt, `
SELECT json_agg(row_to_json(records)) AS "json"
FROM (
SELECT all_types.small_int_ptr AS "smallIntPtr",
all_types.small_int AS "smallInt",
all_types.integer_ptr AS "integerPtr",
all_types.integer AS "integer",
all_types.big_int_ptr AS "bigIntPtr",
all_types.big_int AS "bigInt",
all_types.decimal_ptr AS "decimalPtr",
all_types.decimal AS "decimal",
all_types.numeric_ptr AS "numericPtr",
all_types.numeric AS "numeric",
all_types.real_ptr AS "realPtr",
all_types.real AS "real",
all_types.double_precision_ptr AS "doublePrecisionPtr",
all_types.double_precision AS "doublePrecision",
all_types.smallserial AS "smallserial",
all_types.serial AS "serial",
all_types.bigserial AS "bigserial",
all_types.var_char_ptr AS "varCharPtr",
all_types.var_char AS "varChar",
all_types.char_ptr AS "charPtr",
all_types.char AS "char",
all_types.text_ptr AS "textPtr",
all_types.text AS "text",
ENCODE(all_types.bytea_ptr, 'base64') AS "byteaPtr",
ENCODE(all_types.bytea, 'base64') AS "bytea",
all_types.timestampz_ptr AS "timestampzPtr",
all_types.timestampz AS "timestampz",
to_char(all_types.timestamp_ptr, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampPtr",
to_char(all_types.timestamp, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp",
to_char(all_types.date_ptr::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "datePtr",
to_char(all_types.date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.timez_ptr, 'HH24:MI:SS.USTZH:TZM') AS "timezPtr",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.timez, 'HH24:MI:SS.USTZH:TZM') AS "timez",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.time_ptr, 'HH24:MI:SS.USZ') AS "timePtr",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.time, 'HH24:MI:SS.USZ') AS "time",
all_types.interval_ptr AS "intervalPtr",
all_types.interval AS "interval",
all_types.boolean_ptr AS "booleanPtr",
all_types.boolean AS "boolean",
all_types.point_ptr AS "pointPtr",
all_types.bit_ptr AS "bitPtr",
all_types.bit AS "bit",
all_types.bit_varying_ptr AS "bitVaryingPtr",
all_types.bit_varying AS "bitVarying",
all_types.tsvector_ptr AS "tsvectorPtr",
all_types.tsvector AS "tsvector",
all_types.uuid_ptr AS "uuidPtr",
all_types.uuid AS "uuid",
all_types.xml_ptr AS "xmlPtr",
all_types.xml AS "xml",
all_types.mood_ptr AS "moodPtr",
all_types.mood AS "mood",
all_types.json_ptr::text AS "jsonPtr",
all_types.json::text AS "JSON",
all_types.jsonb_ptr::text AS "jsonbPtr",
all_types.jsonb::text AS "Jsonb",
all_types.text_array_ptr::text AS "TextArrayPtr",
all_types.text_array::text AS "TextArray",
all_types.jsonb_array::text AS "JsonbArray",
all_types.integer_array::text AS "IntegerArray",
all_types.integer_array_ptr::text AS "IntegerArrayPtr",
all_types.text_multi_dim_array::text AS "TextMultiDimArray",
all_types.text_multi_dim_array_ptr::text AS "TextMultiDimArrayPtr"
FROM test_sample.all_types
) AS records;
`)
var dest []model.AllTypes
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
// fix inconsistencies between postgres and cockroachdb.
// cockroachdb returns char[N] columns with trailing whitespaces trimmed
if sourceIsCockroachDB() {
dest[0].Char = allTypesRow0.Char
dest[0].CharPtr = allTypesRow0.CharPtr
dest[1].Char = allTypesRow1.Char
dest[1].CharPtr = allTypesRow1.CharPtr
}
minus8 := time.FixedZone("UTC", -8*60*60)
plus1 := time.FixedZone("UTC", 60*60)
// set time local before comparison
dest[0].Timez = *toTZ(&dest[0].Timez, minus8)
dest[0].TimezPtr = toTZ(dest[0].TimezPtr, minus8)
dest[1].Timez = *toTZ(&dest[1].Timez, minus8)
dest[1].TimezPtr = toTZ(dest[1].TimezPtr, minus8)
dest[0].Timestampz = *toTZ(&dest[0].Timestampz, plus1)
dest[0].TimestampzPtr = toTZ(dest[0].TimestampzPtr, plus1)
dest[1].Timestampz = *toTZ(&dest[1].Timestampz, plus1)
dest[1].TimestampzPtr = toTZ(dest[1].TimestampzPtr, plus1)
testutils.AssertJsonEqual(t, dest[0], allTypesRow0)
testutils.AssertJsonEqual(t, dest[1], allTypesRow1)
}
func toTZ(tm *time.Time, loc *time.Location) *time.Time {
if tm == nil {
return nil
}
return ptr.Of(tm.In(loc))
}
func TestAllTypesViewSelect(t *testing.T) { func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes type AllTypesView model.AllTypes
var dest []AllTypesView var dest []AllTypesView
@ -132,7 +269,7 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11';
requireLogged(t, query) requireLogged(t, query)
} }
func TestBytea(t *testing.T) { func TestByteaInsert(t *testing.T) {
byteArrHex := "\\x48656c6c6f20476f7068657221" byteArrHex := "\\x48656c6c6f20476f7068657221"
byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21") byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21")
@ -147,8 +284,9 @@ RETURNING all_types.bytea AS "all_types.bytea",
all_types.bytea_ptr AS "all_types.bytea_ptr"; all_types.bytea_ptr AS "all_types.bytea_ptr";
`, byteArrHex, byteArrBin) `, byteArrHex, byteArrBin)
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
var inserted model.AllTypes var inserted model.AllTypes
err := insertStmt.Query(db, &inserted) err := insertStmt.Query(tx, &inserted)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!") require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!")
@ -175,12 +313,13 @@ WHERE all_types.bytea_ptr = $1::bytea;
var dest model.AllTypes var dest model.AllTypes
err = stmt.Query(db, &dest) err = stmt.Query(tx, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!") require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!")
// Probably pq driver error. // Probably pq driver error.
// require.Equal(t, string(dest.Bytea), "Hello Gopher!") // require.Equal(t, string(dest.Bytea), "Hello Gopher!")
})
} }
func TestAllTypesFromSubQuery(t *testing.T) { func TestAllTypesFromSubQuery(t *testing.T) {
@ -424,6 +563,7 @@ func TestExpressionCast(t *testing.T) {
CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(),
CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(), CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(),
CAST(String("04:05:06")).AS_INTERVAL(), CAST(String("04:05:06")).AS_INTERVAL(),
CAST(String("some text")).AS_BYTEA().EQ(Bytea([]byte("some text"))),
func() ProjectionList { func() ProjectionList {
if sourceIsCockroachDB() { if sourceIsCockroachDB() {
@ -477,7 +617,6 @@ func TestStringOperators(t *testing.T) {
AllTypes.Text.BETWEEN(String("min"), String("max")), AllTypes.Text.BETWEEN(String("min"), String("max")),
AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr), AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr),
AllTypes.Text.CONCAT(String("text2")), AllTypes.Text.CONCAT(String("text2")),
AllTypes.Text.CONCAT(Int(11)),
AllTypes.Text.LIKE(String("abc")), AllTypes.Text.LIKE(String("abc")),
AllTypes.Text.NOT_LIKE(String("_b_")), AllTypes.Text.NOT_LIKE(String("_b_")),
AllTypes.Text.REGEXP_LIKE(String("^t")), AllTypes.Text.REGEXP_LIKE(String("^t")),
@ -508,18 +647,18 @@ func TestStringOperators(t *testing.T) {
CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)),
CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), CONCAT(Bool(false), Int(1), Float(22.2), String("test test")),
CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)),
CONVERT(Bytea("bytea"), String("UTF8"), String("LATIN1")), CONVERT(Bytea("bytea"), UTF8, LATIN1),
CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")), CONVERT(AllTypes.Bytea, UTF8, LATIN1),
CONVERT_FROM(Bytea("text_in_utf8"), String("UTF8")), CONVERT_FROM(Bytea("text_in_utf8"), UTF8),
CONVERT_TO(String("text_in_utf8"), String("UTF8")), CONVERT_TO(String("text_in_utf8"), UTF8),
ENCODE(Bytea("123\000\001"), String("base64")), ENCODE(Bytea("some text"), Escape),
DECODE(String("MTIzAAE="), String("base64")), DECODE(String("MTIzAAE="), Base64),
FORMAT(String("Hello %s, %1$s"), String("World")), FORMAT(String("Hello %s, %1$s"), String("World")),
INITCAP(String("hi THOMAS")), INITCAP(String("hi THOMAS")),
LEFT(String("abcde"), Int(2)), LEFT(String("abcde"), Int(2)),
RIGHT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)),
LENGTH(Bytea("jose")), LENGTH(Bytea("jose")),
LENGTH(Bytea("jose"), String("UTF8")), LENGTH(Bytea("jose"), UTF8),
LPAD(String("Hi"), Int(5)), LPAD(String("Hi"), Int(5)),
LPAD(String("Hi"), Int(5), String("xy")), LPAD(String("Hi"), Int(5), String("xy")),
RPAD(String("Hi"), Int(5)), RPAD(String("Hi"), Int(5)),
@ -540,6 +679,202 @@ func TestStringOperators(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestBytea(t *testing.T) {
var sampleBytea = Bytea([]byte{11, 0, 22, 33, 44})
var textBytea = Bytea([]byte("text blob"))
stmt := SELECT(
AllTypes.Bytea.EQ(sampleBytea),
AllTypes.Bytea.EQ(AllTypes.ByteaPtr),
AllTypes.Bytea.NOT_EQ(sampleBytea),
AllTypes.Bytea.GT(textBytea),
AllTypes.Bytea.GT_EQ(AllTypes.ByteaPtr),
AllTypes.Bytea.LT(AllTypes.ByteaPtr),
AllTypes.Bytea.LT_EQ(sampleBytea),
AllTypes.Bytea.BETWEEN(Bytea([]byte("min")), Bytea([]byte("max"))),
AllTypes.Bytea.NOT_BETWEEN(AllTypes.Bytea, AllTypes.ByteaPtr),
AllTypes.Bytea.CONCAT(textBytea),
func() ProjectionList {
if sourceIsCockroachDB() {
return ProjectionList{NULL}
}
// cockroach doesn't support currently
return ProjectionList{
AllTypes.Bytea.LIKE(Bytea("b'%pattern%'")),
AllTypes.Bytea.NOT_LIKE(Bytea("b'%pattern%'")),
BTRIM(AllTypes.Bytea, Bytea([]byte{33})),
RTRIM(AllTypes.ByteaPtr, sampleBytea),
LTRIM(sampleBytea, textBytea),
CONCAT(sampleBytea, AllTypes.ByteaPtr, textBytea),
BIT_COUNT(sampleBytea).EQ(Int(3)),
LENGTH(textBytea, UTF8).EQ(Int(4)),
CONVERT(textBytea, UTF8, WIN1252),
CONVERT(AllTypes.Bytea, UTF8, LATIN1).EQ(sampleBytea),
}
}(),
BIT_LENGTH(textBytea),
OCTET_LENGTH(textBytea),
GET_BIT(textBytea, Int(2)).EQ(Int(23)),
GET_BYTE(sampleBytea, Int(1)).EQ(Int(0)),
SET_BIT(textBytea, Int(1), Int(0)).EQ(sampleBytea),
SET_BYTE(textBytea, Int(1), Int(0)).EQ(textBytea),
LENGTH(sampleBytea),
SUBSTR(AllTypes.Bytea, Int(0), Int(2)),
MD5(AllTypes.Bytea),
SHA224(AllTypes.Bytea),
SHA256(AllTypes.Bytea),
SHA384(AllTypes.Bytea),
SHA512(AllTypes.Bytea),
ENCODE(sampleBytea, Base64),
DECODE(String("A234C12B"), Hex).EQ(sampleBytea),
CONVERT_FROM(AllTypes.ByteaPtr, UTF8).EQ(AllTypes.VarChar),
CONVERT_TO(AllTypes.Text, UTF8).NOT_EQ(textBytea),
RawBytea("DECODE(#1::text, #2)", RawArgs{
"#1": "A234C12B",
"#2": "hex",
}).EQ(sampleBytea),
).FROM(
AllTypes,
)
if !sourceIsCockroachDB() {
testutils.AssertStatementSql(t, stmt, `
SELECT all_types.bytea = $1::bytea,
all_types.bytea = all_types.bytea_ptr,
all_types.bytea != $2::bytea,
all_types.bytea > $3::bytea,
all_types.bytea >= all_types.bytea_ptr,
all_types.bytea < all_types.bytea_ptr,
all_types.bytea <= $4::bytea,
all_types.bytea BETWEEN $5::bytea AND $6::bytea,
all_types.bytea NOT BETWEEN all_types.bytea AND all_types.bytea_ptr,
all_types.bytea || $7::bytea,
all_types.bytea LIKE $8::bytea,
all_types.bytea NOT LIKE $9::bytea,
BTRIM(all_types.bytea, $10::bytea),
RTRIM(all_types.bytea_ptr, $11::bytea),
LTRIM($12::bytea, $13::bytea),
CONCAT($14::bytea, all_types.bytea_ptr, $15::bytea),
BIT_COUNT($16::bytea) = $17,
LENGTH($18::bytea, 'UTF8') = $19,
CONVERT($20::bytea, 'UTF8', 'WIN1252'),
CONVERT(all_types.bytea, 'UTF8', 'LATIN1') = $21::bytea,
BIT_LENGTH($22::bytea),
OCTET_LENGTH($23::bytea),
GET_BIT($24::bytea, $25) = $26,
GET_BYTE($27::bytea, $28) = $29,
SET_BIT($30::bytea, $31, $32) = $33::bytea,
SET_BYTE($34::bytea, $35, $36) = $37::bytea,
LENGTH($38::bytea),
SUBSTR(all_types.bytea, $39, $40),
MD5(all_types.bytea),
SHA224(all_types.bytea),
SHA256(all_types.bytea),
SHA384(all_types.bytea),
SHA512(all_types.bytea),
ENCODE($41::bytea, 'base64'),
DECODE($42::text, 'hex') = $43::bytea,
CONVERT_FROM(all_types.bytea_ptr, 'UTF8') = all_types.var_char,
CONVERT_TO(all_types.text, 'UTF8') != $44::bytea,
(DECODE($45::text, $46)) = $47::bytea
FROM test_sample.all_types;
`)
}
var dest []struct{}
err := stmt.Query(db, &dest)
require.NoError(t, err)
}
func TestBlobConversion(t *testing.T) {
nonPrintable := []byte{11, 22, 33, 44, 55}
printable := []byte("this is blob")
stmt := SELECT(
Bytea(nonPrintable).AS("test_dest.non_printable"),
Bytea(printable).AS("test_dest.printable"),
Bytea(nonPrintable).CONCAT(Bytea(printable)).AS("test_dest.bytea_concat"),
ENCODE(Bytea(nonPrintable), Base64).AS("test_dest.non_printable_base64"),
CONVERT_FROM(Bytea(printable), UTF8).AS("test_dest.printable_utf8"),
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT '\x0b16212c37'::bytea AS "test_dest.non_printable",
'\x7468697320697320626c6f62'::bytea AS "test_dest.printable",
('\x0b16212c37'::bytea || '\x7468697320697320626c6f62'::bytea) AS "test_dest.bytea_concat",
ENCODE('\x0b16212c37'::bytea, 'base64') AS "test_dest.non_printable_base64",
CONVERT_FROM('\x7468697320697320626c6f62'::bytea, 'UTF8') AS "test_dest.printable_utf8";
`)
type testDest struct {
NonPrintable []byte
Printable []byte
ByteaConcat []byte
NonPrintableBase64 string
PrintableUTF8 string
}
var dest testDest
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.NonPrintable, nonPrintable)
require.Equal(t, dest.Printable, printable)
require.Equal(t, dest.ByteaConcat, append(nonPrintable, printable...))
require.Equal(t, dest.NonPrintableBase64, base64.StdEncoding.EncodeToString(nonPrintable))
require.Equal(t, dest.PrintableUTF8, string(printable))
t.Run("using select json", func(t *testing.T) {
stmtJson := SELECT_JSON_OBJ(
Bytea(nonPrintable).AS("nonPrintable"),
Bytea(printable).AS("printable"),
Bytea(nonPrintable).CONCAT(Bytea(printable)).AS("byteaConcat"),
ENCODE(Bytea(nonPrintable), Base64).AS("nonPrintableBase64"),
CONVERT_FROM(Bytea(printable), UTF8).AS("printableUtf8"),
)
testutils.AssertStatementSql(t, stmtJson, `
SELECT row_to_json(records) AS "json"
FROM (
SELECT ENCODE($1::bytea, 'base64') AS "nonPrintable",
ENCODE($2::bytea, 'base64') AS "printable",
ENCODE($3::bytea || $4::bytea, 'base64') AS "byteaConcat",
ENCODE($5::bytea, 'base64') AS "nonPrintableBase64",
CONVERT_FROM($6::bytea, 'UTF8') AS "printableUtf8"
) AS records;
`)
var destSelectJson testDest
err := stmtJson.QueryContext(ctx, db, &destSelectJson)
require.NoError(t, err)
testutils.PrintJson(destSelectJson)
require.Equal(t, dest, destSelectJson)
})
}
func TestBoolOperators(t *testing.T) { func TestBoolOperators(t *testing.T) {
query := AllTypes.SELECT( query := AllTypes.SELECT(
AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"), AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"),
@ -941,6 +1276,190 @@ func TestTimeExpression(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestTimeScan(t *testing.T) {
loc, err := time.LoadLocation("Japan")
require.NoError(t, err)
timeT := time.Date(3, 3, 3, 11, 22, 33, 0, time.UTC)
timeWithNanoSeconds := time.Date(3, 3, 3, 1, 2, 3, 1000, time.UTC)
timez := time.Date(3, 3, 3, 7, 8, 9, 0, time.UTC)
timezWithNanoSeconds := time.Date(3, 3, 3, 4, 5, 6, 1000, loc)
// '1999-01-08 04:05:06'
timestamp := time.Date(1999, 01, 8, 4, 5, 6, 0, time.UTC)
timestampWithNanoSeconds := time.Date(3, 3, 3, 8, 9, 10, 1000, time.UTC)
timestampz := time.Date(2003, 10, 3, 9, 10, 11, 0, loc)
timestampzWithNanoSeconds := time.Date(3, 3, 3, 8, 9, 10, 1000, loc)
date := time.Date(2010, 2, 3, 0, 0, 0, 0, time.UTC)
stmt := SELECT(
TimeT(timeT).AS("time"),
TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"),
TimezT(timez).AS("timez"),
TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"),
Timestamp(1999, 01, 8, 4, 5, 6).AS("timestamp"),
TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"),
TimestampzT(timestampz).AS("timestampz"),
TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"),
DateT(date).AS("date"),
TimeT(timeT).ADD(INTERVAL(2, HOUR)).AS("timeExpression"),
SELECT_JSON_OBJ(
TimeT(timeT).AS("time"),
TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"),
TimezT(timez).AS("timez"),
TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"),
TimestampT(timestamp).AS("timestamp"),
TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"),
TimestampzT(timestampz).AS("timestampz"),
TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"),
DateT(date).AS("date"),
TimeT(timeT).ADD(INTERVAL(2, HOUR)).AS("timeExpression"),
).AS("json"),
)
testutils.AssertStatementSql(t, stmt, `
SELECT $1::time without time zone AS "time",
$2::time without time zone AS "timeWithNanoSeconds",
$3::time with time zone AS "timez",
$4::time with time zone AS "timezWithNanoSeconds",
$5::timestamp without time zone AS "timestamp",
$6::timestamp without time zone AS "timestampWithNanoSeconds",
$7::timestamp with time zone AS "timestampz",
$8::timestamp with time zone AS "timestampzWithNanoSeconds",
$9::date AS "date",
($10::time without time zone + INTERVAL '2 HOUR') AS "timeExpression",
(
SELECT row_to_json(json_records) AS "json_json"
FROM (
SELECT '0000-01-01T' || to_char('2000-10-10'::date + $11::time without time zone, 'HH24:MI:SS.USZ') AS "time",
'0000-01-01T' || to_char('2000-10-10'::date + $12::time without time zone, 'HH24:MI:SS.USZ') AS "timeWithNanoSeconds",
'0000-01-01T' || to_char('2000-10-10'::date + $13::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timez",
'0000-01-01T' || to_char('2000-10-10'::date + $14::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timezWithNanoSeconds",
to_char($15::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp",
to_char($16::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampWithNanoSeconds",
$17::timestamp with time zone AS "timestampz",
$18::timestamp with time zone AS "timestampzWithNanoSeconds",
to_char($19::date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date",
'0000-01-01T' || to_char('2000-10-10'::date + ($20::time without time zone + INTERVAL '2 HOUR'), 'HH24:MI:SS.USZ') AS "timeExpression"
) AS json_records
) AS "json";
`)
var dest struct {
Time time.Time
TimeWithNanoSeconds time.Time
Timez time.Time
TimezWithNanoSeconds time.Time
Timestamp time.Time
TimestampWithNanoSeconds time.Time
Timestampz time.Time
TimestampzWithNanoSeconds time.Time
Date time.Time
TimeExpression time.Time
Json struct {
Time time.Time
TimeWithNanoSeconds time.Time
Timez time.Time
TimezWithNanoSeconds time.Time
Timestamp time.Time
TimestampWithNanoSeconds time.Time
Timestampz time.Time
TimestampzWithNanoSeconds time.Time
Date time.Time
TimeExpression time.Time
} `json_column:"json"`
}
err = stmt.Query(db, &dest)
require.NoError(t, err)
ensureTimezEqual(t, timeT.Add(2*time.Hour), dest.TimeExpression, loc)
ensureTimezEqual(t, timeT.Add(2*time.Hour), dest.Json.TimeExpression, loc)
ensureTimezEqual(t, timeT, dest.Time, loc)
ensureTimezEqual(t, timeT, dest.Json.Time, loc)
ensureTimezEqual(t, timeWithNanoSeconds, dest.TimeWithNanoSeconds, loc)
ensureTimezEqual(t, timeWithNanoSeconds, dest.Json.TimeWithNanoSeconds, loc)
ensureTimezEqual(t, timez, dest.Timez, loc)
ensureTimezEqual(t, timez, dest.Json.Timez, loc)
ensureTimezEqual(t, timezWithNanoSeconds, dest.TimezWithNanoSeconds, loc)
ensureTimezEqual(t, timezWithNanoSeconds, dest.Json.TimezWithNanoSeconds, loc)
ensureTimezEqual(t, timestamp, dest.Timestamp, loc)
ensureTimezEqual(t, timestamp, dest.Json.Timestamp, loc)
ensureTimezEqual(t, timestampWithNanoSeconds, dest.TimestampWithNanoSeconds, loc)
ensureTimezEqual(t, timestampWithNanoSeconds, dest.Json.TimestampWithNanoSeconds, loc)
ensureTimezEqual(t, timestampz, dest.Timestampz, loc)
ensureTimezEqual(t, timestampz, dest.Json.Timestampz, loc)
ensureTimezEqual(t, timestampzWithNanoSeconds, dest.TimestampzWithNanoSeconds, loc)
ensureTimezEqual(t, timestampzWithNanoSeconds, dest.Json.TimestampzWithNanoSeconds, loc)
ensureTimezEqual(t, date, dest.Date, loc)
ensureTimezEqual(t, date, dest.Json.Date, loc)
t.Run("json only", func(t *testing.T) {
stmtJson := SELECT_JSON_OBJ(
TimeT(timeT).AS("time"),
TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"),
TimezT(timez).AS("timez"),
TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"),
Timestamp(1999, 01, 8, 4, 5, 6).AS("timestamp"),
TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"),
TimestampzT(timestampz).AS("timestampz"),
TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"),
DateT(date).AS("date"),
)
var jsonDest struct {
Time time.Time
TimeWithNanoSeconds time.Time
Timez time.Time
TimezWithNanoSeconds time.Time
Timestamp time.Time
TimestampWithNanoSeconds time.Time
Timestampz time.Time
TimestampzWithNanoSeconds time.Time
Date time.Time
}
err := stmtJson.QueryContext(ctx, db, &jsonDest)
require.NoError(t, err)
})
}
func ensureTimezEqual(t *testing.T, time1, time2 time.Time, loc *time.Location) {
time1Loc := time1.In(loc)
time2Loc := time2.In(loc)
require.Equal(t, time1Loc.Hour(), time2Loc.Hour())
require.Equal(t, time1Loc.Minute(), time2Loc.Minute())
require.Equal(t, time1Loc.Second(), time2Loc.Second())
require.Equal(t, toMicroSeconds(time1Loc.Nanosecond()), toMicroSeconds(time2Loc.Nanosecond()))
}
func toMicroSeconds(nanoseconds int) int {
return nanoseconds / 1000
}
func TestIntervalSetFunctionality(t *testing.T) { func TestIntervalSetFunctionality(t *testing.T) {
t.Run("updateQueryIntervalTest", func(t *testing.T) { t.Run("updateQueryIntervalTest", func(t *testing.T) {
@ -1052,7 +1571,50 @@ func TestInterval(t *testing.T) {
AllTypes.IntervalPtr.DIV(Float(22.222)).EQ(AllTypes.IntervalPtr), AllTypes.IntervalPtr.DIV(Float(22.222)).EQ(AllTypes.IntervalPtr),
).FROM(AllTypes) ).FROM(AllTypes)
//fmt.Println(stmt.DebugSql()) fmt.Println(stmt.Sql())
testutils.AssertDebugStatementSql(t, stmt, `
SELECT INTERVAL '1 YEAR',
INTERVAL '1 MONTH',
INTERVAL '1 WEEK',
INTERVAL '1 DAY',
INTERVAL '1 HOUR',
INTERVAL '1 MINUTE',
INTERVAL '1 SECOND',
INTERVAL '1 MILLISECOND',
INTERVAL '1 MICROSECOND',
INTERVAL '1 DECADE',
INTERVAL '1 CENTURY',
INTERVAL '1 MILLENNIUM',
INTERVAL '1 YEAR 10 MONTH',
INTERVAL '1 YEAR 10 MONTH 20 DAY',
INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR',
INTERVAL '1 YEAR' IS NOT NULL,
INTERVAL '1 YEAR' AS "one year",
INTERVAL '0 MICROSECOND',
INTERVAL '1 MICROSECOND',
INTERVAL '1000 MICROSECOND',
INTERVAL '1 SECOND',
INTERVAL '1 MINUTE',
INTERVAL '1 HOUR',
INTERVAL '1 DAY',
INTERVAL '1 DAY 2 HOUR 3 MINUTE 4 SECOND 5 MICROSECOND',
(all_types.interval = INTERVAL '2 HOUR 20 MINUTE') = TRUE::boolean,
(all_types.interval_ptr != INTERVAL '2 HOUR 20 MINUTE') = FALSE::boolean,
(all_types.interval IS DISTINCT FROM INTERVAL '2 HOUR 20 MINUTE') = all_types.boolean,
(all_types.interval_ptr IS NOT DISTINCT FROM INTERVAL '10 MICROSECOND') = all_types.boolean,
(all_types.interval < all_types.interval_ptr) = all_types.boolean_ptr,
(all_types.interval <= all_types.interval_ptr) = all_types.boolean_ptr,
(all_types.interval > all_types.interval_ptr) = all_types.boolean_ptr,
(all_types.interval >= all_types.interval_ptr) = all_types.boolean_ptr,
all_types.interval BETWEEN INTERVAL '1 HOUR' AND INTERVAL '2 HOUR',
all_types.interval NOT BETWEEN all_types.interval_ptr AND INTERVAL '30 SECOND',
(all_types.interval + all_types.interval_ptr) = INTERVAL '17 SECOND',
(all_types.interval - all_types.interval_ptr) = INTERVAL '100 MICROSECOND',
(all_types.interval_ptr * 11) = all_types.interval,
(all_types.interval_ptr / 22.222) = all_types.interval_ptr
FROM test_sample.all_types;
`)
err := stmt.Query(db, &struct{}{}) err := stmt.Query(db, &struct{}{})
require.NoError(t, err) require.NoError(t, err)
@ -1159,6 +1721,187 @@ SELECT ROW($1::integer, $2::real, $3::text) AS "row",
require.NoError(t, err) require.NoError(t, err)
} }
func TestAllTypesSubQueryFrom(t *testing.T) {
subQuery := SELECT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.DoublePrecision,
AllTypes.Text,
AllTypes.Date,
AllTypes.Time,
AllTypes.Timez,
AllTypes.Timestamp,
AllTypes.Timestampz,
AllTypes.Interval,
AllTypes.Bytea,
).FROM(
AllTypes,
).AsTable("subQuery")
stmt := SELECT(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.DoublePrecision.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timez.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Timestampz.From(subQuery),
AllTypes.Interval.From(subQuery),
AllTypes.Bytea.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertStatementSql(t, stmt, `
SELECT "subQuery"."all_types.boolean" AS "all_types.boolean",
"subQuery"."all_types.integer" AS "all_types.integer",
"subQuery"."all_types.double_precision" AS "all_types.double_precision",
"subQuery"."all_types.text" AS "all_types.text",
"subQuery"."all_types.date" AS "all_types.date",
"subQuery"."all_types.time" AS "all_types.time",
"subQuery"."all_types.timez" AS "all_types.timez",
"subQuery"."all_types.timestamp" AS "all_types.timestamp",
"subQuery"."all_types.timestampz" AS "all_types.timestampz",
"subQuery"."all_types.interval" AS "all_types.interval",
"subQuery"."all_types.bytea" AS "all_types.bytea"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.double_precision AS "all_types.double_precision",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timez AS "all_types.timez",
all_types.timestamp AS "all_types.timestamp",
all_types.timestampz AS "all_types.timestampz",
all_types.interval AS "all_types.interval",
all_types.bytea AS "all_types.bytea"
FROM test_sample.all_types
) AS "subQuery";
`)
var dest []model.AllTypes
err := stmt.Query(db, &dest)
require.NoError(t, err)
t.Run("using SELECT_JSON", func(t *testing.T) {
stmtJson := SELECT_JSON_ARR(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.DoublePrecision.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timez.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Timestampz.From(subQuery),
AllTypes.Interval.From(subQuery),
AllTypes.Bytea.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertDebugStatementSql(t, stmtJson, `
SELECT json_agg(row_to_json(records)) AS "json"
FROM (
SELECT "subQuery"."all_types.boolean" AS "boolean",
"subQuery"."all_types.integer" AS "integer",
"subQuery"."all_types.double_precision" AS "doublePrecision",
"subQuery"."all_types.text" AS "text",
to_char("subQuery"."all_types.date"::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date",
'0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.time", 'HH24:MI:SS.USZ') AS "time",
'0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.timez", 'HH24:MI:SS.USTZH:TZM') AS "timez",
to_char("subQuery"."all_types.timestamp", 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp",
"subQuery"."all_types.timestampz" AS "timestampz",
"subQuery"."all_types.interval" AS "interval",
ENCODE("subQuery"."all_types.bytea", 'base64') AS "bytea"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.double_precision AS "all_types.double_precision",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timez AS "all_types.timez",
all_types.timestamp AS "all_types.timestamp",
all_types.timestampz AS "all_types.timestampz",
all_types.interval AS "all_types.interval",
all_types.bytea AS "all_types.bytea"
FROM test_sample.all_types
) AS "subQuery"
) AS records;
`)
var destJson []model.AllTypes
err := stmtJson.QueryContext(ctx, db, &destJson)
require.NoError(t, err)
t.Run("using AllColumns()", func(t *testing.T) {
stmtJsonAllColumns := SELECT_JSON_ARR(
subQuery.AllColumns(),
).FROM(
subQuery,
)
require.Equal(t, stmtJson.DebugSql(), stmtJsonAllColumns.DebugSql())
})
// fix timezone before comparisons
minus8 := time.FixedZone("UTC", -8*60*60)
destJson[0].Timez = *toTZ(&destJson[0].Timez, minus8)
destJson[1].Timez = *toTZ(&destJson[1].Timez, minus8)
destJson[0].Timestampz = *toTZ(&destJson[0].Timestampz, time.UTC)
destJson[1].Timestampz = *toTZ(&destJson[1].Timestampz, time.UTC)
dest[0].Timestampz = *toTZ(&dest[0].Timestampz, time.UTC)
dest[1].Timestampz = *toTZ(&dest[1].Timestampz, time.UTC)
testutils.AssertJsonEqual(t, dest, destJson)
})
}
func TestAllTypesUpdateSet(t *testing.T) {
stmt := AllTypes.UPDATE().
SET(
AllTypes.Boolean.SET(Bool(false)),
AllTypes.Integer.SET(Int(2)),
AllTypes.DoublePrecision.SET(Float(2.22)),
AllTypes.Text.SET(Text("some text")),
AllTypes.Date.SET(DateT(time.Now())),
AllTypes.Time.SET(TimeT(time.Now())),
AllTypes.Timez.SET(TimezT(time.Now())),
AllTypes.Timestamp.SET(TimestampT(time.Now())),
AllTypes.Interval.SET(INTERVAL(1, HOUR)),
AllTypes.Bytea.SET(Bytea([]byte{11, 22, 33, 44})),
).WHERE(Bool(true))
testutils.AssertStatementSql(t, stmt, `
UPDATE test_sample.all_types
SET boolean = $1::boolean,
integer = $2,
double_precision = $3,
text = $4::text,
date = $5::date,
time = $6::time without time zone,
timez = $7::time with time zone,
timestamp = $8::timestamp without time zone,
interval = INTERVAL '1 HOUR',
bytea = $9::bytea
WHERE $10::boolean;
`)
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
_, err := stmt.Exec(tx)
require.NoError(t, err)
})
}
func TestSubQueryColumnReference(t *testing.T) { func TestSubQueryColumnReference(t *testing.T) {
type expected struct { type expected struct {
sql string sql string

View file

@ -188,7 +188,129 @@ ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId";
`) `)
} }
type AllArtistDetails []struct { //list of all artist
model.Artist
Albums []struct { // list of albums per artist
model.Album
Tracks []struct { // list of tracks per album
model.Track
Genre model.Genre // track genre
MediaType model.MediaType // track media type
Playlists []model.Playlist // list of playlist where track is used
Invoices []struct { // list of invoices where track occurs
model.Invoice
Customer struct { // customer data for invoice
model.Customer
Employee *struct { // employee data for customer if exists
model.Employee
Manager *model.Employee `alias:"Manager"`
}
}
}
}
}
}
func BenchmarkJoinEverythingJSON(b *testing.B) {
for i := 0; i < b.N; i++ {
testJoinEverythingJSON(b)
}
}
func TestJoinEverythingJSON(t *testing.T) {
testJoinEverythingJSON(t)
}
func testJoinEverythingJSON(t require.TestingT) {
manager := Employee.AS("Manager")
stmt := SELECT_JSON_ARR(
Artist.AllColumns,
SELECT_JSON_ARR(
Album.AllColumns,
SELECT_JSON_ARR(
Track.AllColumns,
SELECT_JSON_OBJ(Genre.AllColumns).
FROM(Genre).
WHERE(Genre.GenreId.EQ(Track.GenreId)).AS("Genre"),
SELECT_JSON_OBJ(MediaType.AllColumns).
FROM(MediaType).
WHERE(MediaType.MediaTypeId.EQ(Track.MediaTypeId)).AS("MediaType"),
SELECT_JSON_ARR(Playlist.AllColumns).
FROM(Playlist.
INNER_JOIN(PlaylistTrack, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId))).
WHERE(PlaylistTrack.TrackId.EQ(Track.TrackId)).
ORDER_BY(Playlist.PlaylistId).AS("Playlists"),
SELECT_JSON_ARR(
Invoice.AllColumns,
SELECT_JSON_OBJ(
Customer.AllColumns,
SELECT_JSON_OBJ(
Employee.AllColumns,
SELECT_JSON_OBJ(manager.AllColumns).
FROM(manager).
WHERE(manager.EmployeeId.EQ(Employee.ReportsTo)).AS("Manager"),
).FROM(Employee).
WHERE(Employee.EmployeeId.EQ(Customer.SupportRepId)).AS("Employee"),
).FROM(Customer).
WHERE(Customer.CustomerId.EQ(Invoice.CustomerId)).AS("Customer"),
).FROM(
Invoice.
INNER_JOIN(InvoiceLine, InvoiceLine.InvoiceId.EQ(Invoice.InvoiceId)),
).WHERE(InvoiceLine.TrackId.EQ(Track.TrackId)).
ORDER_BY(Invoice.InvoiceId).AS("Invoices"),
).FROM(Track).
WHERE(Track.AlbumId.EQ(Album.AlbumId)).
ORDER_BY(Track.TrackId).AS("Tracks"),
).FROM(Album).
WHERE(Album.ArtistId.EQ(Artist.ArtistId)).
ORDER_BY(Album.AlbumId).AS("Albums"),
).FROM(Artist).
ORDER_BY(Artist.ArtistId)
//fmt.Println(stmt.DebugSql())
var dest AllArtistDetails
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 275)
//testutils.SaveJSONFile(dest, "./testdata/results/postgres/joined_everything2.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json")
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 1)
}
func BenchmarkJoinEverything(b *testing.B) {
for i := 0; i < b.N; i++ {
testJoinEverything(b)
}
}
func TestJoinEverything(t *testing.T) { func TestJoinEverything(t *testing.T) {
testJoinEverything(t)
}
func testJoinEverything(t require.TestingT) {
manager := Employee.AS("Manager") manager := Employee.AS("Manager")
@ -223,37 +345,6 @@ func TestJoinEverything(t *testing.T) {
Invoice.InvoiceId, Customer.CustomerId, Invoice.InvoiceId, Customer.CustomerId,
) )
var dest []struct { //list of all artist
model.Artist
Albums []struct { // list of albums per artist
model.Album
Tracks []struct { // list of tracks per album
model.Track
Genre model.Genre // track genre
MediaType model.MediaType // track media type
Playlists []model.Playlist // list of playlist where track is used
Invoices []struct { // list of invoices where track occurs
model.Invoice
Customer struct { // customer data for invoice
model.Customer
Employee *struct { // employee data for customer if exists
model.Employee
Manager *model.Employee `alias:"Manager"`
}
}
}
}
}
}
testutils.AssertStatementSql(t, stmt, ` testutils.AssertStatementSql(t, stmt, `
SELECT "Artist"."ArtistId" AS "Artist.ArtistId", SELECT "Artist"."ArtistId" AS "Artist.ArtistId",
"Artist"."Name" AS "Artist.Name", "Artist"."Name" AS "Artist.Name",
@ -344,7 +435,7 @@ FROM chinook."Artist"
LEFT JOIN chinook."Employee" AS "Manager" ON ("Manager"."EmployeeId" = "Employee"."ReportsTo") LEFT JOIN chinook."Employee" AS "Manager" ON ("Manager"."EmployeeId" = "Employee"."ReportsTo")
ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId"; ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId";
`) `)
var dest AllArtistDetails
err := stmt.QueryContext(context.Background(), db, &dest) err := stmt.QueryContext(context.Background(), db, &dest)
require.NoError(t, err) require.NoError(t, err)

View file

@ -974,8 +974,8 @@ type allTypesTable struct {
Char postgres.ColumnString Char postgres.ColumnString
TextPtr postgres.ColumnString TextPtr postgres.ColumnString
Text postgres.ColumnString Text postgres.ColumnString
ByteaPtr postgres.ColumnString ByteaPtr postgres.ColumnBytea
Bytea postgres.ColumnString Bytea postgres.ColumnBytea
TimestampzPtr postgres.ColumnTimestampz TimestampzPtr postgres.ColumnTimestampz
Timestampz postgres.ColumnTimestampz Timestampz postgres.ColumnTimestampz
TimestampPtr postgres.ColumnTimestamp TimestampPtr postgres.ColumnTimestamp
@ -1078,8 +1078,8 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable {
CharColumn = postgres.StringColumn("char") CharColumn = postgres.StringColumn("char")
TextPtrColumn = postgres.StringColumn("text_ptr") TextPtrColumn = postgres.StringColumn("text_ptr")
TextColumn = postgres.StringColumn("text") TextColumn = postgres.StringColumn("text")
ByteaPtrColumn = postgres.StringColumn("bytea_ptr") ByteaPtrColumn = postgres.ByteaColumn("bytea_ptr")
ByteaColumn = postgres.StringColumn("bytea") ByteaColumn = postgres.ByteaColumn("bytea")
TimestampzPtrColumn = postgres.TimestampzColumn("timestampz_ptr") TimestampzPtrColumn = postgres.TimestampzColumn("timestampz_ptr")
TimestampzColumn = postgres.TimestampzColumn("timestampz") TimestampzColumn = postgres.TimestampzColumn("timestampz")
TimestampPtrColumn = postgres.TimestampColumn("timestamp_ptr") TimestampPtrColumn = postgres.TimestampColumn("timestamp_ptr")

View file

@ -20,6 +20,8 @@ import (
_ "github.com/jackc/pgx/v4/stdlib" _ "github.com/jackc/pgx/v4/stdlib"
) )
var ctx = context.Background()
var db *stmtcache.DB var db *stmtcache.DB
var testRoot string var testRoot string
@ -31,6 +33,7 @@ const CockroachDB = "COCKROACH_DB"
func init() { func init() {
source = os.Getenv("PG_SOURCE") source = os.Getenv("PG_SOURCE")
withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true"
testRoot = repo.GetTestsDirPath()
} }
func sourceIsCockroachDB() bool { func sourceIsCockroachDB() bool {
@ -46,8 +49,6 @@ func skipForCockroachDB(t *testing.T) {
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
defer profile.Start().Stop() defer profile.Start().Stop()
setTestRoot()
for _, driverName := range []string{"postgres", "pgx"} { for _, driverName := range []string{"postgres", "pgx"} {
fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, withStatementCaching) fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, withStatementCaching)
@ -94,10 +95,6 @@ func getConnectionString() string {
return dbconfig.PostgresConnectString return dbconfig.PostgresConnectString
} }
func setTestRoot() {
testRoot = repo.GetTestsDirPath()
}
var loggedSQL string var loggedSQL string
var loggedSQLArgs []interface{} var loggedSQLArgs []interface{}
var loggedDebugSQL string var loggedDebugSQL string
@ -119,14 +116,22 @@ func init() {
}) })
} }
func requireLogged(t *testing.T, statement postgres.Statement) { func requireLogged(t require.TestingT, statement postgres.Statement) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
query, args := statement.Sql() query, args := statement.Sql()
require.Equal(t, loggedSQL, query) require.Equal(t, loggedSQL, query)
require.Equal(t, loggedSQLArgs, args) require.Equal(t, loggedSQLArgs, args)
require.Equal(t, loggedDebugSQL, statement.DebugSql()) require.Equal(t, loggedDebugSQL, statement.DebugSql())
} }
func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { func requireQueryLogged(t require.TestingT, statement postgres.Statement, rowsProcessed int64) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
query, args := statement.Sql() query, args := statement.Sql()
queryLogged, argsLogged := queryInfo.Statement.Sql() queryLogged, argsLogged := queryInfo.Statement.Sql()

View file

@ -9,7 +9,50 @@ import (
"testing" "testing"
) )
func TestNorthwindJoinEverything(t *testing.T) { type Dest []struct {
model.Customers
Demographics model.CustomerDemographics
Orders []struct {
model.Orders
Shipper model.Shippers
Employee struct {
model.Employees
Territories []struct {
model.Territories
Region model.Region
}
}
Details []struct {
model.OrderDetails
Products struct {
model.Products
Category model.Categories
Supplier model.Suppliers
}
}
}
}
func BenchmarkTestNorthwindJoinEverything(b *testing.B) {
for i := 0; i < b.N; i++ {
testNorthwindJoinEverything(b)
}
}
func TestTestNorthwindJoinEverything(t *testing.T) {
testNorthwindJoinEverything(t)
}
func testNorthwindJoinEverything(t require.TestingT) {
stmt := stmt :=
SELECT( SELECT(
@ -21,6 +64,9 @@ func TestNorthwindJoinEverything(t *testing.T) {
Products.AllColumns, Products.AllColumns,
Categories.AllColumns, Categories.AllColumns,
Suppliers.AllColumns, Suppliers.AllColumns,
Employees.AllColumns,
Territories.AllColumns,
Region.AllColumns,
).FROM( ).FROM(
Customers. Customers.
LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)).
@ -35,35 +81,110 @@ func TestNorthwindJoinEverything(t *testing.T) {
LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)).
LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)).
LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)), LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)),
).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID) ).ORDER_BY(
Customers.CustomerID,
Orders.OrderID,
Products.ProductID,
Territories.TerritoryID,
)
var dest []struct { //fmt.Println(stmt.DebugSql())
model.Customers
Demographics model.CustomerDemographics var dest Dest
Orders []struct {
model.Orders
Shipper model.Shippers
Details struct {
model.OrderDetails
Products []struct {
model.Products
Category model.Categories
Supplier model.Suppliers
}
}
}
}
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
//jsonSave("./testdata/northwind-all.json", dest) //testutils.SaveJSONFile(dest, "./testdata/results/postgres/northwind-all.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json")
requireLogged(t, stmt) requireLogged(t, stmt)
} }
func BenchmarkTestNorthwindJoinEverythingJson(b *testing.B) {
for i := 0; i < b.N; i++ {
testNorthwindJoinEverythingJson(b)
}
}
func TestNorthwindJoinEverythingJson(t *testing.T) {
testNorthwindJoinEverythingJson(t)
}
func testNorthwindJoinEverythingJson(t require.TestingT) {
stmt := SELECT_JSON_ARR(
Customers.AllColumns,
SELECT_JSON_OBJ(CustomerDemographics.AllColumns).
FROM(CustomerDemographics.INNER_JOIN(CustomerCustomerDemo, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID))).
WHERE(CustomerCustomerDemo.CustomerID.EQ(Customers.CustomerID)).AS("Demographics"),
SELECT_JSON_ARR(
Orders.AllColumns,
SELECT_JSON_OBJ(Shippers.AllColumns).
FROM(Shippers).
WHERE(Shippers.ShipperID.EQ(Orders.ShipVia)).AS("Shipper"),
SELECT_JSON_OBJ(
Employees.AllColumns,
SELECT_JSON_ARR(
Territories.AllColumns,
SELECT_JSON_OBJ(Region.AllColumns).
FROM(Region).
WHERE(Region.RegionID.EQ(Territories.RegionID)).AS("Region"),
).FROM(
EmployeeTerritories.LEFT_JOIN(
Territories,
EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)),
).WHERE(
EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID),
).AS("Territories"),
).FROM(Employees).
WHERE(Orders.EmployeeID.EQ(Employees.EmployeeID)).AS("Employee"),
SELECT_JSON_ARR(
OrderDetails.AllColumns,
SELECT_JSON_OBJ(
Products.AllColumns,
SELECT_JSON_OBJ(
Categories.AllColumns,
).FROM(Categories).
WHERE(Categories.CategoryID.EQ(Products.CategoryID)).AS("Category"),
SELECT_JSON_OBJ(Suppliers.AllColumns).
FROM(Suppliers).
WHERE(Suppliers.SupplierID.EQ(Products.SupplierID)).AS("Supplier"),
).FROM(Products).
WHERE(Products.ProductID.EQ(OrderDetails.ProductID)).AS("Products"),
).FROM(
OrderDetails,
).WHERE(
OrderDetails.OrderID.EQ(Orders.OrderID),
).AS("Details"),
).FROM(
Orders,
).WHERE(
Orders.CustomerID.EQ(Customers.CustomerID),
).ORDER_BY(
Orders.OrderID,
).AS("Orders"),
).FROM(
Customers,
).ORDER_BY(
Customers.CustomerID,
)
//fmt.Println(stmt.DebugSql())
var dest Dest
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
//testutils.SaveJSONFile(dest, "./testdata/results/postgres/northwind-all2.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json")
}

View file

@ -220,20 +220,7 @@ func TestUUIDComplex(t *testing.T) {
requireLogged(t, query) requireLogged(t, query)
}) })
t.Run("slice of structs left join", func(t *testing.T) { var expectedSliceOfStructsLeftJoin = `
leftQuery := Person.LEFT_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)).
SELECT(Person.AllColumns, PersonPhone.AllColumns).
ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC())
var dest []struct {
model.Person
Phones []struct {
model.PersonPhone
}
}
err := leftQuery.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[ [
{ {
"PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6", "PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6",
@ -274,10 +261,50 @@ func TestUUIDComplex(t *testing.T) {
] ]
} }
] ]
`) `
t.Run("slice of structs left join", func(t *testing.T) {
leftQuery := Person.LEFT_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)).
SELECT(Person.AllColumns, PersonPhone.AllColumns).
ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC())
var dest []struct {
model.Person
Phones []struct {
model.PersonPhone
}
}
err := leftQuery.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, expectedSliceOfStructsLeftJoin)
requireLogged(t, leftQuery) requireLogged(t, leftQuery)
}) })
t.Run("select json", func(t *testing.T) {
jsonQuery := SELECT_JSON_ARR(
Person.AllColumns,
SELECT_JSON_ARR(PersonPhone.AllColumns).
FROM(PersonPhone).
WHERE(PersonPhone.PersonID.EQ(Person.PersonID)).
ORDER_BY(PersonPhone.PhoneID).AS("Phones"),
).FROM(
Person,
).ORDER_BY(
Person.PersonID.ASC(),
)
var dest []struct {
model.Person
Phones []struct {
model.PersonPhone
}
}
err := jsonQuery.QueryContext(ctx, db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, expectedSliceOfStructsLeftJoin)
})
} }
func TestEnumType(t *testing.T) { func TestEnumType(t *testing.T) {
query := Person. query := Person.

View file

@ -209,7 +209,7 @@ func TestScanToStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
require.Error(t, err) require.Error(t, err)
require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID") require.EqualError(t, err, "jet: can't assign int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID")
}) })
t.Run("type mismatch base type", func(t *testing.T) { t.Run("type mismatch base type", func(t *testing.T) {

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show more