Merge pull request #457 from go-jet/select_json

Add support for SELECT_JSON statements
This commit is contained in:
go-jet 2025-03-09 18:43:49 +01:00 committed by GitHub
commit 1f3215c879
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
104 changed files with 5249 additions and 900 deletions

View file

@ -11,7 +11,7 @@ jobs:
- image: cimg/go:1.22.8
# Please keep the version in sync with test/docker-compose.yaml
- image: cimg/postgres:14.10
- image: cimg/postgres:14.1
environment:
POSTGRES_USER: jet
POSTGRES_PASSWORD: jet
@ -19,7 +19,7 @@ jobs:
PGPORT: 50901
# Please keep the version in sync with test/docker-compose.yaml
- image: circleci/mysql:8.0.27
- image: cimg/mysql:8.0.27
command: [ --default-authentication-plugin=mysql_native_password ]
environment:
MYSQL_ROOT_PASSWORD: jet
@ -29,7 +29,7 @@ jobs:
MYSQL_TCP_PORT: 50902
# Please keep the version in sync with test/docker-compose.yaml
- image: circleci/mariadb:10.3
- image: cimg/mariadb:11.4
command: [ '--default-authentication-plugin=mysql_native_password', '--port=50903' ]
environment:
MYSQL_ROOT_PASSWORD: jet
@ -116,25 +116,27 @@ jobs:
name: Create MySQL/MariaDB user and test databases
command: |
mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50902 -u root -pjet -e "set global sql_mode = 'STRICT_ALL_TABLES,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -P 50902 -u jet -pjet -e "create database dvds2"
mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "grant all privileges on *.* to 'jet'@'%';"
mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50903 -u root -pjet -e "set global sql_mode = 'STRICT_ALL_TABLES,STRICT_TRANS_TABLES,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';"
mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database test_sample"
mysql -h 127.0.0.1 -P 50903 -u jet -pjet -e "create database dvds2"
- run:
name: Init databases
command: |
cd tests
go run ./init/init.go -testsuite all
- run:
name: Install gotestsum
command: go install gotest.tools/gotestsum@latest
- run:
name: Init databases (postgres, mysql, sqlite) and generate jet files
command: |
cd tests
go run ./init/init.go -testsuite postgres
go run ./init/init.go -testsuite mysql
go run ./init/init.go -testsuite sqlite
# to create test results report
- run: mkdir -p $TEST_RESULTS
@ -146,14 +148,14 @@ jobs:
name: Running tests with statement caching enabled
command: JET_TESTS_WITH_STMT_CACHE=true go test -tags postgres -v ./tests/...
# run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests
- run:
name: Jet generate mariadb and cockroachdb
name: Init databases (mariadb, cockroachdb) and generate jet files
command: |
cd tests
make jet-gen-mariadb
make jet-gen-cockroach
go run ./init/init.go -testsuite mariadb
go run ./init/init.go -testsuite cockroach
# run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests
- run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/
- run: PG_SOURCE=COCKROACH_DB go test -v ./tests/postgres/

View file

@ -579,5 +579,5 @@ To run the tests, additional dependencies are required:
## License
Copyright 2019-2024 Goran Bjelanovic
Copyright 2019-2025 Goran Bjelanovic
Licensed under the Apache License, Version 2.0.

View file

