diff --git a/.circleci/config.yml b/.circleci/config.yml index a36737b..1518364 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -11,7 +11,7 @@ jobs: - image: cimg/go:1.22.8 # Please keep the version in sync with test/docker-compose.yaml - - image: cimg/postgres:14.10 + - image: cimg/postgres:14.1 environment: POSTGRES_USER: jet POSTGRES_PASSWORD: jet @@ -19,7 +19,7 @@ jobs: PGPORT: 50901 # Please keep the version in sync with test/docker-compose.yaml - - image: circleci/mysql:8.0.27 + - image: cimg/mysql:8.0.27 command: [ --default-authentication-plugin=mysql_native_password ] environment: MYSQL_ROOT_PASSWORD: jet @@ -29,7 +29,7 @@ jobs: MYSQL_TCP_PORT: 50902 # Please keep the version in sync with test/docker-compose.yaml - - image: circleci/mariadb:10.3 + - image: cimg/mariadb:11.4 command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ] environment: MYSQL_ROOT_PASSWORD: jet @@ -116,25 +116,27 @@ jobs: name: Create MySQL/MariaDB user and test databases command: | mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" - mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" + mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_ALL_TABLES,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database test_sample" mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database dvds2" mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';" - mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" + mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_ALL_TABLES,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';" mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database test_sample" mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database dvds2" - - run: - name: Init databases - command: | - cd tests - go run ./init/init.go -testsuite all - - run: name: Install gotestsum command: go install gotest.tools/gotestsum@latest + - run: + name: Init databases (postgres, mysql, sqlite) and generate jet files + command: | + cd tests + go run ./init/init.go -testsuite postgres + go run ./init/init.go -testsuite mysql + go run ./init/init.go -testsuite sqlite + # to create test results report - run: mkdir -p $TEST_RESULTS @@ -146,14 +148,14 @@ jobs: name: Running tests with statement caching enabled command: JET_TESTS_WITH_STMT_CACHE=true go test -tags postgres -v ./tests/... - # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: - name: Jet generate mariadb and cockroachdb + name: Init databases (mariadb, cockroachdb) and generate jet files command: | cd tests - make jet-gen-mariadb - make jet-gen-cockroach + go run ./init/init.go -testsuite mariadb + go run ./init/init.go -testsuite cockroach + # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ - run: PG_SOURCE=COCKROACH_DB go test -v ./tests/postgres/ diff --git a/README.md b/README.md index 3d5f57d..70e29e9 100644 --- a/README.md +++ b/README.md @@ -579,5 +579,5 @@ To run the tests, additional dependencies are required: ## License -Copyright 2019-2024 Goran Bjelanovic +Copyright 2019-2025 Goran Bjelanovic Licensed under the Apache License, Version 2.0. diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 14da5aa..5063f7c 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -18,8 +18,8 @@ func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTyp SELECT t.table_name as "table.name", col.COLUMN_NAME AS "column.Name", - col.COLUMN_DEFAULT IS NOT NULL AND t.table_type != 'VIEW' as "column.HasDefault", - col.IS_NULLABLE = "YES" AS "column.IsNullable", + (col.COLUMN_DEFAULT IS NOT NULL AND col.COLUMN_DEFAULT != 'NULL') AND t.table_type != 'VIEW' as "column.HasDefault", + col.IS_NULLABLE = 'YES' AS "column.IsNullable", col.COLUMN_COMMENT AS "column.Comment", COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey", IF (col.COLUMN_TYPE = 'tinyint(1)', diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index cfd40aa..dcb4e97 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -180,11 +180,15 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string { return "Timez" case "interval": return "Interval" - case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", + case "user-defined", "enum", "text", "character", "character varying", "uuid", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", - "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", - "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL + "char", "varchar", "nvarchar", "bpchar", "varbit", + "tinytext", "mediumtext", "longtext": // MySQL return "String" + case "bytea": // postgres + return "Bytea" + case "binary", "varbinary", "tinyblob", "mediumblob", "longblob", "blob": // mysql and sqlite + return "Blob" case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", "double": // MySQL return "Float" diff --git a/go.mod b/go.mod index 135fc3d..23d0e07 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-jet/jet/v2 -go 1.21 +go 1.22 // used by jet generator require ( diff --git a/internal/3rdparty/snaker/snaker.go b/internal/3rdparty/snaker/snaker.go index 32a19e6..4177e1c 100644 --- a/internal/3rdparty/snaker/snaker.go +++ b/internal/3rdparty/snaker/snaker.go @@ -40,14 +40,23 @@ func snakeToCamel(s string, upperCase bool) string { if upperCase || i > 0 { result += camelizeWord(word, len(words) > 1) - } else { - result += word + } else { // lowerCase and i == 0 + result += toLowerFirstLetter(word) } } return result } +func toLowerFirstLetter(s string) string { + if s == "" { + return s + } + runes := []rune(s) + runes[0] = unicode.ToLower(runes[0]) + return string(runes) +} + func camelizeWord(word string, force bool) string { runes := []rune(word) diff --git a/internal/3rdparty/snaker/snaker_test.go b/internal/3rdparty/snaker/snaker_test.go index f828a91..e05ca23 100644 --- a/internal/3rdparty/snaker/snaker_test.go +++ b/internal/3rdparty/snaker/snaker_test.go @@ -7,7 +7,10 @@ import ( func TestSnakeToCamel(t *testing.T) { require.Equal(t, SnakeToCamel(""), "") + require.Equal(t, SnakeToCamel("_", false), "") require.Equal(t, SnakeToCamel("potato_"), "Potato") + require.Equal(t, SnakeToCamel("potato_", false), "potato") + require.Equal(t, SnakeToCamel("Potato_", false), "potato") require.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased") require.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID") require.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier") diff --git a/internal/jet/alias.go b/internal/jet/alias.go index 8693b13..d096a65 100644 --- a/internal/jet/alias.go +++ b/internal/jet/alias.go @@ -18,10 +18,45 @@ func (a *alias) fromImpl(subQuery SelectTable) Projection { // Generated columns have default aliasing. tableName, columnName := extractTableAndColumnName(a.alias) - column := NewColumnImpl(columnName, tableName, nil) - column.subQuery = subQuery + newDummyColumn := newDummyColumnForExpression(a.expression, columnName) + newDummyColumn.setTableName(tableName) + newDummyColumn.setSubQuery(subQuery) - return &column + return newDummyColumn +} + +// This function is used to create dummy columns when exporting sub-query columns using subQuery.AllColumns() +// In most case we don't care about type of the column, except when sub-query columns are used as SELECT_JSON projection. +// We need to know type to encode value for json unmarshal. At the moment only bool, time and blob columns are of interest, +// so we don't have to support every column type. +func newDummyColumnForExpression(exp Expression, name string) ColumnExpression { + + switch exp.(type) { + case BoolExpression: + return BoolColumn(name) + case IntegerExpression: + return IntegerColumn(name) + case FloatExpression: + return FloatColumn(name) + case BlobExpression: + return BlobColumn(name) + case DateExpression: + return DateColumn(name) + case TimeExpression: + return TimeColumn(name) + case TimezExpression: + return TimezColumn(name) + case TimestampExpression: + return TimestampColumn(name) + case TimestampzExpression: + return TimestampzColumn(name) + case IntervalExpression: + return IntervalColumn(name) + case StringExpression: + return StringColumn(name) + } + + return StringColumn(name) } func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder) { @@ -30,3 +65,15 @@ func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder) out.WriteString("AS") out.WriteAlias(a.alias) } + +func (a *alias) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) { + out.WriteJsonObjKey(a.alias) + a.expression.serializeForJsonValue(statement, out) +} + +func (a *alias) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) { + a.expression.serializeForJsonValue(statement, out) + + out.WriteString("AS") + out.WriteAlias(a.alias) +} diff --git a/internal/jet/blob_expression.go b/internal/jet/blob_expression.go new file mode 100644 index 0000000..435c68e --- /dev/null +++ b/internal/jet/blob_expression.go @@ -0,0 +1,104 @@ +package jet + +// BlobExpression interface +type BlobExpression interface { + Expression + + isStringOrBlob() + + EQ(rhs BlobExpression) BoolExpression + NOT_EQ(rhs BlobExpression) BoolExpression + IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression + + LT(rhs BlobExpression) BoolExpression + LT_EQ(rhs BlobExpression) BoolExpression + GT(rhs BlobExpression) BoolExpression + GT_EQ(rhs BlobExpression) BoolExpression + BETWEEN(min, max BlobExpression) BoolExpression + NOT_BETWEEN(min, max BlobExpression) BoolExpression + + CONCAT(rhs BlobExpression) BlobExpression + + LIKE(pattern BlobExpression) BoolExpression + NOT_LIKE(pattern BlobExpression) BoolExpression +} + +type blobInterfaceImpl struct { + parent BlobExpression +} + +func (b *blobInterfaceImpl) isStringOrBlob() {} + +func (b *blobInterfaceImpl) EQ(rhs BlobExpression) BoolExpression { + return Eq(b.parent, rhs) +} + +func (b *blobInterfaceImpl) NOT_EQ(rhs BlobExpression) BoolExpression { + return NotEq(b.parent, rhs) +} + +func (b *blobInterfaceImpl) IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression { + return IsDistinctFrom(b.parent, rhs) +} + +func (b *blobInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression { + return IsNotDistinctFrom(b.parent, rhs) +} + +func (b *blobInterfaceImpl) GT(rhs BlobExpression) BoolExpression { + return Gt(b.parent, rhs) +} + +func (b *blobInterfaceImpl) GT_EQ(rhs BlobExpression) BoolExpression { + return GtEq(b.parent, rhs) +} + +func (b *blobInterfaceImpl) LT(rhs BlobExpression) BoolExpression { + return Lt(b.parent, rhs) +} + +func (b *blobInterfaceImpl) LT_EQ(rhs BlobExpression) BoolExpression { + return LtEq(b.parent, rhs) +} + +func (b *blobInterfaceImpl) BETWEEN(min, max BlobExpression) BoolExpression { + return NewBetweenOperatorExpression(b.parent, min, max, false) +} + +func (b *blobInterfaceImpl) NOT_BETWEEN(min, max BlobExpression) BoolExpression { + return NewBetweenOperatorExpression(b.parent, min, max, true) +} + +func (b *blobInterfaceImpl) CONCAT(rhs BlobExpression) BlobExpression { + return BlobExp(newBinaryStringOperatorExpression(b.parent, rhs, StringConcatOperator)) +} + +func (b *blobInterfaceImpl) LIKE(pattern BlobExpression) BoolExpression { + return newBinaryBoolOperatorExpression(b.parent, pattern, "LIKE") +} + +func (b *blobInterfaceImpl) NOT_LIKE(pattern BlobExpression) BoolExpression { + return newBinaryBoolOperatorExpression(b.parent, pattern, "NOT LIKE") +} + +//---------------------------------------------------// + +type blobExpressionWrapper struct { + Expression + blobInterfaceImpl +} + +func newBlobExpressionWrap(expression Expression) BlobExpression { + blobExpressionWrap := &blobExpressionWrapper{Expression: expression} + blobExpressionWrap.blobInterfaceImpl.parent = blobExpressionWrap + expression.setParent(blobExpressionWrap) + return blobExpressionWrap +} + +// BlobExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as blob expression. +// Does not add sql cast to generated sql builder output. +func BlobExp(expression Expression) BlobExpression { + return newBlobExpressionWrap(expression) +} diff --git a/internal/jet/bool_expression.go b/internal/jet/bool_expression.go index 41bdcc4..450375c 100644 --- a/internal/jet/bool_expression.go +++ b/internal/jet/bool_expression.go @@ -102,9 +102,10 @@ type boolExpressionWrapper struct { } func newBoolExpressionWrap(expression Expression) BoolExpression { - boolExpressionWrap := boolExpressionWrapper{Expression: expression} - boolExpressionWrap.boolInterfaceImpl.parent = &boolExpressionWrap - return &boolExpressionWrap + boolExpressionWrap := &boolExpressionWrapper{Expression: expression} + boolExpressionWrap.boolInterfaceImpl.parent = boolExpressionWrap + expression.setParent(boolExpressionWrap) + return boolExpressionWrap } // BoolExp is bool expression wrapper around arbitrary expression. diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 533d223..dab2f81 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -41,6 +41,8 @@ type ClauseSelect struct { DistinctOnColumns []ColumnExpression ProjectionList []Projection + IsForRowToJson bool + // MySQL only OptimizerHints optimizerHints } @@ -52,6 +54,10 @@ func (s *ClauseSelect) Projections() ProjectionList { // Serialize serializes clause into SQLBuilder func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(s.ProjectionList) == 0 { + panic("jet: SELECT clause has to have at least one projection") + } + out.NewLine() out.WriteString("SELECT") s.OptimizerHints.Serialize(statementType, out, options...) @@ -66,11 +72,13 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o out.WriteByte(')') } - if len(s.ProjectionList) == 0 { - panic("jet: SELECT clause has to have at least one projection") + if s.IsForRowToJson { + out.IncreaseIdent() + out.WriteRowToJsonProjections(statementType, s.ProjectionList) + out.DecreaseIdent() + } else { + out.WriteProjections(statementType, s.ProjectionList) } - - out.WriteProjections(statementType, s.ProjectionList) } // ClauseFrom struct diff --git a/internal/jet/column.go b/internal/jet/column.go index 2b1b930..dda19ff 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -2,6 +2,10 @@ package jet +import ( + "github.com/go-jet/jet/v2/internal/3rdparty/snaker" +) + // Column is common column interface for all types of columns. type Column interface { Name() string @@ -35,19 +39,19 @@ type ColumnExpressionImpl struct { } // NewColumnImpl creates new ColumnExpressionImpl -func NewColumnImpl(name string, tableName string, parent ColumnExpression) ColumnExpressionImpl { - bc := ColumnExpressionImpl{ +func NewColumnImpl(name string, tableName string, parent ColumnExpression) *ColumnExpressionImpl { + newColumn := &ColumnExpressionImpl{ name: name, tableName: tableName, } if parent != nil { - bc.ExpressionInterfaceImpl.Parent = parent + newColumn.ExpressionInterfaceImpl.Parent = parent } else { - bc.ExpressionInterfaceImpl.Parent = &bc + newColumn.ExpressionInterfaceImpl.Parent = newColumn } - return bc + return newColumn } // Name returns name of the column @@ -76,13 +80,6 @@ func (c *ColumnExpressionImpl) defaultAlias() string { return c.name } -func (c *ColumnExpressionImpl) fromImpl(subQuery SelectTable) Projection { - newColumn := NewColumnImpl(c.name, c.tableName, nil) - newColumn.setSubQuery(subQuery) - - return &newColumn -} - func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { if statement == SetStatementType { // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause @@ -93,14 +90,28 @@ func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out c.serialize(statement, out) } -func (c ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { +func (c *ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { c.serialize(statement, out) out.WriteString("AS") + out.WriteAlias(c.defaultAlias()) } -func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (c *ColumnExpressionImpl) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) { + out.WriteJsonObjKey(snaker.SnakeToCamel(c.name, false)) + c.Parent.serializeForJsonValue(statement, out) +} + +func (c *ColumnExpressionImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) { + c.Parent.serializeForJsonValue(statement, out) + + out.WriteString("AS") + + out.WriteAlias(snaker.SnakeToCamel(c.name, false)) +} + +func (c *ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if c.subQuery != nil { out.WriteIdentifier(c.subQuery.Alias()) diff --git a/internal/jet/column_list.go b/internal/jet/column_list.go index a07b9ba..0ea88fe 100644 --- a/internal/jet/column_list.go +++ b/internal/jet/column_list.go @@ -78,6 +78,18 @@ func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBui SerializeProjectionList(statement, projections, out) } +func (cl ColumnList) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) { + projections := ColumnListToProjectionList(cl) + + SerializeProjectionListJsonObj(statement, projections, out) +} + +func (cl ColumnList) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) { + projections := ColumnListToProjectionList(cl) + + out.WriteRowToJsonProjections(statement, projections) +} + // dummy column interface implementation // Name is placeholder for ColumnList to implement Column interface diff --git a/internal/jet/column_test.go b/internal/jet/column_test.go index 7e6c0ed..23e581a 100644 --- a/internal/jet/column_test.go +++ b/internal/jet/column_test.go @@ -4,11 +4,10 @@ import "testing" func TestColumn(t *testing.T) { column := NewColumnImpl("col", "", nil) - column.ExpressionInterfaceImpl.Parent = &column assertClauseSerialize(t, column, "col") column.setTableName("table1") assertClauseSerialize(t, column, "table1.col") - assertProjectionSerialize(t, &column, `table1.col AS "table1.col"`) + assertProjectionSerialize(t, column, `table1.col AS "table1.col"`) assertProjectionSerialize(t, column.AS("alias1"), `table1.col AS "alias1"`) } diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index a732061..00e333e 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -11,7 +11,11 @@ type ColumnBool interface { type boolColumnImpl struct { boolInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { @@ -51,7 +55,11 @@ type ColumnFloat interface { type floatColumnImpl struct { floatInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { @@ -92,7 +100,11 @@ type ColumnInteger interface { type integerColumnImpl struct { integerInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { @@ -122,7 +134,7 @@ func IntegerColumn(name string) ColumnInteger { //------------------------------------------------------// // ColumnString is interface for SQL text, character, character varying -// bytea, uuid columns and enums types. +// uuid columns and enums types. type ColumnString interface { StringExpression Column @@ -134,7 +146,11 @@ type ColumnString interface { type stringColumnImpl struct { stringInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { @@ -163,6 +179,51 @@ func StringColumn(name string) ColumnString { //------------------------------------------------------// +// ColumnBlob is interface for binary data types (bytea, binary, blob, etc...) +type ColumnBlob interface { + BlobExpression + Column + + From(subQuery SelectTable) ColumnBlob + SET(blob BlobExpression) ColumnAssigment +} + +type blobColumnImpl struct { + blobInterfaceImpl + + *ColumnExpressionImpl +} + +func (i *blobColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) +} + +func (i *blobColumnImpl) From(subQuery SelectTable) ColumnBlob { + newBlobColumn := BlobColumn(i.name) + newBlobColumn.setTableName(i.tableName) + newBlobColumn.setSubQuery(subQuery) + + return newBlobColumn +} + +func (i *blobColumnImpl) SET(blobExp BlobExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: blobExp, + } +} + +// BlobColumn creates named blob column. +func BlobColumn(name string) ColumnBlob { + blobColumn := &blobColumnImpl{} + blobColumn.blobInterfaceImpl.parent = blobColumn + blobColumn.ColumnExpressionImpl = NewColumnImpl(name, "", blobColumn) + + return blobColumn +} + +//------------------------------------------------------// + // ColumnTime is interface for SQL time column. type ColumnTime interface { TimeExpression @@ -174,7 +235,11 @@ type ColumnTime interface { type timeColumnImpl struct { timeInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { @@ -213,7 +278,11 @@ type ColumnTimez interface { type timezColumnImpl struct { timezInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { @@ -253,7 +322,11 @@ type ColumnTimestamp interface { type timestampColumnImpl struct { timestampInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { @@ -293,7 +366,11 @@ type ColumnTimestampz interface { type timestampzColumnImpl struct { timestampzInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { @@ -333,7 +410,11 @@ type ColumnDate interface { type dateColumnImpl struct { dateInterfaceImpl - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { @@ -361,6 +442,51 @@ func DateColumn(name string) ColumnDate { //------------------------------------------------------// +// ColumnInterval is interface of PostgreSQL interval columns. +type ColumnInterval interface { + IntervalExpression + Column + + From(subQuery SelectTable) ColumnInterval + SET(intervalExp IntervalExpression) ColumnAssigment +} + +//------------------------------------------------------// + +type intervalColumnImpl struct { + *ColumnExpressionImpl + intervalInterfaceImpl +} + +func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: intervalExp, + } +} + +func (i *intervalColumnImpl) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) +} + +func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval { + newIntervalColumn := IntervalColumn(i.name) + newIntervalColumn.setTableName(i.tableName) + newIntervalColumn.setSubQuery(subQuery) + + return newIntervalColumn +} + +// IntervalColumn creates named interval column. +func IntervalColumn(name string) ColumnInterval { + intervalColumn := &intervalColumnImpl{} + intervalColumn.ColumnExpressionImpl = NewColumnImpl(name, "", intervalColumn) + intervalColumn.intervalInterfaceImpl.parent = intervalColumn + return intervalColumn +} + +//------------------------------------------------------// + // ColumnRange is interface for range columns which can be int range, string range // timestamp range or date range. type ColumnRange[T Expression] interface { @@ -373,7 +499,11 @@ type ColumnRange[T Expression] interface { type rangeColumnImpl[T Expression] struct { rangeInterfaceImpl[T] - ColumnExpressionImpl + *ColumnExpressionImpl +} + +func (i *rangeColumnImpl[T]) fromImpl(subQuery SelectTable) Projection { + return i.From(subQuery) } func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] { diff --git a/internal/jet/date_expression.go b/internal/jet/date_expression.go index 560688a..2f62a50 100644 --- a/internal/jet/date_expression.go +++ b/internal/jet/date_expression.go @@ -80,9 +80,10 @@ type dateExpressionWrapper struct { } func newDateExpressionWrap(expression Expression) DateExpression { - dateExpressionWrap := dateExpressionWrapper{Expression: expression} - dateExpressionWrap.dateInterfaceImpl.parent = &dateExpressionWrap - return &dateExpressionWrap + dateExpressionWrap := &dateExpressionWrapper{Expression: expression} + dateExpressionWrap.dateInterfaceImpl.parent = dateExpressionWrap + expression.setParent(dateExpressionWrap) + return dateExpressionWrap } // DateExp is date expression wrapper around arbitrary expression. diff --git a/internal/jet/date_expression_test.go b/internal/jet/date_expression_test.go deleted file mode 100644 index 14fdd76..0000000 --- a/internal/jet/date_expression_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package jet - -import ( - "testing" -) - -func TestDateArithmetic(t *testing.T) { - timestamp := Timestamp(2000, 1, 1, 0, 0, 0) - assertClauseDebugSerialize(t, table1ColDate.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp), - "((table1.col_date + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") - assertClauseDebugSerialize(t, table1ColDate.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp), - "((table1.col_date - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") -} diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 68c4c02..678e044 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -1,6 +1,8 @@ package jet -import "strings" +import ( + "strings" +) // Dialect interface type Dialect interface { @@ -11,9 +13,11 @@ type Dialect interface { AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc + ArgumentToString(value any) (string, bool) IsReservedWord(name string) bool SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName(index int) string + JsonValueEncode(expr Expression) Expression } // SerializerFunc func @@ -34,9 +38,11 @@ type DialectParams struct { AliasQuoteChar byte IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc + ArgumentToString func(value any) (string, bool) ReservedWords []string SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName func(index int) string + JsonValueEncode func(expr Expression) Expression } // NewDialect creates new dialect with params @@ -49,9 +55,11 @@ func NewDialect(params DialectParams) Dialect { aliasQuoteChar: params.AliasQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, + argumentToString: params.ArgumentToString, reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), serializeOrderBy: params.SerializeOrderBy, valuesDefaultColumnName: params.ValuesDefaultColumnName, + jsonValueEncode: params.JsonValueEncode, } } @@ -63,9 +71,11 @@ type dialectImpl struct { aliasQuoteChar byte identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc + argumentToString func(value any) (string, bool) reservedWords map[string]bool serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc valuesDefaultColumnName func(index int) string + jsonValueEncode func(expr Expression) Expression } func (d *dialectImpl) Name() string { @@ -102,6 +112,10 @@ func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { return d.argumentPlaceholder } +func (d *dialectImpl) ArgumentToString(value any) (string, bool) { + return d.argumentToString(value) +} + func (d *dialectImpl) IsReservedWord(name string) bool { _, isReservedWord := d.reservedWords[strings.ToLower(name)] return isReservedWord @@ -115,6 +129,10 @@ func (d *dialectImpl) ValuesDefaultColumnName(index int) string { return d.valuesDefaultColumnName(index) } +func (d *dialectImpl) JsonValueEncode(expr Expression) Expression { + return d.jsonValueEncode(expr) +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 9999803..2ec5edc 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -2,7 +2,7 @@ package jet import "fmt" -// Expression is common interface for all expressions. +// Expression is a common interface for all expressions. // Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions. type Expression interface { Serializer @@ -10,6 +10,9 @@ type Expression interface { GroupByClause OrderByClause + serializeForJsonValue(statement StatementType, out *SQLBuilder) + setParent(parent Expression) + // IS_NULL tests expression whether it is a NULL value. IS_NULL() BoolExpression // IS_NOT_NULL tests expression whether it is a non-NULL value. @@ -34,6 +37,10 @@ type ExpressionInterfaceImpl struct { Parent Expression } +func (e *ExpressionInterfaceImpl) setParent(parent Expression) { + e.Parent = parent +} + func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { panic(fmt.Sprintf("jet: can't export unaliased expression subQuery: %s, expression: %s", subQuery.Alias(), serializeToDefaultDebugString(e.Parent))) @@ -92,6 +99,18 @@ func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType e.Parent.serialize(statement, out, NoWrap) } +func (e *ExpressionInterfaceImpl) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) { + panic("jet: expression need to be aliased when used as SELECT JSON projection.") +} + +func (e *ExpressionInterfaceImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) { + panic("jet: expression need to be aliased when used as SELECT JSON projection.") +} + +func (e *ExpressionInterfaceImpl) serializeForJsonValue(statement StatementType, out *SQLBuilder) { + out.Dialect.JsonValueEncode(e.Parent).serialize(statement, out) +} + func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { e.Parent.serialize(statement, out, NoWrap) } @@ -152,7 +171,7 @@ func newExpressionListOperator(operator string, expressions ...Expression) *expr } func newBoolExpressionListOperator(operator string, expressions ...BoolExpression) BoolExpression { - return BoolExp(newExpressionListOperator(operator, BoolExpressionListToExpressionList(expressions)...)) + return BoolExp(newExpressionListOperator(operator, ToExpressionList(expressions)...)) } func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/jet/float_expression.go b/internal/jet/float_expression.go index 52c97eb..f695ec3 100644 --- a/internal/jet/float_expression.go +++ b/internal/jet/float_expression.go @@ -102,9 +102,10 @@ type floatExpressionWrapper struct { } func newFloatExpressionWrap(expression Expression) FloatExpression { - floatExpressionWrap := floatExpressionWrapper{Expression: expression} - floatExpressionWrap.floatInterfaceImpl.parent = &floatExpressionWrap - return &floatExpressionWrap + floatExpressionWrap := &floatExpressionWrapper{Expression: expression} + floatExpressionWrap.floatInterfaceImpl.parent = floatExpressionWrap + expression.setParent(floatExpressionWrap) + return floatExpressionWrap } // FloatExp is date expression wrapper around arbitrary expression. diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index ddc579e..46c8f0d 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -255,18 +255,30 @@ func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) //------------ String functions ------------------// +// HEX function takes an input and returns its equivalent hexadecimal representation +func HEX(expression Expression) StringExpression { + return StringExp(Func("HEX", expression)) +} + +// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument +// as a hexadecimal number and converts it to the byte represented by the number. +// The return value is a binary string. +func UNHEX(expression StringExpression) BlobExpression { + return BlobExp(Func("UNHEX", expression)) +} + // BIT_LENGTH returns number of bits in string expression -func BIT_LENGTH(stringExpression StringExpression) IntegerExpression { +func BIT_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression { return newIntegerFunc("BIT_LENGTH", stringExpression) } // CHAR_LENGTH returns number of characters in string expression -func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression { +func CHAR_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression { return newIntegerFunc("CHAR_LENGTH", stringExpression) } // OCTET_LENGTH returns number of bytes in string expression -func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression { +func OCTET_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression { return newIntegerFunc("OCTET_LENGTH", stringExpression) } @@ -282,7 +294,7 @@ func UPPER(stringExpression StringExpression) StringExpression { // BTRIM removes the longest string consisting only of characters // in characters (a space by default) from the start and end of string -func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression { +func BTRIM(stringExpression StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression { if len(trimChars) > 0 { return NewStringFunc("BTRIM", stringExpression, trimChars[0]) } @@ -291,7 +303,7 @@ func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) Str // LTRIM removes the longest string containing only characters // from characters (a space by default) from the start of string -func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { +func LTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression { if len(trimChars) > 0 { return NewStringFunc("LTRIM", str, trimChars[0]) } @@ -300,7 +312,7 @@ func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression // RTRIM removes the longest string containing only characters // from characters (a space by default) from the end of string -func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { +func RTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression { if len(trimChars) > 0 { return NewStringFunc("RTRIM", str, trimChars[0]) } @@ -324,32 +336,32 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression // CONVERT converts string to dest_encoding. The original encoding is // specified by src_encoding. The string must be valid in this encoding. -func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression { - return NewStringFunc("CONVERT", str, srcEncoding, destEncoding) +func CONVERT(str BlobExpression, srcEncoding StringExpression, destEncoding StringExpression) BlobExpression { + return BlobExp(Func("CONVERT", str, srcEncoding, destEncoding)) } // CONVERT_FROM converts string to the database encoding. The original // encoding is specified by src_encoding. The string must be valid in this encoding. -func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression { +func CONVERT_FROM(str BlobExpression, srcEncoding StringExpression) StringExpression { return NewStringFunc("CONVERT_FROM", str, srcEncoding) } // CONVERT_TO converts string to dest_encoding. -func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { - return NewStringFunc("CONVERT_TO", str, toEncoding) +func CONVERT_TO(str StringExpression, toEncoding StringExpression) BlobExpression { + return BlobExp(Func("CONVERT_TO", str, toEncoding)) } // ENCODE encodes binary data into a textual representation. // Supported formats are: base64, hex, escape. escape converts zero bytes and // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. -func ENCODE(data StringExpression, format StringExpression) StringExpression { - return NewStringFunc("ENCODE", data, format) +func ENCODE(data BlobExpression, format StringExpression) StringExpression { + return StringExp(Func("ENCODE", data, format)) } // DECODE decodes binary data from textual representation in string. // Options for format are same as in encode. -func DECODE(data StringExpression, format StringExpression) StringExpression { - return NewStringFunc("DECODE", data, format) +func DECODE(data StringExpression, format StringExpression) BlobExpression { + return BlobExp(Func("DECODE", data, format)) } // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. @@ -379,11 +391,11 @@ func RIGHT(str StringExpression, n IntegerExpression) StringExpression { } // LENGTH returns number of characters in string with a given encoding -func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { +func LENGTH(str StringOrBlobExpression, encoding ...StringExpression) IntegerExpression { if len(encoding) > 0 { - return NewStringFunc("LENGTH", str, encoding[0]) + return IntExp(Func("LENGTH", str, encoding[0])) } - return NewStringFunc("LENGTH", str) + return IntExp(Func("LENGTH", str)) } // LPAD fills up the string to length length by prepending the characters @@ -407,8 +419,13 @@ func RPAD(str StringExpression, length IntegerExpression, text ...StringExpressi return NewStringFunc("RPAD", str, length) } +// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”). +func BIT_COUNT(bytes BlobExpression) IntegerExpression { + return IntExp(Func("BIT_COUNT", bytes)) +} + // MD5 calculates the MD5 hash of string, returning the result in hexadecimal -func MD5(stringExpression StringExpression) StringExpression { +func MD5(stringExpression StringOrBlobExpression) StringExpression { return NewStringFunc("MD5", stringExpression) } @@ -434,7 +451,7 @@ func STRPOS(str, substring StringExpression) IntegerExpression { } // SUBSTR extracts substring -func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { +func SUBSTR(str StringOrBlobExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { if len(count) > 0 { return NewStringFunc("SUBSTR", str, from, count[0]) } diff --git a/internal/jet/integer_expression.go b/internal/jet/integer_expression.go index f5451a0..d2761fc 100644 --- a/internal/jet/integer_expression.go +++ b/internal/jet/integer_expression.go @@ -141,11 +141,11 @@ type integerExpressionWrapper struct { } func newIntExpressionWrap(expression Expression) IntegerExpression { - intExpressionWrap := integerExpressionWrapper{Expression: expression} + intExpressionWrap := &integerExpressionWrapper{Expression: expression} + intExpressionWrap.integerInterfaceImpl.parent = intExpressionWrap + expression.setParent(intExpressionWrap) - intExpressionWrap.integerInterfaceImpl.parent = &intExpressionWrap - - return &intExpressionWrap + return intExpressionWrap } // IntExp is int expression wrapper around arbitrary expression. diff --git a/internal/jet/interval.go b/internal/jet/interval.go deleted file mode 100644 index debcb57..0000000 --- a/internal/jet/interval.go +++ /dev/null @@ -1,37 +0,0 @@ -package jet - -// Interval is internal common representation of sql interval -type Interval interface { - Serializer - IsInterval -} - -// IsInterval interface -type IsInterval interface { - isInterval() -} - -// IsIntervalImpl is implementation of IsInterval interface -type IsIntervalImpl struct{} - -func (i *IsIntervalImpl) isInterval() {} - -// NewInterval creates new interval from serializer -func NewInterval(s Serializer) *IntervalImpl { - newInterval := &IntervalImpl{ - Value: s, - } - - return newInterval -} - -// IntervalImpl is implementation of Interval type -type IntervalImpl struct { - Value Serializer - IsIntervalImpl -} - -func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("INTERVAL") - i.Value.serialize(statement, out, FallTrough(options)...) -} diff --git a/internal/jet/interval_expression.go b/internal/jet/interval_expression.go new file mode 100644 index 0000000..5d49462 --- /dev/null +++ b/internal/jet/interval_expression.go @@ -0,0 +1,112 @@ +package jet + +// IntervalExpression interface +type IntervalExpression interface { + Expression + isInterval() + + EQ(rhs IntervalExpression) BoolExpression + NOT_EQ(rhs IntervalExpression) BoolExpression + IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression + + LT(rhs IntervalExpression) BoolExpression + LT_EQ(rhs IntervalExpression) BoolExpression + GT(rhs IntervalExpression) BoolExpression + GT_EQ(rhs IntervalExpression) BoolExpression + BETWEEN(min, max IntervalExpression) BoolExpression + NOT_BETWEEN(min, max IntervalExpression) BoolExpression + + ADD(rhs IntervalExpression) IntervalExpression + SUB(rhs IntervalExpression) IntervalExpression + + MUL(rhs NumericExpression) IntervalExpression + DIV(rhs NumericExpression) IntervalExpression +} + +type intervalInterfaceImpl struct { + parent IntervalExpression +} + +func (i *intervalInterfaceImpl) isInterval() {} + +func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression { + return Eq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression { + return NotEq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression { + return IsDistinctFrom(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression { + return IsNotDistinctFrom(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression { + return Lt(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression { + return LtEq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression { + return Gt(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression { + return GtEq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression { + return NewBetweenOperatorExpression(i.parent, min, max, false) +} + +func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression { + return NewBetweenOperatorExpression(i.parent, min, max, true) +} + +func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression { + return IntervalExp(Add(i.parent, rhs)) +} + +func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression { + return IntervalExp(Sub(i.parent, rhs)) +} + +func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression { + return IntervalExp(Mul(i.parent, rhs)) +} + +func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression { + return IntervalExp(Div(i.parent, rhs)) +} + +type intervalWrapper struct { + intervalInterfaceImpl + Expression +} + +func newIntervalExpressionWrap(expression Expression) IntervalExpression { + intervalWrap := &intervalWrapper{Expression: expression} + intervalWrap.intervalInterfaceImpl.parent = intervalWrap + expression.setParent(intervalWrap) + return intervalWrap +} + +// IntervalExp is interval expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as interval expression. +// Does not add sql cast to generated sql builder output. +func IntervalExp(expression Expression) IntervalExpression { + return newIntervalExpressionWrap(expression) +} + +// Interval interface +type Interval interface { + Serializer + isInterval() +} diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 251d3ab..a2eec63 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -412,17 +412,6 @@ func Raw(raw string, namedArgs ...map[string]interface{}) Expression { return rawExp } -// RawWithParent is a Raw constructor used for construction dialect specific expression -func RawWithParent(raw string, parent ...Expression) Expression { - rawExp := &rawExpression{ - Raw: raw, - noWrap: true, - } - rawExp.ExpressionInterfaceImpl.Parent = OptionalOrDefaultExpression(rawExp, parent...) - - return rawExp -} - // RawBool helper that for raw string boolean expressions func RawBool(raw string, namedArgs ...map[string]interface{}) BoolExpression { return BoolExp(Raw(raw, namedArgs...)) @@ -468,6 +457,11 @@ func RawDate(raw string, namedArgs ...map[string]interface{}) DateExpression { return DateExp(Raw(raw, namedArgs...)) } +// RawBlob is raw query helper that for blob expressions +func RawBlob(raw string, namedArgs ...map[string]interface{}) BlobExpression { + return BlobExp(Raw(raw, namedArgs...)) +} + // RawRange helper that for range expressions func RawRange[T Expression](raw string, namedArgs ...map[string]interface{}) Range[T] { return RangeExp[T](Raw(raw, namedArgs...)) diff --git a/internal/jet/projection.go b/internal/jet/projection.go index 3b2ccd8..03a94f1 100644 --- a/internal/jet/projection.go +++ b/internal/jet/projection.go @@ -3,6 +3,8 @@ package jet // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. type Projection interface { serializeForProjection(statement StatementType, out *SQLBuilder) + serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) + serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) fromImpl(subQuery SelectTable) Projection } @@ -28,6 +30,10 @@ func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQ SerializeProjectionList(statement, pl, out) } +func (pl ProjectionList) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) { + SerializeProjectionListJsonObj(statement, pl, out) +} + // As will create new projection list where each column is wrapped with a new table alias. // tableAlias should be in the form 'name' or 'name.*', or it can be an empty string, which will remove existing table alias. // For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will @@ -79,3 +85,18 @@ func (pl ProjectionList) Except(toExclude ...Column) ProjectionList { return ret } + +func (pl ProjectionList) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) { + out.WriteRowToJsonProjections(statement, pl) +} + +// JsonObjProjectionList redefines []Projection so projections can be serialized as json object key/values +type JsonObjProjectionList []Projection + +func (j JsonObjProjectionList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.IncreaseIdent() + out.NewLine() + SerializeProjectionListJsonObj(statement, j, out) + out.DecreaseIdent() + out.NewLine() +} diff --git a/internal/jet/range_expression.go b/internal/jet/range_expression.go index 9e15c2b..77ed290 100644 --- a/internal/jet/range_expression.go +++ b/internal/jet/range_expression.go @@ -118,9 +118,10 @@ type rangeExpressionWrapper[T Expression] struct { } func newRangeExpressionWrap[T Expression](expression Expression) Range[T] { - rangeExpressionWrap := rangeExpressionWrapper[T]{Expression: expression} - rangeExpressionWrap.rangeInterfaceImpl.parent = &rangeExpressionWrap - return &rangeExpressionWrap + rangeExpressionWrap := &rangeExpressionWrapper[T]{Expression: expression} + rangeExpressionWrap.rangeInterfaceImpl.parent = rangeExpressionWrap + expression.setParent(rangeExpressionWrap) + return rangeExpressionWrap } // RangeExp is range expression wrapper around arbitrary expression. diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go index 99fb8eb..427b3c8 100644 --- a/internal/jet/raw_statement.go +++ b/internal/jet/raw_statement.go @@ -1,7 +1,7 @@ package jet type rawStatementImpl struct { - serializerStatementInterfaceImpl + statementInterfaceImpl RawQuery string NamedArguments map[string]interface{} @@ -10,7 +10,7 @@ type rawStatementImpl struct { // RawStatement creates new sql statements from raw query and optional map of named arguments func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement { newRawStatement := rawStatementImpl{ - serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ + statementInterfaceImpl: statementInterfaceImpl{ dialect: dialect, statementType: "", parent: nil, diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go index e7d5ed5..a07d140 100644 --- a/internal/jet/row_expression.go +++ b/internal/jet/row_expression.go @@ -17,9 +17,9 @@ type RowExpression interface { } type rowInterfaceImpl struct { - parent Expression - dialect Dialect - elemCount int + parent Expression + dialect Dialect + expressions []Expression } func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression { @@ -57,9 +57,8 @@ func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression { func (n *rowInterfaceImpl) projections() ProjectionList { var ret ProjectionList - for i := 0; i < n.elemCount; i++ { - rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil) - ret = append(ret, &rowColumn) + for i, expression := range n.expressions { + ret = append(ret, newDummyColumnForExpression(expression, n.dialect.ValuesDefaultColumnName(i))) } return ret @@ -77,7 +76,7 @@ func newRowExpression(name string, dialect Dialect, expressions ...Expression) R ret.Expression = NewFunc(name, expressions, ret) ret.dialect = dialect - ret.elemCount = len(expressions) + ret.expressions = expressions return ret } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 9d36de4..d876f36 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -24,14 +24,16 @@ type StatementType string // Statement types const ( - SelectStatementType StatementType = "SELECT" - InsertStatementType StatementType = "INSERT" - UpdateStatementType StatementType = "UPDATE" - DeleteStatementType StatementType = "DELETE" - SetStatementType StatementType = "SET" - LockStatementType StatementType = "LOCK" - UnLockStatementType StatementType = "UNLOCK" - WithStatementType StatementType = "WITH" + SelectStatementType StatementType = "SELECT" + SelectJsonObjStatementType StatementType = "SELECT_JSON_OBJ" + SelectJsonArrStatementType StatementType = "SELECT_JSON_ARR" + InsertStatementType StatementType = "INSERT" + UpdateStatementType StatementType = "UPDATE" + DeleteStatementType StatementType = "DELETE" + SetStatementType StatementType = "SET" + LockStatementType StatementType = "LOCK" + UnLockStatementType StatementType = "UNLOCK" + WithStatementType StatementType = "WITH" ) // Serializer interface diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 46f47ad..7077af4 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -61,6 +61,17 @@ func (s *SQLBuilder) WriteProjections(statement StatementType, projections []Pro s.DecreaseIdent() } +// WriteRowToJsonProjections serializes slice of projections intended for row_to_json json aggregation +func (s *SQLBuilder) WriteRowToJsonProjections(statement StatementType, projections []Projection) { + for i, projection := range projections { + if i > 0 { + s.WriteString(",") + s.NewLine() + } + projection.serializeForRowToJsonProjection(statement, s) + } +} + // NewLine adds new line to output SQL func (s *SQLBuilder) NewLine() { s.write([]byte{'\n'}) @@ -99,6 +110,11 @@ func (s *SQLBuilder) WriteString(str string) { s.write([]byte(str)) } +// WriteJsonObjKey serializes json object key +func (s *SQLBuilder) WriteJsonObjKey(key string) { + s.WriteString(fmt.Sprintf(`'%s', `, key)) +} + // WriteIdentifier adds identifier to output SQL func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { if s.shouldQuote(name, alwaysQuote...) { @@ -123,7 +139,7 @@ func (s *SQLBuilder) finalize() (string, []interface{}) { } func (s *SQLBuilder) insertConstantArgument(arg interface{}) { - s.WriteString(argToString(arg)) + s.WriteString(s.argToString(arg)) } func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { @@ -196,7 +212,7 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) } if s.Debug { - placeholder = argToString(namedArgumentPos.Value) + placeholder = s.argToString(namedArgumentPos.Value) } raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace) @@ -205,11 +221,17 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) s.WriteString(raw) } -func argToString(value interface{}) string { +func (s *SQLBuilder) argToString(value interface{}) string { if is.Nil(value) { return "NULL" } + strVal, ok := s.Dialect.ArgumentToString(value) + + if ok { + return strVal + } + switch bindVal := value.(type) { case bool: if bindVal { @@ -246,7 +268,7 @@ func argToString(value interface{}) string { return err.Error() } - return argToString(val) + return s.argToString(val) } panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) diff --git a/internal/jet/sql_builder_test.go b/internal/jet/sql_builder_test.go index 3356e6e..ebcf83f 100644 --- a/internal/jet/sql_builder_test.go +++ b/internal/jet/sql_builder_test.go @@ -8,37 +8,39 @@ import ( ) func TestArgToString(t *testing.T) { - require.Equal(t, argToString(true), "TRUE") - require.Equal(t, argToString(false), "FALSE") + s := &SQLBuilder{Dialect: defaultDialect, Debug: true} - require.Equal(t, argToString(int(-32)), "-32") - require.Equal(t, argToString(uint(32)), "32") - require.Equal(t, argToString(int8(-43)), "-43") - require.Equal(t, argToString(uint8(43)), "43") - require.Equal(t, argToString(int16(-54)), "-54") - require.Equal(t, argToString(uint16(54)), "54") - require.Equal(t, argToString(int32(-65)), "-65") - require.Equal(t, argToString(uint32(65)), "65") - require.Equal(t, argToString(int64(-64)), "-64") - require.Equal(t, argToString(uint64(64)), "64") - require.Equal(t, argToString(float32(2.0)), "2") - require.Equal(t, argToString(float64(1.11)), "1.11") + require.Equal(t, s.argToString(true), "TRUE") + require.Equal(t, s.argToString(false), "FALSE") - require.Equal(t, argToString("john"), "'john'") - require.Equal(t, argToString("It's text"), "'It''s text'") - require.Equal(t, argToString([]byte("john")), "'john'") - require.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") + require.Equal(t, s.argToString(int(-32)), "-32") + require.Equal(t, s.argToString(uint(32)), "32") + require.Equal(t, s.argToString(int8(-43)), "-43") + require.Equal(t, s.argToString(uint8(43)), "43") + require.Equal(t, s.argToString(int16(-54)), "-54") + require.Equal(t, s.argToString(uint16(54)), "54") + require.Equal(t, s.argToString(int32(-65)), "-65") + require.Equal(t, s.argToString(uint32(65)), "65") + require.Equal(t, s.argToString(int64(-64)), "-64") + require.Equal(t, s.argToString(uint64(64)), "64") + require.Equal(t, s.argToString(float32(2.0)), "2") + require.Equal(t, s.argToString(float64(1.11)), "1.11") + + require.Equal(t, s.argToString("john"), "'john'") + require.Equal(t, s.argToString("It's text"), "'It''s text'") + require.Equal(t, s.argToString([]byte("john")), "'john'") + require.Equal(t, s.argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") require.NoError(t, err) - require.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") + require.Equal(t, s.argToString(time), "'2006-01-02 15:04:05-07:00'") func() { defer func() { require.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter") }() - argToString(map[string]bool{}) + s.argToString(map[string]bool{}) }() } diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 2ca229d..bfa9e8e 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -7,25 +7,38 @@ import ( "time" ) -// Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) +// Statement is a common interface for all SQL statements, including SELECT, SELECT_JSON_ARR, SELECT_JSON_OBJ, INSERT, +// UPDATE, DELETE, and LOCK. type Statement interface { - // Sql returns parametrized sql query with list of arguments. + // Sql returns a parameterized SQL query along with its list of arguments. Sql() (query string, args []interface{}) - // DebugSql returns debug query where every parametrized placeholder is replaced with its argument string representation. - // Do not use it in production. Use it only for debug purposes. + + // DebugSql returns a debug-friendly SQL query where all parameterized placeholders + // are replaced with their respective argument string representations. + // + // Warning: This method should only be used for debugging purposes. + // Do not use it in production, as it may lead to security risks such as SQL injection. DebugSql() (query string) - // Query executes statement over database connection/transaction db and stores row results in destination. - // Destination can be either pointer to struct or pointer to a slice. - // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. + + // Query delegates call to QueryContext using context.Background() as parameter. Query(db qrm.Queryable, destination interface{}) error - // QueryContext executes statement with a context over database connection/transaction db and stores row result in destination. - // Destination can be either pointer to struct or pointer to a slice. - // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. + + // QueryContext executes the statement with the provided context over a database connection or transaction (`db`), + // and stores the retrieved row results in the given destination. + // + // For statements of type SELECT, INSERT, UPDATE, or DELETE, the destination must be a pointer to either a struct or a slice. + // For SELECT_JSON_ARR statements, the destination must be a pointer to a slice of structs or a pointer to []map[string]any. + // For SELECT_JSON_OBJ statements, the destination must be a pointer to a struct or a pointer to map[string]any. + // + // If the destination is a pointer to a struct and the query returns no rows, QueryContext returns qrm.ErrNoRows. QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error - // Exec executes statement over db connection/transaction without returning any rows. + + // Exec delegates call to ExecContext using context.Background() as parameter. Exec(db qrm.Executable) (sql.Result, error) + // ExecContext executes statement with context over db connection/transaction without returning any rows. ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) + // Rows executes statements over db connection/transaction and returns rows Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) } @@ -60,14 +73,14 @@ type SerializerHasProjections interface { HasProjections } -// serializerStatementInterfaceImpl struct -type serializerStatementInterfaceImpl struct { +// statementInterfaceImpl struct +type statementInterfaceImpl struct { dialect Dialect statementType StatementType parent SerializerStatement } -func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface{}) { +func (s *statementInterfaceImpl) Sql() (query string, args []interface{}) { queryData := &SQLBuilder{Dialect: s.dialect} @@ -77,7 +90,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface return } -func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { +func (s *statementInterfaceImpl) DebugSql() (query string) { sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true} s.parent.serialize(s.statementType, sqlBuilder, NoWrap) @@ -86,11 +99,27 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { return } -func (s *serializerStatementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error { +func (s *statementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error { return s.QueryContext(context.Background(), db, destination) } -func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error { +func (s *statementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error { + return s.query(ctx, func(query string, args []interface{}) (int64, error) { + switch s.statementType { + case SelectJsonObjStatementType: + return qrm.QueryJsonObj(ctx, db, query, args, destination) + case SelectJsonArrStatementType: + return qrm.QueryJsonArr(ctx, db, query, args, destination) + default: + return qrm.Query(ctx, db, query, args, destination) + } + }) +} + +func (s *statementInterfaceImpl) query( + ctx context.Context, + queryFunc func(query string, args []interface{}) (int64, error), +) error { query, args := s.Sql() callLogger(ctx, s) @@ -99,7 +128,7 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db var err error duration := duration(func() { - rowsProcessed, err = qrm.Query(ctx, db, query, args, destination) + rowsProcessed, err = queryFunc(query, args) }) callQueryLoggerFunc(ctx, QueryInfo{ @@ -112,11 +141,11 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db return err } -func (s *serializerStatementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) { +func (s *statementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) { return s.ExecContext(context.Background(), db) } -func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) { +func (s *statementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) { query, args := s.Sql() callLogger(ctx, s) @@ -141,7 +170,7 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q return res, err } -func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) { +func (s *statementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) { query, args := s.Sql() callLogger(ctx, s) @@ -191,11 +220,15 @@ type ExpressionStatement interface { } // NewExpressionStatementImpl creates new expression statement -func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement { +func NewExpressionStatementImpl(Dialect Dialect, + statementType StatementType, + parent ExpressionStatement, + clauses ...Clause) ExpressionStatement { + return &expressionStatementImpl{ ExpressionInterfaceImpl{Parent: parent}, statementImpl{ - serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ + statementInterfaceImpl: statementInterfaceImpl{ parent: parent, dialect: Dialect, statementType: statementType, @@ -214,10 +247,14 @@ func (s *expressionStatementImpl) serializeForProjection(statement StatementType s.serialize(statement, out) } +func (e *expressionStatementImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) { + panic("jet: SELECT JSON statements need to be aliased when used as a projection.") +} + // NewStatementImpl creates new statementImpl func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) SerializerStatement { return &statementImpl{ - serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ + statementInterfaceImpl: statementInterfaceImpl{ parent: parent, dialect: Dialect, statementType: statementType, @@ -227,7 +264,7 @@ func NewStatementImpl(Dialect Dialect, statementType StatementType, parent Seria } type statementImpl struct { - serializerStatementInterfaceImpl + statementInterfaceImpl Clauses []Clause } diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index 29b2447..c7f1158 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -3,6 +3,7 @@ package jet // StringExpression interface type StringExpression interface { Expression + isStringOrBlob() EQ(rhs StringExpression) BoolExpression NOT_EQ(rhs StringExpression) BoolExpression @@ -29,6 +30,8 @@ type stringInterfaceImpl struct { parent StringExpression } +func (s *stringInterfaceImpl) isStringOrBlob() {} + func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression { return Eq(s.parent, rhs) } @@ -102,9 +105,10 @@ type stringExpressionWrapper struct { } func newStringExpressionWrap(expression Expression) StringExpression { - stringExpressionWrap := stringExpressionWrapper{Expression: expression} - stringExpressionWrap.stringInterfaceImpl.parent = &stringExpressionWrap - return &stringExpressionWrap + stringExpressionWrap := &stringExpressionWrapper{Expression: expression} + stringExpressionWrap.stringInterfaceImpl.parent = stringExpressionWrap + expression.setParent(stringExpressionWrap) + return stringExpressionWrap } // StringExp is string expression wrapper around arbitrary expression. diff --git a/internal/jet/string_or_blob_expression.go b/internal/jet/string_or_blob_expression.go new file mode 100644 index 0000000..f05a149 --- /dev/null +++ b/internal/jet/string_or_blob_expression.go @@ -0,0 +1,8 @@ +package jet + +// StringOrBlobExpression is common interface for all string and blob expressions +type StringOrBlobExpression interface { + Expression + + isStringOrBlob() +} diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 70b21c7..866048a 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -12,6 +12,9 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, + ArgumentToString: func(value any) (string, bool) { + return "", false + }, }) var ( diff --git a/internal/jet/time_expression.go b/internal/jet/time_expression.go index efd146f..215d3bc 100644 --- a/internal/jet/time_expression.go +++ b/internal/jet/time_expression.go @@ -75,14 +75,15 @@ func (t *timeInterfaceImpl) SUB(rhs Interval) TimeExpression { //---------------------------------------------------// type timeExpressionWrapper struct { - timeInterfaceImpl Expression + timeInterfaceImpl } func newTimeExpressionWrap(expression Expression) TimeExpression { - timeExpressionWrap := timeExpressionWrapper{Expression: expression} - timeExpressionWrap.timeInterfaceImpl.parent = &timeExpressionWrap - return &timeExpressionWrap + timeExpressionWrap := &timeExpressionWrapper{Expression: expression} + timeExpressionWrap.timeInterfaceImpl.parent = timeExpressionWrap + expression.setParent(timeExpressionWrap) + return timeExpressionWrap } // TimeExp is time expression wrapper around arbitrary expression. diff --git a/internal/jet/time_expression_test.go b/internal/jet/time_expression_test.go index 61ee29f..2b3d015 100644 --- a/internal/jet/time_expression_test.go +++ b/internal/jet/time_expression_test.go @@ -52,11 +52,3 @@ func TestTimeExp(t *testing.T) { assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)), "(table1.col_float < $1)", string("01:01:01.001")) } - -func TestTimeArithmetic(t *testing.T) { - time := Time(10, 20, 3) - assertClauseDebugSerialize(t, table1ColTime.ADD(NewInterval(String("1 HOUR"))).EQ(time), - "((table1.col_time + INTERVAL '1 HOUR') = '10:20:03')") - assertClauseDebugSerialize(t, table1ColTime.SUB(NewInterval(String("1 HOUR"))).EQ(time), - "((table1.col_time - INTERVAL '1 HOUR') = '10:20:03')") -} diff --git a/internal/jet/timestamp_expression.go b/internal/jet/timestamp_expression.go index 1013ce1..222b1e4 100644 --- a/internal/jet/timestamp_expression.go +++ b/internal/jet/timestamp_expression.go @@ -80,9 +80,10 @@ type timestampExpressionWrapper struct { } func newTimestampExpressionWrap(expression Expression) TimestampExpression { - timestampExpressionWrap := timestampExpressionWrapper{Expression: expression} - timestampExpressionWrap.timestampInterfaceImpl.parent = ×tampExpressionWrap - return ×tampExpressionWrap + timestampExpressionWrap := ×tampExpressionWrapper{Expression: expression} + timestampExpressionWrap.timestampInterfaceImpl.parent = timestampExpressionWrap + expression.setParent(timestampExpressionWrap) + return timestampExpressionWrap } // TimestampExp is timestamp expression wrapper around arbitrary expression. diff --git a/internal/jet/timestamp_expression_test.go b/internal/jet/timestamp_expression_test.go index e34d8dd..9a9ceb4 100644 --- a/internal/jet/timestamp_expression_test.go +++ b/internal/jet/timestamp_expression_test.go @@ -53,11 +53,3 @@ func TestTimestampExp(t *testing.T) { assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp), "(table1.col_float < $1)", "2000-01-31 10:20:00.003") } - -func TestTimestampArithmetic(t *testing.T) { - timestamp := Timestamp(2000, 1, 1, 0, 0, 0) - assertClauseDebugSerialize(t, table1ColTimestamp.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp), - "((table1.col_timestamp + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") - assertClauseDebugSerialize(t, table1ColTimestamp.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp), - "((table1.col_timestamp - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") -} diff --git a/internal/jet/timestampz_expression.go b/internal/jet/timestampz_expression.go index b8fe103..6500251 100644 --- a/internal/jet/timestampz_expression.go +++ b/internal/jet/timestampz_expression.go @@ -80,9 +80,10 @@ type timestampzExpressionWrapper struct { } func newTimestampzExpressionWrap(expression Expression) TimestampzExpression { - timestampzExpressionWrap := timestampzExpressionWrapper{Expression: expression} - timestampzExpressionWrap.timestampzInterfaceImpl.parent = ×tampzExpressionWrap - return ×tampzExpressionWrap + timestampzExpressionWrap := ×tampzExpressionWrapper{Expression: expression} + timestampzExpressionWrap.timestampzInterfaceImpl.parent = timestampzExpressionWrap + expression.setParent(timestampzExpressionWrap) + return timestampzExpressionWrap } // TimestampzExp is timestamp with time zone expression wrapper around arbitrary expression. diff --git a/internal/jet/timestampz_expression_test.go b/internal/jet/timestampz_expression_test.go index 1ff1eac..6880c93 100644 --- a/internal/jet/timestampz_expression_test.go +++ b/internal/jet/timestampz_expression_test.go @@ -53,11 +53,3 @@ func TestTimestampzExp(t *testing.T) { assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), "(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200") } - -func TestTimestampzArithmetic(t *testing.T) { - timestampz := Timestampz(2000, 1, 1, 0, 0, 0, 100, "UTC") - assertClauseDebugSerialize(t, table1ColTimestampz.ADD(NewInterval(String("1 HOUR"))).EQ(timestampz), - "((table1.col_timestampz + INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')") - assertClauseDebugSerialize(t, table1ColTimestampz.SUB(NewInterval(String("1 HOUR"))).EQ(timestampz), - "((table1.col_timestampz - INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')") -} diff --git a/internal/jet/timez_expression.go b/internal/jet/timez_expression.go index 8896dce..590ec96 100644 --- a/internal/jet/timez_expression.go +++ b/internal/jet/timez_expression.go @@ -75,14 +75,15 @@ func (t *timezInterfaceImpl) SUB(rhs Interval) TimezExpression { //---------------------------------------------------// type timezExpressionWrapper struct { - timezInterfaceImpl Expression + timezInterfaceImpl } func newTimezExpressionWrap(expression Expression) TimezExpression { - timezExpressionWrap := timezExpressionWrapper{Expression: expression} - timezExpressionWrap.timezInterfaceImpl.parent = &timezExpressionWrap - return &timezExpressionWrap + timezExpressionWrap := &timezExpressionWrapper{Expression: expression} + timezExpressionWrap.timezInterfaceImpl.parent = timezExpressionWrap + expression.setParent(timezExpressionWrap) + return timezExpressionWrap } // TimezExp is time with time zone expression wrapper around arbitrary expression. diff --git a/internal/jet/timez_expression_test.go b/internal/jet/timez_expression_test.go index 9f21c08..104e613 100644 --- a/internal/jet/timez_expression_test.go +++ b/internal/jet/timez_expression_test.go @@ -51,11 +51,3 @@ func TestTimezExp(t *testing.T) { assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")), "(table1.col_float < $1)", string("01:01:01.000000001 +4:00")) } - -func TestTimezArithmetic(t *testing.T) { - timez := Timez(0, 0, 0, 100, "UTC") - assertClauseDebugSerialize(t, table1ColTimez.ADD(NewInterval(String("1 HOUR"))).EQ(timez), - "((table1.col_timez + INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')") - assertClauseDebugSerialize(t, table1ColTimez.SUB(NewInterval(String("1 HOUR"))).EQ(timez), - "((table1.col_timez - INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')") -} diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 466f2a5..d793f24 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -58,6 +58,23 @@ func SerializeProjectionList(statement StatementType, projections []Projection, } } +// SerializeProjectionListJsonObj serializes a list of projections for JSON object +func SerializeProjectionListJsonObj(statement StatementType, projections []Projection, out *SQLBuilder) { + + for i, p := range projections { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + + if p == nil { + panic("jet: Projection is nil") + } + + p.serializeForJsonObjEntry(statement, out) + } +} + // SerializeColumnNames func func SerializeColumnNames(columns []Column, out *SQLBuilder) { for i, col := range columns { @@ -115,8 +132,8 @@ func ExpressionListToSerializerList(expressions []Expression) []Serializer { return ret } -// BoolExpressionListToExpressionList converts list of bool expressions to list of expressions -func BoolExpressionListToExpressionList(expressions []BoolExpression) []Expression { +// ToExpressionList converts list of any expressions to list of expressions +func ToExpressionList[T Expression](expressions []T) []Expression { var ret []Expression for _, expression := range expressions { diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index 03330e6..d507072 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -7,7 +7,7 @@ func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(s newWithImpl := &withImpl{ recursive: recursive, ctes: cte, - serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ + statementInterfaceImpl: statementInterfaceImpl{ dialect: dialect, statementType: WithStatementType, }, @@ -25,7 +25,7 @@ func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(s } type withImpl struct { - serializerStatementInterfaceImpl + statementInterfaceImpl recursive bool ctes []*CommonTableExpression primaryStatement SerializerStatement diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 817de32..eb1f8bf 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -115,6 +115,16 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { require.Equal(t, dataJson, expectedJSON) } +// AssertJsonEqual checks if actual and expected json representation are the same +func AssertJsonEqual(t require.TestingT, actual, expected interface{}, option ...cmp.Option) { + actualJsonData, err := json.MarshalIndent(actual, "", "\t") + require.NoError(t, err) + expectedJsonData, err := json.MarshalIndent(expected, "", "\t") + require.NoError(t, err) + + require.Equal(t, string(actualJsonData), string(expectedJsonData)) +} + // SaveJSONFile saves v as json at testRelativePath // nolint:unused func SaveJSONFile(v interface{}, testRelativePath string) { @@ -127,7 +137,10 @@ func SaveJSONFile(v interface{}, testRelativePath string) { } // AssertJSONFile check if data json representation is the same as json at testRelativePath -func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { +func AssertJSONFile(t require.TestingT, data interface{}, testRelativePath string) { + if _, ok := t.(*testing.B); ok { + return // skip assert for benchmarks + } filePath := getFullPath(testRelativePath) fileJSONData, err := os.ReadFile(filePath) // #nosec G304 @@ -145,7 +158,11 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { } // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs -func AssertStatementSql(t *testing.T, query jet.PrintableStatement, expectedQuery string, expectedArgs ...interface{}) { +func AssertStatementSql(t require.TestingT, query jet.PrintableStatement, expectedQuery string, expectedArgs ...interface{}) { + if _, ok := t.(*testing.B); ok { + return // skip assert for benchmarks + } + queryStr, args := query.Sql() assertQueryString(t, queryStr, expectedQuery) @@ -283,14 +300,14 @@ func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) { } // AssertDeepEqual checks if actual and expected objects are deeply equal. -func AssertDeepEqual(t *testing.T, actual, expected interface{}, option ...cmp.Option) { +func AssertDeepEqual(t require.TestingT, actual, expected interface{}, option ...cmp.Option) { if !assert.True(t, cmp.Equal(actual, expected, option...)) { printDiff(actual, expected, option...) t.FailNow() } } -func assertQueryString(t *testing.T, actual, expected string) { +func assertQueryString(t require.TestingT, actual, expected string) { if !assert.Equal(t, actual, expected) { printDiff(actual, expected) t.FailNow() diff --git a/internal/utils/datetime/duration.go b/internal/utils/datetime/duration.go index 11cc57f..a702623 100644 --- a/internal/utils/datetime/duration.go +++ b/internal/utils/datetime/duration.go @@ -1,6 +1,9 @@ package datetime -import "time" +import ( + //"github.com/go-jet/jet/v2/internal/utils/min" + "time" +) // ExtractTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration func ExtractTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) { @@ -20,3 +23,36 @@ func ExtractTimeComponents(duration time.Duration) (days, hours, minutes, second return } + +// TryParseAsTime attempts to parse the provided value as a time using one of the given formats. +// +// The function iterates over the provided formats and tries to parse the value into a time.Time object. +// It returns the parsed time and a boolean indicating whether the parsing was successful. +func TryParseAsTime(value interface{}, formats []string) (time.Time, bool) { + + var timeStr string + + switch v := value.(type) { + case string: + timeStr = v + case []byte: + timeStr = string(v) + case int64: + return time.Unix(v, 0), true // sqlite + default: + return time.Time{}, false + } + + for _, format := range formats { + formatLen := min(len(format), len(timeStr)) + t, err := time.Parse(format[:formatLen], timeStr) + + if err != nil { + continue + } + + return t, true + } + + return time.Time{}, false +} diff --git a/internal/utils/is/is.go b/internal/utils/is/is.go index 8824b06..1f5ee54 100644 --- a/internal/utils/is/is.go +++ b/internal/utils/is/is.go @@ -1,6 +1,8 @@ package is -import "reflect" +import ( + "reflect" +) // Nil check if v is nil func Nil(v interface{}) bool { diff --git a/internal/utils/min/min.go b/internal/utils/min/min.go deleted file mode 100644 index 0e92146..0000000 --- a/internal/utils/min/min.go +++ /dev/null @@ -1,9 +0,0 @@ -package min - -// Int returns minimum of two int values -func Int(a, b int) int { - if a < b { - return a - } - return b -} diff --git a/mysql/cast.go b/mysql/cast.go index dcf1f57..fbce06c 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -70,6 +70,6 @@ func (c *cast) AS_TIME() TimeExpression { } // AS_BINARY casts expression as BINARY type -func (c *cast) AS_BINARY() StringExpression { - return StringExp(c.AS("BINARY")) +func (c *cast) AS_BINARY() BlobExpression { + return BlobExp(c.AS("BINARY")) } diff --git a/mysql/columns.go b/mysql/columns.go index 3f08396..c0df1aa 100644 --- a/mysql/columns.go +++ b/mysql/columns.go @@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString // StringColumn creates named string column. var StringColumn = jet.StringColumn +// ColumnBlob is interface for blob columns. +type ColumnBlob = jet.ColumnBlob + +// BlobColumn creates named blob column. +var BlobColumn = jet.BlobColumn + // ColumnInteger is interface for SQL smallint, integer, bigint columns. type ColumnInteger = jet.ColumnInteger diff --git a/mysql/dialect.go b/mysql/dialect.go index 9628bfb..8f95156 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -1,11 +1,12 @@ package mysql import ( + "encoding/hex" "fmt" "github.com/go-jet/jet/v2/internal/jet" ) -// Dialect is implementation of MySQL dialect for SQL Builder serialisation. +// Dialect is implementation of MySQL dialect for SQL Builder serialization. var Dialect = newDialect() func newDialect() jet.Dialect { @@ -27,16 +28,43 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, + ArgumentToString: argumentToString, ReservedWords: reservedWords, SerializeOrderBy: serializeOrderBy, ValuesDefaultColumnName: func(index int) string { return fmt.Sprintf("column_%d", index) }, + JsonValueEncode: func(expr Expression) Expression { + switch e := expr.(type) { + case BlobExpression: + return TO_BASE64(e) + + // CustomExpression used bellow (instead DATE_FORMAT function) so that only expr is parametrized + case TimestampExpression: + return CustomExpression(Token("DATE_FORMAT("), e, Token(",'%Y-%m-%dT%H:%i:%s.%fZ')")) + case TimeExpression: + return CustomExpression(Token("CONCAT('0000-01-01T', DATE_FORMAT("), e, Token(",'%H:%i:%s.%fZ'))")) + case DateExpression: + return CustomExpression(Token("CONCAT(DATE_FORMAT("), e, Token(",'%Y-%m-%d')"), Token(", 'T00:00:00Z')")) + case BoolExpression: + return CustomExpression(e, Token(" = 1")) + } + return expr + }, } return jet.NewDialect(mySQLDialectParams) } +func argumentToString(value any) (string, bool) { + switch bindVal := value.(type) { + case []byte: + return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true + } + + return "", false +} + func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { diff --git a/mysql/expressions.go b/mysql/expressions.go index 4073ef5..ac2be48 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression // StringExpression interface type StringExpression = jet.StringExpression +// BlobExpression interface +type BlobExpression = jet.BlobExpression + // IntegerExpression interface type IntegerExpression = jet.IntegerExpression @@ -43,6 +46,11 @@ var BoolExp = jet.BoolExp // Does not add sql cast to generated sql builder output. var StringExp = jet.StringExp +// BlobExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as blob expression. +// Does not add sql cast to generated sql builder output. +var BlobExp = jet.BlobExp + // IntExp is int expression wrapper around arbitrary expression. // Allows go compiler to see any expression as int expression. // Does not add sql cast to generated sql builder output. @@ -100,6 +108,7 @@ var ( RawTime = jet.RawTime RawTimestamp = jet.RawTimestamp RawDate = jet.RawDate + RawBlob = jet.RawBlob ) // Func can be used to call custom or unsupported database functions. diff --git a/mysql/functions.go b/mysql/functions.go index ceec7ab..56aa517 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -148,6 +148,14 @@ var NTH_VALUE = jet.NTH_VALUE //--------------------- String functions ------------------// +// HEX function in MySQL takes an input and returns its equivalent hexadecimal representation +var HEX = jet.HEX + +// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument +// as a hexadecimal number and converts it to the byte represented by the number. +// The return value is a binary string. +var UNHEX = jet.UNHEX + // BIT_LENGTH returns number of bits in string expression var BIT_LENGTH = jet.BIT_LENGTH @@ -157,6 +165,23 @@ var CHAR_LENGTH = jet.CHAR_LENGTH // OCTET_LENGTH returns number of bytes in string expression var OCTET_LENGTH = jet.OCTET_LENGTH +// ELT returns the Nth element of the list of strings: str1 if N = 1, str2 if N = 2, and so on. +// Returns NULL if N is less than 1, greater than the number of arguments, or NULL. +func ELT(n IntegerExpression, list ...StringExpression) StringExpression { + args := []Expression{n} + args = append(args, jet.ToExpressionList(list)...) + + return StringExp(Func("ELT", args...)) +} + +// FIELD returns the index (position) of str in the str1, str2, str3, ... list. Returns 0 if str is not found. +func FIELD(str StringExpression, list ...StringExpression) StringExpression { + args := []Expression{str} + args = append(args, jet.ToExpressionList(list)...) + + return StringExp(Func("FIELD", args...)) +} + // LOWER returns string expression in lower case var LOWER = jet.LOWER @@ -178,7 +203,35 @@ var CONCAT = jet.CONCAT var CONCAT_WS = jet.CONCAT_WS // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. -var FORMAT = jet.FORMAT +func FORMAT(number jet.NumericExpression, decimals IntegerExpression, optionalLocale ...StringExpression) StringExpression { + if len(optionalLocale) > 0 { + return StringExp(Func("FORMAT", number, decimals, optionalLocale[0])) + } + + return StringExp(Func("FORMAT", number, decimals)) +} + +// TO_BASE64 converts the string argument to base-64 encoded form and returns the +// result as a character string with the connection character set and collation. +func TO_BASE64(data jet.StringOrBlobExpression) StringExpression { + return StringExp(Func("TO_BASE64", data)) +} + +// FROM_BASE64 takes a string encoded with the base-64 encoded rules used by TO_BASE64() +// and returns the decoded result as a binary string. +func FROM_BASE64(data StringExpression) BlobExpression { + return BlobExp(Func("FROM_BASE64", data)) +} + +// CHARSET returns the character set of the string argument, or NULL if the argument is NULL. +func CHARSET(exp Expression) StringExpression { + return StringExp(Func("CHARSET", exp)) +} + +// COLLATION returns the collation of the string argument. +func COLLATION(exp Expression) StringExpression { + return StringExp(Func("COLLATION ", exp)) +} // LEFT returns first n characters in the string. // When n is negative, return all but last |n| characters. @@ -189,7 +242,7 @@ var LEFT = jet.LEFT var RIGHT = jet.RIGHT // LENGTH returns number of characters in string with a given encoding -func LENGTH(str jet.StringExpression) jet.StringExpression { +func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression { return jet.LENGTH(str) } diff --git a/mysql/interval.go b/mysql/interval_literal.go similarity index 96% rename from mysql/interval.go rename to mysql/interval_literal.go index 23be21f..4fff705 100644 --- a/mysql/interval.go +++ b/mysql/interval_literal.go @@ -98,13 +98,10 @@ func INTERVAL(value interface{}, unitType unitType) Interval { // INTERVALe creates new temporal interval from expresion and unit type. func INTERVALe(expr Expression, unitType unitType) Interval { - return jet.NewInterval(jet.ListSerializer{ - Serializers: []jet.Serializer{expr, jet.RawWithParent(string(unitType))}, - Separator: " ", - }) + return jet.IntervalExp(CustomExpression(Token("INTERVAL"), expr, Token(unitType))) } -// INTERVALd temoral interval from time.Duration +// INTERVALd creates new temporal interval from time.Duration func INTERVALd(duration time.Duration) Interval { var sign int64 = 1 if duration < 0 { diff --git a/mysql/interval_test.go b/mysql/interval_literal_test.go similarity index 100% rename from mysql/interval_test.go rename to mysql/interval_literal_test.go diff --git a/mysql/literal.go b/mysql/literal.go index ca720c8..d23fd7a 100644 --- a/mysql/literal.go +++ b/mysql/literal.go @@ -56,6 +56,11 @@ var String = jet.String // value can be any uuid type with a String method var UUID = jet.UUID +// Blob creates new blob literal expression +func Blob(data []byte) BlobExpression { + return BlobExp(jet.Literal(data)) +} + // Date creates new date literal func Date(year int, month time.Month, day int) DateExpression { return CAST(jet.Date(year, month, day)).AS_DATE() diff --git a/mysql/select_json.go b/mysql/select_json.go new file mode 100644 index 0000000..cfeb30f --- /dev/null +++ b/mysql/select_json.go @@ -0,0 +1,79 @@ +package mysql + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +// SelectJsonStatement is an interface for MySQL statements that generate JSON on the server. +type SelectJsonStatement interface { + Statement + jet.Serializer + + AS(alias string) Projection + + FROM(table ReadableTable) SelectJsonStatement + WHERE(condition BoolExpression) SelectJsonStatement + ORDER_BY(orderByClauses ...OrderByClause) SelectJsonStatement + LIMIT(limit int64) SelectJsonStatement + OFFSET(offset int64) SelectJsonStatement +} + +// SELECT_JSON_ARR creates a new SelectJsonStatement with a list of projections. +func SELECT_JSON_ARR(projections ...Projection) SelectJsonStatement { + return newSelectStatementJson(projections, jet.SelectJsonArrStatementType) +} + +// SELECT_JSON_OBJ creates a new SelectJsonStatement with a list of projections. +func SELECT_JSON_OBJ(projections ...Projection) SelectJsonStatement { + return newSelectStatementJson(projections, jet.SelectJsonObjStatementType) +} + +type selectJsonStatement struct { + *selectStatementImpl +} + +func newSelectStatementJson(projections []Projection, statementType jet.StatementType) SelectJsonStatement { + newSelect := &selectJsonStatement{ + selectStatementImpl: newSelectStatement(statementType, nil, nil), + } + + newSelect.Select.ProjectionList = ProjectionList{constructJsonFunc(projections, statementType).AS("json")} + + return newSelect +} + +func constructJsonFunc(projections []Projection, statementType jet.StatementType) Expression { + jsonObj := Func("JSON_OBJECT", CustomExpression(jet.JsonObjProjectionList(projections))) + + if statementType == jet.SelectJsonArrStatementType { + return Func("JSON_ARRAYAGG", jsonObj) + } + + return jsonObj +} + +func (s *selectJsonStatement) FROM(table ReadableTable) SelectJsonStatement { + s.From.Tables = []jet.Serializer{table} + + return s +} + +func (s *selectJsonStatement) WHERE(condition BoolExpression) SelectJsonStatement { + s.Where.Condition = condition + return s +} + +func (s *selectJsonStatement) ORDER_BY(orderByClauses ...OrderByClause) SelectJsonStatement { + s.OrderBy.List = orderByClauses + return s +} + +func (s *selectJsonStatement) LIMIT(limit int64) SelectJsonStatement { + s.Limit.Count = limit + return s +} + +func (s *selectJsonStatement) OFFSET(offset int64) SelectJsonStatement { + s.Offset.Count = Int(offset) + return s +} diff --git a/mysql/select_statement.go b/mysql/select_statement.go index aaeff9a..915fe93 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -62,12 +62,12 @@ type SelectStatement interface { // SELECT creates new SelectStatement with list of projections func SELECT(projection Projection, projections ...Projection) SelectStatement { - return newSelectStatement(nil, append([]Projection{projection}, projections...)) + return newSelectStatement(jet.SelectStatementType, nil, append([]Projection{projection}, projections...)) } -func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { +func newSelectStatement(stmtType jet.StatementType, table ReadableTable, projections []Projection) *selectStatementImpl { newSelect := &selectStatementImpl{} - newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, + newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, stmtType, newSelect, &newSelect.Select, &newSelect.From, &newSelect.Where, diff --git a/mysql/table.go b/mysql/table.go index 0ae7ee8..ebc67e2 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -50,7 +50,7 @@ type readableTableInterfaceImpl struct { // Generates a select query on the current tableName. func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { - return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) + return newSelectStatement(jet.SelectStatementType, r.parent, append([]Projection{projection1}, projections...)) } // Creates a inner join tableName Expression using onCondition. diff --git a/postgres/cast.go b/postgres/cast.go index 935cb08..e4d09ce 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -101,9 +101,9 @@ func (b *cast) AS_DECIMAL() FloatExpression { return FloatExp(b.AS("decimal")) } -// AS_BYTEA casts expression AS text type -func (b *cast) AS_BYTEA() StringExpression { - return StringExp(b.AS("bytea")) +// AS_BYTEA casts expression AS bytea type +func (b *cast) AS_BYTEA() ByteaExpression { + return ByteaExp(b.AS("bytea")) } // AS_TIME casts expression AS date type diff --git a/postgres/columns.go b/postgres/columns.go index a70c234..01af0d7 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -23,6 +23,12 @@ type ColumnString = jet.ColumnString // StringColumn creates named string column. var StringColumn = jet.StringColumn +// ColumnBytea is interface for bytea columns +type ColumnBytea = jet.ColumnBlob + +// ByteaColumn creates new named bytea column. +var ByteaColumn = jet.BlobColumn + // ColumnInteger is interface for SQL smallint, integer, bigint columns. type ColumnInteger = jet.ColumnInteger @@ -65,6 +71,12 @@ type ColumnTimestampz = jet.ColumnTimestampz // TimestampzColumn creates named timestamp with time zone column. var TimestampzColumn = jet.TimestampzColumn +// ColumnInterval is interface of PostgreSQL interval columns. +type ColumnInterval = jet.ColumnInterval + +// IntervalColumn creates named interval column +var IntervalColumn = jet.IntervalColumn + // ColumnDateRange is interface of SQL date range column type ColumnDateRange = jet.ColumnRange[DateExpression] @@ -100,41 +112,3 @@ type ColumnInt8Range jet.ColumnRange[jet.Int8Expression] // Int8RangeColumn creates named range with range column var Int8RangeColumn = jet.RangeColumn[jet.Int8Expression] - -//------------------------------------------------------// - -// ColumnInterval is interface of PostgreSQL interval columns. -type ColumnInterval interface { - IntervalExpression - jet.Column - - From(subQuery SelectTable) ColumnInterval - SET(intervalExp IntervalExpression) ColumnAssigment -} - -//------------------------------------------------------// - -type intervalColumnImpl struct { - jet.ColumnExpressionImpl - intervalInterfaceImpl -} - -func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment { - return jet.NewColumnAssignment(i, intervalExp) -} - -func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval { - newIntervalColumn := IntervalColumn(i.Name()) - jet.SetTableName(newIntervalColumn, i.TableName()) - jet.SetSubQuery(newIntervalColumn, subQuery) - - return newIntervalColumn -} - -// IntervalColumn creates named interval column. -func IntervalColumn(name string) ColumnInterval { - intervalColumn := &intervalColumnImpl{} - intervalColumn.ColumnExpressionImpl = jet.NewColumnImpl(name, "", intervalColumn) - intervalColumn.intervalInterfaceImpl.parent = intervalColumn - return intervalColumn -} diff --git a/postgres/dialect.go b/postgres/dialect.go index 7885cf7..9ffa5f7 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -1,10 +1,10 @@ package postgres import ( + "encoding/hex" "fmt" - "strconv" - "github.com/go-jet/jet/v2/internal/jet" + "strconv" ) // Dialect is implementation of postgres dialect for SQL Builder serialisation. @@ -26,15 +26,42 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, - ReservedWords: reservedWords, + ArgumentToString: argumentToString, + ReservedWords: reservedWords, ValuesDefaultColumnName: func(index int) string { return fmt.Sprintf("column%d", index+1) }, + JsonValueEncode: func(expr Expression) Expression { + switch e := expr.(type) { + case ByteaExpression: + return ENCODE(e, Base64) + + // CustomExpression used bellow (instead TO_CHAR function) so that only expr is parametrized + case TimeExpression: + return CustomExpression(Token("'0000-01-01T' || to_char('2000-10-10'::date + "), e, Token(`, 'HH24:MI:SS.USZ')`)) + case TimezExpression: + return CustomExpression(Token("'0000-01-01T' || to_char('2000-10-10'::date + "), e, Token(`, 'HH24:MI:SS.USTZH:TZM')`)) + case TimestampExpression: + return CustomExpression(Token("to_char("), e, Token(`, 'YYYY-MM-DD"T"HH24:MI:SS.USZ')`)) + case DateExpression: + return CustomExpression(Token("to_char("), e, Token(`::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z'`)) + } + return expr + }, } return jet.NewDialect(dialectParams) } +func argumentToString(value any) (string, bool) { + switch bindVal := value.(type) { + case []byte: + return fmt.Sprintf("'\\x%s'", hex.EncodeToString(bindVal)), true + } + + return "", false +} + func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { diff --git a/postgres/expressions.go b/postgres/expressions.go index d8ad34b..f4fbb13 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -12,6 +12,8 @@ type BoolExpression = jet.BoolExpression // StringExpression interface type StringExpression = jet.StringExpression +type ByteaExpression = jet.BlobExpression + // NumericExpression interface type NumericExpression = jet.NumericExpression @@ -39,6 +41,9 @@ type TimestampzExpression = jet.TimestampzExpression // RowExpression interface type RowExpression = jet.RowExpression +// IntervalExpression interface +type IntervalExpression = jet.IntervalExpression + // DateRange Expression interface type DateRange = jet.Range[DateExpression] @@ -82,6 +87,11 @@ var TimeExp = jet.TimeExp // Does not add sql cast to generated sql builder output. var StringExp = jet.StringExp +// ByteaExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as string expression. +// Does not add sql cast to generated sql builder output. +var ByteaExp = jet.BlobExp + // TimezExp is time with time zone expression wrapper around arbitrary expression. // Allows go compiler to see any expression as time with time zone expression. // Does not add sql cast to generated sql builder output. @@ -102,6 +112,11 @@ var TimestampExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampzExp = jet.TimestampzExp +// IntervalExp is interval expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as interval expression. +// Does not add sql cast to generated sql builder output. +var IntervalExp = jet.IntervalExp + // RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression. // This enables the Go compiler to interpret any expression as a row expression // Note: This does not modify the generated SQL builder output by adding a SQL CAST operation. @@ -134,15 +149,17 @@ type RawArgs = map[string]interface{} var ( Raw = jet.Raw - RawBool = jet.RawBool - RawInt = jet.RawInt - RawFloat = jet.RawFloat - RawString = jet.RawString - RawTime = jet.RawTime - RawTimez = jet.RawTimez - RawTimestamp = jet.RawTimestamp - RawTimestampz = jet.RawTimestampz - RawDate = jet.RawDate + RawBool = jet.RawBool + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimez = jet.RawTimez + RawTimestamp = jet.RawTimestamp + RawTimestampz = jet.RawTimestampz + RawDate = jet.RawDate + RawBytea = jet.RawBlob + RawNumRange = jet.RawRange[jet.NumericExpression] RawInt4Range = jet.RawRange[jet.Int4Expression] RawInt8Range = jet.RawRange[jet.Int8Expression] diff --git a/postgres/functions.go b/postgres/functions.go index bce2e98..2e92648 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -192,9 +192,27 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...) } +// Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions +var ( + UTF8 = StringExp(jet.FixedLiteral("UTF8")) + LATIN1 = StringExp(jet.FixedLiteral("LATIN1")) + LATIN2 = StringExp(jet.FixedLiteral("LATIN2")) + LATIN3 = StringExp(jet.FixedLiteral("LATIN3")) + LATIN4 = StringExp(jet.FixedLiteral("LATIN4")) + WIN1252 = StringExp(jet.FixedLiteral("WIN1252")) + ISO_8859_5 = StringExp(jet.FixedLiteral("ISO_8859_5")) + ISO_8859_6 = StringExp(jet.FixedLiteral("ISO_8859_6")) + ISO_8859_7 = StringExp(jet.FixedLiteral("ISO_8859_7")) + ISO_8859_8 = StringExp(jet.FixedLiteral("ISO_8859_8")) + KOI8R = StringExp(jet.FixedLiteral("KOI8R")) + KOI8U = StringExp(jet.FixedLiteral("KOI8U")) +) + // CONVERT converts string to dest_encoding. The original encoding is // specified by src_encoding. The string must be valid in this encoding. -var CONVERT = jet.CONVERT +func CONVERT(str ByteaExpression, srcEncoding StringExpression, destEncoding StringExpression) ByteaExpression { + return jet.CONVERT(str, srcEncoding, destEncoding) +} // CONVERT_FROM converts string to the database encoding. The original // encoding is specified by src_encoding. The string must be valid in this encoding. @@ -203,6 +221,13 @@ var CONVERT_FROM = jet.CONVERT_FROM // CONVERT_TO converts string to dest_encoding. var CONVERT_TO = jet.CONVERT_TO +// ENCODE/DECODE textual formats +var ( + Base64 = StringExp(jet.FixedLiteral("base64")) + Escape = StringExp(jet.FixedLiteral("escape")) + Hex = StringExp(jet.FixedLiteral("hex")) +) + // ENCODE encodes binary data into a textual representation. // Supported formats are: base64, hex, escape. escape converts zero bytes and // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. @@ -212,7 +237,7 @@ var ENCODE = jet.ENCODE // Options for format are same as in encode. var DECODE = jet.DECODE -// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. +// FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf. func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...) } @@ -242,6 +267,49 @@ var LPAD = jet.LPAD // fill (a space by default). If the string is already longer than length then it is truncated. var RPAD = jet.RPAD +// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”). +var BIT_COUNT = jet.BIT_COUNT + +// GET_BIT extracts n'th bit from binary string. +func GET_BIT(bytes ByteaExpression, n IntegerExpression) IntegerExpression { + return IntExp(Func("GET_BIT", bytes, n)) +} + +// GET_BYTE extracts n'th byte from binary string. +func GET_BYTE(bytes ByteaExpression, n IntegerExpression) IntegerExpression { + return IntExp(Func("GET_BYTE", bytes, n)) +} + +// SET_BIT sets n'th bit in binary string to newvalue. +func SET_BIT(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression { + return ByteaExp(Func("SET_BIT", bytes, n, newValue)) +} + +// SET_BYTE sets n'th byte in binary string to newvalue. +func SET_BYTE(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression { + return ByteaExp(Func("SET_BYTE", bytes, n, newValue)) +} + +// SHA224 computes the SHA-224 hash of the binary string. +func SHA224(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA224", bytes)) +} + +// SHA256 computes the SHA-256 hash of the binary string. +func SHA256(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA256", bytes)) +} + +// SHA384 computes the SHA-384 hash of the binary string. +func SHA384(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA384", bytes)) +} + +// SHA512 computes the SHA-512 hash of the binary string. +func SHA512(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA512", bytes)) +} + // MD5 calculates the MD5 hash of string, returning the result in hexadecimal var MD5 = jet.MD5 diff --git a/postgres/interval_expression.go b/postgres/interval_expression.go deleted file mode 100644 index 91d8d8f..0000000 --- a/postgres/interval_expression.go +++ /dev/null @@ -1,257 +0,0 @@ -package postgres - -import ( - "fmt" - "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/utils/datetime" - "strconv" - "strings" - "time" -) - -type quantityAndUnit = float64 -type unit = float64 - -// Interval unit types -const ( - YEAR unit = 123456789 + iota - MONTH - WEEK - DAY - HOUR - MINUTE - SECOND - MILLISECOND - MICROSECOND - DECADE - CENTURY - MILLENNIUM -) - -// IntervalExpression is representation of postgres INTERVAL -type IntervalExpression interface { - jet.IsInterval - jet.Expression - - EQ(rhs IntervalExpression) BoolExpression - NOT_EQ(rhs IntervalExpression) BoolExpression - IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression - IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression - - LT(rhs IntervalExpression) BoolExpression - LT_EQ(rhs IntervalExpression) BoolExpression - GT(rhs IntervalExpression) BoolExpression - GT_EQ(rhs IntervalExpression) BoolExpression - BETWEEN(min, max IntervalExpression) BoolExpression - NOT_BETWEEN(min, max IntervalExpression) BoolExpression - - ADD(rhs IntervalExpression) IntervalExpression - SUB(rhs IntervalExpression) IntervalExpression - - MUL(rhs NumericExpression) IntervalExpression - DIV(rhs NumericExpression) IntervalExpression -} - -type intervalInterfaceImpl struct { - jet.IsIntervalImpl - - parent IntervalExpression -} - -func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression { - return jet.Eq(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression { - return jet.NotEq(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression { - return jet.IsDistinctFrom(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression { - return jet.IsNotDistinctFrom(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression { - return jet.Lt(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression { - return jet.LtEq(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression { - return jet.Gt(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression { - return jet.GtEq(i.parent, rhs) -} - -func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression { - return jet.NewBetweenOperatorExpression(i.parent, min, max, false) -} - -func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression { - return jet.NewBetweenOperatorExpression(i.parent, min, max, true) -} - -func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression { - return IntervalExp(jet.Add(i.parent, rhs)) -} - -func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression { - return IntervalExp(jet.Sub(i.parent, rhs)) -} - -func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression { - return IntervalExp(jet.Mul(i.parent, rhs)) -} - -func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression { - return IntervalExp(jet.Div(i.parent, rhs)) -} - -type intervalExpression struct { - jet.Expression - intervalInterfaceImpl -} - -// INTERVAL creates new interval expression from the list of quantity-unit pairs. -// -// INTERVAL(1, DAY, 3, MINUTE) -func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { - quantityAndUnitLen := len(quantityAndUnit) - if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 { - panic("jet: invalid number of quantity and unit fields") - } - - var fields []string - - for i := 0; i < len(quantityAndUnit); i += 2 { - quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64) - unitString := unitToString(quantityAndUnit[i+1]) - fields = append(fields, quantity+" "+unitString) - } - - intervalStr := fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " ")) - - newInterval := &intervalExpression{} - - newInterval.Expression = jet.RawWithParent(intervalStr, newInterval) - newInterval.intervalInterfaceImpl.parent = newInterval - - return newInterval -} - -// INTERVALd creates interval expression from time.Duration -func INTERVALd(duration time.Duration) IntervalExpression { - days, hours, minutes, seconds, microseconds := datetime.ExtractTimeComponents(duration) - - var quantityAndUnits []quantityAndUnit - - if days > 0 { - quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days)) - quantityAndUnits = append(quantityAndUnits, DAY) - } - - if hours > 0 { - quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours)) - quantityAndUnits = append(quantityAndUnits, HOUR) - } - - if minutes > 0 { - quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes)) - quantityAndUnits = append(quantityAndUnits, MINUTE) - } - - if seconds > 0 { - quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds)) - quantityAndUnits = append(quantityAndUnits, SECOND) - } - - if microseconds > 0 { - quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds)) - quantityAndUnits = append(quantityAndUnits, MICROSECOND) - } - - if len(quantityAndUnits) == 0 { - return INTERVAL(0, MICROSECOND) - } - - return INTERVAL(quantityAndUnits...) -} - -func unitToString(unit quantityAndUnit) string { - switch unit { - case YEAR: - return "YEAR" - case MONTH: - return "MONTH" - case WEEK: - return "WEEK" - case DAY: - return "DAY" - case HOUR: - return "HOUR" - case MINUTE: - return "MINUTE" - case SECOND: - return "SECOND" - case MILLISECOND: - return "MILLISECOND" - case MICROSECOND: - return "MICROSECOND" - case DECADE: - return "DECADE" - case CENTURY: - return "CENTURY" - case MILLENNIUM: - return "MILLENNIUM" - // additional field units for EXTRACT function - case DOW: - return "DOW" - case DOY: - return "DOY" - case EPOCH: - return "EPOCH" - case ISODOW: - return "ISODOW" - case ISOYEAR: - return "ISOYEAR" - case JULIAN: - return "JULIAN" - case QUARTER: - return "QUARTER" - case TIMEZONE: - return "TIMEZONE" - case TIMEZONE_HOUR: - return "TIMEZONE_HOUR" - case TIMEZONE_MINUTE: - return "TIMEZONE_MINUTE" - default: - panic("jet: invalid INTERVAL unit type") - } -} - -//---------------------------------------------------// - -type intervalWrapper struct { - intervalInterfaceImpl - Expression -} - -func newIntervalExpressionWrap(expression Expression) IntervalExpression { - intervalWrap := &intervalWrapper{Expression: expression} - intervalWrap.intervalInterfaceImpl.parent = intervalWrap - return intervalWrap -} - -// IntervalExp is interval expression wrapper around arbitrary expression. -// Allows go compiler to see any expression as interval expression. -// Does not add sql cast to generated sql builder output. -func IntervalExp(expression Expression) IntervalExpression { - return newIntervalExpressionWrap(expression) -} diff --git a/postgres/interval_literal.go b/postgres/interval_literal.go new file mode 100644 index 0000000..c36f539 --- /dev/null +++ b/postgres/interval_literal.go @@ -0,0 +1,140 @@ +package postgres + +import ( + "fmt" + "github.com/go-jet/jet/v2/internal/utils/datetime" + "strconv" + "strings" + "time" +) + +type quantityAndUnit = float64 +type unit = float64 + +// Interval unit types +const ( + YEAR unit = 123456789 + iota + MONTH + WEEK + DAY + HOUR + MINUTE + SECOND + MILLISECOND + MICROSECOND + DECADE + CENTURY + MILLENNIUM +) + +// INTERVAL creates new interval expression from the list of quantity-unit pairs. +// +// INTERVAL(1, DAY, 3, MINUTE) +func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { + quantityAndUnitLen := len(quantityAndUnit) + if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 { + panic("jet: invalid number of quantity and unit fields") + } + + var fields []string + + for i := 0; i < len(quantityAndUnit); i += 2 { + quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64) + unitString := unitToString(quantityAndUnit[i+1]) + fields = append(fields, quantity+" "+unitString) + } + + return IntervalExp(CustomExpression(Token(fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " "))))) +} + +// INTERVALd creates interval expression from time.Duration +func INTERVALd(duration time.Duration) IntervalExpression { + days, hours, minutes, seconds, microseconds := datetime.ExtractTimeComponents(duration) + + var quantityAndUnits []quantityAndUnit + + if days > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days)) + quantityAndUnits = append(quantityAndUnits, DAY) + } + + if hours > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours)) + quantityAndUnits = append(quantityAndUnits, HOUR) + } + + if minutes > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes)) + quantityAndUnits = append(quantityAndUnits, MINUTE) + } + + if seconds > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds)) + quantityAndUnits = append(quantityAndUnits, SECOND) + } + + if microseconds > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds)) + quantityAndUnits = append(quantityAndUnits, MICROSECOND) + } + + if len(quantityAndUnits) == 0 { + return INTERVAL(0, MICROSECOND) + } + + return INTERVAL(quantityAndUnits...) +} + +func unitToString(unit quantityAndUnit) string { + switch unit { + case YEAR: + return "YEAR" + case MONTH: + return "MONTH" + case WEEK: + return "WEEK" + case DAY: + return "DAY" + case HOUR: + return "HOUR" + case MINUTE: + return "MINUTE" + case SECOND: + return "SECOND" + case MILLISECOND: + return "MILLISECOND" + case MICROSECOND: + return "MICROSECOND" + case DECADE: + return "DECADE" + case CENTURY: + return "CENTURY" + case MILLENNIUM: + return "MILLENNIUM" + // additional field units for EXTRACT function + case DOW: + return "DOW" + case DOY: + return "DOY" + case EPOCH: + return "EPOCH" + case ISODOW: + return "ISODOW" + case ISOYEAR: + return "ISOYEAR" + case JULIAN: + return "JULIAN" + case QUARTER: + return "QUARTER" + case TIMEZONE: + return "TIMEZONE" + case TIMEZONE_HOUR: + return "TIMEZONE_HOUR" + case TIMEZONE_MINUTE: + return "TIMEZONE_MINUTE" + default: + panic("jet: invalid INTERVAL unit type") + } +} + +//---------------------------------------------------// diff --git a/postgres/interval_expression_test.go b/postgres/interval_literal_test.go similarity index 100% rename from postgres/interval_expression_test.go rename to postgres/interval_literal_test.go diff --git a/postgres/literal.go b/postgres/literal.go index d070c77..a6a1618 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -127,7 +127,7 @@ func Json(value interface{}) StringExpression { var UUID = jet.UUID // Bytea creates new bytea literal expression -func Bytea(value interface{}) StringExpression { +func Bytea(value interface{}) ByteaExpression { switch value.(type) { case string, []byte: default: diff --git a/postgres/select_json.go b/postgres/select_json.go new file mode 100644 index 0000000..3e5ca51 --- /dev/null +++ b/postgres/select_json.go @@ -0,0 +1,132 @@ +package postgres + +import ( + "github.com/go-jet/jet/v2/internal/jet" + "strings" +) + +// SELECT_JSON_ARR creates a new SelectJsonStatement with a list of projections. +func SELECT_JSON_ARR(projections ...Projection) SelectStatement { + return newSelectStatementJson(projections, jet.SelectJsonArrStatementType) +} + +// SELECT_JSON_OBJ creates a new SelectJsonStatement with a list of projections. +func SELECT_JSON_OBJ(projections ...Projection) SelectStatement { + return newSelectStatementJson(projections, jet.SelectJsonObjStatementType) +} + +type selectJsonStatement struct { + *selectStatementImpl + + subQuery *selectStatementImpl + statementType jet.StatementType +} + +func (s *selectJsonStatement) AS(alias string) Projection { + s.setSubQueryAlias(strings.ToLower(alias) + "_") + + return s.selectStatementImpl.AS(alias) +} + +func (s *selectJsonStatement) FROM(table ...ReadableTable) SelectStatement { + s.subQuery.From.Tables = readableTablesToSerializerList(table) + + return s +} + +func (s *selectJsonStatement) DISTINCT(on ...jet.ColumnExpression) SelectStatement { + s.subQuery.Select.Distinct = true + s.subQuery.Select.DistinctOnColumns = on + return s +} + +func (s *selectJsonStatement) WHERE(condition BoolExpression) SelectStatement { + s.subQuery.Where.Condition = condition + return s +} + +func (s *selectJsonStatement) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement { + s.subQuery.GroupBy.List = groupByClauses + return s +} + +func (s *selectJsonStatement) HAVING(boolExpression BoolExpression) SelectStatement { + s.subQuery.Having.Condition = boolExpression + return s +} + +func (s *selectJsonStatement) WINDOW(name string) windowExpand { + s.subQuery.Window.Definitions = append(s.subQuery.Window.Definitions, jet.WindowDefinition{Name: name}) + return windowExpand{ + selectStatement: s.subQuery, + rootStmt: s, + } +} + +func (s *selectJsonStatement) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement { + s.subQuery.OrderBy.List = orderByClauses + return s +} + +func (s *selectJsonStatement) LIMIT(limit int64) SelectStatement { + s.subQuery.Limit.Count = limit + return s +} + +func (s *selectJsonStatement) OFFSET(offset int64) SelectStatement { + s.subQuery.Offset.Count = Int(offset) + return s +} + +func (s *selectJsonStatement) OFFSET_e(offset IntegerExpression) SelectStatement { + s.subQuery.Offset.Count = offset + return s +} + +func (s *selectJsonStatement) FETCH_FIRST(count IntegerExpression) fetchExpand { + s.subQuery.Fetch.Count = count + + return fetchExpand{ + selectStatement: s.subQuery, + rootStmt: s, + } +} + +func (s *selectJsonStatement) FOR(lock RowLock) SelectStatement { + s.subQuery.For.Lock = lock + return s +} + +func newSelectStatementJson(projections []Projection, statementType jet.StatementType) SelectStatement { + newSelectJson := &selectJsonStatement{ + selectStatementImpl: newSelectStatement(statementType, nil, nil), + subQuery: newSelectStatement(statementType, nil, projections), + statementType: statementType, + } + + newSelectJson.setOperatorsImpl.stmtRoot = newSelectJson + newSelectJson.subQuery.Select.IsForRowToJson = true + + newSelectJson.setSubQueryAlias("") + + return newSelectJson +} + +func (s *selectJsonStatement) setSubQueryAlias(alias string) { + subQueryAlias := alias + "records" + jsonAlias := alias + "json" + + s.Select.ProjectionList = ProjectionList{constructJsonFunc(s.statementType, subQueryAlias).AS(jsonAlias)} + + s.From.Tables = []jet.Serializer{newSelectTable(s.subQuery, subQueryAlias, nil)} +} + +func constructJsonFunc(statementType jet.StatementType, subQueryAlias string) Expression { + rowToJson := Func("row_to_json", CustomExpression(Token(subQueryAlias))) + + if statementType == jet.SelectJsonArrStatementType { + return Func("json_agg", rowToJson) + } + + return rowToJson +} diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 2a48fc5..f2427af 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -70,12 +70,12 @@ type SelectStatement interface { // SELECT creates new SelectStatement with list of projections func SELECT(projection Projection, projections ...Projection) SelectStatement { - return newSelectStatement(nil, append([]Projection{projection}, projections...)) + return newSelectStatement(jet.SelectStatementType, nil, append([]Projection{projection}, projections...)) } -func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { +func newSelectStatement(stmtType jet.StatementType, table ReadableTable, projections []Projection) *selectStatementImpl { newSelect := &selectStatementImpl{} - newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, + newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, stmtType, newSelect, &newSelect.Select, &newSelect.From, &newSelect.Where, @@ -94,7 +94,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta } newSelect.Limit.Count = -1 - newSelect.setOperatorsImpl.parent = newSelect + newSelect.setOperatorsImpl.stmtRoot = newSelect return newSelect } @@ -144,7 +144,10 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem func (s *selectStatementImpl) WINDOW(name string) windowExpand { s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) - return windowExpand{selectStatement: s} + return windowExpand{ + selectStatement: s, + rootStmt: s, + } } func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement { @@ -172,6 +175,7 @@ func (s *selectStatementImpl) FETCH_FIRST(count IntegerExpression) fetchExpand { return fetchExpand{ selectStatement: s, + rootStmt: s, } } @@ -188,6 +192,7 @@ func (s *selectStatementImpl) AsTable(alias string) SelectTable { type windowExpand struct { selectStatement *selectStatementImpl + rootStmt SelectStatement } func (w windowExpand) AS(window ...jet.Window) SelectStatement { @@ -196,7 +201,7 @@ func (w windowExpand) AS(window ...jet.Window) SelectStatement { } windowsDefinition := w.selectStatement.Window.Definitions windowsDefinition[len(windowsDefinition)-1].Window = window[0] - return w.selectStatement + return w.rootStmt } func toJetFrameOffset(offset int64) jet.Serializer { @@ -216,16 +221,17 @@ func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { type fetchExpand struct { selectStatement *selectStatementImpl + rootStmt SelectStatement } func (f fetchExpand) ROWS_ONLY() SelectStatement { f.selectStatement.Fetch.WithTies = false - return f.selectStatement + return f.rootStmt } func (f fetchExpand) ROWS_WITH_TIES() SelectStatement { f.selectStatement.Fetch.WithTies = true - return f.selectStatement + return f.rootStmt } diff --git a/postgres/set_statement.go b/postgres/set_statement.go index 0dee00d..5f553cb 100644 --- a/postgres/set_statement.go +++ b/postgres/set_statement.go @@ -65,31 +65,31 @@ type setOperators interface { } type setOperatorsImpl struct { - parent setOperators + stmtRoot setOperators } func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement { - return UNION(s.parent, rhs) + return UNION(s.stmtRoot, rhs) } func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement { - return UNION_ALL(s.parent, rhs) + return UNION_ALL(s.stmtRoot, rhs) } func (s *setOperatorsImpl) INTERSECT(rhs SelectStatement) setStatement { - return INTERSECT(s.parent, rhs) + return INTERSECT(s.stmtRoot, rhs) } func (s *setOperatorsImpl) INTERSECT_ALL(rhs SelectStatement) setStatement { - return INTERSECT_ALL(s.parent, rhs) + return INTERSECT_ALL(s.stmtRoot, rhs) } func (s *setOperatorsImpl) EXCEPT(rhs SelectStatement) setStatement { - return EXCEPT(s.parent, rhs) + return EXCEPT(s.stmtRoot, rhs) } func (s *setOperatorsImpl) EXCEPT_ALL(rhs SelectStatement) setStatement { - return EXCEPT_ALL(s.parent, rhs) + return EXCEPT_ALL(s.stmtRoot, rhs) } type setStatementImpl struct { @@ -110,7 +110,7 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat newSetStatement.setOperator.Selects = selects newSetStatement.setOperator.Limit.Count = -1 - newSetStatement.setOperatorsImpl.parent = newSetStatement + newSetStatement.setOperatorsImpl.stmtRoot = newSetStatement return newSetStatement } diff --git a/postgres/table.go b/postgres/table.go index f90c114..aa54213 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -55,7 +55,7 @@ type readableTableInterfaceImpl struct { // Generates a select query on the current tableName. func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { - return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) + return newSelectStatement(jet.SelectStatementType, r.parent, append([]Projection{projection1}, projections...)) } // Creates a inner join tableName Expression using onCondition. diff --git a/qrm/internal/null_types.go b/qrm/internal/null_types.go index 85d9c68..1881b72 100644 --- a/qrm/internal/null_types.go +++ b/qrm/internal/null_types.go @@ -4,14 +4,13 @@ import ( "database/sql" "database/sql/driver" "fmt" - "github.com/go-jet/jet/v2/internal/utils/min" + "github.com/go-jet/jet/v2/internal/utils/datetime" "reflect" "strconv" - "time" ) var ( - castOverFlowError = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error") + errCastOverFlow = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error") ) // NullBool struct @@ -64,7 +63,12 @@ func (nt *NullTime) Scan(value interface{}) error { // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value. // At this point we try to parse those values using some of the predefined formats - nt.Time, nt.Valid = tryParseAsTime(value) + nt.Time, nt.Valid = datetime.TryParseAsTime(value, []string{ + "2006-01-02 15:04:05-07:00", // sqlite + "2006-01-02 15:04:05.999999", // go-sql-driver/mysql + "15:04:05-07", // pgx + "15:04:05.999999", // pgx + }) if !nt.Valid { return fmt.Errorf("can't scan time.Time from %q", value) @@ -73,42 +77,6 @@ func (nt *NullTime) Scan(value interface{}) error { return nil } -var formats = []string{ - "2006-01-02 15:04:05-07:00", // sqlite - "2006-01-02 15:04:05.999999", // go-sql-driver/mysql - "15:04:05-07", // pgx - "15:04:05.999999", // pgx -} - -func tryParseAsTime(value interface{}) (time.Time, bool) { - - var timeStr string - - switch v := value.(type) { - case string: - timeStr = v - case []byte: - timeStr = string(v) - case int64: - return time.Unix(v, 0), true // sqlite - default: - return time.Time{}, false - } - - for _, format := range formats { - formatLen := min.Int(len(format), len(timeStr)) - t, err := time.Parse(format[:formatLen], timeStr) - - if err != nil { - continue - } - - return t, true - } - - return time.Time{}, false -} - // NullUInt64 struct type NullUInt64 struct { UInt64 uint64 @@ -124,31 +92,31 @@ func (n *NullUInt64) Scan(value interface{}) error { return nil case int64: if v < 0 { - return castOverFlowError + return errCastOverFlow } n.UInt64, n.Valid = uint64(v), true return nil case int32: if v < 0 { - return castOverFlowError + return errCastOverFlow } n.UInt64, n.Valid = uint64(v), true return nil case int16: if v < 0 { - return castOverFlowError + return errCastOverFlow } n.UInt64, n.Valid = uint64(v), true return nil case int8: if v < 0 { - return castOverFlowError + return errCastOverFlow } n.UInt64, n.Valid = uint64(v), true return nil case int: if v < 0 { - return castOverFlowError + return errCastOverFlow } n.UInt64, n.Valid = uint64(v), true return nil diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go index eab2dd2..3feeffc 100644 --- a/qrm/internal/null_types_test.go +++ b/qrm/internal/null_types_test.go @@ -103,25 +103,25 @@ func TestNullUInt64(t *testing.T) { //Validate negative use cases err := nullUInt64.Scan(int64(-5)) assert.NotNil(t, err) - assert.Error(t, err, castOverFlowError) + assert.Error(t, err, errCastOverFlow) //Validate negative use cases err = nullUInt64.Scan(-5) assert.NotNil(t, err) - assert.Error(t, err, castOverFlowError) + assert.Error(t, err, errCastOverFlow) //Validate negative use cases err = nullUInt64.Scan(int32(-5)) assert.NotNil(t, err) - assert.Error(t, err, castOverFlowError) + assert.Error(t, err, errCastOverFlow) //Validate negative use cases err = nullUInt64.Scan(int16(-5)) assert.NotNil(t, err) - assert.Error(t, err, castOverFlowError) + assert.Error(t, err, errCastOverFlow) //Validate negative use cases err = nullUInt64.Scan(int8(-5)) assert.NotNil(t, err) - assert.Error(t, err, castOverFlowError) + assert.Error(t, err, errCastOverFlow) } diff --git a/qrm/qrm.go b/qrm/qrm.go index edd9387..4d49d46 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -3,6 +3,7 @@ package qrm import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "github.com/go-jet/jet/v2/internal/utils/must" @@ -12,10 +13,130 @@ import ( // ErrNoRows is returned by Query when query result set is empty var ErrNoRows = errors.New("qrm: no rows in result set") -// Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` -// using context `ctx` into destination `destPtr`. -// Destination can be either pointer to struct or pointer to slice of structs. -// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. +// QueryJsonObj executes a SQL query that returns a JSON object, unmarshals the result into the provided destination, +// and returns the number of rows processed. +// +// The query must return exactly one row with a single column; otherwise, an error is returned. +// +// Parameters: +// +// ctx - The context for managing query execution (timeouts, cancellations). +// db - The database connection or transaction that implements the Queryable interface. +// query - The SQL query string to be executed. +// args - A slice of arguments to be used with the query. +// destPtr - A pointer to the variable where the unmarshaled JSON result will be stored. +// The destination should be a pointer to a struct or map[string]any. +// +// Returns: +// +// rowsProcessed - The number of rows processed by the query execution. +// err - An error if query execution or unmarshaling fails. +func QueryJsonObj(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { + must.BeInitializedPtr(destPtr, "jet: destination is nil") + must.BeTypeKind(destPtr, reflect.Ptr, jsonDestObjErr) + destType := reflect.TypeOf(destPtr).Elem() + must.BeTrue(destType.Kind() == reflect.Struct || destType.Kind() == reflect.Map, jsonDestObjErr) + + return queryJson(ctx, db, query, args, destPtr) +} + +// QueryJsonArr executes a SQL query that returns a JSON array, unmarshals the result into the provided destination, +// and returns the number of rows processed. +// +// The query must return exactly one row with a single column; otherwise, an error is returned. +// +// Parameters: +// +// ctx - The context for managing query execution (timeouts, cancellations). +// db - The database connection or transaction that implements the Queryable interface. +// query - The SQL query string to be executed. +// args - A slice of arguments to be used with the query. +// destPtr - A pointer to the variable where the unmarshaled JSON array will be stored. +// The destination should be a pointer to a slice of structs or []map[string]any. +// +// Returns: +// +// rowsProcessed - The number of rows processed by the query execution. +// err - An error if query execution or unmarshaling fails. +func QueryJsonArr(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { + must.BeInitializedPtr(destPtr, "jet: destination is nil") + must.BeTypeKind(destPtr, reflect.Ptr, jsonDestArrErr) + destType := reflect.TypeOf(destPtr).Elem() + must.BeTrue(destType.Kind() == reflect.Slice, jsonDestArrErr) + + return queryJson(ctx, db, query, args, destPtr) +} + +var jsonDestObjErr = "jet: SELECT_JSON_OBJ destination has to be a pointer to struct or pointer to map[string]any" +var jsonDestArrErr = "jet: SELECT_JSON_ARR destination has to be a pointer to slice of struct or pointer to []map[string]any" + +func queryJson(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { + must.BeInitializedPtr(db, "jet: db is nil") + + var rows *sql.Rows + rows, err = db.QueryContext(ctx, query, args...) + + if err != nil { + return 0, err + } + + defer rows.Close() + + if !rows.Next() { + err = rows.Err() + if err != nil { + return 0, err + } + return 0, ErrNoRows + } + + var jsonData []byte + err = rows.Scan(&jsonData) + + if err != nil { + return 1, err + } + + if jsonData == nil { + return 1, nil + } + + err = json.Unmarshal(jsonData, &destPtr) + + if err != nil { + return 1, fmt.Errorf("jet: invalid json, %w", err) + } + + if rows.Next() { + return 1, fmt.Errorf("jet: query returned more then one row") + } + + err = rows.Close() + if err != nil { + return 1, err + } + + return 1, nil +} + +// Query executes a Query Result Mapping (QRM) of the provided SQL `query` with a list of parameterized arguments `args` +// over the database connection `db` using the provided context `ctx` and stores the result in the destination `destPtr`. +// +// The destination must be a pointer to either a struct or a slice of structs +// If the destination is a pointer to a struct and no rows are returned, the method returns qrm.ErrNoRows. +// +// Parameters: +// +// ctx - The context for managing query execution (timeouts, cancellations). +// db - The database connection or transaction implementing the Queryable interface. +// query - The SQL query string to be executed. +// args - A slice of arguments to be used with the query. +// destPtr - A pointer to the variable where the query result will be stored. This can be a pointer to a struct or a slice of structs. +// +// Returns: +// +// rowsProcessed - The number of rows processed by the query execution. +// err - An error if query execution or result mapping fails, or if no rows are found when a struct is expected. func Query(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { must.BeInitializedPtr(db, "jet: db is nil") @@ -185,7 +306,7 @@ func mapRowToSlice( func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { index := 0 if field != nil { - typeName, columnName := getTypeAndFieldName("", *field) + typeName, columnName, _ := getTypeAndFieldName("", *field) if index = scanContext.typeToColumnIndex(typeName, columnName); index < 0 { return } @@ -233,9 +354,11 @@ func mapRowToStruct( continue } - fieldMap := typeInf.fieldMappings[i] + fieldMappingInfo := typeInf.fieldMappings[i] - if fieldMap.complexType { + switch fieldMappingInfo.Type { + + case complexType: var changed bool changed, err = mapRowToDestinationValue(scanContext, concat(groupKey, ":", field.Name), fieldValue, &field) @@ -246,13 +369,12 @@ func mapRowToStruct( if changed { updated = true } - - } else { - if mapOnlySlices || fieldMap.rowIndex == -1 { + default: + if mapOnlySlices || fieldMappingInfo.rowIndex == -1 { continue } - scannedValue := scanContext.rowElemValue(fieldMap.rowIndex) + scannedValue := scanContext.rowElemValue(fieldMappingInfo.rowIndex) if !scannedValue.IsValid() { setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value @@ -261,7 +383,8 @@ func mapRowToStruct( updated = true - if fieldMap.implementsScanner { + switch fieldMappingInfo.Type { + case implementsScanner: initializeValueIfNilPtr(fieldValue) fieldScanner := getScanner(fieldValue) @@ -270,14 +393,27 @@ func mapRowToStruct( err := fieldScanner.Scan(value) if err != nil { - return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err) + return updated, qrmAssignError(scannedValue, field, err) } - } else { + case jsonUnmarshal: + value, ok := scannedValue.Interface().([]byte) + + if !ok { + return updated, qrmAssignError(scannedValue, field, fmt.Errorf("value not convertable to []byte")) + } + + fieldInterface := fieldValue.Addr().Interface() + + err := json.Unmarshal(value, fieldInterface) + + if err != nil { + return updated, qrmAssignError(scannedValue, field, fmt.Errorf("invalid json, %w", err)) + } + default: // simple type err := assign(scannedValue, fieldValue) if err != nil { - return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(), - field.Name, field.Type.String(), err) + return updated, qrmAssignError(scannedValue, field, err) } } } @@ -286,6 +422,11 @@ func mapRowToStruct( return } +func qrmAssignError(scannedValue reflect.Value, field reflect.StructField, err error) error { + return fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(), + field.Name, field.Type.String(), err) +} + func mapRowToDestinationValue( scanContext *ScanContext, groupKey string, diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 7a3c538..e28c4f1 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -75,10 +75,18 @@ type typeInfo struct { fieldMappings []fieldMapping } +type fieldMappingType int + +const ( + simpleType fieldMappingType = iota + complexType // slice and struct are complex types supported + implementsScanner + jsonUnmarshal +) + type fieldMapping struct { - complexType bool // slice and struct are complex types - rowIndex int // index in ScanContext.row - implementsScanner bool + rowIndex int // index in ScanContext.row + Type fieldMappingType } func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { @@ -100,17 +108,21 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect. for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) - newTypeName, fieldName := getTypeAndFieldName(typeName, field) + newTypeName, fieldName, jsonUnmarshaler := getTypeAndFieldName(typeName, field) columnIndex := s.typeToColumnIndex(newTypeName, fieldName) fieldMap := fieldMapping{ rowIndex: columnIndex, } - if implementsScannerType(field.Type) { - fieldMap.implementsScanner = true + if jsonUnmarshaler { + fieldMap.Type = jsonUnmarshal + } else if implementsScannerType(field.Type) { + fieldMap.Type = implementsScanner } else if !isSimpleModelType(field.Type) { - fieldMap.complexType = true + fieldMap.Type = complexType + } else { + fieldMap.Type = simpleType } newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap) @@ -188,7 +200,7 @@ func (s *ScanContext) getGroupKeyInfo( fieldType := indirectType(field.Type) if isPrimaryKey(field, primaryKeyOverwrites) { - newTypeName, fieldName := getTypeAndFieldName(typeName, field) + newTypeName, fieldName, _ := getTypeAndFieldName(typeName, field) pkIndex := s.typeToColumnIndex(newTypeName, fieldName) diff --git a/qrm/utill.go b/qrm/utill.go index b43ee29..1775fda 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -107,20 +107,26 @@ func getTypeName(structType reflect.Type, parentField *reflect.StructField) stri return toCommonIdentifier(aliasParts[0]) } -func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) { +func getTypeAndFieldName(structType string, field reflect.StructField) (string, string, bool) { aliasTag := field.Tag.Get("alias") - if aliasTag == "" { - return structType, field.Name + if aliasTag != "" { + aliasParts := strings.Split(aliasTag, ".") + + if len(aliasParts) == 1 { + return structType, toCommonIdentifier(aliasParts[0]), false + } + + return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]), false } - aliasParts := strings.Split(aliasTag, ".") + jsonColumnTag := field.Tag.Get("json_column") - if len(aliasParts) == 1 { - return structType, toCommonIdentifier(aliasParts[0]) + if jsonColumnTag != "" { + return "", toCommonIdentifier(jsonColumnTag), true } - return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]) + return structType, field.Name, false } var replacer = strings.NewReplacer(" ", "", "-", "", "_", "") diff --git a/sqlite/cast.go b/sqlite/cast.go index 0f68f43..fb74820 100644 --- a/sqlite/cast.go +++ b/sqlite/cast.go @@ -42,6 +42,6 @@ func (c *cast) AS_REAL() FloatExpression { } // AS_BLOB cast expression to BLOB type -func (c *cast) AS_BLOB() StringExpression { - return StringExp(c.AS("BLOB")) +func (c *cast) AS_BLOB() BlobExpression { + return BlobExp(c.AS("BLOB")) } diff --git a/sqlite/columns.go b/sqlite/columns.go index 2941b8d..44b6145 100644 --- a/sqlite/columns.go +++ b/sqlite/columns.go @@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString // StringColumn creates named string column. var StringColumn = jet.StringColumn +// ColumnBlob is interface for +type ColumnBlob = jet.ColumnBlob + +// BlobColumn creates new named blob column +var BlobColumn = jet.BlobColumn + // ColumnInteger is interface for SQL smallint, integer, bigint columns. type ColumnInteger = jet.ColumnInteger diff --git a/sqlite/dialect.go b/sqlite/dialect.go index da03364..3585651 100644 --- a/sqlite/dialect.go +++ b/sqlite/dialect.go @@ -1,6 +1,7 @@ package sqlite import ( + "encoding/hex" "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -23,7 +24,8 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, - ReservedWords: reservedWords2, + ArgumentToString: argumentToString, + ReservedWords: reservedWords2, ValuesDefaultColumnName: func(index int) string { return fmt.Sprintf("column%d", index+1) }, @@ -32,6 +34,15 @@ func newDialect() jet.Dialect { return jet.NewDialect(mySQLDialectParams) } +func argumentToString(value any) (string, bool) { + switch bindVal := value.(type) { + case []byte: + return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true + } + + return "", false +} + func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { diff --git a/sqlite/expressions.go b/sqlite/expressions.go index 0b2d320..3de5cd2 100644 --- a/sqlite/expressions.go +++ b/sqlite/expressions.go @@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression // StringExpression interface type StringExpression = jet.StringExpression +// BlobExpression interface +type BlobExpression = jet.BlobExpression + // NumericExpression is shared interface for integer or real expression type NumericExpression = jet.NumericExpression @@ -46,6 +49,11 @@ var BoolExp = jet.BoolExp // Does not add sql cast to generated sql builder output. var StringExp = jet.StringExp +// BlobExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as blob expression. +// Does not add sql cast to generated sql builder output. +var BlobExp = jet.BlobExp + // IntExp is int expression wrapper around arbitrary expression. // Allows go compiler to see any expression as int expression. // Does not add sql cast to generated sql builder output. diff --git a/sqlite/functions.go b/sqlite/functions.go index ac6bd08..92139b4 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -196,11 +196,22 @@ var RTRIM = jet.RTRIM // return jet.NewStringFunc("RIGHTSTR", str, n) //} +// HEX function takes an input and returns its equivalent hexadecimal representation +var HEX = jet.HEX + +// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument +// as a hexadecimal number and converts it to the byte represented by the number. +// The return value is a binary string. +var UNHEX = jet.UNHEX + // LENGTH returns number of characters in string with a given encoding -func LENGTH(str jet.StringExpression) jet.StringExpression { +func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression { return jet.LENGTH(str) } +// OCTET_LENGTH returns number of bytes in string expression +var OCTET_LENGTH = jet.OCTET_LENGTH + // LPAD fills up the string to length length by prepending the characters // fill (a space by default). If the string is already longer than length // then it is truncated (on the right). diff --git a/sqlite/literal.go b/sqlite/literal.go index 2df5dd7..296bee9 100644 --- a/sqlite/literal.go +++ b/sqlite/literal.go @@ -50,6 +50,11 @@ var Decimal = jet.Decimal // String creates new string literal expression var String = jet.String +// Blob creates new blob literal expression +func Blob(data []byte) BlobExpression { + return BlobExp(jet.Literal(data)) +} + // UUID is a helper function to create string literal expression from uuid object // value can be any uuid type with a String method var UUID = jet.UUID diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index bcbbb25..8c0d2a2 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -26,7 +26,7 @@ services: - ./testdata/init/mysql:/docker-entrypoint-initdb.d mariadb: - image: mariadb:10.3 + image: mariadb:11.4 command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1'] restart: always environment: diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 61bc7f2..ae3de1b 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -23,19 +23,115 @@ func TestAllTypes(t *testing.T) { var dest []model.AllTypes - err := AllTypes. - SELECT(AllTypes.AllColumns). + err := SELECT(AllTypes.AllColumns). + FROM(AllTypes). LIMIT(2). Query(db, &dest) require.NoError(t, err) - require.Equal(t, len(dest), 2) //testutils.PrintJson(dest) testutils.AssertJSON(t, dest, allTypesJson) } +func TestAllTypesJSON(t *testing.T) { + + stmt := SELECT_JSON_ARR( + AllTypes.AllColumns.Except( + AllTypes.JSON, + AllTypes.JSONPtr, + AllTypes.Bit, + AllTypes.BitPtr, + ), + CAST(AllTypes.JSON).AS_CHAR().AS("Json"), + CAST(AllTypes.JSONPtr).AS_CHAR().AS("JsonPtr"), + CAST(AllTypes.Bit).AS_CHAR().AS("Bit"), + CAST(AllTypes.BitPtr).AS_CHAR().AS("BitPtr"), + ).FROM(AllTypes) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +SELECT JSON_ARRAYAGG(JSON_OBJECT( + 'id', all_types.id, + 'boolean', all_types.boolean = 1, + 'booleanPtr', all_types.boolean_ptr = 1, + 'tinyInt', all_types.tiny_int, + 'uTinyInt', all_types.u_tiny_int, + 'smallInt', all_types.small_int, + 'uSmallInt', all_types.u_small_int, + 'mediumInt', all_types.medium_int, + 'uMediumInt', all_types.u_medium_int, + 'integer', all_types.''integer'', + 'uInteger', all_types.u_integer, + 'bigInt', all_types.big_int, + 'uBigInt', all_types.u_big_int, + 'tinyIntPtr', all_types.tiny_int_ptr, + 'uTinyIntPtr', all_types.u_tiny_int_ptr, + 'smallIntPtr', all_types.small_int_ptr, + 'uSmallIntPtr', all_types.u_small_int_ptr, + 'mediumIntPtr', all_types.medium_int_ptr, + 'uMediumIntPtr', all_types.u_medium_int_ptr, + 'integerPtr', all_types.integer_ptr, + 'uIntegerPtr', all_types.u_integer_ptr, + 'bigIntPtr', all_types.big_int_ptr, + 'uBigIntPtr', all_types.u_big_int_ptr, + 'decimal', all_types.''decimal'', + 'decimalPtr', all_types.decimal_ptr, + 'numeric', all_types.''numeric'', + 'numericPtr', all_types.numeric_ptr, + 'float', all_types.''float'', + 'floatPtr', all_types.float_ptr, + 'double', all_types.''double'', + 'doublePtr', all_types.double_ptr, + 'real', all_types.''real'', + 'realPtr', all_types.real_ptr, + 'time', CONCAT('0000-01-01T', DATE_FORMAT(all_types.time,'%H:%i:%s.%fZ')), + 'timePtr', CONCAT('0000-01-01T', DATE_FORMAT(all_types.time_ptr,'%H:%i:%s.%fZ')), + 'date', CONCAT(DATE_FORMAT(all_types.date,'%Y-%m-%d'), 'T00:00:00Z'), + 'datePtr', CONCAT(DATE_FORMAT(all_types.date_ptr,'%Y-%m-%d'), 'T00:00:00Z'), + 'dateTime', DATE_FORMAT(all_types.date_time,'%Y-%m-%dT%H:%i:%s.%fZ'), + 'dateTimePtr', DATE_FORMAT(all_types.date_time_ptr,'%Y-%m-%dT%H:%i:%s.%fZ'), + 'timestamp', DATE_FORMAT(all_types.timestamp,'%Y-%m-%dT%H:%i:%s.%fZ'), + 'timestampPtr', DATE_FORMAT(all_types.timestamp_ptr,'%Y-%m-%dT%H:%i:%s.%fZ'), + 'year', all_types.year, + 'yearPtr', all_types.year_ptr, + 'char', all_types.''char'', + 'charPtr', all_types.char_ptr, + 'varChar', all_types.var_char, + 'varCharPtr', all_types.var_char_ptr, + 'binary', TO_BASE64(all_types.''binary''), + 'binaryPtr', TO_BASE64(all_types.binary_ptr), + 'varBinary', TO_BASE64(all_types.var_binary), + 'varBinaryPtr', TO_BASE64(all_types.var_binary_ptr), + 'blob', TO_BASE64(all_types.''blob''), + 'blobPtr', TO_BASE64(all_types.blob_ptr), + 'text', all_types.text, + 'textPtr', all_types.text_ptr, + 'enum', all_types.enum, + 'enumPtr', all_types.enum_ptr, + 'set', all_types.''set'', + 'setPtr', all_types.set_ptr, + 'Json', CAST(all_types.json AS CHAR), + 'JsonPtr', CAST(all_types.json_ptr AS CHAR), + 'Bit', CAST(all_types.bit AS CHAR), + 'BitPtr', CAST(all_types.bit_ptr AS CHAR) + )) AS "json" +FROM test_sample.all_types; +`, "''", "`")) + + var dest []model.AllTypes + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + // fix float rounding lost before comparison + dest[0].Float = 3.33 + dest[0].FloatPtr = ptr.Of(3.33) + dest[1].Float = 3.33 + + testutils.AssertJSON(t, dest, allTypesJson) +} + func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes @@ -467,7 +563,8 @@ func TestStringOperators(t *testing.T) { RTRIM(AllTypes.VarCharPtr), CONCAT(String("string1"), Int(1), Float(11.12)), CONCAT_WS(String("string1"), Int(1), Float(11.12)), - FORMAT(String("Hello %s, %1$s"), String("World")), + FORMAT(Int(11), Int(2)), + FORMAT(Int(11), Int(2), String("de_DE")), LEFT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)), LENGTH(String("jose")), @@ -479,6 +576,12 @@ func TestStringOperators(t *testing.T) { REVERSE(AllTypes.VarCharPtr), SUBSTR(AllTypes.CharPtr, Int(3)), SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), + ELT(Int(2), AllTypes.CharPtr, AllTypes.Char, AllTypes.Text), + FIELD(AllTypes.Char, AllTypes.VarChar, AllTypes.Text), + FROM_BASE64(String("SGVsbG8gV29ybGQ=")), + TO_BASE64(String("Hello World")), + CHARSET(AllTypes.Char), + COLLATION(AllTypes.Text), } if !sourceIsMariaDB() { @@ -500,6 +603,71 @@ func TestStringOperators(t *testing.T) { require.NoError(t, err) } +func TestBlob(t *testing.T) { + + var sampleBlob = Blob([]byte{11, 0, 22, 33, 44}) + var textBlob = Blob([]byte("text blob")) + + stmt := SELECT( + AllTypes.BlobPtr.EQ(sampleBlob), + AllTypes.BlobPtr.EQ(AllTypes.BlobPtr), + AllTypes.BlobPtr.NOT_EQ(sampleBlob), + AllTypes.BlobPtr.GT(textBlob), + AllTypes.BlobPtr.GT_EQ(AllTypes.BlobPtr), + AllTypes.BlobPtr.LT(AllTypes.BlobPtr), + AllTypes.BlobPtr.LT_EQ(sampleBlob), + AllTypes.BlobPtr.BETWEEN(Blob([]byte("min")), Blob([]byte("max"))), + AllTypes.BlobPtr.NOT_BETWEEN(AllTypes.BlobPtr, AllTypes.BlobPtr), + AllTypes.BlobPtr.CONCAT(textBlob), + AllTypes.BlobPtr.LIKE(AllTypes.BlobPtr), + AllTypes.BlobPtr.NOT_LIKE(sampleBlob), + + BIT_LENGTH(textBlob), + LENGTH(sampleBlob), + CHAR_LENGTH(AllTypes.BlobPtr), + OCTET_LENGTH(textBlob), + CONCAT(sampleBlob, Int(1), Float(11.12)), + TO_BASE64(sampleBlob), + HEX(sampleBlob), + UNHEX(String("616B263A")), + SUBSTR(AllTypes.BlobPtr, Int(3)), + SUBSTR(AllTypes.BlobPtr, Int(3), Int(2)), + ).FROM( + AllTypes, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT all_types.blob_ptr = X'0b0016212c', + all_types.blob_ptr = all_types.blob_ptr, + all_types.blob_ptr != X'0b0016212c', + all_types.blob_ptr > X'7465787420626c6f62', + all_types.blob_ptr >= all_types.blob_ptr, + all_types.blob_ptr < all_types.blob_ptr, + all_types.blob_ptr <= X'0b0016212c', + all_types.blob_ptr BETWEEN X'6d696e' AND X'6d6178', + all_types.blob_ptr NOT BETWEEN all_types.blob_ptr AND all_types.blob_ptr, + CONCAT(all_types.blob_ptr, X'7465787420626c6f62'), + all_types.blob_ptr LIKE all_types.blob_ptr, + all_types.blob_ptr NOT LIKE X'0b0016212c', + BIT_LENGTH(X'7465787420626c6f62'), + LENGTH(X'0b0016212c'), + CHAR_LENGTH(all_types.blob_ptr), + OCTET_LENGTH(X'7465787420626c6f62'), + CONCAT(X'0b0016212c', 1, 11.12), + TO_BASE64(X'0b0016212c'), + HEX(X'0b0016212c'), + UNHEX('616B263A'), + SUBSTR(all_types.blob_ptr, 3), + SUBSTR(all_types.blob_ptr, 3, 2) +FROM test_sample.all_types; +`) + + var dest []struct{} + err := stmt.Query(db, &dest) + + require.NoError(t, err) +} + var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) func TestTimeExpressions(t *testing.T) { @@ -1066,6 +1234,118 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) { require.NoError(t, err) } +func TestAllTypesSubQueryFrom(t *testing.T) { + subQuery := SELECT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.Double, + AllTypes.Text, + AllTypes.Date, + AllTypes.Time, + AllTypes.Timestamp, + AllTypes.Blob, + ).FROM( + AllTypes, + ).AsTable("sub_query") + + stmt := SELECT( + AllTypes.Boolean.From(subQuery), + AllTypes.Integer.From(subQuery), + AllTypes.Double.From(subQuery), + AllTypes.Text.From(subQuery), + AllTypes.Date.From(subQuery), + AllTypes.Time.From(subQuery), + AllTypes.Timestamp.From(subQuery), + AllTypes.Blob.From(subQuery), + ).FROM( + subQuery, + ) + + testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` +SELECT sub_query.''all_types.boolean'' AS "all_types.boolean", + sub_query.''all_types.integer'' AS "all_types.integer", + sub_query.''all_types.double'' AS "all_types.double", + sub_query.''all_types.text'' AS "all_types.text", + sub_query.''all_types.date'' AS "all_types.date", + sub_query.''all_types.time'' AS "all_types.time", + sub_query.''all_types.timestamp'' AS "all_types.timestamp", + sub_query.''all_types.blob'' AS "all_types.blob" +FROM ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.''integer'' AS "all_types.integer", + all_types.''double'' AS "all_types.double", + all_types.text AS "all_types.text", + all_types.date AS "all_types.date", + all_types.time AS "all_types.time", + all_types.timestamp AS "all_types.timestamp", + all_types.''blob'' AS "all_types.blob" + FROM test_sample.all_types + ) AS sub_query; +`, "''", "`")) + + var dest []model.AllTypes + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.NotEmpty(t, dest) + + t.Run("using SELECT_JSON", func(t *testing.T) { + stmtJson := SELECT_JSON_ARR( + AllTypes.Boolean.From(subQuery), + AllTypes.Integer.From(subQuery), + AllTypes.Double.From(subQuery), + AllTypes.Text.From(subQuery), + AllTypes.Date.From(subQuery), + AllTypes.Time.From(subQuery), + AllTypes.Timestamp.From(subQuery), + AllTypes.Blob.From(subQuery), + ).FROM( + subQuery, + ) + + testutils.AssertDebugStatementSql(t, stmtJson, strings.ReplaceAll(` +SELECT JSON_ARRAYAGG(JSON_OBJECT( + 'boolean', sub_query.''all_types.boolean'' = 1, + 'integer', sub_query.''all_types.integer'', + 'double', sub_query.''all_types.double'', + 'text', sub_query.''all_types.text'', + 'date', CONCAT(DATE_FORMAT(sub_query.''all_types.date'','%Y-%m-%d'), 'T00:00:00Z'), + 'time', CONCAT('0000-01-01T', DATE_FORMAT(sub_query.''all_types.time'','%H:%i:%s.%fZ')), + 'timestamp', DATE_FORMAT(sub_query.''all_types.timestamp'','%Y-%m-%dT%H:%i:%s.%fZ'), + 'blob', TO_BASE64(sub_query.''all_types.blob'') + )) AS "json" +FROM ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.''integer'' AS "all_types.integer", + all_types.''double'' AS "all_types.double", + all_types.text AS "all_types.text", + all_types.date AS "all_types.date", + all_types.time AS "all_types.time", + all_types.timestamp AS "all_types.timestamp", + all_types.''blob'' AS "all_types.blob" + FROM test_sample.all_types + ) AS sub_query; +`, "''", "`")) + + var destJson []model.AllTypes + + err := stmtJson.QueryContext(ctx, db, &destJson) + require.NoError(t, err) + + t.Run("using AllColumns()", func(t *testing.T) { + stmtJsonAllColumns := SELECT_JSON_ARR( + subQuery.AllColumns(), + ).FROM( + subQuery, + ) + + require.Equal(t, stmtJson.DebugSql(), stmtJsonAllColumns.DebugSql()) + }) + + testutils.AssertJsonEqual(t, dest, destJson) + }) +} + var toInsert = model.AllTypes{ Boolean: false, BooleanPtr: ptr.Of(true), @@ -1131,6 +1411,7 @@ var toInsert = model.AllTypes{ var allTypesJson = ` [ { + "ID": 1, "Boolean": false, "BooleanPtr": true, "TinyInt": -3, @@ -1195,6 +1476,7 @@ var allTypesJson = ` "JSONPtr": "{\"key1\": \"value1\", \"key2\": \"value2\"}" }, { + "ID": 2, "Boolean": false, "BooleanPtr": null, "TinyInt": -3, diff --git a/tests/mysql/bench_test.go b/tests/mysql/bench_test.go new file mode 100644 index 0000000..e5abe5c --- /dev/null +++ b/tests/mysql/bench_test.go @@ -0,0 +1,138 @@ +//go:build bench +// +build bench + +package mysql + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/mysql" + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" + "github.com/stretchr/testify/require" + "testing" +) + +type allInfo []struct { + model.Actor + + Films []struct { + model.Film + + Language model.Language + Categories []model.Category + + Inventories []struct { + model.Inventory + + Rentals []struct { + model.Rental + + Customer model.Customer + } + } + } +} + +func BenchmarkTestDVDsJoinEverything(b *testing.B) { + for i := 0; i < b.N; i++ { + testDVDsJoinEverything(b) + } +} + +func TestDVDsJoinEverything(t *testing.T) { + testDVDsJoinEverything(t) +} + +func testDVDsJoinEverything(t require.TestingT) { + stmt := SELECT( + Actor.AllColumns, + Film.AllColumns, + Language.AllColumns, + Category.AllColumns, + Inventory.AllColumns, + Rental.AllColumns, + Customer.AllColumns, + ).FROM( + Actor. + LEFT_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)). + LEFT_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)). + LEFT_JOIN(Language, Language.LanguageID.EQ(Film.LanguageID)). + LEFT_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)). + LEFT_JOIN(Category, Category.CategoryID.EQ(FilmCategory.CategoryID)). + LEFT_JOIN(Inventory, Inventory.FilmID.EQ(Film.FilmID)). + LEFT_JOIN(Rental, Rental.InventoryID.EQ(Inventory.InventoryID)). + LEFT_JOIN(Customer, Customer.CustomerID.EQ(Rental.CustomerID)), + ).ORDER_BY( + Actor.ActorID.ASC(), + Film.FilmID.ASC(), + Category.CategoryID.ASC(), + Inventory.InventoryID.ASC(), + Rental.RentalID.ASC(), + ) + + var dest allInfo + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + //testutils.SaveJSONFile(dest, "./testdata/results/mysql/dvds_join_everything.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/dvds_join_everything.json") +} + +func BenchmarkTestDVDsJoinEverythingJSON(b *testing.B) { + for i := 0; i < b.N; i++ { + testDVDsJoinEverythingJSON(b) + } +} + +func TestDVDsJoinEverythingJSON(t *testing.T) { + testDVDsJoinEverythingJSON(t) +} + +func testDVDsJoinEverythingJSON(t require.TestingT) { + stmt := SELECT_JSON_ARR( + Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate, + + SELECT_JSON_ARR( + Film.AllColumns, + + SELECT_JSON_OBJ(Language.AllColumns). + FROM(Language). + WHERE(Language.LanguageID.EQ(Film.LanguageID)).AS("Language"), + + SELECT_JSON_ARR(Category.AllColumns). + FROM(Category.INNER_JOIN(FilmCategory, FilmCategory.CategoryID.EQ(Category.CategoryID))). + WHERE(FilmCategory.FilmID.EQ(Film.FilmID)).AS("Categories"), + + SELECT_JSON_ARR( + Inventory.AllColumns, + + SELECT_JSON_ARR( + Rental.AllColumns, + + SELECT_JSON_OBJ(Customer.AllColumns). + FROM(Customer). + WHERE(Customer.CustomerID.EQ(Rental.CustomerID)).AS("Customer"), + ).FROM(Rental). + WHERE(Rental.InventoryID.EQ(Inventory.InventoryID)). + ORDER_BY(Rental.RentalID).AS("Rentals"), + ).FROM(Inventory). + WHERE(Inventory.FilmID.EQ(Film.FilmID)). + ORDER_BY(Inventory.InventoryID).AS("Inventories"), + ).FROM(Film. + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)), + ).WHERE(FilmActor.ActorID.EQ(Actor.ActorID)). + ORDER_BY(Film.FilmID.ASC()).AS("Films"), + ).FROM(Actor). + ORDER_BY(Actor.ActorID.ASC()) + + //fmt.Println(stmt.DebugSql()) + + var dest allInfo + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + //testutils.SaveJSONFile(dest, "./testdata/results/mysql/dvds_join_everything2.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/dvds_join_everything.json") +} diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 800de18..c037c6f 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -3,6 +3,7 @@ package mysql import ( "os" "os/exec" + "path/filepath" "strconv" "testing" @@ -304,7 +305,7 @@ func newLinkTableImpl(schemaName, tableName, alias string) linkTable { DescriptionColumn = mysql.StringColumn("description") allColumns = mysql.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn} mutableColumns = mysql.ColumnList{URLColumn, NameColumn, DescriptionColumn} - defaultColumns = mysql.ColumnList{DescriptionColumn} + defaultColumns = mysql.ColumnList{} ) return linkTable{ @@ -606,3 +607,398 @@ func UseSchema(schema string) { StaffList = StaffList.FromSchema(schema) } ` + +func TestGeneratedTestSampleDatabase(t *testing.T) { + + enumDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/enum/") + modelDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/model/") + tableDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/table/") + viewDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/view/") + + testutils.AssertFileNamesEqual(t, enumDir, "all_types_enum.go", "all_types_enum_ptr.go", + "all_types_view_enum.go", "all_types_view_enum_ptr.go") + testutils.AssertFileContent(t, enumDir+"/all_types_enum.go", allTypesEnum) + + testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_enum.go", "all_types_enum_ptr.go", + "all_types_view.go", "all_types_view_enum.go", "all_types_view_enum_ptr.go", "link.go", "link2.go", + "floats.go", "user.go") + + testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) + + testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", + "link.go", "link2.go", "user.go", "floats.go", "table_use_schema.go") + testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) + + testutils.AssertFileNamesEqual(t, viewDir, "all_types_view.go", "view_use_schema.go") +} + +var allTypesEnum = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package enum + +import "github.com/go-jet/jet/v2/mysql" + +var AllTypesEnum = &struct { + Value1 mysql.StringExpression + Value2 mysql.StringExpression + Value3 mysql.StringExpression +}{ + Value1: mysql.NewEnumValue("value1"), + Value2: mysql.NewEnumValue("value2"), + Value3: mysql.NewEnumValue("value3"), +} +` + +var allTypesModelContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package model + +import ( + "time" +) + +type AllTypes struct { + ID int32 ` + "`" + `sql:"primary_key"` + "`" + ` + Boolean bool + BooleanPtr *bool + TinyInt int8 + UTinyInt uint8 + SmallInt int16 + USmallInt uint16 + MediumInt int32 + UMediumInt uint32 + Integer int32 + UInteger uint32 + BigInt int64 + UBigInt uint64 + TinyIntPtr *int8 + UTinyIntPtr *uint8 + SmallIntPtr *int16 + USmallIntPtr *uint16 + MediumIntPtr *int32 + UMediumIntPtr *uint32 + IntegerPtr *int32 + UIntegerPtr *uint32 + BigIntPtr *int64 + UBigIntPtr *uint64 + Decimal float64 + DecimalPtr *float64 + Numeric float64 + NumericPtr *float64 + Float float64 + FloatPtr *float64 + Double float64 + DoublePtr *float64 + Real float64 + RealPtr *float64 + Bit string + BitPtr *string + Time time.Time + TimePtr *time.Time + Date time.Time + DatePtr *time.Time + DateTime time.Time + DateTimePtr *time.Time + Timestamp time.Time + TimestampPtr *time.Time + Year int16 + YearPtr *int16 + Char string + CharPtr *string + VarChar string + VarCharPtr *string + Binary []byte + BinaryPtr *[]byte + VarBinary []byte + VarBinaryPtr *[]byte + Blob []byte + BlobPtr *[]byte + Text string + TextPtr *string + Enum AllTypesEnum + EnumPtr *AllTypesEnumPtr + Set string + SetPtr *string + JSON string + JSONPtr *string +} +` + +var allTypesTableContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/mysql" +) + +var AllTypes = newAllTypesTable("test_sample", "all_types", "") + +type allTypesTable struct { + mysql.Table + + // Columns + ID mysql.ColumnInteger + Boolean mysql.ColumnBool + BooleanPtr mysql.ColumnBool + TinyInt mysql.ColumnInteger + UTinyInt mysql.ColumnInteger + SmallInt mysql.ColumnInteger + USmallInt mysql.ColumnInteger + MediumInt mysql.ColumnInteger + UMediumInt mysql.ColumnInteger + Integer mysql.ColumnInteger + UInteger mysql.ColumnInteger + BigInt mysql.ColumnInteger + UBigInt mysql.ColumnInteger + TinyIntPtr mysql.ColumnInteger + UTinyIntPtr mysql.ColumnInteger + SmallIntPtr mysql.ColumnInteger + USmallIntPtr mysql.ColumnInteger + MediumIntPtr mysql.ColumnInteger + UMediumIntPtr mysql.ColumnInteger + IntegerPtr mysql.ColumnInteger + UIntegerPtr mysql.ColumnInteger + BigIntPtr mysql.ColumnInteger + UBigIntPtr mysql.ColumnInteger + Decimal mysql.ColumnFloat + DecimalPtr mysql.ColumnFloat + Numeric mysql.ColumnFloat + NumericPtr mysql.ColumnFloat + Float mysql.ColumnFloat + FloatPtr mysql.ColumnFloat + Double mysql.ColumnFloat + DoublePtr mysql.ColumnFloat + Real mysql.ColumnFloat + RealPtr mysql.ColumnFloat + Bit mysql.ColumnString + BitPtr mysql.ColumnString + Time mysql.ColumnTime + TimePtr mysql.ColumnTime + Date mysql.ColumnDate + DatePtr mysql.ColumnDate + DateTime mysql.ColumnTimestamp + DateTimePtr mysql.ColumnTimestamp + Timestamp mysql.ColumnTimestamp + TimestampPtr mysql.ColumnTimestamp + Year mysql.ColumnInteger + YearPtr mysql.ColumnInteger + Char mysql.ColumnString + CharPtr mysql.ColumnString + VarChar mysql.ColumnString + VarCharPtr mysql.ColumnString + Binary mysql.ColumnBlob + BinaryPtr mysql.ColumnBlob + VarBinary mysql.ColumnBlob + VarBinaryPtr mysql.ColumnBlob + Blob mysql.ColumnBlob + BlobPtr mysql.ColumnBlob + Text mysql.ColumnString + TextPtr mysql.ColumnString + Enum mysql.ColumnString + EnumPtr mysql.ColumnString + Set mysql.ColumnString + SetPtr mysql.ColumnString + JSON mysql.ColumnString + JSONPtr mysql.ColumnString + + AllColumns mysql.ColumnList + MutableColumns mysql.ColumnList + DefaultColumns mysql.ColumnList +} + +type AllTypesTable struct { + allTypesTable + + NEW allTypesTable +} + +// AS creates new AllTypesTable with assigned alias +func (a AllTypesTable) AS(alias string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new AllTypesTable with assigned schema name +func (a AllTypesTable) FromSchema(schemaName string) *AllTypesTable { + return newAllTypesTable(schemaName, a.TableName(), a.Alias()) +} + +// WithPrefix creates new AllTypesTable with assigned table prefix +func (a AllTypesTable) WithPrefix(prefix string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new AllTypesTable with assigned table suffix +func (a AllTypesTable) WithSuffix(suffix string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + +func newAllTypesTable(schemaName, tableName, alias string) *AllTypesTable { + return &AllTypesTable{ + allTypesTable: newAllTypesTableImpl(schemaName, tableName, alias), + NEW: newAllTypesTableImpl("", "new", ""), + } +} + +func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { + var ( + IDColumn = mysql.IntegerColumn("id") + BooleanColumn = mysql.BoolColumn("boolean") + BooleanPtrColumn = mysql.BoolColumn("boolean_ptr") + TinyIntColumn = mysql.IntegerColumn("tiny_int") + UTinyIntColumn = mysql.IntegerColumn("u_tiny_int") + SmallIntColumn = mysql.IntegerColumn("small_int") + USmallIntColumn = mysql.IntegerColumn("u_small_int") + MediumIntColumn = mysql.IntegerColumn("medium_int") + UMediumIntColumn = mysql.IntegerColumn("u_medium_int") + IntegerColumn = mysql.IntegerColumn("integer") + UIntegerColumn = mysql.IntegerColumn("u_integer") + BigIntColumn = mysql.IntegerColumn("big_int") + UBigIntColumn = mysql.IntegerColumn("u_big_int") + TinyIntPtrColumn = mysql.IntegerColumn("tiny_int_ptr") + UTinyIntPtrColumn = mysql.IntegerColumn("u_tiny_int_ptr") + SmallIntPtrColumn = mysql.IntegerColumn("small_int_ptr") + USmallIntPtrColumn = mysql.IntegerColumn("u_small_int_ptr") + MediumIntPtrColumn = mysql.IntegerColumn("medium_int_ptr") + UMediumIntPtrColumn = mysql.IntegerColumn("u_medium_int_ptr") + IntegerPtrColumn = mysql.IntegerColumn("integer_ptr") + UIntegerPtrColumn = mysql.IntegerColumn("u_integer_ptr") + BigIntPtrColumn = mysql.IntegerColumn("big_int_ptr") + UBigIntPtrColumn = mysql.IntegerColumn("u_big_int_ptr") + DecimalColumn = mysql.FloatColumn("decimal") + DecimalPtrColumn = mysql.FloatColumn("decimal_ptr") + NumericColumn = mysql.FloatColumn("numeric") + NumericPtrColumn = mysql.FloatColumn("numeric_ptr") + FloatColumn = mysql.FloatColumn("float") + FloatPtrColumn = mysql.FloatColumn("float_ptr") + DoubleColumn = mysql.FloatColumn("double") + DoublePtrColumn = mysql.FloatColumn("double_ptr") + RealColumn = mysql.FloatColumn("real") + RealPtrColumn = mysql.FloatColumn("real_ptr") + BitColumn = mysql.StringColumn("bit") + BitPtrColumn = mysql.StringColumn("bit_ptr") + TimeColumn = mysql.TimeColumn("time") + TimePtrColumn = mysql.TimeColumn("time_ptr") + DateColumn = mysql.DateColumn("date") + DatePtrColumn = mysql.DateColumn("date_ptr") + DateTimeColumn = mysql.TimestampColumn("date_time") + DateTimePtrColumn = mysql.TimestampColumn("date_time_ptr") + TimestampColumn = mysql.TimestampColumn("timestamp") + TimestampPtrColumn = mysql.TimestampColumn("timestamp_ptr") + YearColumn = mysql.IntegerColumn("year") + YearPtrColumn = mysql.IntegerColumn("year_ptr") + CharColumn = mysql.StringColumn("char") + CharPtrColumn = mysql.StringColumn("char_ptr") + VarCharColumn = mysql.StringColumn("var_char") + VarCharPtrColumn = mysql.StringColumn("var_char_ptr") + BinaryColumn = mysql.BlobColumn("binary") + BinaryPtrColumn = mysql.BlobColumn("binary_ptr") + VarBinaryColumn = mysql.BlobColumn("var_binary") + VarBinaryPtrColumn = mysql.BlobColumn("var_binary_ptr") + BlobColumn = mysql.BlobColumn("blob") + BlobPtrColumn = mysql.BlobColumn("blob_ptr") + TextColumn = mysql.StringColumn("text") + TextPtrColumn = mysql.StringColumn("text_ptr") + EnumColumn = mysql.StringColumn("enum") + EnumPtrColumn = mysql.StringColumn("enum_ptr") + SetColumn = mysql.StringColumn("set") + SetPtrColumn = mysql.StringColumn("set_ptr") + JSONColumn = mysql.StringColumn("json") + JSONPtrColumn = mysql.StringColumn("json_ptr") + allColumns = mysql.ColumnList{IDColumn, BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn} + mutableColumns = mysql.ColumnList{BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn} + defaultColumns = mysql.ColumnList{BooleanColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, DecimalColumn, NumericColumn, FloatColumn, DoubleColumn, RealColumn, BitColumn, TimeColumn, DateColumn, DateTimeColumn, TimestampColumn, YearColumn, CharColumn, VarCharColumn, BinaryColumn, VarBinaryColumn, EnumColumn, SetColumn} + ) + + return allTypesTable{ + Table: mysql.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + ID: IDColumn, + Boolean: BooleanColumn, + BooleanPtr: BooleanPtrColumn, + TinyInt: TinyIntColumn, + UTinyInt: UTinyIntColumn, + SmallInt: SmallIntColumn, + USmallInt: USmallIntColumn, + MediumInt: MediumIntColumn, + UMediumInt: UMediumIntColumn, + Integer: IntegerColumn, + UInteger: UIntegerColumn, + BigInt: BigIntColumn, + UBigInt: UBigIntColumn, + TinyIntPtr: TinyIntPtrColumn, + UTinyIntPtr: UTinyIntPtrColumn, + SmallIntPtr: SmallIntPtrColumn, + USmallIntPtr: USmallIntPtrColumn, + MediumIntPtr: MediumIntPtrColumn, + UMediumIntPtr: UMediumIntPtrColumn, + IntegerPtr: IntegerPtrColumn, + UIntegerPtr: UIntegerPtrColumn, + BigIntPtr: BigIntPtrColumn, + UBigIntPtr: UBigIntPtrColumn, + Decimal: DecimalColumn, + DecimalPtr: DecimalPtrColumn, + Numeric: NumericColumn, + NumericPtr: NumericPtrColumn, + Float: FloatColumn, + FloatPtr: FloatPtrColumn, + Double: DoubleColumn, + DoublePtr: DoublePtrColumn, + Real: RealColumn, + RealPtr: RealPtrColumn, + Bit: BitColumn, + BitPtr: BitPtrColumn, + Time: TimeColumn, + TimePtr: TimePtrColumn, + Date: DateColumn, + DatePtr: DatePtrColumn, + DateTime: DateTimeColumn, + DateTimePtr: DateTimePtrColumn, + Timestamp: TimestampColumn, + TimestampPtr: TimestampPtrColumn, + Year: YearColumn, + YearPtr: YearPtrColumn, + Char: CharColumn, + CharPtr: CharPtrColumn, + VarChar: VarCharColumn, + VarCharPtr: VarCharPtrColumn, + Binary: BinaryColumn, + BinaryPtr: BinaryPtrColumn, + VarBinary: VarBinaryColumn, + VarBinaryPtr: VarBinaryPtrColumn, + Blob: BlobColumn, + BlobPtr: BlobPtrColumn, + Text: TextColumn, + TextPtr: TextPtrColumn, + Enum: EnumColumn, + EnumPtr: EnumPtrColumn, + Set: SetColumn, + SetPtr: SetPtrColumn, + JSON: JSONColumn, + JSONPtr: JSONPtrColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + DefaultColumns: defaultColumns, + } +} +` diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index e4f4322..1d6f68d 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -8,6 +8,7 @@ import ( jetmysql "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" _ "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/require" "runtime" @@ -21,12 +22,14 @@ var db *stmtcache.DB var source string var withStatementCaching bool +var testRoot string const MariaDB = "MariaDB" func init() { source = os.Getenv("MY_SQL_SOURCE") withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" + testRoot = repo.GetTestsDirPath() } func sourceIsMariaDB() bool { diff --git a/tests/mysql/select_json_test.go b/tests/mysql/select_json_test.go new file mode 100644 index 0000000..cea3328 --- /dev/null +++ b/tests/mysql/select_json_test.go @@ -0,0 +1,453 @@ +package mysql + +import ( + "context" + "fmt" + "github.com/go-jet/jet/v2/qrm" + "strings" + "testing" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/mysql" + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" + + "github.com/stretchr/testify/require" +) + +var ctx = context.Background() + +func TestSelectJsonObj(t *testing.T) { + stmt := SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.EQ(Int(2))) + + testutils.AssertStatementSql(t, stmt, ` +SELECT JSON_OBJECT( + 'actorID', actor.actor_id, + 'firstName', actor.first_name, + 'lastName', actor.last_name, + 'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ') + ) AS "json" +FROM dvds.actor +WHERE actor.actor_id = ?; +`, int64(2)) + + var dest model.Actor + + err := stmt.Query(db, &dest) + require.Nil(t, err) + + testutils.AssertDeepEqual(t, dest, actor2) + requireLogged(t, stmt) + requireQueryLogged(t, stmt, 1) +} + +func TestSelectJsonObj_NestedObj(t *testing.T) { + stmt := SELECT_JSON_OBJ( + Actor.AllColumns, + + SELECT_JSON_OBJ(Film.AllColumns). + FROM(FilmActor.INNER_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID))). + WHERE(Actor.ActorID.EQ(FilmActor.ActorID)). + ORDER_BY(Film.Length.DESC()). + LIMIT(1).OFFSET(3).AS("LongestFilm"), + ).FROM( + Actor, + ).WHERE( + Actor.ActorID.EQ(Int(2)), + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT JSON_OBJECT( + 'actorID', actor.actor_id, + 'firstName', actor.first_name, + 'lastName', actor.last_name, + 'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ'), + 'LongestFilm', ( + SELECT JSON_OBJECT( + 'filmID', film.film_id, + 'title', film.title, + 'description', film.description, + 'releaseYear', film.release_year, + 'languageID', film.language_id, + 'originalLanguageID', film.original_language_id, + 'rentalDuration', film.rental_duration, + 'rentalRate', film.rental_rate, + 'length', film.length, + 'replacementCost', film.replacement_cost, + 'rating', film.rating, + 'specialFeatures', film.special_features, + 'lastUpdate', DATE_FORMAT(film.last_update,'%Y-%m-%dT%H:%i:%s.%fZ') + ) AS "json" + FROM dvds.film_actor + INNER JOIN dvds.film ON (film.film_id = film_actor.film_id) + WHERE actor.actor_id = film_actor.actor_id + ORDER BY film.length DESC + LIMIT ? + OFFSET ? + ) + ) AS "json" +FROM dvds.actor +WHERE actor.actor_id = ?; +`) + + var dest struct { + model.Actor + + LongestFilm model.Film + } + + err := stmt.QueryContext(ctx, db, &dest) + require.Nil(t, err) + testutils.AssertJSON(t, dest, ` +{ + "ActorID": 2, + "FirstName": "NICK", + "LastName": "WAHLBERG", + "LastUpdate": "2006-02-15T04:34:33Z", + "LongestFilm": { + "FilmID": 754, + "Title": "RUSHMORE MERMAID", + "Description": "A Boring Story of a Woman And a Moose who must Reach a Husband in A Shark Tank", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 2.99, + "Length": 150, + "ReplacementCost": 17.99, + "Rating": "PG-13", + "SpecialFeatures": "Trailers,Commentaries,Deleted Scenes", + "LastUpdate": "2006-02-15T05:03:42Z" + } +} +`) + +} + +func TestSelectJsonArr(t *testing.T) { + stmt := SELECT_JSON_ARR(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT JSON_ARRAYAGG(JSON_OBJECT( + 'actorID', actor.actor_id, + 'firstName', actor.first_name, + 'lastName', actor.last_name, + 'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ') + )) AS "json" +FROM dvds.actor +ORDER BY actor.actor_id; +`) + + var dest []model.Actor + + err := stmt.Query(db, &dest) + require.Nil(t, err) + + testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json") + requireLogged(t, stmt) + requireQueryLogged(t, stmt, 1) +} + +func TestSelectJsonArr_NestedArr(t *testing.T) { + stmt := SELECT_JSON_ARR( + Actor.AllColumns, + + SELECT_JSON_ARR( + Film.AllColumns, + ).FROM( + FilmActor.INNER_JOIN( + Film, + Film.FilmID.EQ(FilmActor.FilmID).AND( + Actor.ActorID.EQ(FilmActor.ActorID)), + ), + ).WHERE( + Film.FilmID.MOD(Int(17)).EQ(Int(0)), + ).ORDER_BY( + Film.Length.DESC(), + ).AS("Films"), + ).FROM( + Actor, + ).WHERE( + Actor.ActorID.BETWEEN(Int(1), Int(3)), + ).ORDER_BY( + Actor.ActorID, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT JSON_ARRAYAGG(JSON_OBJECT( + 'actorID', actor.actor_id, + 'firstName', actor.first_name, + 'lastName', actor.last_name, + 'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ'), + 'Films', ( + SELECT JSON_ARRAYAGG(JSON_OBJECT( + 'filmID', film.film_id, + 'title', film.title, + 'description', film.description, + 'releaseYear', film.release_year, + 'languageID', film.language_id, + 'originalLanguageID', film.original_language_id, + 'rentalDuration', film.rental_duration, + 'rentalRate', film.rental_rate, + 'length', film.length, + 'replacementCost', film.replacement_cost, + 'rating', film.rating, + 'specialFeatures', film.special_features, + 'lastUpdate', DATE_FORMAT(film.last_update,'%Y-%m-%dT%H:%i:%s.%fZ') + )) AS "json" + FROM dvds.film_actor + INNER JOIN dvds.film ON ((film.film_id = film_actor.film_id) AND (actor.actor_id = film_actor.actor_id)) + WHERE (film.film_id % 17) = 0 + ORDER BY film.length DESC + ) + )) AS "json" +FROM dvds.actor +WHERE actor.actor_id BETWEEN 1 AND 3 +ORDER BY actor.actor_id; +`) + + var dest []struct { + model.Actor + + Films []model.Film + } + + err := stmt.QueryContext(ctx, db, &dest) + fmt.Println(err) + require.Nil(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "ActorID": 1, + "FirstName": "PENELOPE", + "LastName": "GUINESS", + "LastUpdate": "2006-02-15T04:34:33Z", + "Films": null + }, + { + "ActorID": 2, + "FirstName": "NICK", + "LastName": "WAHLBERG", + "LastUpdate": "2006-02-15T04:34:33Z", + "Films": [ + { + "FilmID": 357, + "Title": "GILBERT PELICAN", + "Description": "A Fateful Tale of a Man And a Feminist who must Conquer a Crocodile in A Manhattan Penthouse", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 7, + "RentalRate": 0.99, + "Length": 114, + "ReplacementCost": 13.99, + "Rating": "G", + "SpecialFeatures": "Trailers,Commentaries", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + { + "FilmID": 561, + "Title": "MASK PEACH", + "Description": "A Boring Character Study of a Student And a Robot who must Meet a Woman in California", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 6, + "RentalRate": 2.99, + "Length": 123, + "ReplacementCost": 26.99, + "Rating": "NC-17", + "SpecialFeatures": "Commentaries,Deleted Scenes", + "LastUpdate": "2006-02-15T05:03:42Z" + } + ] + }, + { + "ActorID": 3, + "FirstName": "ED", + "LastName": "CHASE", + "LastUpdate": "2006-02-15T04:34:33Z", + "Films": [ + { + "FilmID": 17, + "Title": "ALONE TRIP", + "Description": "A Fast-Paced Character Study of a Composer And a Dog who must Outgun a Boat in An Abandoned Fun House", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 3, + "RentalRate": 0.99, + "Length": 82, + "ReplacementCost": 14.99, + "Rating": "R", + "SpecialFeatures": "Trailers,Behind the Scenes", + "LastUpdate": "2006-02-15T05:03:42Z" + }, + { + "FilmID": 289, + "Title": "EVE RESURRECTION", + "Description": "A Awe-Inspiring Yarn of a Pastry Chef And a Database Administrator who must Challenge a Teacher in A Baloon", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 5, + "RentalRate": 4.99, + "Length": 66, + "ReplacementCost": 25.99, + "Rating": "G", + "SpecialFeatures": "Trailers,Commentaries,Deleted Scenes", + "LastUpdate": "2006-02-15T05:03:42Z" + } + ] + } +] +`) + +} + +func TestSelectJson_GroupBy(t *testing.T) { + skipForMariaDB(t) // scope issues with select without FROM + + subQuery := SELECT( + Customer.AllColumns, + + SUM(Payment.Amount).AS("sum"), + AVG(Payment.Amount).AS("avg"), + MAX(Payment.Amount).AS("max"), + MIN(Payment.Amount).AS("min"), + COUNT(Payment.Amount).AS("count"), + ).FROM( + Payment. + INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)), + ).GROUP_BY( + Customer.CustomerID, + ).HAVING( + SUMf(Payment.Amount).GT(Float(125)), + ).ORDER_BY( + Customer.CustomerID, SUM(Payment.Amount).ASC(), + ).AsTable("customers_info") + + stmt := SELECT_JSON_ARR( + subQuery.AllColumns().Except( // TODO: remove when ColumnList.From() is implemented + FloatColumn("sum"), + FloatColumn("avg"), + FloatColumn("max"), + FloatColumn("min"), + FloatColumn("count"), + ), + + SELECT_JSON_OBJ( + FloatColumn("sum").From(subQuery), + FloatColumn("avg").From(subQuery), + FloatColumn("max").From(subQuery), + FloatColumn("min").From(subQuery), + FloatColumn("count").From(subQuery), + ).AS("amount"), + ).FROM(subQuery) + + testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(` +SELECT JSON_ARRAYAGG(JSON_OBJECT( + 'customerID', customers_info.''customer.customer_id'', + 'storeID', customers_info.''customer.store_id'', + 'firstName', customers_info.''customer.first_name'', + 'lastName', customers_info.''customer.last_name'', + 'email', customers_info.''customer.email'', + 'addressID', customers_info.''customer.address_id'', + 'active', customers_info.''customer.active'' = 1, + 'createDate', DATE_FORMAT(customers_info.''customer.create_date'','%Y-%m-%dT%H:%i:%s.%fZ'), + 'lastUpdate', DATE_FORMAT(customers_info.''customer.last_update'','%Y-%m-%dT%H:%i:%s.%fZ'), + 'amount', ( + SELECT JSON_OBJECT( + 'sum', customers_info.sum, + 'avg', customers_info.avg, + 'max', customers_info.max, + 'min', customers_info.min, + 'count', customers_info.count + ) AS "json" + ) + )) AS "json" +FROM ( + SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.active AS "customer.active", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + SUM(payment.amount) AS "sum", + AVG(payment.amount) AS "avg", + MAX(payment.amount) AS "max", + MIN(payment.amount) AS "min", + COUNT(payment.amount) AS "count" + FROM dvds.payment + INNER JOIN dvds.customer ON (customer.customer_id = payment.customer_id) + GROUP BY customer.customer_id + HAVING SUM(payment.amount) > 125 + ORDER BY customer.customer_id, SUM(payment.amount) ASC + ) AS customers_info; +`, "''", "`")) + + var dest []struct { + model.Customer + + Amount struct { + Sum float64 + Avg float64 + Max float64 + Min float64 + Count int64 + } + } + + err := stmt.QueryContext(ctx, db, &dest) + require.Nil(t, err) + + testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json") + requireLogged(t, stmt) +} + +func TestSelectJsonObject_EmptyResult(t *testing.T) { + + t.Run("json obj", func(t *testing.T) { + stmt := SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.FirstName.EQ(String("Kowalski"))) + + var dest model.Actor + + err := stmt.QueryContext(ctx, db, &dest) + require.ErrorIs(t, err, qrm.ErrNoRows) + }) + + t.Run("json arr", func(t *testing.T) { + stmt := SELECT_JSON_ARR(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.FirstName.EQ(String("Kowalski"))) + + var dest []model.Actor + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + require.Empty(t, dest) + }) +} + +func TestSelectJson_ProjectionNotAliased(t *testing.T) { + + t.Run("expression not aliased", func(t *testing.T) { + testutils.AssertPanicErr(t, func() { + stmt := SELECT_JSON_ARR( + Int(2).ADD(Customer.CustomerID), + ).FROM(Customer) + + stmt.DebugSql() + + }, "jet: expression need to be aliased when used as SELECT JSON projection.") + }) +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 6e4507d..e27c683 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -19,9 +19,9 @@ import ( ) func TestSelect_ScanToStruct(t *testing.T) { - query := Actor. - SELECT(Actor.AllColumns). + query := SELECT(Actor.AllColumns). DISTINCT(). + FROM(Actor). WHERE(Actor.ActorID.EQ(Int(2))) testutils.AssertStatementSql(t, query, ` @@ -50,9 +50,56 @@ var actor2 = model.Actor{ LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), } +func TestSelect_NestedObject(t *testing.T) { + stmt := SELECT( + Actor.AllColumns, + Film.AllColumns, + ).FROM( + Actor. + LEFT_JOIN(FilmActor, FilmActor.ActorID.EQ(Actor.ActorID)). + LEFT_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)), + ).WHERE( + Actor.ActorID.EQ(Int(2)), + ).ORDER_BY( + Film.LastUpdate.DESC(), + ).LIMIT(1) + + var dest struct { + model.Actor + + LatestFilm model.Film + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +{ + "ActorID": 2, + "FirstName": "NICK", + "LastName": "WAHLBERG", + "LastUpdate": "2006-02-15T04:34:33Z", + "LatestFilm": { + "FilmID": 3, + "Title": "ADAPTATION HOLES", + "Description": "A Astounding Reflection of a Lumberjack And a Car who must Sink a Lumberjack in A Baloon Factory", + "ReleaseYear": 2006, + "LanguageID": 1, + "OriginalLanguageID": null, + "RentalDuration": 7, + "RentalRate": 2.99, + "Length": 50, + "ReplacementCost": 18.99, + "Rating": "NC-17", + "SpecialFeatures": "Trailers,Deleted Scenes", + "LastUpdate": "2006-02-15T05:03:42Z" + } +} +`) +} + func TestSelect_ScanToSlice(t *testing.T) { - query := Actor. - SELECT(Actor.AllColumns). + query := SELECT(Actor.AllColumns). + FROM(Actor). ORDER_BY(Actor.ActorID) testutils.AssertStatementSql(t, query, ` @@ -107,19 +154,20 @@ GROUP BY payment.customer_id HAVING SUM(payment.amount) > 125.6 ORDER BY payment.customer_id, SUM(payment.amount) ASC; ` - query := Payment. - INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)). - SELECT( - Customer.AllColumns, + query := SELECT( + Customer.AllColumns, - SUMf(Payment.Amount).AS("amount.sum"), - AVG(Payment.Amount).AS("amount.avg"), - MAX(Payment.PaymentDate).AS("amount.max_date"), - MAXf(Payment.Amount).AS("amount.max"), - MIN(Payment.PaymentDate).AS("amount.min_date"), - MINf(Payment.Amount).AS("amount.min"), - COUNT(Payment.Amount).AS("amount.count"), - ). + SUMf(Payment.Amount).AS("amount.sum"), + AVG(Payment.Amount).AS("amount.avg"), + MAX(Payment.PaymentDate).AS("amount.max_date"), + MAXf(Payment.Amount).AS("amount.max"), + MIN(Payment.PaymentDate).AS("amount.min_date"), + MINf(Payment.Amount).AS("amount.min"), + COUNT(Payment.Amount).AS("amount.count"), + ).FROM( + Payment. + INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)), + ). GROUP_BY(Payment.CustomerID). HAVING( SUMf(Payment.Amount).GT(Float(125.6)), @@ -1122,7 +1170,7 @@ WHERE payment.payment_id < ? WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) ORDER BY payment.customer_id; ` - query := Payment.SELECT( + query := SELECT( AVG(Payment.Amount).OVER(), AVG(Payment.Amount).OVER(Window("w1")), AVG(Payment.Amount).OVER( @@ -1131,7 +1179,7 @@ ORDER BY payment.customer_id; RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), ), AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), - ). + ).FROM(Payment). WHERE(Payment.PaymentID.LT(Int(10))). WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). WINDOW("w2").AS(Window("w1")). diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index e720e24..c2c6a5c 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,6 +1,8 @@ package postgres import ( + "encoding/base64" + "fmt" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" "math" @@ -36,6 +38,141 @@ func TestAllTypesSelect(t *testing.T) { testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } +func TestAllTypesSelectJson(t *testing.T) { + + stmt := SELECT_JSON_ARR( + AllTypesAllColumns.Except( + AllTypes.JSON, AllTypes.JSONPtr, + AllTypes.Jsonb, AllTypes.JsonbPtr, + AllTypes.TextArray, AllTypes.TextArrayPtr, + AllTypes.JsonbArray, AllTypes.IntegerArray, AllTypes.IntegerArrayPtr, + AllTypes.TextMultiDimArray, AllTypes.TextMultiDimArrayPtr, + ), + // unsupported at the moment, casting to text allows these columns to be assigned to string fields + CAST(AllTypes.JSONPtr).AS_TEXT().AS("jsonPtr"), + CAST(AllTypes.JSON).AS_TEXT().AS("JSON"), + CAST(AllTypes.JsonbPtr).AS_TEXT().AS("jsonbPtr"), + CAST(AllTypes.Jsonb).AS_TEXT().AS("Jsonb"), + CAST(AllTypes.TextArrayPtr).AS_TEXT().AS("TextArrayPtr"), + CAST(AllTypes.TextArray).AS_TEXT().AS("TextArray"), + CAST(AllTypes.JsonbArray).AS_TEXT().AS("JsonbArray"), + CAST(AllTypes.IntegerArray).AS_TEXT().AS("IntegerArray"), + CAST(AllTypes.IntegerArrayPtr).AS_TEXT().AS("IntegerArrayPtr"), + CAST(AllTypes.TextMultiDimArray).AS_TEXT().AS("TextMultiDimArray"), + CAST(AllTypes.TextMultiDimArrayPtr).AS_TEXT().AS("TextMultiDimArrayPtr"), + ).FROM(AllTypes) + + testutils.AssertStatementSql(t, stmt, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT all_types.small_int_ptr AS "smallIntPtr", + all_types.small_int AS "smallInt", + all_types.integer_ptr AS "integerPtr", + all_types.integer AS "integer", + all_types.big_int_ptr AS "bigIntPtr", + all_types.big_int AS "bigInt", + all_types.decimal_ptr AS "decimalPtr", + all_types.decimal AS "decimal", + all_types.numeric_ptr AS "numericPtr", + all_types.numeric AS "numeric", + all_types.real_ptr AS "realPtr", + all_types.real AS "real", + all_types.double_precision_ptr AS "doublePrecisionPtr", + all_types.double_precision AS "doublePrecision", + all_types.smallserial AS "smallserial", + all_types.serial AS "serial", + all_types.bigserial AS "bigserial", + all_types.var_char_ptr AS "varCharPtr", + all_types.var_char AS "varChar", + all_types.char_ptr AS "charPtr", + all_types.char AS "char", + all_types.text_ptr AS "textPtr", + all_types.text AS "text", + ENCODE(all_types.bytea_ptr, 'base64') AS "byteaPtr", + ENCODE(all_types.bytea, 'base64') AS "bytea", + all_types.timestampz_ptr AS "timestampzPtr", + all_types.timestampz AS "timestampz", + to_char(all_types.timestamp_ptr, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampPtr", + to_char(all_types.timestamp, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp", + to_char(all_types.date_ptr::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "datePtr", + to_char(all_types.date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date", + '0000-01-01T' || to_char('2000-10-10'::date + all_types.timez_ptr, 'HH24:MI:SS.USTZH:TZM') AS "timezPtr", + '0000-01-01T' || to_char('2000-10-10'::date + all_types.timez, 'HH24:MI:SS.USTZH:TZM') AS "timez", + '0000-01-01T' || to_char('2000-10-10'::date + all_types.time_ptr, 'HH24:MI:SS.USZ') AS "timePtr", + '0000-01-01T' || to_char('2000-10-10'::date + all_types.time, 'HH24:MI:SS.USZ') AS "time", + all_types.interval_ptr AS "intervalPtr", + all_types.interval AS "interval", + all_types.boolean_ptr AS "booleanPtr", + all_types.boolean AS "boolean", + all_types.point_ptr AS "pointPtr", + all_types.bit_ptr AS "bitPtr", + all_types.bit AS "bit", + all_types.bit_varying_ptr AS "bitVaryingPtr", + all_types.bit_varying AS "bitVarying", + all_types.tsvector_ptr AS "tsvectorPtr", + all_types.tsvector AS "tsvector", + all_types.uuid_ptr AS "uuidPtr", + all_types.uuid AS "uuid", + all_types.xml_ptr AS "xmlPtr", + all_types.xml AS "xml", + all_types.mood_ptr AS "moodPtr", + all_types.mood AS "mood", + all_types.json_ptr::text AS "jsonPtr", + all_types.json::text AS "JSON", + all_types.jsonb_ptr::text AS "jsonbPtr", + all_types.jsonb::text AS "Jsonb", + all_types.text_array_ptr::text AS "TextArrayPtr", + all_types.text_array::text AS "TextArray", + all_types.jsonb_array::text AS "JsonbArray", + all_types.integer_array::text AS "IntegerArray", + all_types.integer_array_ptr::text AS "IntegerArrayPtr", + all_types.text_multi_dim_array::text AS "TextMultiDimArray", + all_types.text_multi_dim_array_ptr::text AS "TextMultiDimArrayPtr" + FROM test_sample.all_types + ) AS records; +`) + + var dest []model.AllTypes + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + // fix inconsistencies between postgres and cockroachdb. + // cockroachdb returns char[N] columns with trailing whitespaces trimmed + if sourceIsCockroachDB() { + dest[0].Char = allTypesRow0.Char + dest[0].CharPtr = allTypesRow0.CharPtr + + dest[1].Char = allTypesRow1.Char + dest[1].CharPtr = allTypesRow1.CharPtr + } + + minus8 := time.FixedZone("UTC", -8*60*60) + plus1 := time.FixedZone("UTC", 60*60) + + // set time local before comparison + dest[0].Timez = *toTZ(&dest[0].Timez, minus8) + dest[0].TimezPtr = toTZ(dest[0].TimezPtr, minus8) + dest[1].Timez = *toTZ(&dest[1].Timez, minus8) + dest[1].TimezPtr = toTZ(dest[1].TimezPtr, minus8) + + dest[0].Timestampz = *toTZ(&dest[0].Timestampz, plus1) + dest[0].TimestampzPtr = toTZ(dest[0].TimestampzPtr, plus1) + dest[1].Timestampz = *toTZ(&dest[1].Timestampz, plus1) + dest[1].TimestampzPtr = toTZ(dest[1].TimestampzPtr, plus1) + + testutils.AssertJsonEqual(t, dest[0], allTypesRow0) + testutils.AssertJsonEqual(t, dest[1], allTypesRow1) +} + +func toTZ(tm *time.Time, loc *time.Location) *time.Time { + if tm == nil { + return nil + } + + return ptr.Of(tm.In(loc)) +} + func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes var dest []AllTypesView @@ -132,7 +269,7 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; requireLogged(t, query) } -func TestBytea(t *testing.T) { +func TestByteaInsert(t *testing.T) { byteArrHex := "\\x48656c6c6f20476f7068657221" byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21") @@ -147,40 +284,42 @@ RETURNING all_types.bytea AS "all_types.bytea", all_types.bytea_ptr AS "all_types.bytea_ptr"; `, byteArrHex, byteArrBin) - var inserted model.AllTypes - err := insertStmt.Query(db, &inserted) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + var inserted model.AllTypes + err := insertStmt.Query(tx, &inserted) + require.NoError(t, err) - require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!") - // It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver. - // pq driver always encodes parameter string if destination column is of type bytea. - // Probably pq driver error. - // require.Equal(t, string(inserted.Bytea), "Hello Gopher!") + require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!") + // It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver. + // pq driver always encodes parameter string if destination column is of type bytea. + // Probably pq driver error. + // require.Equal(t, string(inserted.Bytea), "Hello Gopher!") - stmt := SELECT( - AllTypes.Bytea, - AllTypes.ByteaPtr, - ).FROM( - AllTypes, - ).WHERE( - AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)), - ) + stmt := SELECT( + AllTypes.Bytea, + AllTypes.ByteaPtr, + ).FROM( + AllTypes, + ).WHERE( + AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)), + ) - testutils.AssertStatementSql(t, stmt, ` + testutils.AssertStatementSql(t, stmt, ` SELECT all_types.bytea AS "all_types.bytea", all_types.bytea_ptr AS "all_types.bytea_ptr" FROM test_sample.all_types WHERE all_types.bytea_ptr = $1::bytea; `, byteArrBin) - var dest model.AllTypes + var dest model.AllTypes - err = stmt.Query(db, &dest) - require.NoError(t, err) + err = stmt.Query(tx, &dest) + require.NoError(t, err) - require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!") - // Probably pq driver error. - // require.Equal(t, string(dest.Bytea), "Hello Gopher!") + require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!") + // Probably pq driver error. + // require.Equal(t, string(dest.Bytea), "Hello Gopher!") + }) } func TestAllTypesFromSubQuery(t *testing.T) { @@ -424,6 +563,7 @@ func TestExpressionCast(t *testing.T) { CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(), CAST(String("04:05:06")).AS_INTERVAL(), + CAST(String("some text")).AS_BYTEA().EQ(Bytea([]byte("some text"))), func() ProjectionList { if sourceIsCockroachDB() { @@ -477,7 +617,6 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.BETWEEN(String("min"), String("max")), AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr), AllTypes.Text.CONCAT(String("text2")), - AllTypes.Text.CONCAT(Int(11)), AllTypes.Text.LIKE(String("abc")), AllTypes.Text.NOT_LIKE(String("_b_")), AllTypes.Text.REGEXP_LIKE(String("^t")), @@ -508,18 +647,18 @@ func TestStringOperators(t *testing.T) { CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), - CONVERT(Bytea("bytea"), String("UTF8"), String("LATIN1")), - CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")), - CONVERT_FROM(Bytea("text_in_utf8"), String("UTF8")), - CONVERT_TO(String("text_in_utf8"), String("UTF8")), - ENCODE(Bytea("123\000\001"), String("base64")), - DECODE(String("MTIzAAE="), String("base64")), + CONVERT(Bytea("bytea"), UTF8, LATIN1), + CONVERT(AllTypes.Bytea, UTF8, LATIN1), + CONVERT_FROM(Bytea("text_in_utf8"), UTF8), + CONVERT_TO(String("text_in_utf8"), UTF8), + ENCODE(Bytea("some text"), Escape), + DECODE(String("MTIzAAE="), Base64), FORMAT(String("Hello %s, %1$s"), String("World")), INITCAP(String("hi THOMAS")), LEFT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)), LENGTH(Bytea("jose")), - LENGTH(Bytea("jose"), String("UTF8")), + LENGTH(Bytea("jose"), UTF8), LPAD(String("Hi"), Int(5)), LPAD(String("Hi"), Int(5), String("xy")), RPAD(String("Hi"), Int(5)), @@ -540,6 +679,202 @@ func TestStringOperators(t *testing.T) { require.NoError(t, err) } +func TestBytea(t *testing.T) { + + var sampleBytea = Bytea([]byte{11, 0, 22, 33, 44}) + var textBytea = Bytea([]byte("text blob")) + + stmt := SELECT( + AllTypes.Bytea.EQ(sampleBytea), + AllTypes.Bytea.EQ(AllTypes.ByteaPtr), + AllTypes.Bytea.NOT_EQ(sampleBytea), + AllTypes.Bytea.GT(textBytea), + AllTypes.Bytea.GT_EQ(AllTypes.ByteaPtr), + AllTypes.Bytea.LT(AllTypes.ByteaPtr), + AllTypes.Bytea.LT_EQ(sampleBytea), + AllTypes.Bytea.BETWEEN(Bytea([]byte("min")), Bytea([]byte("max"))), + AllTypes.Bytea.NOT_BETWEEN(AllTypes.Bytea, AllTypes.ByteaPtr), + AllTypes.Bytea.CONCAT(textBytea), + + func() ProjectionList { + if sourceIsCockroachDB() { + return ProjectionList{NULL} + } + // cockroach doesn't support currently + return ProjectionList{ + AllTypes.Bytea.LIKE(Bytea("b'%pattern%'")), + AllTypes.Bytea.NOT_LIKE(Bytea("b'%pattern%'")), + + BTRIM(AllTypes.Bytea, Bytea([]byte{33})), + RTRIM(AllTypes.ByteaPtr, sampleBytea), + LTRIM(sampleBytea, textBytea), + CONCAT(sampleBytea, AllTypes.ByteaPtr, textBytea), + BIT_COUNT(sampleBytea).EQ(Int(3)), + LENGTH(textBytea, UTF8).EQ(Int(4)), + + CONVERT(textBytea, UTF8, WIN1252), + CONVERT(AllTypes.Bytea, UTF8, LATIN1).EQ(sampleBytea), + } + }(), + + BIT_LENGTH(textBytea), + OCTET_LENGTH(textBytea), + + GET_BIT(textBytea, Int(2)).EQ(Int(23)), + GET_BYTE(sampleBytea, Int(1)).EQ(Int(0)), + SET_BIT(textBytea, Int(1), Int(0)).EQ(sampleBytea), + SET_BYTE(textBytea, Int(1), Int(0)).EQ(textBytea), + LENGTH(sampleBytea), + + SUBSTR(AllTypes.Bytea, Int(0), Int(2)), + + MD5(AllTypes.Bytea), + SHA224(AllTypes.Bytea), + SHA256(AllTypes.Bytea), + SHA384(AllTypes.Bytea), + SHA512(AllTypes.Bytea), + + ENCODE(sampleBytea, Base64), + DECODE(String("A234C12B"), Hex).EQ(sampleBytea), + + CONVERT_FROM(AllTypes.ByteaPtr, UTF8).EQ(AllTypes.VarChar), + CONVERT_TO(AllTypes.Text, UTF8).NOT_EQ(textBytea), + + RawBytea("DECODE(#1::text, #2)", RawArgs{ + "#1": "A234C12B", + "#2": "hex", + }).EQ(sampleBytea), + ).FROM( + AllTypes, + ) + + if !sourceIsCockroachDB() { + testutils.AssertStatementSql(t, stmt, ` +SELECT all_types.bytea = $1::bytea, + all_types.bytea = all_types.bytea_ptr, + all_types.bytea != $2::bytea, + all_types.bytea > $3::bytea, + all_types.bytea >= all_types.bytea_ptr, + all_types.bytea < all_types.bytea_ptr, + all_types.bytea <= $4::bytea, + all_types.bytea BETWEEN $5::bytea AND $6::bytea, + all_types.bytea NOT BETWEEN all_types.bytea AND all_types.bytea_ptr, + all_types.bytea || $7::bytea, + all_types.bytea LIKE $8::bytea, + all_types.bytea NOT LIKE $9::bytea, + BTRIM(all_types.bytea, $10::bytea), + RTRIM(all_types.bytea_ptr, $11::bytea), + LTRIM($12::bytea, $13::bytea), + CONCAT($14::bytea, all_types.bytea_ptr, $15::bytea), + BIT_COUNT($16::bytea) = $17, + LENGTH($18::bytea, 'UTF8') = $19, + CONVERT($20::bytea, 'UTF8', 'WIN1252'), + CONVERT(all_types.bytea, 'UTF8', 'LATIN1') = $21::bytea, + BIT_LENGTH($22::bytea), + OCTET_LENGTH($23::bytea), + GET_BIT($24::bytea, $25) = $26, + GET_BYTE($27::bytea, $28) = $29, + SET_BIT($30::bytea, $31, $32) = $33::bytea, + SET_BYTE($34::bytea, $35, $36) = $37::bytea, + LENGTH($38::bytea), + SUBSTR(all_types.bytea, $39, $40), + MD5(all_types.bytea), + SHA224(all_types.bytea), + SHA256(all_types.bytea), + SHA384(all_types.bytea), + SHA512(all_types.bytea), + ENCODE($41::bytea, 'base64'), + DECODE($42::text, 'hex') = $43::bytea, + CONVERT_FROM(all_types.bytea_ptr, 'UTF8') = all_types.var_char, + CONVERT_TO(all_types.text, 'UTF8') != $44::bytea, + (DECODE($45::text, $46)) = $47::bytea +FROM test_sample.all_types; +`) + } + + var dest []struct{} + err := stmt.Query(db, &dest) + + require.NoError(t, err) +} + +func TestBlobConversion(t *testing.T) { + + nonPrintable := []byte{11, 22, 33, 44, 55} + printable := []byte("this is blob") + + stmt := SELECT( + Bytea(nonPrintable).AS("test_dest.non_printable"), + Bytea(printable).AS("test_dest.printable"), + + Bytea(nonPrintable).CONCAT(Bytea(printable)).AS("test_dest.bytea_concat"), + + ENCODE(Bytea(nonPrintable), Base64).AS("test_dest.non_printable_base64"), + CONVERT_FROM(Bytea(printable), UTF8).AS("test_dest.printable_utf8"), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT '\x0b16212c37'::bytea AS "test_dest.non_printable", + '\x7468697320697320626c6f62'::bytea AS "test_dest.printable", + ('\x0b16212c37'::bytea || '\x7468697320697320626c6f62'::bytea) AS "test_dest.bytea_concat", + ENCODE('\x0b16212c37'::bytea, 'base64') AS "test_dest.non_printable_base64", + CONVERT_FROM('\x7468697320697320626c6f62'::bytea, 'UTF8') AS "test_dest.printable_utf8"; +`) + + type testDest struct { + NonPrintable []byte + Printable []byte + + ByteaConcat []byte + NonPrintableBase64 string + PrintableUTF8 string + } + + var dest testDest + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Equal(t, dest.NonPrintable, nonPrintable) + require.Equal(t, dest.Printable, printable) + + require.Equal(t, dest.ByteaConcat, append(nonPrintable, printable...)) + require.Equal(t, dest.NonPrintableBase64, base64.StdEncoding.EncodeToString(nonPrintable)) + require.Equal(t, dest.PrintableUTF8, string(printable)) + + t.Run("using select json", func(t *testing.T) { + stmtJson := SELECT_JSON_OBJ( + Bytea(nonPrintable).AS("nonPrintable"), + Bytea(printable).AS("printable"), + + Bytea(nonPrintable).CONCAT(Bytea(printable)).AS("byteaConcat"), + + ENCODE(Bytea(nonPrintable), Base64).AS("nonPrintableBase64"), + CONVERT_FROM(Bytea(printable), UTF8).AS("printableUtf8"), + ) + + testutils.AssertStatementSql(t, stmtJson, ` +SELECT row_to_json(records) AS "json" +FROM ( + SELECT ENCODE($1::bytea, 'base64') AS "nonPrintable", + ENCODE($2::bytea, 'base64') AS "printable", + ENCODE($3::bytea || $4::bytea, 'base64') AS "byteaConcat", + ENCODE($5::bytea, 'base64') AS "nonPrintableBase64", + CONVERT_FROM($6::bytea, 'UTF8') AS "printableUtf8" + ) AS records; +`) + + var destSelectJson testDest + + err := stmtJson.QueryContext(ctx, db, &destSelectJson) + require.NoError(t, err) + testutils.PrintJson(destSelectJson) + + require.Equal(t, dest, destSelectJson) + + }) +} + func TestBoolOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"), @@ -941,6 +1276,190 @@ func TestTimeExpression(t *testing.T) { require.NoError(t, err) } +func TestTimeScan(t *testing.T) { + loc, err := time.LoadLocation("Japan") + require.NoError(t, err) + + timeT := time.Date(3, 3, 3, 11, 22, 33, 0, time.UTC) + timeWithNanoSeconds := time.Date(3, 3, 3, 1, 2, 3, 1000, time.UTC) + + timez := time.Date(3, 3, 3, 7, 8, 9, 0, time.UTC) + timezWithNanoSeconds := time.Date(3, 3, 3, 4, 5, 6, 1000, loc) + + // '1999-01-08 04:05:06' + timestamp := time.Date(1999, 01, 8, 4, 5, 6, 0, time.UTC) + timestampWithNanoSeconds := time.Date(3, 3, 3, 8, 9, 10, 1000, time.UTC) + + timestampz := time.Date(2003, 10, 3, 9, 10, 11, 0, loc) + timestampzWithNanoSeconds := time.Date(3, 3, 3, 8, 9, 10, 1000, loc) + + date := time.Date(2010, 2, 3, 0, 0, 0, 0, time.UTC) + + stmt := SELECT( + TimeT(timeT).AS("time"), + TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"), + TimezT(timez).AS("timez"), + TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"), + Timestamp(1999, 01, 8, 4, 5, 6).AS("timestamp"), + TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"), + TimestampzT(timestampz).AS("timestampz"), + TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"), + DateT(date).AS("date"), + + TimeT(timeT).ADD(INTERVAL(2, HOUR)).AS("timeExpression"), + + SELECT_JSON_OBJ( + TimeT(timeT).AS("time"), + TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"), + TimezT(timez).AS("timez"), + TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"), + TimestampT(timestamp).AS("timestamp"), + TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"), + TimestampzT(timestampz).AS("timestampz"), + TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"), + DateT(date).AS("date"), + + TimeT(timeT).ADD(INTERVAL(2, HOUR)).AS("timeExpression"), + ).AS("json"), + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT $1::time without time zone AS "time", + $2::time without time zone AS "timeWithNanoSeconds", + $3::time with time zone AS "timez", + $4::time with time zone AS "timezWithNanoSeconds", + $5::timestamp without time zone AS "timestamp", + $6::timestamp without time zone AS "timestampWithNanoSeconds", + $7::timestamp with time zone AS "timestampz", + $8::timestamp with time zone AS "timestampzWithNanoSeconds", + $9::date AS "date", + ($10::time without time zone + INTERVAL '2 HOUR') AS "timeExpression", + ( + SELECT row_to_json(json_records) AS "json_json" + FROM ( + SELECT '0000-01-01T' || to_char('2000-10-10'::date + $11::time without time zone, 'HH24:MI:SS.USZ') AS "time", + '0000-01-01T' || to_char('2000-10-10'::date + $12::time without time zone, 'HH24:MI:SS.USZ') AS "timeWithNanoSeconds", + '0000-01-01T' || to_char('2000-10-10'::date + $13::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timez", + '0000-01-01T' || to_char('2000-10-10'::date + $14::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timezWithNanoSeconds", + to_char($15::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp", + to_char($16::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampWithNanoSeconds", + $17::timestamp with time zone AS "timestampz", + $18::timestamp with time zone AS "timestampzWithNanoSeconds", + to_char($19::date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date", + '0000-01-01T' || to_char('2000-10-10'::date + ($20::time without time zone + INTERVAL '2 HOUR'), 'HH24:MI:SS.USZ') AS "timeExpression" + ) AS json_records + ) AS "json"; +`) + + var dest struct { + Time time.Time + TimeWithNanoSeconds time.Time + Timez time.Time + TimezWithNanoSeconds time.Time + Timestamp time.Time + TimestampWithNanoSeconds time.Time + Timestampz time.Time + TimestampzWithNanoSeconds time.Time + Date time.Time + + TimeExpression time.Time + + Json struct { + Time time.Time + TimeWithNanoSeconds time.Time + Timez time.Time + TimezWithNanoSeconds time.Time + Timestamp time.Time + TimestampWithNanoSeconds time.Time + Timestampz time.Time + TimestampzWithNanoSeconds time.Time + Date time.Time + + TimeExpression time.Time + } `json_column:"json"` + } + + err = stmt.Query(db, &dest) + require.NoError(t, err) + + ensureTimezEqual(t, timeT.Add(2*time.Hour), dest.TimeExpression, loc) + ensureTimezEqual(t, timeT.Add(2*time.Hour), dest.Json.TimeExpression, loc) + + ensureTimezEqual(t, timeT, dest.Time, loc) + ensureTimezEqual(t, timeT, dest.Json.Time, loc) + ensureTimezEqual(t, timeWithNanoSeconds, dest.TimeWithNanoSeconds, loc) + ensureTimezEqual(t, timeWithNanoSeconds, dest.Json.TimeWithNanoSeconds, loc) + + ensureTimezEqual(t, timez, dest.Timez, loc) + ensureTimezEqual(t, timez, dest.Json.Timez, loc) + ensureTimezEqual(t, timezWithNanoSeconds, dest.TimezWithNanoSeconds, loc) + ensureTimezEqual(t, timezWithNanoSeconds, dest.Json.TimezWithNanoSeconds, loc) + + ensureTimezEqual(t, timestamp, dest.Timestamp, loc) + ensureTimezEqual(t, timestamp, dest.Json.Timestamp, loc) + ensureTimezEqual(t, timestampWithNanoSeconds, dest.TimestampWithNanoSeconds, loc) + ensureTimezEqual(t, timestampWithNanoSeconds, dest.Json.TimestampWithNanoSeconds, loc) + + ensureTimezEqual(t, timestampz, dest.Timestampz, loc) + ensureTimezEqual(t, timestampz, dest.Json.Timestampz, loc) + ensureTimezEqual(t, timestampzWithNanoSeconds, dest.TimestampzWithNanoSeconds, loc) + ensureTimezEqual(t, timestampzWithNanoSeconds, dest.Json.TimestampzWithNanoSeconds, loc) + + ensureTimezEqual(t, date, dest.Date, loc) + ensureTimezEqual(t, date, dest.Json.Date, loc) + + t.Run("json only", func(t *testing.T) { + stmtJson := SELECT_JSON_OBJ( + TimeT(timeT).AS("time"), + TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"), + + TimezT(timez).AS("timez"), + TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"), + + Timestamp(1999, 01, 8, 4, 5, 6).AS("timestamp"), + TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"), + + TimestampzT(timestampz).AS("timestampz"), + TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"), + + DateT(date).AS("date"), + ) + + var jsonDest struct { + Time time.Time + TimeWithNanoSeconds time.Time + + Timez time.Time + TimezWithNanoSeconds time.Time + + Timestamp time.Time + TimestampWithNanoSeconds time.Time + + Timestampz time.Time + TimestampzWithNanoSeconds time.Time + + Date time.Time + } + + err := stmtJson.QueryContext(ctx, db, &jsonDest) + require.NoError(t, err) + }) +} + +func ensureTimezEqual(t *testing.T, time1, time2 time.Time, loc *time.Location) { + time1Loc := time1.In(loc) + time2Loc := time2.In(loc) + + require.Equal(t, time1Loc.Hour(), time2Loc.Hour()) + require.Equal(t, time1Loc.Minute(), time2Loc.Minute()) + require.Equal(t, time1Loc.Second(), time2Loc.Second()) + require.Equal(t, toMicroSeconds(time1Loc.Nanosecond()), toMicroSeconds(time2Loc.Nanosecond())) +} + +func toMicroSeconds(nanoseconds int) int { + return nanoseconds / 1000 +} + func TestIntervalSetFunctionality(t *testing.T) { t.Run("updateQueryIntervalTest", func(t *testing.T) { @@ -1052,7 +1571,50 @@ func TestInterval(t *testing.T) { AllTypes.IntervalPtr.DIV(Float(22.222)).EQ(AllTypes.IntervalPtr), ).FROM(AllTypes) - //fmt.Println(stmt.DebugSql()) + fmt.Println(stmt.Sql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT INTERVAL '1 YEAR', + INTERVAL '1 MONTH', + INTERVAL '1 WEEK', + INTERVAL '1 DAY', + INTERVAL '1 HOUR', + INTERVAL '1 MINUTE', + INTERVAL '1 SECOND', + INTERVAL '1 MILLISECOND', + INTERVAL '1 MICROSECOND', + INTERVAL '1 DECADE', + INTERVAL '1 CENTURY', + INTERVAL '1 MILLENNIUM', + INTERVAL '1 YEAR 10 MONTH', + INTERVAL '1 YEAR 10 MONTH 20 DAY', + INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR', + INTERVAL '1 YEAR' IS NOT NULL, + INTERVAL '1 YEAR' AS "one year", + INTERVAL '0 MICROSECOND', + INTERVAL '1 MICROSECOND', + INTERVAL '1000 MICROSECOND', + INTERVAL '1 SECOND', + INTERVAL '1 MINUTE', + INTERVAL '1 HOUR', + INTERVAL '1 DAY', + INTERVAL '1 DAY 2 HOUR 3 MINUTE 4 SECOND 5 MICROSECOND', + (all_types.interval = INTERVAL '2 HOUR 20 MINUTE') = TRUE::boolean, + (all_types.interval_ptr != INTERVAL '2 HOUR 20 MINUTE') = FALSE::boolean, + (all_types.interval IS DISTINCT FROM INTERVAL '2 HOUR 20 MINUTE') = all_types.boolean, + (all_types.interval_ptr IS NOT DISTINCT FROM INTERVAL '10 MICROSECOND') = all_types.boolean, + (all_types.interval < all_types.interval_ptr) = all_types.boolean_ptr, + (all_types.interval <= all_types.interval_ptr) = all_types.boolean_ptr, + (all_types.interval > all_types.interval_ptr) = all_types.boolean_ptr, + (all_types.interval >= all_types.interval_ptr) = all_types.boolean_ptr, + all_types.interval BETWEEN INTERVAL '1 HOUR' AND INTERVAL '2 HOUR', + all_types.interval NOT BETWEEN all_types.interval_ptr AND INTERVAL '30 SECOND', + (all_types.interval + all_types.interval_ptr) = INTERVAL '17 SECOND', + (all_types.interval - all_types.interval_ptr) = INTERVAL '100 MICROSECOND', + (all_types.interval_ptr * 11) = all_types.interval, + (all_types.interval_ptr / 22.222) = all_types.interval_ptr +FROM test_sample.all_types; +`) err := stmt.Query(db, &struct{}{}) require.NoError(t, err) @@ -1159,6 +1721,187 @@ SELECT ROW($1::integer, $2::real, $3::text) AS "row", require.NoError(t, err) } +func TestAllTypesSubQueryFrom(t *testing.T) { + subQuery := SELECT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.DoublePrecision, + AllTypes.Text, + AllTypes.Date, + AllTypes.Time, + AllTypes.Timez, + AllTypes.Timestamp, + AllTypes.Timestampz, + AllTypes.Interval, + AllTypes.Bytea, + ).FROM( + AllTypes, + ).AsTable("subQuery") + + stmt := SELECT( + AllTypes.Boolean.From(subQuery), + AllTypes.Integer.From(subQuery), + AllTypes.DoublePrecision.From(subQuery), + AllTypes.Text.From(subQuery), + AllTypes.Date.From(subQuery), + AllTypes.Time.From(subQuery), + AllTypes.Timez.From(subQuery), + AllTypes.Timestamp.From(subQuery), + AllTypes.Timestampz.From(subQuery), + AllTypes.Interval.From(subQuery), + AllTypes.Bytea.From(subQuery), + ).FROM( + subQuery, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT "subQuery"."all_types.boolean" AS "all_types.boolean", + "subQuery"."all_types.integer" AS "all_types.integer", + "subQuery"."all_types.double_precision" AS "all_types.double_precision", + "subQuery"."all_types.text" AS "all_types.text", + "subQuery"."all_types.date" AS "all_types.date", + "subQuery"."all_types.time" AS "all_types.time", + "subQuery"."all_types.timez" AS "all_types.timez", + "subQuery"."all_types.timestamp" AS "all_types.timestamp", + "subQuery"."all_types.timestampz" AS "all_types.timestampz", + "subQuery"."all_types.interval" AS "all_types.interval", + "subQuery"."all_types.bytea" AS "all_types.bytea" +FROM ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.integer AS "all_types.integer", + all_types.double_precision AS "all_types.double_precision", + all_types.text AS "all_types.text", + all_types.date AS "all_types.date", + all_types.time AS "all_types.time", + all_types.timez AS "all_types.timez", + all_types.timestamp AS "all_types.timestamp", + all_types.timestampz AS "all_types.timestampz", + all_types.interval AS "all_types.interval", + all_types.bytea AS "all_types.bytea" + FROM test_sample.all_types + ) AS "subQuery"; +`) + + var dest []model.AllTypes + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + t.Run("using SELECT_JSON", func(t *testing.T) { + stmtJson := SELECT_JSON_ARR( + AllTypes.Boolean.From(subQuery), + AllTypes.Integer.From(subQuery), + AllTypes.DoublePrecision.From(subQuery), + AllTypes.Text.From(subQuery), + AllTypes.Date.From(subQuery), + AllTypes.Time.From(subQuery), + AllTypes.Timez.From(subQuery), + AllTypes.Timestamp.From(subQuery), + AllTypes.Timestampz.From(subQuery), + AllTypes.Interval.From(subQuery), + AllTypes.Bytea.From(subQuery), + ).FROM( + subQuery, + ) + + testutils.AssertDebugStatementSql(t, stmtJson, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT "subQuery"."all_types.boolean" AS "boolean", + "subQuery"."all_types.integer" AS "integer", + "subQuery"."all_types.double_precision" AS "doublePrecision", + "subQuery"."all_types.text" AS "text", + to_char("subQuery"."all_types.date"::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date", + '0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.time", 'HH24:MI:SS.USZ') AS "time", + '0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.timez", 'HH24:MI:SS.USTZH:TZM') AS "timez", + to_char("subQuery"."all_types.timestamp", 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp", + "subQuery"."all_types.timestampz" AS "timestampz", + "subQuery"."all_types.interval" AS "interval", + ENCODE("subQuery"."all_types.bytea", 'base64') AS "bytea" + FROM ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.integer AS "all_types.integer", + all_types.double_precision AS "all_types.double_precision", + all_types.text AS "all_types.text", + all_types.date AS "all_types.date", + all_types.time AS "all_types.time", + all_types.timez AS "all_types.timez", + all_types.timestamp AS "all_types.timestamp", + all_types.timestampz AS "all_types.timestampz", + all_types.interval AS "all_types.interval", + all_types.bytea AS "all_types.bytea" + FROM test_sample.all_types + ) AS "subQuery" + ) AS records; +`) + + var destJson []model.AllTypes + + err := stmtJson.QueryContext(ctx, db, &destJson) + require.NoError(t, err) + + t.Run("using AllColumns()", func(t *testing.T) { + stmtJsonAllColumns := SELECT_JSON_ARR( + subQuery.AllColumns(), + ).FROM( + subQuery, + ) + + require.Equal(t, stmtJson.DebugSql(), stmtJsonAllColumns.DebugSql()) + }) + + // fix timezone before comparisons + minus8 := time.FixedZone("UTC", -8*60*60) + destJson[0].Timez = *toTZ(&destJson[0].Timez, minus8) + destJson[1].Timez = *toTZ(&destJson[1].Timez, minus8) + + destJson[0].Timestampz = *toTZ(&destJson[0].Timestampz, time.UTC) + destJson[1].Timestampz = *toTZ(&destJson[1].Timestampz, time.UTC) + + dest[0].Timestampz = *toTZ(&dest[0].Timestampz, time.UTC) + dest[1].Timestampz = *toTZ(&dest[1].Timestampz, time.UTC) + + testutils.AssertJsonEqual(t, dest, destJson) + }) +} + +func TestAllTypesUpdateSet(t *testing.T) { + + stmt := AllTypes.UPDATE(). + SET( + AllTypes.Boolean.SET(Bool(false)), + AllTypes.Integer.SET(Int(2)), + AllTypes.DoublePrecision.SET(Float(2.22)), + AllTypes.Text.SET(Text("some text")), + AllTypes.Date.SET(DateT(time.Now())), + AllTypes.Time.SET(TimeT(time.Now())), + AllTypes.Timez.SET(TimezT(time.Now())), + AllTypes.Timestamp.SET(TimestampT(time.Now())), + AllTypes.Interval.SET(INTERVAL(1, HOUR)), + AllTypes.Bytea.SET(Bytea([]byte{11, 22, 33, 44})), + ).WHERE(Bool(true)) + + testutils.AssertStatementSql(t, stmt, ` +UPDATE test_sample.all_types +SET boolean = $1::boolean, + integer = $2, + double_precision = $3, + text = $4::text, + date = $5::date, + time = $6::time without time zone, + timez = $7::time with time zone, + timestamp = $8::timestamp without time zone, + interval = INTERVAL '1 HOUR', + bytea = $9::bytea +WHERE $10::boolean; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + }) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { sql string diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 349a4ba..74f2097 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -188,7 +188,129 @@ ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId"; `) } +type AllArtistDetails []struct { //list of all artist + model.Artist + + Albums []struct { // list of albums per artist + model.Album + + Tracks []struct { // list of tracks per album + model.Track + + Genre model.Genre // track genre + MediaType model.MediaType // track media type + + Playlists []model.Playlist // list of playlist where track is used + + Invoices []struct { // list of invoices where track occurs + model.Invoice + + Customer struct { // customer data for invoice + model.Customer + + Employee *struct { // employee data for customer if exists + model.Employee + + Manager *model.Employee `alias:"Manager"` + } + } + } + } + } +} + +func BenchmarkJoinEverythingJSON(b *testing.B) { + for i := 0; i < b.N; i++ { + testJoinEverythingJSON(b) + } +} + +func TestJoinEverythingJSON(t *testing.T) { + testJoinEverythingJSON(t) +} + +func testJoinEverythingJSON(t require.TestingT) { + + manager := Employee.AS("Manager") + + stmt := SELECT_JSON_ARR( + Artist.AllColumns, + + SELECT_JSON_ARR( + Album.AllColumns, + + SELECT_JSON_ARR( + Track.AllColumns, + + SELECT_JSON_OBJ(Genre.AllColumns). + FROM(Genre). + WHERE(Genre.GenreId.EQ(Track.GenreId)).AS("Genre"), + + SELECT_JSON_OBJ(MediaType.AllColumns). + FROM(MediaType). + WHERE(MediaType.MediaTypeId.EQ(Track.MediaTypeId)).AS("MediaType"), + + SELECT_JSON_ARR(Playlist.AllColumns). + FROM(Playlist. + INNER_JOIN(PlaylistTrack, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId))). + WHERE(PlaylistTrack.TrackId.EQ(Track.TrackId)). + ORDER_BY(Playlist.PlaylistId).AS("Playlists"), + + SELECT_JSON_ARR( + Invoice.AllColumns, + + SELECT_JSON_OBJ( + Customer.AllColumns, + + SELECT_JSON_OBJ( + Employee.AllColumns, + + SELECT_JSON_OBJ(manager.AllColumns). + FROM(manager). + WHERE(manager.EmployeeId.EQ(Employee.ReportsTo)).AS("Manager"), + ).FROM(Employee). + WHERE(Employee.EmployeeId.EQ(Customer.SupportRepId)).AS("Employee"), + ).FROM(Customer). + WHERE(Customer.CustomerId.EQ(Invoice.CustomerId)).AS("Customer"), + ).FROM( + Invoice. + INNER_JOIN(InvoiceLine, InvoiceLine.InvoiceId.EQ(Invoice.InvoiceId)), + ).WHERE(InvoiceLine.TrackId.EQ(Track.TrackId)). + ORDER_BY(Invoice.InvoiceId).AS("Invoices"), + ).FROM(Track). + WHERE(Track.AlbumId.EQ(Album.AlbumId)). + ORDER_BY(Track.TrackId).AS("Tracks"), + ).FROM(Album). + WHERE(Album.ArtistId.EQ(Artist.ArtistId)). + ORDER_BY(Album.AlbumId).AS("Albums"), + ).FROM(Artist). + ORDER_BY(Artist.ArtistId) + + //fmt.Println(stmt.DebugSql()) + + var dest AllArtistDetails + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + require.Equal(t, len(dest), 275) + //testutils.SaveJSONFile(dest, "./testdata/results/postgres/joined_everything2.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") + requireLogged(t, stmt) + requireQueryLogged(t, stmt, 1) +} + +func BenchmarkJoinEverything(b *testing.B) { + for i := 0; i < b.N; i++ { + testJoinEverything(b) + } +} + func TestJoinEverything(t *testing.T) { + testJoinEverything(t) +} + +func testJoinEverything(t require.TestingT) { manager := Employee.AS("Manager") @@ -223,37 +345,6 @@ func TestJoinEverything(t *testing.T) { Invoice.InvoiceId, Customer.CustomerId, ) - var dest []struct { //list of all artist - model.Artist - - Albums []struct { // list of albums per artist - model.Album - - Tracks []struct { // list of tracks per album - model.Track - - Genre model.Genre // track genre - MediaType model.MediaType // track media type - - Playlists []model.Playlist // list of playlist where track is used - - Invoices []struct { // list of invoices where track occurs - model.Invoice - - Customer struct { // customer data for invoice - model.Customer - - Employee *struct { // employee data for customer if exists - model.Employee - - Manager *model.Employee `alias:"Manager"` - } - } - } - } - } - } - testutils.AssertStatementSql(t, stmt, ` SELECT "Artist"."ArtistId" AS "Artist.ArtistId", "Artist"."Name" AS "Artist.Name", @@ -344,7 +435,7 @@ FROM chinook."Artist" LEFT JOIN chinook."Employee" AS "Manager" ON ("Manager"."EmployeeId" = "Employee"."ReportsTo") ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId"; `) - + var dest AllArtistDetails err := stmt.QueryContext(context.Background(), db, &dest) require.NoError(t, err) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 005e703..06478a6 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -974,8 +974,8 @@ type allTypesTable struct { Char postgres.ColumnString TextPtr postgres.ColumnString Text postgres.ColumnString - ByteaPtr postgres.ColumnString - Bytea postgres.ColumnString + ByteaPtr postgres.ColumnBytea + Bytea postgres.ColumnBytea TimestampzPtr postgres.ColumnTimestampz Timestampz postgres.ColumnTimestampz TimestampPtr postgres.ColumnTimestamp @@ -1078,8 +1078,8 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { CharColumn = postgres.StringColumn("char") TextPtrColumn = postgres.StringColumn("text_ptr") TextColumn = postgres.StringColumn("text") - ByteaPtrColumn = postgres.StringColumn("bytea_ptr") - ByteaColumn = postgres.StringColumn("bytea") + ByteaPtrColumn = postgres.ByteaColumn("bytea_ptr") + ByteaColumn = postgres.ByteaColumn("bytea") TimestampzPtrColumn = postgres.TimestampzColumn("timestampz_ptr") TimestampzColumn = postgres.TimestampzColumn("timestampz") TimestampPtrColumn = postgres.TimestampColumn("timestamp_ptr") diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 46dfd94..915387a 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -20,6 +20,8 @@ import ( _ "github.com/jackc/pgx/v4/stdlib" ) +var ctx = context.Background() + var db *stmtcache.DB var testRoot string @@ -31,6 +33,7 @@ const CockroachDB = "COCKROACH_DB" func init() { source = os.Getenv("PG_SOURCE") withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" + testRoot = repo.GetTestsDirPath() } func sourceIsCockroachDB() bool { @@ -46,8 +49,6 @@ func skipForCockroachDB(t *testing.T) { func TestMain(m *testing.M) { defer profile.Start().Stop() - setTestRoot() - for _, driverName := range []string{"postgres", "pgx"} { fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, withStatementCaching) @@ -94,10 +95,6 @@ func getConnectionString() string { return dbconfig.PostgresConnectString } -func setTestRoot() { - testRoot = repo.GetTestsDirPath() -} - var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string @@ -119,14 +116,22 @@ func init() { }) } -func requireLogged(t *testing.T, statement postgres.Statement) { +func requireLogged(t require.TestingT, statement postgres.Statement) { + if _, ok := t.(*testing.B); ok { + return // skip assert for benchmarks + } + query, args := statement.Sql() require.Equal(t, loggedSQL, query) require.Equal(t, loggedSQLArgs, args) require.Equal(t, loggedDebugSQL, statement.DebugSql()) } -func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { +func requireQueryLogged(t require.TestingT, statement postgres.Statement, rowsProcessed int64) { + if _, ok := t.(*testing.B); ok { + return // skip assert for benchmarks + } + query, args := statement.Sql() queryLogged, argsLogged := queryInfo.Statement.Sql() diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index 7aad0d4..5615f45 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -9,7 +9,50 @@ import ( "testing" ) -func TestNorthwindJoinEverything(t *testing.T) { +type Dest []struct { + model.Customers + + Demographics model.CustomerDemographics + + Orders []struct { + model.Orders + + Shipper model.Shippers + + Employee struct { + model.Employees + + Territories []struct { + model.Territories + + Region model.Region + } + } + + Details []struct { + model.OrderDetails + + Products struct { + model.Products + + Category model.Categories + Supplier model.Suppliers + } + } + } +} + +func BenchmarkTestNorthwindJoinEverything(b *testing.B) { + for i := 0; i < b.N; i++ { + testNorthwindJoinEverything(b) + } +} + +func TestTestNorthwindJoinEverything(t *testing.T) { + testNorthwindJoinEverything(t) +} + +func testNorthwindJoinEverything(t require.TestingT) { stmt := SELECT( @@ -21,6 +64,9 @@ func TestNorthwindJoinEverything(t *testing.T) { Products.AllColumns, Categories.AllColumns, Suppliers.AllColumns, + Employees.AllColumns, + Territories.AllColumns, + Region.AllColumns, ).FROM( Customers. LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)). @@ -35,35 +81,110 @@ func TestNorthwindJoinEverything(t *testing.T) { LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)). LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)). LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)), - ).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID) + ).ORDER_BY( + Customers.CustomerID, + Orders.OrderID, + Products.ProductID, + Territories.TerritoryID, + ) - var dest []struct { - model.Customers + //fmt.Println(stmt.DebugSql()) - Demographics model.CustomerDemographics - - Orders []struct { - model.Orders - - Shipper model.Shippers - - Details struct { - model.OrderDetails - - Products []struct { - model.Products - - Category model.Categories - Supplier model.Suppliers - } - } - } - } + var dest Dest err := stmt.Query(db, &dest) require.NoError(t, err) - //jsonSave("./testdata/northwind-all.json", dest) + //testutils.SaveJSONFile(dest, "./testdata/results/postgres/northwind-all.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") requireLogged(t, stmt) } + +func BenchmarkTestNorthwindJoinEverythingJson(b *testing.B) { + for i := 0; i < b.N; i++ { + testNorthwindJoinEverythingJson(b) + } +} + +func TestNorthwindJoinEverythingJson(t *testing.T) { + testNorthwindJoinEverythingJson(t) +} + +func testNorthwindJoinEverythingJson(t require.TestingT) { + + stmt := SELECT_JSON_ARR( + Customers.AllColumns, + + SELECT_JSON_OBJ(CustomerDemographics.AllColumns). + FROM(CustomerDemographics.INNER_JOIN(CustomerCustomerDemo, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID))). + WHERE(CustomerCustomerDemo.CustomerID.EQ(Customers.CustomerID)).AS("Demographics"), + + SELECT_JSON_ARR( + Orders.AllColumns, + + SELECT_JSON_OBJ(Shippers.AllColumns). + FROM(Shippers). + WHERE(Shippers.ShipperID.EQ(Orders.ShipVia)).AS("Shipper"), + + SELECT_JSON_OBJ( + Employees.AllColumns, + SELECT_JSON_ARR( + Territories.AllColumns, + + SELECT_JSON_OBJ(Region.AllColumns). + FROM(Region). + WHERE(Region.RegionID.EQ(Territories.RegionID)).AS("Region"), + ).FROM( + EmployeeTerritories.LEFT_JOIN( + Territories, + EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)), + ).WHERE( + EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID), + ).AS("Territories"), + ).FROM(Employees). + WHERE(Orders.EmployeeID.EQ(Employees.EmployeeID)).AS("Employee"), + + SELECT_JSON_ARR( + OrderDetails.AllColumns, + + SELECT_JSON_OBJ( + Products.AllColumns, + + SELECT_JSON_OBJ( + Categories.AllColumns, + ).FROM(Categories). + WHERE(Categories.CategoryID.EQ(Products.CategoryID)).AS("Category"), + + SELECT_JSON_OBJ(Suppliers.AllColumns). + FROM(Suppliers). + WHERE(Suppliers.SupplierID.EQ(Products.SupplierID)).AS("Supplier"), + ).FROM(Products). + WHERE(Products.ProductID.EQ(OrderDetails.ProductID)).AS("Products"), + ).FROM( + OrderDetails, + ).WHERE( + OrderDetails.OrderID.EQ(Orders.OrderID), + ).AS("Details"), + ).FROM( + Orders, + ).WHERE( + Orders.CustomerID.EQ(Customers.CustomerID), + ).ORDER_BY( + Orders.OrderID, + ).AS("Orders"), + ).FROM( + Customers, + ).ORDER_BY( + Customers.CustomerID, + ) + + //fmt.Println(stmt.DebugSql()) + + var dest Dest + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + //testutils.SaveJSONFile(dest, "./testdata/results/postgres/northwind-all2.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") +} diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 0a252db..63bfef9 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -220,20 +220,7 @@ func TestUUIDComplex(t *testing.T) { requireLogged(t, query) }) - t.Run("slice of structs left join", func(t *testing.T) { - leftQuery := Person.LEFT_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)). - SELECT(Person.AllColumns, PersonPhone.AllColumns). - ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC()) - var dest []struct { - model.Person - Phones []struct { - model.PersonPhone - } - } - err := leftQuery.Query(db, &dest) - - require.NoError(t, err) - testutils.AssertJSON(t, dest, ` + var expectedSliceOfStructsLeftJoin = ` [ { "PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6", @@ -274,10 +261,50 @@ func TestUUIDComplex(t *testing.T) { ] } ] -`) +` + + t.Run("slice of structs left join", func(t *testing.T) { + leftQuery := Person.LEFT_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)). + SELECT(Person.AllColumns, PersonPhone.AllColumns). + ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC()) + var dest []struct { + model.Person + Phones []struct { + model.PersonPhone + } + } + err := leftQuery.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, expectedSliceOfStructsLeftJoin) requireLogged(t, leftQuery) }) + t.Run("select json", func(t *testing.T) { + jsonQuery := SELECT_JSON_ARR( + Person.AllColumns, + SELECT_JSON_ARR(PersonPhone.AllColumns). + FROM(PersonPhone). + WHERE(PersonPhone.PersonID.EQ(Person.PersonID)). + ORDER_BY(PersonPhone.PhoneID).AS("Phones"), + ).FROM( + Person, + ).ORDER_BY( + Person.PersonID.ASC(), + ) + + var dest []struct { + model.Person + Phones []struct { + model.PersonPhone + } + } + + err := jsonQuery.QueryContext(ctx, db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, expectedSliceOfStructsLeftJoin) + }) + } func TestEnumType(t *testing.T) { query := Person. diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 321fc38..a630145 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -209,7 +209,7 @@ func TestScanToStruct(t *testing.T) { err := query.Query(db, &dest) require.Error(t, err) - require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID") + require.EqualError(t, err, "jet: can't assign int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID") }) t.Run("type mismatch base type", func(t *testing.T) { diff --git a/tests/postgres/select_json_test.go b/tests/postgres/select_json_test.go new file mode 100644 index 0000000..8691018 --- /dev/null +++ b/tests/postgres/select_json_test.go @@ -0,0 +1,908 @@ +package postgres + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/go-jet/jet/v2/qrm" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/view" + "github.com/stretchr/testify/require" + "testing" + "time" + + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" +) + +func TestSelectJsonObject(t *testing.T) { + stmt := SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.EQ(Int32(2))) + + testutils.AssertStatementSql(t, stmt, ` +SELECT row_to_json(records) AS "json" +FROM ( + SELECT actor.actor_id AS "actorID", + actor.first_name AS "firstName", + actor.last_name AS "lastName", + to_char(actor.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.actor + WHERE actor.actor_id = $1::integer + ) AS records; +`, int32(2)) + + var dest model.Actor + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + testutils.AssertJsonEqual(t, dest, actor2) + requireLogged(t, stmt) + + t.Run("scan to map", func(t *testing.T) { + var dest2 map[string]interface{} + + err = stmt.QueryContext(ctx, db, &dest2) + require.NoError(t, err) + //testutils.PrintJson(dest2) + testutils.AssertDeepEqual(t, dest2, map[string]interface{}{ + "actorID": float64(2), + "firstName": "Nick", + "lastName": "Wahlberg", + "lastUpdate": "2013-05-26T14:47:57.620000Z", + }) + }) +} + +func TestSelectJsonArr(t *testing.T) { + stmt := SELECT_JSON_ARR( + Rental.StaffID, + Rental.CustomerID, + Rental.RentalID, + ).DISTINCT( + Rental.StaffID, + Rental.CustomerID, + ).FROM( + Rental, + ).WHERE( + Rental.CustomerID.LT(Int(2)), + ).ORDER_BY( + Rental.StaffID.ASC(), + Rental.CustomerID.ASC(), + Rental.RentalID.ASC(), + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT DISTINCT ON (rental.staff_id, rental.customer_id) rental.staff_id AS "staffID", + rental.customer_id AS "customerID", + rental.rental_id AS "rentalID" + FROM dvds.rental + WHERE rental.customer_id < $1 + ORDER BY rental.staff_id ASC, rental.customer_id ASC, rental.rental_id ASC + ) AS records; +`, int64(2)) + + var dest []model.Rental + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "RentalID": 573, + "RentalDate": "0001-01-01T00:00:00Z", + "InventoryID": 0, + "CustomerID": 1, + "ReturnDate": null, + "StaffID": 1, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + { + "RentalID": 76, + "RentalDate": "0001-01-01T00:00:00Z", + "InventoryID": 0, + "CustomerID": 1, + "ReturnDate": null, + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + } +] +`) + t.Run("scan to array map", func(t *testing.T) { + var dest2 []map[string]interface{} + + err := stmt.QueryContext(ctx, db, &dest2) + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest2, []map[string]interface{}{ + { + "rentalID": 573., + "customerID": 1., + "staffID": 1., + }, + { + "rentalID": 76., + "customerID": 1., + "staffID": 2., + }, + }) + }) +} + +func TestSelectJsonArr_NestedArr(t *testing.T) { + + stmt := SELECT_JSON_ARR( + Customer.AllColumns, + + SELECT_JSON_ARR(Rental.AllColumns). + FROM(Rental). + WHERE(Rental.CustomerID.EQ(Customer.CustomerID)). + ORDER_BY(Rental.RentalID). + OFFSET_e(Int(1)).LIMIT(3).AS("Rentals"), + ).FROM( + Customer, + ).ORDER_BY( + Customer.CustomerID, + ).LIMIT(2).OFFSET(1) + + testutils.AssertStatementSql(t, stmt, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT customer.customer_id AS "customerID", + customer.store_id AS "storeID", + customer.first_name AS "firstName", + customer.last_name AS "lastName", + customer.email AS "email", + customer.address_id AS "addressID", + customer.activebool AS "activebool", + to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "createDate", + to_char(customer.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", + customer.active AS "active", + ( + SELECT json_agg(row_to_json(rentals_records)) AS "rentals_json" + FROM ( + SELECT rental.rental_id AS "rentalID", + to_char(rental.rental_date, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "rentalDate", + rental.inventory_id AS "inventoryID", + rental.customer_id AS "customerID", + to_char(rental.return_date, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "returnDate", + rental.staff_id AS "staffID", + to_char(rental.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.rental + WHERE rental.customer_id = customer.customer_id + ORDER BY rental.rental_id + LIMIT $1 + OFFSET $2 + ) AS rentals_records + ) AS "Rentals" + FROM dvds.customer + ORDER BY customer.customer_id + LIMIT $3 + OFFSET $4 + ) AS records; +`) + + var dest []struct { + model.Customer + + Rentals []model.Rental + } + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + t.Run("partial select json", func(t *testing.T) { + + stmt := SELECT( + Customer.AllColumns, + + SELECT_JSON_ARR(Rental.AllColumns). + FROM(Rental). + WHERE(Rental.CustomerID.EQ(Customer.CustomerID)). + ORDER_BY(Rental.RentalID). + OFFSET_e(Int(1)).LIMIT(3).AS("Rentals"), + ).FROM( + Customer, + ).ORDER_BY( + Customer.CustomerID, + ).OFFSET(1).LIMIT(2) + + testutils.AssertStatementSql(t, stmt, ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.activebool AS "customer.activebool", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + customer.active AS "customer.active", + ( + SELECT json_agg(row_to_json(rentals_records)) AS "rentals_json" + FROM ( + SELECT rental.rental_id AS "rentalID", + to_char(rental.rental_date, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "rentalDate", + rental.inventory_id AS "inventoryID", + rental.customer_id AS "customerID", + to_char(rental.return_date, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "returnDate", + rental.staff_id AS "staffID", + to_char(rental.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.rental + WHERE rental.customer_id = customer.customer_id + ORDER BY rental.rental_id + LIMIT $1 + OFFSET $2 + ) AS rentals_records + ) AS "Rentals" +FROM dvds.customer +ORDER BY customer.customer_id +LIMIT $3 +OFFSET $4; +`) + + var dest2 []struct { + model.Customer + + Rentals []model.Rental `json_column:"Rentals"` + } + + err := stmt.Query(db, &dest2) + require.NoError(t, err) + testutils.AssertJsonEqual(t, dest, dest2) + + var dest3 []struct { + model.Customer + + Rentals *[]model.Rental `json_column:"rentals"` + } + + err = stmt.Query(db, &dest3) + require.NoError(t, err) + testutils.AssertJsonEqual(t, dest, dest3) + + var dest4 []struct { + model.Customer + + Rentals []*model.Rental `json_column:"rentals"` + } + + err = stmt.Query(db, &dest4) + require.NoError(t, err) + testutils.AssertJsonEqual(t, dest, dest4) + }) +} + +func TestSelectJson_GroupByHaving(t *testing.T) { + stmt := SELECT_JSON_ARR( + Customer.AllColumns, + + SELECT_JSON_OBJ( + SUM(Payment.Amount).AS("sum"), + AVG(Payment.Amount).AS("avg"), + MAX(Payment.PaymentDate).AS("max_date"), + MAX(Payment.Amount).AS("max"), + MIN(Payment.PaymentDate).AS("min_date"), + MIN(Payment.Amount).AS("min"), + COUNT(Payment.Amount).AS("count"), + ).AS("amount"), + ).FROM( + Payment. + INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)), + ).GROUP_BY( + Customer.CustomerID, + ).HAVING( + SUMf(Payment.Amount).GT(Real(125)), + ).ORDER_BY( + Customer.CustomerID, SUM(Payment.Amount).ASC(), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT customer.customer_id AS "customerID", + customer.store_id AS "storeID", + customer.first_name AS "firstName", + customer.last_name AS "lastName", + customer.email AS "email", + customer.address_id AS "addressID", + customer.activebool AS "activebool", + to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "createDate", + to_char(customer.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", + customer.active AS "active", + ( + SELECT row_to_json(amount_records) AS "amount_json" + FROM ( + SELECT SUM(payment.amount) AS "sum", + AVG(payment.amount) AS "avg", + MAX(payment.payment_date) AS "max_date", + MAX(payment.amount) AS "max", + MIN(payment.payment_date) AS "min_date", + MIN(payment.amount) AS "min", + COUNT(payment.amount) AS "count" + ) AS amount_records + ) AS "amount" + FROM dvds.payment + INNER JOIN dvds.customer ON (customer.customer_id = payment.customer_id) + GROUP BY customer.customer_id + HAVING SUM(payment.amount) > 125::real + ORDER BY customer.customer_id, SUM(payment.amount) ASC + ) AS records; +`) + + var dest []struct { + model.Customer + + Amount struct { + Sum float64 + Avg float64 + Max float64 + Min float64 + Count int64 + } `alias:"amount"` + } + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + if sourceIsCockroachDB() { + return // small precision difference in result + } + + testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json") +} + +func TestSelectQuickStartJSON(t *testing.T) { + + stmt := SELECT_JSON_ARR( + Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate, + + SELECT_JSON_ARR( + Film.AllColumns.Except(Film.SpecialFeatures), + CAST(Film.SpecialFeatures).AS_TEXT().AS("SpecialFeatures"), + + SELECT_JSON_OBJ( + Language.AllColumns, + ).FROM( + Language, + ).WHERE( + Language.LanguageID.EQ(Film.LanguageID).AND( + Language.Name.EQ(Char(20)("English")), + ), + ).AS("Language"), + + SELECT_JSON_ARR( + Category.AllColumns, + ).FROM( + Category. + INNER_JOIN(FilmCategory, FilmCategory.CategoryID.EQ(Category.CategoryID)), + ).WHERE( + FilmCategory.FilmID.EQ(Film.FilmID).AND( + Category.Name.NOT_EQ(Text("Action")), + ), + ).AS("Categories"), + ).FROM( + Film. + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)), + ).WHERE( + FilmActor.ActorID.EQ(Actor.ActorID).AND(Film.Length.GT(Int32(180))), + ).ORDER_BY( + Film.FilmID.ASC(), + ).AS("Films"), + ).FROM( + Actor, + ).ORDER_BY( + Actor.ActorID.ASC(), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT actor.actor_id AS "actorID", + actor.first_name AS "firstName", + actor.last_name AS "lastName", + to_char(actor.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", + ( + SELECT json_agg(row_to_json(films_records)) AS "films_json" + FROM ( + SELECT film.film_id AS "filmID", + film.title AS "title", + film.description AS "description", + film.release_year AS "releaseYear", + film.language_id AS "languageID", + film.rental_duration AS "rentalDuration", + film.rental_rate AS "rentalRate", + film.length AS "length", + film.replacement_cost AS "replacementCost", + film.rating AS "rating", + to_char(film.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", + film.fulltext AS "fulltext", + film.special_features::text AS "SpecialFeatures", + ( + SELECT row_to_json(language_records) AS "language_json" + FROM ( + SELECT language.language_id AS "languageID", + language.name AS "name", + to_char(language.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.language + WHERE (language.language_id = film.language_id) AND (language.name = 'English'::char(20)) + ) AS language_records + ) AS "Language", + ( + SELECT json_agg(row_to_json(categories_records)) AS "categories_json" + FROM ( + SELECT category.category_id AS "categoryID", + category.name AS "name", + to_char(category.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.category + INNER JOIN dvds.film_category ON (film_category.category_id = category.category_id) + WHERE (film_category.film_id = film.film_id) AND (category.name != 'Action'::text) + ) AS categories_records + ) AS "Categories" + FROM dvds.film + INNER JOIN dvds.film_actor ON (film_actor.film_id = film.film_id) + WHERE (film_actor.actor_id = actor.actor_id) AND (film.length > 180::integer) + ORDER BY film.film_id ASC + ) AS films_records + ) AS "Films" + FROM dvds.actor + ORDER BY actor.actor_id ASC + ) AS records; +`) + + var dest []struct { + model.Actor + + Films []struct { + model.Film + + Language model.Language + Categories []model.Category + } + } + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + require.Len(t, dest, 200) + + if sourceIsCockroachDB() { + return // char[n] columns whitespaces are trimmed when returned as json in cockroachdb + } + + //testutils.SaveJSONFile(dest, "./testdata/results/postgres/quick-start-json-dest2.json") + //testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-json-dest.json") +} + +func TestSelectJsonInReturning(t *testing.T) { + + stmt := Rental. + UPDATE(Rental.ReturnDate). + MODEL(model.Rental{ + ReturnDate: ptr.Of(time.Date(2010, 2, 4, 5, 6, 7, 8, time.UTC)), + }). + WHERE( + Rental.RentalID.EQ(Int(11496)), + ). + RETURNING( + Rental.AllColumns.Except(Rental.LastUpdate), + + SELECT_JSON_OBJ( + Customer.AllColumns, + ).FROM( + Customer, + ).WHERE( + Customer.CustomerID.EQ(Rental.CustomerID), + ).AS("Customer"), + ) + + testutils.AssertStatementSql(t, stmt, ` +UPDATE dvds.rental +SET return_date = $1 +WHERE rental.rental_id = $2 +RETURNING rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + ( + SELECT row_to_json(customer_records) AS "customer_json" + FROM ( + SELECT customer.customer_id AS "customerID", + customer.store_id AS "storeID", + customer.first_name AS "firstName", + customer.last_name AS "lastName", + customer.email AS "email", + customer.address_id AS "addressID", + customer.activebool AS "activebool", + to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "createDate", + to_char(customer.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", + customer.active AS "active" + FROM dvds.customer + WHERE customer.customer_id = rental.customer_id + ) AS customer_records + ) AS "Customer"; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + var dest struct { + model.Rental + + Customer model.Customer `json_column:"Customer"` + } + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +{ + "RentalID": 11496, + "RentalDate": "2006-02-14T15:16:03Z", + "InventoryID": 2047, + "CustomerID": 155, + "ReturnDate": "2010-02-04T05:06:07Z", + "StaffID": 1, + "LastUpdate": "0001-01-01T00:00:00Z", + "Customer": { + "CustomerID": 155, + "StoreID": 1, + "FirstName": "Gail", + "LastName": "Knight", + "Email": "gail.knight@sakilacustomer.org", + "AddressID": 159, + "Activebool": true, + "CreateDate": "2006-02-14T00:00:00Z", + "LastUpdate": "2013-05-26T14:49:45.738Z", + "Active": 1 + } +} +`) + }) +} + +func TestSelectJson_FetchFirst(t *testing.T) { + stmt := SELECT_JSON_ARR(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID). + OFFSET(2). + FETCH_FIRST(Int(3)).ROWS_ONLY() + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT actor.actor_id AS "actorID", + actor.first_name AS "firstName", + actor.last_name AS "lastName", + to_char(actor.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.actor + ORDER BY actor.actor_id + OFFSET 2 + FETCH FIRST 3 ROWS ONLY + ) AS records; +`) + + var dest []model.Actor + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "ActorID": 3, + "FirstName": "Ed", + "LastName": "Chase", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 4, + "FirstName": "Jennifer", + "LastName": "Davis", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 5, + "FirstName": "Johnny", + "LastName": "Lollobrigida", + "LastUpdate": "2013-05-26T14:47:57.62Z" + } +] +`) +} + +func TestSelectJson_RowLock(t *testing.T) { + + stmt := SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.EQ(Int(200))). + FOR(UPDATE().NOWAIT()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT row_to_json(records) AS "json" +FROM ( + SELECT actor.actor_id AS "actorID", + actor.first_name AS "firstName", + actor.last_name AS "lastName", + to_char(actor.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.actor + WHERE actor.actor_id = 200 + FOR UPDATE NOWAIT + ) AS records; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + var dest model.Actor + + err := stmt.QueryContext(ctx, tx, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +{ + "ActorID": 200, + "FirstName": "Thora", + "LastName": "Temple", + "LastUpdate": "2013-05-26T14:47:57.62Z" +} +`) + }) + +} + +func TestSelectJson_UNION(t *testing.T) { + + stmt := UNION_ALL( + SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.EQ(Int(20))), + + SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.EQ(Int(21))), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +( + SELECT row_to_json(records) AS "json" + FROM ( + SELECT actor.actor_id AS "actorID", + actor.first_name AS "firstName", + actor.last_name AS "lastName", + to_char(actor.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.actor + WHERE actor.actor_id = 20 + ) AS records +) +UNION ALL +( + SELECT row_to_json(records) AS "json" + FROM ( + SELECT actor.actor_id AS "actorID", + actor.first_name AS "firstName", + actor.last_name AS "lastName", + to_char(actor.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.actor + WHERE actor.actor_id = 21 + ) AS records +); +`) + + var dest []struct { + model.Actor `json_column:"json"` + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "ActorID": 20, + "FirstName": "Lucille", + "LastName": "Tracy", + "LastUpdate": "2013-05-26T14:47:57.62Z" + }, + { + "ActorID": 21, + "FirstName": "Kirsten", + "LastName": "Paltrow", + "LastUpdate": "2013-05-26T14:47:57.62Z" + } +] +`) +} + +func TestSelectJson_Window(t *testing.T) { + stmt := SELECT_JSON_ARR( + AVG(Payment.Amount).OVER().AS("avgOver"), + AVG(Payment.Amount).OVER(Window("w1")).AS("avgOverW1"), + AVG(Payment.Amount).OVER( + Window("w2"). + ORDER_BY(Payment.CustomerID). + RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + ).AS("avgOverW2"), + AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))).AS("avgOverW3"), + ).FROM( + Payment, + ).WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). + WINDOW("w2").AS(Window("w1")). + WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). + ORDER_BY(Payment.CustomerID). + LIMIT(4) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT json_agg(row_to_json(records)) AS "json" +FROM ( + SELECT AVG(payment.amount) OVER () AS "avgOver", + AVG(payment.amount) OVER (w1) AS "avgOverW1", + AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "avgOverW2", + AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "avgOverW3" + FROM dvds.payment + WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) + ORDER BY payment.customer_id + LIMIT 4 + ) AS records; +`) + + var dest []struct { + AvgOver float64 + AvgOverW1 float64 + AvgOverW2 float64 + AvgOverW3 float64 + } + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) +} + +func TestSelectJson_QueryWithoutUnMarshaling(t *testing.T) { + stmt := SELECT( + SELECT_JSON_ARR( + view.CustomerList.AllColumns, + + SELECT_JSON_ARR(Rental.AllColumns). + FROM(Rental). + WHERE(view.CustomerList.ID.EQ(Rental.CustomerID)). + ORDER_BY(Rental.CustomerID). + AS("Rentals"), + ).FROM( + view.CustomerList, + ).WHERE( + view.CustomerList.ID.LT_EQ(Int(2)), + ).ORDER_BY( + view.CustomerList.ID, + ).AS("raw_json"), + ) + + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT ( + SELECT json_agg(row_to_json(raw_json_records)) AS "raw_json_json" + FROM ( + SELECT customer_list.id AS "id", + customer_list.name AS "name", + customer_list.address AS "address", + customer_list."zip code" AS "zip code", + customer_list.phone AS "phone", + customer_list.city AS "city", + customer_list.country AS "country", + customer_list.notes AS "notes", + customer_list.sid AS "sid", + ( + SELECT json_agg(row_to_json(rentals_records)) AS "rentals_json" + FROM ( + SELECT rental.rental_id AS "rentalID", + to_char(rental.rental_date, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "rentalDate", + rental.inventory_id AS "inventoryID", + rental.customer_id AS "customerID", + to_char(rental.return_date, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "returnDate", + rental.staff_id AS "staffID", + to_char(rental.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate" + FROM dvds.rental + WHERE customer_list.id = rental.customer_id + ORDER BY rental.customer_id + ) AS rentals_records + ) AS "Rentals" + FROM dvds.customer_list + WHERE customer_list.id <= 2 + ORDER BY customer_list.id + ) AS raw_json_records + ) AS "raw_json"; +`) + + var dest struct { + RawJson []byte + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + if sourceIsCockroachDB() { + require.Equal(t, string(dest.RawJson), `[{"Rentals": [{"customerID": 1, "inventoryID": 3021, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-05-25T11:30:37.000000Z", "rentalID": 76, "returnDate": "2005-06-03T12:00:37.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 4020, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-05-28T10:35:23.000000Z", "rentalID": 573, "returnDate": "2005-06-03T06:32:23.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 2785, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-15T00:54:12.000000Z", "rentalID": 1185, "returnDate": "2005-06-23T02:42:12.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 1021, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-15T18:02:53.000000Z", "rentalID": 1422, "returnDate": "2005-06-19T15:54:53.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 1407, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-15T21:08:46.000000Z", "rentalID": 1476, "returnDate": "2005-06-25T02:26:46.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 726, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-16T15:18:57.000000Z", "rentalID": 1725, "returnDate": "2005-06-17T21:05:57.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 197, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-18T08:41:48.000000Z", "rentalID": 2308, "returnDate": "2005-06-22T03:36:48.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 3497, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-18T13:33:59.000000Z", "rentalID": 2363, "returnDate": "2005-06-19T17:40:59.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 4566, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-21T06:24:45.000000Z", "rentalID": 3284, "returnDate": "2005-06-28T03:28:45.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 1443, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-08T03:17:05.000000Z", "rentalID": 4526, "returnDate": "2005-07-14T01:19:05.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 3486, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-08T07:33:56.000000Z", "rentalID": 4611, "returnDate": "2005-07-12T13:25:56.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 3726, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-09T13:24:07.000000Z", "rentalID": 5244, "returnDate": "2005-07-14T14:01:07.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 797, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-09T16:38:01.000000Z", "rentalID": 5326, "returnDate": "2005-07-13T18:02:01.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 1330, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-11T10:13:46.000000Z", "rentalID": 6163, "returnDate": "2005-07-19T13:15:46.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 2465, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-27T11:31:22.000000Z", "rentalID": 7273, "returnDate": "2005-07-31T06:50:22.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 1092, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-28T09:04:45.000000Z", "rentalID": 7841, "returnDate": "2005-07-30T12:37:45.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 4268, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-28T16:18:23.000000Z", "rentalID": 8033, "returnDate": "2005-07-30T17:56:23.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 1558, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-28T17:33:39.000000Z", "rentalID": 8074, "returnDate": "2005-07-29T20:17:39.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 4497, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-28T19:20:07.000000Z", "rentalID": 8116, "returnDate": "2005-07-29T22:54:07.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 108, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-29T03:58:49.000000Z", "rentalID": 8326, "returnDate": "2005-08-01T05:16:49.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 2219, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-31T02:42:18.000000Z", "rentalID": 9571, "returnDate": "2005-08-02T23:26:18.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 14, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-01T08:51:04.000000Z", "rentalID": 10437, "returnDate": "2005-08-10T12:12:04.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 3232, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-02T15:36:52.000000Z", "rentalID": 11299, "returnDate": "2005-08-10T16:40:52.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 1440, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-02T18:01:38.000000Z", "rentalID": 11367, "returnDate": "2005-08-04T13:19:38.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 2639, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-17T12:37:54.000000Z", "rentalID": 11824, "returnDate": "2005-08-19T10:11:54.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 921, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-18T03:57:29.000000Z", "rentalID": 12250, "returnDate": "2005-08-22T23:05:29.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 3019, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-19T09:55:16.000000Z", "rentalID": 13068, "returnDate": "2005-08-20T14:44:16.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 2269, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-19T13:56:54.000000Z", "rentalID": 13176, "returnDate": "2005-08-23T08:50:54.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 4249, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-21T23:33:57.000000Z", "rentalID": 14762, "returnDate": "2005-08-23T01:30:57.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 1449, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-22T01:27:57.000000Z", "rentalID": 14825, "returnDate": "2005-08-27T07:01:57.000000Z", "staffID": 2}, {"customerID": 1, "inventoryID": 1446, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-22T19:41:37.000000Z", "rentalID": 15298, "returnDate": "2005-08-28T22:49:37.000000Z", "staffID": 1}, {"customerID": 1, "inventoryID": 312, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-22T20:03:46.000000Z", "rentalID": 15315, "returnDate": "2005-08-30T01:51:46.000000Z", "staffID": 2}], "address": "1913 Hanoi Way", "city": "Sasebo", "country": "Japan", "id": 1, "name": "Mary Smith", "notes": "active", "phone": "28303384290", "sid": 1, "zip code": "35200"}, {"Rentals": [{"customerID": 2, "inventoryID": 1090, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-05-27T00:09:24.000000Z", "rentalID": 320, "returnDate": "2005-05-28T04:30:24.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 352, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-06-17T20:54:58.000000Z", "rentalID": 2128, "returnDate": "2005-06-24T00:41:58.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 4116, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-10T06:31:24.000000Z", "rentalID": 5636, "returnDate": "2005-07-13T02:36:24.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 2760, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-10T12:38:56.000000Z", "rentalID": 5755, "returnDate": "2005-07-19T17:02:56.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 741, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-27T14:30:42.000000Z", "rentalID": 7346, "returnDate": "2005-08-02T16:48:42.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 488, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-27T15:23:02.000000Z", "rentalID": 7376, "returnDate": "2005-08-04T10:35:02.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 2053, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-27T18:40:20.000000Z", "rentalID": 7459, "returnDate": "2005-08-02T21:07:20.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 1937, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-29T00:12:59.000000Z", "rentalID": 8230, "returnDate": "2005-08-06T19:52:59.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 626, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-29T12:56:59.000000Z", "rentalID": 8598, "returnDate": "2005-08-01T08:39:59.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 4038, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-29T17:14:29.000000Z", "rentalID": 8705, "returnDate": "2005-08-02T16:01:29.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 2377, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-30T06:06:10.000000Z", "rentalID": 9031, "returnDate": "2005-08-04T10:45:10.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 4030, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-30T13:47:43.000000Z", "rentalID": 9236, "returnDate": "2005-08-08T18:52:43.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 1382, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-30T14:14:11.000000Z", "rentalID": 9248, "returnDate": "2005-08-05T11:19:11.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 4088, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-30T16:21:13.000000Z", "rentalID": 9296, "returnDate": "2005-08-08T11:57:13.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 3084, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-30T22:39:53.000000Z", "rentalID": 9465, "returnDate": "2005-08-06T16:43:53.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 3142, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-07-31T21:58:56.000000Z", "rentalID": 10136, "returnDate": "2005-08-03T19:44:56.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 138, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-01T09:45:26.000000Z", "rentalID": 10466, "returnDate": "2005-08-06T06:28:26.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 3418, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-02T02:10:56.000000Z", "rentalID": 10918, "returnDate": "2005-08-02T21:23:56.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 654, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-02T07:41:41.000000Z", "rentalID": 11087, "returnDate": "2005-08-10T10:37:41.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 1149, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-02T10:43:48.000000Z", "rentalID": 11177, "returnDate": "2005-08-10T10:55:48.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 2060, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-02T13:44:53.000000Z", "rentalID": 11256, "returnDate": "2005-08-04T16:39:53.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 805, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-17T03:52:18.000000Z", "rentalID": 11614, "returnDate": "2005-08-20T07:04:18.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 1521, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-19T06:26:04.000000Z", "rentalID": 12963, "returnDate": "2005-08-23T11:37:04.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 3164, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-21T13:24:32.000000Z", "rentalID": 14475, "returnDate": "2005-08-27T08:59:32.000000Z", "staffID": 2}, {"customerID": 2, "inventoryID": 4570, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-21T22:41:56.000000Z", "rentalID": 14743, "returnDate": "2005-08-29T00:18:56.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 2179, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-22T13:53:04.000000Z", "rentalID": 15145, "returnDate": "2005-08-31T15:51:04.000000Z", "staffID": 1}, {"customerID": 2, "inventoryID": 2898, "lastUpdate": "2006-02-16T02:30:53.000000Z", "rentalDate": "2005-08-23T17:39:35.000000Z", "rentalID": 15907, "returnDate": "2005-08-25T23:23:35.000000Z", "staffID": 1}], "address": "1121 Loja Avenue", "city": "San Bernardino", "country": "United States", "id": 2, "name": "Patricia Johnson", "notes": "active", "phone": "838635286649", "sid": 1, "zip code": "17886"}]`) + } else { + require.Equal(t, string(dest.RawJson), `[{"id":1,"name":"Mary Smith","address":"1913 Hanoi Way","zip code":"35200","phone":"28303384290","city":"Sasebo","country":"Japan","notes":"active","sid":1,"Rentals":[{"rentalID":76,"rentalDate":"2005-05-25T11:30:37.000000Z","inventoryID":3021,"customerID":1,"returnDate":"2005-06-03T12:00:37.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":573,"rentalDate":"2005-05-28T10:35:23.000000Z","inventoryID":4020,"customerID":1,"returnDate":"2005-06-03T06:32:23.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":1185,"rentalDate":"2005-06-15T00:54:12.000000Z","inventoryID":2785,"customerID":1,"returnDate":"2005-06-23T02:42:12.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":1422,"rentalDate":"2005-06-15T18:02:53.000000Z","inventoryID":1021,"customerID":1,"returnDate":"2005-06-19T15:54:53.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":1476,"rentalDate":"2005-06-15T21:08:46.000000Z","inventoryID":1407,"customerID":1,"returnDate":"2005-06-25T02:26:46.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":1725,"rentalDate":"2005-06-16T15:18:57.000000Z","inventoryID":726,"customerID":1,"returnDate":"2005-06-17T21:05:57.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":2308,"rentalDate":"2005-06-18T08:41:48.000000Z","inventoryID":197,"customerID":1,"returnDate":"2005-06-22T03:36:48.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":2363,"rentalDate":"2005-06-18T13:33:59.000000Z","inventoryID":3497,"customerID":1,"returnDate":"2005-06-19T17:40:59.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":3284,"rentalDate":"2005-06-21T06:24:45.000000Z","inventoryID":4566,"customerID":1,"returnDate":"2005-06-28T03:28:45.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":4526,"rentalDate":"2005-07-08T03:17:05.000000Z","inventoryID":1443,"customerID":1,"returnDate":"2005-07-14T01:19:05.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":4611,"rentalDate":"2005-07-08T07:33:56.000000Z","inventoryID":3486,"customerID":1,"returnDate":"2005-07-12T13:25:56.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":5244,"rentalDate":"2005-07-09T13:24:07.000000Z","inventoryID":3726,"customerID":1,"returnDate":"2005-07-14T14:01:07.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":5326,"rentalDate":"2005-07-09T16:38:01.000000Z","inventoryID":797,"customerID":1,"returnDate":"2005-07-13T18:02:01.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":6163,"rentalDate":"2005-07-11T10:13:46.000000Z","inventoryID":1330,"customerID":1,"returnDate":"2005-07-19T13:15:46.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":7273,"rentalDate":"2005-07-27T11:31:22.000000Z","inventoryID":2465,"customerID":1,"returnDate":"2005-07-31T06:50:22.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":7841,"rentalDate":"2005-07-28T09:04:45.000000Z","inventoryID":1092,"customerID":1,"returnDate":"2005-07-30T12:37:45.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":8033,"rentalDate":"2005-07-28T16:18:23.000000Z","inventoryID":4268,"customerID":1,"returnDate":"2005-07-30T17:56:23.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":8074,"rentalDate":"2005-07-28T17:33:39.000000Z","inventoryID":1558,"customerID":1,"returnDate":"2005-07-29T20:17:39.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":8116,"rentalDate":"2005-07-28T19:20:07.000000Z","inventoryID":4497,"customerID":1,"returnDate":"2005-07-29T22:54:07.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":8326,"rentalDate":"2005-07-29T03:58:49.000000Z","inventoryID":108,"customerID":1,"returnDate":"2005-08-01T05:16:49.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":9571,"rentalDate":"2005-07-31T02:42:18.000000Z","inventoryID":2219,"customerID":1,"returnDate":"2005-08-02T23:26:18.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":10437,"rentalDate":"2005-08-01T08:51:04.000000Z","inventoryID":14,"customerID":1,"returnDate":"2005-08-10T12:12:04.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":11299,"rentalDate":"2005-08-02T15:36:52.000000Z","inventoryID":3232,"customerID":1,"returnDate":"2005-08-10T16:40:52.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":11367,"rentalDate":"2005-08-02T18:01:38.000000Z","inventoryID":1440,"customerID":1,"returnDate":"2005-08-04T13:19:38.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":11824,"rentalDate":"2005-08-17T12:37:54.000000Z","inventoryID":2639,"customerID":1,"returnDate":"2005-08-19T10:11:54.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":12250,"rentalDate":"2005-08-18T03:57:29.000000Z","inventoryID":921,"customerID":1,"returnDate":"2005-08-22T23:05:29.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":13068,"rentalDate":"2005-08-19T09:55:16.000000Z","inventoryID":3019,"customerID":1,"returnDate":"2005-08-20T14:44:16.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":13176,"rentalDate":"2005-08-19T13:56:54.000000Z","inventoryID":2269,"customerID":1,"returnDate":"2005-08-23T08:50:54.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":14762,"rentalDate":"2005-08-21T23:33:57.000000Z","inventoryID":4249,"customerID":1,"returnDate":"2005-08-23T01:30:57.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":14825,"rentalDate":"2005-08-22T01:27:57.000000Z","inventoryID":1449,"customerID":1,"returnDate":"2005-08-27T07:01:57.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":15298,"rentalDate":"2005-08-22T19:41:37.000000Z","inventoryID":1446,"customerID":1,"returnDate":"2005-08-28T22:49:37.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":15315,"rentalDate":"2005-08-22T20:03:46.000000Z","inventoryID":312,"customerID":1,"returnDate":"2005-08-30T01:51:46.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}]}, {"id":2,"name":"Patricia Johnson","address":"1121 Loja Avenue","zip code":"17886","phone":"838635286649","city":"San Bernardino","country":"United States","notes":"active","sid":1,"Rentals":[{"rentalID":320,"rentalDate":"2005-05-27T00:09:24.000000Z","inventoryID":1090,"customerID":2,"returnDate":"2005-05-28T04:30:24.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":2128,"rentalDate":"2005-06-17T20:54:58.000000Z","inventoryID":352,"customerID":2,"returnDate":"2005-06-24T00:41:58.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":5636,"rentalDate":"2005-07-10T06:31:24.000000Z","inventoryID":4116,"customerID":2,"returnDate":"2005-07-13T02:36:24.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":5755,"rentalDate":"2005-07-10T12:38:56.000000Z","inventoryID":2760,"customerID":2,"returnDate":"2005-07-19T17:02:56.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":7346,"rentalDate":"2005-07-27T14:30:42.000000Z","inventoryID":741,"customerID":2,"returnDate":"2005-08-02T16:48:42.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":7376,"rentalDate":"2005-07-27T15:23:02.000000Z","inventoryID":488,"customerID":2,"returnDate":"2005-08-04T10:35:02.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":7459,"rentalDate":"2005-07-27T18:40:20.000000Z","inventoryID":2053,"customerID":2,"returnDate":"2005-08-02T21:07:20.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":8230,"rentalDate":"2005-07-29T00:12:59.000000Z","inventoryID":1937,"customerID":2,"returnDate":"2005-08-06T19:52:59.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":8598,"rentalDate":"2005-07-29T12:56:59.000000Z","inventoryID":626,"customerID":2,"returnDate":"2005-08-01T08:39:59.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":8705,"rentalDate":"2005-07-29T17:14:29.000000Z","inventoryID":4038,"customerID":2,"returnDate":"2005-08-02T16:01:29.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":9031,"rentalDate":"2005-07-30T06:06:10.000000Z","inventoryID":2377,"customerID":2,"returnDate":"2005-08-04T10:45:10.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":9236,"rentalDate":"2005-07-30T13:47:43.000000Z","inventoryID":4030,"customerID":2,"returnDate":"2005-08-08T18:52:43.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":9248,"rentalDate":"2005-07-30T14:14:11.000000Z","inventoryID":1382,"customerID":2,"returnDate":"2005-08-05T11:19:11.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":9296,"rentalDate":"2005-07-30T16:21:13.000000Z","inventoryID":4088,"customerID":2,"returnDate":"2005-08-08T11:57:13.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":9465,"rentalDate":"2005-07-30T22:39:53.000000Z","inventoryID":3084,"customerID":2,"returnDate":"2005-08-06T16:43:53.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":10136,"rentalDate":"2005-07-31T21:58:56.000000Z","inventoryID":3142,"customerID":2,"returnDate":"2005-08-03T19:44:56.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":10466,"rentalDate":"2005-08-01T09:45:26.000000Z","inventoryID":138,"customerID":2,"returnDate":"2005-08-06T06:28:26.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":10918,"rentalDate":"2005-08-02T02:10:56.000000Z","inventoryID":3418,"customerID":2,"returnDate":"2005-08-02T21:23:56.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":11087,"rentalDate":"2005-08-02T07:41:41.000000Z","inventoryID":654,"customerID":2,"returnDate":"2005-08-10T10:37:41.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":11177,"rentalDate":"2005-08-02T10:43:48.000000Z","inventoryID":1149,"customerID":2,"returnDate":"2005-08-10T10:55:48.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":11256,"rentalDate":"2005-08-02T13:44:53.000000Z","inventoryID":2060,"customerID":2,"returnDate":"2005-08-04T16:39:53.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":11614,"rentalDate":"2005-08-17T03:52:18.000000Z","inventoryID":805,"customerID":2,"returnDate":"2005-08-20T07:04:18.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":12963,"rentalDate":"2005-08-19T06:26:04.000000Z","inventoryID":1521,"customerID":2,"returnDate":"2005-08-23T11:37:04.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":14475,"rentalDate":"2005-08-21T13:24:32.000000Z","inventoryID":3164,"customerID":2,"returnDate":"2005-08-27T08:59:32.000000Z","staffID":2,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":14743,"rentalDate":"2005-08-21T22:41:56.000000Z","inventoryID":4570,"customerID":2,"returnDate":"2005-08-29T00:18:56.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":15145,"rentalDate":"2005-08-22T13:53:04.000000Z","inventoryID":2179,"customerID":2,"returnDate":"2005-08-31T15:51:04.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}, {"rentalID":15907,"rentalDate":"2005-08-23T17:39:35.000000Z","inventoryID":2898,"customerID":2,"returnDate":"2005-08-25T23:23:35.000000Z","staffID":1,"lastUpdate":"2006-02-16T02:30:53.000000Z"}]}]`) + } +} + +func TestSelectJsonObject_EmptyResult(t *testing.T) { + t.Run("json obj", func(t *testing.T) { + stmt := SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.FirstName.EQ(Text("Kowalski"))) + + var dest model.Actor + + err := stmt.QueryContext(ctx, db, &dest) + require.ErrorIs(t, err, qrm.ErrNoRows) + }) + + t.Run("json arr", func(t *testing.T) { + stmt := SELECT_JSON_ARR(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.FirstName.EQ(Text("Kowalski"))) + + var dest []model.Actor + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + require.Empty(t, dest) + }) +} + +func TestSelectJson_InvalidDestination(t *testing.T) { + t.Run("json obj", func(t *testing.T) { + stmt := SELECT_JSON_OBJ(Actor.AllColumns). + FROM(Actor) + + testutils.AssertQueryPanicErr(t, stmt, db, &[]model.Actor{}, "jet: SELECT_JSON_OBJ destination has to be a pointer to struct or pointer to map[string]any") + testutils.AssertQueryPanicErr(t, stmt, db, model.Actor{}, "jet: SELECT_JSON_OBJ destination has to be a pointer to struct or pointer to map[string]any") + testutils.AssertQueryPanicErr(t, stmt, nil, &model.Actor{}, "jet: db is nil") + testutils.AssertQueryPanicErr(t, stmt, db, nil, "jet: destination is nil") + }) + + t.Run("json arr", func(t *testing.T) { + stmt := SELECT_JSON_ARR(Actor.AllColumns). + FROM(Actor) + + testutils.AssertQueryPanicErr(t, stmt, db, &model.Actor{}, "jet: SELECT_JSON_ARR destination has to be a pointer to slice of struct or pointer to []map[string]any") + testutils.AssertQueryPanicErr(t, stmt, db, []model.Actor{}, "jet: SELECT_JSON_ARR destination has to be a pointer to slice of struct or pointer to []map[string]any") + testutils.AssertQueryPanicErr(t, stmt, nil, &[]model.Actor{}, "jet: db is nil") + testutils.AssertQueryPanicErr(t, stmt, db, nil, "jet: destination is nil") + }) +} + +func TestSelectJson_ProjectionNotAliased(t *testing.T) { + t.Run("statement not aliased", func(t *testing.T) { + testutils.AssertPanicErr(t, func() { + stmt := SELECT_JSON_ARR( + Customer.AllColumns, + + SELECT_JSON_ARR(Rental.AllColumns). + FROM(Rental). + WHERE(Rental.CustomerID.EQ(Customer.CustomerID)), + ).FROM(Customer) + + stmt.DebugSql() + + }, "jet: SELECT JSON statements need to be aliased when used as a projection.") + }) + + t.Run("expression not aliased", func(t *testing.T) { + testutils.AssertPanicErr(t, func() { + stmt := SELECT_JSON_ARR( + Int(2).ADD(Customer.CustomerID), + ).FROM(Customer) + + stmt.DebugSql() + + }, "jet: expression need to be aliased when used as SELECT JSON projection.") + }) +} + +func TestSelectJson_InvalidJson(t *testing.T) { + stmt := SELECT( + Bytea("}invalid json {()").AS("invalid_json"), + ) + + var dest struct { + InvalidJson []byte `json_column:"invalid_json"` + } + err := stmt.QueryContext(ctx, db, &dest) + require.ErrorContains(t, err, "invalid json") +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index d39987d..3716a70 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -21,36 +21,36 @@ import ( ) func TestSelect_ScanToStruct(t *testing.T) { - expectedSQL := ` + + t.Run("standard", func(t *testing.T) { + stmt := SELECT(Actor.AllColumns). + DISTINCT(). + FROM(Actor). + WHERE(Actor.ActorID.EQ(Int(2))) + + testutils.AssertDebugStatementSql(t, stmt, ` SELECT DISTINCT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor WHERE actor.actor_id = 2; -` +`, int64(2)) - query := SELECT(Actor.AllColumns). - DISTINCT(). - FROM(Actor). - WHERE(Actor.ActorID.EQ(Int(2))) + var dest model.Actor + err := stmt.Query(db, &dest) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2)) + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest, actor2) + requireLogged(t, stmt) + }) +} - actor := model.Actor{} - err := query.Query(db, &actor) - - require.NoError(t, err) - - expectedActor := model.Actor{ - ActorID: 2, - FirstName: "Nick", - LastName: "Wahlberg", - LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), - } - - testutils.AssertDeepEqual(t, actor, expectedActor) - requireLogged(t, query) +var actor2 = model.Actor{ + ActorID: 2, + FirstName: "Nick", + LastName: "Wahlberg", + LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), } func TestSelectDistinctOn(t *testing.T) { @@ -85,7 +85,6 @@ ORDER BY rental.staff_id ASC, rental.customer_id ASC, rental.rental_id ASC; err := stmt.Query(db, &dest) require.NoError(t, err) - testutils.AssertJSON(t, dest, ` [ { @@ -187,6 +186,21 @@ ORDER BY customer.customer_id ASC; testutils.AssertDeepEqual(t, lastCustomer, customers[598]) requireLogged(t, query) + + t.Run("select json", func(t *testing.T) { + stmt := SELECT_JSON_ARR( + Customer.AllColumns, + ).FROM( + Customer, + ).ORDER_BY(Customer.CustomerID.ASC()) + + var dest []model.Customer + + err := stmt.QueryContext(ctx, db, &dest) + require.NoError(t, err) + + testutils.AssertDeepEqual(t, customers, dest) + }) } func TestSelectAndUnionInProjection(t *testing.T) { @@ -217,15 +231,14 @@ FROM dvds.payment LIMIT 12; ` - query := Payment. - SELECT( - Payment.PaymentID, - Customer.SELECT(Customer.CustomerID).LIMIT(1), - UNION( - Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), - Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2), - ).LIMIT(1), - ). + query := SELECT( + Payment.PaymentID, + Customer.SELECT(Customer.CustomerID).LIMIT(1), + UNION( + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2), + ).LIMIT(1), + ).FROM(Payment). LIMIT(12) //fmt.Println(query.DebugSql()) @@ -2771,7 +2784,8 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; err := stmt.Query(db, &dest) require.NoError(t, err) - //jsonSave("./testdata/quick-start-dest.json", dest) + //testutils.SaveJSONFile(dest, "./testdata/results/postgres/quick-start-dest.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") var dest2 []struct { @@ -2784,7 +2798,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; err = stmt.Query(db, &dest2) require.NoError(t, err) - //jsonSave("./testdata/quick-start-dest2.json", dest2) + //testutils.SaveJSONFile(dest2, "./testdata/results/postgres/quick-start-dest2.json") testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") } @@ -2966,7 +2980,7 @@ WHERE payment.payment_id < $1 WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) ORDER BY payment.customer_id; ` - query := Payment.SELECT( + query := SELECT( AVG(Payment.Amount).OVER(), AVG(Payment.Amount).OVER(Window("w1")), AVG(Payment.Amount).OVER( @@ -2976,6 +2990,7 @@ ORDER BY payment.customer_id; ), AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), ). + FROM(Payment). WHERE(Payment.PaymentID.LT(Int(10))). WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). WINDOW("w2").AS(Window("w1")). diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 080e870..b676bb4 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "encoding/hex" "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/sqlite" @@ -18,7 +19,7 @@ import ( func TestAllTypes(t *testing.T) { - dest := []model.AllTypes{} + var dest []model.AllTypes err := SELECT(AllTypes.AllColumns). FROM(AllTypes). @@ -571,12 +572,102 @@ func TestStringOperators(t *testing.T) { SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), ).FROM(AllTypes) - dest := []struct{}{} + var dest []struct{} err := query.Query(sampleDB, &dest) require.NoError(t, err) } +func TestBlob(t *testing.T) { + + var sampleBlob = Blob([]byte{11, 0, 22, 33, 44}) + var textBlob = Blob([]byte("text blob")) + + stmt := SELECT( + AllTypes.Blob.EQ(sampleBlob), + AllTypes.Blob.EQ(AllTypes.BlobPtr), + AllTypes.Blob.NOT_EQ(sampleBlob), + AllTypes.Blob.GT(textBlob), + AllTypes.Blob.GT_EQ(AllTypes.BlobPtr), + AllTypes.Blob.LT(AllTypes.BlobPtr), + AllTypes.Blob.LT_EQ(sampleBlob), + AllTypes.Blob.BETWEEN(Blob([]byte("min")), Blob([]byte("max"))), + AllTypes.Blob.NOT_BETWEEN(AllTypes.Blob, AllTypes.BlobPtr), + AllTypes.Blob.CONCAT(textBlob), + AllTypes.Blob.LIKE(AllTypes.BlobPtr), + AllTypes.Blob.NOT_LIKE(sampleBlob), + + RTRIM(AllTypes.BlobPtr, sampleBlob), + LTRIM(sampleBlob, textBlob), + LENGTH(sampleBlob), + OCTET_LENGTH(textBlob), + SUBSTR(AllTypes.Blob, Int(0), Int(2)), + + HEX(AllTypes.Blob), + UNHEX(AllTypes.Text), + ).FROM( + AllTypes, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT all_types.blob = X'0b0016212c', + all_types.blob = all_types.blob_ptr, + all_types.blob != X'0b0016212c', + all_types.blob > X'7465787420626c6f62', + all_types.blob >= all_types.blob_ptr, + all_types.blob < all_types.blob_ptr, + all_types.blob <= X'0b0016212c', + all_types.blob BETWEEN X'6d696e' AND X'6d6178', + all_types.blob NOT BETWEEN all_types.blob AND all_types.blob_ptr, + all_types.blob || X'7465787420626c6f62', + all_types.blob LIKE all_types.blob_ptr, + all_types.blob NOT LIKE X'0b0016212c', + RTRIM(all_types.blob_ptr, X'0b0016212c'), + LTRIM(X'0b0016212c', X'7465787420626c6f62'), + LENGTH(X'0b0016212c'), + OCTET_LENGTH(X'7465787420626c6f62'), + SUBSTR(all_types.blob, 0, 2), + HEX(all_types.blob), + UNHEX(all_types.text) +FROM all_types; +`) + + var dest []struct{} + err := stmt.Query(sampleDB, &dest) + + require.NoError(t, err) +} + +func TestBlobConversion(t *testing.T) { + + nonPrintable := []byte{0x11, 0x22, 0x33, 0x44, 0x55} + printable := []byte("this is blob") + + stmt := SELECT( + Blob(nonPrintable).AS("non_printable"), + Blob(printable).AS("printable"), + + HEX(Blob(nonPrintable)).AS("non_printable_hex"), + UNHEX(String("1122334455")).AS("non_printable_unhex"), + ) + + var dest struct { + NonPrintable []byte + Printable []byte + + NonPrintableHex string + NonPrintableUnHex []byte + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Equal(t, dest.NonPrintable, nonPrintable) + require.Equal(t, dest.Printable, printable) + require.Equal(t, dest.NonPrintableHex, hex.EncodeToString(nonPrintable)) + require.Equal(t, dest.NonPrintableUnHex, nonPrintable) +} + func TestReservedWord(t *testing.T) { stmt := SELECT(ReservedWords.AllColumns). FROM(ReservedWords) diff --git a/tests/sqlite/sample_test.go b/tests/sqlite/sample_test.go index 0655f1f..2740f8b 100644 --- a/tests/sqlite/sample_test.go +++ b/tests/sqlite/sample_test.go @@ -2,8 +2,8 @@ package sqlite import ( "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/go-jet/jet/v2/qrm" "github.com/stretchr/testify/require" "testing" diff --git a/tests/testdata b/tests/testdata index 1c501ac..b0ff9d7 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1c501acb72bea389788404988ef0130b733f9cee +Subproject commit b0ff9d75f2b829a83706b485fba4aaf8563860e2