@ -18,8 +18,8 @@ func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTyp
SELECT
t.table_name as "table.name",
col.COLUMN_NAME AS "column.Name",
col.COLUMN_DEFAULT IS NOT NULL AND t.table_type != 'VIEW' as "column.HasDefault",
col.IS_NULLABLE = "YES" AS "column.IsNullable",
(col.COLUMN_DEFAULT IS NOT NULL AND col.COLUMN_DEFAULT != 'NULL') AND t.table_type != 'VIEW' as "column.HasDefault",
col.IS_NULLABLE = 'YES' AS "column.IsNullable",
col.COLUMN_COMMENT AS "column.Comment",
COALESCE(pk.IsPrimaryKey, 0) AS "column.IsPrimaryKey",
IF (col.COLUMN_TYPE = 'tinyint(1)',

View file

@ -180,11 +180,15 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string {
return "Timez"
case "interval":
return "Interval"
case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid",
case "user-defined", "enum", "text", "character", "character varying", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
"char", "varchar", "nvarchar", "bpchar", "varbit",
"tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "bytea": // postgres
return "Bytea"
case "binary", "varbinary", "tinyblob", "mediumblob", "longblob", "blob": // mysql and sqlite
return "Blob"
case "real", "numeric", "decimal", "double precision", "float", "float4", "float8",
"double": // MySQL
return "Float"

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/go-jet/jet/v2
go 1.21
go 1.22
// used by jet generator
require (

View file

@ -40,14 +40,23 @@ func snakeToCamel(s string, upperCase bool) string {
if upperCase || i > 0 {
result += camelizeWord(word, len(words) > 1)
} else {
result += word
} else { // lowerCase and i == 0
result += toLowerFirstLetter(word)
}
}
return result
}
func toLowerFirstLetter(s string) string {
if s == "" {
return s
}
runes := []rune(s)
runes[0] = unicode.ToLower(runes[0])
return string(runes)
}
func camelizeWord(word string, force bool) string {
runes := []rune(word)

View file

@ -7,7 +7,10 @@ import (
func TestSnakeToCamel(t *testing.T) {
require.Equal(t, SnakeToCamel(""), "")
require.Equal(t, SnakeToCamel("_", false), "")
require.Equal(t, SnakeToCamel("potato_"), "Potato")
require.Equal(t, SnakeToCamel("potato_", false), "potato")
require.Equal(t, SnakeToCamel("Potato_", false), "potato")
require.Equal(t, SnakeToCamel("this_has_to_be_uppercased"), "ThisHasToBeUppercased")
require.Equal(t, SnakeToCamel("this_is_an_id"), "ThisIsAnID")
require.Equal(t, SnakeToCamel("this_is_an_identifier"), "ThisIsAnIdentifier")

View file

@ -18,10 +18,45 @@ func (a *alias) fromImpl(subQuery SelectTable) Projection {
// Generated columns have default aliasing.
tableName, columnName := extractTableAndColumnName(a.alias)
column := NewColumnImpl(columnName, tableName, nil)
column.subQuery = subQuery
newDummyColumn := newDummyColumnForExpression(a.expression, columnName)
newDummyColumn.setTableName(tableName)
newDummyColumn.setSubQuery(subQuery)
return &column
return newDummyColumn
}
// This function is used to create dummy columns when exporting sub-query columns using subQuery.AllColumns()
// In most case we don't care about type of the column, except when sub-query columns are used as SELECT_JSON projection.
// We need to know type to encode value for json unmarshal. At the moment only bool, time and blob columns are of interest,
// so we don't have to support every column type.
func newDummyColumnForExpression(exp Expression, name string) ColumnExpression {
switch exp.(type) {
case BoolExpression:
return BoolColumn(name)
case IntegerExpression:
return IntegerColumn(name)
case FloatExpression:
return FloatColumn(name)
case BlobExpression:
return BlobColumn(name)
case DateExpression:
return DateColumn(name)
case TimeExpression:
return TimeColumn(name)
case TimezExpression:
return TimezColumn(name)
case TimestampExpression:
return TimestampColumn(name)
case TimestampzExpression:
return TimestampzColumn(name)
case IntervalExpression:
return IntervalColumn(name)
case StringExpression:
return StringColumn(name)
}
return StringColumn(name)
}
func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder) {
@ -30,3 +65,15 @@ func (a *alias) serializeForProjection(statement StatementType, out *SQLBuilder)
out.WriteString("AS")
out.WriteAlias(a.alias)
}
func (a *alias) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
out.WriteJsonObjKey(a.alias)
a.expression.serializeForJsonValue(statement, out)
}
func (a *alias) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
a.expression.serializeForJsonValue(statement, out)
out.WriteString("AS")
out.WriteAlias(a.alias)
}

View file

@ -0,0 +1,104 @@
package jet
// BlobExpression interface
type BlobExpression interface {
Expression
isStringOrBlob()
EQ(rhs BlobExpression) BoolExpression
NOT_EQ(rhs BlobExpression) BoolExpression
IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression
LT(rhs BlobExpression) BoolExpression
LT_EQ(rhs BlobExpression) BoolExpression
GT(rhs BlobExpression) BoolExpression
GT_EQ(rhs BlobExpression) BoolExpression
BETWEEN(min, max BlobExpression) BoolExpression
NOT_BETWEEN(min, max BlobExpression) BoolExpression
CONCAT(rhs BlobExpression) BlobExpression
LIKE(pattern BlobExpression) BoolExpression
NOT_LIKE(pattern BlobExpression) BoolExpression
}
type blobInterfaceImpl struct {
parent BlobExpression
}
func (b *blobInterfaceImpl) isStringOrBlob() {}
func (b *blobInterfaceImpl) EQ(rhs BlobExpression) BoolExpression {
return Eq(b.parent, rhs)
}
func (b *blobInterfaceImpl) NOT_EQ(rhs BlobExpression) BoolExpression {
return NotEq(b.parent, rhs)
}
func (b *blobInterfaceImpl) IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression {
return IsDistinctFrom(b.parent, rhs)
}
func (b *blobInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression {
return IsNotDistinctFrom(b.parent, rhs)
}
func (b *blobInterfaceImpl) GT(rhs BlobExpression) BoolExpression {
return Gt(b.parent, rhs)
}
func (b *blobInterfaceImpl) GT_EQ(rhs BlobExpression) BoolExpression {
return GtEq(b.parent, rhs)
}
func (b *blobInterfaceImpl) LT(rhs BlobExpression) BoolExpression {
return Lt(b.parent, rhs)
}
func (b *blobInterfaceImpl) LT_EQ(rhs BlobExpression) BoolExpression {
return LtEq(b.parent, rhs)
}
func (b *blobInterfaceImpl) BETWEEN(min, max BlobExpression) BoolExpression {
return NewBetweenOperatorExpression(b.parent, min, max, false)
}
func (b *blobInterfaceImpl) NOT_BETWEEN(min, max BlobExpression) BoolExpression {
return NewBetweenOperatorExpression(b.parent, min, max, true)
}
func (b *blobInterfaceImpl) CONCAT(rhs BlobExpression) BlobExpression {
return BlobExp(newBinaryStringOperatorExpression(b.parent, rhs, StringConcatOperator))
}
func (b *blobInterfaceImpl) LIKE(pattern BlobExpression) BoolExpression {
return newBinaryBoolOperatorExpression(b.parent, pattern, "LIKE")
}
func (b *blobInterfaceImpl) NOT_LIKE(pattern BlobExpression) BoolExpression {
return newBinaryBoolOperatorExpression(b.parent, pattern, "NOT LIKE")
}
//---------------------------------------------------//
type blobExpressionWrapper struct {
Expression
blobInterfaceImpl
}
func newBlobExpressionWrap(expression Expression) BlobExpression {
blobExpressionWrap := &blobExpressionWrapper{Expression: expression}
blobExpressionWrap.blobInterfaceImpl.parent = blobExpressionWrap
expression.setParent(blobExpressionWrap)
return blobExpressionWrap
}
// BlobExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as blob expression.
// Does not add sql cast to generated sql builder output.
func BlobExp(expression Expression) BlobExpression {
return newBlobExpressionWrap(expression)
}

View file

@ -102,9 +102,10 @@ type boolExpressionWrapper struct {
}
func newBoolExpressionWrap(expression Expression) BoolExpression {
boolExpressionWrap := boolExpressionWrapper{Expression: expression}
boolExpressionWrap.boolInterfaceImpl.parent = &boolExpressionWrap
return &boolExpressionWrap
boolExpressionWrap := &boolExpressionWrapper{Expression: expression}
boolExpressionWrap.boolInterfaceImpl.parent = boolExpressionWrap
expression.setParent(boolExpressionWrap)
return boolExpressionWrap
}
// BoolExp is bool expression wrapper around arbitrary expression.

View file

@ -41,6 +41,8 @@ type ClauseSelect struct {
DistinctOnColumns []ColumnExpression
ProjectionList []Projection
IsForRowToJson bool
// MySQL only
OptimizerHints optimizerHints
}
@ -52,6 +54,10 @@ func (s *ClauseSelect) Projections() ProjectionList {
// Serialize serializes clause into SQLBuilder
func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection")
}
out.NewLine()
out.WriteString("SELECT")
s.OptimizerHints.Serialize(statementType, out, options...)
@ -66,11 +72,13 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
out.WriteByte(')')
}
if len(s.ProjectionList) == 0 {
panic("jet: SELECT clause has to have at least one projection")
if s.IsForRowToJson {
out.IncreaseIdent()
out.WriteRowToJsonProjections(statementType, s.ProjectionList)
out.DecreaseIdent()
} else {
out.WriteProjections(statementType, s.ProjectionList)
}
out.WriteProjections(statementType, s.ProjectionList)
}
// ClauseFrom struct

View file

@ -2,6 +2,10 @@
package jet
import (
"github.com/go-jet/jet/v2/internal/3rdparty/snaker"
)
// Column is common column interface for all types of columns.
type Column interface {
Name() string
@ -35,19 +39,19 @@ type ColumnExpressionImpl struct {
}
// NewColumnImpl creates new ColumnExpressionImpl
func NewColumnImpl(name string, tableName string, parent ColumnExpression) ColumnExpressionImpl {
bc := ColumnExpressionImpl{
func NewColumnImpl(name string, tableName string, parent ColumnExpression) *ColumnExpressionImpl {
newColumn := &ColumnExpressionImpl{
name: name,
tableName: tableName,
}
if parent != nil {
bc.ExpressionInterfaceImpl.Parent = parent
newColumn.ExpressionInterfaceImpl.Parent = parent
} else {
bc.ExpressionInterfaceImpl.Parent = &bc
newColumn.ExpressionInterfaceImpl.Parent = newColumn
}
return bc
return newColumn
}
// Name returns name of the column
@ -76,13 +80,6 @@ func (c *ColumnExpressionImpl) defaultAlias() string {
return c.name
}
func (c *ColumnExpressionImpl) fromImpl(subQuery SelectTable) Projection {
newColumn := NewColumnImpl(c.name, c.tableName, nil)
newColumn.setSubQuery(subQuery)
return &newColumn
}
func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
if statement == SetStatementType {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
@ -93,14 +90,28 @@ func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out
c.serialize(statement, out)
}
func (c ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
func (c *ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
c.serialize(statement, out)
out.WriteString("AS")
out.WriteAlias(c.defaultAlias())
}
func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
func (c *ColumnExpressionImpl) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
out.WriteJsonObjKey(snaker.SnakeToCamel(c.name, false))
c.Parent.serializeForJsonValue(statement, out)
}
func (c *ColumnExpressionImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
c.Parent.serializeForJsonValue(statement, out)
out.WriteString("AS")
out.WriteAlias(snaker.SnakeToCamel(c.name, false))
}
func (c *ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias())

View file

@ -78,6 +78,18 @@ func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBui
SerializeProjectionList(statement, projections, out)
}
func (cl ColumnList) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionListJsonObj(statement, projections, out)
}
func (cl ColumnList) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
out.WriteRowToJsonProjections(statement, projections)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface

View file

@ -4,11 +4,10 @@ import "testing"
func TestColumn(t *testing.T) {
column := NewColumnImpl("col", "", nil)
column.ExpressionInterfaceImpl.Parent = &column
assertClauseSerialize(t, column, "col")
column.setTableName("table1")
assertClauseSerialize(t, column, "table1.col")
assertProjectionSerialize(t, &column, `table1.col AS "table1.col"`)
assertProjectionSerialize(t, column, `table1.col AS "table1.col"`)
assertProjectionSerialize(t, column.AS("alias1"), `table1.col AS "alias1"`)
}

View file

@ -11,7 +11,11 @@ type ColumnBool interface {
type boolColumnImpl struct {
boolInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
@ -51,7 +55,11 @@ type ColumnFloat interface {
type floatColumnImpl struct {
floatInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
@ -92,7 +100,11 @@ type ColumnInteger interface {
type integerColumnImpl struct {
integerInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
@ -122,7 +134,7 @@ func IntegerColumn(name string) ColumnInteger {
//------------------------------------------------------//
// ColumnString is interface for SQL text, character, character varying
// bytea, uuid columns and enums types.
// uuid columns and enums types.
type ColumnString interface {
StringExpression
Column
@ -134,7 +146,11 @@ type ColumnString interface {
type stringColumnImpl struct {
stringInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
@ -163,6 +179,51 @@ func StringColumn(name string) ColumnString {
//------------------------------------------------------//
// ColumnBlob is interface for binary data types (bytea, binary, blob, etc...)
type ColumnBlob interface {
BlobExpression
Column
From(subQuery SelectTable) ColumnBlob
SET(blob BlobExpression) ColumnAssigment
}
type blobColumnImpl struct {
blobInterfaceImpl
*ColumnExpressionImpl
}
func (i *blobColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *blobColumnImpl) From(subQuery SelectTable) ColumnBlob {
newBlobColumn := BlobColumn(i.name)
newBlobColumn.setTableName(i.tableName)
newBlobColumn.setSubQuery(subQuery)
return newBlobColumn
}
func (i *blobColumnImpl) SET(blobExp BlobExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: blobExp,
}
}
// BlobColumn creates named blob column.
func BlobColumn(name string) ColumnBlob {
blobColumn := &blobColumnImpl{}
blobColumn.blobInterfaceImpl.parent = blobColumn
blobColumn.ColumnExpressionImpl = NewColumnImpl(name, "", blobColumn)
return blobColumn
}
//------------------------------------------------------//
// ColumnTime is interface for SQL time column.
type ColumnTime interface {
TimeExpression
@ -174,7 +235,11 @@ type ColumnTime interface {
type timeColumnImpl struct {
timeInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
@ -213,7 +278,11 @@ type ColumnTimez interface {
type timezColumnImpl struct {
timezInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
@ -253,7 +322,11 @@ type ColumnTimestamp interface {
type timestampColumnImpl struct {
timestampInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
@ -293,7 +366,11 @@ type ColumnTimestampz interface {
type timestampzColumnImpl struct {
timestampzInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
@ -333,7 +410,11 @@ type ColumnDate interface {
type dateColumnImpl struct {
dateInterfaceImpl
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
@ -361,6 +442,51 @@ func DateColumn(name string) ColumnDate {
//------------------------------------------------------//
// ColumnInterval is interface of PostgreSQL interval columns.
type ColumnInterval interface {
IntervalExpression
Column
From(subQuery SelectTable) ColumnInterval
SET(intervalExp IntervalExpression) ColumnAssigment
}
//------------------------------------------------------//
type intervalColumnImpl struct {
*ColumnExpressionImpl
intervalInterfaceImpl
}
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intervalExp,
}
}
func (i *intervalColumnImpl) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval {
newIntervalColumn := IntervalColumn(i.name)
newIntervalColumn.setTableName(i.tableName)
newIntervalColumn.setSubQuery(subQuery)
return newIntervalColumn
}
// IntervalColumn creates named interval column.
func IntervalColumn(name string) ColumnInterval {
intervalColumn := &intervalColumnImpl{}
intervalColumn.ColumnExpressionImpl = NewColumnImpl(name, "", intervalColumn)
intervalColumn.intervalInterfaceImpl.parent = intervalColumn
return intervalColumn
}
//------------------------------------------------------//
// ColumnRange is interface for range columns which can be int range, string range
// timestamp range or date range.
type ColumnRange[T Expression] interface {
@ -373,7 +499,11 @@ type ColumnRange[T Expression] interface {
type rangeColumnImpl[T Expression] struct {
rangeInterfaceImpl[T]
ColumnExpressionImpl
*ColumnExpressionImpl
}
func (i *rangeColumnImpl[T]) fromImpl(subQuery SelectTable) Projection {
return i.From(subQuery)
}
func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] {

View file

@ -80,9 +80,10 @@ type dateExpressionWrapper struct {
}
func newDateExpressionWrap(expression Expression) DateExpression {
dateExpressionWrap := dateExpressionWrapper{Expression: expression}
dateExpressionWrap.dateInterfaceImpl.parent = &dateExpressionWrap
return &dateExpressionWrap
dateExpressionWrap := &dateExpressionWrapper{Expression: expression}
dateExpressionWrap.dateInterfaceImpl.parent = dateExpressionWrap
expression.setParent(dateExpressionWrap)
return dateExpressionWrap
}
// DateExp is date expression wrapper around arbitrary expression.

View file

@ -1,13 +0,0 @@
package jet
import (
"testing"
)
func TestDateArithmetic(t *testing.T) {
timestamp := Timestamp(2000, 1, 1, 0, 0, 0)
assertClauseDebugSerialize(t, table1ColDate.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_date + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
assertClauseDebugSerialize(t, table1ColDate.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_date - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
}

View file

@ -1,6 +1,8 @@
package jet
import "strings"
import (
"strings"
)
// Dialect interface
type Dialect interface {
@ -11,9 +13,11 @@ type Dialect interface {
AliasQuoteChar() byte
IdentifierQuoteChar() byte
ArgumentPlaceholder() QueryPlaceholderFunc
ArgumentToString(value any) (string, bool)
IsReservedWord(name string) bool
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName(index int) string
JsonValueEncode(expr Expression) Expression
}
// SerializerFunc func
@ -34,9 +38,11 @@ type DialectParams struct {
AliasQuoteChar byte
IdentifierQuoteChar byte
ArgumentPlaceholder QueryPlaceholderFunc
ArgumentToString func(value any) (string, bool)
ReservedWords []string
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName func(index int) string
JsonValueEncode func(expr Expression) Expression
}
// NewDialect creates new dialect with params
@ -49,9 +55,11 @@ func NewDialect(params DialectParams) Dialect {
aliasQuoteChar: params.AliasQuoteChar,
identifierQuoteChar: params.IdentifierQuoteChar,
argumentPlaceholder: params.ArgumentPlaceholder,
argumentToString: params.ArgumentToString,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
serializeOrderBy: params.SerializeOrderBy,
valuesDefaultColumnName: params.ValuesDefaultColumnName,
jsonValueEncode: params.JsonValueEncode,
}
}
@ -63,9 +71,11 @@ type dialectImpl struct {
aliasQuoteChar byte
identifierQuoteChar byte
argumentPlaceholder QueryPlaceholderFunc
argumentToString func(value any) (string, bool)
reservedWords map[string]bool
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
valuesDefaultColumnName func(index int) string
jsonValueEncode func(expr Expression) Expression
}
func (d *dialectImpl) Name() string {
@ -102,6 +112,10 @@ func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
return d.argumentPlaceholder
}
func (d *dialectImpl) ArgumentToString(value any) (string, bool) {
return d.argumentToString(value)
}
func (d *dialectImpl) IsReservedWord(name string) bool {
_, isReservedWord := d.reservedWords[strings.ToLower(name)]
return isReservedWord
@ -115,6 +129,10 @@ func (d *dialectImpl) ValuesDefaultColumnName(index int) string {
return d.valuesDefaultColumnName(index)
}
func (d *dialectImpl) JsonValueEncode(expr Expression) Expression {
return d.jsonValueEncode(expr)
}
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {

View file

@ -2,7 +2,7 @@ package jet
import "fmt"
// Expression is common interface for all expressions.
// Expression is a common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface {
Serializer
@ -10,6 +10,9 @@ type Expression interface {
GroupByClause
OrderByClause
serializeForJsonValue(statement StatementType, out *SQLBuilder)
setParent(parent Expression)
// IS_NULL tests expression whether it is a NULL value.
IS_NULL() BoolExpression
// IS_NOT_NULL tests expression whether it is a non-NULL value.
@ -34,6 +37,10 @@ type ExpressionInterfaceImpl struct {
Parent Expression
}
func (e *ExpressionInterfaceImpl) setParent(parent Expression) {
e.Parent = parent
}
func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection {
panic(fmt.Sprintf("jet: can't export unaliased expression subQuery: %s, expression: %s",
subQuery.Alias(), serializeToDefaultDebugString(e.Parent)))
@ -92,6 +99,18 @@ func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType
e.Parent.serialize(statement, out, NoWrap)
}
func (e *ExpressionInterfaceImpl) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
panic("jet: expression need to be aliased when used as SELECT JSON projection.")
}
func (e *ExpressionInterfaceImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
panic("jet: expression need to be aliased when used as SELECT JSON projection.")
}
func (e *ExpressionInterfaceImpl) serializeForJsonValue(statement StatementType, out *SQLBuilder) {
out.Dialect.JsonValueEncode(e.Parent).serialize(statement, out)
}
func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, NoWrap)
}
@ -152,7 +171,7 @@ func newExpressionListOperator(operator string, expressions ...Expression) *expr
}
func newBoolExpressionListOperator(operator string, expressions ...BoolExpression) BoolExpression {
return BoolExp(newExpressionListOperator(operator, BoolExpressionListToExpressionList(expressions)...))
return BoolExp(newExpressionListOperator(operator, ToExpressionList(expressions)...))
}
func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {

View file

@ -102,9 +102,10 @@ type floatExpressionWrapper struct {
}
func newFloatExpressionWrap(expression Expression) FloatExpression {
floatExpressionWrap := floatExpressionWrapper{Expression: expression}
floatExpressionWrap.floatInterfaceImpl.parent = &floatExpressionWrap
return &floatExpressionWrap
floatExpressionWrap := &floatExpressionWrapper{Expression: expression}
floatExpressionWrap.floatInterfaceImpl.parent = floatExpressionWrap
expression.setParent(floatExpressionWrap)
return floatExpressionWrap
}
// FloatExp is date expression wrapper around arbitrary expression.

View file

@ -255,18 +255,30 @@ func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{})
//------------ String functions ------------------//
// HEX function takes an input and returns its equivalent hexadecimal representation
func HEX(expression Expression) StringExpression {
return StringExp(Func("HEX", expression))
}
// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument
// as a hexadecimal number and converts it to the byte represented by the number.
// The return value is a binary string.
func UNHEX(expression StringExpression) BlobExpression {
return BlobExp(Func("UNHEX", expression))
}
// BIT_LENGTH returns number of bits in string expression
func BIT_LENGTH(stringExpression StringExpression) IntegerExpression {
func BIT_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression {
return newIntegerFunc("BIT_LENGTH", stringExpression)
}
// CHAR_LENGTH returns number of characters in string expression
func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression {
func CHAR_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression {
return newIntegerFunc("CHAR_LENGTH", stringExpression)
}
// OCTET_LENGTH returns number of bytes in string expression
func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression {
func OCTET_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression {
return newIntegerFunc("OCTET_LENGTH", stringExpression)
}
@ -282,7 +294,7 @@ func UPPER(stringExpression StringExpression) StringExpression {
// BTRIM removes the longest string consisting only of characters
// in characters (a space by default) from the start and end of string
func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression {
func BTRIM(stringExpression StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression {
if len(trimChars) > 0 {
return NewStringFunc("BTRIM", stringExpression, trimChars[0])
}
@ -291,7 +303,7 @@ func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) Str
// LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
func LTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression {
if len(trimChars) > 0 {
return NewStringFunc("LTRIM", str, trimChars[0])
}
@ -300,7 +312,7 @@ func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression
// RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
func RTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression {
if len(trimChars) > 0 {
return NewStringFunc("RTRIM", str, trimChars[0])
}
@ -324,32 +336,32 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression
// CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding.
func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression {
return NewStringFunc("CONVERT", str, srcEncoding, destEncoding)
func CONVERT(str BlobExpression, srcEncoding StringExpression, destEncoding StringExpression) BlobExpression {
return BlobExp(Func("CONVERT", str, srcEncoding, destEncoding))
}
// CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding.
func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression {
func CONVERT_FROM(str BlobExpression, srcEncoding StringExpression) StringExpression {
return NewStringFunc("CONVERT_FROM", str, srcEncoding)
}
// CONVERT_TO converts string to dest_encoding.
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression {
return NewStringFunc("CONVERT_TO", str, toEncoding)
func CONVERT_TO(str StringExpression, toEncoding StringExpression) BlobExpression {
return BlobExp(Func("CONVERT_TO", str, toEncoding))
}
// ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
func ENCODE(data StringExpression, format StringExpression) StringExpression {
return NewStringFunc("ENCODE", data, format)
func ENCODE(data BlobExpression, format StringExpression) StringExpression {
return StringExp(Func("ENCODE", data, format))
}
// DECODE decodes binary data from textual representation in string.
// Options for format are same as in encode.
func DECODE(data StringExpression, format StringExpression) StringExpression {
return NewStringFunc("DECODE", data, format)
func DECODE(data StringExpression, format StringExpression) BlobExpression {
return BlobExp(Func("DECODE", data, format))
}
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
@ -379,11 +391,11 @@ func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
}
// LENGTH returns number of characters in string with a given encoding
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression {
func LENGTH(str StringOrBlobExpression, encoding ...StringExpression) IntegerExpression {
if len(encoding) > 0 {
return NewStringFunc("LENGTH", str, encoding[0])
return IntExp(Func("LENGTH", str, encoding[0]))
}
return NewStringFunc("LENGTH", str)
return IntExp(Func("LENGTH", str))
}
// LPAD fills up the string to length length by prepending the characters
@ -407,8 +419,13 @@ func RPAD(str StringExpression, length IntegerExpression, text ...StringExpressi
return NewStringFunc("RPAD", str, length)
}
// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”).
func BIT_COUNT(bytes BlobExpression) IntegerExpression {
return IntExp(Func("BIT_COUNT", bytes))
}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal
func MD5(stringExpression StringExpression) StringExpression {
func MD5(stringExpression StringOrBlobExpression) StringExpression {
return NewStringFunc("MD5", stringExpression)
}
@ -434,7 +451,7 @@ func STRPOS(str, substring StringExpression) IntegerExpression {
}
// SUBSTR extracts substring
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
func SUBSTR(str StringOrBlobExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 {
return NewStringFunc("SUBSTR", str, from, count[0])
}

View file

@ -141,11 +141,11 @@ type integerExpressionWrapper struct {
}
func newIntExpressionWrap(expression Expression) IntegerExpression {
intExpressionWrap := integerExpressionWrapper{Expression: expression}
intExpressionWrap := &integerExpressionWrapper{Expression: expression}
intExpressionWrap.integerInterfaceImpl.parent = intExpressionWrap
expression.setParent(intExpressionWrap)
intExpressionWrap.integerInterfaceImpl.parent = &intExpressionWrap
return &intExpressionWrap
return intExpressionWrap
}
// IntExp is int expression wrapper around arbitrary expression.

View file

@ -1,37 +0,0 @@
package jet
// Interval is internal common representation of sql interval
type Interval interface {
Serializer
IsInterval
}
// IsInterval interface
type IsInterval interface {
isInterval()
}
// IsIntervalImpl is implementation of IsInterval interface
type IsIntervalImpl struct{}
func (i *IsIntervalImpl) isInterval() {}
// NewInterval creates new interval from serializer
func NewInterval(s Serializer) *IntervalImpl {
newInterval := &IntervalImpl{
Value: s,
}
return newInterval
}
// IntervalImpl is implementation of Interval type
type IntervalImpl struct {
Value Serializer
IsIntervalImpl
}
func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("INTERVAL")
i.Value.serialize(statement, out, FallTrough(options)...)
}

View file

@ -0,0 +1,112 @@
package jet
// IntervalExpression interface
type IntervalExpression interface {
Expression
isInterval()
EQ(rhs IntervalExpression) BoolExpression
NOT_EQ(rhs IntervalExpression) BoolExpression
IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
LT(rhs IntervalExpression) BoolExpression
LT_EQ(rhs IntervalExpression) BoolExpression
GT(rhs IntervalExpression) BoolExpression
GT_EQ(rhs IntervalExpression) BoolExpression
BETWEEN(min, max IntervalExpression) BoolExpression
NOT_BETWEEN(min, max IntervalExpression) BoolExpression
ADD(rhs IntervalExpression) IntervalExpression
SUB(rhs IntervalExpression) IntervalExpression
MUL(rhs NumericExpression) IntervalExpression
DIV(rhs NumericExpression) IntervalExpression
}
type intervalInterfaceImpl struct {
parent IntervalExpression
}
func (i *intervalInterfaceImpl) isInterval() {}
func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression {
return Eq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression {
return NotEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return IsDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return IsNotDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression {
return Lt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression {
return LtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression {
return Gt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression {
return GtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression {
return NewBetweenOperatorExpression(i.parent, min, max, false)
}
func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression {
return NewBetweenOperatorExpression(i.parent, min, max, true)
}
func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression {
return IntervalExp(Add(i.parent, rhs))
}
func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression {
return IntervalExp(Sub(i.parent, rhs))
}
func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression {
return IntervalExp(Mul(i.parent, rhs))
}
func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression {
return IntervalExp(Div(i.parent, rhs))
}
type intervalWrapper struct {
intervalInterfaceImpl
Expression
}
func newIntervalExpressionWrap(expression Expression) IntervalExpression {
intervalWrap := &intervalWrapper{Expression: expression}
intervalWrap.intervalInterfaceImpl.parent = intervalWrap
expression.setParent(intervalWrap)
return intervalWrap
}
// IntervalExp is interval expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as interval expression.
// Does not add sql cast to generated sql builder output.
func IntervalExp(expression Expression) IntervalExpression {
return newIntervalExpressionWrap(expression)
}
// Interval interface
type Interval interface {
Serializer
isInterval()
}

View file

@ -412,17 +412,6 @@ func Raw(raw string, namedArgs ...map[string]interface{}) Expression {
return rawExp
}
// RawWithParent is a Raw constructor used for construction dialect specific expression
func RawWithParent(raw string, parent ...Expression) Expression {
rawExp := &rawExpression{
Raw: raw,
noWrap: true,
}
rawExp.ExpressionInterfaceImpl.Parent = OptionalOrDefaultExpression(rawExp, parent...)
return rawExp
}
// RawBool helper that for raw string boolean expressions
func RawBool(raw string, namedArgs ...map[string]interface{}) BoolExpression {
return BoolExp(Raw(raw, namedArgs...))
@ -468,6 +457,11 @@ func RawDate(raw string, namedArgs ...map[string]interface{}) DateExpression {
return DateExp(Raw(raw, namedArgs...))
}
// RawBlob is raw query helper that for blob expressions
func RawBlob(raw string, namedArgs ...map[string]interface{}) BlobExpression {
return BlobExp(Raw(raw, namedArgs...))
}
// RawRange helper that for range expressions
func RawRange[T Expression](raw string, namedArgs ...map[string]interface{}) Range[T] {
return RangeExp[T](Raw(raw, namedArgs...))

View file

@ -3,6 +3,8 @@ package jet
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection interface {
serializeForProjection(statement StatementType, out *SQLBuilder)
serializeForJsonObjEntry(statement StatementType, out *SQLBuilder)
serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder)
fromImpl(subQuery SelectTable) Projection
}
@ -28,6 +30,10 @@ func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQ
SerializeProjectionList(statement, pl, out)
}
func (pl ProjectionList) serializeForJsonObjEntry(statement StatementType, out *SQLBuilder) {
SerializeProjectionListJsonObj(statement, pl, out)
}
// As will create new projection list where each column is wrapped with a new table alias.
// tableAlias should be in the form 'name' or 'name.*', or it can be an empty string, which will remove existing table alias.
// For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will
@ -79,3 +85,18 @@ func (pl ProjectionList) Except(toExclude ...Column) ProjectionList {
return ret
}
func (pl ProjectionList) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
out.WriteRowToJsonProjections(statement, pl)
}
// JsonObjProjectionList redefines []Projection so projections can be serialized as json object key/values
type JsonObjProjectionList []Projection
func (j JsonObjProjectionList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.IncreaseIdent()
out.NewLine()
SerializeProjectionListJsonObj(statement, j, out)
out.DecreaseIdent()
out.NewLine()
}

View file

@ -118,9 +118,10 @@ type rangeExpressionWrapper[T Expression] struct {
}
func newRangeExpressionWrap[T Expression](expression Expression) Range[T] {
rangeExpressionWrap := rangeExpressionWrapper[T]{Expression: expression}
rangeExpressionWrap.rangeInterfaceImpl.parent = &rangeExpressionWrap
return &rangeExpressionWrap
rangeExpressionWrap := &rangeExpressionWrapper[T]{Expression: expression}
rangeExpressionWrap.rangeInterfaceImpl.parent = rangeExpressionWrap
expression.setParent(rangeExpressionWrap)
return rangeExpressionWrap
}
// RangeExp is range expression wrapper around arbitrary expression.

View file

@ -1,7 +1,7 @@
package jet
type rawStatementImpl struct {
serializerStatementInterfaceImpl
statementInterfaceImpl
RawQuery string
NamedArguments map[string]interface{}
@ -10,7 +10,7 @@ type rawStatementImpl struct {
// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) SerializerStatement {
newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
statementInterfaceImpl: statementInterfaceImpl{
dialect: dialect,
statementType: "",
parent: nil,

View file

@ -17,9 +17,9 @@ type RowExpression interface {
}
type rowInterfaceImpl struct {
parent Expression
dialect Dialect
elemCount int
parent Expression
dialect Dialect
expressions []Expression
}
func (n *rowInterfaceImpl) EQ(rhs RowExpression) BoolExpression {
@ -57,9 +57,8 @@ func (n *rowInterfaceImpl) LT_EQ(rhs RowExpression) BoolExpression {
func (n *rowInterfaceImpl) projections() ProjectionList {
var ret ProjectionList
for i := 0; i < n.elemCount; i++ {
rowColumn := NewColumnImpl(n.dialect.ValuesDefaultColumnName(i), "", nil)
ret = append(ret, &rowColumn)
for i, expression := range n.expressions {
ret = append(ret, newDummyColumnForExpression(expression, n.dialect.ValuesDefaultColumnName(i)))
}
return ret
@ -77,7 +76,7 @@ func newRowExpression(name string, dialect Dialect, expressions ...Expression) R
ret.Expression = NewFunc(name, expressions, ret)
ret.dialect = dialect
ret.elemCount = len(expressions)
ret.expressions = expressions
return ret
}

View file

@ -24,14 +24,16 @@ type StatementType string
// Statement types
const (
SelectStatementType StatementType = "SELECT"
InsertStatementType StatementType = "INSERT"
UpdateStatementType StatementType = "UPDATE"
DeleteStatementType StatementType = "DELETE"
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
WithStatementType StatementType = "WITH"
SelectStatementType StatementType = "SELECT"
SelectJsonObjStatementType StatementType = "SELECT_JSON_OBJ"
SelectJsonArrStatementType StatementType = "SELECT_JSON_ARR"
InsertStatementType StatementType = "INSERT"
UpdateStatementType StatementType = "UPDATE"
DeleteStatementType StatementType = "DELETE"
SetStatementType StatementType = "SET"
LockStatementType StatementType = "LOCK"
UnLockStatementType StatementType = "UNLOCK"
WithStatementType StatementType = "WITH"
)
// Serializer interface

View file

@ -61,6 +61,17 @@ func (s *SQLBuilder) WriteProjections(statement StatementType, projections []Pro
s.DecreaseIdent()
}
// WriteRowToJsonProjections serializes slice of projections intended for row_to_json json aggregation
func (s *SQLBuilder) WriteRowToJsonProjections(statement StatementType, projections []Projection) {
for i, projection := range projections {
if i > 0 {
s.WriteString(",")
s.NewLine()
}
projection.serializeForRowToJsonProjection(statement, s)
}
}
// NewLine adds new line to output SQL
func (s *SQLBuilder) NewLine() {
s.write([]byte{'\n'})
@ -99,6 +110,11 @@ func (s *SQLBuilder) WriteString(str string) {
s.write([]byte(str))
}
// WriteJsonObjKey serializes json object key
func (s *SQLBuilder) WriteJsonObjKey(key string) {
s.WriteString(fmt.Sprintf(`'%s', `, key))
}
// WriteIdentifier adds identifier to output SQL
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
if s.shouldQuote(name, alwaysQuote...) {
@ -123,7 +139,7 @@ func (s *SQLBuilder) finalize() (string, []interface{}) {
}
func (s *SQLBuilder) insertConstantArgument(arg interface{}) {
s.WriteString(argToString(arg))
s.WriteString(s.argToString(arg))
}
func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) {
@ -196,7 +212,7 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{})
}
if s.Debug {
placeholder = argToString(namedArgumentPos.Value)
placeholder = s.argToString(namedArgumentPos.Value)
}
raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace)
@ -205,11 +221,17 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{})
s.WriteString(raw)
}
func argToString(value interface{}) string {
func (s *SQLBuilder) argToString(value interface{}) string {
if is.Nil(value) {
return "NULL"
}
strVal, ok := s.Dialect.ArgumentToString(value)
if ok {
return strVal
}
switch bindVal := value.(type) {
case bool:
if bindVal {
@ -246,7 +268,7 @@ func argToString(value interface{}) string {
return err.Error()
}
return argToString(val)
return s.argToString(val)
}
panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String()))

View file

@ -8,37 +8,39 @@ import (
)
func TestArgToString(t *testing.T) {
require.Equal(t, argToString(true), "TRUE")
require.Equal(t, argToString(false), "FALSE")
s := &SQLBuilder{Dialect: defaultDialect, Debug: true}
require.Equal(t, argToString(int(-32)), "-32")
require.Equal(t, argToString(uint(32)), "32")
require.Equal(t, argToString(int8(-43)), "-43")
require.Equal(t, argToString(uint8(43)), "43")
require.Equal(t, argToString(int16(-54)), "-54")
require.Equal(t, argToString(uint16(54)), "54")
require.Equal(t, argToString(int32(-65)), "-65")
require.Equal(t, argToString(uint32(65)), "65")
require.Equal(t, argToString(int64(-64)), "-64")
require.Equal(t, argToString(uint64(64)), "64")
require.Equal(t, argToString(float32(2.0)), "2")
require.Equal(t, argToString(float64(1.11)), "1.11")
require.Equal(t, s.argToString(true), "TRUE")
require.Equal(t, s.argToString(false), "FALSE")
require.Equal(t, argToString("john"), "'john'")
require.Equal(t, argToString("It's text"), "'It''s text'")
require.Equal(t, argToString([]byte("john")), "'john'")
require.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
require.Equal(t, s.argToString(int(-32)), "-32")
require.Equal(t, s.argToString(uint(32)), "32")
require.Equal(t, s.argToString(int8(-43)), "-43")
require.Equal(t, s.argToString(uint8(43)), "43")
require.Equal(t, s.argToString(int16(-54)), "-54")
require.Equal(t, s.argToString(uint16(54)), "54")
require.Equal(t, s.argToString(int32(-65)), "-65")
require.Equal(t, s.argToString(uint32(65)), "65")
require.Equal(t, s.argToString(int64(-64)), "-64")
require.Equal(t, s.argToString(uint64(64)), "64")
require.Equal(t, s.argToString(float32(2.0)), "2")
require.Equal(t, s.argToString(float64(1.11)), "1.11")
require.Equal(t, s.argToString("john"), "'john'")
require.Equal(t, s.argToString("It's text"), "'It''s text'")
require.Equal(t, s.argToString([]byte("john")), "'john'")
require.Equal(t, s.argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
require.NoError(t, err)
require.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
require.Equal(t, s.argToString(time), "'2006-01-02 15:04:05-07:00'")
func() {
defer func() {
require.Equal(t, recover().(string), "jet: map[string]bool type can not be used as SQL query parameter")
}()
argToString(map[string]bool{})
s.argToString(map[string]bool{})
}()
}

View file

@ -7,25 +7,38 @@ import (
"time"
)
// Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
// Statement is a common interface for all SQL statements, including SELECT, SELECT_JSON_ARR, SELECT_JSON_OBJ, INSERT,
// UPDATE, DELETE, and LOCK.
type Statement interface {
// Sql returns parametrized sql query with list of arguments.
// Sql returns a parameterized SQL query along with its list of arguments.
Sql() (query string, args []interface{})
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument string representation.
// Do not use it in production. Use it only for debug purposes.
// DebugSql returns a debug-friendly SQL query where all parameterized placeholders
// are replaced with their respective argument string representations.
//
// Warning: This method should only be used for debugging purposes.
// Do not use it in production, as it may lead to security risks such as SQL injection.
DebugSql() (query string)
// Query executes statement over database connection/transaction db and stores row results in destination.
// Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
// Query delegates call to QueryContext using context.Background() as parameter.
Query(db qrm.Queryable, destination interface{}) error
// QueryContext executes statement with a context over database connection/transaction db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
// QueryContext executes the statement with the provided context over a database connection or transaction (`db`),
// and stores the retrieved row results in the given destination.
//
// For statements of type SELECT, INSERT, UPDATE, or DELETE, the destination must be a pointer to either a struct or a slice.
// For SELECT_JSON_ARR statements, the destination must be a pointer to a slice of structs or a pointer to []map[string]any.
// For SELECT_JSON_OBJ statements, the destination must be a pointer to a struct or a pointer to map[string]any.
//
// If the destination is a pointer to a struct and the query returns no rows, QueryContext returns qrm.ErrNoRows.
QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error
// Exec executes statement over db connection/transaction without returning any rows.
// Exec delegates call to ExecContext using context.Background() as parameter.
Exec(db qrm.Executable) (sql.Result, error)
// ExecContext executes statement with context over db connection/transaction without returning any rows.
ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error)
// Rows executes statements over db connection/transaction and returns rows
Rows(ctx context.Context, db qrm.Queryable) (*Rows, error)
}
@ -60,14 +73,14 @@ type SerializerHasProjections interface {
HasProjections
}
// serializerStatementInterfaceImpl struct
type serializerStatementInterfaceImpl struct {
// statementInterfaceImpl struct
type statementInterfaceImpl struct {
dialect Dialect
statementType StatementType
parent SerializerStatement
}
func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface{}) {
func (s *statementInterfaceImpl) Sql() (query string, args []interface{}) {
queryData := &SQLBuilder{Dialect: s.dialect}
@ -77,7 +90,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface
return
}
func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
func (s *statementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true}
s.parent.serialize(s.statementType, sqlBuilder, NoWrap)
@ -86,11 +99,27 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
return
}
func (s *serializerStatementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error {
func (s *statementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error {
return s.QueryContext(context.Background(), db, destination)
}
func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error {
func (s *statementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error {
return s.query(ctx, func(query string, args []interface{}) (int64, error) {
switch s.statementType {
case SelectJsonObjStatementType:
return qrm.QueryJsonObj(ctx, db, query, args, destination)
case SelectJsonArrStatementType:
return qrm.QueryJsonArr(ctx, db, query, args, destination)
default:
return qrm.Query(ctx, db, query, args, destination)
}
})
}
func (s *statementInterfaceImpl) query(
ctx context.Context,
queryFunc func(query string, args []interface{}) (int64, error),
) error {
query, args := s.Sql()
callLogger(ctx, s)
@ -99,7 +128,7 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db
var err error
duration := duration(func() {
rowsProcessed, err = qrm.Query(ctx, db, query, args, destination)
rowsProcessed, err = queryFunc(query, args)
})
callQueryLoggerFunc(ctx, QueryInfo{
@ -112,11 +141,11 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db
return err
}
func (s *serializerStatementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) {
func (s *statementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) {
return s.ExecContext(context.Background(), db)
}
func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) {
func (s *statementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) {
query, args := s.Sql()
callLogger(ctx, s)
@ -141,7 +170,7 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q
return res, err
}
func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) {
func (s *statementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) {
query, args := s.Sql()
callLogger(ctx, s)
@ -191,11 +220,15 @@ type ExpressionStatement interface {
}
// NewExpressionStatementImpl creates new expression statement
func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement {
func NewExpressionStatementImpl(Dialect Dialect,
statementType StatementType,
parent ExpressionStatement,
clauses ...Clause) ExpressionStatement {
return &expressionStatementImpl{
ExpressionInterfaceImpl{Parent: parent},
statementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
statementInterfaceImpl: statementInterfaceImpl{
parent: parent,
dialect: Dialect,
statementType: statementType,
@ -214,10 +247,14 @@ func (s *expressionStatementImpl) serializeForProjection(statement StatementType
s.serialize(statement, out)
}
func (e *expressionStatementImpl) serializeForRowToJsonProjection(statement StatementType, out *SQLBuilder) {
panic("jet: SELECT JSON statements need to be aliased when used as a projection.")
}
// NewStatementImpl creates new statementImpl
func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) SerializerStatement {
return &statementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
statementInterfaceImpl: statementInterfaceImpl{
parent: parent,
dialect: Dialect,
statementType: statementType,
@ -227,7 +264,7 @@ func NewStatementImpl(Dialect Dialect, statementType StatementType, parent Seria
}
type statementImpl struct {
serializerStatementInterfaceImpl
statementInterfaceImpl
Clauses []Clause
}

View file

@ -3,6 +3,7 @@ package jet
// StringExpression interface
type StringExpression interface {
Expression
isStringOrBlob()
EQ(rhs StringExpression) BoolExpression
NOT_EQ(rhs StringExpression) BoolExpression
@ -29,6 +30,8 @@ type stringInterfaceImpl struct {
parent StringExpression
}
func (s *stringInterfaceImpl) isStringOrBlob() {}
func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression {
return Eq(s.parent, rhs)
}
@ -102,9 +105,10 @@ type stringExpressionWrapper struct {
}
func newStringExpressionWrap(expression Expression) StringExpression {
stringExpressionWrap := stringExpressionWrapper{Expression: expression}
stringExpressionWrap.stringInterfaceImpl.parent = &stringExpressionWrap
return &stringExpressionWrap
stringExpressionWrap := &stringExpressionWrapper{Expression: expression}
stringExpressionWrap.stringInterfaceImpl.parent = stringExpressionWrap
expression.setParent(stringExpressionWrap)
return stringExpressionWrap
}
// StringExp is string expression wrapper around arbitrary expression.

View file

@ -0,0 +1,8 @@
package jet
// StringOrBlobExpression is common interface for all string and blob expressions
type StringOrBlobExpression interface {
Expression
isStringOrBlob()
}

View file

@ -12,6 +12,9 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
ArgumentToString: func(value any) (string, bool) {
return "", false
},
})
var (

View file

@ -75,14 +75,15 @@ func (t *timeInterfaceImpl) SUB(rhs Interval) TimeExpression {
//---------------------------------------------------//
type timeExpressionWrapper struct {
timeInterfaceImpl
Expression
timeInterfaceImpl
}
func newTimeExpressionWrap(expression Expression) TimeExpression {
timeExpressionWrap := timeExpressionWrapper{Expression: expression}
timeExpressionWrap.timeInterfaceImpl.parent = &timeExpressionWrap
return &timeExpressionWrap
timeExpressionWrap := &timeExpressionWrapper{Expression: expression}
timeExpressionWrap.timeInterfaceImpl.parent = timeExpressionWrap
expression.setParent(timeExpressionWrap)
return timeExpressionWrap
}
// TimeExp is time expression wrapper around arbitrary expression.

View file

@ -52,11 +52,3 @@ func TestTimeExp(t *testing.T) {
assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)),
"(table1.col_float < $1)", string("01:01:01.001"))
}
func TestTimeArithmetic(t *testing.T) {
time := Time(10, 20, 3)
assertClauseDebugSerialize(t, table1ColTime.ADD(NewInterval(String("1 HOUR"))).EQ(time),
"((table1.col_time + INTERVAL '1 HOUR') = '10:20:03')")
assertClauseDebugSerialize(t, table1ColTime.SUB(NewInterval(String("1 HOUR"))).EQ(time),
"((table1.col_time - INTERVAL '1 HOUR') = '10:20:03')")
}

View file

@ -80,9 +80,10 @@ type timestampExpressionWrapper struct {
}
func newTimestampExpressionWrap(expression Expression) TimestampExpression {
timestampExpressionWrap := timestampExpressionWrapper{Expression: expression}
timestampExpressionWrap.timestampInterfaceImpl.parent = &timestampExpressionWrap
return &timestampExpressionWrap
timestampExpressionWrap := &timestampExpressionWrapper{Expression: expression}
timestampExpressionWrap.timestampInterfaceImpl.parent = timestampExpressionWrap
expression.setParent(timestampExpressionWrap)
return timestampExpressionWrap
}
// TimestampExp is timestamp expression wrapper around arbitrary expression.

View file

@ -53,11 +53,3 @@ func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1)", "2000-01-31 10:20:00.003")
}
func TestTimestampArithmetic(t *testing.T) {
timestamp := Timestamp(2000, 1, 1, 0, 0, 0)
assertClauseDebugSerialize(t, table1ColTimestamp.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_timestamp + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
assertClauseDebugSerialize(t, table1ColTimestamp.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_timestamp - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
}

View file

@ -80,9 +80,10 @@ type timestampzExpressionWrapper struct {
}
func newTimestampzExpressionWrap(expression Expression) TimestampzExpression {
timestampzExpressionWrap := timestampzExpressionWrapper{Expression: expression}
timestampzExpressionWrap.timestampzInterfaceImpl.parent = &timestampzExpressionWrap
return &timestampzExpressionWrap
timestampzExpressionWrap := &timestampzExpressionWrapper{Expression: expression}
timestampzExpressionWrap.timestampzInterfaceImpl.parent = timestampzExpressionWrap
expression.setParent(timestampzExpressionWrap)
return timestampzExpressionWrap
}
// TimestampzExp is timestamp with time zone expression wrapper around arbitrary expression.

View file

@ -53,11 +53,3 @@ func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200")
}
func TestTimestampzArithmetic(t *testing.T) {
timestampz := Timestampz(2000, 1, 1, 0, 0, 0, 100, "UTC")
assertClauseDebugSerialize(t, table1ColTimestampz.ADD(NewInterval(String("1 HOUR"))).EQ(timestampz),
"((table1.col_timestampz + INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')")
assertClauseDebugSerialize(t, table1ColTimestampz.SUB(NewInterval(String("1 HOUR"))).EQ(timestampz),
"((table1.col_timestampz - INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')")
}

View file

@ -75,14 +75,15 @@ func (t *timezInterfaceImpl) SUB(rhs Interval) TimezExpression {
//---------------------------------------------------//
type timezExpressionWrapper struct {
timezInterfaceImpl
Expression
timezInterfaceImpl
}
func newTimezExpressionWrap(expression Expression) TimezExpression {
timezExpressionWrap := timezExpressionWrapper{Expression: expression}
timezExpressionWrap.timezInterfaceImpl.parent = &timezExpressionWrap
return &timezExpressionWrap
timezExpressionWrap := &timezExpressionWrapper{Expression: expression}
timezExpressionWrap.timezInterfaceImpl.parent = timezExpressionWrap
expression.setParent(timezExpressionWrap)
return timezExpressionWrap
}
// TimezExp is time with time zone expression wrapper around arbitrary expression.

View file

@ -51,11 +51,3 @@ func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")),
"(table1.col_float < $1)", string("01:01:01.000000001 +4:00"))
}
func TestTimezArithmetic(t *testing.T) {
timez := Timez(0, 0, 0, 100, "UTC")
assertClauseDebugSerialize(t, table1ColTimez.ADD(NewInterval(String("1 HOUR"))).EQ(timez),
"((table1.col_timez + INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')")
assertClauseDebugSerialize(t, table1ColTimez.SUB(NewInterval(String("1 HOUR"))).EQ(timez),
"((table1.col_timez - INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')")
}

View file

@ -58,6 +58,23 @@ func SerializeProjectionList(statement StatementType, projections []Projection,
}
}
// SerializeProjectionListJsonObj serializes a list of projections for JSON object
func SerializeProjectionListJsonObj(statement StatementType, projections []Projection, out *SQLBuilder) {
for i, p := range projections {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
if p == nil {
panic("jet: Projection is nil")
}
p.serializeForJsonObjEntry(statement, out)
}
}
// SerializeColumnNames func
func SerializeColumnNames(columns []Column, out *SQLBuilder) {
for i, col := range columns {
@ -115,8 +132,8 @@ func ExpressionListToSerializerList(expressions []Expression) []Serializer {
return ret
}
// BoolExpressionListToExpressionList converts list of bool expressions to list of expressions
func BoolExpressionListToExpressionList(expressions []BoolExpression) []Expression {
// ToExpressionList converts list of any expressions to list of expressions
func ToExpressionList[T Expression](expressions []T) []Expression {
var ret []Expression
for _, expression := range expressions {

View file

@ -7,7 +7,7 @@ func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(s
newWithImpl := &withImpl{
recursive: recursive,
ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
statementInterfaceImpl: statementInterfaceImpl{
dialect: dialect,
statementType: WithStatementType,
},
@ -25,7 +25,7 @@ func WITH(dialect Dialect, recursive bool, cte ...*CommonTableExpression) func(s
}
type withImpl struct {
serializerStatementInterfaceImpl
statementInterfaceImpl
recursive bool
ctes []*CommonTableExpression
primaryStatement SerializerStatement

View file

@ -115,6 +115,16 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
require.Equal(t, dataJson, expectedJSON)
}
// AssertJsonEqual checks if actual and expected json representation are the same
func AssertJsonEqual(t require.TestingT, actual, expected interface{}, option ...cmp.Option) {
actualJsonData, err := json.MarshalIndent(actual, "", "\t")
require.NoError(t, err)
expectedJsonData, err := json.MarshalIndent(expected, "", "\t")
require.NoError(t, err)
require.Equal(t, string(actualJsonData), string(expectedJsonData))
}
// SaveJSONFile saves v as json at testRelativePath
// nolint:unused
func SaveJSONFile(v interface{}, testRelativePath string) {
@ -127,7 +137,10 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
}
// AssertJSONFile check if data json representation is the same as json at testRelativePath
func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
func AssertJSONFile(t require.TestingT, data interface{}, testRelativePath string) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
filePath := getFullPath(testRelativePath)
fileJSONData, err := os.ReadFile(filePath) // #nosec G304
@ -145,7 +158,11 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
}
// AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs
func AssertStatementSql(t *testing.T, query jet.PrintableStatement, expectedQuery string, expectedArgs ...interface{}) {
func AssertStatementSql(t require.TestingT, query jet.PrintableStatement, expectedQuery string, expectedArgs ...interface{}) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
queryStr, args := query.Sql()
assertQueryString(t, queryStr, expectedQuery)
@ -283,14 +300,14 @@ func AssertFileNamesEqual(t *testing.T, dirPath string, fileNames ...string) {
}
// AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}, option ...cmp.Option) {
func AssertDeepEqual(t require.TestingT, actual, expected interface{}, option ...cmp.Option) {
if !assert.True(t, cmp.Equal(actual, expected, option...)) {
printDiff(actual, expected, option...)
t.FailNow()
}
}
func assertQueryString(t *testing.T, actual, expected string) {
func assertQueryString(t require.TestingT, actual, expected string) {
if !assert.Equal(t, actual, expected) {
printDiff(actual, expected)
t.FailNow()

View file

@ -1,6 +1,9 @@
package datetime
import "time"
import (
//"github.com/go-jet/jet/v2/internal/utils/min"
"time"
)
// ExtractTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration
func ExtractTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) {
@ -20,3 +23,36 @@ func ExtractTimeComponents(duration time.Duration) (days, hours, minutes, second
return
}
// TryParseAsTime attempts to parse the provided value as a time using one of the given formats.
//
// The function iterates over the provided formats and tries to parse the value into a time.Time object.
// It returns the parsed time and a boolean indicating whether the parsing was successful.
func TryParseAsTime(value interface{}, formats []string) (time.Time, bool) {
var timeStr string
switch v := value.(type) {
case string:
timeStr = v
case []byte:
timeStr = string(v)
case int64:
return time.Unix(v, 0), true // sqlite
default:
return time.Time{}, false
}
for _, format := range formats {
formatLen := min(len(format), len(timeStr))
t, err := time.Parse(format[:formatLen], timeStr)
if err != nil {
continue
}
return t, true
}
return time.Time{}, false
}

View file

@ -1,6 +1,8 @@
package is
import "reflect"
import (
"reflect"
)
// Nil check if v is nil
func Nil(v interface{}) bool {

View file

@ -1,9 +0,0 @@
package min
// Int returns minimum of two int values
func Int(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -70,6 +70,6 @@ func (c *cast) AS_TIME() TimeExpression {
}
// AS_BINARY casts expression as BINARY type
func (c *cast) AS_BINARY() StringExpression {
return StringExp(c.AS("BINARY"))
func (c *cast) AS_BINARY() BlobExpression {
return BlobExp(c.AS("BINARY"))
}

View file

@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString
// StringColumn creates named string column.
var StringColumn = jet.StringColumn
// ColumnBlob is interface for blob columns.
type ColumnBlob = jet.ColumnBlob
// BlobColumn creates named blob column.
var BlobColumn = jet.BlobColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger

View file

@ -1,11 +1,12 @@
package mysql
import (
"encoding/hex"
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
)
// Dialect is implementation of MySQL dialect for SQL Builder serialisation.
// Dialect is implementation of MySQL dialect for SQL Builder serialization.
var Dialect = newDialect()
func newDialect() jet.Dialect {
@ -27,16 +28,43 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(int) string {
return "?"
},
ArgumentToString: argumentToString,
ReservedWords: reservedWords,
SerializeOrderBy: serializeOrderBy,
ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column_%d", index)
},
JsonValueEncode: func(expr Expression) Expression {
switch e := expr.(type) {
case BlobExpression:
return TO_BASE64(e)
// CustomExpression used bellow (instead DATE_FORMAT function) so that only expr is parametrized
case TimestampExpression:
return CustomExpression(Token("DATE_FORMAT("), e, Token(",'%Y-%m-%dT%H:%i:%s.%fZ')"))
case TimeExpression:
return CustomExpression(Token("CONCAT('0000-01-01T', DATE_FORMAT("), e, Token(",'%H:%i:%s.%fZ'))"))
case DateExpression:
return CustomExpression(Token("CONCAT(DATE_FORMAT("), e, Token(",'%Y-%m-%d')"), Token(", 'T00:00:00Z')"))
case BoolExpression:
return CustomExpression(e, Token(" = 1"))
}
return expr
},
}
return jet.NewDialect(mySQLDialectParams)
}
func argumentToString(value any) (string, bool) {
switch bindVal := value.(type) {
case []byte:
return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true
}
return "", false
}
func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {

View file

@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression
// StringExpression interface
type StringExpression = jet.StringExpression
// BlobExpression interface
type BlobExpression = jet.BlobExpression
// IntegerExpression interface
type IntegerExpression = jet.IntegerExpression
@ -43,6 +46,11 @@ var BoolExp = jet.BoolExp
// Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp
// BlobExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as blob expression.
// Does not add sql cast to generated sql builder output.
var BlobExp = jet.BlobExp
// IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output.
@ -100,6 +108,7 @@ var (
RawTime = jet.RawTime
RawTimestamp = jet.RawTimestamp
RawDate = jet.RawDate
RawBlob = jet.RawBlob
)
// Func can be used to call custom or unsupported database functions.

View file

@ -148,6 +148,14 @@ var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------//
// HEX function in MySQL takes an input and returns its equivalent hexadecimal representation
var HEX = jet.HEX
// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument
// as a hexadecimal number and converts it to the byte represented by the number.
// The return value is a binary string.
var UNHEX = jet.UNHEX
// BIT_LENGTH returns number of bits in string expression
var BIT_LENGTH = jet.BIT_LENGTH
@ -157,6 +165,23 @@ var CHAR_LENGTH = jet.CHAR_LENGTH
// OCTET_LENGTH returns number of bytes in string expression
var OCTET_LENGTH = jet.OCTET_LENGTH
// ELT returns the Nth element of the list of strings: str1 if N = 1, str2 if N = 2, and so on.
// Returns NULL if N is less than 1, greater than the number of arguments, or NULL.
func ELT(n IntegerExpression, list ...StringExpression) StringExpression {
args := []Expression{n}
args = append(args, jet.ToExpressionList(list)...)
return StringExp(Func("ELT", args...))
}
// FIELD returns the index (position) of str in the str1, str2, str3, ... list. Returns 0 if str is not found.
func FIELD(str StringExpression, list ...StringExpression) StringExpression {
args := []Expression{str}
args = append(args, jet.ToExpressionList(list)...)
return StringExp(Func("FIELD", args...))
}
// LOWER returns string expression in lower case
var LOWER = jet.LOWER
@ -178,7 +203,35 @@ var CONCAT = jet.CONCAT
var CONCAT_WS = jet.CONCAT_WS
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
var FORMAT = jet.FORMAT
func FORMAT(number jet.NumericExpression, decimals IntegerExpression, optionalLocale ...StringExpression) StringExpression {
if len(optionalLocale) > 0 {
return StringExp(Func("FORMAT", number, decimals, optionalLocale[0]))
}
return StringExp(Func("FORMAT", number, decimals))
}
// TO_BASE64 converts the string argument to base-64 encoded form and returns the
// result as a character string with the connection character set and collation.
func TO_BASE64(data jet.StringOrBlobExpression) StringExpression {
return StringExp(Func("TO_BASE64", data))
}
// FROM_BASE64 takes a string encoded with the base-64 encoded rules used by TO_BASE64()
// and returns the decoded result as a binary string.
func FROM_BASE64(data StringExpression) BlobExpression {
return BlobExp(Func("FROM_BASE64", data))
}
// CHARSET returns the character set of the string argument, or NULL if the argument is NULL.
func CHARSET(exp Expression) StringExpression {
return StringExp(Func("CHARSET", exp))
}
// COLLATION returns the collation of the string argument.
func COLLATION(exp Expression) StringExpression {
return StringExp(Func("COLLATION ", exp))
}
// LEFT returns first n characters in the string.
// When n is negative, return all but last |n| characters.
@ -189,7 +242,7 @@ var LEFT = jet.LEFT
var RIGHT = jet.RIGHT
// LENGTH returns number of characters in string with a given encoding
func LENGTH(str jet.StringExpression) jet.StringExpression {
func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression {
return jet.LENGTH(str)
}

View file

@ -98,13 +98,10 @@ func INTERVAL(value interface{}, unitType unitType) Interval {
// INTERVALe creates new temporal interval from expresion and unit type.
func INTERVALe(expr Expression, unitType unitType) Interval {
return jet.NewInterval(jet.ListSerializer{
Serializers: []jet.Serializer{expr, jet.RawWithParent(string(unitType))},
Separator: " ",
})
return jet.IntervalExp(CustomExpression(Token("INTERVAL"), expr, Token(unitType)))
}
// INTERVALd temoral interval from time.Duration
// INTERVALd creates new temporal interval from time.Duration
func INTERVALd(duration time.Duration) Interval {
var sign int64 = 1
if duration < 0 {

View file

@ -56,6 +56,11 @@ var String = jet.String
// value can be any uuid type with a String method
var UUID = jet.UUID
// Blob creates new blob literal expression
func Blob(data []byte) BlobExpression {
return BlobExp(jet.Literal(data))
}
// Date creates new date literal
func Date(year int, month time.Month, day int) DateExpression {
return CAST(jet.Date(year, month, day)).AS_DATE()

79
mysql/select_json.go Normal file
View file

@ -0,0 +1,79 @@
package mysql
import (
"github.com/go-jet/jet/v2/internal/jet"
)
// SelectJsonStatement is an interface for MySQL statements that generate JSON on the server.
type SelectJsonStatement interface {
Statement
jet.Serializer
AS(alias string) Projection
FROM(table ReadableTable) SelectJsonStatement
WHERE(condition BoolExpression) SelectJsonStatement
ORDER_BY(orderByClauses ...OrderByClause) SelectJsonStatement
LIMIT(limit int64) SelectJsonStatement
OFFSET(offset int64) SelectJsonStatement
}
// SELECT_JSON_ARR creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_ARR(projections ...Projection) SelectJsonStatement {
return newSelectStatementJson(projections, jet.SelectJsonArrStatementType)
}
// SELECT_JSON_OBJ creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_OBJ(projections ...Projection) SelectJsonStatement {
return newSelectStatementJson(projections, jet.SelectJsonObjStatementType)
}
type selectJsonStatement struct {
*selectStatementImpl
}
func newSelectStatementJson(projections []Projection, statementType jet.StatementType) SelectJsonStatement {
newSelect := &selectJsonStatement{
selectStatementImpl: newSelectStatement(statementType, nil, nil),
}
newSelect.Select.ProjectionList = ProjectionList{constructJsonFunc(projections, statementType).AS("json")}
return newSelect
}
func constructJsonFunc(projections []Projection, statementType jet.StatementType) Expression {
jsonObj := Func("JSON_OBJECT", CustomExpression(jet.JsonObjProjectionList(projections)))
if statementType == jet.SelectJsonArrStatementType {
return Func("JSON_ARRAYAGG", jsonObj)
}
return jsonObj
}
func (s *selectJsonStatement) FROM(table ReadableTable) SelectJsonStatement {
s.From.Tables = []jet.Serializer{table}
return s
}
func (s *selectJsonStatement) WHERE(condition BoolExpression) SelectJsonStatement {
s.Where.Condition = condition
return s
}
func (s *selectJsonStatement) ORDER_BY(orderByClauses ...OrderByClause) SelectJsonStatement {
s.OrderBy.List = orderByClauses
return s
}
func (s *selectJsonStatement) LIMIT(limit int64) SelectJsonStatement {
s.Limit.Count = limit
return s
}
func (s *selectJsonStatement) OFFSET(offset int64) SelectJsonStatement {
s.Offset.Count = Int(offset)
return s
}

View file

@ -62,12 +62,12 @@ type SelectStatement interface {
// SELECT creates new SelectStatement with list of projections
func SELECT(projection Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection}, projections...))
return newSelectStatement(jet.SelectStatementType, nil, append([]Projection{projection}, projections...))
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
func newSelectStatement(stmtType jet.StatementType, table ReadableTable, projections []Projection) *selectStatementImpl {
newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect,
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, stmtType, newSelect,
&newSelect.Select,
&newSelect.From,
&newSelect.Where,

View file

@ -50,7 +50,7 @@ type readableTableInterfaceImpl struct {
// Generates a select query on the current tableName.
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
return newSelectStatement(jet.SelectStatementType, r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.

View file

@ -101,9 +101,9 @@ func (b *cast) AS_DECIMAL() FloatExpression {
return FloatExp(b.AS("decimal"))
}
// AS_BYTEA casts expression AS text type
func (b *cast) AS_BYTEA() StringExpression {
return StringExp(b.AS("bytea"))
// AS_BYTEA casts expression AS bytea type
func (b *cast) AS_BYTEA() ByteaExpression {
return ByteaExp(b.AS("bytea"))
}
// AS_TIME casts expression AS date type

View file

@ -23,6 +23,12 @@ type ColumnString = jet.ColumnString
// StringColumn creates named string column.
var StringColumn = jet.StringColumn
// ColumnBytea is interface for bytea columns
type ColumnBytea = jet.ColumnBlob
// ByteaColumn creates new named bytea column.
var ByteaColumn = jet.BlobColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger
@ -65,6 +71,12 @@ type ColumnTimestampz = jet.ColumnTimestampz
// TimestampzColumn creates named timestamp with time zone column.
var TimestampzColumn = jet.TimestampzColumn
// ColumnInterval is interface of PostgreSQL interval columns.
type ColumnInterval = jet.ColumnInterval
// IntervalColumn creates named interval column
var IntervalColumn = jet.IntervalColumn
// ColumnDateRange is interface of SQL date range column
type ColumnDateRange = jet.ColumnRange[DateExpression]
@ -100,41 +112,3 @@ type ColumnInt8Range jet.ColumnRange[jet.Int8Expression]
// Int8RangeColumn creates named range with range column
var Int8RangeColumn = jet.RangeColumn[jet.Int8Expression]
//------------------------------------------------------//
// ColumnInterval is interface of PostgreSQL interval columns.
type ColumnInterval interface {
IntervalExpression
jet.Column
From(subQuery SelectTable) ColumnInterval
SET(intervalExp IntervalExpression) ColumnAssigment
}
//------------------------------------------------------//
type intervalColumnImpl struct {
jet.ColumnExpressionImpl
intervalInterfaceImpl
}
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
return jet.NewColumnAssignment(i, intervalExp)
}
func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval {
newIntervalColumn := IntervalColumn(i.Name())
jet.SetTableName(newIntervalColumn, i.TableName())
jet.SetSubQuery(newIntervalColumn, subQuery)
return newIntervalColumn
}
// IntervalColumn creates named interval column.
func IntervalColumn(name string) ColumnInterval {
intervalColumn := &intervalColumnImpl{}
intervalColumn.ColumnExpressionImpl = jet.NewColumnImpl(name, "", intervalColumn)
intervalColumn.intervalInterfaceImpl.parent = intervalColumn
return intervalColumn
}

View file

@ -1,10 +1,10 @@
package postgres
import (
"encoding/hex"
"fmt"
"strconv"
"github.com/go-jet/jet/v2/internal/jet"
"strconv"
)
// Dialect is implementation of postgres dialect for SQL Builder serialisation.
@ -26,15 +26,42 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord)
},
ReservedWords: reservedWords,
ArgumentToString: argumentToString,
ReservedWords: reservedWords,
ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column%d", index+1)
},
JsonValueEncode: func(expr Expression) Expression {
switch e := expr.(type) {
case ByteaExpression:
return ENCODE(e, Base64)
// CustomExpression used bellow (instead TO_CHAR function) so that only expr is parametrized
case TimeExpression:
return CustomExpression(Token("'0000-01-01T' || to_char('2000-10-10'::date + "), e, Token(`, 'HH24:MI:SS.USZ')`))
case TimezExpression:
return CustomExpression(Token("'0000-01-01T' || to_char('2000-10-10'::date + "), e, Token(`, 'HH24:MI:SS.USTZH:TZM')`))
case TimestampExpression:
return CustomExpression(Token("to_char("), e, Token(`, 'YYYY-MM-DD"T"HH24:MI:SS.USZ')`))
case DateExpression:
return CustomExpression(Token("to_char("), e, Token(`::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z'`))
}
return expr
},
}
return jet.NewDialect(dialectParams)
}
func argumentToString(value any) (string, bool) {
switch bindVal := value.(type) {
case []byte:
return fmt.Sprintf("'\\x%s'", hex.EncodeToString(bindVal)), true
}
return "", false
}
func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {

View file

@ -12,6 +12,8 @@ type BoolExpression = jet.BoolExpression
// StringExpression interface
type StringExpression = jet.StringExpression
type ByteaExpression = jet.BlobExpression
// NumericExpression interface
type NumericExpression = jet.NumericExpression
@ -39,6 +41,9 @@ type TimestampzExpression = jet.TimestampzExpression
// RowExpression interface
type RowExpression = jet.RowExpression
// IntervalExpression interface
type IntervalExpression = jet.IntervalExpression
// DateRange Expression interface
type DateRange = jet.Range[DateExpression]
@ -82,6 +87,11 @@ var TimeExp = jet.TimeExp
// Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp
// ByteaExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as string expression.
// Does not add sql cast to generated sql builder output.
var ByteaExp = jet.BlobExp
// TimezExp is time with time zone expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time with time zone expression.
// Does not add sql cast to generated sql builder output.
@ -102,6 +112,11 @@ var TimestampExp = jet.TimestampExp
// Does not add sql cast to generated sql builder output.
var TimestampzExp = jet.TimestampzExp
// IntervalExp is interval expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as interval expression.
// Does not add sql cast to generated sql builder output.
var IntervalExp = jet.IntervalExp
// RowExp serves as a wrapper for an arbitrary expression, treating it as a row expression.
// This enables the Go compiler to interpret any expression as a row expression
// Note: This does not modify the generated SQL builder output by adding a SQL CAST operation.
@ -134,15 +149,17 @@ type RawArgs = map[string]interface{}
var (
Raw = jet.Raw
RawBool = jet.RawBool
RawInt = jet.RawInt
RawFloat = jet.RawFloat
RawString = jet.RawString
RawTime = jet.RawTime
RawTimez = jet.RawTimez
RawTimestamp = jet.RawTimestamp
RawTimestampz = jet.RawTimestampz
RawDate = jet.RawDate
RawBool = jet.RawBool
RawInt = jet.RawInt
RawFloat = jet.RawFloat
RawString = jet.RawString
RawTime = jet.RawTime
RawTimez = jet.RawTimez
RawTimestamp = jet.RawTimestamp
RawTimestampz = jet.RawTimestampz
RawDate = jet.RawDate
RawBytea = jet.RawBlob
RawNumRange = jet.RawRange[jet.NumericExpression]
RawInt4Range = jet.RawRange[jet.Int4Expression]
RawInt8Range = jet.RawRange[jet.Int8Expression]

View file

@ -192,9 +192,27 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression
return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...)
}
// Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions
var (
UTF8 = StringExp(jet.FixedLiteral("UTF8"))
LATIN1 = StringExp(jet.FixedLiteral("LATIN1"))
LATIN2 = StringExp(jet.FixedLiteral("LATIN2"))
LATIN3 = StringExp(jet.FixedLiteral("LATIN3"))
LATIN4 = StringExp(jet.FixedLiteral("LATIN4"))
WIN1252 = StringExp(jet.FixedLiteral("WIN1252"))
ISO_8859_5 = StringExp(jet.FixedLiteral("ISO_8859_5"))
ISO_8859_6 = StringExp(jet.FixedLiteral("ISO_8859_6"))
ISO_8859_7 = StringExp(jet.FixedLiteral("ISO_8859_7"))
ISO_8859_8 = StringExp(jet.FixedLiteral("ISO_8859_8"))
KOI8R = StringExp(jet.FixedLiteral("KOI8R"))
KOI8U = StringExp(jet.FixedLiteral("KOI8U"))
)
// CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding.
var CONVERT = jet.CONVERT
func CONVERT(str ByteaExpression, srcEncoding StringExpression, destEncoding StringExpression) ByteaExpression {
return jet.CONVERT(str, srcEncoding, destEncoding)
}
// CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding.
@ -203,6 +221,13 @@ var CONVERT_FROM = jet.CONVERT_FROM
// CONVERT_TO converts string to dest_encoding.
var CONVERT_TO = jet.CONVERT_TO
// ENCODE/DECODE textual formats
var (
Base64 = StringExp(jet.FixedLiteral("base64"))
Escape = StringExp(jet.FixedLiteral("escape"))
Hex = StringExp(jet.FixedLiteral("hex"))
)
// ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
@ -212,7 +237,7 @@ var ENCODE = jet.ENCODE
// Options for format are same as in encode.
var DECODE = jet.DECODE
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
// FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf.
func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...)
}
@ -242,6 +267,49 @@ var LPAD = jet.LPAD
// fill (a space by default). If the string is already longer than length then it is truncated.
var RPAD = jet.RPAD
// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”).
var BIT_COUNT = jet.BIT_COUNT
// GET_BIT extracts n'th bit from binary string.
func GET_BIT(bytes ByteaExpression, n IntegerExpression) IntegerExpression {
return IntExp(Func("GET_BIT", bytes, n))
}
// GET_BYTE extracts n'th byte from binary string.
func GET_BYTE(bytes ByteaExpression, n IntegerExpression) IntegerExpression {
return IntExp(Func("GET_BYTE", bytes, n))
}
// SET_BIT sets n'th bit in binary string to newvalue.
func SET_BIT(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression {
return ByteaExp(Func("SET_BIT", bytes, n, newValue))
}
// SET_BYTE sets n'th byte in binary string to newvalue.
func SET_BYTE(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression {
return ByteaExp(Func("SET_BYTE", bytes, n, newValue))
}
// SHA224 computes the SHA-224 hash of the binary string.
func SHA224(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA224", bytes))
}
// SHA256 computes the SHA-256 hash of the binary string.
func SHA256(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA256", bytes))
}
// SHA384 computes the SHA-384 hash of the binary string.
func SHA384(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA384", bytes))
}
// SHA512 computes the SHA-512 hash of the binary string.
func SHA512(bytes ByteaExpression) ByteaExpression {
return ByteaExp(Func("SHA512", bytes))
}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal
var MD5 = jet.MD5

View file

@ -1,257 +0,0 @@
package postgres
import (
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils/datetime"
"strconv"
"strings"
"time"
)
type quantityAndUnit = float64
type unit = float64
// Interval unit types
const (
YEAR unit = 123456789 + iota
MONTH
WEEK
DAY
HOUR
MINUTE
SECOND
MILLISECOND
MICROSECOND
DECADE
CENTURY
MILLENNIUM
)
// IntervalExpression is representation of postgres INTERVAL
type IntervalExpression interface {
jet.IsInterval
jet.Expression
EQ(rhs IntervalExpression) BoolExpression
NOT_EQ(rhs IntervalExpression) BoolExpression
IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
LT(rhs IntervalExpression) BoolExpression
LT_EQ(rhs IntervalExpression) BoolExpression
GT(rhs IntervalExpression) BoolExpression
GT_EQ(rhs IntervalExpression) BoolExpression
BETWEEN(min, max IntervalExpression) BoolExpression
NOT_BETWEEN(min, max IntervalExpression) BoolExpression
ADD(rhs IntervalExpression) IntervalExpression
SUB(rhs IntervalExpression) IntervalExpression
MUL(rhs NumericExpression) IntervalExpression
DIV(rhs NumericExpression) IntervalExpression
}
type intervalInterfaceImpl struct {
jet.IsIntervalImpl
parent IntervalExpression
}
func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression {
return jet.Eq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression {
return jet.NotEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return jet.IsDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return jet.IsNotDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression {
return jet.Lt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression {
return jet.LtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression {
return jet.Gt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression {
return jet.GtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) BETWEEN(min, max IntervalExpression) BoolExpression {
return jet.NewBetweenOperatorExpression(i.parent, min, max, false)
}
func (i *intervalInterfaceImpl) NOT_BETWEEN(min, max IntervalExpression) BoolExpression {
return jet.NewBetweenOperatorExpression(i.parent, min, max, true)
}
func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression {
return IntervalExp(jet.Add(i.parent, rhs))
}
func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression {
return IntervalExp(jet.Sub(i.parent, rhs))
}
func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression {
return IntervalExp(jet.Mul(i.parent, rhs))
}
func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression {
return IntervalExp(jet.Div(i.parent, rhs))
}
type intervalExpression struct {
jet.Expression
intervalInterfaceImpl
}
// INTERVAL creates new interval expression from the list of quantity-unit pairs.
//
// INTERVAL(1, DAY, 3, MINUTE)
func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression {
quantityAndUnitLen := len(quantityAndUnit)
if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 {
panic("jet: invalid number of quantity and unit fields")
}
var fields []string
for i := 0; i < len(quantityAndUnit); i += 2 {
quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64)
unitString := unitToString(quantityAndUnit[i+1])
fields = append(fields, quantity+" "+unitString)
}
intervalStr := fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " "))
newInterval := &intervalExpression{}
newInterval.Expression = jet.RawWithParent(intervalStr, newInterval)
newInterval.intervalInterfaceImpl.parent = newInterval
return newInterval
}
// INTERVALd creates interval expression from time.Duration
func INTERVALd(duration time.Duration) IntervalExpression {
days, hours, minutes, seconds, microseconds := datetime.ExtractTimeComponents(duration)
var quantityAndUnits []quantityAndUnit
if days > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days))
quantityAndUnits = append(quantityAndUnits, DAY)
}
if hours > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours))
quantityAndUnits = append(quantityAndUnits, HOUR)
}
if minutes > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes))
quantityAndUnits = append(quantityAndUnits, MINUTE)
}
if seconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds))
quantityAndUnits = append(quantityAndUnits, SECOND)
}
if microseconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds))
quantityAndUnits = append(quantityAndUnits, MICROSECOND)
}
if len(quantityAndUnits) == 0 {
return INTERVAL(0, MICROSECOND)
}
return INTERVAL(quantityAndUnits...)
}
func unitToString(unit quantityAndUnit) string {
switch unit {
case YEAR:
return "YEAR"
case MONTH:
return "MONTH"
case WEEK:
return "WEEK"
case DAY:
return "DAY"
case HOUR:
return "HOUR"
case MINUTE:
return "MINUTE"
case SECOND:
return "SECOND"
case MILLISECOND:
return "MILLISECOND"
case MICROSECOND:
return "MICROSECOND"
case DECADE:
return "DECADE"
case CENTURY:
return "CENTURY"
case MILLENNIUM:
return "MILLENNIUM"
// additional field units for EXTRACT function
case DOW:
return "DOW"
case DOY:
return "DOY"
case EPOCH:
return "EPOCH"
case ISODOW:
return "ISODOW"
case ISOYEAR:
return "ISOYEAR"
case JULIAN:
return "JULIAN"
case QUARTER:
return "QUARTER"
case TIMEZONE:
return "TIMEZONE"
case TIMEZONE_HOUR:
return "TIMEZONE_HOUR"
case TIMEZONE_MINUTE:
return "TIMEZONE_MINUTE"
default:
panic("jet: invalid INTERVAL unit type")
}
}
//---------------------------------------------------//
type intervalWrapper struct {
intervalInterfaceImpl
Expression
}
func newIntervalExpressionWrap(expression Expression) IntervalExpression {
intervalWrap := &intervalWrapper{Expression: expression}
intervalWrap.intervalInterfaceImpl.parent = intervalWrap
return intervalWrap
}
// IntervalExp is interval expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as interval expression.
// Does not add sql cast to generated sql builder output.
func IntervalExp(expression Expression) IntervalExpression {
return newIntervalExpressionWrap(expression)
}

View file

@ -0,0 +1,140 @@
package postgres
import (
"fmt"
"github.com/go-jet/jet/v2/internal/utils/datetime"
"strconv"
"strings"
"time"
)
type quantityAndUnit = float64
type unit = float64
// Interval unit types
const (
YEAR unit = 123456789 + iota
MONTH
WEEK
DAY
HOUR
MINUTE
SECOND
MILLISECOND
MICROSECOND
DECADE
CENTURY
MILLENNIUM
)
// INTERVAL creates new interval expression from the list of quantity-unit pairs.
//
// INTERVAL(1, DAY, 3, MINUTE)
func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression {
quantityAndUnitLen := len(quantityAndUnit)
if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 {
panic("jet: invalid number of quantity and unit fields")
}
var fields []string
for i := 0; i < len(quantityAndUnit); i += 2 {
quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64)
unitString := unitToString(quantityAndUnit[i+1])
fields = append(fields, quantity+" "+unitString)
}
return IntervalExp(CustomExpression(Token(fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " ")))))
}
// INTERVALd creates interval expression from time.Duration
func INTERVALd(duration time.Duration) IntervalExpression {
days, hours, minutes, seconds, microseconds := datetime.ExtractTimeComponents(duration)
var quantityAndUnits []quantityAndUnit
if days > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days))
quantityAndUnits = append(quantityAndUnits, DAY)
}
if hours > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours))
quantityAndUnits = append(quantityAndUnits, HOUR)
}
if minutes > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes))
quantityAndUnits = append(quantityAndUnits, MINUTE)
}
if seconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds))
quantityAndUnits = append(quantityAndUnits, SECOND)
}
if microseconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds))
quantityAndUnits = append(quantityAndUnits, MICROSECOND)
}
if len(quantityAndUnits) == 0 {
return INTERVAL(0, MICROSECOND)
}
return INTERVAL(quantityAndUnits...)
}
func unitToString(unit quantityAndUnit) string {
switch unit {
case YEAR:
return "YEAR"
case MONTH:
return "MONTH"
case WEEK:
return "WEEK"
case DAY:
return "DAY"
case HOUR:
return "HOUR"
case MINUTE:
return "MINUTE"
case SECOND:
return "SECOND"
case MILLISECOND:
return "MILLISECOND"
case MICROSECOND:
return "MICROSECOND"
case DECADE:
return "DECADE"
case CENTURY:
return "CENTURY"
case MILLENNIUM:
return "MILLENNIUM"
// additional field units for EXTRACT function
case DOW:
return "DOW"
case DOY:
return "DOY"
case EPOCH:
return "EPOCH"
case ISODOW:
return "ISODOW"
case ISOYEAR:
return "ISOYEAR"
case JULIAN:
return "JULIAN"
case QUARTER:
return "QUARTER"
case TIMEZONE:
return "TIMEZONE"
case TIMEZONE_HOUR:
return "TIMEZONE_HOUR"
case TIMEZONE_MINUTE:
return "TIMEZONE_MINUTE"
default:
panic("jet: invalid INTERVAL unit type")
}
}
//---------------------------------------------------//

View file

@ -127,7 +127,7 @@ func Json(value interface{}) StringExpression {
var UUID = jet.UUID
// Bytea creates new bytea literal expression
func Bytea(value interface{}) StringExpression {
func Bytea(value interface{}) ByteaExpression {
switch value.(type) {
case string, []byte:
default:

132
postgres/select_json.go Normal file
View file

@ -0,0 +1,132 @@
package postgres
import (
"github.com/go-jet/jet/v2/internal/jet"
"strings"
)
// SELECT_JSON_ARR creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_ARR(projections ...Projection) SelectStatement {
return newSelectStatementJson(projections, jet.SelectJsonArrStatementType)
}
// SELECT_JSON_OBJ creates a new SelectJsonStatement with a list of projections.
func SELECT_JSON_OBJ(projections ...Projection) SelectStatement {
return newSelectStatementJson(projections, jet.SelectJsonObjStatementType)
}
type selectJsonStatement struct {
*selectStatementImpl
subQuery *selectStatementImpl
statementType jet.StatementType
}
func (s *selectJsonStatement) AS(alias string) Projection {
s.setSubQueryAlias(strings.ToLower(alias) + "_")
return s.selectStatementImpl.AS(alias)
}
func (s *selectJsonStatement) FROM(table ...ReadableTable) SelectStatement {
s.subQuery.From.Tables = readableTablesToSerializerList(table)
return s
}
func (s *selectJsonStatement) DISTINCT(on ...jet.ColumnExpression) SelectStatement {
s.subQuery.Select.Distinct = true
s.subQuery.Select.DistinctOnColumns = on
return s
}
func (s *selectJsonStatement) WHERE(condition BoolExpression) SelectStatement {
s.subQuery.Where.Condition = condition
return s
}
func (s *selectJsonStatement) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.subQuery.GroupBy.List = groupByClauses
return s
}
func (s *selectJsonStatement) HAVING(boolExpression BoolExpression) SelectStatement {
s.subQuery.Having.Condition = boolExpression
return s
}
func (s *selectJsonStatement) WINDOW(name string) windowExpand {
s.subQuery.Window.Definitions = append(s.subQuery.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{
selectStatement: s.subQuery,
rootStmt: s,
}
}
func (s *selectJsonStatement) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement {
s.subQuery.OrderBy.List = orderByClauses
return s
}
func (s *selectJsonStatement) LIMIT(limit int64) SelectStatement {
s.subQuery.Limit.Count = limit
return s
}
func (s *selectJsonStatement) OFFSET(offset int64) SelectStatement {
s.subQuery.Offset.Count = Int(offset)
return s
}
func (s *selectJsonStatement) OFFSET_e(offset IntegerExpression) SelectStatement {
s.subQuery.Offset.Count = offset
return s
}
func (s *selectJsonStatement) FETCH_FIRST(count IntegerExpression) fetchExpand {
s.subQuery.Fetch.Count = count
return fetchExpand{
selectStatement: s.subQuery,
rootStmt: s,
}
}
func (s *selectJsonStatement) FOR(lock RowLock) SelectStatement {
s.subQuery.For.Lock = lock
return s
}
func newSelectStatementJson(projections []Projection, statementType jet.StatementType) SelectStatement {
newSelectJson := &selectJsonStatement{
selectStatementImpl: newSelectStatement(statementType, nil, nil),
subQuery: newSelectStatement(statementType, nil, projections),
statementType: statementType,
}
newSelectJson.setOperatorsImpl.stmtRoot = newSelectJson
newSelectJson.subQuery.Select.IsForRowToJson = true
newSelectJson.setSubQueryAlias("")
return newSelectJson
}
func (s *selectJsonStatement) setSubQueryAlias(alias string) {
subQueryAlias := alias + "records"
jsonAlias := alias + "json"
s.Select.ProjectionList = ProjectionList{constructJsonFunc(s.statementType, subQueryAlias).AS(jsonAlias)}
s.From.Tables = []jet.Serializer{newSelectTable(s.subQuery, subQueryAlias, nil)}
}
func constructJsonFunc(statementType jet.StatementType, subQueryAlias string) Expression {
rowToJson := Func("row_to_json", CustomExpression(Token(subQueryAlias)))
if statementType == jet.SelectJsonArrStatementType {
return Func("json_agg", rowToJson)
}
return rowToJson
}

View file

@ -70,12 +70,12 @@ type SelectStatement interface {
// SELECT creates new SelectStatement with list of projections
func SELECT(projection Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection}, projections...))
return newSelectStatement(jet.SelectStatementType, nil, append([]Projection{projection}, projections...))
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
func newSelectStatement(stmtType jet.StatementType, table ReadableTable, projections []Projection) *selectStatementImpl {
newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect,
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, stmtType, newSelect,
&newSelect.Select,
&newSelect.From,
&newSelect.Where,
@ -94,7 +94,7 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
}
newSelect.Limit.Count = -1
newSelect.setOperatorsImpl.parent = newSelect
newSelect.setOperatorsImpl.stmtRoot = newSelect
return newSelect
}
@ -144,7 +144,10 @@ func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatem
func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s}
return windowExpand{
selectStatement: s,
rootStmt: s,
}
}
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement {
@ -172,6 +175,7 @@ func (s *selectStatementImpl) FETCH_FIRST(count IntegerExpression) fetchExpand {
return fetchExpand{
selectStatement: s,
rootStmt: s,
}
}
@ -188,6 +192,7 @@ func (s *selectStatementImpl) AsTable(alias string) SelectTable {
type windowExpand struct {
selectStatement *selectStatementImpl
rootStmt SelectStatement
}
func (w windowExpand) AS(window ...jet.Window) SelectStatement {
@ -196,7 +201,7 @@ func (w windowExpand) AS(window ...jet.Window) SelectStatement {
}
windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement
return w.rootStmt
}
func toJetFrameOffset(offset int64) jet.Serializer {
@ -216,16 +221,17 @@ func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
type fetchExpand struct {
selectStatement *selectStatementImpl
rootStmt SelectStatement
}
func (f fetchExpand) ROWS_ONLY() SelectStatement {
f.selectStatement.Fetch.WithTies = false
return f.selectStatement
return f.rootStmt
}
func (f fetchExpand) ROWS_WITH_TIES() SelectStatement {
f.selectStatement.Fetch.WithTies = true
return f.selectStatement
return f.rootStmt
}

View file

@ -65,31 +65,31 @@ type setOperators interface {
}
type setOperatorsImpl struct {
parent setOperators
stmtRoot setOperators
}
func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement {
return UNION(s.parent, rhs)
return UNION(s.stmtRoot, rhs)
}
func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement {
return UNION_ALL(s.parent, rhs)
return UNION_ALL(s.stmtRoot, rhs)
}
func (s *setOperatorsImpl) INTERSECT(rhs SelectStatement) setStatement {
return INTERSECT(s.parent, rhs)
return INTERSECT(s.stmtRoot, rhs)
}
func (s *setOperatorsImpl) INTERSECT_ALL(rhs SelectStatement) setStatement {
return INTERSECT_ALL(s.parent, rhs)
return INTERSECT_ALL(s.stmtRoot, rhs)
}
func (s *setOperatorsImpl) EXCEPT(rhs SelectStatement) setStatement {
return EXCEPT(s.parent, rhs)
return EXCEPT(s.stmtRoot, rhs)
}
func (s *setOperatorsImpl) EXCEPT_ALL(rhs SelectStatement) setStatement {
return EXCEPT_ALL(s.parent, rhs)
return EXCEPT_ALL(s.stmtRoot, rhs)
}
type setStatementImpl struct {
@ -110,7 +110,7 @@ func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStat
newSetStatement.setOperator.Selects = selects
newSetStatement.setOperator.Limit.Count = -1
newSetStatement.setOperatorsImpl.parent = newSetStatement
newSetStatement.setOperatorsImpl.stmtRoot = newSetStatement
return newSetStatement
}

View file

@ -55,7 +55,7 @@ type readableTableInterfaceImpl struct {
// Generates a select query on the current tableName.
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
return newSelectStatement(jet.SelectStatementType, r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.

View file

@ -4,14 +4,13 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/go-jet/jet/v2/internal/utils/min"
"github.com/go-jet/jet/v2/internal/utils/datetime"
"reflect"
"strconv"
"time"
)
var (
castOverFlowError = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error")
errCastOverFlow = fmt.Errorf("cannot cast a negative value to an unsigned value, buffer overflow error")
)
// NullBool struct
@ -64,7 +63,12 @@ func (nt *NullTime) Scan(value interface{}) error {
// Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
// At this point we try to parse those values using some of the predefined formats
nt.Time, nt.Valid = tryParseAsTime(value)
nt.Time, nt.Valid = datetime.TryParseAsTime(value, []string{
"2006-01-02 15:04:05-07:00", // sqlite
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx
"15:04:05.999999", // pgx
})
if !nt.Valid {
return fmt.Errorf("can't scan time.Time from %q", value)
@ -73,42 +77,6 @@ func (nt *NullTime) Scan(value interface{}) error {
return nil
}
var formats = []string{
"2006-01-02 15:04:05-07:00", // sqlite
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx
"15:04:05.999999", // pgx
}
func tryParseAsTime(value interface{}) (time.Time, bool) {
var timeStr string
switch v := value.(type) {
case string:
timeStr = v
case []byte:
timeStr = string(v)
case int64:
return time.Unix(v, 0), true // sqlite
default:
return time.Time{}, false
}
for _, format := range formats {
formatLen := min.Int(len(format), len(timeStr))
t, err := time.Parse(format[:formatLen], timeStr)
if err != nil {
continue
}
return t, true
}
return time.Time{}, false
}
// NullUInt64 struct
type NullUInt64 struct {
UInt64 uint64
@ -124,31 +92,31 @@ func (n *NullUInt64) Scan(value interface{}) error {
return nil
case int64:
if v < 0 {
return castOverFlowError
return errCastOverFlow
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int32:
if v < 0 {
return castOverFlowError
return errCastOverFlow
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int16:
if v < 0 {
return castOverFlowError
return errCastOverFlow
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int8:
if v < 0 {
return castOverFlowError
return errCastOverFlow
}
n.UInt64, n.Valid = uint64(v), true
return nil
case int:
if v < 0 {
return castOverFlowError
return errCastOverFlow
}
n.UInt64, n.Valid = uint64(v), true
return nil

View file

@ -103,25 +103,25 @@ func TestNullUInt64(t *testing.T) {
//Validate negative use cases
err := nullUInt64.Scan(int64(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
assert.Error(t, err, errCastOverFlow)
//Validate negative use cases
err = nullUInt64.Scan(-5)
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
assert.Error(t, err, errCastOverFlow)
//Validate negative use cases
err = nullUInt64.Scan(int32(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
assert.Error(t, err, errCastOverFlow)
//Validate negative use cases
err = nullUInt64.Scan(int16(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
assert.Error(t, err, errCastOverFlow)
//Validate negative use cases
err = nullUInt64.Scan(int8(-5))
assert.NotNil(t, err)
assert.Error(t, err, castOverFlowError)
assert.Error(t, err, errCastOverFlow)
}

View file

@ -3,6 +3,7 @@ package qrm
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/go-jet/jet/v2/internal/utils/must"
@ -12,10 +13,130 @@ import (
// ErrNoRows is returned by Query when query result set is empty
var ErrNoRows = errors.New("qrm: no rows in result set")
// Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db`
// using context `ctx` into destination `destPtr`.
// Destination can be either pointer to struct or pointer to slice of structs.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
// QueryJsonObj executes a SQL query that returns a JSON object, unmarshals the result into the provided destination,
// and returns the number of rows processed.
//
// The query must return exactly one row with a single column; otherwise, an error is returned.
//
// Parameters:
//
// ctx - The context for managing query execution (timeouts, cancellations).
// db - The database connection or transaction that implements the Queryable interface.
// query - The SQL query string to be executed.
// args - A slice of arguments to be used with the query.
// destPtr - A pointer to the variable where the unmarshaled JSON result will be stored.
// The destination should be a pointer to a struct or map[string]any.
//
// Returns:
//
// rowsProcessed - The number of rows processed by the query execution.
// err - An error if query execution or unmarshaling fails.
func QueryJsonObj(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(destPtr, "jet: destination is nil")
must.BeTypeKind(destPtr, reflect.Ptr, jsonDestObjErr)
destType := reflect.TypeOf(destPtr).Elem()
must.BeTrue(destType.Kind() == reflect.Struct || destType.Kind() == reflect.Map, jsonDestObjErr)
return queryJson(ctx, db, query, args, destPtr)
}
// QueryJsonArr executes a SQL query that returns a JSON array, unmarshals the result into the provided destination,
// and returns the number of rows processed.
//
// The query must return exactly one row with a single column; otherwise, an error is returned.
//
// Parameters:
//
// ctx - The context for managing query execution (timeouts, cancellations).
// db - The database connection or transaction that implements the Queryable interface.
// query - The SQL query string to be executed.
// args - A slice of arguments to be used with the query.
// destPtr - A pointer to the variable where the unmarshaled JSON array will be stored.
// The destination should be a pointer to a slice of structs or []map[string]any.
//
// Returns:
//
// rowsProcessed - The number of rows processed by the query execution.
// err - An error if query execution or unmarshaling fails.
func QueryJsonArr(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(destPtr, "jet: destination is nil")
must.BeTypeKind(destPtr, reflect.Ptr, jsonDestArrErr)
destType := reflect.TypeOf(destPtr).Elem()
must.BeTrue(destType.Kind() == reflect.Slice, jsonDestArrErr)
return queryJson(ctx, db, query, args, destPtr)
}
var jsonDestObjErr = "jet: SELECT_JSON_OBJ destination has to be a pointer to struct or pointer to map[string]any"
var jsonDestArrErr = "jet: SELECT_JSON_ARR destination has to be a pointer to slice of struct or pointer to []map[string]any"
func queryJson(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(db, "jet: db is nil")
var rows *sql.Rows
rows, err = db.QueryContext(ctx, query, args...)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
err = rows.Err()
if err != nil {
return 0, err
}
return 0, ErrNoRows
}
var jsonData []byte
err = rows.Scan(&jsonData)
if err != nil {
return 1, err
}
if jsonData == nil {
return 1, nil
}
err = json.Unmarshal(jsonData, &destPtr)
if err != nil {
return 1, fmt.Errorf("jet: invalid json, %w", err)
}
if rows.Next() {
return 1, fmt.Errorf("jet: query returned more then one row")
}
err = rows.Close()
if err != nil {
return 1, err
}
return 1, nil
}
// Query executes a Query Result Mapping (QRM) of the provided SQL `query` with a list of parameterized arguments `args`
// over the database connection `db` using the provided context `ctx` and stores the result in the destination `destPtr`.
//
// The destination must be a pointer to either a struct or a slice of structs
// If the destination is a pointer to a struct and no rows are returned, the method returns qrm.ErrNoRows.
//
// Parameters:
//
// ctx - The context for managing query execution (timeouts, cancellations).
// db - The database connection or transaction implementing the Queryable interface.
// query - The SQL query string to be executed.
// args - A slice of arguments to be used with the query.
// destPtr - A pointer to the variable where the query result will be stored. This can be a pointer to a struct or a slice of structs.
//
// Returns:
//
// rowsProcessed - The number of rows processed by the query execution.
// err - An error if query execution or result mapping fails, or if no rows are found when a struct is expected.
func Query(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
must.BeInitializedPtr(db, "jet: db is nil")
@ -185,7 +306,7 @@ func mapRowToSlice(
func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
index := 0
if field != nil {
typeName, columnName := getTypeAndFieldName("", *field)
typeName, columnName, _ := getTypeAndFieldName("", *field)
if index = scanContext.typeToColumnIndex(typeName, columnName); index < 0 {
return
}
@ -233,9 +354,11 @@ func mapRowToStruct(
continue
}
fieldMap := typeInf.fieldMappings[i]
fieldMappingInfo := typeInf.fieldMappings[i]
if fieldMap.complexType {
switch fieldMappingInfo.Type {
case complexType:
var changed bool
changed, err = mapRowToDestinationValue(scanContext, concat(groupKey, ":", field.Name), fieldValue, &field)
@ -246,13 +369,12 @@ func mapRowToStruct(
if changed {
updated = true
}
} else {
if mapOnlySlices || fieldMap.rowIndex == -1 {
default:
if mapOnlySlices || fieldMappingInfo.rowIndex == -1 {
continue
}
scannedValue := scanContext.rowElemValue(fieldMap.rowIndex)
scannedValue := scanContext.rowElemValue(fieldMappingInfo.rowIndex)
if !scannedValue.IsValid() {
setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value
@ -261,7 +383,8 @@ func mapRowToStruct(
updated = true
if fieldMap.implementsScanner {
switch fieldMappingInfo.Type {
case implementsScanner:
initializeValueIfNilPtr(fieldValue)
fieldScanner := getScanner(fieldValue)
@ -270,14 +393,27 @@ func mapRowToStruct(
err := fieldScanner.Scan(value)
if err != nil {
return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err)
return updated, qrmAssignError(scannedValue, field, err)
}
} else {
case jsonUnmarshal:
value, ok := scannedValue.Interface().([]byte)
if !ok {
return updated, qrmAssignError(scannedValue, field, fmt.Errorf("value not convertable to []byte"))
}
fieldInterface := fieldValue.Addr().Interface()
err := json.Unmarshal(value, fieldInterface)
if err != nil {
return updated, qrmAssignError(scannedValue, field, fmt.Errorf("invalid json, %w", err))
}
default: // simple type
err := assign(scannedValue, fieldValue)
if err != nil {
return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(),
field.Name, field.Type.String(), err)
return updated, qrmAssignError(scannedValue, field, err)
}
}
}
@ -286,6 +422,11 @@ func mapRowToStruct(
return
}
func qrmAssignError(scannedValue reflect.Value, field reflect.StructField, err error) error {
return fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(),
field.Name, field.Type.String(), err)
}
func mapRowToDestinationValue(
scanContext *ScanContext,
groupKey string,

View file

@ -75,10 +75,18 @@ type typeInfo struct {
fieldMappings []fieldMapping
}
type fieldMappingType int
const (
simpleType fieldMappingType = iota
complexType // slice and struct are complex types supported
implementsScanner
jsonUnmarshal
)
type fieldMapping struct {
complexType bool // slice and struct are complex types
rowIndex int // index in ScanContext.row
implementsScanner bool
rowIndex int // index in ScanContext.row
Type fieldMappingType
}
func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo {
@ -100,17 +108,21 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
newTypeName, fieldName, jsonUnmarshaler := getTypeAndFieldName(typeName, field)
columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{
rowIndex: columnIndex,
}
if implementsScannerType(field.Type) {
fieldMap.implementsScanner = true
if jsonUnmarshaler {
fieldMap.Type = jsonUnmarshal
} else if implementsScannerType(field.Type) {
fieldMap.Type = implementsScanner
} else if !isSimpleModelType(field.Type) {
fieldMap.complexType = true
fieldMap.Type = complexType
} else {
fieldMap.Type = simpleType
}
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
@ -188,7 +200,7 @@ func (s *ScanContext) getGroupKeyInfo(
fieldType := indirectType(field.Type)
if isPrimaryKey(field, primaryKeyOverwrites) {
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
newTypeName, fieldName, _ := getTypeAndFieldName(typeName, field)
pkIndex := s.typeToColumnIndex(newTypeName, fieldName)

View file

@ -107,20 +107,26 @@ func getTypeName(structType reflect.Type, parentField *reflect.StructField) stri
return toCommonIdentifier(aliasParts[0])
}
func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) {
func getTypeAndFieldName(structType string, field reflect.StructField) (string, string, bool) {
aliasTag := field.Tag.Get("alias")
if aliasTag == "" {
return structType, field.Name
if aliasTag != "" {
aliasParts := strings.Split(aliasTag, ".")
if len(aliasParts) == 1 {
return structType, toCommonIdentifier(aliasParts[0]), false
}
return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]), false
}
aliasParts := strings.Split(aliasTag, ".")
jsonColumnTag := field.Tag.Get("json_column")
if len(aliasParts) == 1 {
return structType, toCommonIdentifier(aliasParts[0])
if jsonColumnTag != "" {
return "", toCommonIdentifier(jsonColumnTag), true
}
return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1])
return structType, field.Name, false
}
var replacer = strings.NewReplacer(" ", "", "-", "", "_", "")

View file

@ -42,6 +42,6 @@ func (c *cast) AS_REAL() FloatExpression {
}
// AS_BLOB cast expression to BLOB type
func (c *cast) AS_BLOB() StringExpression {
return StringExp(c.AS("BLOB"))
func (c *cast) AS_BLOB() BlobExpression {
return BlobExp(c.AS("BLOB"))
}

View file

@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString
// StringColumn creates named string column.
var StringColumn = jet.StringColumn
// ColumnBlob is interface for
type ColumnBlob = jet.ColumnBlob
// BlobColumn creates new named blob column
var BlobColumn = jet.BlobColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger

View file

@ -1,6 +1,7 @@
package sqlite
import (
"encoding/hex"
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
)
@ -23,7 +24,8 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(int) string {
return "?"
},
ReservedWords: reservedWords2,
ArgumentToString: argumentToString,
ReservedWords: reservedWords2,
ValuesDefaultColumnName: func(index int) string {
return fmt.Sprintf("column%d", index+1)
},
@ -32,6 +34,15 @@ func newDialect() jet.Dialect {
return jet.NewDialect(mySQLDialectParams)
}
func argumentToString(value any) (string, bool) {
switch bindVal := value.(type) {
case []byte:
return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true
}
return "", false
}
func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {

View file

@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression
// StringExpression interface
type StringExpression = jet.StringExpression
// BlobExpression interface
type BlobExpression = jet.BlobExpression
// NumericExpression is shared interface for integer or real expression
type NumericExpression = jet.NumericExpression
@ -46,6 +49,11 @@ var BoolExp = jet.BoolExp
// Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp
// BlobExp is blob expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as blob expression.
// Does not add sql cast to generated sql builder output.
var BlobExp = jet.BlobExp
// IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output.

View file

@ -196,11 +196,22 @@ var RTRIM = jet.RTRIM
// return jet.NewStringFunc("RIGHTSTR", str, n)
//}
// HEX function takes an input and returns its equivalent hexadecimal representation
var HEX = jet.HEX
// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument
// as a hexadecimal number and converts it to the byte represented by the number.
// The return value is a binary string.
var UNHEX = jet.UNHEX
// LENGTH returns number of characters in string with a given encoding
func LENGTH(str jet.StringExpression) jet.StringExpression {
func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression {
return jet.LENGTH(str)
}
// OCTET_LENGTH returns number of bytes in string expression
var OCTET_LENGTH = jet.OCTET_LENGTH
// LPAD fills up the string to length length by prepending the characters
// fill (a space by default). If the string is already longer than length
// then it is truncated (on the right).

View file

@ -50,6 +50,11 @@ var Decimal = jet.Decimal
// String creates new string literal expression
var String = jet.String
// Blob creates new blob literal expression
func Blob(data []byte) BlobExpression {
return BlobExp(jet.Literal(data))
}
// UUID is a helper function to create string literal expression from uuid object
// value can be any uuid type with a String method
var UUID = jet.UUID

View file

@ -26,7 +26,7 @@ services:
- ./testdata/init/mysql:/docker-entrypoint-initdb.d
mariadb:
image: mariadb:10.3
image: mariadb:11.4
command: ['--default-authentication-plugin=mysql_native_password', '--log_bin_trust_function_creators=1']
restart: always
environment:

View file

@ -23,19 +23,115 @@ func TestAllTypes(t *testing.T) {
var dest []model.AllTypes
err := AllTypes.
SELECT(AllTypes.AllColumns).
err := SELECT(AllTypes.AllColumns).
FROM(AllTypes).
LIMIT(2).
Query(db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 2)
//testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, allTypesJson)
}
func TestAllTypesJSON(t *testing.T) {
stmt := SELECT_JSON_ARR(
AllTypes.AllColumns.Except(
AllTypes.JSON,
AllTypes.JSONPtr,
AllTypes.Bit,
AllTypes.BitPtr,
),
CAST(AllTypes.JSON).AS_CHAR().AS("Json"),
CAST(AllTypes.JSONPtr).AS_CHAR().AS("JsonPtr"),
CAST(AllTypes.Bit).AS_CHAR().AS("Bit"),
CAST(AllTypes.BitPtr).AS_CHAR().AS("BitPtr"),
).FROM(AllTypes)
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'id', all_types.id,
'boolean', all_types.boolean = 1,
'booleanPtr', all_types.boolean_ptr = 1,
'tinyInt', all_types.tiny_int,
'uTinyInt', all_types.u_tiny_int,
'smallInt', all_types.small_int,
'uSmallInt', all_types.u_small_int,
'mediumInt', all_types.medium_int,
'uMediumInt', all_types.u_medium_int,
'integer', all_types.''integer'',
'uInteger', all_types.u_integer,
'bigInt', all_types.big_int,
'uBigInt', all_types.u_big_int,
'tinyIntPtr', all_types.tiny_int_ptr,
'uTinyIntPtr', all_types.u_tiny_int_ptr,
'smallIntPtr', all_types.small_int_ptr,
'uSmallIntPtr', all_types.u_small_int_ptr,
'mediumIntPtr', all_types.medium_int_ptr,
'uMediumIntPtr', all_types.u_medium_int_ptr,
'integerPtr', all_types.integer_ptr,
'uIntegerPtr', all_types.u_integer_ptr,
'bigIntPtr', all_types.big_int_ptr,
'uBigIntPtr', all_types.u_big_int_ptr,
'decimal', all_types.''decimal'',
'decimalPtr', all_types.decimal_ptr,
'numeric', all_types.''numeric'',
'numericPtr', all_types.numeric_ptr,
'float', all_types.''float'',
'floatPtr', all_types.float_ptr,
'double', all_types.''double'',
'doublePtr', all_types.double_ptr,
'real', all_types.''real'',
'realPtr', all_types.real_ptr,
'time', CONCAT('0000-01-01T', DATE_FORMAT(all_types.time,'%H:%i:%s.%fZ')),
'timePtr', CONCAT('0000-01-01T', DATE_FORMAT(all_types.time_ptr,'%H:%i:%s.%fZ')),
'date', CONCAT(DATE_FORMAT(all_types.date,'%Y-%m-%d'), 'T00:00:00Z'),
'datePtr', CONCAT(DATE_FORMAT(all_types.date_ptr,'%Y-%m-%d'), 'T00:00:00Z'),
'dateTime', DATE_FORMAT(all_types.date_time,'%Y-%m-%dT%H:%i:%s.%fZ'),
'dateTimePtr', DATE_FORMAT(all_types.date_time_ptr,'%Y-%m-%dT%H:%i:%s.%fZ'),
'timestamp', DATE_FORMAT(all_types.timestamp,'%Y-%m-%dT%H:%i:%s.%fZ'),
'timestampPtr', DATE_FORMAT(all_types.timestamp_ptr,'%Y-%m-%dT%H:%i:%s.%fZ'),
'year', all_types.year,
'yearPtr', all_types.year_ptr,
'char', all_types.''char'',
'charPtr', all_types.char_ptr,
'varChar', all_types.var_char,
'varCharPtr', all_types.var_char_ptr,
'binary', TO_BASE64(all_types.''binary''),
'binaryPtr', TO_BASE64(all_types.binary_ptr),
'varBinary', TO_BASE64(all_types.var_binary),
'varBinaryPtr', TO_BASE64(all_types.var_binary_ptr),
'blob', TO_BASE64(all_types.''blob''),
'blobPtr', TO_BASE64(all_types.blob_ptr),
'text', all_types.text,
'textPtr', all_types.text_ptr,
'enum', all_types.enum,
'enumPtr', all_types.enum_ptr,
'set', all_types.''set'',
'setPtr', all_types.set_ptr,
'Json', CAST(all_types.json AS CHAR),
'JsonPtr', CAST(all_types.json_ptr AS CHAR),
'Bit', CAST(all_types.bit AS CHAR),
'BitPtr', CAST(all_types.bit_ptr AS CHAR)
)) AS "json"
FROM test_sample.all_types;
`, "''", "`"))
var dest []model.AllTypes
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
// fix float rounding lost before comparison
dest[0].Float = 3.33
dest[0].FloatPtr = ptr.Of(3.33)
dest[1].Float = 3.33
testutils.AssertJSON(t, dest, allTypesJson)
}
func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes
@ -467,7 +563,8 @@ func TestStringOperators(t *testing.T) {
RTRIM(AllTypes.VarCharPtr),
CONCAT(String("string1"), Int(1), Float(11.12)),
CONCAT_WS(String("string1"), Int(1), Float(11.12)),
FORMAT(String("Hello %s, %1$s"), String("World")),
FORMAT(Int(11), Int(2)),
FORMAT(Int(11), Int(2), String("de_DE")),
LEFT(String("abcde"), Int(2)),
RIGHT(String("abcde"), Int(2)),
LENGTH(String("jose")),
@ -479,6 +576,12 @@ func TestStringOperators(t *testing.T) {
REVERSE(AllTypes.VarCharPtr),
SUBSTR(AllTypes.CharPtr, Int(3)),
SUBSTR(AllTypes.CharPtr, Int(3), Int(2)),
ELT(Int(2), AllTypes.CharPtr, AllTypes.Char, AllTypes.Text),
FIELD(AllTypes.Char, AllTypes.VarChar, AllTypes.Text),
FROM_BASE64(String("SGVsbG8gV29ybGQ=")),
TO_BASE64(String("Hello World")),
CHARSET(AllTypes.Char),
COLLATION(AllTypes.Text),
}
if !sourceIsMariaDB() {
@ -500,6 +603,71 @@ func TestStringOperators(t *testing.T) {
require.NoError(t, err)
}
func TestBlob(t *testing.T) {
var sampleBlob = Blob([]byte{11, 0, 22, 33, 44})
var textBlob = Blob([]byte("text blob"))
stmt := SELECT(
AllTypes.BlobPtr.EQ(sampleBlob),
AllTypes.BlobPtr.EQ(AllTypes.BlobPtr),
AllTypes.BlobPtr.NOT_EQ(sampleBlob),
AllTypes.BlobPtr.GT(textBlob),
AllTypes.BlobPtr.GT_EQ(AllTypes.BlobPtr),
AllTypes.BlobPtr.LT(AllTypes.BlobPtr),
AllTypes.BlobPtr.LT_EQ(sampleBlob),
AllTypes.BlobPtr.BETWEEN(Blob([]byte("min")), Blob([]byte("max"))),
AllTypes.BlobPtr.NOT_BETWEEN(AllTypes.BlobPtr, AllTypes.BlobPtr),
AllTypes.BlobPtr.CONCAT(textBlob),
AllTypes.BlobPtr.LIKE(AllTypes.BlobPtr),
AllTypes.BlobPtr.NOT_LIKE(sampleBlob),
BIT_LENGTH(textBlob),
LENGTH(sampleBlob),
CHAR_LENGTH(AllTypes.BlobPtr),
OCTET_LENGTH(textBlob),
CONCAT(sampleBlob, Int(1), Float(11.12)),
TO_BASE64(sampleBlob),
HEX(sampleBlob),
UNHEX(String("616B263A")),
SUBSTR(AllTypes.BlobPtr, Int(3)),
SUBSTR(AllTypes.BlobPtr, Int(3), Int(2)),
).FROM(
AllTypes,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT all_types.blob_ptr = X'0b0016212c',
all_types.blob_ptr = all_types.blob_ptr,
all_types.blob_ptr != X'0b0016212c',
all_types.blob_ptr > X'7465787420626c6f62',
all_types.blob_ptr >= all_types.blob_ptr,
all_types.blob_ptr < all_types.blob_ptr,
all_types.blob_ptr <= X'0b0016212c',
all_types.blob_ptr BETWEEN X'6d696e' AND X'6d6178',
all_types.blob_ptr NOT BETWEEN all_types.blob_ptr AND all_types.blob_ptr,
CONCAT(all_types.blob_ptr, X'7465787420626c6f62'),
all_types.blob_ptr LIKE all_types.blob_ptr,
all_types.blob_ptr NOT LIKE X'0b0016212c',
BIT_LENGTH(X'7465787420626c6f62'),
LENGTH(X'0b0016212c'),
CHAR_LENGTH(all_types.blob_ptr),
OCTET_LENGTH(X'7465787420626c6f62'),
CONCAT(X'0b0016212c', 1, 11.12),
TO_BASE64(X'0b0016212c'),
HEX(X'0b0016212c'),
UNHEX('616B263A'),
SUBSTR(all_types.blob_ptr, 3),
SUBSTR(all_types.blob_ptr, 3, 2)
FROM test_sample.all_types;
`)
var dest []struct{}
err := stmt.Query(db, &dest)
require.NoError(t, err)
}
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func TestTimeExpressions(t *testing.T) {
@ -1066,6 +1234,118 @@ func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
require.NoError(t, err)
}
func TestAllTypesSubQueryFrom(t *testing.T) {
subQuery := SELECT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Double,
AllTypes.Text,
AllTypes.Date,
AllTypes.Time,
AllTypes.Timestamp,
AllTypes.Blob,
).FROM(
AllTypes,
).AsTable("sub_query")
stmt := SELECT(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.Double.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Blob.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
SELECT sub_query.''all_types.boolean'' AS "all_types.boolean",
sub_query.''all_types.integer'' AS "all_types.integer",
sub_query.''all_types.double'' AS "all_types.double",
sub_query.''all_types.text'' AS "all_types.text",
sub_query.''all_types.date'' AS "all_types.date",
sub_query.''all_types.time'' AS "all_types.time",
sub_query.''all_types.timestamp'' AS "all_types.timestamp",
sub_query.''all_types.blob'' AS "all_types.blob"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.''integer'' AS "all_types.integer",
all_types.''double'' AS "all_types.double",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timestamp AS "all_types.timestamp",
all_types.''blob'' AS "all_types.blob"
FROM test_sample.all_types
) AS sub_query;
`, "''", "`"))
var dest []model.AllTypes
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.NotEmpty(t, dest)
t.Run("using SELECT_JSON", func(t *testing.T) {
stmtJson := SELECT_JSON_ARR(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.Double.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Blob.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertDebugStatementSql(t, stmtJson, strings.ReplaceAll(`
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'boolean', sub_query.''all_types.boolean'' = 1,
'integer', sub_query.''all_types.integer'',
'double', sub_query.''all_types.double'',
'text', sub_query.''all_types.text'',
'date', CONCAT(DATE_FORMAT(sub_query.''all_types.date'','%Y-%m-%d'), 'T00:00:00Z'),
'time', CONCAT('0000-01-01T', DATE_FORMAT(sub_query.''all_types.time'','%H:%i:%s.%fZ')),
'timestamp', DATE_FORMAT(sub_query.''all_types.timestamp'','%Y-%m-%dT%H:%i:%s.%fZ'),
'blob', TO_BASE64(sub_query.''all_types.blob'')
)) AS "json"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.''integer'' AS "all_types.integer",
all_types.''double'' AS "all_types.double",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timestamp AS "all_types.timestamp",
all_types.''blob'' AS "all_types.blob"
FROM test_sample.all_types
) AS sub_query;
`, "''", "`"))
var destJson []model.AllTypes
err := stmtJson.QueryContext(ctx, db, &destJson)
require.NoError(t, err)
t.Run("using AllColumns()", func(t *testing.T) {
stmtJsonAllColumns := SELECT_JSON_ARR(
subQuery.AllColumns(),
).FROM(
subQuery,
)
require.Equal(t, stmtJson.DebugSql(), stmtJsonAllColumns.DebugSql())
})
testutils.AssertJsonEqual(t, dest, destJson)
})
}
var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: ptr.Of(true),
@ -1131,6 +1411,7 @@ var toInsert = model.AllTypes{
var allTypesJson = `
[
{
"ID": 1,
"Boolean": false,
"BooleanPtr": true,
"TinyInt": -3,
@ -1195,6 +1476,7 @@ var allTypesJson = `
"JSONPtr": "{\"key1\": \"value1\", \"key2\": \"value2\"}"
},
{
"ID": 2,
"Boolean": false,
"BooleanPtr": null,
"TinyInt": -3,

138
tests/mysql/bench_test.go Normal file
View file

@ -0,0 +1,138 @@
//go:build bench
// +build bench
package mysql
import (
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/require"
"testing"
)
type allInfo []struct {
model.Actor
Films []struct {
model.Film
Language model.Language
Categories []model.Category
Inventories []struct {
model.Inventory
Rentals []struct {
model.Rental
Customer model.Customer
}
}
}
}
func BenchmarkTestDVDsJoinEverything(b *testing.B) {
for i := 0; i < b.N; i++ {
testDVDsJoinEverything(b)
}
}
func TestDVDsJoinEverything(t *testing.T) {
testDVDsJoinEverything(t)
}
func testDVDsJoinEverything(t require.TestingT) {
stmt := SELECT(
Actor.AllColumns,
Film.AllColumns,
Language.AllColumns,
Category.AllColumns,
Inventory.AllColumns,
Rental.AllColumns,
Customer.AllColumns,
).FROM(
Actor.
LEFT_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)).
LEFT_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)).
LEFT_JOIN(Language, Language.LanguageID.EQ(Film.LanguageID)).
LEFT_JOIN(FilmCategory, FilmCategory.FilmID.EQ(Film.FilmID)).
LEFT_JOIN(Category, Category.CategoryID.EQ(FilmCategory.CategoryID)).
LEFT_JOIN(Inventory, Inventory.FilmID.EQ(Film.FilmID)).
LEFT_JOIN(Rental, Rental.InventoryID.EQ(Inventory.InventoryID)).
LEFT_JOIN(Customer, Customer.CustomerID.EQ(Rental.CustomerID)),
).ORDER_BY(
Actor.ActorID.ASC(),
Film.FilmID.ASC(),
Category.CategoryID.ASC(),
Inventory.InventoryID.ASC(),
Rental.RentalID.ASC(),
)
var dest allInfo
err := stmt.Query(db, &dest)
require.NoError(t, err)
//testutils.SaveJSONFile(dest, "./testdata/results/mysql/dvds_join_everything.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/dvds_join_everything.json")
}
func BenchmarkTestDVDsJoinEverythingJSON(b *testing.B) {
for i := 0; i < b.N; i++ {
testDVDsJoinEverythingJSON(b)
}
}
func TestDVDsJoinEverythingJSON(t *testing.T) {
testDVDsJoinEverythingJSON(t)
}
func testDVDsJoinEverythingJSON(t require.TestingT) {
stmt := SELECT_JSON_ARR(
Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate,
SELECT_JSON_ARR(
Film.AllColumns,
SELECT_JSON_OBJ(Language.AllColumns).
FROM(Language).
WHERE(Language.LanguageID.EQ(Film.LanguageID)).AS("Language"),
SELECT_JSON_ARR(Category.AllColumns).
FROM(Category.INNER_JOIN(FilmCategory, FilmCategory.CategoryID.EQ(Category.CategoryID))).
WHERE(FilmCategory.FilmID.EQ(Film.FilmID)).AS("Categories"),
SELECT_JSON_ARR(
Inventory.AllColumns,
SELECT_JSON_ARR(
Rental.AllColumns,
SELECT_JSON_OBJ(Customer.AllColumns).
FROM(Customer).
WHERE(Customer.CustomerID.EQ(Rental.CustomerID)).AS("Customer"),
).FROM(Rental).
WHERE(Rental.InventoryID.EQ(Inventory.InventoryID)).
ORDER_BY(Rental.RentalID).AS("Rentals"),
).FROM(Inventory).
WHERE(Inventory.FilmID.EQ(Film.FilmID)).
ORDER_BY(Inventory.InventoryID).AS("Inventories"),
).FROM(Film.
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)),
).WHERE(FilmActor.ActorID.EQ(Actor.ActorID)).
ORDER_BY(Film.FilmID.ASC()).AS("Films"),
).FROM(Actor).
ORDER_BY(Actor.ActorID.ASC())
//fmt.Println(stmt.DebugSql())
var dest allInfo
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
//testutils.SaveJSONFile(dest, "./testdata/results/mysql/dvds_join_everything2.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/dvds_join_everything.json")
}

View file

@ -3,6 +3,7 @@ package mysql
import (
"os"
"os/exec"
"path/filepath"
"strconv"
"testing"
@ -304,7 +305,7 @@ func newLinkTableImpl(schemaName, tableName, alias string) linkTable {
DescriptionColumn = mysql.StringColumn("description")
allColumns = mysql.ColumnList{IDColumn, URLColumn, NameColumn, DescriptionColumn}
mutableColumns = mysql.ColumnList{URLColumn, NameColumn, DescriptionColumn}
defaultColumns = mysql.ColumnList{DescriptionColumn}
defaultColumns = mysql.ColumnList{}
)
return linkTable{
@ -606,3 +607,398 @@ func UseSchema(schema string) {
StaffList = StaffList.FromSchema(schema)
}
`
func TestGeneratedTestSampleDatabase(t *testing.T) {
enumDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/enum/")
modelDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/model/")
tableDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/table/")
viewDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/view/")
testutils.AssertFileNamesEqual(t, enumDir, "all_types_enum.go", "all_types_enum_ptr.go",
"all_types_view_enum.go", "all_types_view_enum_ptr.go")
testutils.AssertFileContent(t, enumDir+"/all_types_enum.go", allTypesEnum)
testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_enum.go", "all_types_enum_ptr.go",
"all_types_view.go", "all_types_view_enum.go", "all_types_view_enum_ptr.go", "link.go", "link2.go",
"floats.go", "user.go")
testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent)
testutils.AssertFileNamesEqual(t, tableDir, "all_types.go",
"link.go", "link2.go", "user.go", "floats.go", "table_use_schema.go")
testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent)
testutils.AssertFileNamesEqual(t, viewDir, "all_types_view.go", "view_use_schema.go")
}
var allTypesEnum = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package enum
import "github.com/go-jet/jet/v2/mysql"
var AllTypesEnum = &struct {
Value1 mysql.StringExpression
Value2 mysql.StringExpression
Value3 mysql.StringExpression
}{
Value1: mysql.NewEnumValue("value1"),
Value2: mysql.NewEnumValue("value2"),
Value3: mysql.NewEnumValue("value3"),
}
`
var allTypesModelContent = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type AllTypes struct {
ID int32 ` + "`" + `sql:"primary_key"` + "`" + `
Boolean bool
BooleanPtr *bool
TinyInt int8
UTinyInt uint8
SmallInt int16
USmallInt uint16
MediumInt int32
UMediumInt uint32
Integer int32
UInteger uint32
BigInt int64
UBigInt uint64
TinyIntPtr *int8
UTinyIntPtr *uint8
SmallIntPtr *int16
USmallIntPtr *uint16
MediumIntPtr *int32
UMediumIntPtr *uint32
IntegerPtr *int32
UIntegerPtr *uint32
BigIntPtr *int64
UBigIntPtr *uint64
Decimal float64
DecimalPtr *float64
Numeric float64
NumericPtr *float64
Float float64
FloatPtr *float64
Double float64
DoublePtr *float64
Real float64
RealPtr *float64
Bit string
BitPtr *string
Time time.Time
TimePtr *time.Time
Date time.Time
DatePtr *time.Time
DateTime time.Time
DateTimePtr *time.Time
Timestamp time.Time
TimestampPtr *time.Time
Year int16
YearPtr *int16
Char string
CharPtr *string
VarChar string
VarCharPtr *string
Binary []byte
BinaryPtr *[]byte
VarBinary []byte
VarBinaryPtr *[]byte
Blob []byte
BlobPtr *[]byte
Text string
TextPtr *string
Enum AllTypesEnum
EnumPtr *AllTypesEnumPtr
Set string
SetPtr *string
JSON string
JSONPtr *string
}
`
var allTypesTableContent = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/mysql"
)
var AllTypes = newAllTypesTable("test_sample", "all_types", "")
type allTypesTable struct {
mysql.Table
// Columns
ID mysql.ColumnInteger
Boolean mysql.ColumnBool
BooleanPtr mysql.ColumnBool
TinyInt mysql.ColumnInteger
UTinyInt mysql.ColumnInteger
SmallInt mysql.ColumnInteger
USmallInt mysql.ColumnInteger
MediumInt mysql.ColumnInteger
UMediumInt mysql.ColumnInteger
Integer mysql.ColumnInteger
UInteger mysql.ColumnInteger
BigInt mysql.ColumnInteger
UBigInt mysql.ColumnInteger
TinyIntPtr mysql.ColumnInteger
UTinyIntPtr mysql.ColumnInteger
SmallIntPtr mysql.ColumnInteger
USmallIntPtr mysql.ColumnInteger
MediumIntPtr mysql.ColumnInteger
UMediumIntPtr mysql.ColumnInteger
IntegerPtr mysql.ColumnInteger
UIntegerPtr mysql.ColumnInteger
BigIntPtr mysql.ColumnInteger
UBigIntPtr mysql.ColumnInteger
Decimal mysql.ColumnFloat
DecimalPtr mysql.ColumnFloat
Numeric mysql.ColumnFloat
NumericPtr mysql.ColumnFloat
Float mysql.ColumnFloat
FloatPtr mysql.ColumnFloat
Double mysql.ColumnFloat
DoublePtr mysql.ColumnFloat
Real mysql.ColumnFloat
RealPtr mysql.ColumnFloat
Bit mysql.ColumnString
BitPtr mysql.ColumnString
Time mysql.ColumnTime
TimePtr mysql.ColumnTime
Date mysql.ColumnDate
DatePtr mysql.ColumnDate
DateTime mysql.ColumnTimestamp
DateTimePtr mysql.ColumnTimestamp
Timestamp mysql.ColumnTimestamp
TimestampPtr mysql.ColumnTimestamp
Year mysql.ColumnInteger
YearPtr mysql.ColumnInteger
Char mysql.ColumnString
CharPtr mysql.ColumnString
VarChar mysql.ColumnString
VarCharPtr mysql.ColumnString
Binary mysql.ColumnBlob
BinaryPtr mysql.ColumnBlob
VarBinary mysql.ColumnBlob
VarBinaryPtr mysql.ColumnBlob
Blob mysql.ColumnBlob
BlobPtr mysql.ColumnBlob
Text mysql.ColumnString
TextPtr mysql.ColumnString
Enum mysql.ColumnString
EnumPtr mysql.ColumnString
Set mysql.ColumnString
SetPtr mysql.ColumnString
JSON mysql.ColumnString
JSONPtr mysql.ColumnString
AllColumns mysql.ColumnList
MutableColumns mysql.ColumnList
DefaultColumns mysql.ColumnList
}
type AllTypesTable struct {
allTypesTable
NEW allTypesTable
}
// AS creates new AllTypesTable with assigned alias
func (a AllTypesTable) AS(alias string) *AllTypesTable {
return newAllTypesTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new AllTypesTable with assigned schema name
func (a AllTypesTable) FromSchema(schemaName string) *AllTypesTable {
return newAllTypesTable(schemaName, a.TableName(), a.Alias())
}
// WithPrefix creates new AllTypesTable with assigned table prefix
func (a AllTypesTable) WithPrefix(prefix string) *AllTypesTable {
return newAllTypesTable(a.SchemaName(), prefix+a.TableName(), a.TableName())
}
// WithSuffix creates new AllTypesTable with assigned table suffix
func (a AllTypesTable) WithSuffix(suffix string) *AllTypesTable {
return newAllTypesTable(a.SchemaName(), a.TableName()+suffix, a.TableName())
}
func newAllTypesTable(schemaName, tableName, alias string) *AllTypesTable {
return &AllTypesTable{
allTypesTable: newAllTypesTableImpl(schemaName, tableName, alias),
NEW: newAllTypesTableImpl("", "new", ""),
}
}
func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable {
var (
IDColumn = mysql.IntegerColumn("id")
BooleanColumn = mysql.BoolColumn("boolean")
BooleanPtrColumn = mysql.BoolColumn("boolean_ptr")
TinyIntColumn = mysql.IntegerColumn("tiny_int")
UTinyIntColumn = mysql.IntegerColumn("u_tiny_int")
SmallIntColumn = mysql.IntegerColumn("small_int")
USmallIntColumn = mysql.IntegerColumn("u_small_int")
MediumIntColumn = mysql.IntegerColumn("medium_int")
UMediumIntColumn = mysql.IntegerColumn("u_medium_int")
IntegerColumn = mysql.IntegerColumn("integer")
UIntegerColumn = mysql.IntegerColumn("u_integer")
BigIntColumn = mysql.IntegerColumn("big_int")
UBigIntColumn = mysql.IntegerColumn("u_big_int")
TinyIntPtrColumn = mysql.IntegerColumn("tiny_int_ptr")
UTinyIntPtrColumn = mysql.IntegerColumn("u_tiny_int_ptr")
SmallIntPtrColumn = mysql.IntegerColumn("small_int_ptr")
USmallIntPtrColumn = mysql.IntegerColumn("u_small_int_ptr")
MediumIntPtrColumn = mysql.IntegerColumn("medium_int_ptr")
UMediumIntPtrColumn = mysql.IntegerColumn("u_medium_int_ptr")
IntegerPtrColumn = mysql.IntegerColumn("integer_ptr")
UIntegerPtrColumn = mysql.IntegerColumn("u_integer_ptr")
BigIntPtrColumn = mysql.IntegerColumn("big_int_ptr")
UBigIntPtrColumn = mysql.IntegerColumn("u_big_int_ptr")
DecimalColumn = mysql.FloatColumn("decimal")
DecimalPtrColumn = mysql.FloatColumn("decimal_ptr")
NumericColumn = mysql.FloatColumn("numeric")
NumericPtrColumn = mysql.FloatColumn("numeric_ptr")
FloatColumn = mysql.FloatColumn("float")
FloatPtrColumn = mysql.FloatColumn("float_ptr")
DoubleColumn = mysql.FloatColumn("double")
DoublePtrColumn = mysql.FloatColumn("double_ptr")
RealColumn = mysql.FloatColumn("real")
RealPtrColumn = mysql.FloatColumn("real_ptr")
BitColumn = mysql.StringColumn("bit")
BitPtrColumn = mysql.StringColumn("bit_ptr")
TimeColumn = mysql.TimeColumn("time")
TimePtrColumn = mysql.TimeColumn("time_ptr")
DateColumn = mysql.DateColumn("date")
DatePtrColumn = mysql.DateColumn("date_ptr")
DateTimeColumn = mysql.TimestampColumn("date_time")
DateTimePtrColumn = mysql.TimestampColumn("date_time_ptr")
TimestampColumn = mysql.TimestampColumn("timestamp")
TimestampPtrColumn = mysql.TimestampColumn("timestamp_ptr")
YearColumn = mysql.IntegerColumn("year")
YearPtrColumn = mysql.IntegerColumn("year_ptr")
CharColumn = mysql.StringColumn("char")
CharPtrColumn = mysql.StringColumn("char_ptr")
VarCharColumn = mysql.StringColumn("var_char")
VarCharPtrColumn = mysql.StringColumn("var_char_ptr")
BinaryColumn = mysql.BlobColumn("binary")
BinaryPtrColumn = mysql.BlobColumn("binary_ptr")
VarBinaryColumn = mysql.BlobColumn("var_binary")
VarBinaryPtrColumn = mysql.BlobColumn("var_binary_ptr")
BlobColumn = mysql.BlobColumn("blob")
BlobPtrColumn = mysql.BlobColumn("blob_ptr")
TextColumn = mysql.StringColumn("text")
TextPtrColumn = mysql.StringColumn("text_ptr")
EnumColumn = mysql.StringColumn("enum")
EnumPtrColumn = mysql.StringColumn("enum_ptr")
SetColumn = mysql.StringColumn("set")
SetPtrColumn = mysql.StringColumn("set_ptr")
JSONColumn = mysql.StringColumn("json")
JSONPtrColumn = mysql.StringColumn("json_ptr")
allColumns = mysql.ColumnList{IDColumn, BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn}
mutableColumns = mysql.ColumnList{BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn}
defaultColumns = mysql.ColumnList{BooleanColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, DecimalColumn, NumericColumn, FloatColumn, DoubleColumn, RealColumn, BitColumn, TimeColumn, DateColumn, DateTimeColumn, TimestampColumn, YearColumn, CharColumn, VarCharColumn, BinaryColumn, VarBinaryColumn, EnumColumn, SetColumn}
)
return allTypesTable{
Table: mysql.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
ID: IDColumn,
Boolean: BooleanColumn,
BooleanPtr: BooleanPtrColumn,
TinyInt: TinyIntColumn,
UTinyInt: UTinyIntColumn,
SmallInt: SmallIntColumn,
USmallInt: USmallIntColumn,
MediumInt: MediumIntColumn,
UMediumInt: UMediumIntColumn,
Integer: IntegerColumn,
UInteger: UIntegerColumn,
BigInt: BigIntColumn,
UBigInt: UBigIntColumn,
TinyIntPtr: TinyIntPtrColumn,
UTinyIntPtr: UTinyIntPtrColumn,
SmallIntPtr: SmallIntPtrColumn,
USmallIntPtr: USmallIntPtrColumn,
MediumIntPtr: MediumIntPtrColumn,
UMediumIntPtr: UMediumIntPtrColumn,
IntegerPtr: IntegerPtrColumn,
UIntegerPtr: UIntegerPtrColumn,
BigIntPtr: BigIntPtrColumn,
UBigIntPtr: UBigIntPtrColumn,
Decimal: DecimalColumn,
DecimalPtr: DecimalPtrColumn,
Numeric: NumericColumn,
NumericPtr: NumericPtrColumn,
Float: FloatColumn,
FloatPtr: FloatPtrColumn,
Double: DoubleColumn,
DoublePtr: DoublePtrColumn,
Real: RealColumn,
RealPtr: RealPtrColumn,
Bit: BitColumn,
BitPtr: BitPtrColumn,
Time: TimeColumn,
TimePtr: TimePtrColumn,
Date: DateColumn,
DatePtr: DatePtrColumn,
DateTime: DateTimeColumn,
DateTimePtr: DateTimePtrColumn,
Timestamp: TimestampColumn,
TimestampPtr: TimestampPtrColumn,
Year: YearColumn,
YearPtr: YearPtrColumn,
Char: CharColumn,
CharPtr: CharPtrColumn,
VarChar: VarCharColumn,
VarCharPtr: VarCharPtrColumn,
Binary: BinaryColumn,
BinaryPtr: BinaryPtrColumn,
VarBinary: VarBinaryColumn,
VarBinaryPtr: VarBinaryPtrColumn,
Blob: BlobColumn,
BlobPtr: BlobPtrColumn,
Text: TextColumn,
TextPtr: TextPtrColumn,
Enum: EnumColumn,
EnumPtr: EnumPtrColumn,
Set: SetColumn,
SetPtr: SetPtrColumn,
JSON: JSONColumn,
JSONPtr: JSONPtrColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
DefaultColumns: defaultColumns,
}
}
`

View file

@ -8,6 +8,7 @@ import (
jetmysql "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/stmtcache"
"github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
_ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/require"
"runtime"
@ -21,12 +22,14 @@ var db *stmtcache.DB
var source string
var withStatementCaching bool
var testRoot string
const MariaDB = "MariaDB"
func init() {
source = os.Getenv("MY_SQL_SOURCE")
withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true"
testRoot = repo.GetTestsDirPath()
}
func sourceIsMariaDB() bool {

View file

@ -0,0 +1,453 @@
package mysql
import (
"context"
"fmt"
"github.com/go-jet/jet/v2/qrm"
"strings"
"testing"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
"github.com/stretchr/testify/require"
)
var ctx = context.Background()
func TestSelectJsonObj(t *testing.T) {
stmt := SELECT_JSON_OBJ(Actor.AllColumns).
FROM(Actor).
WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertStatementSql(t, stmt, `
SELECT JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
) AS "json"
FROM dvds.actor
WHERE actor.actor_id = ?;
`, int64(2))
var dest model.Actor
err := stmt.Query(db, &dest)
require.Nil(t, err)
testutils.AssertDeepEqual(t, dest, actor2)
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 1)
}
func TestSelectJsonObj_NestedObj(t *testing.T) {
stmt := SELECT_JSON_OBJ(
Actor.AllColumns,
SELECT_JSON_OBJ(Film.AllColumns).
FROM(FilmActor.INNER_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID))).
WHERE(Actor.ActorID.EQ(FilmActor.ActorID)).
ORDER_BY(Film.Length.DESC()).
LIMIT(1).OFFSET(3).AS("LongestFilm"),
).FROM(
Actor,
).WHERE(
Actor.ActorID.EQ(Int(2)),
)
testutils.AssertStatementSql(t, stmt, `
SELECT JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ'),
'LongestFilm', (
SELECT JSON_OBJECT(
'filmID', film.film_id,
'title', film.title,
'description', film.description,
'releaseYear', film.release_year,
'languageID', film.language_id,
'originalLanguageID', film.original_language_id,
'rentalDuration', film.rental_duration,
'rentalRate', film.rental_rate,
'length', film.length,
'replacementCost', film.replacement_cost,
'rating', film.rating,
'specialFeatures', film.special_features,
'lastUpdate', DATE_FORMAT(film.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
) AS "json"
FROM dvds.film_actor
INNER JOIN dvds.film ON (film.film_id = film_actor.film_id)
WHERE actor.actor_id = film_actor.actor_id
ORDER BY film.length DESC
LIMIT ?
OFFSET ?
)
) AS "json"
FROM dvds.actor
WHERE actor.actor_id = ?;
`)
var dest struct {
model.Actor
LongestFilm model.Film
}
err := stmt.QueryContext(ctx, db, &dest)
require.Nil(t, err)
testutils.AssertJSON(t, dest, `
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"LastUpdate": "2006-02-15T04:34:33Z",
"LongestFilm": {
"FilmID": 754,
"Title": "RUSHMORE MERMAID",
"Description": "A Boring Story of a Woman And a Moose who must Reach a Husband in A Shark Tank",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 6,
"RentalRate": 2.99,
"Length": 150,
"ReplacementCost": 17.99,
"Rating": "PG-13",
"SpecialFeatures": "Trailers,Commentaries,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
}
`)
}
func TestSelectJsonArr(t *testing.T) {
stmt := SELECT_JSON_ARR(Actor.AllColumns).
FROM(Actor).
ORDER_BY(Actor.ActorID)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
)) AS "json"
FROM dvds.actor
ORDER BY actor.actor_id;
`)
var dest []model.Actor
err := stmt.Query(db, &dest)
require.Nil(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json")
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 1)
}
func TestSelectJsonArr_NestedArr(t *testing.T) {
stmt := SELECT_JSON_ARR(
Actor.AllColumns,
SELECT_JSON_ARR(
Film.AllColumns,
).FROM(
FilmActor.INNER_JOIN(
Film,
Film.FilmID.EQ(FilmActor.FilmID).AND(
Actor.ActorID.EQ(FilmActor.ActorID)),
),
).WHERE(
Film.FilmID.MOD(Int(17)).EQ(Int(0)),
).ORDER_BY(
Film.Length.DESC(),
).AS("Films"),
).FROM(
Actor,
).WHERE(
Actor.ActorID.BETWEEN(Int(1), Int(3)),
).ORDER_BY(
Actor.ActorID,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'actorID', actor.actor_id,
'firstName', actor.first_name,
'lastName', actor.last_name,
'lastUpdate', DATE_FORMAT(actor.last_update,'%Y-%m-%dT%H:%i:%s.%fZ'),
'Films', (
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'filmID', film.film_id,
'title', film.title,
'description', film.description,
'releaseYear', film.release_year,
'languageID', film.language_id,
'originalLanguageID', film.original_language_id,
'rentalDuration', film.rental_duration,
'rentalRate', film.rental_rate,
'length', film.length,
'replacementCost', film.replacement_cost,
'rating', film.rating,
'specialFeatures', film.special_features,
'lastUpdate', DATE_FORMAT(film.last_update,'%Y-%m-%dT%H:%i:%s.%fZ')
)) AS "json"
FROM dvds.film_actor
INNER JOIN dvds.film ON ((film.film_id = film_actor.film_id) AND (actor.actor_id = film_actor.actor_id))
WHERE (film.film_id % 17) = 0
ORDER BY film.length DESC
)
)) AS "json"
FROM dvds.actor
WHERE actor.actor_id BETWEEN 1 AND 3
ORDER BY actor.actor_id;
`)
var dest []struct {
model.Actor
Films []model.Film
}
err := stmt.QueryContext(ctx, db, &dest)
fmt.Println(err)
require.Nil(t, err)
testutils.AssertJSON(t, dest, `
[
{
"ActorID": 1,
"FirstName": "PENELOPE",
"LastName": "GUINESS",
"LastUpdate": "2006-02-15T04:34:33Z",
"Films": null
},
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"LastUpdate": "2006-02-15T04:34:33Z",
"Films": [
{
"FilmID": 357,
"Title": "GILBERT PELICAN",
"Description": "A Fateful Tale of a Man And a Feminist who must Conquer a Crocodile in A Manhattan Penthouse",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 7,
"RentalRate": 0.99,
"Length": 114,
"ReplacementCost": 13.99,
"Rating": "G",
"SpecialFeatures": "Trailers,Commentaries",
"LastUpdate": "2006-02-15T05:03:42Z"
},
{
"FilmID": 561,
"Title": "MASK PEACH",
"Description": "A Boring Character Study of a Student And a Robot who must Meet a Woman in California",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 6,
"RentalRate": 2.99,
"Length": 123,
"ReplacementCost": 26.99,
"Rating": "NC-17",
"SpecialFeatures": "Commentaries,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
]
},
{
"ActorID": 3,
"FirstName": "ED",
"LastName": "CHASE",
"LastUpdate": "2006-02-15T04:34:33Z",
"Films": [
{
"FilmID": 17,
"Title": "ALONE TRIP",
"Description": "A Fast-Paced Character Study of a Composer And a Dog who must Outgun a Boat in An Abandoned Fun House",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 3,
"RentalRate": 0.99,
"Length": 82,
"ReplacementCost": 14.99,
"Rating": "R",
"SpecialFeatures": "Trailers,Behind the Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
},
{
"FilmID": 289,
"Title": "EVE RESURRECTION",
"Description": "A Awe-Inspiring Yarn of a Pastry Chef And a Database Administrator who must Challenge a Teacher in A Baloon",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 5,
"RentalRate": 4.99,
"Length": 66,
"ReplacementCost": 25.99,
"Rating": "G",
"SpecialFeatures": "Trailers,Commentaries,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
]
}
]
`)
}
func TestSelectJson_GroupBy(t *testing.T) {
skipForMariaDB(t) // scope issues with select without FROM
subQuery := SELECT(
Customer.AllColumns,
SUM(Payment.Amount).AS("sum"),
AVG(Payment.Amount).AS("avg"),
MAX(Payment.Amount).AS("max"),
MIN(Payment.Amount).AS("min"),
COUNT(Payment.Amount).AS("count"),
).FROM(
Payment.
INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)),
).GROUP_BY(
Customer.CustomerID,
).HAVING(
SUMf(Payment.Amount).GT(Float(125)),
).ORDER_BY(
Customer.CustomerID, SUM(Payment.Amount).ASC(),
).AsTable("customers_info")
stmt := SELECT_JSON_ARR(
subQuery.AllColumns().Except( // TODO: remove when ColumnList.From() is implemented
FloatColumn("sum"),
FloatColumn("avg"),
FloatColumn("max"),
FloatColumn("min"),
FloatColumn("count"),
),
SELECT_JSON_OBJ(
FloatColumn("sum").From(subQuery),
FloatColumn("avg").From(subQuery),
FloatColumn("max").From(subQuery),
FloatColumn("min").From(subQuery),
FloatColumn("count").From(subQuery),
).AS("amount"),
).FROM(subQuery)
testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(`
SELECT JSON_ARRAYAGG(JSON_OBJECT(
'customerID', customers_info.''customer.customer_id'',
'storeID', customers_info.''customer.store_id'',
'firstName', customers_info.''customer.first_name'',
'lastName', customers_info.''customer.last_name'',
'email', customers_info.''customer.email'',
'addressID', customers_info.''customer.address_id'',
'active', customers_info.''customer.active'' = 1,
'createDate', DATE_FORMAT(customers_info.''customer.create_date'','%Y-%m-%dT%H:%i:%s.%fZ'),
'lastUpdate', DATE_FORMAT(customers_info.''customer.last_update'','%Y-%m-%dT%H:%i:%s.%fZ'),
'amount', (
SELECT JSON_OBJECT(
'sum', customers_info.sum,
'avg', customers_info.avg,
'max', customers_info.max,
'min', customers_info.min,
'count', customers_info.count
) AS "json"
)
)) AS "json"
FROM (
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
customer.last_name AS "customer.last_name",
customer.email AS "customer.email",
customer.address_id AS "customer.address_id",
customer.active AS "customer.active",
customer.create_date AS "customer.create_date",
customer.last_update AS "customer.last_update",
SUM(payment.amount) AS "sum",
AVG(payment.amount) AS "avg",
MAX(payment.amount) AS "max",
MIN(payment.amount) AS "min",
COUNT(payment.amount) AS "count"
FROM dvds.payment
INNER JOIN dvds.customer ON (customer.customer_id = payment.customer_id)
GROUP BY customer.customer_id
HAVING SUM(payment.amount) > 125
ORDER BY customer.customer_id, SUM(payment.amount) ASC
) AS customers_info;
`, "''", "`"))
var dest []struct {
model.Customer
Amount struct {
Sum float64
Avg float64
Max float64
Min float64
Count int64
}
}
err := stmt.QueryContext(ctx, db, &dest)
require.Nil(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json")
requireLogged(t, stmt)
}
func TestSelectJsonObject_EmptyResult(t *testing.T) {
t.Run("json obj", func(t *testing.T) {
stmt := SELECT_JSON_OBJ(Actor.AllColumns).
FROM(Actor).
WHERE(Actor.FirstName.EQ(String("Kowalski")))
var dest model.Actor
err := stmt.QueryContext(ctx, db, &dest)
require.ErrorIs(t, err, qrm.ErrNoRows)
})
t.Run("json arr", func(t *testing.T) {
stmt := SELECT_JSON_ARR(Actor.AllColumns).
FROM(Actor).
WHERE(Actor.FirstName.EQ(String("Kowalski")))
var dest []model.Actor
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
require.Empty(t, dest)
})
}
func TestSelectJson_ProjectionNotAliased(t *testing.T) {
t.Run("expression not aliased", func(t *testing.T) {
testutils.AssertPanicErr(t, func() {
stmt := SELECT_JSON_ARR(
Int(2).ADD(Customer.CustomerID),
).FROM(Customer)
stmt.DebugSql()
}, "jet: expression need to be aliased when used as SELECT JSON projection.")
})
}

View file

@ -19,9 +19,9 @@ import (
)
func TestSelect_ScanToStruct(t *testing.T) {
query := Actor.
SELECT(Actor.AllColumns).
query := SELECT(Actor.AllColumns).
DISTINCT().
FROM(Actor).
WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertStatementSql(t, query, `
@ -50,9 +50,56 @@ var actor2 = model.Actor{
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2),
}
func TestSelect_NestedObject(t *testing.T) {
stmt := SELECT(
Actor.AllColumns,
Film.AllColumns,
).FROM(
Actor.
LEFT_JOIN(FilmActor, FilmActor.ActorID.EQ(Actor.ActorID)).
LEFT_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)),
).WHERE(
Actor.ActorID.EQ(Int(2)),
).ORDER_BY(
Film.LastUpdate.DESC(),
).LIMIT(1)
var dest struct {
model.Actor
LatestFilm model.Film
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
{
"ActorID": 2,
"FirstName": "NICK",
"LastName": "WAHLBERG",
"LastUpdate": "2006-02-15T04:34:33Z",
"LatestFilm": {
"FilmID": 3,
"Title": "ADAPTATION HOLES",
"Description": "A Astounding Reflection of a Lumberjack And a Car who must Sink a Lumberjack in A Baloon Factory",
"ReleaseYear": 2006,
"LanguageID": 1,
"OriginalLanguageID": null,
"RentalDuration": 7,
"RentalRate": 2.99,
"Length": 50,
"ReplacementCost": 18.99,
"Rating": "NC-17",
"SpecialFeatures": "Trailers,Deleted Scenes",
"LastUpdate": "2006-02-15T05:03:42Z"
}
}
`)
}
func TestSelect_ScanToSlice(t *testing.T) {
query := Actor.
SELECT(Actor.AllColumns).
query := SELECT(Actor.AllColumns).
FROM(Actor).
ORDER_BY(Actor.ActorID)
testutils.AssertStatementSql(t, query, `
@ -107,19 +154,20 @@ GROUP BY payment.customer_id
HAVING SUM(payment.amount) > 125.6
ORDER BY payment.customer_id, SUM(payment.amount) ASC;
`
query := Payment.
INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)).
SELECT(
Customer.AllColumns,
query := SELECT(
Customer.AllColumns,
SUMf(Payment.Amount).AS("amount.sum"),
AVG(Payment.Amount).AS("amount.avg"),
MAX(Payment.PaymentDate).AS("amount.max_date"),
MAXf(Payment.Amount).AS("amount.max"),
MIN(Payment.PaymentDate).AS("amount.min_date"),
MINf(Payment.Amount).AS("amount.min"),
COUNT(Payment.Amount).AS("amount.count"),
).
SUMf(Payment.Amount).AS("amount.sum"),
AVG(Payment.Amount).AS("amount.avg"),
MAX(Payment.PaymentDate).AS("amount.max_date"),
MAXf(Payment.Amount).AS("amount.max"),
MIN(Payment.PaymentDate).AS("amount.min_date"),
MINf(Payment.Amount).AS("amount.min"),
COUNT(Payment.Amount).AS("amount.count"),
).FROM(
Payment.
INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)),
).
GROUP_BY(Payment.CustomerID).
HAVING(
SUMf(Payment.Amount).GT(Float(125.6)),
@ -1122,7 +1170,7 @@ WHERE payment.payment_id < ?
WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id)
ORDER BY payment.customer_id;
`
query := Payment.SELECT(
query := SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(Window("w1")),
AVG(Payment.Amount).OVER(
@ -1131,7 +1179,7 @@ ORDER BY payment.customer_id;
RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
),
AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
).
).FROM(Payment).
WHERE(Payment.PaymentID.LT(Int(10))).
WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)).
WINDOW("w2").AS(Window("w1")).

View file

@ -1,6 +1,8 @@
package postgres
import (
"encoding/base64"
"fmt"
"github.com/go-jet/jet/v2/internal/utils/ptr"
"github.com/stretchr/testify/assert"
"math"
@ -36,6 +38,141 @@ func TestAllTypesSelect(t *testing.T) {
testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
}
func TestAllTypesSelectJson(t *testing.T) {
stmt := SELECT_JSON_ARR(
AllTypesAllColumns.Except(
AllTypes.JSON, AllTypes.JSONPtr,
AllTypes.Jsonb, AllTypes.JsonbPtr,
AllTypes.TextArray, AllTypes.TextArrayPtr,
AllTypes.JsonbArray, AllTypes.IntegerArray, AllTypes.IntegerArrayPtr,
AllTypes.TextMultiDimArray, AllTypes.TextMultiDimArrayPtr,
),
// unsupported at the moment, casting to text allows these columns to be assigned to string fields
CAST(AllTypes.JSONPtr).AS_TEXT().AS("jsonPtr"),
CAST(AllTypes.JSON).AS_TEXT().AS("JSON"),
CAST(AllTypes.JsonbPtr).AS_TEXT().AS("jsonbPtr"),
CAST(AllTypes.Jsonb).AS_TEXT().AS("Jsonb"),
CAST(AllTypes.TextArrayPtr).AS_TEXT().AS("TextArrayPtr"),
CAST(AllTypes.TextArray).AS_TEXT().AS("TextArray"),
CAST(AllTypes.JsonbArray).AS_TEXT().AS("JsonbArray"),
CAST(AllTypes.IntegerArray).AS_TEXT().AS("IntegerArray"),
CAST(AllTypes.IntegerArrayPtr).AS_TEXT().AS("IntegerArrayPtr"),
CAST(AllTypes.TextMultiDimArray).AS_TEXT().AS("TextMultiDimArray"),
CAST(AllTypes.TextMultiDimArrayPtr).AS_TEXT().AS("TextMultiDimArrayPtr"),
).FROM(AllTypes)
testutils.AssertStatementSql(t, stmt, `
SELECT json_agg(row_to_json(records)) AS "json"
FROM (
SELECT all_types.small_int_ptr AS "smallIntPtr",
all_types.small_int AS "smallInt",
all_types.integer_ptr AS "integerPtr",
all_types.integer AS "integer",
all_types.big_int_ptr AS "bigIntPtr",
all_types.big_int AS "bigInt",
all_types.decimal_ptr AS "decimalPtr",
all_types.decimal AS "decimal",
all_types.numeric_ptr AS "numericPtr",
all_types.numeric AS "numeric",
all_types.real_ptr AS "realPtr",
all_types.real AS "real",
all_types.double_precision_ptr AS "doublePrecisionPtr",
all_types.double_precision AS "doublePrecision",
all_types.smallserial AS "smallserial",
all_types.serial AS "serial",
all_types.bigserial AS "bigserial",
all_types.var_char_ptr AS "varCharPtr",
all_types.var_char AS "varChar",
all_types.char_ptr AS "charPtr",
all_types.char AS "char",
all_types.text_ptr AS "textPtr",
all_types.text AS "text",
ENCODE(all_types.bytea_ptr, 'base64') AS "byteaPtr",
ENCODE(all_types.bytea, 'base64') AS "bytea",
all_types.timestampz_ptr AS "timestampzPtr",
all_types.timestampz AS "timestampz",
to_char(all_types.timestamp_ptr, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampPtr",
to_char(all_types.timestamp, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp",
to_char(all_types.date_ptr::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "datePtr",
to_char(all_types.date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.timez_ptr, 'HH24:MI:SS.USTZH:TZM') AS "timezPtr",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.timez, 'HH24:MI:SS.USTZH:TZM') AS "timez",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.time_ptr, 'HH24:MI:SS.USZ') AS "timePtr",
'0000-01-01T' || to_char('2000-10-10'::date + all_types.time, 'HH24:MI:SS.USZ') AS "time",
all_types.interval_ptr AS "intervalPtr",
all_types.interval AS "interval",
all_types.boolean_ptr AS "booleanPtr",
all_types.boolean AS "boolean",
all_types.point_ptr AS "pointPtr",
all_types.bit_ptr AS "bitPtr",
all_types.bit AS "bit",
all_types.bit_varying_ptr AS "bitVaryingPtr",
all_types.bit_varying AS "bitVarying",
all_types.tsvector_ptr AS "tsvectorPtr",
all_types.tsvector AS "tsvector",
all_types.uuid_ptr AS "uuidPtr",
all_types.uuid AS "uuid",
all_types.xml_ptr AS "xmlPtr",
all_types.xml AS "xml",
all_types.mood_ptr AS "moodPtr",
all_types.mood AS "mood",
all_types.json_ptr::text AS "jsonPtr",
all_types.json::text AS "JSON",
all_types.jsonb_ptr::text AS "jsonbPtr",
all_types.jsonb::text AS "Jsonb",
all_types.text_array_ptr::text AS "TextArrayPtr",
all_types.text_array::text AS "TextArray",
all_types.jsonb_array::text AS "JsonbArray",
all_types.integer_array::text AS "IntegerArray",
all_types.integer_array_ptr::text AS "IntegerArrayPtr",
all_types.text_multi_dim_array::text AS "TextMultiDimArray",
all_types.text_multi_dim_array_ptr::text AS "TextMultiDimArrayPtr"
FROM test_sample.all_types
) AS records;
`)
var dest []model.AllTypes
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
// fix inconsistencies between postgres and cockroachdb.
// cockroachdb returns char[N] columns with trailing whitespaces trimmed
if sourceIsCockroachDB() {
dest[0].Char = allTypesRow0.Char
dest[0].CharPtr = allTypesRow0.CharPtr
dest[1].Char = allTypesRow1.Char
dest[1].CharPtr = allTypesRow1.CharPtr
}
minus8 := time.FixedZone("UTC", -8*60*60)
plus1 := time.FixedZone("UTC", 60*60)
// set time local before comparison
dest[0].Timez = *toTZ(&dest[0].Timez, minus8)
dest[0].TimezPtr = toTZ(dest[0].TimezPtr, minus8)
dest[1].Timez = *toTZ(&dest[1].Timez, minus8)
dest[1].TimezPtr = toTZ(dest[1].TimezPtr, minus8)
dest[0].Timestampz = *toTZ(&dest[0].Timestampz, plus1)
dest[0].TimestampzPtr = toTZ(dest[0].TimestampzPtr, plus1)
dest[1].Timestampz = *toTZ(&dest[1].Timestampz, plus1)
dest[1].TimestampzPtr = toTZ(dest[1].TimestampzPtr, plus1)
testutils.AssertJsonEqual(t, dest[0], allTypesRow0)
testutils.AssertJsonEqual(t, dest[1], allTypesRow1)
}
func toTZ(tm *time.Time, loc *time.Location) *time.Time {
if tm == nil {
return nil
}
return ptr.Of(tm.In(loc))
}
func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes
var dest []AllTypesView
@ -132,7 +269,7 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11';
requireLogged(t, query)
}
func TestBytea(t *testing.T) {
func TestByteaInsert(t *testing.T) {
byteArrHex := "\\x48656c6c6f20476f7068657221"
byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21")
@ -147,40 +284,42 @@ RETURNING all_types.bytea AS "all_types.bytea",
all_types.bytea_ptr AS "all_types.bytea_ptr";
`, byteArrHex, byteArrBin)
var inserted model.AllTypes
err := insertStmt.Query(db, &inserted)
require.NoError(t, err)
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
var inserted model.AllTypes
err := insertStmt.Query(tx, &inserted)
require.NoError(t, err)
require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!")
// It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver.
// pq driver always encodes parameter string if destination column is of type bytea.
// Probably pq driver error.
// require.Equal(t, string(inserted.Bytea), "Hello Gopher!")
require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!")
// It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver.
// pq driver always encodes parameter string if destination column is of type bytea.
// Probably pq driver error.
// require.Equal(t, string(inserted.Bytea), "Hello Gopher!")
stmt := SELECT(
AllTypes.Bytea,
AllTypes.ByteaPtr,
).FROM(
AllTypes,
).WHERE(
AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)),
)
stmt := SELECT(
AllTypes.Bytea,
AllTypes.ByteaPtr,
).FROM(
AllTypes,
).WHERE(
AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)),
)
testutils.AssertStatementSql(t, stmt, `
testutils.AssertStatementSql(t, stmt, `
SELECT all_types.bytea AS "all_types.bytea",
all_types.bytea_ptr AS "all_types.bytea_ptr"
FROM test_sample.all_types
WHERE all_types.bytea_ptr = $1::bytea;
`, byteArrBin)
var dest model.AllTypes
var dest model.AllTypes
err = stmt.Query(db, &dest)
require.NoError(t, err)
err = stmt.Query(tx, &dest)
require.NoError(t, err)
require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!")
// Probably pq driver error.
// require.Equal(t, string(dest.Bytea), "Hello Gopher!")
require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!")
// Probably pq driver error.
// require.Equal(t, string(dest.Bytea), "Hello Gopher!")
})
}
func TestAllTypesFromSubQuery(t *testing.T) {
@ -424,6 +563,7 @@ func TestExpressionCast(t *testing.T) {
CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(),
CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(),
CAST(String("04:05:06")).AS_INTERVAL(),
CAST(String("some text")).AS_BYTEA().EQ(Bytea([]byte("some text"))),
func() ProjectionList {
if sourceIsCockroachDB() {
@ -477,7 +617,6 @@ func TestStringOperators(t *testing.T) {
AllTypes.Text.BETWEEN(String("min"), String("max")),
AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr),
AllTypes.Text.CONCAT(String("text2")),
AllTypes.Text.CONCAT(Int(11)),
AllTypes.Text.LIKE(String("abc")),
AllTypes.Text.NOT_LIKE(String("_b_")),
AllTypes.Text.REGEXP_LIKE(String("^t")),
@ -508,18 +647,18 @@ func TestStringOperators(t *testing.T) {
CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)),
CONCAT(Bool(false), Int(1), Float(22.2), String("test test")),
CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)),
CONVERT(Bytea("bytea"), String("UTF8"), String("LATIN1")),
CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")),
CONVERT_FROM(Bytea("text_in_utf8"), String("UTF8")),
CONVERT_TO(String("text_in_utf8"), String("UTF8")),
ENCODE(Bytea("123\000\001"), String("base64")),
DECODE(String("MTIzAAE="), String("base64")),
CONVERT(Bytea("bytea"), UTF8, LATIN1),
CONVERT(AllTypes.Bytea, UTF8, LATIN1),
CONVERT_FROM(Bytea("text_in_utf8"), UTF8),
CONVERT_TO(String("text_in_utf8"), UTF8),
ENCODE(Bytea("some text"), Escape),
DECODE(String("MTIzAAE="), Base64),
FORMAT(String("Hello %s, %1$s"), String("World")),
INITCAP(String("hi THOMAS")),
LEFT(String("abcde"), Int(2)),
RIGHT(String("abcde"), Int(2)),
LENGTH(Bytea("jose")),
LENGTH(Bytea("jose"), String("UTF8")),
LENGTH(Bytea("jose"), UTF8),
LPAD(String("Hi"), Int(5)),
LPAD(String("Hi"), Int(5), String("xy")),
RPAD(String("Hi"), Int(5)),
@ -540,6 +679,202 @@ func TestStringOperators(t *testing.T) {
require.NoError(t, err)
}
func TestBytea(t *testing.T) {
var sampleBytea = Bytea([]byte{11, 0, 22, 33, 44})
var textBytea = Bytea([]byte("text blob"))
stmt := SELECT(
AllTypes.Bytea.EQ(sampleBytea),
AllTypes.Bytea.EQ(AllTypes.ByteaPtr),
AllTypes.Bytea.NOT_EQ(sampleBytea),
AllTypes.Bytea.GT(textBytea),
AllTypes.Bytea.GT_EQ(AllTypes.ByteaPtr),
AllTypes.Bytea.LT(AllTypes.ByteaPtr),
AllTypes.Bytea.LT_EQ(sampleBytea),
AllTypes.Bytea.BETWEEN(Bytea([]byte("min")), Bytea([]byte("max"))),
AllTypes.Bytea.NOT_BETWEEN(AllTypes.Bytea, AllTypes.ByteaPtr),
AllTypes.Bytea.CONCAT(textBytea),
func() ProjectionList {
if sourceIsCockroachDB() {
return ProjectionList{NULL}
}
// cockroach doesn't support currently
return ProjectionList{
AllTypes.Bytea.LIKE(Bytea("b'%pattern%'")),
AllTypes.Bytea.NOT_LIKE(Bytea("b'%pattern%'")),
BTRIM(AllTypes.Bytea, Bytea([]byte{33})),
RTRIM(AllTypes.ByteaPtr, sampleBytea),
LTRIM(sampleBytea, textBytea),
CONCAT(sampleBytea, AllTypes.ByteaPtr, textBytea),
BIT_COUNT(sampleBytea).EQ(Int(3)),
LENGTH(textBytea, UTF8).EQ(Int(4)),
CONVERT(textBytea, UTF8, WIN1252),
CONVERT(AllTypes.Bytea, UTF8, LATIN1).EQ(sampleBytea),
}
}(),
BIT_LENGTH(textBytea),
OCTET_LENGTH(textBytea),
GET_BIT(textBytea, Int(2)).EQ(Int(23)),
GET_BYTE(sampleBytea, Int(1)).EQ(Int(0)),
SET_BIT(textBytea, Int(1), Int(0)).EQ(sampleBytea),
SET_BYTE(textBytea, Int(1), Int(0)).EQ(textBytea),
LENGTH(sampleBytea),
SUBSTR(AllTypes.Bytea, Int(0), Int(2)),
MD5(AllTypes.Bytea),
SHA224(AllTypes.Bytea),
SHA256(AllTypes.Bytea),
SHA384(AllTypes.Bytea),
SHA512(AllTypes.Bytea),
ENCODE(sampleBytea, Base64),
DECODE(String("A234C12B"), Hex).EQ(sampleBytea),
CONVERT_FROM(AllTypes.ByteaPtr, UTF8).EQ(AllTypes.VarChar),
CONVERT_TO(AllTypes.Text, UTF8).NOT_EQ(textBytea),
RawBytea("DECODE(#1::text, #2)", RawArgs{
"#1": "A234C12B",
"#2": "hex",
}).EQ(sampleBytea),
).FROM(
AllTypes,
)
if !sourceIsCockroachDB() {
testutils.AssertStatementSql(t, stmt, `
SELECT all_types.bytea = $1::bytea,
all_types.bytea = all_types.bytea_ptr,
all_types.bytea != $2::bytea,
all_types.bytea > $3::bytea,
all_types.bytea >= all_types.bytea_ptr,
all_types.bytea < all_types.bytea_ptr,
all_types.bytea <= $4::bytea,
all_types.bytea BETWEEN $5::bytea AND $6::bytea,
all_types.bytea NOT BETWEEN all_types.bytea AND all_types.bytea_ptr,
all_types.bytea || $7::bytea,
all_types.bytea LIKE $8::bytea,
all_types.bytea NOT LIKE $9::bytea,
BTRIM(all_types.bytea, $10::bytea),
RTRIM(all_types.bytea_ptr, $11::bytea),
LTRIM($12::bytea, $13::bytea),
CONCAT($14::bytea, all_types.bytea_ptr, $15::bytea),
BIT_COUNT($16::bytea) = $17,
LENGTH($18::bytea, 'UTF8') = $19,
CONVERT($20::bytea, 'UTF8', 'WIN1252'),
CONVERT(all_types.bytea, 'UTF8', 'LATIN1') = $21::bytea,
BIT_LENGTH($22::bytea),
OCTET_LENGTH($23::bytea),
GET_BIT($24::bytea, $25) = $26,
GET_BYTE($27::bytea, $28) = $29,
SET_BIT($30::bytea, $31, $32) = $33::bytea,
SET_BYTE($34::bytea, $35, $36) = $37::bytea,
LENGTH($38::bytea),
SUBSTR(all_types.bytea, $39, $40),
MD5(all_types.bytea),
SHA224(all_types.bytea),
SHA256(all_types.bytea),
SHA384(all_types.bytea),
SHA512(all_types.bytea),
ENCODE($41::bytea, 'base64'),
DECODE($42::text, 'hex') = $43::bytea,
CONVERT_FROM(all_types.bytea_ptr, 'UTF8') = all_types.var_char,
CONVERT_TO(all_types.text, 'UTF8') != $44::bytea,
(DECODE($45::text, $46)) = $47::bytea
FROM test_sample.all_types;
`)
}
var dest []struct{}
err := stmt.Query(db, &dest)
require.NoError(t, err)
}
func TestBlobConversion(t *testing.T) {
nonPrintable := []byte{11, 22, 33, 44, 55}
printable := []byte("this is blob")
stmt := SELECT(
Bytea(nonPrintable).AS("test_dest.non_printable"),
Bytea(printable).AS("test_dest.printable"),
Bytea(nonPrintable).CONCAT(Bytea(printable)).AS("test_dest.bytea_concat"),
ENCODE(Bytea(nonPrintable), Base64).AS("test_dest.non_printable_base64"),
CONVERT_FROM(Bytea(printable), UTF8).AS("test_dest.printable_utf8"),
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT '\x0b16212c37'::bytea AS "test_dest.non_printable",
'\x7468697320697320626c6f62'::bytea AS "test_dest.printable",
('\x0b16212c37'::bytea || '\x7468697320697320626c6f62'::bytea) AS "test_dest.bytea_concat",
ENCODE('\x0b16212c37'::bytea, 'base64') AS "test_dest.non_printable_base64",
CONVERT_FROM('\x7468697320697320626c6f62'::bytea, 'UTF8') AS "test_dest.printable_utf8";
`)
type testDest struct {
NonPrintable []byte
Printable []byte
ByteaConcat []byte
NonPrintableBase64 string
PrintableUTF8 string
}
var dest testDest
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.NonPrintable, nonPrintable)
require.Equal(t, dest.Printable, printable)
require.Equal(t, dest.ByteaConcat, append(nonPrintable, printable...))
require.Equal(t, dest.NonPrintableBase64, base64.StdEncoding.EncodeToString(nonPrintable))
require.Equal(t, dest.PrintableUTF8, string(printable))
t.Run("using select json", func(t *testing.T) {
stmtJson := SELECT_JSON_OBJ(
Bytea(nonPrintable).AS("nonPrintable"),
Bytea(printable).AS("printable"),
Bytea(nonPrintable).CONCAT(Bytea(printable)).AS("byteaConcat"),
ENCODE(Bytea(nonPrintable), Base64).AS("nonPrintableBase64"),
CONVERT_FROM(Bytea(printable), UTF8).AS("printableUtf8"),
)
testutils.AssertStatementSql(t, stmtJson, `
SELECT row_to_json(records) AS "json"
FROM (
SELECT ENCODE($1::bytea, 'base64') AS "nonPrintable",
ENCODE($2::bytea, 'base64') AS "printable",
ENCODE($3::bytea || $4::bytea, 'base64') AS "byteaConcat",
ENCODE($5::bytea, 'base64') AS "nonPrintableBase64",
CONVERT_FROM($6::bytea, 'UTF8') AS "printableUtf8"
) AS records;
`)
var destSelectJson testDest
err := stmtJson.QueryContext(ctx, db, &destSelectJson)
require.NoError(t, err)
testutils.PrintJson(destSelectJson)
require.Equal(t, dest, destSelectJson)
})
}
func TestBoolOperators(t *testing.T) {
query := AllTypes.SELECT(
AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"),
@ -941,6 +1276,190 @@ func TestTimeExpression(t *testing.T) {
require.NoError(t, err)
}
func TestTimeScan(t *testing.T) {
loc, err := time.LoadLocation("Japan")
require.NoError(t, err)
timeT := time.Date(3, 3, 3, 11, 22, 33, 0, time.UTC)
timeWithNanoSeconds := time.Date(3, 3, 3, 1, 2, 3, 1000, time.UTC)
timez := time.Date(3, 3, 3, 7, 8, 9, 0, time.UTC)
timezWithNanoSeconds := time.Date(3, 3, 3, 4, 5, 6, 1000, loc)
// '1999-01-08 04:05:06'
timestamp := time.Date(1999, 01, 8, 4, 5, 6, 0, time.UTC)
timestampWithNanoSeconds := time.Date(3, 3, 3, 8, 9, 10, 1000, time.UTC)
timestampz := time.Date(2003, 10, 3, 9, 10, 11, 0, loc)
timestampzWithNanoSeconds := time.Date(3, 3, 3, 8, 9, 10, 1000, loc)
date := time.Date(2010, 2, 3, 0, 0, 0, 0, time.UTC)
stmt := SELECT(
TimeT(timeT).AS("time"),
TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"),
TimezT(timez).AS("timez"),
TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"),
Timestamp(1999, 01, 8, 4, 5, 6).AS("timestamp"),
TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"),
TimestampzT(timestampz).AS("timestampz"),
TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"),
DateT(date).AS("date"),
TimeT(timeT).ADD(INTERVAL(2, HOUR)).AS("timeExpression"),
SELECT_JSON_OBJ(
TimeT(timeT).AS("time"),
TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"),
TimezT(timez).AS("timez"),
TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"),
TimestampT(timestamp).AS("timestamp"),
TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"),
TimestampzT(timestampz).AS("timestampz"),
TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"),
DateT(date).AS("date"),
TimeT(timeT).ADD(INTERVAL(2, HOUR)).AS("timeExpression"),
).AS("json"),
)
testutils.AssertStatementSql(t, stmt, `
SELECT $1::time without time zone AS "time",
$2::time without time zone AS "timeWithNanoSeconds",
$3::time with time zone AS "timez",
$4::time with time zone AS "timezWithNanoSeconds",
$5::timestamp without time zone AS "timestamp",
$6::timestamp without time zone AS "timestampWithNanoSeconds",
$7::timestamp with time zone AS "timestampz",
$8::timestamp with time zone AS "timestampzWithNanoSeconds",
$9::date AS "date",
($10::time without time zone + INTERVAL '2 HOUR') AS "timeExpression",
(
SELECT row_to_json(json_records) AS "json_json"
FROM (
SELECT '0000-01-01T' || to_char('2000-10-10'::date + $11::time without time zone, 'HH24:MI:SS.USZ') AS "time",
'0000-01-01T' || to_char('2000-10-10'::date + $12::time without time zone, 'HH24:MI:SS.USZ') AS "timeWithNanoSeconds",
'0000-01-01T' || to_char('2000-10-10'::date + $13::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timez",
'0000-01-01T' || to_char('2000-10-10'::date + $14::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timezWithNanoSeconds",
to_char($15::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp",
to_char($16::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampWithNanoSeconds",
$17::timestamp with time zone AS "timestampz",
$18::timestamp with time zone AS "timestampzWithNanoSeconds",
to_char($19::date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date",
'0000-01-01T' || to_char('2000-10-10'::date + ($20::time without time zone + INTERVAL '2 HOUR'), 'HH24:MI:SS.USZ') AS "timeExpression"
) AS json_records
) AS "json";
`)
var dest struct {
Time time.Time
TimeWithNanoSeconds time.Time
Timez time.Time
TimezWithNanoSeconds time.Time
Timestamp time.Time
TimestampWithNanoSeconds time.Time
Timestampz time.Time
TimestampzWithNanoSeconds time.Time
Date time.Time
TimeExpression time.Time
Json struct {
Time time.Time
TimeWithNanoSeconds time.Time
Timez time.Time
TimezWithNanoSeconds time.Time
Timestamp time.Time
TimestampWithNanoSeconds time.Time
Timestampz time.Time
TimestampzWithNanoSeconds time.Time
Date time.Time
TimeExpression time.Time
} `json_column:"json"`
}
err = stmt.Query(db, &dest)
require.NoError(t, err)
ensureTimezEqual(t, timeT.Add(2*time.Hour), dest.TimeExpression, loc)
ensureTimezEqual(t, timeT.Add(2*time.Hour), dest.Json.TimeExpression, loc)
ensureTimezEqual(t, timeT, dest.Time, loc)
ensureTimezEqual(t, timeT, dest.Json.Time, loc)
ensureTimezEqual(t, timeWithNanoSeconds, dest.TimeWithNanoSeconds, loc)
ensureTimezEqual(t, timeWithNanoSeconds, dest.Json.TimeWithNanoSeconds, loc)
ensureTimezEqual(t, timez, dest.Timez, loc)
ensureTimezEqual(t, timez, dest.Json.Timez, loc)
ensureTimezEqual(t, timezWithNanoSeconds, dest.TimezWithNanoSeconds, loc)
ensureTimezEqual(t, timezWithNanoSeconds, dest.Json.TimezWithNanoSeconds, loc)
ensureTimezEqual(t, timestamp, dest.Timestamp, loc)
ensureTimezEqual(t, timestamp, dest.Json.Timestamp, loc)
ensureTimezEqual(t, timestampWithNanoSeconds, dest.TimestampWithNanoSeconds, loc)
ensureTimezEqual(t, timestampWithNanoSeconds, dest.Json.TimestampWithNanoSeconds, loc)
ensureTimezEqual(t, timestampz, dest.Timestampz, loc)
ensureTimezEqual(t, timestampz, dest.Json.Timestampz, loc)
ensureTimezEqual(t, timestampzWithNanoSeconds, dest.TimestampzWithNanoSeconds, loc)
ensureTimezEqual(t, timestampzWithNanoSeconds, dest.Json.TimestampzWithNanoSeconds, loc)
ensureTimezEqual(t, date, dest.Date, loc)
ensureTimezEqual(t, date, dest.Json.Date, loc)
t.Run("json only", func(t *testing.T) {
stmtJson := SELECT_JSON_OBJ(
TimeT(timeT).AS("time"),
TimeT(timeWithNanoSeconds).AS("timeWithNanoSeconds"),
TimezT(timez).AS("timez"),
TimezT(timezWithNanoSeconds).AS("timezWithNanoSeconds"),
Timestamp(1999, 01, 8, 4, 5, 6).AS("timestamp"),
TimestampT(timestampWithNanoSeconds).AS("timestampWithNanoSeconds"),
TimestampzT(timestampz).AS("timestampz"),
TimestampzT(timestampzWithNanoSeconds).AS("timestampzWithNanoSeconds"),
DateT(date).AS("date"),
)
var jsonDest struct {
Time time.Time
TimeWithNanoSeconds time.Time
Timez time.Time
TimezWithNanoSeconds time.Time
Timestamp time.Time
TimestampWithNanoSeconds time.Time
Timestampz time.Time
TimestampzWithNanoSeconds time.Time
Date time.Time
}
err := stmtJson.QueryContext(ctx, db, &jsonDest)
require.NoError(t, err)
})
}
func ensureTimezEqual(t *testing.T, time1, time2 time.Time, loc *time.Location) {
time1Loc := time1.In(loc)
time2Loc := time2.In(loc)
require.Equal(t, time1Loc.Hour(), time2Loc.Hour())
require.Equal(t, time1Loc.Minute(), time2Loc.Minute())
require.Equal(t, time1Loc.Second(), time2Loc.Second())
require.Equal(t, toMicroSeconds(time1Loc.Nanosecond()), toMicroSeconds(time2Loc.Nanosecond()))
}
func toMicroSeconds(nanoseconds int) int {
return nanoseconds / 1000
}
func TestIntervalSetFunctionality(t *testing.T) {
t.Run("updateQueryIntervalTest", func(t *testing.T) {
@ -1052,7 +1571,50 @@ func TestInterval(t *testing.T) {
AllTypes.IntervalPtr.DIV(Float(22.222)).EQ(AllTypes.IntervalPtr),
).FROM(AllTypes)
//fmt.Println(stmt.DebugSql())
fmt.Println(stmt.Sql())
testutils.AssertDebugStatementSql(t, stmt, `
SELECT INTERVAL '1 YEAR',
INTERVAL '1 MONTH',
INTERVAL '1 WEEK',
INTERVAL '1 DAY',
INTERVAL '1 HOUR',
INTERVAL '1 MINUTE',
INTERVAL '1 SECOND',
INTERVAL '1 MILLISECOND',
INTERVAL '1 MICROSECOND',
INTERVAL '1 DECADE',
INTERVAL '1 CENTURY',
INTERVAL '1 MILLENNIUM',
INTERVAL '1 YEAR 10 MONTH',
INTERVAL '1 YEAR 10 MONTH 20 DAY',
INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR',
INTERVAL '1 YEAR' IS NOT NULL,
INTERVAL '1 YEAR' AS "one year",
INTERVAL '0 MICROSECOND',
INTERVAL '1 MICROSECOND',
INTERVAL '1000 MICROSECOND',
INTERVAL '1 SECOND',
INTERVAL '1 MINUTE',
INTERVAL '1 HOUR',
INTERVAL '1 DAY',
INTERVAL '1 DAY 2 HOUR 3 MINUTE 4 SECOND 5 MICROSECOND',
(all_types.interval = INTERVAL '2 HOUR 20 MINUTE') = TRUE::boolean,
(all_types.interval_ptr != INTERVAL '2 HOUR 20 MINUTE') = FALSE::boolean,
(all_types.interval IS DISTINCT FROM INTERVAL '2 HOUR 20 MINUTE') = all_types.boolean,
(all_types.interval_ptr IS NOT DISTINCT FROM INTERVAL '10 MICROSECOND') = all_types.boolean,
(all_types.interval < all_types.interval_ptr) = all_types.boolean_ptr,
(all_types.interval <= all_types.interval_ptr) = all_types.boolean_ptr,
(all_types.interval > all_types.interval_ptr) = all_types.boolean_ptr,
(all_types.interval >= all_types.interval_ptr) = all_types.boolean_ptr,
all_types.interval BETWEEN INTERVAL '1 HOUR' AND INTERVAL '2 HOUR',
all_types.interval NOT BETWEEN all_types.interval_ptr AND INTERVAL '30 SECOND',
(all_types.interval + all_types.interval_ptr) = INTERVAL '17 SECOND',
(all_types.interval - all_types.interval_ptr) = INTERVAL '100 MICROSECOND',
(all_types.interval_ptr * 11) = all_types.interval,
(all_types.interval_ptr / 22.222) = all_types.interval_ptr
FROM test_sample.all_types;
`)
err := stmt.Query(db, &struct{}{})
require.NoError(t, err)
@ -1159,6 +1721,187 @@ SELECT ROW($1::integer, $2::real, $3::text) AS "row",
require.NoError(t, err)
}
func TestAllTypesSubQueryFrom(t *testing.T) {
subQuery := SELECT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.DoublePrecision,
AllTypes.Text,
AllTypes.Date,
AllTypes.Time,
AllTypes.Timez,
AllTypes.Timestamp,
AllTypes.Timestampz,
AllTypes.Interval,
AllTypes.Bytea,
).FROM(
AllTypes,
).AsTable("subQuery")
stmt := SELECT(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.DoublePrecision.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timez.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Timestampz.From(subQuery),
AllTypes.Interval.From(subQuery),
AllTypes.Bytea.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertStatementSql(t, stmt, `
SELECT "subQuery"."all_types.boolean" AS "all_types.boolean",
"subQuery"."all_types.integer" AS "all_types.integer",
"subQuery"."all_types.double_precision" AS "all_types.double_precision",
"subQuery"."all_types.text" AS "all_types.text",
"subQuery"."all_types.date" AS "all_types.date",
"subQuery"."all_types.time" AS "all_types.time",
"subQuery"."all_types.timez" AS "all_types.timez",
"subQuery"."all_types.timestamp" AS "all_types.timestamp",
"subQuery"."all_types.timestampz" AS "all_types.timestampz",
"subQuery"."all_types.interval" AS "all_types.interval",
"subQuery"."all_types.bytea" AS "all_types.bytea"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.double_precision AS "all_types.double_precision",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timez AS "all_types.timez",
all_types.timestamp AS "all_types.timestamp",
all_types.timestampz AS "all_types.timestampz",
all_types.interval AS "all_types.interval",
all_types.bytea AS "all_types.bytea"
FROM test_sample.all_types
) AS "subQuery";
`)
var dest []model.AllTypes
err := stmt.Query(db, &dest)
require.NoError(t, err)
t.Run("using SELECT_JSON", func(t *testing.T) {
stmtJson := SELECT_JSON_ARR(
AllTypes.Boolean.From(subQuery),
AllTypes.Integer.From(subQuery),
AllTypes.DoublePrecision.From(subQuery),
AllTypes.Text.From(subQuery),
AllTypes.Date.From(subQuery),
AllTypes.Time.From(subQuery),
AllTypes.Timez.From(subQuery),
AllTypes.Timestamp.From(subQuery),
AllTypes.Timestampz.From(subQuery),
AllTypes.Interval.From(subQuery),
AllTypes.Bytea.From(subQuery),
).FROM(
subQuery,
)
testutils.AssertDebugStatementSql(t, stmtJson, `
SELECT json_agg(row_to_json(records)) AS "json"
FROM (
SELECT "subQuery"."all_types.boolean" AS "boolean",
"subQuery"."all_types.integer" AS "integer",
"subQuery"."all_types.double_precision" AS "doublePrecision",
"subQuery"."all_types.text" AS "text",
to_char("subQuery"."all_types.date"::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date",
'0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.time", 'HH24:MI:SS.USZ') AS "time",
'0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.timez", 'HH24:MI:SS.USTZH:TZM') AS "timez",
to_char("subQuery"."all_types.timestamp", 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp",
"subQuery"."all_types.timestampz" AS "timestampz",
"subQuery"."all_types.interval" AS "interval",
ENCODE("subQuery"."all_types.bytea", 'base64') AS "bytea"
FROM (
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.double_precision AS "all_types.double_precision",
all_types.text AS "all_types.text",
all_types.date AS "all_types.date",
all_types.time AS "all_types.time",
all_types.timez AS "all_types.timez",
all_types.timestamp AS "all_types.timestamp",
all_types.timestampz AS "all_types.timestampz",
all_types.interval AS "all_types.interval",
all_types.bytea AS "all_types.bytea"
FROM test_sample.all_types
) AS "subQuery"
) AS records;
`)
var destJson []model.AllTypes
err := stmtJson.QueryContext(ctx, db, &destJson)
require.NoError(t, err)
t.Run("using AllColumns()", func(t *testing.T) {
stmtJsonAllColumns := SELECT_JSON_ARR(
subQuery.AllColumns(),
).FROM(
subQuery,
)
require.Equal(t, stmtJson.DebugSql(), stmtJsonAllColumns.DebugSql())
})
// fix timezone before comparisons
minus8 := time.FixedZone("UTC", -8*60*60)
destJson[0].Timez = *toTZ(&destJson[0].Timez, minus8)
destJson[1].Timez = *toTZ(&destJson[1].Timez, minus8)
destJson[0].Timestampz = *toTZ(&destJson[0].Timestampz, time.UTC)
destJson[1].Timestampz = *toTZ(&destJson[1].Timestampz, time.UTC)
dest[0].Timestampz = *toTZ(&dest[0].Timestampz, time.UTC)
dest[1].Timestampz = *toTZ(&dest[1].Timestampz, time.UTC)
testutils.AssertJsonEqual(t, dest, destJson)
})
}
func TestAllTypesUpdateSet(t *testing.T) {
stmt := AllTypes.UPDATE().
SET(
AllTypes.Boolean.SET(Bool(false)),
AllTypes.Integer.SET(Int(2)),
AllTypes.DoublePrecision.SET(Float(2.22)),
AllTypes.Text.SET(Text("some text")),
AllTypes.Date.SET(DateT(time.Now())),
AllTypes.Time.SET(TimeT(time.Now())),
AllTypes.Timez.SET(TimezT(time.Now())),
AllTypes.Timestamp.SET(TimestampT(time.Now())),
AllTypes.Interval.SET(INTERVAL(1, HOUR)),
AllTypes.Bytea.SET(Bytea([]byte{11, 22, 33, 44})),
).WHERE(Bool(true))
testutils.AssertStatementSql(t, stmt, `
UPDATE test_sample.all_types
SET boolean = $1::boolean,
integer = $2,
double_precision = $3,
text = $4::text,
date = $5::date,
time = $6::time without time zone,
timez = $7::time with time zone,
timestamp = $8::timestamp without time zone,
interval = INTERVAL '1 HOUR',
bytea = $9::bytea
WHERE $10::boolean;
`)
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
_, err := stmt.Exec(tx)
require.NoError(t, err)
})
}
func TestSubQueryColumnReference(t *testing.T) {
type expected struct {
sql string

View file

@ -188,7 +188,129 @@ ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId";
`)
}
type AllArtistDetails []struct { //list of all artist
model.Artist
Albums []struct { // list of albums per artist
model.Album
Tracks []struct { // list of tracks per album
model.Track
Genre model.Genre // track genre
MediaType model.MediaType // track media type
Playlists []model.Playlist // list of playlist where track is used
Invoices []struct { // list of invoices where track occurs
model.Invoice
Customer struct { // customer data for invoice
model.Customer
Employee *struct { // employee data for customer if exists
model.Employee
Manager *model.Employee `alias:"Manager"`
}
}
}
}
}
}
func BenchmarkJoinEverythingJSON(b *testing.B) {
for i := 0; i < b.N; i++ {
testJoinEverythingJSON(b)
}
}
func TestJoinEverythingJSON(t *testing.T) {
testJoinEverythingJSON(t)
}
func testJoinEverythingJSON(t require.TestingT) {
manager := Employee.AS("Manager")
stmt := SELECT_JSON_ARR(
Artist.AllColumns,
SELECT_JSON_ARR(
Album.AllColumns,
SELECT_JSON_ARR(
Track.AllColumns,
SELECT_JSON_OBJ(Genre.AllColumns).
FROM(Genre).
WHERE(Genre.GenreId.EQ(Track.GenreId)).AS("Genre"),
SELECT_JSON_OBJ(MediaType.AllColumns).
FROM(MediaType).
WHERE(MediaType.MediaTypeId.EQ(Track.MediaTypeId)).AS("MediaType"),
SELECT_JSON_ARR(Playlist.AllColumns).
FROM(Playlist.
INNER_JOIN(PlaylistTrack, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId))).
WHERE(PlaylistTrack.TrackId.EQ(Track.TrackId)).
ORDER_BY(Playlist.PlaylistId).AS("Playlists"),
SELECT_JSON_ARR(
Invoice.AllColumns,
SELECT_JSON_OBJ(
Customer.AllColumns,
SELECT_JSON_OBJ(
Employee.AllColumns,
SELECT_JSON_OBJ(manager.AllColumns).
FROM(manager).
WHERE(manager.EmployeeId.EQ(Employee.ReportsTo)).AS("Manager"),
).FROM(Employee).
WHERE(Employee.EmployeeId.EQ(Customer.SupportRepId)).AS("Employee"),
).FROM(Customer).
WHERE(Customer.CustomerId.EQ(Invoice.CustomerId)).AS("Customer"),
).FROM(
Invoice.
INNER_JOIN(InvoiceLine, InvoiceLine.InvoiceId.EQ(Invoice.InvoiceId)),
).WHERE(InvoiceLine.TrackId.EQ(Track.TrackId)).
ORDER_BY(Invoice.InvoiceId).AS("Invoices"),
).FROM(Track).
WHERE(Track.AlbumId.EQ(Album.AlbumId)).
ORDER_BY(Track.TrackId).AS("Tracks"),
).FROM(Album).
WHERE(Album.ArtistId.EQ(Artist.ArtistId)).
ORDER_BY(Album.AlbumId).AS("Albums"),
).FROM(Artist).
ORDER_BY(Artist.ArtistId)
//fmt.Println(stmt.DebugSql())
var dest AllArtistDetails
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 275)
//testutils.SaveJSONFile(dest, "./testdata/results/postgres/joined_everything2.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json")
requireLogged(t, stmt)
requireQueryLogged(t, stmt, 1)
}
func BenchmarkJoinEverything(b *testing.B) {
for i := 0; i < b.N; i++ {
testJoinEverything(b)
}
}
func TestJoinEverything(t *testing.T) {
testJoinEverything(t)
}
func testJoinEverything(t require.TestingT) {
manager := Employee.AS("Manager")
@ -223,37 +345,6 @@ func TestJoinEverything(t *testing.T) {
Invoice.InvoiceId, Customer.CustomerId,
)
var dest []struct { //list of all artist
model.Artist
Albums []struct { // list of albums per artist
model.Album
Tracks []struct { // list of tracks per album
model.Track
Genre model.Genre // track genre
MediaType model.MediaType // track media type
Playlists []model.Playlist // list of playlist where track is used
Invoices []struct { // list of invoices where track occurs
model.Invoice
Customer struct { // customer data for invoice
model.Customer
Employee *struct { // employee data for customer if exists
model.Employee
Manager *model.Employee `alias:"Manager"`
}
}
}
}
}
}
testutils.AssertStatementSql(t, stmt, `
SELECT "Artist"."ArtistId" AS "Artist.ArtistId",
"Artist"."Name" AS "Artist.Name",
@ -344,7 +435,7 @@ FROM chinook."Artist"
LEFT JOIN chinook."Employee" AS "Manager" ON ("Manager"."EmployeeId" = "Employee"."ReportsTo")
ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId";
`)
var dest AllArtistDetails
err := stmt.QueryContext(context.Background(), db, &dest)
require.NoError(t, err)

View file

@ -974,8 +974,8 @@ type allTypesTable struct {
Char postgres.ColumnString
TextPtr postgres.ColumnString
Text postgres.ColumnString
ByteaPtr postgres.ColumnString
Bytea postgres.ColumnString
ByteaPtr postgres.ColumnBytea
Bytea postgres.ColumnBytea
TimestampzPtr postgres.ColumnTimestampz
Timestampz postgres.ColumnTimestampz
TimestampPtr postgres.ColumnTimestamp
@ -1078,8 +1078,8 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable {
CharColumn = postgres.StringColumn("char")
TextPtrColumn = postgres.StringColumn("text_ptr")
TextColumn = postgres.StringColumn("text")
ByteaPtrColumn = postgres.StringColumn("bytea_ptr")
ByteaColumn = postgres.StringColumn("bytea")
ByteaPtrColumn = postgres.ByteaColumn("bytea_ptr")
ByteaColumn = postgres.ByteaColumn("bytea")
TimestampzPtrColumn = postgres.TimestampzColumn("timestampz_ptr")
TimestampzColumn = postgres.TimestampzColumn("timestampz")
TimestampPtrColumn = postgres.TimestampColumn("timestamp_ptr")

View file

@ -20,6 +20,8 @@ import (
_ "github.com/jackc/pgx/v4/stdlib"
)
var ctx = context.Background()
var db *stmtcache.DB
var testRoot string
@ -31,6 +33,7 @@ const CockroachDB = "COCKROACH_DB"
func init() {
source = os.Getenv("PG_SOURCE")
withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true"
testRoot = repo.GetTestsDirPath()
}
func sourceIsCockroachDB() bool {
@ -46,8 +49,6 @@ func skipForCockroachDB(t *testing.T) {
func TestMain(m *testing.M) {
defer profile.Start().Stop()
setTestRoot()
for _, driverName := range []string{"postgres", "pgx"} {
fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, withStatementCaching)
@ -94,10 +95,6 @@ func getConnectionString() string {
return dbconfig.PostgresConnectString
}
func setTestRoot() {
testRoot = repo.GetTestsDirPath()
}
var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string
@ -119,14 +116,22 @@ func init() {
})
}
func requireLogged(t *testing.T, statement postgres.Statement) {
func requireLogged(t require.TestingT, statement postgres.Statement) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
query, args := statement.Sql()
require.Equal(t, loggedSQL, query)
require.Equal(t, loggedSQLArgs, args)
require.Equal(t, loggedDebugSQL, statement.DebugSql())
}
func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) {
func requireQueryLogged(t require.TestingT, statement postgres.Statement, rowsProcessed int64) {
if _, ok := t.(*testing.B); ok {
return // skip assert for benchmarks
}
query, args := statement.Sql()
queryLogged, argsLogged := queryInfo.Statement.Sql()

View file

@ -9,7 +9,50 @@ import (
"testing"
)
func TestNorthwindJoinEverything(t *testing.T) {
type Dest []struct {
model.Customers
Demographics model.CustomerDemographics
Orders []struct {
model.Orders
Shipper model.Shippers
Employee struct {
model.Employees
Territories []struct {
model.Territories
Region model.Region
}
}
Details []struct {
model.OrderDetails
Products struct {
model.Products
Category model.Categories
Supplier model.Suppliers
}
}
}
}
func BenchmarkTestNorthwindJoinEverything(b *testing.B) {
for i := 0; i < b.N; i++ {
testNorthwindJoinEverything(b)
}
}
func TestTestNorthwindJoinEverything(t *testing.T) {
testNorthwindJoinEverything(t)
}
func testNorthwindJoinEverything(t require.TestingT) {
stmt :=
SELECT(
@ -21,6 +64,9 @@ func TestNorthwindJoinEverything(t *testing.T) {
Products.AllColumns,
Categories.AllColumns,
Suppliers.AllColumns,
Employees.AllColumns,
Territories.AllColumns,
Region.AllColumns,
).FROM(
Customers.
LEFT_JOIN(CustomerCustomerDemo, Customers.CustomerID.EQ(CustomerCustomerDemo.CustomerID)).
@ -35,35 +81,110 @@ func TestNorthwindJoinEverything(t *testing.T) {
LEFT_JOIN(EmployeeTerritories, EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID)).
LEFT_JOIN(Territories, EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)).
LEFT_JOIN(Region, Territories.RegionID.EQ(Region.RegionID)),
).ORDER_BY(Customers.CustomerID, Orders.OrderID, Products.ProductID)
).ORDER_BY(
Customers.CustomerID,
Orders.OrderID,
Products.ProductID,
Territories.TerritoryID,
)
var dest []struct {
model.Customers
//fmt.Println(stmt.DebugSql())
Demographics model.CustomerDemographics
Orders []struct {
model.Orders
Shipper model.Shippers
Details struct {
model.OrderDetails
Products []struct {
model.Products
Category model.Categories
Supplier model.Suppliers
}
}
}
}
var dest Dest
err := stmt.Query(db, &dest)
require.NoError(t, err)
//jsonSave("./testdata/northwind-all.json", dest)
//testutils.SaveJSONFile(dest, "./testdata/results/postgres/northwind-all.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json")
requireLogged(t, stmt)
}
func BenchmarkTestNorthwindJoinEverythingJson(b *testing.B) {
for i := 0; i < b.N; i++ {
testNorthwindJoinEverythingJson(b)
}
}
func TestNorthwindJoinEverythingJson(t *testing.T) {
testNorthwindJoinEverythingJson(t)
}
func testNorthwindJoinEverythingJson(t require.TestingT) {
stmt := SELECT_JSON_ARR(
Customers.AllColumns,
SELECT_JSON_OBJ(CustomerDemographics.AllColumns).
FROM(CustomerDemographics.INNER_JOIN(CustomerCustomerDemo, CustomerCustomerDemo.CustomerTypeID.EQ(CustomerDemographics.CustomerTypeID))).
WHERE(CustomerCustomerDemo.CustomerID.EQ(Customers.CustomerID)).AS("Demographics"),
SELECT_JSON_ARR(
Orders.AllColumns,
SELECT_JSON_OBJ(Shippers.AllColumns).
FROM(Shippers).
WHERE(Shippers.ShipperID.EQ(Orders.ShipVia)).AS("Shipper"),
SELECT_JSON_OBJ(
Employees.AllColumns,
SELECT_JSON_ARR(
Territories.AllColumns,
SELECT_JSON_OBJ(Region.AllColumns).
FROM(Region).
WHERE(Region.RegionID.EQ(Territories.RegionID)).AS("Region"),
).FROM(
EmployeeTerritories.LEFT_JOIN(
Territories,
EmployeeTerritories.TerritoryID.EQ(Territories.TerritoryID)),
).WHERE(
EmployeeTerritories.EmployeeID.EQ(Employees.EmployeeID),
).AS("Territories"),
).FROM(Employees).
WHERE(Orders.EmployeeID.EQ(Employees.EmployeeID)).AS("Employee"),
SELECT_JSON_ARR(
OrderDetails.AllColumns,
SELECT_JSON_OBJ(
Products.AllColumns,
SELECT_JSON_OBJ(
Categories.AllColumns,
).FROM(Categories).
WHERE(Categories.CategoryID.EQ(Products.CategoryID)).AS("Category"),
SELECT_JSON_OBJ(Suppliers.AllColumns).
FROM(Suppliers).
WHERE(Suppliers.SupplierID.EQ(Products.SupplierID)).AS("Supplier"),
).FROM(Products).
WHERE(Products.ProductID.EQ(OrderDetails.ProductID)).AS("Products"),
).FROM(
OrderDetails,
).WHERE(
OrderDetails.OrderID.EQ(Orders.OrderID),
).AS("Details"),
).FROM(
Orders,
).WHERE(
Orders.CustomerID.EQ(Customers.CustomerID),
).ORDER_BY(
Orders.OrderID,
).AS("Orders"),
).FROM(
Customers,
).ORDER_BY(
Customers.CustomerID,
)
//fmt.Println(stmt.DebugSql())
var dest Dest
err := stmt.QueryContext(ctx, db, &dest)
require.NoError(t, err)
//testutils.SaveJSONFile(dest, "./testdata/results/postgres/northwind-all2.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json")
}

View file

@ -220,20 +220,7 @@ func TestUUIDComplex(t *testing.T) {
requireLogged(t, query)
})
t.Run("slice of structs left join", func(t *testing.T) {
leftQuery := Person.LEFT_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)).
SELECT(Person.AllColumns, PersonPhone.AllColumns).
ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC())
var dest []struct {
model.Person
Phones []struct {
model.PersonPhone
}
}
err := leftQuery.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
var expectedSliceOfStructsLeftJoin = `
[
{
"PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6",
@ -274,10 +261,50 @@ func TestUUIDComplex(t *testing.T) {
]
}
]
`)
`
t.Run("slice of structs left join", func(t *testing.T) {
leftQuery := Person.LEFT_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)).
SELECT(Person.AllColumns, PersonPhone.AllColumns).
ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC())
var dest []struct {
model.Person
Phones []struct {
model.PersonPhone
}
}
err := leftQuery.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, expectedSliceOfStructsLeftJoin)
requireLogged(t, leftQuery)
})
t.Run("select json", func(t *testing.T) {
jsonQuery := SELECT_JSON_ARR(
Person.AllColumns,
SELECT_JSON_ARR(PersonPhone.AllColumns).
FROM(PersonPhone).
WHERE(PersonPhone.PersonID.EQ(Person.PersonID)).
ORDER_BY(PersonPhone.PhoneID).AS("Phones"),
).FROM(
Person,
).ORDER_BY(
Person.PersonID.ASC(),
)
var dest []struct {
model.Person
Phones []struct {
model.PersonPhone
}
}
err := jsonQuery.QueryContext(ctx, db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, expectedSliceOfStructsLeftJoin)
})
}
func TestEnumType(t *testing.T) {
query := Person.

View file

@ -209,7 +209,7 @@ func TestScanToStruct(t *testing.T) {
err := query.Query(db, &dest)
require.Error(t, err)
require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID")
require.EqualError(t, err, "jet: can't assign int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID")
})
t.Run("type mismatch base type", func(t *testing.T) {

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show more