Merge pull request #32 from go-jet/develop

Interval type support. Sample usage can be seen here.
Datetime arithmetic with interval types.
Dynamic projection list support. Sample usage.
[bug] Escape reserved words used as identifier(issue).
[bug] Fix crash on generating enum SQL Builder files when database enum contains numeric values(issue).
This commit is contained in:
go-jet 2020-02-17 20:08:39 +01:00 committed by GitHub
commit eea776a1ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
93 changed files with 3152 additions and 1174 deletions

View file

@ -46,7 +46,8 @@ jobs:
go get github.com/go-sql-driver/mysql go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile go get github.com/pkg/profile
go get gotest.tools/assert go get github.com/stretchr/testify/assert
go get github.com/google/go-cmp/cmp
go get github.com/davecgh/go-spew/spew go get github.com/davecgh/go-spew/spew
go get github.com/jstemmer/go-junit-report go get github.com/jstemmer/go-junit-report
@ -142,7 +143,8 @@ jobs:
go get github.com/go-sql-driver/mysql go get github.com/go-sql-driver/mysql
go get github.com/pkg/profile go get github.com/pkg/profile
go get gotest.tools/assert go get github.com/stretchr/testify/assert
go get github.com/google/go-cmp/cmp
go get github.com/davecgh/go-spew/spew go get github.com/davecgh/go-spew/spew
go get github.com/jstemmer/go-junit-report go get github.com/jstemmer/go-junit-report

View file

@ -35,13 +35,13 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
## Features ## Features
1) Auto-generated type-safe SQL Builder 1) Auto-generated type-safe SQL Builder
- PostgreSQL: - PostgreSQL:
* SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, sub-queries)` * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)`
* INSERT `(VALUES, query, RETURNING)`, * INSERT `(VALUES, query, RETURNING)`,
* UPDATE `(SET, WHERE, RETURNING)`, * UPDATE `(SET, WHERE, RETURNING)`,
* DELETE `(WHERE, RETURNING)`, * DELETE `(WHERE, RETURNING)`,
* LOCK `(IN, NOWAIT)` * LOCK `(IN, NOWAIT)`
- MySQL and MariaDB: - MySQL and MariaDB:
* SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, sub-queries)` * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)`
* INSERT `(VALUES, query)`, * INSERT `(VALUES, query)`,
* UPDATE `(SET, WHERE)`, * UPDATE `(SET, WHERE)`,
* DELETE `(WHERE, ORDER_BY, LIMIT)`, * DELETE `(WHERE, ORDER_BY, LIMIT)`,
@ -560,15 +560,12 @@ At the moment Jet dependence only of:
To run the tests, additional dependencies are required: To run the tests, additional dependencies are required:
- `github.com/pkg/profile` - `github.com/pkg/profile`
- `gotest.tools/assert` - `github.com/stretchr/testify`
## Versioning ## Versioning
[SemVer](http://semver.org/) is used for versioning. For the versions available, take a look at the [releases](https://github.com/go-jet/jet/releases). [SemVer](http://semver.org/) is used for versioning. For the versions available, take a look at the [releases](https://github.com/go-jet/jet/releases).
For now there is no guarantee that public API will remain backward compatible. Please read new release drafts to get acquaint how to handle possible build breakable API changes.
## License ## License
Copyright 2019 Goran Bjelanovic Copyright 2019 Goran Bjelanovic

View file

@ -57,8 +57,10 @@ func (c ColumnMetaData) getSqlBuilderColumnType() string {
return "Time" return "Time"
case "time with time zone": case "time with time zone":
return "Timez" return "Timez"
case "interval":
return "Interval"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid", case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "binary", "varbinary", "char", "varchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String" return "String"

View file

@ -73,7 +73,8 @@ func generateGoFiles(dirPath, packageName string, template string, metaDataList
func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) { func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier, "ToGoIdentifier": utils.ToGoIdentifier,
"ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier,
"now": func() string { "now": func() string {
return time.Now().Format(time.RFC850) return time.Now().Format(time.RFC850)
}, },

View file

@ -94,11 +94,11 @@ import "github.com/go-jet/jet/{{dialect.PackageName}}"
var {{ToGoIdentifier $.Name}} = &struct { var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}} {{- range $index, $element := .Values}}
{{ToGoIdentifier $element}} {{dialect.PackageName}}.StringExpression {{ToGoEnumValueIdentifier $.Name $element}} {{dialect.PackageName}}.StringExpression
{{- end}} {{- end}}
} { } {
{{- range $index, $element := .Values}} {{- range $index, $element := .Values}}
{{ToGoIdentifier $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"), {{ToGoEnumValueIdentifier $.Name $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"),
{{- end}} {{- end}}
} }
` `

View file

@ -1,7 +1,7 @@
package snaker package snaker
import ( import (
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )

View file

@ -13,8 +13,7 @@ func newAlias(expression Expression, aliasName string) Projection {
} }
func (a *alias) fromImpl(subQuery SelectTable) Projection { func (a *alias) fromImpl(subQuery SelectTable) Projection {
column := newColumn(a.alias, "", nil) column := NewColumnImpl(a.alias, "", nil)
column.Parent = &column
column.subQuery = subQuery column.subQuery = subQuery
return &column return &column

View file

@ -37,109 +37,69 @@ type boolInterfaceImpl struct {
} }
func (b *boolInterfaceImpl) EQ(expression BoolExpression) BoolExpression { func (b *boolInterfaceImpl) EQ(expression BoolExpression) BoolExpression {
return eq(b.parent, expression) return Eq(b.parent, expression)
} }
func (b *boolInterfaceImpl) NOT_EQ(expression BoolExpression) BoolExpression { func (b *boolInterfaceImpl) NOT_EQ(expression BoolExpression) BoolExpression {
return notEq(b.parent, expression) return NotEq(b.parent, expression)
} }
func (b *boolInterfaceImpl) IS_DISTINCT_FROM(rhs BoolExpression) BoolExpression { func (b *boolInterfaceImpl) IS_DISTINCT_FROM(rhs BoolExpression) BoolExpression {
return isDistinctFrom(b.parent, rhs) return IsDistinctFrom(b.parent, rhs)
} }
func (b *boolInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BoolExpression) BoolExpression { func (b *boolInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BoolExpression) BoolExpression {
return isNotDistinctFrom(b.parent, rhs) return IsNotDistinctFrom(b.parent, rhs)
} }
func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression { func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression {
return newBinaryBoolOperator(b.parent, expression, "AND") return newBinaryBoolOperatorExpression(b.parent, expression, "AND")
} }
func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression { func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression {
return newBinaryBoolOperator(b.parent, expression, "OR") return newBinaryBoolOperatorExpression(b.parent, expression, "OR")
} }
func (b *boolInterfaceImpl) IS_TRUE() BoolExpression { func (b *boolInterfaceImpl) IS_TRUE() BoolExpression {
return newPostifxBoolExpression(b.parent, "IS TRUE") return newPostfixBoolOperatorExpression(b.parent, "IS TRUE")
} }
func (b *boolInterfaceImpl) IS_NOT_TRUE() BoolExpression { func (b *boolInterfaceImpl) IS_NOT_TRUE() BoolExpression {
return newPostifxBoolExpression(b.parent, "IS NOT TRUE") return newPostfixBoolOperatorExpression(b.parent, "IS NOT TRUE")
} }
func (b *boolInterfaceImpl) IS_FALSE() BoolExpression { func (b *boolInterfaceImpl) IS_FALSE() BoolExpression {
return newPostifxBoolExpression(b.parent, "IS FALSE") return newPostfixBoolOperatorExpression(b.parent, "IS FALSE")
} }
func (b *boolInterfaceImpl) IS_NOT_FALSE() BoolExpression { func (b *boolInterfaceImpl) IS_NOT_FALSE() BoolExpression {
return newPostifxBoolExpression(b.parent, "IS NOT FALSE") return newPostfixBoolOperatorExpression(b.parent, "IS NOT FALSE")
} }
func (b *boolInterfaceImpl) IS_UNKNOWN() BoolExpression { func (b *boolInterfaceImpl) IS_UNKNOWN() BoolExpression {
return newPostifxBoolExpression(b.parent, "IS UNKNOWN") return newPostfixBoolOperatorExpression(b.parent, "IS UNKNOWN")
} }
func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression { func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression {
return newPostifxBoolExpression(b.parent, "IS NOT UNKNOWN") return newPostfixBoolOperatorExpression(b.parent, "IS NOT UNKNOWN")
} }
//---------------------------------------------------// //---------------------------------------------------//
type binaryBoolExpression struct { func newBinaryBoolOperatorExpression(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression {
expressionInterfaceImpl return BoolExp(NewBinaryOperatorExpression(lhs, rhs, operator, additionalParams...))
boolInterfaceImpl
binaryOpExpression
}
func newBinaryBoolOperator(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression {
binaryBoolExpression := binaryBoolExpression{}
binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator, additionalParams...)
binaryBoolExpression.expressionInterfaceImpl.Parent = &binaryBoolExpression
binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression
return &binaryBoolExpression
} }
//---------------------------------------------------// //---------------------------------------------------//
type prefixBoolExpression struct { func newPrefixBoolOperatorExpression(expression Expression, operator string) BoolExpression {
expressionInterfaceImpl return BoolExp(newPrefixOperatorExpression(expression, operator))
boolInterfaceImpl
prefixOpExpression
}
func newPrefixBoolOperator(expression Expression, operator string) BoolExpression {
exp := prefixBoolExpression{}
exp.prefixOpExpression = newPrefixExpression(expression, operator)
exp.expressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp
return &exp
} }
//---------------------------------------------------// //---------------------------------------------------//
type postfixBoolOpExpression struct { func newPostfixBoolOperatorExpression(expression Expression, operator string) BoolExpression {
expressionInterfaceImpl return BoolExp(newPostfixOperatorExpression(expression, operator))
boolInterfaceImpl
postfixOpExpression
}
func newPostifxBoolExpression(expression Expression, operator string) BoolExpression {
exp := postfixBoolOpExpression{}
exp.postfixOpExpression = newPostfixOpExpression(expression, operator)
exp.expressionInterfaceImpl.Parent = &exp
exp.boolInterfaceImpl.parent = &exp
return &exp
} }
//---------------------------------------------------// //---------------------------------------------------//
type boolExpressionWrapper struct { type boolExpressionWrapper struct {
boolInterfaceImpl boolInterfaceImpl
Expression Expression

View file

@ -24,13 +24,13 @@ func (b *castImpl) AS(castType string) Expression {
cast: string(castType), cast: string(castType),
} }
castExp.expressionInterfaceImpl.Parent = castExp castExp.ExpressionInterfaceImpl.Parent = castExp
return castExp return castExp
} }
type castExpression struct { type castExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
expression Expression expression Expression
cast string cast string

View file

@ -1,7 +1,7 @@
package jet package jet
import ( import (
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )

View file

@ -18,9 +18,9 @@ type ColumnExpression interface {
Expression Expression
} }
// The base type for real materialized columns. // ColumnExpressionImpl is base type for sql columns.
type columnImpl struct { type ColumnExpressionImpl struct {
expressionInterfaceImpl ExpressionInterfaceImpl
name string name string
tableName string tableName string
@ -28,34 +28,41 @@ type columnImpl struct {
subQuery SelectTable subQuery SelectTable
} }
func newColumn(name string, tableName string, parent ColumnExpression) columnImpl { // NewColumnImpl creates new ColumnExpressionImpl
bc := columnImpl{ func NewColumnImpl(name string, tableName string, parent ColumnExpression) ColumnExpressionImpl {
bc := ColumnExpressionImpl{
name: name, name: name,
tableName: tableName, tableName: tableName,
} }
bc.expressionInterfaceImpl.Parent = parent if parent != nil {
bc.ExpressionInterfaceImpl.Parent = parent
} else {
bc.ExpressionInterfaceImpl.Parent = &bc
}
return bc return bc
} }
func (c *columnImpl) Name() string { // Name returns name of the column
func (c *ColumnExpressionImpl) Name() string {
return c.name return c.name
} }
func (c *columnImpl) TableName() string { // TableName returns column table name
func (c *ColumnExpressionImpl) TableName() string {
return c.tableName return c.tableName
} }
func (c *columnImpl) setTableName(table string) { func (c *ColumnExpressionImpl) setTableName(table string) {
c.tableName = table c.tableName = table
} }
func (c *columnImpl) setSubQuery(subQuery SelectTable) { func (c *ColumnExpressionImpl) setSubQuery(subQuery SelectTable) {
c.subQuery = subQuery c.subQuery = subQuery
} }
func (c *columnImpl) defaultAlias() string { func (c *ColumnExpressionImpl) defaultAlias() string {
if c.tableName != "" { if c.tableName != "" {
return c.tableName + "." + c.name return c.tableName + "." + c.name
} }
@ -63,25 +70,31 @@ func (c *columnImpl) defaultAlias() string {
return c.name return c.name
} }
func (c *columnImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { func (c *ColumnExpressionImpl) fromImpl(subQuery SelectTable) Projection {
newColumn := NewColumnImpl(c.name, c.tableName, nil)
newColumn.setSubQuery(subQuery)
return &newColumn
}
func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
if statement == SetStatementType { if statement == SetStatementType {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.WriteAlias(c.defaultAlias()) //always quote out.WriteAlias(c.defaultAlias()) //always quote
return return
} }
c.serialize(statement, out) c.serialize(statement, out)
} }
func (c columnImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { func (c ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
c.serialize(statement, out) c.serialize(statement, out)
out.WriteString("AS") out.WriteString("AS")
out.WriteAlias(c.defaultAlias()) out.WriteAlias(c.defaultAlias())
} }
func (c columnImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.subQuery != nil { if c.subQuery != nil {
out.WriteIdentifier(c.subQuery.Alias()) out.WriteIdentifier(c.subQuery.Alias())
@ -128,3 +141,13 @@ func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {} func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {} func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" } func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -3,8 +3,8 @@ package jet
import "testing" import "testing"
func TestColumn(t *testing.T) { func TestColumn(t *testing.T) {
column := newColumn("col", "", nil) column := NewColumnImpl("col", "", nil)
column.expressionInterfaceImpl.Parent = &column column.ExpressionInterfaceImpl.Parent = &column
assertClauseSerialize(t, column, "col") assertClauseSerialize(t, column, "col")
column.setTableName("table1") column.setTableName("table1")

View file

@ -10,11 +10,10 @@ type ColumnBool interface {
type boolColumnImpl struct { type boolColumnImpl struct {
boolInterfaceImpl boolInterfaceImpl
ColumnExpressionImpl
columnImpl
} }
func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
newBoolColumn := BoolColumn(i.name) newBoolColumn := BoolColumn(i.name)
newBoolColumn.setTableName(i.tableName) newBoolColumn.setTableName(i.tableName)
newBoolColumn.setSubQuery(subQuery) newBoolColumn.setSubQuery(subQuery)
@ -22,16 +21,10 @@ func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newBoolColumn return newBoolColumn
} }
func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
newBoolColumn := i.fromImpl(subQuery).(ColumnBool)
return newBoolColumn
}
// BoolColumn creates named bool column. // BoolColumn creates named bool column.
func BoolColumn(name string) ColumnBool { func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{} boolColumn := &boolColumnImpl{}
boolColumn.columnImpl = newColumn(name, "", boolColumn) boolColumn.ColumnExpressionImpl = NewColumnImpl(name, "", boolColumn)
boolColumn.boolInterfaceImpl.parent = boolColumn boolColumn.boolInterfaceImpl.parent = boolColumn
return boolColumn return boolColumn
@ -49,19 +42,13 @@ type ColumnFloat interface {
type floatColumnImpl struct { type floatColumnImpl struct {
floatInterfaceImpl floatInterfaceImpl
columnImpl ColumnExpressionImpl
}
func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection {
newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery)
return newFloatColumn
} }
func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
newFloatColumn := i.fromImpl(subQuery).(ColumnFloat) newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery)
return newFloatColumn return newFloatColumn
} }
@ -70,7 +57,7 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
func FloatColumn(name string) ColumnFloat { func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{} floatColumn := &floatColumnImpl{}
floatColumn.floatInterfaceImpl.parent = floatColumn floatColumn.floatInterfaceImpl.parent = floatColumn
floatColumn.columnImpl = newColumn(name, "", floatColumn) floatColumn.ColumnExpressionImpl = NewColumnImpl(name, "", floatColumn)
return floatColumn return floatColumn
} }
@ -88,10 +75,10 @@ type ColumnInteger interface {
type integerColumnImpl struct { type integerColumnImpl struct {
integerInterfaceImpl integerInterfaceImpl
columnImpl ColumnExpressionImpl
} }
func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
newIntColumn := IntegerColumn(i.name) newIntColumn := IntegerColumn(i.name)
newIntColumn.setTableName(i.tableName) newIntColumn.setTableName(i.tableName)
newIntColumn.setSubQuery(subQuery) newIntColumn.setSubQuery(subQuery)
@ -99,15 +86,11 @@ func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newIntColumn return newIntColumn
} }
func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return i.fromImpl(subQuery).(ColumnInteger)
}
// IntegerColumn creates named integer column. // IntegerColumn creates named integer column.
func IntegerColumn(name string) ColumnInteger { func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{} integerColumn := &integerColumnImpl{}
integerColumn.integerInterfaceImpl.parent = integerColumn integerColumn.integerInterfaceImpl.parent = integerColumn
integerColumn.columnImpl = newColumn(name, "", integerColumn) integerColumn.ColumnExpressionImpl = NewColumnImpl(name, "", integerColumn)
return integerColumn return integerColumn
} }
@ -126,10 +109,10 @@ type ColumnString interface {
type stringColumnImpl struct { type stringColumnImpl struct {
stringInterfaceImpl stringInterfaceImpl
columnImpl ColumnExpressionImpl
} }
func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
newStrColumn := StringColumn(i.name) newStrColumn := StringColumn(i.name)
newStrColumn.setTableName(i.tableName) newStrColumn.setTableName(i.tableName)
newStrColumn.setSubQuery(subQuery) newStrColumn.setSubQuery(subQuery)
@ -137,15 +120,11 @@ func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newStrColumn return newStrColumn
} }
func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return i.fromImpl(subQuery).(ColumnString)
}
// StringColumn creates named string column. // StringColumn creates named string column.
func StringColumn(name string) ColumnString { func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{} stringColumn := &stringColumnImpl{}
stringColumn.stringInterfaceImpl.parent = stringColumn stringColumn.stringInterfaceImpl.parent = stringColumn
stringColumn.columnImpl = newColumn(name, "", stringColumn) stringColumn.ColumnExpressionImpl = NewColumnImpl(name, "", stringColumn)
return stringColumn return stringColumn
} }
@ -162,10 +141,10 @@ type ColumnTime interface {
type timeColumnImpl struct { type timeColumnImpl struct {
timeInterfaceImpl timeInterfaceImpl
columnImpl ColumnExpressionImpl
} }
func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
newTimeColumn := TimeColumn(i.name) newTimeColumn := TimeColumn(i.name)
newTimeColumn.setTableName(i.tableName) newTimeColumn.setTableName(i.tableName)
newTimeColumn.setSubQuery(subQuery) newTimeColumn.setSubQuery(subQuery)
@ -173,15 +152,11 @@ func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newTimeColumn return newTimeColumn
} }
func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return i.fromImpl(subQuery).(ColumnTime)
}
// TimeColumn creates named time column // TimeColumn creates named time column
func TimeColumn(name string) ColumnTime { func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{} timeColumn := &timeColumnImpl{}
timeColumn.timeInterfaceImpl.parent = timeColumn timeColumn.timeInterfaceImpl.parent = timeColumn
timeColumn.columnImpl = newColumn(name, "", timeColumn) timeColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timeColumn)
return timeColumn return timeColumn
} }
@ -197,11 +172,10 @@ type ColumnTimez interface {
type timezColumnImpl struct { type timezColumnImpl struct {
timezInterfaceImpl timezInterfaceImpl
ColumnExpressionImpl
columnImpl
} }
func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
newTimezColumn := TimezColumn(i.name) newTimezColumn := TimezColumn(i.name)
newTimezColumn.setTableName(i.tableName) newTimezColumn.setTableName(i.tableName)
newTimezColumn.setSubQuery(subQuery) newTimezColumn.setSubQuery(subQuery)
@ -209,15 +183,11 @@ func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newTimezColumn return newTimezColumn
} }
func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return i.fromImpl(subQuery).(ColumnTimez)
}
// TimezColumn creates named time with time zone column. // TimezColumn creates named time with time zone column.
func TimezColumn(name string) ColumnTimez { func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{} timezColumn := &timezColumnImpl{}
timezColumn.timezInterfaceImpl.parent = timezColumn timezColumn.timezInterfaceImpl.parent = timezColumn
timezColumn.columnImpl = newColumn(name, "", timezColumn) timezColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timezColumn)
return timezColumn return timezColumn
} }
@ -234,11 +204,10 @@ type ColumnTimestamp interface {
type timestampColumnImpl struct { type timestampColumnImpl struct {
timestampInterfaceImpl timestampInterfaceImpl
ColumnExpressionImpl
columnImpl
} }
func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
newTimestampColumn := TimestampColumn(i.name) newTimestampColumn := TimestampColumn(i.name)
newTimestampColumn.setTableName(i.tableName) newTimestampColumn.setTableName(i.tableName)
newTimestampColumn.setSubQuery(subQuery) newTimestampColumn.setSubQuery(subQuery)
@ -246,15 +215,11 @@ func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newTimestampColumn return newTimestampColumn
} }
func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return i.fromImpl(subQuery).(ColumnTimestamp)
}
// TimestampColumn creates named timestamp column // TimestampColumn creates named timestamp column
func TimestampColumn(name string) ColumnTimestamp { func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{} timestampColumn := &timestampColumnImpl{}
timestampColumn.timestampInterfaceImpl.parent = timestampColumn timestampColumn.timestampInterfaceImpl.parent = timestampColumn
timestampColumn.columnImpl = newColumn(name, "", timestampColumn) timestampColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timestampColumn)
return timestampColumn return timestampColumn
} }
@ -271,11 +236,10 @@ type ColumnTimestampz interface {
type timestampzColumnImpl struct { type timestampzColumnImpl struct {
timestampzInterfaceImpl timestampzInterfaceImpl
ColumnExpressionImpl
columnImpl
} }
func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
newTimestampzColumn := TimestampzColumn(i.name) newTimestampzColumn := TimestampzColumn(i.name)
newTimestampzColumn.setTableName(i.tableName) newTimestampzColumn.setTableName(i.tableName)
newTimestampzColumn.setSubQuery(subQuery) newTimestampzColumn.setSubQuery(subQuery)
@ -283,15 +247,11 @@ func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newTimestampzColumn return newTimestampzColumn
} }
func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return i.fromImpl(subQuery).(ColumnTimestampz)
}
// TimestampzColumn creates named timestamp with time zone column. // TimestampzColumn creates named timestamp with time zone column.
func TimestampzColumn(name string) ColumnTimestampz { func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{} timestampzColumn := &timestampzColumnImpl{}
timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn
timestampzColumn.columnImpl = newColumn(name, "", timestampzColumn) timestampzColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timestampzColumn)
return timestampzColumn return timestampzColumn
} }
@ -308,11 +268,10 @@ type ColumnDate interface {
type dateColumnImpl struct { type dateColumnImpl struct {
dateInterfaceImpl dateInterfaceImpl
ColumnExpressionImpl
columnImpl
} }
func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection { func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
newDateColumn := DateColumn(i.name) newDateColumn := DateColumn(i.name)
newDateColumn.setTableName(i.tableName) newDateColumn.setTableName(i.tableName)
newDateColumn.setSubQuery(subQuery) newDateColumn.setSubQuery(subQuery)
@ -320,14 +279,10 @@ func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection {
return newDateColumn return newDateColumn
} }
func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return i.fromImpl(subQuery).(ColumnDate)
}
// DateColumn creates named date column. // DateColumn creates named date column.
func DateColumn(name string) ColumnDate { func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{} dateColumn := &dateColumnImpl{}
dateColumn.dateInterfaceImpl.parent = dateColumn dateColumn.dateInterfaceImpl.parent = dateColumn
dateColumn.columnImpl = newColumn(name, "", dateColumn) dateColumn.ColumnExpressionImpl = NewColumnImpl(name, "", dateColumn)
return dateColumn return dateColumn
} }

View file

@ -43,5 +43,74 @@ func TestNewFloatColumnColumn(t *testing.T) {
assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`) assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`)
assertClauseSerialize(t, floatColumn2.EQ(Float(2.22)), `(sub_query."table1.col_float" = $1)`, float64(2.22)) assertClauseSerialize(t, floatColumn2.EQ(Float(2.22)), `(sub_query."table1.col_float" = $1)`, float64(2.22))
assertProjectionSerialize(t, floatColumn2, `sub_query."table1.col_float" AS "table1.col_float"`) assertProjectionSerialize(t, floatColumn2, `sub_query."table1.col_float" AS "table1.col_float"`)
}
func TestNewDateColumnColumn(t *testing.T) {
dateColumn := DateColumn("col_date").From(subQuery)
assertClauseSerialize(t, dateColumn, `sub_query."col_date"`)
assertClauseSerialize(t, dateColumn.EQ(Date(2002, 2, 3)),
`(sub_query."col_date" = $1)`, "2002-02-03")
assertProjectionSerialize(t, dateColumn, `sub_query."col_date" AS "col_date"`)
dateColumn2 := table1ColDate.From(subQuery)
assertClauseSerialize(t, dateColumn2, `sub_query."table1.col_date"`)
assertClauseSerialize(t, dateColumn2.EQ(Date(2002, 2, 3)),
`(sub_query."table1.col_date" = $1)`, "2002-02-03")
assertProjectionSerialize(t, dateColumn2, `sub_query."table1.col_date" AS "table1.col_date"`)
}
func TestNewTimeColumnColumn(t *testing.T) {
timeColumn := TimeColumn("col_time").From(subQuery)
assertClauseSerialize(t, timeColumn, `sub_query."col_time"`)
assertClauseSerialize(t, timeColumn.EQ(Time(1, 1, 1, 1)),
`(sub_query."col_time" = $1)`, "01:01:01.000000001")
assertProjectionSerialize(t, timeColumn, `sub_query."col_time" AS "col_time"`)
timeColumn2 := table1ColTime.From(subQuery)
assertClauseSerialize(t, timeColumn2, `sub_query."table1.col_time"`)
assertClauseSerialize(t, timeColumn2.EQ(Time(2, 2, 2)),
`(sub_query."table1.col_time" = $1)`, "02:02:02")
assertProjectionSerialize(t, timeColumn2, `sub_query."table1.col_time" AS "table1.col_time"`)
}
func TestNewTimezColumnColumn(t *testing.T) {
timezColumn := TimezColumn("col_timez").From(subQuery)
assertClauseSerialize(t, timezColumn, `sub_query."col_timez"`)
assertClauseSerialize(t, timezColumn.EQ(Timez(1, 1, 1, 1, "UTC")),
`(sub_query."col_timez" = $1)`, "01:01:01.000000001 UTC")
assertProjectionSerialize(t, timezColumn, `sub_query."col_timez" AS "col_timez"`)
timezColumn2 := table1ColTimez.From(subQuery)
assertClauseSerialize(t, timezColumn2, `sub_query."table1.col_timez"`)
assertClauseSerialize(t, timezColumn2.EQ(Timez(2, 2, 2, 0, "UTC")),
`(sub_query."table1.col_timez" = $1)`, "02:02:02 UTC")
assertProjectionSerialize(t, timezColumn2, `sub_query."table1.col_timez" AS "table1.col_timez"`)
}
func TestNewTimestampColumnColumn(t *testing.T) {
timestampColumn := TimestampColumn("col_timestamp").From(subQuery)
assertClauseSerialize(t, timestampColumn, `sub_query."col_timestamp"`)
assertClauseSerialize(t, timestampColumn.EQ(Timestamp(1, 1, 1, 1, 1, 1)),
`(sub_query."col_timestamp" = $1)`, "0001-01-01 01:01:01")
assertProjectionSerialize(t, timestampColumn, `sub_query."col_timestamp" AS "col_timestamp"`)
timestampColumn2 := table1ColTimestamp.From(subQuery)
assertClauseSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp"`)
assertClauseSerialize(t, timestampColumn2.EQ(Timestamp(2, 2, 2, 2, 2, 2)),
`(sub_query."table1.col_timestamp" = $1)`, "0002-02-02 02:02:02")
assertProjectionSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp" AS "table1.col_timestamp"`)
}
func TestNewTimestampzColumnColumn(t *testing.T) {
timestampzColumn := TimestampzColumn("col_timestampz").From(subQuery)
assertClauseSerialize(t, timestampzColumn, `sub_query."col_timestampz"`)
assertClauseSerialize(t, timestampzColumn.EQ(Timestampz(1, 1, 1, 1, 1, 1, 0, "UTC")),
`(sub_query."col_timestampz" = $1)`, "0001-01-01 01:01:01 UTC")
assertProjectionSerialize(t, timestampzColumn, `sub_query."col_timestampz" AS "col_timestampz"`)
timestampzColumn2 := table1ColTimestampz.From(subQuery)
assertClauseSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz"`)
assertClauseSerialize(t, timestampzColumn2.EQ(Timestampz(2, 2, 2, 2, 2, 2, 0, "UTC")),
`(sub_query."table1.col_timestampz" = $1)`, "0002-02-02 02:02:02 UTC")
assertProjectionSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz" AS "table1.col_timestampz"`)
} }

View file

@ -13,42 +13,53 @@ type DateExpression interface {
LT_EQ(rhs DateExpression) BoolExpression LT_EQ(rhs DateExpression) BoolExpression
GT(rhs DateExpression) BoolExpression GT(rhs DateExpression) BoolExpression
GT_EQ(rhs DateExpression) BoolExpression GT_EQ(rhs DateExpression) BoolExpression
ADD(rhs Interval) TimestampExpression
SUB(rhs Interval) TimestampExpression
} }
type dateInterfaceImpl struct { type dateInterfaceImpl struct {
parent DateExpression parent DateExpression
} }
func (t *dateInterfaceImpl) EQ(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) EQ(rhs DateExpression) BoolExpression {
return eq(t.parent, rhs) return Eq(d.parent, rhs)
} }
func (t *dateInterfaceImpl) NOT_EQ(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) NOT_EQ(rhs DateExpression) BoolExpression {
return notEq(t.parent, rhs) return NotEq(d.parent, rhs)
} }
func (t *dateInterfaceImpl) IS_DISTINCT_FROM(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) IS_DISTINCT_FROM(rhs DateExpression) BoolExpression {
return isDistinctFrom(t.parent, rhs) return IsDistinctFrom(d.parent, rhs)
} }
func (t *dateInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs DateExpression) BoolExpression {
return isNotDistinctFrom(t.parent, rhs) return IsNotDistinctFrom(d.parent, rhs)
} }
func (t *dateInterfaceImpl) LT(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) LT(rhs DateExpression) BoolExpression {
return lt(t.parent, rhs) return Lt(d.parent, rhs)
} }
func (t *dateInterfaceImpl) LT_EQ(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) LT_EQ(rhs DateExpression) BoolExpression {
return ltEq(t.parent, rhs) return LtEq(d.parent, rhs)
} }
func (t *dateInterfaceImpl) GT(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) GT(rhs DateExpression) BoolExpression {
return gt(t.parent, rhs) return Gt(d.parent, rhs)
} }
func (t *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression { func (d *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression {
return gtEq(t.parent, rhs) return GtEq(d.parent, rhs)
}
func (d *dateInterfaceImpl) ADD(rhs Interval) TimestampExpression {
return TimestampExp(Add(d.parent, rhs))
}
func (d *dateInterfaceImpl) SUB(rhs Interval) TimestampExpression {
return TimestampExp(Sub(d.parent, rhs))
} }
//---------------------------------------------------// //---------------------------------------------------//

View file

@ -0,0 +1,13 @@
package jet
import (
"testing"
)
func TestDateArithmetic(t *testing.T) {
timestamp := Timestamp(2000, 1, 1, 0, 0, 0)
assertClauseDebugSerialize(t, table1ColDate.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_date + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
assertClauseDebugSerialize(t, table1ColDate.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_date - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
}

View file

@ -1,5 +1,7 @@
package jet package jet
import "strings"
// Dialect interface // Dialect interface
type Dialect interface { type Dialect interface {
Name() string Name() string
@ -9,13 +11,14 @@ type Dialect interface {
AliasQuoteChar() byte AliasQuoteChar() byte
IdentifierQuoteChar() byte IdentifierQuoteChar() byte
ArgumentPlaceholder() QueryPlaceholderFunc ArgumentPlaceholder() QueryPlaceholderFunc
IsReservedWord(name string) bool
} }
// SerializeFunc func // SerializerFunc func
type SerializeFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption) type SerializerFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption)
// SerializeOverride func // SerializeOverride func
type SerializeOverride func(expressions ...Expression) SerializeFunc type SerializeOverride func(expressions ...Serializer) SerializerFunc
// QueryPlaceholderFunc func // QueryPlaceholderFunc func
type QueryPlaceholderFunc func(ord int) string type QueryPlaceholderFunc func(ord int) string
@ -29,6 +32,7 @@ type DialectParams struct {
AliasQuoteChar byte AliasQuoteChar byte
IdentifierQuoteChar byte IdentifierQuoteChar byte
ArgumentPlaceholder QueryPlaceholderFunc ArgumentPlaceholder QueryPlaceholderFunc
ReservedWords []string
} }
// NewDialect creates new dialect with params // NewDialect creates new dialect with params
@ -41,6 +45,7 @@ func NewDialect(params DialectParams) Dialect {
aliasQuoteChar: params.AliasQuoteChar, aliasQuoteChar: params.AliasQuoteChar,
identifierQuoteChar: params.IdentifierQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar,
argumentPlaceholder: params.ArgumentPlaceholder, argumentPlaceholder: params.ArgumentPlaceholder,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
} }
} }
@ -52,6 +57,7 @@ type dialectImpl struct {
aliasQuoteChar byte aliasQuoteChar byte
identifierQuoteChar byte identifierQuoteChar byte
argumentPlaceholder QueryPlaceholderFunc argumentPlaceholder QueryPlaceholderFunc
reservedWords map[string]bool
supportsReturning bool supportsReturning bool
} }
@ -89,3 +95,17 @@ func (d *dialectImpl) IdentifierQuoteChar() byte {
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
return d.argumentPlaceholder return d.argumentPlaceholder
} }
func (d *dialectImpl) IsReservedWord(name string) bool {
_, isReservedWord := d.reservedWords[strings.ToLower(name)]
return isReservedWord
}
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {
ret[strings.ToLower(elem)] = true
}
return ret
}

View file

@ -1,7 +1,7 @@
package jet package jet
type enumValue struct { type enumValue struct {
expressionInterfaceImpl ExpressionInterfaceImpl
stringInterfaceImpl stringInterfaceImpl
name string name string
@ -11,7 +11,7 @@ type enumValue struct {
func NewEnumValue(name string) StringExpression { func NewEnumValue(name string) StringExpression {
enumValue := &enumValue{name: name} enumValue := &enumValue{name: name}
enumValue.expressionInterfaceImpl.Parent = enumValue enumValue.ExpressionInterfaceImpl.Parent = enumValue
enumValue.stringInterfaceImpl.parent = enumValue enumValue.stringInterfaceImpl.parent = enumValue
return enumValue return enumValue

View file

@ -8,82 +8,93 @@ type Expression interface {
GroupByClause GroupByClause
OrderByClause OrderByClause
// Test expression whether it is a NULL value. // IS_NULL tests expression whether it is a NULL value.
IS_NULL() BoolExpression IS_NULL() BoolExpression
// Test expression whether it is a non-NULL value. // IS_NOT_NULL tests expression whether it is a non-NULL value.
IS_NOT_NULL() BoolExpression IS_NOT_NULL() BoolExpression
// Check if this expressions matches any in expressions list // IN checks if this expressions matches any in expressions list
IN(expressions ...Expression) BoolExpression IN(expressions ...Expression) BoolExpression
// Check if this expressions is different of all expressions in expressions list // NOT_IN checks if this expressions is different of all expressions in expressions list
NOT_IN(expressions ...Expression) BoolExpression NOT_IN(expressions ...Expression) BoolExpression
// The temporary alias name to assign to the expression // AS the temporary alias name to assign to the expression
AS(alias string) Projection AS(alias string) Projection
// Expression will be used to sort query result in ascending order // ASC expression will be used to sort query result in ascending order
ASC() OrderByClause ASC() OrderByClause
// Expression will be used to sort query result in ascending order // DESC expression will be used to sort query result in ascending order
DESC() OrderByClause DESC() OrderByClause
} }
type expressionInterfaceImpl struct { // ExpressionInterfaceImpl implements Expression interface methods
type ExpressionInterfaceImpl struct {
Parent Expression Parent Expression
} }
func (e *expressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection {
return e.Parent return e.Parent
} }
func (e *expressionInterfaceImpl) IS_NULL() BoolExpression { // IS_NULL tests expression whether it is a NULL value.
return newPostifxBoolExpression(e.Parent, "IS NULL") func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostfixBoolOperatorExpression(e.Parent, "IS NULL")
} }
func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression { // IS_NOT_NULL tests expression whether it is a non-NULL value.
return newPostifxBoolExpression(e.Parent, "IS NOT NULL") func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression {
return newPostfixBoolOperatorExpression(e.Parent, "IS NOT NULL")
} }
func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { // IN checks if this expressions matches any in expressions list
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN") func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN")
} }
func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { // NOT_IN checks if this expressions is different of all expressions in expressions list
return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN") func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression {
return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN")
} }
func (e *expressionInterfaceImpl) AS(alias string) Projection { // AS the temporary alias name to assign to the expression
func (e *ExpressionInterfaceImpl) AS(alias string) Projection {
return newAlias(e.Parent, alias) return newAlias(e.Parent, alias)
} }
func (e *expressionInterfaceImpl) ASC() OrderByClause { // ASC expression will be used to sort query result in ascending order
func (e *ExpressionInterfaceImpl) ASC() OrderByClause {
return newOrderByClause(e.Parent, true) return newOrderByClause(e.Parent, true)
} }
func (e *expressionInterfaceImpl) DESC() OrderByClause { // DESC expression will be used to sort query result in ascending order
func (e *ExpressionInterfaceImpl) DESC() OrderByClause {
return newOrderByClause(e.Parent, false) return newOrderByClause(e.Parent, false)
} }
func (e *expressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap) e.Parent.serialize(statement, out, noWrap)
} }
func (e *expressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap) e.Parent.serialize(statement, out, noWrap)
} }
func (e *expressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) {
e.Parent.serialize(statement, out, noWrap) e.Parent.serialize(statement, out, noWrap)
} }
// Representation of binary operations (e.g. comparisons, arithmetic) // Representation of binary operations (e.g. comparisons, arithmetic)
type binaryOpExpression struct { type binaryOperatorExpression struct {
lhs, rhs Expression ExpressionInterfaceImpl
additionalParam Expression
lhs, rhs Serializer
additionalParam Serializer
operator string operator string
} }
func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam ...Expression) binaryOpExpression { // NewBinaryOperatorExpression creates new binaryOperatorExpression
binaryExpression := binaryOpExpression{ func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression {
binaryExpression := &binaryOperatorExpression{
lhs: lhs, lhs: lhs,
rhs: rhs, rhs: rhs,
operator: operator, operator: operator,
@ -93,10 +104,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam .
binaryExpression.additionalParam = additionalParam[0] binaryExpression.additionalParam = additionalParam[0]
} }
binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression
return binaryExpression return binaryExpression
} }
func (c *binaryOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.lhs == nil { if c.lhs == nil {
panic("jet: lhs is nil for '" + c.operator + "' operator") panic("jet: lhs is nil for '" + c.operator + "' operator")
} }
@ -125,21 +138,24 @@ func (c *binaryOpExpression) serialize(statement StatementType, out *SQLBuilder,
} }
// A prefix operator Expression // A prefix operator Expression
type prefixOpExpression struct { type prefixExpression struct {
ExpressionInterfaceImpl
expression Expression expression Expression
operator string operator string
} }
func newPrefixExpression(expression Expression, operator string) prefixOpExpression { func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression {
prefixExpression := prefixOpExpression{ prefixExpression := &prefixExpression{
expression: expression, expression: expression,
operator: operator, operator: operator,
} }
prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression
return prefixExpression return prefixExpression
} }
func (p *prefixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(") out.WriteString("(")
out.WriteString(p.operator) out.WriteString(p.operator)
@ -152,18 +168,22 @@ func (p *prefixOpExpression) serialize(statement StatementType, out *SQLBuilder,
out.WriteString(")") out.WriteString(")")
} }
// A postifx operator Expression // A postfix operator Expression
type postfixOpExpression struct { type postfixOpExpression struct {
ExpressionInterfaceImpl
expression Expression expression Expression
operator string operator string
} }
func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression { func newPostfixOperatorExpression(expression Expression, operator string) *postfixOpExpression {
postfixOpExpression := postfixOpExpression{ postfixOpExpression := &postfixOpExpression{
expression: expression, expression: expression,
operator: operator, operator: operator,
} }
postfixOpExpression.ExpressionInterfaceImpl.Parent = postfixOpExpression
return postfixOpExpression return postfixOpExpression
} }

View file

@ -29,78 +29,59 @@ type floatInterfaceImpl struct {
} }
func (n *floatInterfaceImpl) EQ(rhs FloatExpression) BoolExpression { func (n *floatInterfaceImpl) EQ(rhs FloatExpression) BoolExpression {
return eq(n.parent, rhs) return Eq(n.parent, rhs)
} }
func (n *floatInterfaceImpl) NOT_EQ(rhs FloatExpression) BoolExpression { func (n *floatInterfaceImpl) NOT_EQ(rhs FloatExpression) BoolExpression {
return notEq(n.parent, rhs) return NotEq(n.parent, rhs)
} }
func (n *floatInterfaceImpl) IS_DISTINCT_FROM(rhs FloatExpression) BoolExpression { func (n *floatInterfaceImpl) IS_DISTINCT_FROM(rhs FloatExpression) BoolExpression {
return isDistinctFrom(n.parent, rhs) return IsDistinctFrom(n.parent, rhs)
} }
func (n *floatInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs FloatExpression) BoolExpression { func (n *floatInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs FloatExpression) BoolExpression {
return isNotDistinctFrom(n.parent, rhs) return IsNotDistinctFrom(n.parent, rhs)
} }
func (n *floatInterfaceImpl) GT(rhs FloatExpression) BoolExpression { func (n *floatInterfaceImpl) GT(rhs FloatExpression) BoolExpression {
return gt(n.parent, rhs) return Gt(n.parent, rhs)
} }
func (n *floatInterfaceImpl) GT_EQ(rhs FloatExpression) BoolExpression { func (n *floatInterfaceImpl) GT_EQ(rhs FloatExpression) BoolExpression {
return gtEq(n.parent, rhs) return GtEq(n.parent, rhs)
} }
func (n *floatInterfaceImpl) LT(expression FloatExpression) BoolExpression { func (n *floatInterfaceImpl) LT(rhs FloatExpression) BoolExpression {
return lt(n.parent, expression) return Lt(n.parent, rhs)
} }
func (n *floatInterfaceImpl) LT_EQ(expression FloatExpression) BoolExpression { func (n *floatInterfaceImpl) LT_EQ(rhs FloatExpression) BoolExpression {
return ltEq(n.parent, expression) return LtEq(n.parent, rhs)
} }
func (n *floatInterfaceImpl) ADD(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) ADD(rhs NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "+") return FloatExp(Add(n.parent, rhs))
} }
func (n *floatInterfaceImpl) SUB(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) SUB(rhs NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "-") return FloatExp(Sub(n.parent, rhs))
} }
func (n *floatInterfaceImpl) MUL(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) MUL(rhs NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "*") return FloatExp(Mul(n.parent, rhs))
} }
func (n *floatInterfaceImpl) DIV(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) DIV(rhs NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "/") return FloatExp(Div(n.parent, rhs))
} }
func (n *floatInterfaceImpl) MOD(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) MOD(rhs NumericExpression) FloatExpression {
return newBinaryFloatExpression(n.parent, expression, "%") return FloatExp(Mod(n.parent, rhs))
} }
func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression { func (n *floatInterfaceImpl) POW(rhs NumericExpression) FloatExpression {
return POW(n.parent, expression) return POW(n.parent, rhs)
}
//---------------------------------------------------//
type binaryFloatExpression struct {
expressionInterfaceImpl
floatInterfaceImpl
binaryOpExpression
}
func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpression {
floatExpression := binaryFloatExpression{}
floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
floatExpression.expressionInterfaceImpl.Parent = &floatExpression
floatExpression.floatInterfaceImpl.parent = &floatExpression
return &floatExpression
} }
//---------------------------------------------------// //---------------------------------------------------//

View file

@ -578,7 +578,7 @@ func LEAST(value Expression, values ...Expression) Expression {
//--------------------------------------------------------------------// //--------------------------------------------------------------------//
type funcExpressionImpl struct { type funcExpressionImpl struct {
expressionInterfaceImpl ExpressionInterfaceImpl
name string name string
expressions []Expression expressions []Expression
@ -592,9 +592,9 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
} }
if parent != nil { if parent != nil {
funcExp.expressionInterfaceImpl.Parent = parent funcExp.ExpressionInterfaceImpl.Parent = parent
} else { } else {
funcExp.expressionInterfaceImpl.Parent = funcExp funcExp.ExpressionInterfaceImpl.Parent = funcExp
} }
return funcExp return funcExp
@ -605,14 +605,14 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := newFunc(name, expressions, nil) newFun := newFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun) windowExpr := newWindowExpression(newFun)
newFun.expressionInterfaceImpl.Parent = windowExpr newFun.ExpressionInterfaceImpl.Parent = windowExpr
return windowExpr return windowExpr
} }
func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(f.expressions...) serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...)
serializeOverrideFunc(statement, out, options...) serializeOverrideFunc(statement, out, options...)
return return
} }
@ -642,7 +642,7 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression {
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
boolFunc.boolInterfaceImpl.parent = boolFunc boolFunc.boolInterfaceImpl.parent = boolFunc
boolFunc.expressionInterfaceImpl.Parent = boolFunc boolFunc.ExpressionInterfaceImpl.Parent = boolFunc
return boolFunc return boolFunc
} }
@ -654,7 +654,7 @@ func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpress
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc)
intWindowFunc := newBoolWindowExpression(boolFunc) intWindowFunc := newBoolWindowExpression(boolFunc)
boolFunc.boolInterfaceImpl.parent = intWindowFunc boolFunc.boolInterfaceImpl.parent = intWindowFunc
boolFunc.expressionInterfaceImpl.Parent = intWindowFunc boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc return intWindowFunc
} }
@ -681,7 +681,7 @@ func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpre
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc)
floatWindowFunc := newFloatWindowExpression(floatFunc) floatWindowFunc := newFloatWindowExpression(floatFunc)
floatFunc.floatInterfaceImpl.parent = floatWindowFunc floatFunc.floatInterfaceImpl.parent = floatWindowFunc
floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc
return floatWindowFunc return floatWindowFunc
} }
@ -707,7 +707,7 @@ func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowE
integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc)
intWindowFunc := newIntegerWindowExpression(integerFunc) intWindowFunc := newIntegerWindowExpression(integerFunc)
integerFunc.integerInterfaceImpl.parent = intWindowFunc integerFunc.integerInterfaceImpl.parent = intWindowFunc
integerFunc.expressionInterfaceImpl.Parent = intWindowFunc integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc
return intWindowFunc return intWindowFunc
} }

View file

@ -54,134 +54,89 @@ type integerInterfaceImpl struct {
} }
func (i *integerInterfaceImpl) EQ(rhs IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) EQ(rhs IntegerExpression) BoolExpression {
return eq(i.parent, rhs) return Eq(i.parent, rhs)
} }
func (i *integerInterfaceImpl) NOT_EQ(rhs IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) NOT_EQ(rhs IntegerExpression) BoolExpression {
return notEq(i.parent, rhs) return NotEq(i.parent, rhs)
} }
func (i *integerInterfaceImpl) IS_DISTINCT_FROM(rhs IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) IS_DISTINCT_FROM(rhs IntegerExpression) BoolExpression {
return isDistinctFrom(i.parent, rhs) return IsDistinctFrom(i.parent, rhs)
} }
func (i *integerInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntegerExpression) BoolExpression {
return isNotDistinctFrom(i.parent, rhs) return IsNotDistinctFrom(i.parent, rhs)
} }
func (i *integerInterfaceImpl) GT(rhs IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) GT(rhs IntegerExpression) BoolExpression {
return gt(i.parent, rhs) return Gt(i.parent, rhs)
} }
func (i *integerInterfaceImpl) GT_EQ(rhs IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) GT_EQ(rhs IntegerExpression) BoolExpression {
return gtEq(i.parent, rhs) return GtEq(i.parent, rhs)
} }
func (i *integerInterfaceImpl) LT(expression IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) LT(rhs IntegerExpression) BoolExpression {
return lt(i.parent, expression) return Lt(i.parent, rhs)
} }
func (i *integerInterfaceImpl) LT_EQ(expression IntegerExpression) BoolExpression { func (i *integerInterfaceImpl) LT_EQ(rhs IntegerExpression) BoolExpression {
return ltEq(i.parent, expression) return LtEq(i.parent, rhs)
} }
func (i *integerInterfaceImpl) ADD(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) ADD(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "+") return IntExp(Add(i.parent, rhs))
} }
func (i *integerInterfaceImpl) SUB(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) SUB(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "-") return IntExp(Sub(i.parent, rhs))
} }
func (i *integerInterfaceImpl) MUL(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) MUL(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "*") return IntExp(Mul(i.parent, rhs))
} }
func (i *integerInterfaceImpl) DIV(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) DIV(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "/") return IntExp(Div(i.parent, rhs))
} }
func (i *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) MOD(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "%") return IntExp(Mod(i.parent, rhs))
} }
func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) POW(rhs IntegerExpression) IntegerExpression {
return IntExp(POW(i.parent, expression)) return IntExp(POW(i.parent, rhs))
} }
func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) BIT_AND(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "&") return newBinaryIntegerOperatorExpression(i.parent, rhs, "&")
} }
func (i *integerInterfaceImpl) BIT_OR(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) BIT_OR(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "|") return newBinaryIntegerOperatorExpression(i.parent, rhs, "|")
} }
func (i *integerInterfaceImpl) BIT_XOR(expression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) BIT_XOR(rhs IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, expression, "#") return newBinaryIntegerOperatorExpression(i.parent, rhs, "#")
} }
func (i *integerInterfaceImpl) BIT_SHIFT_LEFT(intExpression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) BIT_SHIFT_LEFT(intExpression IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, intExpression, "<<") return newBinaryIntegerOperatorExpression(i.parent, intExpression, "<<")
} }
func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) IntegerExpression { func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) IntegerExpression {
return newBinaryIntegerExpression(i.parent, intExpression, ">>") return newBinaryIntegerOperatorExpression(i.parent, intExpression, ">>")
} }
//---------------------------------------------------// //---------------------------------------------------//
type binaryIntegerExpression struct { func newBinaryIntegerOperatorExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression {
expressionInterfaceImpl return IntExp(NewBinaryOperatorExpression(lhs, rhs, operator))
integerInterfaceImpl
binaryOpExpression
}
func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression {
integerExpression := binaryIntegerExpression{}
integerExpression.expressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression
integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
return &integerExpression
} }
//---------------------------------------------------// //---------------------------------------------------//
type prefixIntegerOpExpression struct { func newPrefixIntegerOperatorExpression(expression IntegerExpression, operator string) IntegerExpression {
expressionInterfaceImpl return IntExp(newPrefixOperatorExpression(expression, operator))
integerInterfaceImpl
prefixOpExpression
}
func newPrefixIntegerOperator(expression IntegerExpression, operator string) IntegerExpression {
integerExpression := prefixIntegerOpExpression{}
integerExpression.prefixOpExpression = newPrefixExpression(expression, operator)
integerExpression.expressionInterfaceImpl.Parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression
return &integerExpression
}
//---------------------------------------------------//
type prefixFloatOpExpression struct {
expressionInterfaceImpl
floatInterfaceImpl
prefixOpExpression
}
func newPrefixFloatOperator(expression FloatExpression, operator string) FloatExpression {
floatOpExpression := prefixFloatOpExpression{}
floatOpExpression.prefixOpExpression = newPrefixExpression(expression, operator)
floatOpExpression.expressionInterfaceImpl.Parent = &floatOpExpression
floatOpExpression.floatInterfaceImpl.parent = &floatOpExpression
return &floatOpExpression
} }
//---------------------------------------------------// //---------------------------------------------------//

37
internal/jet/interval.go Normal file
View file

@ -0,0 +1,37 @@
package jet
// Interval is internal common representation of sql interval
type Interval interface {
Serializer
IsInterval
}
// IsInterval interface
type IsInterval interface {
isInterval()
}
// IsIntervalImpl is implementation of IsInterval interface
type IsIntervalImpl struct{}
func (i *IsIntervalImpl) isInterval() {}
// NewInterval creates new interval from serializer
func NewInterval(s Serializer) *IntervalImpl {
newInterval := &IntervalImpl{
interval: s,
}
return newInterval
}
// IntervalImpl is implementation of Interval type
type IntervalImpl struct {
interval Serializer
IsIntervalImpl
}
func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("INTERVAL")
i.interval.serialize(statement, out, options...)
}

View file

@ -14,7 +14,7 @@ type LiteralExpression interface {
} }
type literalExpressionImpl struct { type literalExpressionImpl struct {
expressionInterfaceImpl ExpressionInterfaceImpl
value interface{} value interface{}
constant bool constant bool
@ -27,11 +27,17 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl
exp.constant = optionalConstant[0] exp.constant = optionalConstant[0]
} }
exp.expressionInterfaceImpl.Parent = &exp exp.ExpressionInterfaceImpl.Parent = &exp
return &exp return &exp
} }
// Literal is injected directly to SQL query, and does not appear in parametrized argument list.
func Literal(value interface{}) *literalExpressionImpl {
exp := literal(value)
return exp
}
// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. // FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list.
func FixedLiteral(value interface{}) *literalExpressionImpl { func FixedLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value) exp := literal(value)
@ -273,13 +279,13 @@ func formatNanoseconds(nanoseconds ...time.Duration) string {
//--------------------------------------------------// //--------------------------------------------------//
type nullLiteral struct { type nullLiteral struct {
expressionInterfaceImpl ExpressionInterfaceImpl
} }
func newNullLiteral() Expression { func newNullLiteral() Expression {
nullExpression := &nullLiteral{} nullExpression := &nullLiteral{}
nullExpression.expressionInterfaceImpl.Parent = nullExpression nullExpression.ExpressionInterfaceImpl.Parent = nullExpression
return nullExpression return nullExpression
} }
@ -290,13 +296,13 @@ func (n *nullLiteral) serialize(statement StatementType, out *SQLBuilder, option
//--------------------------------------------------// //--------------------------------------------------//
type starLiteral struct { type starLiteral struct {
expressionInterfaceImpl ExpressionInterfaceImpl
} }
func newStarLiteral() Expression { func newStarLiteral() Expression {
starExpression := &starLiteral{} starExpression := &starLiteral{}
starExpression.expressionInterfaceImpl.Parent = starExpression starExpression.ExpressionInterfaceImpl.Parent = starExpression
return starExpression return starExpression
} }
@ -308,7 +314,7 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option
//---------------------------------------------------// //---------------------------------------------------//
type wrap struct { type wrap struct {
expressionInterfaceImpl ExpressionInterfaceImpl
expressions []Expression expressions []Expression
} }
@ -321,7 +327,7 @@ func (n *wrap) serialize(statement StatementType, out *SQLBuilder, options ...Se
// WRAP wraps list of expressions with brackets '(' and ')' // WRAP wraps list of expressions with brackets '(' and ')'
func WRAP(expression ...Expression) Expression { func WRAP(expression ...Expression) Expression {
wrap := &wrap{expressions: expression} wrap := &wrap{expressions: expression}
wrap.expressionInterfaceImpl.Parent = wrap wrap.ExpressionInterfaceImpl.Parent = wrap
return wrap return wrap
} }
@ -329,20 +335,20 @@ func WRAP(expression ...Expression) Expression {
//---------------------------------------------------// //---------------------------------------------------//
type rawExpression struct { type rawExpression struct {
expressionInterfaceImpl ExpressionInterfaceImpl
raw string Raw string
} }
func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(n.raw) out.WriteString(n.Raw)
} }
// Raw can be used for any unsupported functions, operators or expressions. // Raw can be used for any unsupported functions, operators or expressions.
// For example: Raw("current_database()") // For example: Raw("current_database()")
func Raw(raw string) Expression { func Raw(raw string, parent ...Expression) Expression {
rawExp := &rawExpression{raw: raw} rawExp := &rawExpression{Raw: raw}
rawExp.expressionInterfaceImpl.Parent = rawExp rawExp.ExpressionInterfaceImpl.Parent = OptionalOrDefaultExpression(rawExp, parent...)
return rawExp return rawExp
} }

View file

@ -11,7 +11,7 @@ const (
// NOT returns negation of bool expression result // NOT returns negation of bool expression result
func NOT(exp BoolExpression) BoolExpression { func NOT(exp BoolExpression) BoolExpression {
return newPrefixBoolOperator(exp, "NOT") return newPrefixBoolOperatorExpression(exp, "NOT")
} }
// BIT_NOT inverts every bit in integer expression result // BIT_NOT inverts every bit in integer expression result
@ -19,52 +19,79 @@ func BIT_NOT(expr IntegerExpression) IntegerExpression {
if literalExp, ok := expr.(LiteralExpression); ok { if literalExp, ok := expr.(LiteralExpression); ok {
literalExp.SetConstant(true) literalExp.SetConstant(true)
} }
return newPrefixIntegerOperator(expr, "~") return newPrefixIntegerOperatorExpression(expr, "~")
} }
//----------- Comparison operators ---------------// //----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery // EXISTS checks for existence of the rows in subQuery
func EXISTS(subQuery Expression) BoolExpression { func EXISTS(subQuery Expression) BoolExpression {
return newPrefixBoolOperator(subQuery, "EXISTS") return newPrefixBoolOperatorExpression(subQuery, "EXISTS")
} }
// Returns a representation of "a=b" // Eq returns a representation of "a=b"
func eq(lhs, rhs Expression) BoolExpression { func Eq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperator(lhs, rhs, "=") return newBinaryBoolOperatorExpression(lhs, rhs, "=")
} }
// Returns a representation of "a!=b" // NotEq returns a representation of "a!=b"
func notEq(lhs, rhs Expression) BoolExpression { func NotEq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperator(lhs, rhs, "!=") return newBinaryBoolOperatorExpression(lhs, rhs, "!=")
} }
func isDistinctFrom(lhs, rhs Expression) BoolExpression { // IsDistinctFrom returns a representation of "a IS DISTINCT FROM b"
return newBinaryBoolOperator(lhs, rhs, "IS DISTINCT FROM") func IsDistinctFrom(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperatorExpression(lhs, rhs, "IS DISTINCT FROM")
} }
func isNotDistinctFrom(lhs, rhs Expression) BoolExpression { // IsNotDistinctFrom returns a representation of "a IS NOT DISTINCT FROM b"
return newBinaryBoolOperator(lhs, rhs, "IS NOT DISTINCT FROM") func IsNotDistinctFrom(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperatorExpression(lhs, rhs, "IS NOT DISTINCT FROM")
} }
// Returns a representation of "a<b" // Lt returns a representation of "a<b"
func lt(lhs Expression, rhs Expression) BoolExpression { func Lt(lhs Expression, rhs Expression) BoolExpression {
return newBinaryBoolOperator(lhs, rhs, "<") return newBinaryBoolOperatorExpression(lhs, rhs, "<")
} }
// Returns a representation of "a<=b" // LtEq returns a representation of "a<=b"
func ltEq(lhs, rhs Expression) BoolExpression { func LtEq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperator(lhs, rhs, "<=") return newBinaryBoolOperatorExpression(lhs, rhs, "<=")
} }
// Returns a representation of "a>b" // Gt returns a representation of "a>b"
func gt(lhs, rhs Expression) BoolExpression { func Gt(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperator(lhs, rhs, ">") return newBinaryBoolOperatorExpression(lhs, rhs, ">")
} }
// Returns a representation of "a>=b" // GtEq returns a representation of "a>=b"
func gtEq(lhs, rhs Expression) BoolExpression { func GtEq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolOperator(lhs, rhs, ">=") return newBinaryBoolOperatorExpression(lhs, rhs, ">=")
}
// Add notEq returns a representation of "a + b"
func Add(lhs, rhs Serializer) Expression {
return NewBinaryOperatorExpression(lhs, rhs, "+")
}
// Sub notEq returns a representation of "a - b"
func Sub(lhs, rhs Serializer) Expression {
return NewBinaryOperatorExpression(lhs, rhs, "-")
}
// Mul returns a representation of "a * b"
func Mul(lhs, rhs Serializer) Expression {
return NewBinaryOperatorExpression(lhs, rhs, "*")
}
// Div returns a representation of "a / b"
func Div(lhs, rhs Serializer) Expression {
return NewBinaryOperatorExpression(lhs, rhs, "/")
}
// Mod returns a representation of "a % b"
func Mod(lhs, rhs Serializer) Expression {
return NewBinaryOperatorExpression(lhs, rhs, "%")
} }
// --------------- CASE operator -------------------// // --------------- CASE operator -------------------//
@ -79,7 +106,7 @@ type CaseOperator interface {
} }
type caseOperatorImpl struct { type caseOperatorImpl struct {
expressionInterfaceImpl ExpressionInterfaceImpl
expression Expression expression Expression
when []Expression when []Expression
@ -95,7 +122,7 @@ func CASE(expression ...Expression) CaseOperator {
caseExp.expression = expression[0] caseExp.expression = expression[0]
} }
caseExp.expressionInterfaceImpl.Parent = caseExp caseExp.ExpressionInterfaceImpl.Parent = caseExp
return caseExp return caseExp
} }

View file

@ -41,3 +41,18 @@ func contains(options []SerializeOption, option SerializeOption) bool {
return false return false
} }
// ListSerializer serializes list of serializers with separator
type ListSerializer struct {
Serializers []Serializer
Separator string
}
func (s ListSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
for i, ser := range s.Serializers {
if i > 0 {
out.WriteString(s.Separator)
}
ser.serialize(statement, out)
}
}

View file

@ -22,7 +22,7 @@ type SQLBuilder struct {
lastChar byte lastChar byte
ident int ident int
debug bool Debug bool
} }
const defaultIdent = 5 const defaultIdent = 5
@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) {
// WriteIdentifier adds identifier to output SQL // WriteIdentifier adds identifier to output SQL
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
if shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
s.WriteString(identQuoteChar + name + identQuoteChar) s.WriteString(identQuoteChar + name + identQuoteChar)
} else { } else {
@ -120,7 +120,7 @@ func (s *SQLBuilder) insertConstantArgument(arg interface{}) {
} }
func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) {
if s.debug { if s.Debug {
s.insertConstantArgument(arg) s.insertConstantArgument(arg)
return return
} }
@ -142,12 +142,8 @@ func argToString(value interface{}) string {
return "TRUE" return "TRUE"
} }
return "FALSE" return "FALSE"
case int: case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return strconv.FormatInt(int64(bindVal), 10) return integerTypesToString(bindVal)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case int64:
return strconv.FormatInt(bindVal, 10)
case float32: case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
@ -167,6 +163,32 @@ func argToString(value interface{}) string {
} }
} }
func integerTypesToString(value interface{}) string {
switch bindVal := value.(type) {
case int:
return strconv.FormatInt(int64(bindVal), 10)
case uint:
return strconv.FormatUint(uint64(bindVal), 10)
case int8:
return strconv.FormatInt(int64(bindVal), 10)
case uint8:
return strconv.FormatUint(uint64(bindVal), 10)
case int16:
return strconv.FormatInt(int64(bindVal), 10)
case uint16:
return strconv.FormatUint(uint64(bindVal), 10)
case int32:
return strconv.FormatInt(int64(bindVal), 10)
case uint32:
return strconv.FormatUint(uint64(bindVal), 10)
case int64:
return strconv.FormatInt(bindVal, 10)
case uint64:
return strconv.FormatUint(bindVal, 10)
}
panic("jet: Unsupported integer type: " + reflect.TypeOf(value).String())
}
func shouldQuoteIdentifier(identifier string) bool { func shouldQuoteIdentifier(identifier string) bool {
for _, c := range identifier { for _, c := range identifier {
if unicode.IsNumber(c) || c == '_' { if unicode.IsNumber(c) || c == '_' {

View file

@ -2,7 +2,7 @@ package jet
import ( import (
"github.com/google/uuid" "github.com/google/uuid"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -12,8 +12,16 @@ func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(false), "FALSE") assert.Equal(t, argToString(false), "FALSE")
assert.Equal(t, argToString(int(-32)), "-32") assert.Equal(t, argToString(int(-32)), "-32")
assert.Equal(t, argToString(int32(-32)), "-32") assert.Equal(t, argToString(uint(32)), "32")
assert.Equal(t, argToString(int8(-43)), "-43")
assert.Equal(t, argToString(uint8(43)), "43")
assert.Equal(t, argToString(int16(-54)), "-54")
assert.Equal(t, argToString(uint16(54)), "54")
assert.Equal(t, argToString(int32(-65)), "-65")
assert.Equal(t, argToString(uint32(65)), "65")
assert.Equal(t, argToString(int64(-64)), "-64") assert.Equal(t, argToString(int64(-64)), "-64")
assert.Equal(t, argToString(uint64(64)), "64")
assert.Equal(t, argToString(float32(2.0)), "2")
assert.Equal(t, argToString(float64(1.11)), "1.11") assert.Equal(t, argToString(float64(1.11)), "1.11")
assert.Equal(t, argToString("john"), "'john'") assert.Equal(t, argToString("john"), "'john'")
@ -22,7 +30,7 @@ func TestArgToString(t *testing.T) {
assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'")
time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'")
func() { func() {

View file

@ -65,7 +65,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface
} }
func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
sqlBuilder := &SQLBuilder{Dialect: s.dialect, debug: true} sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true}
s.parent.serialize(s.statementType, sqlBuilder, noWrap) s.parent.serialize(s.statementType, sqlBuilder, noWrap)
@ -106,7 +106,7 @@ type ExpressionStatement interface {
// NewExpressionStatementImpl creates new expression statement // NewExpressionStatementImpl creates new expression statement
func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement { func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement {
return &expressionStatementImpl{ return &expressionStatementImpl{
expressionInterfaceImpl{Parent: parent}, ExpressionInterfaceImpl{Parent: parent},
statementImpl{ statementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
parent: parent, parent: parent,
@ -119,7 +119,7 @@ func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, pa
} }
type expressionStatementImpl struct { type expressionStatementImpl struct {
expressionInterfaceImpl ExpressionInterfaceImpl
statementImpl statementImpl
} }

View file

@ -28,74 +28,60 @@ type stringInterfaceImpl struct {
} }
func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression {
return eq(s.parent, rhs) return Eq(s.parent, rhs)
} }
func (s *stringInterfaceImpl) NOT_EQ(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) NOT_EQ(rhs StringExpression) BoolExpression {
return notEq(s.parent, rhs) return NotEq(s.parent, rhs)
} }
func (s *stringInterfaceImpl) IS_DISTINCT_FROM(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) IS_DISTINCT_FROM(rhs StringExpression) BoolExpression {
return isDistinctFrom(s.parent, rhs) return IsDistinctFrom(s.parent, rhs)
} }
func (s *stringInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs StringExpression) BoolExpression {
return isNotDistinctFrom(s.parent, rhs) return IsNotDistinctFrom(s.parent, rhs)
} }
func (s *stringInterfaceImpl) GT(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) GT(rhs StringExpression) BoolExpression {
return gt(s.parent, rhs) return Gt(s.parent, rhs)
} }
func (s *stringInterfaceImpl) GT_EQ(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) GT_EQ(rhs StringExpression) BoolExpression {
return gtEq(s.parent, rhs) return GtEq(s.parent, rhs)
} }
func (s *stringInterfaceImpl) LT(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) LT(rhs StringExpression) BoolExpression {
return lt(s.parent, rhs) return Lt(s.parent, rhs)
} }
func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression { func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression {
return ltEq(s.parent, rhs) return LtEq(s.parent, rhs)
} }
func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression { func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression {
return newBinaryStringExpression(s.parent, rhs, StringConcatOperator) return newBinaryStringOperatorExpression(s.parent, rhs, StringConcatOperator)
} }
func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "LIKE") return newBinaryBoolOperatorExpression(s.parent, pattern, "LIKE")
} }
func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression { func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE") return newBinaryBoolOperatorExpression(s.parent, pattern, "NOT LIKE")
} }
func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) return newBinaryBoolOperatorExpression(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
} }
func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperator(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) return newBinaryBoolOperatorExpression(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0]))
} }
//---------------------------------------------------// //---------------------------------------------------//
func newBinaryStringOperatorExpression(lhs, rhs Expression, operator string) StringExpression {
type binaryStringExpression struct { return StringExp(NewBinaryOperatorExpression(lhs, rhs, operator))
expressionInterfaceImpl
stringInterfaceImpl
binaryOpExpression
}
func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpression {
boolExpression := binaryStringExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
boolExpression.expressionInterfaceImpl.Parent = &boolExpression
boolExpression.stringInterfaceImpl.parent = &boolExpression
return &boolExpression
} }
//---------------------------------------------------// //---------------------------------------------------//

View file

@ -1,7 +1,7 @@
package jet package jet
import ( import (
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )

View file

@ -1,7 +1,7 @@
package jet package jet
import ( import (
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"strconv" "strconv"
"testing" "testing"
) )
@ -14,36 +14,40 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests
}, },
}) })
var table1Col1 = IntegerColumn("col1") var (
var table1ColInt = IntegerColumn("col_int") table1Col1 = IntegerColumn("col1")
var table1ColFloat = FloatColumn("col_float") table1ColInt = IntegerColumn("col_int")
var table1Col3 = IntegerColumn("col3") table1ColFloat = FloatColumn("col_float")
var table1ColTime = TimeColumn("col_time") table1Col3 = IntegerColumn("col3")
var table1ColTimez = TimezColumn("col_timez") table1ColTime = TimeColumn("col_time")
var table1ColTimestamp = TimestampColumn("col_timestamp") table1ColTimez = TimezColumn("col_timez")
var table1ColTimestampz = TimestampzColumn("col_timestampz") table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColBool = BoolColumn("col_bool") table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColDate = DateColumn("col_date") table1ColBool = BoolColumn("col_bool")
table1ColDate = DateColumn("col_date")
)
var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz) var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz)
var table2Col3 = IntegerColumn("col3") var (
var table2Col4 = IntegerColumn("col4") table2Col3 = IntegerColumn("col3")
var table2ColInt = IntegerColumn("col_int") table2Col4 = IntegerColumn("col4")
var table2ColFloat = FloatColumn("col_float") table2ColInt = IntegerColumn("col_int")
var table2ColStr = StringColumn("col_str") table2ColFloat = FloatColumn("col_float")
var table2ColBool = BoolColumn("col_bool") table2ColStr = StringColumn("col_str")
var table2ColTime = TimeColumn("col_time") table2ColBool = BoolColumn("col_bool")
var table2ColTimez = TimezColumn("col_timez") table2ColTime = TimeColumn("col_time")
var table2ColTimestamp = TimestampColumn("col_timestamp") table2ColTimez = TimezColumn("col_timez")
var table2ColTimestampz = TimestampzColumn("col_timestampz") table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColDate = DateColumn("col_date") table2ColTimestampz = TimestampzColumn("col_timestampz")
table2ColDate = DateColumn("col_date")
)
var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz) var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz)
var table3Col1 = IntegerColumn("col1") var (
var table3ColInt = IntegerColumn("col_int") table3Col1 = IntegerColumn("col1")
var table3StrCol = StringColumn("col2") table3ColInt = IntegerColumn("col_int")
table3StrCol = StringColumn("col2")
)
var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol) var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol)
func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
@ -52,8 +56,8 @@ func assertClauseSerialize(t *testing.T, clause Serializer, query string, args .
//fmt.Println(out.Buff.String()) //fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query) assert.Equal(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args) assert.Equal(t, out.Args, args)
} }
func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) {
@ -67,19 +71,19 @@ func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string)
} }
func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) {
out := SQLBuilder{Dialect: defaultDialect, debug: true} out := SQLBuilder{Dialect: defaultDialect, Debug: true}
clause.serialize(SelectStatementType, &out) clause.serialize(SelectStatementType, &out)
//fmt.Println(out.Buff.String()) //fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query) assert.Equal(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args) assert.Equal(t, out.Args, args)
} }
func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) { func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) {
out := SQLBuilder{Dialect: defaultDialect} out := SQLBuilder{Dialect: defaultDialect}
projection.serializeForProjection(SelectStatementType, &out) projection.serializeForProjection(SelectStatementType, &out)
assert.DeepEqual(t, out.Buff.String(), query) assert.Equal(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args) assert.Equal(t, out.Args, args)
} }

View file

@ -13,6 +13,9 @@ type TimeExpression interface {
LT_EQ(rhs TimeExpression) BoolExpression LT_EQ(rhs TimeExpression) BoolExpression
GT(rhs TimeExpression) BoolExpression GT(rhs TimeExpression) BoolExpression
GT_EQ(rhs TimeExpression) BoolExpression GT_EQ(rhs TimeExpression) BoolExpression
ADD(rhs Interval) TimeExpression
SUB(rhs Interval) TimeExpression
} }
type timeInterfaceImpl struct { type timeInterfaceImpl struct {
@ -20,54 +23,44 @@ type timeInterfaceImpl struct {
} }
func (t *timeInterfaceImpl) EQ(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) EQ(rhs TimeExpression) BoolExpression {
return eq(t.parent, rhs) return Eq(t.parent, rhs)
} }
func (t *timeInterfaceImpl) NOT_EQ(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) NOT_EQ(rhs TimeExpression) BoolExpression {
return notEq(t.parent, rhs) return NotEq(t.parent, rhs)
} }
func (t *timeInterfaceImpl) IS_DISTINCT_FROM(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) IS_DISTINCT_FROM(rhs TimeExpression) BoolExpression {
return isDistinctFrom(t.parent, rhs) return IsDistinctFrom(t.parent, rhs)
} }
func (t *timeInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimeExpression) BoolExpression {
return isNotDistinctFrom(t.parent, rhs) return IsNotDistinctFrom(t.parent, rhs)
} }
func (t *timeInterfaceImpl) LT(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) LT(rhs TimeExpression) BoolExpression {
return lt(t.parent, rhs) return Lt(t.parent, rhs)
} }
func (t *timeInterfaceImpl) LT_EQ(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) LT_EQ(rhs TimeExpression) BoolExpression {
return ltEq(t.parent, rhs) return LtEq(t.parent, rhs)
} }
func (t *timeInterfaceImpl) GT(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) GT(rhs TimeExpression) BoolExpression {
return gt(t.parent, rhs) return Gt(t.parent, rhs)
} }
func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression { func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression {
return gtEq(t.parent, rhs) return GtEq(t.parent, rhs)
} }
//---------------------------------------------------// func (t *timeInterfaceImpl) ADD(rhs Interval) TimeExpression {
type prefixTimeExpression struct { return TimeExp(Add(t.parent, rhs))
expressionInterfaceImpl
timeInterfaceImpl
prefixOpExpression
} }
//func newPrefixTimeExpression(operator string, expression Expression) TimeExpression { func (t *timeInterfaceImpl) SUB(rhs Interval) TimeExpression {
// timeExpr := prefixTimeExpression{} return TimeExp(Sub(t.parent, rhs))
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) }
//
// timeExpr.expressionInterfaceImpl.parent = &timeExpr
// timeExpr.timeInterfaceImpl.parent = &timeExpr
//
// return &timeExpr
//}
//---------------------------------------------------// //---------------------------------------------------//

View file

@ -52,3 +52,11 @@ func TestTimeExp(t *testing.T) {
assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)), assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)),
"(table1.col_float < $1)", string("01:01:01.001")) "(table1.col_float < $1)", string("01:01:01.001"))
} }
func TestTimeArithmetic(t *testing.T) {
time := Time(10, 20, 3)
assertClauseDebugSerialize(t, table1ColTime.ADD(NewInterval(String("1 HOUR"))).EQ(time),
"((table1.col_time + INTERVAL '1 HOUR') = '10:20:03')")
assertClauseDebugSerialize(t, table1ColTime.SUB(NewInterval(String("1 HOUR"))).EQ(time),
"((table1.col_time - INTERVAL '1 HOUR') = '10:20:03')")
}

View file

@ -13,6 +13,9 @@ type TimestampExpression interface {
LT_EQ(rhs TimestampExpression) BoolExpression LT_EQ(rhs TimestampExpression) BoolExpression
GT(rhs TimestampExpression) BoolExpression GT(rhs TimestampExpression) BoolExpression
GT_EQ(rhs TimestampExpression) BoolExpression GT_EQ(rhs TimestampExpression) BoolExpression
ADD(rhs Interval) TimestampExpression
SUB(rhs Interval) TimestampExpression
} }
type timestampInterfaceImpl struct { type timestampInterfaceImpl struct {
@ -20,35 +23,43 @@ type timestampInterfaceImpl struct {
} }
func (t *timestampInterfaceImpl) EQ(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) EQ(rhs TimestampExpression) BoolExpression {
return eq(t.parent, rhs) return Eq(t.parent, rhs)
} }
func (t *timestampInterfaceImpl) NOT_EQ(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) NOT_EQ(rhs TimestampExpression) BoolExpression {
return notEq(t.parent, rhs) return NotEq(t.parent, rhs)
} }
func (t *timestampInterfaceImpl) IS_DISTINCT_FROM(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) IS_DISTINCT_FROM(rhs TimestampExpression) BoolExpression {
return isDistinctFrom(t.parent, rhs) return IsDistinctFrom(t.parent, rhs)
} }
func (t *timestampInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimestampExpression) BoolExpression {
return isNotDistinctFrom(t.parent, rhs) return IsNotDistinctFrom(t.parent, rhs)
} }
func (t *timestampInterfaceImpl) LT(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) LT(rhs TimestampExpression) BoolExpression {
return lt(t.parent, rhs) return Lt(t.parent, rhs)
} }
func (t *timestampInterfaceImpl) LT_EQ(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) LT_EQ(rhs TimestampExpression) BoolExpression {
return ltEq(t.parent, rhs) return LtEq(t.parent, rhs)
} }
func (t *timestampInterfaceImpl) GT(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) GT(rhs TimestampExpression) BoolExpression {
return gt(t.parent, rhs) return Gt(t.parent, rhs)
} }
func (t *timestampInterfaceImpl) GT_EQ(rhs TimestampExpression) BoolExpression { func (t *timestampInterfaceImpl) GT_EQ(rhs TimestampExpression) BoolExpression {
return gtEq(t.parent, rhs) return GtEq(t.parent, rhs)
}
func (t *timestampInterfaceImpl) ADD(rhs Interval) TimestampExpression {
return TimestampExp(Add(t.parent, rhs))
}
func (t *timestampInterfaceImpl) SUB(rhs Interval) TimestampExpression {
return TimestampExp(Sub(t.parent, rhs))
} }
//------------------------------------------------- //-------------------------------------------------

View file

@ -53,3 +53,11 @@ func TestTimestampExp(t *testing.T) {
assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp), assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp),
"(table1.col_float < $1)", "2000-01-31 10:20:00.003") "(table1.col_float < $1)", "2000-01-31 10:20:00.003")
} }
func TestTimestampArithmetic(t *testing.T) {
timestamp := Timestamp(2000, 1, 1, 0, 0, 0)
assertClauseDebugSerialize(t, table1ColTimestamp.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_timestamp + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
assertClauseDebugSerialize(t, table1ColTimestamp.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp),
"((table1.col_timestamp - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')")
}

View file

@ -13,6 +13,9 @@ type TimestampzExpression interface {
LT_EQ(rhs TimestampzExpression) BoolExpression LT_EQ(rhs TimestampzExpression) BoolExpression
GT(rhs TimestampzExpression) BoolExpression GT(rhs TimestampzExpression) BoolExpression
GT_EQ(rhs TimestampzExpression) BoolExpression GT_EQ(rhs TimestampzExpression) BoolExpression
ADD(rhs Interval) TimestampzExpression
SUB(rhs Interval) TimestampzExpression
} }
type timestampzInterfaceImpl struct { type timestampzInterfaceImpl struct {
@ -20,44 +23,43 @@ type timestampzInterfaceImpl struct {
} }
func (t *timestampzInterfaceImpl) EQ(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) EQ(rhs TimestampzExpression) BoolExpression {
return eq(t.parent, rhs) return Eq(t.parent, rhs)
} }
func (t *timestampzInterfaceImpl) NOT_EQ(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) NOT_EQ(rhs TimestampzExpression) BoolExpression {
return notEq(t.parent, rhs) return NotEq(t.parent, rhs)
} }
func (t *timestampzInterfaceImpl) IS_DISTINCT_FROM(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) IS_DISTINCT_FROM(rhs TimestampzExpression) BoolExpression {
return isDistinctFrom(t.parent, rhs) return IsDistinctFrom(t.parent, rhs)
} }
func (t *timestampzInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimestampzExpression) BoolExpression {
return isNotDistinctFrom(t.parent, rhs) return IsNotDistinctFrom(t.parent, rhs)
} }
func (t *timestampzInterfaceImpl) LT(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) LT(rhs TimestampzExpression) BoolExpression {
return lt(t.parent, rhs) return Lt(t.parent, rhs)
} }
func (t *timestampzInterfaceImpl) LT_EQ(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) LT_EQ(rhs TimestampzExpression) BoolExpression {
return ltEq(t.parent, rhs) return LtEq(t.parent, rhs)
} }
func (t *timestampzInterfaceImpl) GT(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) GT(rhs TimestampzExpression) BoolExpression {
return gt(t.parent, rhs) return Gt(t.parent, rhs)
} }
func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression { func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression {
return gtEq(t.parent, rhs) return GtEq(t.parent, rhs)
} }
//---------------------------------------------------// func (t *timestampzInterfaceImpl) ADD(rhs Interval) TimestampzExpression {
return TimestampzExp(Add(t.parent, rhs))
}
type prefixTimestampzOperator struct { func (t *timestampzInterfaceImpl) SUB(rhs Interval) TimestampzExpression {
expressionInterfaceImpl return TimestampzExp(Sub(t.parent, rhs))
timestampzInterfaceImpl
prefixOpExpression
} }
//------------------------------------------------- //-------------------------------------------------

View file

@ -53,3 +53,11 @@ func TestTimestampzExp(t *testing.T) {
assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz),
"(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200") "(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200")
} }
func TestTimestampzArithmetic(t *testing.T) {
timestampz := Timestampz(2000, 1, 1, 0, 0, 0, 100, "UTC")
assertClauseDebugSerialize(t, table1ColTimestampz.ADD(NewInterval(String("1 HOUR"))).EQ(timestampz),
"((table1.col_timestampz + INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')")
assertClauseDebugSerialize(t, table1ColTimestampz.SUB(NewInterval(String("1 HOUR"))).EQ(timestampz),
"((table1.col_timestampz - INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')")
}

View file

@ -4,23 +4,18 @@ package jet
type TimezExpression interface { type TimezExpression interface {
Expression Expression
//EQ
EQ(rhs TimezExpression) BoolExpression EQ(rhs TimezExpression) BoolExpression
//NOT_EQ
NOT_EQ(rhs TimezExpression) BoolExpression NOT_EQ(rhs TimezExpression) BoolExpression
//IS_DISTINCT_FROM
IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression
//IS_NOT_DISTINCT_FROM
IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression
//LT
LT(rhs TimezExpression) BoolExpression LT(rhs TimezExpression) BoolExpression
//LT_EQ
LT_EQ(rhs TimezExpression) BoolExpression LT_EQ(rhs TimezExpression) BoolExpression
//GT
GT(rhs TimezExpression) BoolExpression GT(rhs TimezExpression) BoolExpression
//GT_EQ
GT_EQ(rhs TimezExpression) BoolExpression GT_EQ(rhs TimezExpression) BoolExpression
ADD(rhs Interval) TimezExpression
SUB(rhs Interval) TimezExpression
} }
type timezInterfaceImpl struct { type timezInterfaceImpl struct {
@ -28,54 +23,44 @@ type timezInterfaceImpl struct {
} }
func (t *timezInterfaceImpl) EQ(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) EQ(rhs TimezExpression) BoolExpression {
return eq(t.parent, rhs) return Eq(t.parent, rhs)
} }
func (t *timezInterfaceImpl) NOT_EQ(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) NOT_EQ(rhs TimezExpression) BoolExpression {
return notEq(t.parent, rhs) return NotEq(t.parent, rhs)
} }
func (t *timezInterfaceImpl) IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression {
return isDistinctFrom(t.parent, rhs) return IsDistinctFrom(t.parent, rhs)
} }
func (t *timezInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression {
return isNotDistinctFrom(t.parent, rhs) return IsNotDistinctFrom(t.parent, rhs)
} }
func (t *timezInterfaceImpl) LT(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) LT(rhs TimezExpression) BoolExpression {
return lt(t.parent, rhs) return Lt(t.parent, rhs)
} }
func (t *timezInterfaceImpl) LT_EQ(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) LT_EQ(rhs TimezExpression) BoolExpression {
return ltEq(t.parent, rhs) return LtEq(t.parent, rhs)
} }
func (t *timezInterfaceImpl) GT(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) GT(rhs TimezExpression) BoolExpression {
return gt(t.parent, rhs) return Gt(t.parent, rhs)
} }
func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression { func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression {
return gtEq(t.parent, rhs) return GtEq(t.parent, rhs)
} }
//---------------------------------------------------// func (t *timezInterfaceImpl) ADD(rhs Interval) TimezExpression {
type prefixTimezExpression struct { return TimezExp(Add(t.parent, rhs))
expressionInterfaceImpl
timezInterfaceImpl
prefixOpExpression
} }
//func newPrefixTimezExpression(operator string, expression Expression) TimezExpression { func (t *timezInterfaceImpl) SUB(rhs Interval) TimezExpression {
// timeExpr := prefixTimezExpression{} return TimezExp(Sub(t.parent, rhs))
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) }
//
// timeExpr.expressionInterfaceImpl.parent = &timeExpr
// timeExpr.timezInterfaceImpl.parent = &timeExpr
//
// return &timeExpr
//}
//---------------------------------------------------// //---------------------------------------------------//

View file

@ -1,6 +1,8 @@
package jet package jet
import "testing" import (
"testing"
)
var timezVar = Timez(10, 20, 0, 0, "+4:00") var timezVar = Timez(10, 20, 0, 0, "+4:00")
@ -49,3 +51,11 @@ func TestTimezExp(t *testing.T) {
assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")), assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")),
"(table1.col_float < $1)", string("01:01:01.000000001 +4:00")) "(table1.col_float < $1)", string("01:01:01.000000001 +4:00"))
} }
func TestTimezArithmetic(t *testing.T) {
timez := Timez(0, 0, 0, 100, "UTC")
assertClauseDebugSerialize(t, table1ColTimez.ADD(NewInterval(String("1 HOUR"))).EQ(timez),
"((table1.col_timez + INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')")
assertClauseDebugSerialize(t, table1ColTimez.SUB(NewInterval(String("1 HOUR"))).EQ(timez),
"((table1.col_timez - INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')")
}

View file

@ -63,6 +63,17 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) {
} }
} }
// ExpressionListToSerializerList converts list of expressions to list of serializers
func ExpressionListToSerializerList(expressions []Expression) []Serializer {
var ret []Serializer
for _, expr := range expressions {
ret = append(ret, expr)
}
return ret
}
// ColumnListToProjectionList func // ColumnListToProjectionList func
func ColumnListToProjectionList(columns []ColumnExpression) []Projection { func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
var ret []Projection var ret []Projection
@ -176,3 +187,23 @@ func UnwidColumnList(columns []Column) []Column {
return ret return ret
} }
// OptionalOrDefaultString will return first value from variable argument list str or
// defaultStr if variable argument list is empty
func OptionalOrDefaultString(defaultStr string, str ...string) string {
if len(str) > 0 {
return str[0]
}
return defaultStr
}
// OptionalOrDefaultExpression will return first value from variable argument list expression or
// defaultExpression if variable argument list is empty
func OptionalOrDefaultExpression(defaultExpression Expression, expression ...Expression) Expression {
if len(expression) > 0 {
return expression[0]
}
return defaultExpression
}

View file

@ -0,0 +1,19 @@
package jet
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestOptionalOrDefaultString(t *testing.T) {
assert.Equal(t, OptionalOrDefaultString("default"), "default")
assert.Equal(t, OptionalOrDefaultString("default", "optional"), "optional")
}
func TestOptionalOrDefaultExpression(t *testing.T) {
defaultExpression := table2ColFloat
optionalExpression := table1Col1
assert.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression)
assert.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression)
}

View file

@ -7,21 +7,23 @@ import (
"github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/qrm" "github.com/go-jet/jet/qrm"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
// AssertExec assert statement execution for successful execution and number of rows affected // AssertExec assert statement execution for successful execution and number of rows affected
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db) res, err := stmt.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
rows, err := res.RowsAffected() rows, err := res.RowsAffected()
assert.NilError(t, err) assert.NoError(t, err)
if len(rowsAffected) > 0 { if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0]) assert.Equal(t, rows, rowsAffected[0])
@ -49,7 +51,7 @@ func PrintJson(v interface{}) {
// AssertJSON check if data json output is the same as expectedJSON // AssertJSON check if data json output is the same as expectedJSON
func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t") jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON) assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON)
} }
@ -69,17 +71,17 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath) filePath := getFullPath(testRelativePath)
fileJSONData, err := ioutil.ReadFile(filePath) fileJSONData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err) assert.NoError(t, err)
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1) fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1)
} }
jsonData, err := json.MarshalIndent(data, "", "\t") jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err) assert.NoError(t, err)
assert.Assert(t, string(fileJSONData) == string(jsonData)) assert.True(t, string(fileJSONData) == string(jsonData))
//assert.DeepEqual(t, string(fileJSONData), string(jsonData)) //AssertDeepEqual(t, string(fileJSONData), string(jsonData))
} }
// AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs
@ -90,7 +92,7 @@ func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string,
if len(expectedArgs) == 0 { if len(expectedArgs) == 0 {
return return
} }
assert.DeepEqual(t, args, expectedArgs) AssertDeepEqual(t, args, expectedArgs)
} }
// AssertStatementSqlErr checks if statement Sql() panics with errorStr // AssertStatementSqlErr checks if statement Sql() panics with errorStr
@ -108,7 +110,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
_, args := query.Sql() _, args := query.Sql()
if len(expectedArgs) > 0 { if len(expectedArgs) > 0 {
assert.DeepEqual(t, args, expectedArgs) AssertDeepEqual(t, args, expectedArgs)
} }
debuqSql := query.DebugSql() debuqSql := query.DebugSql()
@ -122,13 +124,35 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali
//fmt.Println(out.Buff.String()) //fmt.Println(out.Buff.String())
assert.DeepEqual(t, out.Buff.String(), query) AssertDeepEqual(t, out.Buff.String(), query)
if len(args) > 0 { if len(args) > 0 {
assert.DeepEqual(t, out.Args, args) AssertDeepEqual(t, out.Args, args)
} }
} }
// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args
func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect, Debug: true}
jet.Serialize(clause, jet.SelectStatementType, &out)
AssertDeepEqual(t, out.Buff.String(), query)
if len(args) > 0 {
AssertDeepEqual(t, out.Args, args)
}
}
// AssertPanicErr checks if running a function fun produces a panic with errorStr string
func AssertPanicErr(t *testing.T, fun func(), errorStr string) {
defer func() {
r := recover()
assert.Equal(t, r, errorStr)
}()
fun()
}
// AssertClauseSerializeErr check if clause serialize panics with errString // AssertClauseSerializeErr check if clause serialize panics with errString
func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) {
defer func() { defer func() {
@ -145,8 +169,8 @@ func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet
out := jet.SQLBuilder{Dialect: dialect} out := jet.SQLBuilder{Dialect: dialect}
jet.SerializeForProjection(projection, jet.SelectStatementType, &out) jet.SerializeForProjection(projection, jet.SelectStatementType, &out)
assert.DeepEqual(t, out.Buff.String(), query) AssertDeepEqual(t, out.Buff.String(), query)
assert.DeepEqual(t, out.Args, args) AssertDeepEqual(t, out.Args, args)
} }
// AssertQueryPanicErr check if statement Query execution panics with error errString // AssertQueryPanicErr check if statement Query execution panics with error errString
@ -163,13 +187,13 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter
func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) {
enumFileData, err := ioutil.ReadFile(filePath) enumFileData, err := ioutil.ReadFile(filePath)
assert.NilError(t, err) assert.NoError(t, err)
beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) beginIndex := bytes.Index(enumFileData, []byte(contentBegin))
//fmt.Println("-"+string(enumFileData[beginIndex:])+"-") //fmt.Println("-"+string(enumFileData[beginIndex:])+"-")
assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent)
} }
// AssertFileNamesEqual check if all filesInfos are contained in fileNames // AssertFileNamesEqual check if all filesInfos are contained in fileNames
@ -183,6 +207,11 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
} }
for _, fileName := range fileNames { for _, fileName := range fileNames {
assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") assert.True(t, fileNamesMap[fileName], fileName+" does not exist.")
} }
} }
// AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}) {
assert.True(t, cmp.Equal(actual, expected))
}

View file

@ -9,6 +9,8 @@ import (
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"time"
"unicode"
) )
// ToGoIdentifier converts database to Go identifier. // ToGoIdentifier converts database to Go identifier.
@ -16,6 +18,16 @@ func ToGoIdentifier(databaseIdentifier string) string {
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
} }
// ToGoEnumValueIdentifier converts enum value name to Go identifier name.
func ToGoEnumValueIdentifier(enumName, enumValue string) string {
enumValueIdentifier := ToGoIdentifier(enumValue)
if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) {
return ToGoIdentifier(enumName) + enumValueIdentifier
}
return enumValueIdentifier
}
// ToGoFileName converts database identifier to Go file name. // ToGoFileName converts database identifier to Go file name.
func ToGoFileName(databaseIdentifier string) string { func ToGoFileName(databaseIdentifier string) string {
return strings.ToLower(replaceInvalidChars(databaseIdentifier)) return strings.ToLower(replaceInvalidChars(databaseIdentifier))
@ -182,3 +194,22 @@ func StringSliceContains(strings []string, contains string) bool {
return false return false
} }
// ExtractDateTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration
func ExtractDateTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) {
days = int64(duration / (24 * time.Hour))
reminder := duration % (24 * time.Hour)
hours = int64(reminder / time.Hour)
reminder = reminder % time.Hour
minutes = int64(reminder / time.Minute)
reminder = reminder % time.Minute
seconds = int64(reminder / time.Second)
reminder = reminder % time.Second
microseconds = int64(reminder / time.Microsecond)
return
}

View file

@ -2,7 +2,7 @@ package utils
import ( import (
"fmt" "fmt"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -25,6 +25,11 @@ func TestToGoIdentifier(t *testing.T) {
assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable") assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
} }
func TestToGoEnumValueIdentifier(t *testing.T) {
assert.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue")
assert.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100")
}
func TestErrorCatchErr(t *testing.T) { func TestErrorCatchErr(t *testing.T) {
var err error var err error

View file

@ -5,14 +5,14 @@ import (
) )
func TestCAST(t *testing.T) { func TestCAST(t *testing.T) {
assertClauseSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`) assertSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`)
assertClauseSerialize(t, CAST(Int(22)).AS_CHAR(), `CAST(? AS CHAR)`) assertSerialize(t, CAST(Int(22)).AS_CHAR(), `CAST(? AS CHAR)`)
assertClauseSerialize(t, CAST(Int(22)).AS_CHAR(10), `CAST(? AS CHAR(10))`) assertSerialize(t, CAST(Int(22)).AS_CHAR(10), `CAST(? AS CHAR(10))`)
assertClauseSerialize(t, CAST(Int(22)).AS_DATE(), `CAST(? AS DATE)`) assertSerialize(t, CAST(Int(22)).AS_DATE(), `CAST(? AS DATE)`)
assertClauseSerialize(t, CAST(Int(22)).AS_DECIMAL(), `CAST(? AS DECIMAL)`) assertSerialize(t, CAST(Int(22)).AS_DECIMAL(), `CAST(? AS DECIMAL)`)
assertClauseSerialize(t, CAST(Int(22)).AS_TIME(), `CAST(? AS TIME)`) assertSerialize(t, CAST(Int(22)).AS_TIME(), `CAST(? AS TIME)`)
assertClauseSerialize(t, CAST(Int(22)).AS_DATETIME(), `CAST(? AS DATETIME)`) assertSerialize(t, CAST(Int(22)).AS_DATETIME(), `CAST(? AS DATETIME)`)
assertClauseSerialize(t, CAST(Int(22)).AS_SIGNED(), `CAST(? AS SIGNED)`) assertSerialize(t, CAST(Int(22)).AS_SIGNED(), `CAST(? AS SIGNED)`)
assertClauseSerialize(t, CAST(Int(22)).AS_UNSIGNED(), `CAST(? AS UNSIGNED)`) assertSerialize(t, CAST(Int(22)).AS_UNSIGNED(), `CAST(? AS UNSIGNED)`)
assertClauseSerialize(t, CAST(Int(22)).AS_BINARY(), `CAST(? AS BINARY)`) assertSerialize(t, CAST(Int(22)).AS_BINARY(), `CAST(? AS BINARY)`)
} }

View file

@ -8,7 +8,6 @@ import (
var Dialect = newDialect() var Dialect = newDialect()
func newDialect() jet.Dialect { func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{} operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator
@ -32,7 +31,7 @@ func newDialect() jet.Dialect {
return jet.NewDialect(mySQLDialectParams) return jet.NewDialect(mySQLDialectParams)
} }
func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc { func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator XOR") panic("jet: invalid number of expressions for operator XOR")
@ -49,7 +48,7 @@ func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc { func mysqlCONCAToperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator CONCAT") panic("jet: invalid number of expressions for operator CONCAT")
@ -66,7 +65,7 @@ func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc { func mysqlDivision(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator DIV") panic("jet: invalid number of expressions for operator DIV")
@ -90,7 +89,7 @@ func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { func mysqlISNOTDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -102,7 +101,7 @@ func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { func mysqlISDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
out.WriteString("NOT(") out.WriteString("NOT(")
mysqlISNOTDISTINCTFROM(expressions...)(statement, out, options...) mysqlISNOTDISTINCTFROM(expressions...)(statement, out, options...)
@ -110,7 +109,7 @@ func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -136,7 +135,7 @@ func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")

View file

@ -5,37 +5,37 @@ import (
) )
func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) { func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(NOT(table1.col_bool <=> table2.col_bool))") assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(NOT(table1.col_bool <=> table2.col_bool))")
assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(NOT(table1.col_bool <=> ?))", false) assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(NOT(table1.col_bool <=> ?))", false)
} }
func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool <=> table2.col_bool)") assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool <=> table2.col_bool)")
assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool <=> ?)", false) assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool <=> ?)", false)
} }
func TestBoolLiteral(t *testing.T) { func TestBoolLiteral(t *testing.T) {
assertClauseSerialize(t, Bool(true), "?", true) assertSerialize(t, Bool(true), "?", true)
assertClauseSerialize(t, Bool(false), "?", false) assertSerialize(t, Bool(false), "?", false)
} }
func TestIntegerExpressionDIV(t *testing.T) { func TestIntegerExpressionDIV(t *testing.T) {
assertClauseSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int DIV table2.col_int)") assertSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int DIV table2.col_int)")
assertClauseSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int DIV ?)", int64(11)) assertSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int DIV ?)", int64(11))
} }
func TestIntExpressionPOW(t *testing.T) { func TestIntExpressionPOW(t *testing.T) {
assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)") assertSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)")
assertClauseSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11)) assertSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11))
} }
func TestIntExpressionBIT_XOR(t *testing.T) { func TestIntExpressionBIT_XOR(t *testing.T) {
assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)") assertSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)")
assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11)) assertSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11))
} }
func TestExists(t *testing.T) { func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS( assertSerialize(t, EXISTS(
table2. table2.
SELECT(Int(1)). SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)), WHERE(table1Col1.EQ(table2Col3)),
@ -48,15 +48,15 @@ func TestExists(t *testing.T) {
} }
func TestString_REGEXP_LIKE_operator(t *testing.T) { func TestString_REGEXP_LIKE_operator(t *testing.T) {
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)") assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN")
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 REGEXP ?)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 REGEXP ?)", "JOHN")
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP BINARY ?)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP BINARY ?)", "JOHN")
} }
func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)")
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN")
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 NOT REGEXP ?)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 NOT REGEXP ?)", "JOHN")
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP BINARY ?)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP BINARY ?)", "JOHN")
} }

View file

@ -1,7 +1,7 @@
package mysql package mysql
import ( import (
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )

195
mysql/interval.go Normal file
View file

@ -0,0 +1,195 @@
package mysql
import (
"fmt"
"regexp"
"time"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
)
type unitType string
// List of interval unit types for MySQL
const (
MICROSECOND unitType = "MICROSECOND"
SECOND = "SECOND"
MINUTE = "MINUTE"
HOUR = "HOUR"
DAY = "DAY"
WEEK = "WEEK"
MONTH = "MONTH"
QUARTER = "QUARTER"
YEAR = "YEAR"
SECOND_MICROSECOND = "SECOND_MICROSECOND"
MINUTE_MICROSECOND = "MINUTE_MICROSECOND"
MINUTE_SECOND = "MINUTE_SECOND"
HOUR_MICROSECOND = "HOUR_MICROSECOND"
HOUR_SECOND = "HOUR_SECOND"
HOUR_MINUTE = "HOUR_MINUTE"
DAY_MICROSECOND = "DAY_MICROSECOND"
DAY_SECOND = "DAY_SECOND"
DAY_MINUTE = "DAY_MINUTE"
DAY_HOUR = "DAY_HOUR"
YEAR_MONTH = "YEAR_MONTH"
)
// Interval is representation of MySQL interval
type Interval = jet.Interval
// INTERVAL creates new temporal interval.
// In a case of MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR unit type
// value parameter should be number. For example: INTERVAL(1, DAY)
// In a case of other unit types, value should be string with appropriate format.
// For example: INTERVAL("10:08:50", HOUR_SECOND)
func INTERVAL(value interface{}, unitType unitType) Interval {
switch unitType {
case MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR:
if !isNumericType(value) {
panic("jet: INTERVAL invalid value type. Numeric type expected")
}
return INTERVALe(jet.FixedLiteral(value), unitType)
default:
strValue, ok := value.(string)
if !ok {
panic("jet: INTERNAL invalid value type. String type expected")
}
var regexp *regexp.Regexp
switch unitType {
case SECOND_MICROSECOND:
regexp = regexSecondMicrosecond
case MINUTE_MICROSECOND:
regexp = regexMinuteMicrosecond
case MINUTE_SECOND:
regexp = regexMinuteSecond
case HOUR_MICROSECOND:
regexp = regexHourMicrosecond
case HOUR_SECOND:
regexp = regexHourSecond
case HOUR_MINUTE:
regexp = regexHourMinute
case DAY_MICROSECOND:
regexp = regexDayMicrosecond
case DAY_SECOND:
regexp = regexDaySecond
case DAY_MINUTE:
regexp = regexDayMinute
case DAY_HOUR:
regexp = regexDayHour
case YEAR_MONTH:
regexp = regexYearMonth
default:
panic("jet: INTERVAL invalid unit type")
}
if !regexp.MatchString(strValue) {
panic("jet: INTERVAL invalid format")
}
return INTERVALe(jet.Literal(value), unitType)
}
}
// INTERVALe creates new temporal interval from expresion and unit type.
func INTERVALe(expr Expression, unitType unitType) Interval {
return jet.NewInterval(jet.ListSerializer{
Serializers: []jet.Serializer{expr, jet.Raw(string(unitType))},
Separator: " ",
})
}
// INTERVALd temoral interval from time.Duration
func INTERVALd(duration time.Duration) Interval {
var sign int64 = 1
if duration < 0 {
sign = -1
duration = -duration
}
days, hours, minutes, sec, microsec := utils.ExtractDateTimeComponents(duration)
if days != 0 {
switch {
case microsec > 0:
intervalStr := fmt.Sprintf("%d %02d:%02d:%02d.%06d", sign*days, hours, minutes, sec, microsec)
return INTERVAL(intervalStr, DAY_MICROSECOND)
case sec > 0:
intervalStr := fmt.Sprintf("%d %02d:%02d:%02d", sign*days, hours, minutes, sec)
return INTERVAL(intervalStr, DAY_SECOND)
case minutes > 0:
intervalStr := fmt.Sprintf("%d %02d:%02d", sign*days, hours, minutes)
return INTERVAL(intervalStr, DAY_MINUTE)
case hours > 0:
intervalStr := fmt.Sprintf("%d %02d", sign*days, hours)
return INTERVAL(intervalStr, DAY_HOUR)
default:
return INTERVAL(sign*days, DAY)
}
}
if hours != 0 {
switch {
case microsec > 0:
intervalStr := fmt.Sprintf("%02d:%02d:%02d.%06d", sign*hours, minutes, sec, microsec)
return INTERVAL(intervalStr, HOUR_MICROSECOND)
case sec > 0:
intervalStr := fmt.Sprintf("%02d:%02d:%02d", sign*hours, minutes, sec)
return INTERVAL(intervalStr, HOUR_SECOND)
case minutes > 0:
intervalStr := fmt.Sprintf("%02d:%02d", sign*hours, minutes)
return INTERVAL(intervalStr, HOUR_MINUTE)
default:
return INTERVAL(sign*hours, HOUR)
}
}
if minutes != 0 {
switch {
case microsec > 0:
intervalStr := fmt.Sprintf("%02d:%02d.%06d", sign*minutes, sec, microsec)
return INTERVAL(intervalStr, MINUTE_MICROSECOND)
case sec > 0:
intervalStr := fmt.Sprintf("%02d:%02d", sign*minutes, sec)
return INTERVAL(intervalStr, MINUTE_SECOND)
default:
return INTERVAL(sign*minutes, MINUTE)
}
}
if sec != 0 {
if microsec > 0 {
intervalStr := fmt.Sprintf("%02d.%06d", sign*sec, microsec)
return INTERVAL(intervalStr, SECOND_MICROSECOND)
}
return INTERVAL(sign*sec, SECOND)
}
return INTERVAL(sign*microsec, MICROSECOND)
}
var (
regexSecondMicrosecond = regexp.MustCompile(`^-?\d{1,2}\.\d+$`) //'SECONDS.MICROSECONDS'
regexMinuteMicrosecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}\.\d+$`) //'MINUTE:SECONDS.MICROSECONDS'
regexMinuteSecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}$`) //'MINUTE:SECONDS'
regexHourMicrosecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}:\d{2}\.\d+$`) //'HOUR:MINUTE:SECONDS.MICROSECONDS'
regexHourSecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}:\d{2}$`) //'HOUR:MINUTE:SECONDS'
regexHourMinute = regexp.MustCompile(`^-?\d{1,2}:\d{2}$`) //'HOUR:MINUTE'
regexDayMicrosecond = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}:\d{2}.\d+$`) //'DAY HOUR:MINUTE:SECONDS'
regexDaySecond = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}:\d{2}$`) //'DAY HOUR:MINUTE:SECONDS'
regexDayMinute = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}$`) //'DAY HOUR:MINUTE'
regexDayHour = regexp.MustCompile(`^-?\d+ \d{1,2}$`) //'DAY HOUR:MINUTE'
regexYearMonth = regexp.MustCompile(`^-?\d+-\d{1,2}$`) //'YEAR-MONTH'
)
func isNumericType(value interface{}) bool {
switch value.(type) {
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return true
default:
return false
}
}

99
mysql/interval_test.go Normal file
View file

@ -0,0 +1,99 @@
package mysql
import (
"testing"
"time"
)
func TestINTERVAL(t *testing.T) {
assertSerialize(t, INTERVAL("3-2", YEAR_MONTH), "INTERVAL ? YEAR_MONTH")
assertDebugSerialize(t, INTERVAL("3-2", YEAR_MONTH), "INTERVAL '3-2' YEAR_MONTH")
assertDebugSerialize(t, INTERVAL("-3-2", YEAR_MONTH), "INTERVAL '-3-2' YEAR_MONTH")
assertDebugSerialize(t, INTERVAL("10 25", DAY_HOUR), "INTERVAL '10 25' DAY_HOUR")
assertDebugSerialize(t, INTERVAL("-10 25", DAY_HOUR), "INTERVAL '-10 25' DAY_HOUR")
assertDebugSerialize(t, INTERVAL("10 25:15", DAY_MINUTE), "INTERVAL '10 25:15' DAY_MINUTE")
assertDebugSerialize(t, INTERVAL("-10 25:15", DAY_MINUTE), "INTERVAL '-10 25:15' DAY_MINUTE")
assertDebugSerialize(t, INTERVAL("10 25:15:08", DAY_SECOND), "INTERVAL '10 25:15:08' DAY_SECOND")
assertDebugSerialize(t, INTERVAL("-10 25:15:08", DAY_SECOND), "INTERVAL '-10 25:15:08' DAY_SECOND")
assertDebugSerialize(t, INTERVAL("10 25:15:08.000100", DAY_MICROSECOND), "INTERVAL '10 25:15:08.000100' DAY_MICROSECOND")
assertDebugSerialize(t, INTERVAL("-10 25:15:08.000100", DAY_MICROSECOND), "INTERVAL '-10 25:15:08.000100' DAY_MICROSECOND")
assertDebugSerialize(t, INTERVAL("15:08", HOUR_MINUTE), "INTERVAL '15:08' HOUR_MINUTE")
assertDebugSerialize(t, INTERVAL("-15:08", HOUR_MINUTE), "INTERVAL '-15:08' HOUR_MINUTE")
assertDebugSerialize(t, INTERVAL("15:08", HOUR_MINUTE), "INTERVAL '15:08' HOUR_MINUTE")
assertDebugSerialize(t, INTERVAL("-15:08", HOUR_MINUTE), "INTERVAL '-15:08' HOUR_MINUTE")
assertDebugSerialize(t, INTERVAL("15:08:03", HOUR_SECOND), "INTERVAL '15:08:03' HOUR_SECOND")
assertDebugSerialize(t, INTERVAL("-15:08:03", HOUR_SECOND), "INTERVAL '-15:08:03' HOUR_SECOND")
assertDebugSerialize(t, INTERVAL("25:15:08.000100", HOUR_MICROSECOND), "INTERVAL '25:15:08.000100' HOUR_MICROSECOND")
assertDebugSerialize(t, INTERVAL("-25:15:08.000100", HOUR_MICROSECOND), "INTERVAL '-25:15:08.000100' HOUR_MICROSECOND")
assertDebugSerialize(t, INTERVAL("08:03", MINUTE_SECOND), "INTERVAL '08:03' MINUTE_SECOND")
assertDebugSerialize(t, INTERVAL("-08:03", MINUTE_SECOND), "INTERVAL '-08:03' MINUTE_SECOND")
assertDebugSerialize(t, INTERVAL("15:08.000100", MINUTE_MICROSECOND), "INTERVAL '15:08.000100' MINUTE_MICROSECOND")
assertDebugSerialize(t, INTERVAL("-15:08.000100", MINUTE_MICROSECOND), "INTERVAL '-15:08.000100' MINUTE_MICROSECOND")
assertDebugSerialize(t, INTERVAL("08.000100", SECOND_MICROSECOND), "INTERVAL '08.000100' SECOND_MICROSECOND")
assertDebugSerialize(t, INTERVAL("-08.000100", SECOND_MICROSECOND), "INTERVAL '-08.000100' SECOND_MICROSECOND")
assertSerialize(t, INTERVAL(15, SECOND), "INTERVAL 15 SECOND")
assertSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL 1 MICROSECOND")
assertSerialize(t, INTERVAL(2, MINUTE), "INTERVAL 2 MINUTE")
assertSerialize(t, INTERVAL(3, HOUR), "INTERVAL 3 HOUR")
assertSerialize(t, INTERVAL(4, DAY), "INTERVAL 4 DAY")
assertSerialize(t, INTERVAL(5, MONTH), "INTERVAL 5 MONTH")
assertSerialize(t, INTERVAL(6, YEAR), "INTERVAL 6 YEAR")
assertSerialize(t, INTERVAL(-6, YEAR), "INTERVAL -6 YEAR")
assertSerialize(t, INTERVAL(uint(6), YEAR), "INTERVAL 6 YEAR")
assertSerialize(t, INTERVAL(int16(7), YEAR), "INTERVAL 7 YEAR")
assertSerialize(t, INTERVAL(3.5, YEAR), "INTERVAL 3.5 YEAR")
}
func TestINTERVAL_InvalidUnitType(t *testing.T) {
assertPanicErr(t, func() { INTERVAL("11", HOUR) }, "jet: INTERVAL invalid value type. Numeric type expected")
assertPanicErr(t, func() { INTERVAL("11", YEAR_MONTH) }, "jet: INTERVAL invalid format")
assertPanicErr(t, func() { INTERVAL("11+11", YEAR_MONTH) }, "jet: INTERVAL invalid format")
assertPanicErr(t, func() { INTERVAL(156.11, YEAR_MONTH) }, "jet: INTERNAL invalid value type. String type expected")
}
func TestINTERVALd(t *testing.T) {
assertDebugSerialize(t, INTERVALd(3*time.Microsecond), "INTERVAL 3 MICROSECOND")
assertDebugSerialize(t, INTERVALd(-1*time.Microsecond), "INTERVAL -1 MICROSECOND")
assertDebugSerialize(t, INTERVALd(3*time.Second), "INTERVAL 3 SECOND")
assertDebugSerialize(t, INTERVALd(3*time.Second+4*time.Microsecond), "INTERVAL '03.000004' SECOND_MICROSECOND")
assertDebugSerialize(t, INTERVALd(-1*time.Second), "INTERVAL -1 SECOND")
assertDebugSerialize(t, INTERVALd(3*time.Minute), "INTERVAL 3 MINUTE")
assertDebugSerialize(t, INTERVALd(3*time.Minute+4*time.Second), "INTERVAL '03:04' MINUTE_SECOND")
assertDebugSerialize(t, INTERVALd(3*time.Minute+4*time.Second+5*time.Microsecond), "INTERVAL '03:04.000005' MINUTE_MICROSECOND")
assertDebugSerialize(t, INTERVALd(-11*time.Minute), "INTERVAL -11 MINUTE")
assertDebugSerialize(t, INTERVALd(-11*time.Minute-22*time.Second), "INTERVAL '-11:22' MINUTE_SECOND")
assertDebugSerialize(t, INTERVALd(3*time.Hour), "INTERVAL 3 HOUR")
assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute), "INTERVAL '03:04' HOUR_MINUTE")
assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute+5*time.Second), "INTERVAL '03:04:05' HOUR_SECOND")
assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute+5*time.Second+6*time.Millisecond), "INTERVAL '03:04:05.006000' HOUR_MICROSECOND")
assertDebugSerialize(t, INTERVALd(-11*time.Hour), "INTERVAL -11 HOUR")
assertDebugSerialize(t, INTERVALd(-11*time.Hour-22*time.Minute), "INTERVAL '-11:22' HOUR_MINUTE")
assertDebugSerialize(t, INTERVALd(3*24*time.Hour), "INTERVAL 3 DAY")
assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour), "INTERVAL '3 04' DAY_HOUR")
assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute), "INTERVAL '3 04:05' DAY_MINUTE")
assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute+6*time.Second), "INTERVAL '3 04:05:06' DAY_SECOND")
assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute+6*time.Second+7*time.Microsecond), "INTERVAL '3 04:05:06.000007' DAY_MICROSECOND")
assertDebugSerialize(t, INTERVALd(-11*24*time.Hour), "INTERVAL -11 DAY")
assertDebugSerialize(t, INTERVALd(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond), "INTERVAL '01:02:03.000345' HOUR_MICROSECOND")
assertDebugSerialize(t, INTERVALd(-1*(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond)), "INTERVAL '-1:02:03.000345' HOUR_MICROSECOND")
}
func TestINTERVALe(t *testing.T) {
assertSerialize(t, INTERVALe(table1ColFloat, MICROSECOND), "INTERVAL table1.col_float MICROSECOND")
assertSerialize(t, INTERVALe(table1ColFloat, SECOND), "INTERVAL table1.col_float SECOND")
assertSerialize(t, INTERVALe(table1ColFloat, MINUTE), "INTERVAL table1.col_float MINUTE")
assertSerialize(t, INTERVALe(table1ColFloat, HOUR), "INTERVAL table1.col_float HOUR")
assertSerialize(t, INTERVALe(table1ColFloat, DAY), "INTERVAL table1.col_float DAY")
assertSerialize(t, INTERVALe(table1ColFloat, WEEK), "INTERVAL table1.col_float WEEK")
assertSerialize(t, INTERVALe(table1ColFloat, MONTH), "INTERVAL table1.col_float MONTH")
assertSerialize(t, INTERVALe(table1ColFloat, QUARTER), "INTERVAL table1.col_float QUARTER")
assertSerialize(t, INTERVALe(table1ColFloat, YEAR), "INTERVAL table1.col_float YEAR")
}

View file

@ -6,37 +6,37 @@ import (
) )
func TestBool(t *testing.T) { func TestBool(t *testing.T) {
assertClauseSerialize(t, Bool(false), `?`, false) assertSerialize(t, Bool(false), `?`, false)
} }
func TestInt(t *testing.T) { func TestInt(t *testing.T) {
assertClauseSerialize(t, Int(11), `?`, int64(11)) assertSerialize(t, Int(11), `?`, int64(11))
} }
func TestFloat(t *testing.T) { func TestFloat(t *testing.T) {
assertClauseSerialize(t, Float(12.34), `?`, float64(12.34)) assertSerialize(t, Float(12.34), `?`, float64(12.34))
} }
func TestString(t *testing.T) { func TestString(t *testing.T) {
assertClauseSerialize(t, String("Some text"), `?`, "Some text") assertSerialize(t, String("Some text"), `?`, "Some text")
} }
func TestDate(t *testing.T) { func TestDate(t *testing.T) {
assertClauseSerialize(t, Date(2014, time.January, 2), `CAST(? AS DATE)`, "2014-01-02") assertSerialize(t, Date(2014, time.January, 2), `CAST(? AS DATE)`, "2014-01-02")
assertClauseSerialize(t, DateT(time.Now()), `CAST(? AS DATE)`) assertSerialize(t, DateT(time.Now()), `CAST(? AS DATE)`)
} }
func TestTime(t *testing.T) { func TestTime(t *testing.T) {
assertClauseSerialize(t, Time(10, 15, 30), `CAST(? AS TIME)`, "10:15:30") assertSerialize(t, Time(10, 15, 30), `CAST(? AS TIME)`, "10:15:30")
assertClauseSerialize(t, TimeT(time.Now()), `CAST(? AS TIME)`) assertSerialize(t, TimeT(time.Now()), `CAST(? AS TIME)`)
} }
func TestDateTime(t *testing.T) { func TestDateTime(t *testing.T) {
assertClauseSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `CAST(? AS DATETIME)`, "2010-03-30 10:15:30") assertSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `CAST(? AS DATETIME)`, "2010-03-30 10:15:30")
assertClauseSerialize(t, DateTimeT(time.Now()), `CAST(? AS DATETIME)`) assertSerialize(t, DateTimeT(time.Now()), `CAST(? AS DATETIME)`)
} }
func TestTimestamp(t *testing.T) { func TestTimestamp(t *testing.T) {
assertClauseSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30") assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30")
assertClauseSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`) assertSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`)
} }

View file

@ -12,17 +12,17 @@ func TestJoinNilInputs(t *testing.T) {
} }
func TestINNER_JOIN(t *testing.T) { func TestINNER_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).
INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int) INNER JOIN db.table2 ON (table1.col_int = table2.col_int)
INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(Int(1))). INNER_JOIN(table2, table1ColInt.EQ(Int(1))).
INNER_JOIN(table3, table1ColInt.EQ(Int(2))), INNER_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -31,17 +31,17 @@ INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
} }
func TestLEFT_JOIN(t *testing.T) { func TestLEFT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)
LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). LEFT_JOIN(table2, table1ColInt.EQ(Int(1))).
LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), LEFT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -50,17 +50,17 @@ LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
} }
func TestRIGHT_JOIN(t *testing.T) { func TestRIGHT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)
RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))).
RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -69,17 +69,17 @@ RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
} }
func TestFULL_JOIN(t *testing.T) { func TestFULL_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)).
FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int) FULL JOIN db.table2 ON (table1.col_int = table2.col_int)
FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(Int(1))). FULL_JOIN(table2, table1ColInt.EQ(Int(1))).
FULL_JOIN(table3, table1ColInt.EQ(Int(2))), FULL_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -88,11 +88,11 @@ FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
} }
func TestCROSS_JOIN(t *testing.T) { func TestCROSS_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
CROSS_JOIN(table2), CROSS_JOIN(table2),
`db.table1 `db.table1
CROSS JOIN db.table2`) CROSS JOIN db.table2`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
CROSS_JOIN(table2). CROSS_JOIN(table2).
CROSS_JOIN(table3), CROSS_JOIN(table3),
`db.table1 `db.table1

View file

@ -7,3 +7,6 @@ type Statement = jet.Statement
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection = jet.Projection type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList

View file

@ -58,10 +58,14 @@ var table3 = NewTable(
table3ColInt, table3ColInt,
table3StrCol) table3StrCol)
func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) testutils.AssertClauseSerialize(t, Dialect, clause, query, args...)
} }
func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertDebugClauseSerialize(t, Dialect, clause, query, args...)
}
func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) testutils.AssertClauseSerializeErr(t, Dialect, clause, errString)
} }
@ -70,5 +74,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st
testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...) testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...)
} }
var assertPanicErr = testutils.AssertPanicErr
var assertStatementSql = testutils.AssertStatementSql var assertStatementSql = testutils.AssertStatementSql
var assertStatementSqlErr = testutils.AssertStatementSqlErr var assertStatementSqlErr = testutils.AssertStatementSqlErr

View file

@ -2,8 +2,9 @@ package postgres
import ( import (
"fmt" "fmt"
"github.com/go-jet/jet/internal/jet"
"strconv" "strconv"
"github.com/go-jet/jet/internal/jet"
) )
type cast interface { type cast interface {
@ -32,7 +33,7 @@ type cast interface {
AS_TIME() TimeExpression AS_TIME() TimeExpression
// Cast expression AS text type // Cast expression AS text type
AS_TEXT() StringExpression AS_TEXT() StringExpression
// Cast expression AS bytea type
AS_BYTEA() StringExpression AS_BYTEA() StringExpression
// Cast expression AS time with time timezone type // Cast expression AS time with time timezone type
AS_TIMEZ() TimezExpression AS_TIMEZ() TimezExpression
@ -40,6 +41,8 @@ type cast interface {
AS_TIMESTAMP() TimestampExpression AS_TIMESTAMP() TimestampExpression
// Cast expression AS timestamp with timezone type // Cast expression AS timestamp with timezone type
AS_TIMESTAMPZ() TimestampzExpression AS_TIMESTAMPZ() TimestampzExpression
// Cast expression AS interval type
AS_INTERVAL() IntervalExpression
} }
type castImpl struct { type castImpl struct {
@ -151,3 +154,8 @@ func (b *castImpl) AS_TIMESTAMP() TimestampExpression {
func (b *castImpl) AS_TIMESTAMPZ() TimestampzExpression { func (b *castImpl) AS_TIMESTAMPZ() TimestampzExpression {
return TimestampzExp(b.AS("timestamp with time zone")) return TimestampzExp(b.AS("timestamp with time zone"))
} }
// Cast expression AS interval type
func (b *castImpl) AS_INTERVAL() IntervalExpression {
return IntervalExp(b.AS("interval"))
}

View file

@ -5,60 +5,67 @@ import (
) )
func TestExpressionCAST_AS(t *testing.T) { func TestExpressionCAST_AS(t *testing.T) {
assertClauseSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") assertSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test")
} }
func TestExpressionCAST_AS_BOOL(t *testing.T) { func TestExpressionCAST_AS_BOOL(t *testing.T) {
assertClauseSerialize(t, CAST(Int(1)).AS_BOOL(), "$1::boolean", int64(1)) assertSerialize(t, CAST(Int(1)).AS_BOOL(), "$1::boolean", int64(1))
assertClauseSerialize(t, CAST(table2Col3).AS_BOOL(), "table2.col3::boolean") assertSerialize(t, CAST(table2Col3).AS_BOOL(), "table2.col3::boolean")
assertClauseSerialize(t, CAST(table2Col3.ADD(table2Col3)).AS_BOOL(), "(table2.col3 + table2.col3)::boolean") assertSerialize(t, CAST(table2Col3.ADD(table2Col3)).AS_BOOL(), "(table2.col3 + table2.col3)::boolean")
} }
func TestExpressionCAST_AS_SMALLINT(t *testing.T) { func TestExpressionCAST_AS_SMALLINT(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_SMALLINT(), "table2.col3::smallint") assertSerialize(t, CAST(table2Col3).AS_SMALLINT(), "table2.col3::smallint")
} }
func TestExpressionCAST_AS_INTEGER(t *testing.T) { func TestExpressionCAST_AS_INTEGER(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_INTEGER(), "table2.col3::integer") assertSerialize(t, CAST(table2Col3).AS_INTEGER(), "table2.col3::integer")
} }
func TestExpressionCAST_AS_BIGINT(t *testing.T) { func TestExpressionCAST_AS_BIGINT(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_BIGINT(), "table2.col3::bigint") assertSerialize(t, CAST(table2Col3).AS_BIGINT(), "table2.col3::bigint")
} }
func TestExpressionCAST_AS_NUMERIC(t *testing.T) { func TestExpressionCAST_AS_NUMERIC(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_NUMERIC(11, 11), "table2.col3::numeric(11, 11)") assertSerialize(t, CAST(table2Col3).AS_NUMERIC(11, 11), "table2.col3::numeric(11, 11)")
assertClauseSerialize(t, CAST(table2Col3).AS_NUMERIC(11), "table2.col3::numeric(11)") assertSerialize(t, CAST(table2Col3).AS_NUMERIC(11), "table2.col3::numeric(11)")
} }
func TestExpressionCAST_AS_REAL(t *testing.T) { func TestExpressionCAST_AS_REAL(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_REAL(), "table2.col3::real") assertSerialize(t, CAST(table2Col3).AS_REAL(), "table2.col3::real")
} }
func TestExpressionCAST_AS_DOUBLE(t *testing.T) { func TestExpressionCAST_AS_DOUBLE(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_DOUBLE(), "table2.col3::double precision") assertSerialize(t, CAST(table2Col3).AS_DOUBLE(), "table2.col3::double precision")
} }
func TestExpressionCAST_AS_TEXT(t *testing.T) { func TestExpressionCAST_AS_TEXT(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_TEXT(), "table2.col3::text") assertSerialize(t, CAST(table2Col3).AS_TEXT(), "table2.col3::text")
} }
func TestExpressionCAST_AS_DATE(t *testing.T) { func TestExpressionCAST_AS_DATE(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::date") assertSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::date")
} }
func TestExpressionCAST_AS_TIME(t *testing.T) { func TestExpressionCAST_AS_TIME(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_TIME(), "table2.col3::time without time zone") assertSerialize(t, CAST(table2Col3).AS_TIME(), "table2.col3::time without time zone")
} }
func TestExpressionCAST_AS_TIMEZ(t *testing.T) { func TestExpressionCAST_AS_TIMEZ(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_TIMEZ(), "table2.col3::time with time zone") assertSerialize(t, CAST(table2Col3).AS_TIMEZ(), "table2.col3::time with time zone")
} }
func TestExpressionCAST_AS_TIMESTAMP(t *testing.T) { func TestExpressionCAST_AS_TIMESTAMP(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_TIMESTAMP(), "table2.col3::timestamp without time zone") assertSerialize(t, CAST(table2Col3).AS_TIMESTAMP(), "table2.col3::timestamp without time zone")
} }
func TestExpressionCAST_AS_TIMESTAMPZ(t *testing.T) { func TestExpressionCAST_AS_TIMESTAMPZ(t *testing.T) {
assertClauseSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone") assertSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone")
}
func TestExpressionCAST_AS_INTERVAL(t *testing.T) {
assertSerialize(t, CAST(table2ColTimez).AS_INTERVAL(), "table2.col_timez::interval")
assertSerialize(t, CAST(Time(20, 11, 10)).AS_INTERVAL(), "$1::time without time zone::interval", "20:11:10")
assertSerialize(t, table2ColDate.SUB(CAST(Time(20, 11, 10)).AS_INTERVAL()),
"(table2.col_date - $1::time without time zone::interval)", "20:11:10")
} }

View file

@ -1,6 +1,8 @@
package postgres package postgres
import "github.com/go-jet/jet/internal/jet" import (
"github.com/go-jet/jet/internal/jet"
)
// Column is common column interface for all types of columns. // Column is common column interface for all types of columns.
type Column = jet.ColumnExpression type Column = jet.ColumnExpression
@ -62,3 +64,34 @@ type ColumnTimestampz = jet.ColumnTimestampz
// TimestampzColumn creates named timestamp with time zone column. // TimestampzColumn creates named timestamp with time zone column.
var TimestampzColumn = jet.TimestampzColumn var TimestampzColumn = jet.TimestampzColumn
//------------------------------------------------------//
// ColumnInterval is interface of PostgreSQL interval columns.
type ColumnInterval interface {
IntervalExpression
jet.Column
From(subQuery SelectTable) ColumnInterval
}
type intervalColumnImpl struct {
jet.ColumnExpressionImpl
intervalInterfaceImpl
}
func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval {
newIntervalColumn := IntervalColumn(i.Name())
jet.SetTableName(newIntervalColumn, i.TableName())
jet.SetSubQuery(newIntervalColumn, subQuery)
return newIntervalColumn
}
// IntervalColumn creates named interval column.
func IntervalColumn(name string) ColumnInterval {
intervalColumn := &intervalColumnImpl{}
intervalColumn.ColumnExpressionImpl = jet.NewColumnImpl(name, "", intervalColumn)
intervalColumn.intervalInterfaceImpl.parent = intervalColumn
return intervalColumn
}

20
postgres/columns_test.go Normal file
View file

@ -0,0 +1,20 @@
package postgres
import (
"testing"
)
func TestNewIntervalColumn(t *testing.T) {
subQuery := SELECT(Int(1)).AsTable("sub_query")
subQueryIntervalColumn := IntervalColumn("col_interval").From(subQuery)
assertSerialize(t, subQueryIntervalColumn, `sub_query."col_interval"`)
assertSerialize(t, subQueryIntervalColumn.EQ(INTERVAL(2, HOUR, 10, MINUTE)),
`(sub_query."col_interval" = INTERVAL '2 HOUR 10 MINUTE')`)
assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query."col_interval" AS "col_interval"`)
subQueryIntervalColumn2 := table1ColInterval.From(subQuery)
assertSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval"`)
assertSerialize(t, subQueryIntervalColumn2.EQ(INTERVAL(1, DAY)), `(sub_query."table1.col_interval" = INTERVAL '1 DAY')`)
assertProjectionSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval" AS "table1.col_interval"`)
}

View file

@ -24,12 +24,13 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(ord int) string { ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord) return "$" + strconv.Itoa(ord)
}, },
ReservedWords: reservedWords,
} }
return jet.NewDialect(dialectParams) return jet.NewDialect(dialectParams)
} }
func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -54,7 +55,7 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc {
} }
} }
func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { func postgresREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -80,7 +81,7 @@ func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc
} }
} }
func postgresNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator") panic("jet: invalid number of expressions for operator")
@ -105,3 +106,83 @@ func postgresNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeF
jet.Serialize(expressions[1], statement, out, options...) jet.Serialize(expressions[1], statement, out, options...)
} }
} }
var reservedWords = []string{
"ALL",
"ANALYSE",
"ANALYZE",
"AND",
"ANY",
"ARRAY",
"AS",
"ASC",
"ASYMMETRIC",
"BOTH",
"CASE",
"CAST",
"CHECK",
"COLLATE",
"COLUMN",
"CONSTRAINT",
"CREATE",
"CURRENT_CATALOG",
"CURRENT_DATE",
"CURRENT_ROLE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_USER",
"DEFAULT",
"DEFERRABLE",
"DESC",
"DISTINCT",
"DO",
"ELSE",
"END",
"EXCEPT",
"FALSE",
"FETCH",
"FOR",
"FOREIGN",
"FROM",
"GRANT",
"GROUP",
"HAVING",
"IN",
"INITIALLY",
"INTERSECT",
"INTO",
"LATERAL",
"LEADING",
"LIMIT",
"LOCALTIME",
"LOCALTIMESTAMP",
"NOT",
"NULL",
"OFFSET",
"ON",
"ONLY",
"OR",
"ORDER",
"PLACING",
"PRIMARY",
"REFERENCES",
"RETURNING",
"SELECT",
"SESSION_USER",
"SOME",
"SYMMETRIC",
"TABLE",
"THEN",
"TO",
"TRAILING",
"TRUE",
"UNION",
"UNIQUE",
"USER",
"USING",
"VARIADIC",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
}

View file

@ -3,21 +3,21 @@ package postgres
import "testing" import "testing"
func TestString_REGEXP_LIKE_operator(t *testing.T) { func TestString_REGEXP_LIKE_operator(t *testing.T) {
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)")
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN")
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN")
assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN")
} }
func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) {
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)")
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN")
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN")
assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN")
} }
func TestExists(t *testing.T) { func TestExists(t *testing.T) {
assertClauseSerialize(t, EXISTS( assertSerialize(t, EXISTS(
table2. table2.
SELECT(Int(1)). SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)), WHERE(table1Col1.EQ(table2Col3)),
@ -27,17 +27,31 @@ func TestExists(t *testing.T) {
FROM db.table2 FROM db.table2
WHERE table1.col1 = table2.col3 WHERE table1.col1 = table2.col3
))`, int64(1)) ))`, int64(1))
assertSerialize(t, EXISTS(
SELECT(Int(1)),
).EQ(Bool(true)),
`((EXISTS (
SELECT $1
)) = $2)`, int64(1), true)
assertProjectionSerialize(t, EXISTS(
SELECT(Int(1)),
).AS("exists"),
`(EXISTS (
SELECT $1
)) AS "exists"`, int64(1))
} }
func TestIN(t *testing.T) { func TestIN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)),
`($1 IN (( `($1 IN ((
SELECT table1.col1 AS "table1.col1" SELECT table1.col1 AS "table1.col1"
FROM db.table1 FROM db.table1
)))`, float64(1.11)) )))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) IN (( `(ROW($1, table1.col1) IN ((
SELECT table2.col3 AS "table2.col3", SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1" table3.col1 AS "table3.col1"
@ -47,16 +61,34 @@ func TestIN(t *testing.T) {
func TestNOT_IN(t *testing.T) { func TestNOT_IN(t *testing.T) {
assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)),
`($1 NOT IN (( `($1 NOT IN ((
SELECT table1.col1 AS "table1.col1" SELECT table1.col1 AS "table1.col1"
FROM db.table1 FROM db.table1
)))`, float64(1.11)) )))`, float64(1.11))
assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)),
`(ROW($1, table1.col1) NOT IN (( `(ROW($1, table1.col1) NOT IN ((
SELECT table2.col3 AS "table2.col3", SELECT table2.col3 AS "table2.col3",
table3.col1 AS "table3.col1" table3.col1 AS "table3.col1"
FROM db.table2 FROM db.table2
)))`, int64(12)) )))`, int64(12))
} }
func TestReservedWordEscaped(t *testing.T) {
var table1ColUser = IntervalColumn("user")
var table1ColVariadic = IntervalColumn("VARIADIC")
var table1ColProcedure = IntervalColumn("procedure")
_ = NewTable(
"db",
"table1",
table1ColUser,
table1ColVariadic,
table1ColProcedure,
)
assertSerialize(t, table1ColUser, `table1."user"`)
assertSerialize(t, table1ColVariadic, `table1."VARIADIC"`)
assertSerialize(t, table1ColProcedure, `table1.procedure`)
}

View file

@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression
// StringExpression interface // StringExpression interface
type StringExpression = jet.StringExpression type StringExpression = jet.StringExpression
// NumericExpression interface
type NumericExpression = jet.NumericExpression
// IntegerExpression interface // IntegerExpression interface
type IntegerExpression = jet.IntegerExpression type IntegerExpression = jet.IntegerExpression

View file

@ -1,7 +1,7 @@
package postgres package postgres
import ( import (
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )

View file

@ -0,0 +1,224 @@
package postgres
import (
"fmt"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/utils"
"strconv"
"strings"
"time"
)
type quantityAndUnit = float64
// Interval unit types
const (
YEAR quantityAndUnit = 123456789 + iota
MONTH
WEEK
DAY
HOUR
MINUTE
SECOND
MILLISECOND
MICROSECOND
DECADE
CENTURY
MILLENNIUM
)
// IntervalExpression is representation of postgres INTERVAL
type IntervalExpression interface {
jet.IsInterval
jet.Expression
EQ(rhs IntervalExpression) BoolExpression
NOT_EQ(rhs IntervalExpression) BoolExpression
IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression
LT(rhs IntervalExpression) BoolExpression
LT_EQ(rhs IntervalExpression) BoolExpression
GT(rhs IntervalExpression) BoolExpression
GT_EQ(rhs IntervalExpression) BoolExpression
ADD(rhs IntervalExpression) IntervalExpression
SUB(rhs IntervalExpression) IntervalExpression
MUL(rhs NumericExpression) IntervalExpression
DIV(rhs NumericExpression) IntervalExpression
}
type intervalInterfaceImpl struct {
jet.IsIntervalImpl
parent IntervalExpression
}
func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression {
return jet.Eq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression {
return jet.NotEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return jet.IsDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression {
return jet.IsNotDistinctFrom(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression {
return jet.Lt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression {
return jet.LtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression {
return jet.Gt(i.parent, rhs)
}
func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression {
return jet.GtEq(i.parent, rhs)
}
func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression {
return IntervalExp(jet.Add(i.parent, rhs))
}
func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression {
return IntervalExp(jet.Sub(i.parent, rhs))
}
func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression {
return IntervalExp(jet.Mul(i.parent, rhs))
}
func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression {
return IntervalExp(jet.Div(i.parent, rhs))
}
type intervalExpression struct {
jet.Expression
intervalInterfaceImpl
}
// INTERVAL creates new interval expression from the list of quantity-unit pairs.
// For example: INTERVAL(1, DAY, 3, MINUTE)
func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression {
quantityAndUnitLen := len(quantityAndUnit)
if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 {
panic("jet: invalid number of quantity and unit fields")
}
fields := []string{}
for i := 0; i < len(quantityAndUnit); i += 2 {
quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64)
unitString := unitToString(quantityAndUnit[i+1])
fields = append(fields, quantity+" "+unitString)
}
intervalStr := fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " "))
newInterval := &intervalExpression{}
newInterval.Expression = jet.Raw(intervalStr, newInterval)
newInterval.intervalInterfaceImpl.parent = newInterval
return newInterval
}
// INTERVALd creates interval expression from time.Duration
func INTERVALd(duration time.Duration) IntervalExpression {
days, hours, minutes, seconds, microseconds := utils.ExtractDateTimeComponents(duration)
quantityAndUnits := []quantityAndUnit{}
if days > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days))
quantityAndUnits = append(quantityAndUnits, DAY)
}
if hours > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours))
quantityAndUnits = append(quantityAndUnits, HOUR)
}
if minutes > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes))
quantityAndUnits = append(quantityAndUnits, MINUTE)
}
if seconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds))
quantityAndUnits = append(quantityAndUnits, SECOND)
}
if microseconds > 0 {
quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds))
quantityAndUnits = append(quantityAndUnits, MICROSECOND)
}
if len(quantityAndUnits) == 0 {
return INTERVAL(0, MICROSECOND)
}
return INTERVAL(quantityAndUnits...)
}
func unitToString(unit quantityAndUnit) string {
switch unit {
case YEAR:
return "YEAR"
case MONTH:
return "MONTH"
case WEEK:
return "WEEK"
case DAY:
return "DAY"
case HOUR:
return "HOUR"
case MINUTE:
return "MINUTE"
case SECOND:
return "SECOND"
case MILLISECOND:
return "MILLISECOND"
case MICROSECOND:
return "MICROSECOND"
case DECADE:
return "DECADE"
case CENTURY:
return "CENTURY"
case MILLENNIUM:
return "MILLENNIUM"
default:
panic("jet: invalid INTERVAL unit type")
}
}
//---------------------------------------------------//
type intervalWrapper struct {
intervalInterfaceImpl
Expression
}
func newIntervalExpressionWrap(expression Expression) IntervalExpression {
intervalWrap := &intervalWrapper{Expression: expression}
intervalWrap.intervalInterfaceImpl.parent = intervalWrap
return intervalWrap
}
// IntervalExp is interval expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as interval expression.
// Does not add sql cast to generated sql builder output.
func IntervalExp(expression Expression) IntervalExpression {
return newIntervalExpressionWrap(expression)
}

View file

@ -0,0 +1,84 @@
package postgres
import (
"testing"
"time"
)
func TestINTERVAL(t *testing.T) {
assertSerialize(t, INTERVAL(1, YEAR), "INTERVAL '1 YEAR'")
assertSerialize(t, INTERVAL(1, MONTH), "INTERVAL '1 MONTH'")
assertSerialize(t, INTERVAL(1, WEEK), "INTERVAL '1 WEEK'")
assertSerialize(t, INTERVAL(1, DAY), "INTERVAL '1 DAY'")
assertSerialize(t, INTERVAL(1, HOUR), "INTERVAL '1 HOUR'")
assertSerialize(t, INTERVAL(1, MINUTE), "INTERVAL '1 MINUTE'")
assertSerialize(t, INTERVAL(1, SECOND), "INTERVAL '1 SECOND'")
assertSerialize(t, INTERVAL(1, MILLISECOND), "INTERVAL '1 MILLISECOND'")
assertSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL '1 MICROSECOND'")
assertSerialize(t, INTERVAL(1, DECADE), "INTERVAL '1 DECADE'")
assertSerialize(t, INTERVAL(1, CENTURY), "INTERVAL '1 CENTURY'")
assertSerialize(t, INTERVAL(1, MILLENNIUM), "INTERVAL '1 MILLENNIUM'")
assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH), "INTERVAL '1 YEAR 10 MONTH'")
assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY), "INTERVAL '1 YEAR 10 MONTH 20 DAY'")
assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), "INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR'")
assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL")
assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`)
f := 5.2
assertSerialize(t, INTERVAL(f, YEAR), "INTERVAL '5.2 YEAR'")
}
func TestINTERVALd(t *testing.T) {
assertSerialize(t, INTERVALd(0), "INTERVAL '0 MICROSECOND'")
assertSerialize(t, INTERVALd(1*time.Microsecond), "INTERVAL '1 MICROSECOND'")
assertSerialize(t, INTERVALd(1*time.Millisecond), "INTERVAL '1000 MICROSECOND'")
assertSerialize(t, INTERVALd(1*time.Second), "INTERVAL '1 SECOND'")
assertSerialize(t, INTERVALd(1*time.Minute), "INTERVAL '1 MINUTE'")
assertSerialize(t, INTERVALd(1*time.Hour), "INTERVAL '1 HOUR'")
assertSerialize(t, INTERVALd(24*time.Hour), "INTERVAL '1 DAY'")
assertSerialize(t, INTERVALd(24*time.Hour+2*time.Hour+3*time.Minute+4*time.Second+5*time.Microsecond),
"INTERVAL '1 DAY 2 HOUR 3 MINUTE 4 SECOND 5 MICROSECOND'")
}
func TestINTERVAL_InvalidParams(t *testing.T) {
assertPanicErr(t, func() { INTERVAL() }, "jet: invalid number of quantity and unit fields")
assertPanicErr(t, func() { INTERVAL(1) }, "jet: invalid number of quantity and unit fields")
assertPanicErr(t, func() { INTERVAL(1, 2) }, "jet: invalid INTERVAL unit type")
}
func TestDateTimeIntervalArithmetic(t *testing.T) {
assertSerialize(t, table2ColDate.ADD(INTERVAL(1, HOUR)), "(table2.col_date + INTERVAL '1 HOUR')")
assertSerialize(t, table2ColDate.SUB(INTERVAL(1, HOUR)), "(table2.col_date - INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTime.ADD(INTERVAL(1, HOUR)), "(table2.col_time + INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTime.SUB(INTERVAL(1, HOUR)), "(table2.col_time - INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTimez.ADD(INTERVAL(1, HOUR)), "(table2.col_timez + INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTimez.SUB(INTERVAL(1, HOUR)), "(table2.col_timez - INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTimestamp.ADD(INTERVAL(1, HOUR)), "(table2.col_timestamp + INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTimestamp.SUB(INTERVAL(1, HOUR)), "(table2.col_timestamp - INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTimestampz.ADD(INTERVAL(1, HOUR)), "(table2.col_timestampz + INTERVAL '1 HOUR')")
assertSerialize(t, table2ColTimestampz.SUB(INTERVAL(1, HOUR)), "(table2.col_timestampz - INTERVAL '1 HOUR')")
}
func TestIntervalExpressionMethods(t *testing.T) {
assertSerialize(t, table1ColInterval.EQ(table2ColInterval), "(table1.col_interval = table2.col_interval)")
assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')")
assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')")
assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)),
"((table1.col_interval = INTERVAL '11 MINUTE') = $1)", false)
assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)")
assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)")
assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)")
assertSerialize(t, table1ColInterval.LT(table2ColInterval), "(table1.col_interval < table2.col_interval)")
assertSerialize(t, table1ColInterval.LT_EQ(table2ColInterval), "(table1.col_interval <= table2.col_interval)")
assertSerialize(t, table1ColInterval.GT(table2ColInterval), "(table1.col_interval > table2.col_interval)")
assertSerialize(t, table1ColInterval.GT_EQ(table2ColInterval), "(table1.col_interval >= table2.col_interval)")
assertSerialize(t, table1ColInterval.ADD(table2ColInterval), "(table1.col_interval + table2.col_interval)")
assertSerialize(t, table1ColInterval.SUB(table2ColInterval), "(table1.col_interval - table2.col_interval)")
assertSerialize(t, table1ColInterval.MUL(table2ColInt), "(table1.col_interval * table2.col_int)")
assertSerialize(t, table1ColInterval.MUL(table2ColFloat), "(table1.col_interval * table2.col_float)")
assertSerialize(t, table1ColInterval.DIV(table2ColInt), "(table1.col_interval / table2.col_int)")
assertSerialize(t, table1ColInterval.DIV(table2ColFloat), "(table1.col_interval / table2.col_float)")
}

View file

@ -6,45 +6,45 @@ import (
) )
func TestBool(t *testing.T) { func TestBool(t *testing.T) {
assertClauseSerialize(t, Bool(false), `$1`, false) assertSerialize(t, Bool(false), `$1`, false)
} }
func TestInt(t *testing.T) { func TestInt(t *testing.T) {
assertClauseSerialize(t, Int(11), `$1`, int64(11)) assertSerialize(t, Int(11), `$1`, int64(11))
} }
func TestFloat(t *testing.T) { func TestFloat(t *testing.T) {
assertClauseSerialize(t, Float(12.34), `$1`, float64(12.34)) assertSerialize(t, Float(12.34), `$1`, float64(12.34))
} }
func TestString(t *testing.T) { func TestString(t *testing.T) {
assertClauseSerialize(t, String("Some text"), `$1`, "Some text") assertSerialize(t, String("Some text"), `$1`, "Some text")
} }
func TestDate(t *testing.T) { func TestDate(t *testing.T) {
assertClauseSerialize(t, Date(2014, time.January, 2), `$1::date`, "2014-01-02") assertSerialize(t, Date(2014, time.January, 2), `$1::date`, "2014-01-02")
assertClauseSerialize(t, DateT(time.Now()), `$1::date`) assertSerialize(t, DateT(time.Now()), `$1::date`)
} }
func TestTime(t *testing.T) { func TestTime(t *testing.T) {
assertClauseSerialize(t, Time(10, 15, 30), `$1::time without time zone`, "10:15:30") assertSerialize(t, Time(10, 15, 30), `$1::time without time zone`, "10:15:30")
assertClauseSerialize(t, TimeT(time.Now()), `$1::time without time zone`) assertSerialize(t, TimeT(time.Now()), `$1::time without time zone`)
} }
func TestTimez(t *testing.T) { func TestTimez(t *testing.T) {
assertClauseSerialize(t, Timez(10, 15, 30, 0, "UTC"), assertSerialize(t, Timez(10, 15, 30, 0, "UTC"),
`$1::time with time zone`, "10:15:30 UTC") `$1::time with time zone`, "10:15:30 UTC")
assertClauseSerialize(t, TimezT(time.Now()), `$1::time with time zone`) assertSerialize(t, TimezT(time.Now()), `$1::time with time zone`)
} }
func TestTimestamp(t *testing.T) { func TestTimestamp(t *testing.T) {
assertClauseSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30),
`$1::timestamp without time zone`, "2010-03-30 10:15:30") `$1::timestamp without time zone`, "2010-03-30 10:15:30")
assertClauseSerialize(t, TimestampT(time.Now()), `$1::timestamp without time zone`) assertSerialize(t, TimestampT(time.Now()), `$1::timestamp without time zone`)
} }
func TestTimestampz(t *testing.T) { func TestTimestampz(t *testing.T) {
assertClauseSerialize(t, Timestampz(2010, time.March, 30, 10, 15, 30, 0, "UTC"), assertSerialize(t, Timestampz(2010, time.March, 30, 10, 15, 30, 0, "UTC"),
`$1::timestamp with time zone`, "2010-03-30 10:15:30 UTC") `$1::timestamp with time zone`, "2010-03-30 10:15:30 UTC")
assertClauseSerialize(t, TimestampzT(time.Now()), `$1::timestamp with time zone`) assertSerialize(t, TimestampzT(time.Now()), `$1::timestamp with time zone`)
} }

View file

@ -12,17 +12,17 @@ func TestJoinNilInputs(t *testing.T) {
} }
func TestINNER_JOIN(t *testing.T) { func TestINNER_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).
INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int) INNER JOIN db.table2 ON (table1.col_int = table2.col_int)
INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(Int(1))). INNER_JOIN(table2, table1ColInt.EQ(Int(1))).
INNER_JOIN(table3, table1ColInt.EQ(Int(2))), INNER_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -31,17 +31,17 @@ INNER JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
} }
func TestLEFT_JOIN(t *testing.T) { func TestLEFT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)
LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). LEFT_JOIN(table2, table1ColInt.EQ(Int(1))).
LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), LEFT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -50,17 +50,17 @@ LEFT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
} }
func TestRIGHT_JOIN(t *testing.T) { func TestRIGHT_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)
RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))).
RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -69,17 +69,17 @@ RIGHT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
} }
func TestFULL_JOIN(t *testing.T) { func TestFULL_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1 `db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)).
FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1 `db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int) FULL JOIN db.table2 ON (table1.col_int = table2.col_int)
FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(Int(1))). FULL_JOIN(table2, table1ColInt.EQ(Int(1))).
FULL_JOIN(table3, table1ColInt.EQ(Int(2))), FULL_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1 `db.table1
@ -88,11 +88,11 @@ FULL JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2))
} }
func TestCROSS_JOIN(t *testing.T) { func TestCROSS_JOIN(t *testing.T) {
assertClauseSerialize(t, table1. assertSerialize(t, table1.
CROSS_JOIN(table2), CROSS_JOIN(table2),
`db.table1 `db.table1
CROSS JOIN db.table2`) CROSS JOIN db.table2`)
assertClauseSerialize(t, table1. assertSerialize(t, table1.
CROSS_JOIN(table2). CROSS_JOIN(table2).
CROSS_JOIN(table3), CROSS_JOIN(table3),
`db.table1 `db.table1

View file

@ -7,3 +7,6 @@ type Statement = jet.Statement
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection = jet.Projection type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList

View file

@ -1,9 +1,10 @@
package postgres package postgres
import ( import (
"testing"
"github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"testing"
) )
var table1Col1 = IntegerColumn("col1") var table1Col1 = IntegerColumn("col1")
@ -16,6 +17,7 @@ var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColTimestampz = TimestampzColumn("col_timestampz") var table1ColTimestampz = TimestampzColumn("col_timestampz")
var table1ColBool = BoolColumn("col_bool") var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date") var table1ColDate = DateColumn("col_date")
var table1ColInterval = IntervalColumn("col_interval")
var table1 = NewTable( var table1 = NewTable(
"db", "db",
@ -30,6 +32,7 @@ var table1 = NewTable(
table1ColDate, table1ColDate,
table1ColTimestamp, table1ColTimestamp,
table1ColTimestampz, table1ColTimestampz,
table1ColInterval,
) )
var table2Col3 = IntegerColumn("col3") var table2Col3 = IntegerColumn("col3")
@ -43,6 +46,7 @@ var table2ColTimez = TimezColumn("col_timez")
var table2ColTimestamp = TimestampColumn("col_timestamp") var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date") var table2ColDate = DateColumn("col_date")
var table2ColInterval = IntervalColumn("col_interval")
var table2 = NewTable( var table2 = NewTable(
"db", "db",
@ -58,6 +62,7 @@ var table2 = NewTable(
table2ColDate, table2ColDate,
table2ColTimestamp, table2ColTimestamp,
table2ColTimestampz, table2ColTimestampz,
table2ColInterval,
) )
var table3Col1 = IntegerColumn("col1") var table3Col1 = IntegerColumn("col1")
@ -70,7 +75,7 @@ var table3 = NewTable(
table3ColInt, table3ColInt,
table3StrCol) table3StrCol)
func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) testutils.AssertClauseSerialize(t, Dialect, clause, query, args...)
} }
@ -84,3 +89,4 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st
var assertStatementSql = testutils.AssertStatementSql var assertStatementSql = testutils.AssertStatementSql
var assertStatementSqlErr = testutils.AssertStatementSqlErr var assertStatementSqlErr = testutils.AssertStatementSqlErr
var assertPanicErr = testutils.AssertPanicErr

View file

@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -10,10 +10,10 @@ import (
func TestNullByteArray(t *testing.T) { func TestNullByteArray(t *testing.T) {
var array NullByteArray var array NullByteArray
assert.NilError(t, array.Scan(nil)) assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan([]byte("bytea"))) assert.NoError(t, array.Scan([]byte("bytea")))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
assert.Equal(t, string(array.ByteArray), string([]byte("bytea"))) assert.Equal(t, string(array.ByteArray), string([]byte("bytea")))
@ -23,21 +23,21 @@ func TestNullByteArray(t *testing.T) {
func TestNullTime(t *testing.T) { func TestNullTime(t *testing.T) {
var array NullTime var array NullTime
assert.NilError(t, array.Scan(nil)) assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) assert.Equal(t, array.Valid, false)
time := time.Now() time := time.Now()
assert.NilError(t, array.Scan(time)) assert.NoError(t, array.Scan(time))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, time) assert.Equal(t, value, time)
assert.NilError(t, array.Scan([]byte("13:10:11"))) assert.NoError(t, array.Scan([]byte("13:10:11")))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
assert.NilError(t, array.Scan("13:10:11")) assert.NoError(t, array.Scan("13:10:11"))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
@ -48,10 +48,10 @@ func TestNullTime(t *testing.T) {
func TestNullInt8(t *testing.T) { func TestNullInt8(t *testing.T) {
var array NullInt8 var array NullInt8
assert.NilError(t, array.Scan(nil)) assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11))) assert.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, int8(11)) assert.Equal(t, value, int8(11))
@ -62,25 +62,25 @@ func TestNullInt8(t *testing.T) {
func TestNullInt16(t *testing.T) { func TestNullInt16(t *testing.T) {
var array NullInt16 var array NullInt16
assert.NilError(t, array.Scan(nil)) assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11))) assert.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, int16(11)) assert.Equal(t, value, int16(11))
assert.NilError(t, array.Scan(int16(20))) assert.NoError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int16(20)) assert.Equal(t, value, int16(20))
assert.NilError(t, array.Scan(int8(30))) assert.NoError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int16(30)) assert.Equal(t, value, int16(30))
assert.NilError(t, array.Scan(uint8(30))) assert.NoError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int16(30)) assert.Equal(t, value, int16(30))
@ -91,35 +91,35 @@ func TestNullInt16(t *testing.T) {
func TestNullInt32(t *testing.T) { func TestNullInt32(t *testing.T) {
var array NullInt32 var array NullInt32
assert.NilError(t, array.Scan(nil)) assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(int64(11))) assert.NoError(t, array.Scan(int64(11)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, int32(11)) assert.Equal(t, value, int32(11))
assert.NilError(t, array.Scan(int32(32))) assert.NoError(t, array.Scan(int32(32)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(32)) assert.Equal(t, value, int32(32))
assert.NilError(t, array.Scan(int16(20))) assert.NoError(t, array.Scan(int16(20)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(20)) assert.Equal(t, value, int32(20))
assert.NilError(t, array.Scan(uint16(16))) assert.NoError(t, array.Scan(uint16(16)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(16)) assert.Equal(t, value, int32(16))
assert.NilError(t, array.Scan(int8(30))) assert.NoError(t, array.Scan(int8(30)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(30)) assert.Equal(t, value, int32(30))
assert.NilError(t, array.Scan(uint8(30))) assert.NoError(t, array.Scan(uint8(30)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, int32(30)) assert.Equal(t, value, int32(30))
@ -130,15 +130,15 @@ func TestNullInt32(t *testing.T) {
func TestNullFloat32(t *testing.T) { func TestNullFloat32(t *testing.T) {
var array NullFloat32 var array NullFloat32
assert.NilError(t, array.Scan(nil)) assert.NoError(t, array.Scan(nil))
assert.Equal(t, array.Valid, false) assert.Equal(t, array.Valid, false)
assert.NilError(t, array.Scan(float64(64))) assert.NoError(t, array.Scan(float64(64)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ := array.Value() value, _ := array.Value()
assert.Equal(t, value, float32(64)) assert.Equal(t, value, float32(64))
assert.NilError(t, array.Scan(float32(32))) assert.NoError(t, array.Scan(float32(32)))
assert.Equal(t, array.Valid, true) assert.Equal(t, array.Valid, true)
value, _ = array.Value() value, _ = array.Value()
assert.Equal(t, value, float32(32)) assert.Equal(t, value, float32(32))

View file

@ -2,28 +2,28 @@ package qrm
import ( import (
"github.com/google/uuid" "github.com/google/uuid"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"reflect" "reflect"
"testing" "testing"
"time" "time"
) )
func TestIsSimpleModelType(t *testing.T) { func TestIsSimpleModelType(t *testing.T) {
assert.Assert(t, isSimpleModelType(reflect.TypeOf(int8(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(int8(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(int16(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(int16(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(int32(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(int32(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(int64(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(int64(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(uint8(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(uint16(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(uint32(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) assert.True(t, isSimpleModelType(reflect.TypeOf(uint64(11))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) assert.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) assert.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) assert.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text"))))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(time.Now()))) assert.True(t, isSimpleModelType(reflect.TypeOf(time.Now())))
assert.Assert(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) assert.True(t, isSimpleModelType(reflect.TypeOf(uuid.New())))
complexModelType := struct { complexModelType := struct {
Field1 string Field1 string
@ -32,4 +32,6 @@ func TestIsSimpleModelType(t *testing.T) {
assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false) assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false) assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false)
assert.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false)
} }

View file

@ -1,19 +1,20 @@
package mysql package mysql
import ( import (
"fmt" "testing"
"time"
"github.com/google/uuid"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view"
"github.com/go-jet/jet/tests/testdata/results/common" "github.com/go-jet/jet/tests/testdata/results/common"
"github.com/google/uuid"
"time"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing"
) )
func TestAllTypes(t *testing.T) { func TestAllTypes(t *testing.T) {
@ -25,7 +26,7 @@ func TestAllTypes(t *testing.T) {
LIMIT(2). LIMIT(2).
Query(db, &dest) Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
@ -44,7 +45,7 @@ func TestAllTypesViewSelect(t *testing.T) {
dest := []AllTypesView{} dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert
@ -73,10 +74,10 @@ func TestUUID(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Assert(t, dest.StrUUID != nil) assert.True(t, dest.StrUUID != nil)
assert.Assert(t, dest.UUID.String() != uuid.UUID{}.String()) assert.True(t, dest.UUID.String() != uuid.UUID{}.String())
assert.Assert(t, dest.StrUUID.String() != uuid.UUID{}.String()) assert.True(t, dest.StrUUID.String() != uuid.UUID{}.String())
assert.Equal(t, dest.StrUUID.String(), dest.BinUUID.String()) assert.Equal(t, dest.StrUUID.String(), dest.BinUUID.String())
} }
@ -118,7 +119,7 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -209,7 +210,7 @@ FROM test_sample.all_types;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json")
} }
@ -306,7 +307,7 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json") testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json")
} }
@ -443,7 +444,7 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -506,20 +507,16 @@ func TestStringOperators(t *testing.T) {
REGEXP_LIKE(AllTypes.Text, String("aba"), "i"), REGEXP_LIKE(AllTypes.Text, String("aba"), "i"),
}...) }...)
} }
//_, args, _ := query.Sql()
//fmt.Println(query.Sql())
//fmt.Println(args[15])
query := SELECT(projectionList[0], projectionList[1:]...). query := SELECT(projectionList[0], projectionList[1:]...).
FROM(AllTypes) FROM(AllTypes)
fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
@ -555,32 +552,49 @@ func TestTimeExpressions(t *testing.T) {
AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(AllTypes.Time),
AllTypes.Time.GT_EQ(Time(14, 26, 36)), AllTypes.Time.GT_EQ(Time(14, 26, 36)),
AllTypes.Time.ADD(INTERVAL(10, MINUTE)),
AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)),
AllTypes.Time.ADD(INTERVALd(3*time.Hour)),
AllTypes.Time.SUB(INTERVAL(20, MINUTE)),
AllTypes.Time.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)),
AllTypes.Time.SUB(INTERVALd(3*time.Minute)),
AllTypes.Time.ADD(INTERVAL(20, MINUTE)).SUB(INTERVAL(11, HOUR)),
CURRENT_TIME(), CURRENT_TIME(),
CURRENT_TIME(3), CURRENT_TIME(3),
) )
//fmt.Println(query.Sql()) //fmt.Println(query.DebugSql())
testutils.AssertStatementSql(t, query, ` testutils.AssertDebugStatementSql(t, query, `
SELECT CAST(? AS TIME), SELECT CAST('20:34:58' AS TIME),
all_types.time = all_types.time, all_types.time = all_types.time,
all_types.time = CAST(? AS TIME), all_types.time = CAST('23:06:06' AS TIME),
all_types.time = CAST(? AS TIME), all_types.time = CAST('22:06:06.011' AS TIME),
all_types.time = CAST(? AS TIME), all_types.time = CAST('21:06:06.011111' AS TIME),
all_types.time_ptr != all_types.time, all_types.time_ptr != all_types.time,
all_types.time_ptr != CAST(? AS TIME), all_types.time_ptr != CAST('20:16:06' AS TIME),
NOT(all_types.time <=> all_types.time), NOT(all_types.time <=> all_types.time),
NOT(all_types.time <=> CAST(? AS TIME)), NOT(all_types.time <=> CAST('19:26:06' AS TIME)),
all_types.time <=> all_types.time, all_types.time <=> all_types.time,
all_types.time <=> CAST(? AS TIME), all_types.time <=> CAST('18:36:06' AS TIME),
all_types.time < all_types.time, all_types.time < all_types.time,
all_types.time < CAST(? AS TIME), all_types.time < CAST('17:46:06' AS TIME),
all_types.time <= all_types.time, all_types.time <= all_types.time,
all_types.time <= CAST(? AS TIME), all_types.time <= CAST('16:56:56' AS TIME),
all_types.time > all_types.time, all_types.time > all_types.time,
all_types.time > CAST(? AS TIME), all_types.time > CAST('15:16:46' AS TIME),
all_types.time >= all_types.time, all_types.time >= all_types.time,
all_types.time >= CAST(? AS TIME), all_types.time >= CAST('14:26:36' AS TIME),
all_types.time + INTERVAL 10 MINUTE,
all_types.time + INTERVAL all_types.integer MINUTE,
all_types.time + INTERVAL 3 HOUR,
all_types.time - INTERVAL 20 MINUTE,
all_types.time - INTERVAL all_types.small_int MINUTE,
all_types.time - INTERVAL 3 MINUTE,
(all_types.time + INTERVAL 20 MINUTE) - INTERVAL 11 HOUR,
CURRENT_TIME, CURRENT_TIME,
CURRENT_TIME(3) CURRENT_TIME(3)
FROM test_sample.all_types; FROM test_sample.all_types;
@ -590,7 +604,7 @@ FROM test_sample.all_types;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestDateExpressions(t *testing.T) { func TestDateExpressions(t *testing.T) {
@ -621,10 +635,18 @@ func TestDateExpressions(t *testing.T) {
AllTypes.Date.GT_EQ(AllTypes.Date), AllTypes.Date.GT_EQ(AllTypes.Date),
AllTypes.Date.GT_EQ(Date(2019, 2, 3)), AllTypes.Date.GT_EQ(Date(2019, 2, 3)),
AllTypes.Date.ADD(INTERVAL("10:20.000100", MINUTE_MICROSECOND)),
AllTypes.Date.ADD(INTERVALe(AllTypes.BigInt, MINUTE)),
AllTypes.Date.ADD(INTERVALd(15*time.Hour)),
AllTypes.Date.SUB(INTERVAL(20, MINUTE)),
AllTypes.Date.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)),
AllTypes.Date.SUB(INTERVALd(3*time.Minute)),
CURRENT_DATE(), CURRENT_DATE(),
) )
//fmt.Println(query.Sql()) //fmt.Println(query.DebugSql())
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
SELECT CAST(? AS DATE), SELECT CAST(? AS DATE),
@ -644,6 +666,12 @@ SELECT CAST(? AS DATE),
all_types.date > CAST(? AS DATE), all_types.date > CAST(? AS DATE),
all_types.date >= all_types.date, all_types.date >= all_types.date,
all_types.date >= CAST(? AS DATE), all_types.date >= CAST(? AS DATE),
all_types.date + INTERVAL ? MINUTE_MICROSECOND,
all_types.date + INTERVAL all_types.big_int MINUTE,
all_types.date + INTERVAL 15 HOUR,
all_types.date - INTERVAL 20 MINUTE,
all_types.date - INTERVAL all_types.small_int MINUTE,
all_types.date - INTERVAL 3 MINUTE,
CURRENT_DATE CURRENT_DATE
FROM test_sample.all_types; FROM test_sample.all_types;
`) `)
@ -651,7 +679,7 @@ FROM test_sample.all_types;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestDateTimeExpressions(t *testing.T) { func TestDateTimeExpressions(t *testing.T) {
@ -683,11 +711,19 @@ func TestDateTimeExpressions(t *testing.T) {
AllTypes.DateTime.GT_EQ(AllTypes.DateTime), AllTypes.DateTime.GT_EQ(AllTypes.DateTime),
AllTypes.DateTime.GT_EQ(dateTime), AllTypes.DateTime.GT_EQ(dateTime),
AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)),
AllTypes.DateTime.ADD(INTERVALd(2*time.Hour)),
AllTypes.DateTime.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
AllTypes.DateTime.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)),
AllTypes.DateTime.SUB(INTERVALd(3*time.Hour)),
NOW(), NOW(),
NOW(1), NOW(1),
) )
//fmt.Println(query.DebugSql()) //Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, ` testutils.AssertDebugStatementSql(t, query, `
SELECT all_types.date_time = all_types.date_time, SELECT all_types.date_time = all_types.date_time,
@ -706,6 +742,12 @@ SELECT all_types.date_time = all_types.date_time,
all_types.date_time > CAST('2019-06-06 10:02:46' AS DATETIME), all_types.date_time > CAST('2019-06-06 10:02:46' AS DATETIME),
all_types.date_time >= all_types.date_time, all_types.date_time >= all_types.date_time,
all_types.date_time >= CAST('2019-06-06 10:02:46' AS DATETIME), all_types.date_time >= CAST('2019-06-06 10:02:46' AS DATETIME),
all_types.date_time + INTERVAL '05:10:20.000100' HOUR_MICROSECOND,
all_types.date_time + INTERVAL all_types.big_int HOUR,
all_types.date_time + INTERVAL 2 HOUR,
all_types.date_time - INTERVAL '05:10:20.000100' HOUR_MICROSECOND,
all_types.date_time - INTERVAL all_types.integer_ptr HOUR,
all_types.date_time - INTERVAL 3 HOUR,
NOW(), NOW(),
NOW(1) NOW(1)
FROM test_sample.all_types; FROM test_sample.all_types;
@ -714,7 +756,7 @@ FROM test_sample.all_types;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestTimestampExpressions(t *testing.T) { func TestTimestampExpressions(t *testing.T) {
@ -746,6 +788,14 @@ func TestTimestampExpressions(t *testing.T) {
AllTypes.Timestamp.GT_EQ(AllTypes.Timestamp), AllTypes.Timestamp.GT_EQ(AllTypes.Timestamp),
AllTypes.Timestamp.GT_EQ(timestamp), AllTypes.Timestamp.GT_EQ(timestamp),
AllTypes.Timestamp.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
AllTypes.Timestamp.ADD(INTERVALe(AllTypes.BigInt, HOUR)),
AllTypes.Timestamp.ADD(INTERVALd(2*time.Hour)),
AllTypes.Timestamp.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
AllTypes.Timestamp.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)),
AllTypes.Timestamp.SUB(INTERVALd(3*time.Hour)),
CURRENT_TIMESTAMP(), CURRENT_TIMESTAMP(),
CURRENT_TIMESTAMP(2), CURRENT_TIMESTAMP(2),
) )
@ -769,6 +819,12 @@ SELECT all_types.timestamp = all_types.timestamp,
all_types.timestamp > TIMESTAMP('2019-06-06 10:02:46'), all_types.timestamp > TIMESTAMP('2019-06-06 10:02:46'),
all_types.timestamp >= all_types.timestamp, all_types.timestamp >= all_types.timestamp,
all_types.timestamp >= TIMESTAMP('2019-06-06 10:02:46'), all_types.timestamp >= TIMESTAMP('2019-06-06 10:02:46'),
all_types.timestamp + INTERVAL '05:10:20.000100' HOUR_MICROSECOND,
all_types.timestamp + INTERVAL all_types.big_int HOUR,
all_types.timestamp + INTERVAL 2 HOUR,
all_types.timestamp - INTERVAL '05:10:20.000100' HOUR_MICROSECOND,
all_types.timestamp - INTERVAL all_types.integer_ptr HOUR,
all_types.timestamp - INTERVAL 3 HOUR,
CURRENT_TIMESTAMP, CURRENT_TIMESTAMP,
CURRENT_TIMESTAMP(2) CURRENT_TIMESTAMP(2)
FROM test_sample.all_types; FROM test_sample.all_types;
@ -776,13 +832,13 @@ FROM test_sample.all_types;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestTimeLiterals(t *testing.T) { func TestTimeLiterals(t *testing.T) {
loc, err := time.LoadLocation("Europe/Berlin") loc, err := time.LoadLocation("Europe/Berlin")
assert.NilError(t, err) assert.NoError(t, err)
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 351387237, loc) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 351387237, loc)
@ -821,7 +877,7 @@ LIMIT ?;
} }
err = query.Query(db, &dest) err = query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -853,6 +909,60 @@ LIMIT ?;
} }
func TestINTERVAL(t *testing.T) {
query := SELECT(
Date(2000, 2, 10).ADD(INTERVAL(1, MICROSECOND)).
EQ(Timestamp(2000, 2, 10, 0, 0, 0, 1*time.Microsecond)),
Date(2000, 2, 10).SUB(INTERVAL(2, SECOND)),
Date(2000, 2, 10).ADD(INTERVAL(3, MINUTE)),
Date(2000, 2, 10).SUB(INTERVAL(4, HOUR)),
Date(2000, 2, 10).ADD(INTERVAL(5, DAY)),
Date(2000, 2, 10).SUB(INTERVAL(6, MONTH)),
Date(2000, 2, 10).ADD(INTERVAL(7, YEAR)),
Date(2000, 2, 10).ADD(INTERVAL(-7, YEAR)),
Date(2000, 2, 10).ADD(INTERVAL("20.0000100", SECOND_MICROSECOND)),
Date(2000, 2, 10).SUB(INTERVAL("02:20.0000100", MINUTE_MICROSECOND)),
Date(2000, 2, 10).SUB(INTERVAL("11:02:20.0000100", HOUR_MICROSECOND)),
Date(2000, 2, 10).SUB(INTERVAL("100 11:02:20.0000100", DAY_MICROSECOND)),
Date(2000, 2, 10).SUB(INTERVAL("11:02", MINUTE_SECOND)),
Date(2000, 2, 10).SUB(INTERVAL("11:02:20", HOUR_SECOND)),
Date(2000, 2, 10).SUB(INTERVAL("11:02", HOUR_MINUTE)),
Date(2000, 2, 10).SUB(INTERVAL("11 02:03:04", DAY_SECOND)),
Date(2000, 2, 10).SUB(INTERVAL("11 02:03", DAY_MINUTE)),
Date(2000, 2, 10).SUB(INTERVAL("11 2", DAY_HOUR)),
Date(2000, 2, 10).SUB(INTERVAL("2000-2", YEAR_MONTH)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MICROSECOND)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, SECOND)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MINUTE)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, DAY)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, WEEK)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MONTH)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, QUARTER)),
Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, YEAR)),
Date(2000, 2, 10).SUB(INTERVALd(3*time.Microsecond)),
Date(2000, 2, 10).SUB(INTERVALd(-3*time.Microsecond)),
Date(2000, 2, 10).SUB(INTERVALd(3*time.Second)),
Date(2000, 2, 10).SUB(INTERVALd(3*time.Second+4*time.Microsecond)),
Date(2000, 2, 10).SUB(INTERVALd(3*time.Minute+4*time.Second+5*time.Microsecond)),
Date(2000, 2, 10).SUB(INTERVALd(3*time.Hour+4*time.Minute+5*time.Second+6*time.Microsecond)),
Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute+5*time.Second+6*time.Microsecond)),
Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute+5*time.Second)),
Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute)),
Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour)),
Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour)),
Date(2000, 2, 10).SUB(INTERVALd(3*time.Hour)),
Date(2000, 2, 10).SUB(INTERVALd(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond)),
).FROM(AllTypes)
//fmt.Println(query.DebugSql())
err := query.Query(db, &struct{}{})
assert.NoError(t, err)
}
var allTypesJson = ` var allTypesJson = `
[ [
{ {
@ -985,3 +1095,53 @@ var allTypesJson = `
} }
] ]
` `
func TestReservedWord(t *testing.T) {
stmt := SELECT(User.AllColumns).
FROM(User)
// NOTE: A word that follows a period in a qualified name must be an identifier, so it
// need not be quoted even if it is reserved
testutils.AssertDebugStatementSql(t, stmt, `
SELECT user.column AS "user.column",
user.use AS "user.use",
user.ceil AS "user.ceil",
user.commit AS "user.commit",
user.create AS "user.create",
user.default AS "user.default",
user.desc AS "user.desc",
user.empty AS "user.empty",
user.float AS "user.float",
user.join AS "user.join",
user.like AS "user.like",
user.max AS "user.max",
user.rank AS "user.rank"
FROM test_sample.user;
`)
var dest []model.User
err := stmt.Query(db, &dest)
assert.NoError(t, err)
testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, `
[
{
"Column": "Column",
"Use": "CHECK",
"Ceil": "CEIL",
"Commit": "COMMIT",
"Create": "CREATE",
"Default": "DEFAULT",
"Desc": "DESC",
"Empty": "EMPTY",
"Float": "FLOAT",
"Join": "JOIN",
"Like": "LIKE",
"Max": "MAX",
"Rank": "RANK"
}
]
`)
}

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -55,9 +55,9 @@ FROM test_sample.all_types;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest, Result{ testutils.AssertDeepEqual(t, dest, Result{
As1: "test", As1: "test",
Date1: *testutils.Date("2011-02-02"), Date1: *testutils.Date("2011-02-02"),
Time: *testutils.TimeWithoutTimeZone("14:06:10"), Time: *testutils.TimeWithoutTimeZone("14:06:10"),

View file

@ -6,7 +6,7 @@ import (
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/generator/mysql" "github.com/go-jet/jet/generator/mysql"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
@ -25,23 +25,23 @@ func TestGenerator(t *testing.T) {
DBName: "dvds", DBName: "dvds",
}) })
assert.NilError(t, err) assert.NoError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
} }
err := os.RemoveAll(genTestDirRoot) err := os.RemoveAll(genTestDirRoot)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestCmdGenerator(t *testing.T) { func TestCmdGenerator(t *testing.T) {
goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet") goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet")
goInstallJet.Stderr = os.Stderr goInstallJet.Stderr = os.Stderr
err := goInstallJet.Run() err := goInstallJet.Run()
assert.NilError(t, err) assert.NoError(t, err)
err = os.RemoveAll(genTestDir3) err = os.RemoveAll(genTestDir3)
assert.NilError(t, err) assert.NoError(t, err)
cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306", cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306",
"-user=jet", "-password=jet", "-path="+genTestDir3) "-user=jet", "-password=jet", "-path="+genTestDir3)
@ -50,18 +50,18 @@ func TestCmdGenerator(t *testing.T) {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
err = cmd.Run() err = cmd.Run()
assert.NilError(t, err) assert.NoError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
err = os.RemoveAll(genTestDirRoot) err = os.RemoveAll(genTestDirRoot)
assert.NilError(t, err) assert.NoError(t, err)
} }
func assertGeneratedFiles(t *testing.T) { func assertGeneratedFiles(t *testing.T) {
// Table SQL Builder files // Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
@ -71,7 +71,7 @@ func assertGeneratedFiles(t *testing.T) {
// View SQL Builder files // View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
@ -80,14 +80,14 @@ func assertGeneratedFiles(t *testing.T) {
// Enums SQL Builder files // Enums SQL Builder files
enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go") testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go")
testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files // Model files
modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",

View file

@ -6,7 +6,7 @@ import (
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -32,7 +32,7 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES
102, "http://www.yahoo.com", "Yahoo", nil) 102, "http://www.yahoo.com", "Yahoo", nil)
_, err := insertQuery.Exec(db) _, err := insertQuery.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
insertedLinks := []model.Link{} insertedLinks := []model.Link{}
@ -41,18 +41,18 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &insertedLinks) Query(db, &insertedLinks)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(insertedLinks), 3) assert.Equal(t, len(insertedLinks), 3)
assert.DeepEqual(t, insertedLinks[0], postgreTutorial) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
assert.DeepEqual(t, insertedLinks[1], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{
ID: 101, ID: 101,
URL: "http://www.google.com", URL: "http://www.google.com",
Name: "Google", Name: "Google",
}) })
assert.DeepEqual(t, insertedLinks[2], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{
ID: 102, ID: 102,
URL: "http://www.yahoo.com", URL: "http://www.yahoo.com",
Name: "Yahoo", Name: "Yahoo",
@ -80,7 +80,7 @@ INSERT INTO test_sample.link VALUES
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
_, err := stmt.Exec(db) _, err := stmt.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
insertedLinks := []model.Link{} insertedLinks := []model.Link{}
@ -89,9 +89,9 @@ INSERT INTO test_sample.link VALUES
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &insertedLinks) Query(db, &insertedLinks)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(insertedLinks), 1) assert.Equal(t, len(insertedLinks), 1)
assert.DeepEqual(t, insertedLinks[0], postgreTutorial) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
} }
func TestInsertModelObject(t *testing.T) { func TestInsertModelObject(t *testing.T) {
@ -113,7 +113,7 @@ INSERT INTO test_sample.link (url, name) VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go")
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestInsertModelObjectEmptyColumnList(t *testing.T) { func TestInsertModelObjectEmptyColumnList(t *testing.T) {
@ -136,7 +136,7 @@ INSERT INTO test_sample.link VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestInsertModelsObject(t *testing.T) { func TestInsertModelsObject(t *testing.T) {
@ -172,7 +172,7 @@ INSERT INTO test_sample.link (url, name) VALUES
"http://www.yahoo.com", "Yahoo") "http://www.yahoo.com", "Yahoo")
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestInsertUsingMutableColumns(t *testing.T) { func TestInsertUsingMutableColumns(t *testing.T) {
@ -207,14 +207,14 @@ INSERT INTO test_sample.link (url, name, description) VALUES
"http://www.yahoo.com", "Yahoo", nil) "http://www.yahoo.com", "Yahoo", nil)
_, err := stmt.Exec(db) _, err := stmt.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestInsertQuery(t *testing.T) { func TestInsertQuery(t *testing.T) {
_, err := Link.DELETE(). _, err := Link.DELETE().
WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))). WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))).
Exec(db) Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
var expectedSQL = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name) ( INSERT INTO test_sample.link (url, name) (
@ -236,7 +236,7 @@ INSERT INTO test_sample.link (url, name) (
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1))
_, err = query.Exec(db) _, err = query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
youtubeLinks := []model.Link{} youtubeLinks := []model.Link{}
err = Link. err = Link.
@ -244,7 +244,7 @@ INSERT INTO test_sample.link (url, name) (
WHERE(Link.Name.EQ(String("Youtube"))). WHERE(Link.Name.EQ(String("Youtube"))).
Query(db, &youtubeLinks) Query(db, &youtubeLinks)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(youtubeLinks), 2) assert.Equal(t, len(youtubeLinks), 2)
} }
@ -283,5 +283,5 @@ func TestInsertWithExecContext(t *testing.T) {
func cleanUpLinkTable(t *testing.T) { func cleanUpLinkTable(t *testing.T) {
_, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db) _, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -16,7 +16,7 @@ LOCK TABLES dvds.customer READ;
`) `)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestLockWrite(t *testing.T) { func TestLockWrite(t *testing.T) {
@ -27,7 +27,7 @@ LOCK TABLES dvds.customer WRITE;
`) `)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestUnlockTables(t *testing.T) { func TestUnlockTables(t *testing.T) {
@ -38,5 +38,5 @@ UNLOCK TABLES;
`) `)
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }

View file

@ -1,14 +1,13 @@
package mysql package mysql
import ( import (
"fmt"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/mysql"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -31,9 +30,9 @@ WHERE actor.actor_id = ?;
actor := model.Actor{} actor := model.Actor{}
err := query.Query(db, &actor) err := query.Query(db, &actor)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, actor, actor2) testutils.AssertDeepEqual(t, actor, actor2)
} }
var actor2 = model.Actor{ var actor2 = model.Actor{
@ -60,10 +59,10 @@ ORDER BY actor.actor_id;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 200) assert.Equal(t, len(dest), 200)
assert.DeepEqual(t, dest[1], actor2) testutils.AssertDeepEqual(t, dest[1], actor2)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
//testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json")
@ -137,7 +136,7 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -177,7 +176,7 @@ func TestSubQuery(t *testing.T) {
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.SaveJsonFile(dest, "mysql/testdata/r_rating_films.json") //testutils.SaveJsonFile(dest, "mysql/testdata/r_rating_films.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/r_rating_films.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/r_rating_films.json")
@ -230,7 +229,7 @@ LIMIT ?;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestSelectUNION(t *testing.T) { func TestSelectUNION(t *testing.T) {
@ -266,7 +265,7 @@ LIMIT ?;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestSelectUNION_ALL(t *testing.T) { func TestSelectUNION_ALL(t *testing.T) {
@ -309,7 +308,7 @@ OFFSET ?;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestJoinQueryStruct(t *testing.T) { func TestJoinQueryStruct(t *testing.T) {
@ -407,7 +406,7 @@ LIMIT ?;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//assert.Equal(t, len(dest), 1) //assert.Equal(t, len(dest), 1)
//assert.Equal(t, len(dest[0].Films), 10) //assert.Equal(t, len(dest[0].Films), 10)
//assert.Equal(t, len(dest[0].Films[0].Actors), 10) //assert.Equal(t, len(dest[0].Films[0].Actors), 10)
@ -451,10 +450,10 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
@ -465,10 +464,10 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
if sourceIsMariaDB() { if sourceIsMariaDB() {
@ -483,10 +482,10 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
} }
@ -515,7 +514,7 @@ SELECT true,
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestLockInShareMode(t *testing.T) { func TestLockInShareMode(t *testing.T) {
@ -536,7 +535,7 @@ LOCK IN SHARE MODE;
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestWindowFunction(t *testing.T) { func TestWindowFunction(t *testing.T) {
@ -607,13 +606,13 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date;
).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate).
WHERE(Payment.PaymentID.LT(Int(10))) WHERE(Payment.PaymentID.LT(Int(10)))
fmt.Println(query.Sql()) //fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestWindowClause(t *testing.T) { func TestWindowClause(t *testing.T) {
@ -643,14 +642,14 @@ ORDER BY payment.customer_id;
WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)).
ORDER_BY(Payment.CustomerID) ORDER_BY(Payment.CustomerID)
fmt.Println(query.Sql()) //fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestSimpleView(t *testing.T) { func TestSimpleView(t *testing.T) {
@ -671,7 +670,7 @@ func TestSimpleView(t *testing.T) {
var dest []ActorInfo var dest []ActorInfo
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
testutils.AssertJSON(t, dest[1:2], ` testutils.AssertJSON(t, dest[1:2], `
@ -703,9 +702,42 @@ func TestJoinViewWithTable(t *testing.T) {
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32) assert.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27) assert.Equal(t, len(dest[1].Rentals), 27)
} }
func TestConditionalProjectionList(t *testing.T) {
projectionList := ProjectionList{}
columnsToSelect := []string{"customer_id", "create_date"}
for _, columnName := range columnsToSelect {
switch columnName {
case Customer.CustomerID.Name():
projectionList = append(projectionList, Customer.CustomerID)
case Customer.Email.Name():
projectionList = append(projectionList, Customer.Email)
case Customer.CreateDate.Name():
projectionList = append(projectionList, Customer.CreateDate)
}
}
stmt := SELECT(projectionList).
FROM(Customer).
LIMIT(3)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT customer.customer_id AS "customer.customer_id",
customer.create_date AS "customer.create_date"
FROM dvds.customer
LIMIT 3;
`)
var dest []model.Customer
err := stmt.Query(db, &dest)
assert.NoError(t, err)
assert.Equal(t, len(dest), 3)
}

View file

@ -8,7 +8,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -40,9 +40,9 @@ WHERE link.name = 'Bing';
WHERE(Link.Name.EQ(String("Bong"))). WHERE(Link.Name.EQ(String("Bong"))).
Query(db, &links) Query(db, &links)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(links), 1) assert.Equal(t, len(links), 1)
assert.DeepEqual(t, links[0], model.Link{ testutils.AssertDeepEqual(t, links[0], model.Link{
ID: 204, ID: 204,
URL: "http://bong.com", URL: "http://bong.com",
Name: "Bong", Name: "Bong",
@ -244,7 +244,7 @@ func TestUpdateWithJoin(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
_, err := query.Exec(db) _, err := query.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func setupLinkTableForUpdateTest(t *testing.T) { func setupLinkTableForUpdateTest(t *testing.T) {
@ -259,5 +259,5 @@ func setupLinkTableForUpdateTest(t *testing.T) {
VALUES(204, "http://www.bing.com", "Bing", DEFAULT). VALUES(204, "http://www.bing.com", "Bing", DEFAULT).
Exec(db) Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }

View file

@ -1,40 +1,40 @@
package postgres package postgres
import ( import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/postgres"
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view"
"github.com/go-jet/jet/tests/testdata/results/common" "github.com/go-jet/jet/tests/testdata/results/common"
"github.com/google/uuid"
"gotest.tools/assert"
"testing"
"time"
) )
func TestAllTypesSelect(t *testing.T) { func TestAllTypesSelect(t *testing.T) {
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest) err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
assert.DeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
} }
func TestAllTypesViewSelect(t *testing.T) { func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes type AllTypesView model.AllTypes
dest := []AllTypesView{} dest := []AllTypesView{}
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest[0], AllTypesView(allTypesRow0)) testutils.AssertDeepEqual(t, dest[0], AllTypesView(allTypesRow0))
assert.DeepEqual(t, dest[1], AllTypesView(allTypesRow1)) testutils.AssertDeepEqual(t, dest[1], AllTypesView(allTypesRow1))
} }
func TestAllTypesInsertModel(t *testing.T) { func TestAllTypesInsertModel(t *testing.T) {
@ -45,11 +45,11 @@ func TestAllTypesInsertModel(t *testing.T) {
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
assert.DeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
} }
func TestAllTypesInsertQuery(t *testing.T) { func TestAllTypesInsertQuery(t *testing.T) {
@ -64,10 +64,156 @@ func TestAllTypesInsertQuery(t *testing.T) {
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
}
func TestAllTypesFromSubQuery(t *testing.T) {
subQuery := SELECT(AllTypes.AllColumns).
FROM(AllTypes).
AsTable("allTypesSubQuery")
mainQuery := SELECT(subQuery.AllColumns()).
FROM(subQuery).
LIMIT(2)
assert.Equal(t, mainQuery.DebugSql(), `
SELECT "allTypesSubQuery"."all_types.small_int_ptr" AS "all_types.small_int_ptr",
"allTypesSubQuery"."all_types.small_int" AS "all_types.small_int",
"allTypesSubQuery"."all_types.integer_ptr" AS "all_types.integer_ptr",
"allTypesSubQuery"."all_types.integer" AS "all_types.integer",
"allTypesSubQuery"."all_types.big_int_ptr" AS "all_types.big_int_ptr",
"allTypesSubQuery"."all_types.big_int" AS "all_types.big_int",
"allTypesSubQuery"."all_types.decimal_ptr" AS "all_types.decimal_ptr",
"allTypesSubQuery"."all_types.decimal" AS "all_types.decimal",
"allTypesSubQuery"."all_types.numeric_ptr" AS "all_types.numeric_ptr",
"allTypesSubQuery"."all_types.numeric" AS "all_types.numeric",
"allTypesSubQuery"."all_types.real_ptr" AS "all_types.real_ptr",
"allTypesSubQuery"."all_types.real" AS "all_types.real",
"allTypesSubQuery"."all_types.double_precision_ptr" AS "all_types.double_precision_ptr",
"allTypesSubQuery"."all_types.double_precision" AS "all_types.double_precision",
"allTypesSubQuery"."all_types.smallserial" AS "all_types.smallserial",
"allTypesSubQuery"."all_types.serial" AS "all_types.serial",
"allTypesSubQuery"."all_types.bigserial" AS "all_types.bigserial",
"allTypesSubQuery"."all_types.var_char_ptr" AS "all_types.var_char_ptr",
"allTypesSubQuery"."all_types.var_char" AS "all_types.var_char",
"allTypesSubQuery"."all_types.char_ptr" AS "all_types.char_ptr",
"allTypesSubQuery"."all_types.char" AS "all_types.char",
"allTypesSubQuery"."all_types.text_ptr" AS "all_types.text_ptr",
"allTypesSubQuery"."all_types.text" AS "all_types.text",
"allTypesSubQuery"."all_types.bytea_ptr" AS "all_types.bytea_ptr",
"allTypesSubQuery"."all_types.bytea" AS "all_types.bytea",
"allTypesSubQuery"."all_types.timestampz_ptr" AS "all_types.timestampz_ptr",
"allTypesSubQuery"."all_types.timestampz" AS "all_types.timestampz",
"allTypesSubQuery"."all_types.timestamp_ptr" AS "all_types.timestamp_ptr",
"allTypesSubQuery"."all_types.timestamp" AS "all_types.timestamp",
"allTypesSubQuery"."all_types.date_ptr" AS "all_types.date_ptr",
"allTypesSubQuery"."all_types.date" AS "all_types.date",
"allTypesSubQuery"."all_types.timez_ptr" AS "all_types.timez_ptr",
"allTypesSubQuery"."all_types.timez" AS "all_types.timez",
"allTypesSubQuery"."all_types.time_ptr" AS "all_types.time_ptr",
"allTypesSubQuery"."all_types.time" AS "all_types.time",
"allTypesSubQuery"."all_types.interval_ptr" AS "all_types.interval_ptr",
"allTypesSubQuery"."all_types.interval" AS "all_types.interval",
"allTypesSubQuery"."all_types.boolean_ptr" AS "all_types.boolean_ptr",
"allTypesSubQuery"."all_types.boolean" AS "all_types.boolean",
"allTypesSubQuery"."all_types.point_ptr" AS "all_types.point_ptr",
"allTypesSubQuery"."all_types.bit_ptr" AS "all_types.bit_ptr",
"allTypesSubQuery"."all_types.bit" AS "all_types.bit",
"allTypesSubQuery"."all_types.bit_varying_ptr" AS "all_types.bit_varying_ptr",
"allTypesSubQuery"."all_types.bit_varying" AS "all_types.bit_varying",
"allTypesSubQuery"."all_types.tsvector_ptr" AS "all_types.tsvector_ptr",
"allTypesSubQuery"."all_types.tsvector" AS "all_types.tsvector",
"allTypesSubQuery"."all_types.uuid_ptr" AS "all_types.uuid_ptr",
"allTypesSubQuery"."all_types.uuid" AS "all_types.uuid",
"allTypesSubQuery"."all_types.xml_ptr" AS "all_types.xml_ptr",
"allTypesSubQuery"."all_types.xml" AS "all_types.xml",
"allTypesSubQuery"."all_types.json_ptr" AS "all_types.json_ptr",
"allTypesSubQuery"."all_types.json" AS "all_types.json",
"allTypesSubQuery"."all_types.jsonb_ptr" AS "all_types.jsonb_ptr",
"allTypesSubQuery"."all_types.jsonb" AS "all_types.jsonb",
"allTypesSubQuery"."all_types.integer_array_ptr" AS "all_types.integer_array_ptr",
"allTypesSubQuery"."all_types.integer_array" AS "all_types.integer_array",
"allTypesSubQuery"."all_types.text_array_ptr" AS "all_types.text_array_ptr",
"allTypesSubQuery"."all_types.text_array" AS "all_types.text_array",
"allTypesSubQuery"."all_types.jsonb_array" AS "all_types.jsonb_array",
"allTypesSubQuery"."all_types.text_multi_dim_array_ptr" AS "all_types.text_multi_dim_array_ptr",
"allTypesSubQuery"."all_types.text_multi_dim_array" AS "all_types.text_multi_dim_array"
FROM (
SELECT all_types.small_int_ptr AS "all_types.small_int_ptr",
all_types.small_int AS "all_types.small_int",
all_types.integer_ptr AS "all_types.integer_ptr",
all_types.integer AS "all_types.integer",
all_types.big_int_ptr AS "all_types.big_int_ptr",
all_types.big_int AS "all_types.big_int",
all_types.decimal_ptr AS "all_types.decimal_ptr",
all_types.decimal AS "all_types.decimal",
all_types.numeric_ptr AS "all_types.numeric_ptr",
all_types.numeric AS "all_types.numeric",
all_types.real_ptr AS "all_types.real_ptr",
all_types.real AS "all_types.real",
all_types.double_precision_ptr AS "all_types.double_precision_ptr",
all_types.double_precision AS "all_types.double_precision",
all_types.smallserial AS "all_types.smallserial",
all_types.serial AS "all_types.serial",
all_types.bigserial AS "all_types.bigserial",
all_types.var_char_ptr AS "all_types.var_char_ptr",
all_types.var_char AS "all_types.var_char",
all_types.char_ptr AS "all_types.char_ptr",
all_types.char AS "all_types.char",
all_types.text_ptr AS "all_types.text_ptr",
all_types.text AS "all_types.text",
all_types.bytea_ptr AS "all_types.bytea_ptr",
all_types.bytea AS "all_types.bytea",
all_types.timestampz_ptr AS "all_types.timestampz_ptr",
all_types.timestampz AS "all_types.timestampz",
all_types.timestamp_ptr AS "all_types.timestamp_ptr",
all_types.timestamp AS "all_types.timestamp",
all_types.date_ptr AS "all_types.date_ptr",
all_types.date AS "all_types.date",
all_types.timez_ptr AS "all_types.timez_ptr",
all_types.timez AS "all_types.timez",
all_types.time_ptr AS "all_types.time_ptr",
all_types.time AS "all_types.time",
all_types.interval_ptr AS "all_types.interval_ptr",
all_types.interval AS "all_types.interval",
all_types.boolean_ptr AS "all_types.boolean_ptr",
all_types.boolean AS "all_types.boolean",
all_types.point_ptr AS "all_types.point_ptr",
all_types.bit_ptr AS "all_types.bit_ptr",
all_types.bit AS "all_types.bit",
all_types.bit_varying_ptr AS "all_types.bit_varying_ptr",
all_types.bit_varying AS "all_types.bit_varying",
all_types.tsvector_ptr AS "all_types.tsvector_ptr",
all_types.tsvector AS "all_types.tsvector",
all_types.uuid_ptr AS "all_types.uuid_ptr",
all_types.uuid AS "all_types.uuid",
all_types.xml_ptr AS "all_types.xml_ptr",
all_types.xml AS "all_types.xml",
all_types.json_ptr AS "all_types.json_ptr",
all_types.json AS "all_types.json",
all_types.jsonb_ptr AS "all_types.jsonb_ptr",
all_types.jsonb AS "all_types.jsonb",
all_types.integer_array_ptr AS "all_types.integer_array_ptr",
all_types.integer_array AS "all_types.integer_array",
all_types.text_array_ptr AS "all_types.text_array_ptr",
all_types.text_array AS "all_types.text_array",
all_types.jsonb_array AS "all_types.jsonb_array",
all_types.text_multi_dim_array_ptr AS "all_types.text_multi_dim_array_ptr",
all_types.text_multi_dim_array AS "all_types.text_multi_dim_array"
FROM test_sample.all_types
) AS "allTypesSubQuery"
LIMIT 2;
`)
dest := []model.AllTypes{}
err := mainQuery.Query(db, &dest)
assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0], allTypesRow0)
assert.DeepEqual(t, dest[1], allTypesRow1)
} }
func TestExpressionOperators(t *testing.T) { func TestExpressionOperators(t *testing.T) {
@ -105,7 +251,7 @@ LIMIT $5;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -134,22 +280,23 @@ LIMIT $5;
func TestExpressionCast(t *testing.T) { func TestExpressionCast(t *testing.T) {
query := AllTypes.SELECT( query := AllTypes.SELECT(
postgres.CAST(Int(150)).AS_CHAR(12).AS("char12"), CAST(Int(150)).AS_CHAR(12).AS("char12"),
postgres.CAST(String("TRUE")).AS_BOOL(), CAST(String("TRUE")).AS_BOOL(),
postgres.CAST(String("111")).AS_SMALLINT(), CAST(String("111")).AS_SMALLINT(),
postgres.CAST(String("111")).AS_INTEGER(), CAST(String("111")).AS_INTEGER(),
postgres.CAST(String("111")).AS_BIGINT(), CAST(String("111")).AS_BIGINT(),
postgres.CAST(String("11.23")).AS_NUMERIC(30, 10), CAST(String("11.23")).AS_NUMERIC(30, 10),
postgres.CAST(String("11.23")).AS_NUMERIC(30), CAST(String("11.23")).AS_NUMERIC(30),
postgres.CAST(String("11.23")).AS_NUMERIC(), CAST(String("11.23")).AS_NUMERIC(),
postgres.CAST(String("11.23")).AS_REAL(), CAST(String("11.23")).AS_REAL(),
postgres.CAST(String("11.23")).AS_DOUBLE(), CAST(String("11.23")).AS_DOUBLE(),
postgres.CAST(Int(234)).AS_TEXT(), CAST(Int(234)).AS_TEXT(),
postgres.CAST(String("1/8/1999")).AS_DATE(), CAST(String("1/8/1999")).AS_DATE(),
postgres.CAST(String("04:05:06.789")).AS_TIME(), CAST(String("04:05:06.789")).AS_TIME(),
postgres.CAST(String("04:05:06 PST")).AS_TIMEZ(), CAST(String("04:05:06 PST")).AS_TIMEZ(),
postgres.CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(),
postgres.CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(),
CAST(String("04:05:06")).AS_INTERVAL(),
TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")),
TO_CHAR(AllTypes.Integer, String("999")), TO_CHAR(AllTypes.Integer, String("999")),
@ -173,7 +320,7 @@ func TestExpressionCast(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestStringOperators(t *testing.T) { func TestStringOperators(t *testing.T) {
@ -253,7 +400,7 @@ func TestStringOperators(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestBoolOperators(t *testing.T) { func TestBoolOperators(t *testing.T) {
@ -322,7 +469,7 @@ LIMIT $5;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json")
} }
@ -359,7 +506,7 @@ func TestFloatOperators(t *testing.T) {
TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"),
TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"),
TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"),
TRUNC(postgres.CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"),
CEIL(AllTypes.Real).AS("ceil"), CEIL(AllTypes.Real).AS("ceil"),
FLOOR(AllTypes.Real).AS("floor"), FLOOR(AllTypes.Real).AS("floor"),
@ -418,7 +565,7 @@ LIMIT $35;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -557,7 +704,7 @@ LIMIT $23;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.SaveJsonFile("./testdata/common/int_operators.json", dest) //testutils.SaveJsonFile("./testdata/common/int_operators.json", dest)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -606,6 +753,19 @@ func TestTimeExpression(t *testing.T) {
AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(AllTypes.Time),
AllTypes.Time.GT_EQ(Time(23, 6, 6, 1)), AllTypes.Time.GT_EQ(Time(23, 6, 6, 1)),
AllTypes.Date.ADD(INTERVAL(1, HOUR)),
AllTypes.Date.SUB(INTERVAL(1, MINUTE)),
AllTypes.Time.ADD(INTERVAL(1, HOUR)),
AllTypes.Time.SUB(INTERVAL(1, MINUTE)),
AllTypes.Timez.ADD(INTERVAL(1, HOUR)),
AllTypes.Timez.SUB(INTERVAL(1, MINUTE)),
AllTypes.Timestamp.ADD(INTERVAL(1, HOUR)),
AllTypes.Timestamp.SUB(INTERVAL(1, MINUTE)),
AllTypes.Timestampz.ADD(INTERVAL(1, HOUR)),
AllTypes.Timestampz.SUB(INTERVAL(1, MINUTE)),
AllTypes.Date.SUB(CAST(String("04:05:06")).AS_INTERVAL()),
CURRENT_DATE(), CURRENT_DATE(),
CURRENT_TIME(), CURRENT_TIME(),
CURRENT_TIME(2), CURRENT_TIME(2),
@ -623,7 +783,58 @@ func TestTimeExpression(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
}
func TestInterval(t *testing.T) {
stmt := SELECT(
INTERVAL(1, YEAR),
INTERVAL(1, MONTH),
INTERVAL(1, WEEK),
INTERVAL(1, DAY),
INTERVAL(1, HOUR),
INTERVAL(1, MINUTE),
INTERVAL(1, SECOND),
INTERVAL(1, MILLISECOND),
INTERVAL(1, MICROSECOND),
INTERVAL(1, DECADE),
INTERVAL(1, CENTURY),
INTERVAL(1, MILLENNIUM),
INTERVAL(1, YEAR, 10, MONTH),
INTERVAL(1, YEAR, 10, MONTH, 20, DAY),
INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR),
INTERVAL(1, YEAR).IS_NOT_NULL(),
INTERVAL(1, YEAR).AS("one year"),
INTERVALd(0),
INTERVALd(1*time.Microsecond),
INTERVALd(1*time.Millisecond),
INTERVALd(1*time.Second),
INTERVALd(1*time.Minute),
INTERVALd(1*time.Hour),
INTERVALd(24*time.Hour),
INTERVALd(24*time.Hour+2*time.Hour+3*time.Minute+4*time.Second+5*time.Microsecond),
AllTypes.Interval.EQ(INTERVAL(2, HOUR, 20, MINUTE)).EQ(Bool(true)),
AllTypes.IntervalPtr.NOT_EQ(INTERVAL(2, HOUR, 20, MINUTE)).EQ(Bool(false)),
AllTypes.Interval.IS_DISTINCT_FROM(INTERVAL(2, HOUR, 20, MINUTE)).EQ(AllTypes.Boolean),
AllTypes.IntervalPtr.IS_NOT_DISTINCT_FROM(INTERVALd(10*time.Microsecond)).EQ(AllTypes.Boolean),
AllTypes.Interval.LT(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr),
AllTypes.Interval.LT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr),
AllTypes.Interval.GT(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr),
AllTypes.Interval.GT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr),
AllTypes.Interval.ADD(AllTypes.IntervalPtr).EQ(INTERVALd(17*time.Second)),
AllTypes.Interval.SUB(AllTypes.IntervalPtr).EQ(INTERVAL(100, MICROSECOND)),
AllTypes.IntervalPtr.MUL(Int(11)).EQ(AllTypes.Interval),
AllTypes.IntervalPtr.DIV(Float(22.222)).EQ(AllTypes.IntervalPtr),
).FROM(AllTypes)
//fmt.Println(stmt.DebugSql())
err := stmt.Query(db, &struct{}{})
assert.NoError(t, err)
} }
func TestSubQueryColumnReference(t *testing.T) { func TestSubQueryColumnReference(t *testing.T) {
@ -775,17 +986,17 @@ FROM`
dest1 := []model.AllTypes{} dest1 := []model.AllTypes{}
err := stmt1.Query(db, &dest1) err := stmt1.Query(db, &dest1)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest1), 2) assert.Equal(t, len(dest1), 2)
assert.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean) assert.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean)
assert.Equal(t, dest1[0].Integer, allTypesRow0.Integer) assert.Equal(t, dest1[0].Integer, allTypesRow0.Integer)
assert.Equal(t, dest1[0].Real, allTypesRow0.Real) assert.Equal(t, dest1[0].Real, allTypesRow0.Real)
assert.Equal(t, dest1[0].Text, allTypesRow0.Text) assert.Equal(t, dest1[0].Text, allTypesRow0.Text)
assert.DeepEqual(t, dest1[0].Time, allTypesRow0.Time) testutils.AssertDeepEqual(t, dest1[0].Time, allTypesRow0.Time)
assert.DeepEqual(t, dest1[0].Timez, allTypesRow0.Timez) testutils.AssertDeepEqual(t, dest1[0].Timez, allTypesRow0.Timez)
assert.DeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp) testutils.AssertDeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp)
assert.DeepEqual(t, dest1[0].Timestampz, allTypesRow0.Timestampz) testutils.AssertDeepEqual(t, dest1[0].Timestampz, allTypesRow0.Timestampz)
assert.DeepEqual(t, dest1[0].Date, allTypesRow0.Date) testutils.AssertDeepEqual(t, dest1[0].Date, allTypesRow0.Date)
stmt2 := SELECT( stmt2 := SELECT(
subQuery.AllColumns(), subQuery.AllColumns(),
@ -797,15 +1008,15 @@ FROM`
dest2 := []model.AllTypes{} dest2 := []model.AllTypes{}
err = stmt2.Query(db, &dest2) err = stmt2.Query(db, &dest2)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest1, dest2) testutils.AssertDeepEqual(t, dest1, dest2)
} }
} }
func TestTimeLiterals(t *testing.T) { func TestTimeLiterals(t *testing.T) {
loc, err := time.LoadLocation("Europe/Berlin") loc, err := time.LoadLocation("Europe/Berlin")
assert.NilError(t, err) assert.NoError(t, err)
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, loc) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, loc)
@ -840,7 +1051,7 @@ LIMIT $6;
err = query.Query(db, &dest) err = query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)

View file

@ -7,7 +7,7 @@ import (
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -30,11 +30,11 @@ ORDER BY "Album"."AlbumId" ASC;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 347) assert.Equal(t, len(dest), 347)
assert.DeepEqual(t, dest[0], album1) testutils.AssertDeepEqual(t, dest[0], album1)
assert.DeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[1], album2)
assert.DeepEqual(t, dest[len(dest)-1], album347) testutils.AssertDeepEqual(t, dest[len(dest)-1], album347)
} }
func TestJoinEverything(t *testing.T) { func TestJoinEverything(t *testing.T) {
@ -103,7 +103,7 @@ func TestJoinEverything(t *testing.T) {
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 275) assert.Equal(t, len(dest), 275)
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json")
} }
@ -143,7 +143,7 @@ ORDER BY "Employee"."EmployeeId";
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 8) assert.Equal(t, len(dest), 8)
testutils.AssertJSON(t, dest[0:2], ` testutils.AssertJSON(t, dest[0:2], `
[ [
@ -236,11 +236,11 @@ ORDER BY "Album.AlbumId";
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0], album1) testutils.AssertDeepEqual(t, dest[0], album1)
assert.DeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[1], album2)
} }
func TestQueryWithContext(t *testing.T) { func TestQueryWithContext(t *testing.T) {
@ -327,7 +327,7 @@ ORDER BY "first10Artist"."Artist.ArtistId";
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//spew.Dump(dest) //spew.Dump(dest)
} }

View file

@ -6,7 +6,7 @@ import (
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -48,11 +48,11 @@ RETURNING link.id AS "link.id",
err := deleteStmt.Query(db, &dest) err := deleteStmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0].Name, "Gmail") testutils.AssertDeepEqual(t, dest[0].Name, "Gmail")
assert.DeepEqual(t, dest[1].Name, "Outlook") testutils.AssertDeepEqual(t, dest[1].Name, "Outlook")
} }
func initForDeleteTest(t *testing.T) { func initForDeleteTest(t *testing.T) {

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/generator/postgres" "github.com/go-jet/jet/generator/postgres"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
@ -15,12 +15,11 @@ import (
) )
func TestGeneratedModel(t *testing.T) { func TestGeneratedModel(t *testing.T) {
actor := model.Actor{} actor := model.Actor{}
assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
assert.Assert(t, ok) assert.True(t, ok)
assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key") assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
@ -30,12 +29,12 @@ func TestGeneratedModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16") assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16")
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
assert.Assert(t, ok) assert.True(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16") assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16")
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
assert.Assert(t, ok) assert.True(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
staff := model.Staff{} staff := model.Staff{}
@ -50,10 +49,10 @@ func TestCmdGenerator(t *testing.T) {
goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet") goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet")
goInstallJet.Stderr = os.Stderr goInstallJet.Stderr = os.Stderr
err := goInstallJet.Run() err := goInstallJet.Run()
assert.NilError(t, err) assert.NoError(t, err)
err = os.RemoveAll(genTestDir2) err = os.RemoveAll(genTestDir2)
assert.NilError(t, err) assert.NoError(t, err)
cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", "-port=5432", cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", "-port=5432",
"-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2) "-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2)
@ -61,12 +60,12 @@ func TestCmdGenerator(t *testing.T) {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
err = cmd.Run() err = cmd.Run()
assert.NilError(t, err) assert.NoError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
err = os.RemoveAll(genTestDir2) err = os.RemoveAll(genTestDir2)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestGenerator(t *testing.T) { func TestGenerator(t *testing.T) {
@ -84,19 +83,19 @@ func TestGenerator(t *testing.T) {
SchemaName: "dvds", SchemaName: "dvds",
}) })
assert.NilError(t, err) assert.NoError(t, err)
assertGeneratedFiles(t) assertGeneratedFiles(t)
} }
err := os.RemoveAll(genTestDir2) err := os.RemoveAll(genTestDir2)
assert.NilError(t, err) assert.NoError(t, err)
} }
func assertGeneratedFiles(t *testing.T) { func assertGeneratedFiles(t *testing.T) {
// Table SQL Builder files // Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
@ -106,7 +105,7 @@ func assertGeneratedFiles(t *testing.T) {
// View SQL Builder files // View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go")
@ -115,14 +114,14 @@ func assertGeneratedFiles(t *testing.T) {
// Enums SQL Builder files // Enums SQL Builder files
enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go") testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go")
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile)
// Model files // Model files
modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model")
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
@ -275,3 +274,366 @@ func newActorInfoTable() *ActorInfoTable {
} }
} }
` `
func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
enumDir := testRoot + ".gentestdata/jetdb/test_sample/enum/"
modelDir := testRoot + ".gentestdata/jetdb/test_sample/model/"
tableDir := testRoot + ".gentestdata/jetdb/test_sample/table/"
enumFiles, err := ioutil.ReadDir(enumDir)
assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, enumFiles, "mood.go", "level.go")
testutils.AssertFileContent(t, enumDir+"mood.go", "\npackage enum", moodEnumContent)
testutils.AssertFileContent(t, enumDir+"level.go", "\npackage enum", levelEnumContent)
modelFiles, err := ioutil.ReadDir(modelDir)
assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go",
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go")
testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent)
tableFiles, err := ioutil.ReadDir(tableDir)
assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go",
"person.go", "person_phone.go", "weird_names_table.go", "user.go")
testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent)
}
var moodEnumContent = `
package enum
import "github.com/go-jet/jet/postgres"
var Mood = &struct {
Sad postgres.StringExpression
Ok postgres.StringExpression
Happy postgres.StringExpression
}{
Sad: postgres.NewEnumValue("sad"),
Ok: postgres.NewEnumValue("ok"),
Happy: postgres.NewEnumValue("happy"),
}
`
var levelEnumContent = `
package enum
import "github.com/go-jet/jet/postgres"
var Level = &struct {
Level1 postgres.StringExpression
Level2 postgres.StringExpression
Level3 postgres.StringExpression
Level4 postgres.StringExpression
Level5 postgres.StringExpression
}{
Level1: postgres.NewEnumValue("1"),
Level2: postgres.NewEnumValue("2"),
Level3: postgres.NewEnumValue("3"),
Level4: postgres.NewEnumValue("4"),
Level5: postgres.NewEnumValue("5"),
}
`
var allTypesModelContent = `
package model
import (
"github.com/google/uuid"
"time"
)
type AllTypes struct {
SmallIntPtr *int16
SmallInt int16
IntegerPtr *int32
Integer int32
BigIntPtr *int64
BigInt int64
DecimalPtr *float64
Decimal float64
NumericPtr *float64
Numeric float64
RealPtr *float32
Real float32
DoublePrecisionPtr *float64
DoublePrecision float64
Smallserial int16
Serial int32
Bigserial int64
VarCharPtr *string
VarChar string
CharPtr *string
Char string
TextPtr *string
Text string
ByteaPtr *[]byte
Bytea []byte
TimestampzPtr *time.Time
Timestampz time.Time
TimestampPtr *time.Time
Timestamp time.Time
DatePtr *time.Time
Date time.Time
TimezPtr *time.Time
Timez time.Time
TimePtr *time.Time
Time time.Time
IntervalPtr *string
Interval string
BooleanPtr *bool
Boolean bool
PointPtr *string
BitPtr *string
Bit string
BitVaryingPtr *string
BitVarying string
TsvectorPtr *string
Tsvector string
UUIDPtr *uuid.UUID
UUID uuid.UUID
XMLPtr *string
XML string
JSONPtr *string
JSON string
JsonbPtr *string
Jsonb string
IntegerArrayPtr *string
IntegerArray string
TextArrayPtr *string
TextArray string
JsonbArray string
TextMultiDimArrayPtr *string
TextMultiDimArray string
}
`
var allTypesTableContent = `
package table
import (
"github.com/go-jet/jet/postgres"
)
var AllTypes = newAllTypesTable()
type AllTypesTable struct {
postgres.Table
//Columns
SmallIntPtr postgres.ColumnInteger
SmallInt postgres.ColumnInteger
IntegerPtr postgres.ColumnInteger
Integer postgres.ColumnInteger
BigIntPtr postgres.ColumnInteger
BigInt postgres.ColumnInteger
DecimalPtr postgres.ColumnFloat
Decimal postgres.ColumnFloat
NumericPtr postgres.ColumnFloat
Numeric postgres.ColumnFloat
RealPtr postgres.ColumnFloat
Real postgres.ColumnFloat
DoublePrecisionPtr postgres.ColumnFloat
DoublePrecision postgres.ColumnFloat
Smallserial postgres.ColumnInteger
Serial postgres.ColumnInteger
Bigserial postgres.ColumnInteger
VarCharPtr postgres.ColumnString
VarChar postgres.ColumnString
CharPtr postgres.ColumnString
Char postgres.ColumnString
TextPtr postgres.ColumnString
Text postgres.ColumnString
ByteaPtr postgres.ColumnString
Bytea postgres.ColumnString
TimestampzPtr postgres.ColumnTimestampz
Timestampz postgres.ColumnTimestampz
TimestampPtr postgres.ColumnTimestamp
Timestamp postgres.ColumnTimestamp
DatePtr postgres.ColumnDate
Date postgres.ColumnDate
TimezPtr postgres.ColumnTimez
Timez postgres.ColumnTimez
TimePtr postgres.ColumnTime
Time postgres.ColumnTime
IntervalPtr postgres.ColumnInterval
Interval postgres.ColumnInterval
BooleanPtr postgres.ColumnBool
Boolean postgres.ColumnBool
PointPtr postgres.ColumnString
BitPtr postgres.ColumnString
Bit postgres.ColumnString
BitVaryingPtr postgres.ColumnString
BitVarying postgres.ColumnString
TsvectorPtr postgres.ColumnString
Tsvector postgres.ColumnString
UUIDPtr postgres.ColumnString
UUID postgres.ColumnString
XMLPtr postgres.ColumnString
XML postgres.ColumnString
JSONPtr postgres.ColumnString
JSON postgres.ColumnString
JsonbPtr postgres.ColumnString
Jsonb postgres.ColumnString
IntegerArrayPtr postgres.ColumnString
IntegerArray postgres.ColumnString
TextArrayPtr postgres.ColumnString
TextArray postgres.ColumnString
JsonbArray postgres.ColumnString
TextMultiDimArrayPtr postgres.ColumnString
TextMultiDimArray postgres.ColumnString
AllColumns postgres.ColumnList
MutableColumns postgres.ColumnList
}
// creates new AllTypesTable with assigned alias
func (a *AllTypesTable) AS(alias string) *AllTypesTable {
aliasTable := newAllTypesTable()
aliasTable.Table.AS(alias)
return aliasTable
}
func newAllTypesTable() *AllTypesTable {
var (
SmallIntPtrColumn = postgres.IntegerColumn("small_int_ptr")
SmallIntColumn = postgres.IntegerColumn("small_int")
IntegerPtrColumn = postgres.IntegerColumn("integer_ptr")
IntegerColumn = postgres.IntegerColumn("integer")
BigIntPtrColumn = postgres.IntegerColumn("big_int_ptr")
BigIntColumn = postgres.IntegerColumn("big_int")
DecimalPtrColumn = postgres.FloatColumn("decimal_ptr")
DecimalColumn = postgres.FloatColumn("decimal")
NumericPtrColumn = postgres.FloatColumn("numeric_ptr")
NumericColumn = postgres.FloatColumn("numeric")
RealPtrColumn = postgres.FloatColumn("real_ptr")
RealColumn = postgres.FloatColumn("real")
DoublePrecisionPtrColumn = postgres.FloatColumn("double_precision_ptr")
DoublePrecisionColumn = postgres.FloatColumn("double_precision")
SmallserialColumn = postgres.IntegerColumn("smallserial")
SerialColumn = postgres.IntegerColumn("serial")
BigserialColumn = postgres.IntegerColumn("bigserial")
VarCharPtrColumn = postgres.StringColumn("var_char_ptr")
VarCharColumn = postgres.StringColumn("var_char")
CharPtrColumn = postgres.StringColumn("char_ptr")
CharColumn = postgres.StringColumn("char")
TextPtrColumn = postgres.StringColumn("text_ptr")
TextColumn = postgres.StringColumn("text")
ByteaPtrColumn = postgres.StringColumn("bytea_ptr")
ByteaColumn = postgres.StringColumn("bytea")
TimestampzPtrColumn = postgres.TimestampzColumn("timestampz_ptr")
TimestampzColumn = postgres.TimestampzColumn("timestampz")
TimestampPtrColumn = postgres.TimestampColumn("timestamp_ptr")
TimestampColumn = postgres.TimestampColumn("timestamp")
DatePtrColumn = postgres.DateColumn("date_ptr")
DateColumn = postgres.DateColumn("date")
TimezPtrColumn = postgres.TimezColumn("timez_ptr")
TimezColumn = postgres.TimezColumn("timez")
TimePtrColumn = postgres.TimeColumn("time_ptr")
TimeColumn = postgres.TimeColumn("time")
IntervalPtrColumn = postgres.IntervalColumn("interval_ptr")
IntervalColumn = postgres.IntervalColumn("interval")
BooleanPtrColumn = postgres.BoolColumn("boolean_ptr")
BooleanColumn = postgres.BoolColumn("boolean")
PointPtrColumn = postgres.StringColumn("point_ptr")
BitPtrColumn = postgres.StringColumn("bit_ptr")
BitColumn = postgres.StringColumn("bit")
BitVaryingPtrColumn = postgres.StringColumn("bit_varying_ptr")
BitVaryingColumn = postgres.StringColumn("bit_varying")
TsvectorPtrColumn = postgres.StringColumn("tsvector_ptr")
TsvectorColumn = postgres.StringColumn("tsvector")
UUIDPtrColumn = postgres.StringColumn("uuid_ptr")
UUIDColumn = postgres.StringColumn("uuid")
XMLPtrColumn = postgres.StringColumn("xml_ptr")
XMLColumn = postgres.StringColumn("xml")
JSONPtrColumn = postgres.StringColumn("json_ptr")
JSONColumn = postgres.StringColumn("json")
JsonbPtrColumn = postgres.StringColumn("jsonb_ptr")
JsonbColumn = postgres.StringColumn("jsonb")
IntegerArrayPtrColumn = postgres.StringColumn("integer_array_ptr")
IntegerArrayColumn = postgres.StringColumn("integer_array")
TextArrayPtrColumn = postgres.StringColumn("text_array_ptr")
TextArrayColumn = postgres.StringColumn("text_array")
JsonbArrayColumn = postgres.StringColumn("jsonb_array")
TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr")
TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array")
)
return &AllTypesTable{
Table: postgres.NewTable("test_sample", "all_types", SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn),
//Columns
SmallIntPtr: SmallIntPtrColumn,
SmallInt: SmallIntColumn,
IntegerPtr: IntegerPtrColumn,
Integer: IntegerColumn,
BigIntPtr: BigIntPtrColumn,
BigInt: BigIntColumn,
DecimalPtr: DecimalPtrColumn,
Decimal: DecimalColumn,
NumericPtr: NumericPtrColumn,
Numeric: NumericColumn,
RealPtr: RealPtrColumn,
Real: RealColumn,
DoublePrecisionPtr: DoublePrecisionPtrColumn,
DoublePrecision: DoublePrecisionColumn,
Smallserial: SmallserialColumn,
Serial: SerialColumn,
Bigserial: BigserialColumn,
VarCharPtr: VarCharPtrColumn,
VarChar: VarCharColumn,
CharPtr: CharPtrColumn,
Char: CharColumn,
TextPtr: TextPtrColumn,
Text: TextColumn,
ByteaPtr: ByteaPtrColumn,
Bytea: ByteaColumn,
TimestampzPtr: TimestampzPtrColumn,
Timestampz: TimestampzColumn,
TimestampPtr: TimestampPtrColumn,
Timestamp: TimestampColumn,
DatePtr: DatePtrColumn,
Date: DateColumn,
TimezPtr: TimezPtrColumn,
Timez: TimezColumn,
TimePtr: TimePtrColumn,
Time: TimeColumn,
IntervalPtr: IntervalPtrColumn,
Interval: IntervalColumn,
BooleanPtr: BooleanPtrColumn,
Boolean: BooleanColumn,
PointPtr: PointPtrColumn,
BitPtr: BitPtrColumn,
Bit: BitColumn,
BitVaryingPtr: BitVaryingPtrColumn,
BitVarying: BitVaryingColumn,
TsvectorPtr: TsvectorPtrColumn,
Tsvector: TsvectorColumn,
UUIDPtr: UUIDPtrColumn,
UUID: UUIDColumn,
XMLPtr: XMLPtrColumn,
XML: XMLColumn,
JSONPtr: JSONPtrColumn,
JSON: JSONColumn,
JsonbPtr: JsonbPtrColumn,
Jsonb: JsonbColumn,
IntegerArrayPtr: IntegerArrayPtrColumn,
IntegerArray: IntegerArrayColumn,
TextArrayPtr: TextArrayPtrColumn,
TextArray: TextArrayColumn,
JsonbArray: JsonbArrayColumn,
TextMultiDimArrayPtr: TextMultiDimArrayPtrColumn,
TextMultiDimArray: TextMultiDimArrayColumn,
AllColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn},
MutableColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn},
}
}
`

View file

@ -6,7 +6,7 @@ import (
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -24,7 +24,6 @@ RETURNING link.id AS "link.id",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"; link.description AS "link.description";
` `
Link.ID.Name()
insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(101, "http://www.google.com", "Google", DEFAULT). VALUES(101, "http://www.google.com", "Google", DEFAULT).
@ -40,23 +39,23 @@ RETURNING link.id AS "link.id",
err := insertQuery.Query(db, &insertedLinks) err := insertQuery.Query(db, &insertedLinks)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(insertedLinks), 3) assert.Equal(t, len(insertedLinks), 3)
assert.DeepEqual(t, insertedLinks[0], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{
ID: 100, ID: 100,
URL: "http://www.postgresqltutorial.com", URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial", Name: "PostgreSQL Tutorial",
}) })
assert.DeepEqual(t, insertedLinks[1], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{
ID: 101, ID: 101,
URL: "http://www.google.com", URL: "http://www.google.com",
Name: "Google", Name: "Google",
}) })
assert.DeepEqual(t, insertedLinks[2], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{
ID: 102, ID: 102,
URL: "http://www.yahoo.com", URL: "http://www.yahoo.com",
Name: "Yahoo", Name: "Yahoo",
@ -69,9 +68,9 @@ RETURNING link.id AS "link.id",
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &allLinks) Query(db, &allLinks)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, insertedLinks, allLinks) testutils.AssertDeepEqual(t, insertedLinks, allLinks)
} }
func TestInsertEmptyColumnList(t *testing.T) { func TestInsertEmptyColumnList(t *testing.T) {
@ -207,7 +206,7 @@ func TestInsertQuery(t *testing.T) {
_, err := Link.DELETE(). _, err := Link.DELETE().
WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))).
Exec(db) Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
var expectedSQL = ` var expectedSQL = `
INSERT INTO test_sample.link (url, name) ( INSERT INTO test_sample.link (url, name) (
@ -237,7 +236,7 @@ RETURNING link.id AS "link.id",
err = query.Query(db, &dest) err = query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
youtubeLinks := []model.Link{} youtubeLinks := []model.Link{}
err = Link. err = Link.
@ -245,7 +244,7 @@ RETURNING link.id AS "link.id",
WHERE(Link.Name.EQ(String("Youtube"))). WHERE(Link.Name.EQ(String("Youtube"))).
Query(db, &youtubeLinks) Query(db, &youtubeLinks)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(youtubeLinks), 2) assert.Equal(t, len(youtubeLinks), 2)
} }

View file

@ -3,7 +3,7 @@ package postgres
import ( import (
"context" "context"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
@ -35,11 +35,11 @@ LOCK TABLE dvds.address IN`
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
for _, lockMode := range testData { for _, lockMode := range testData {
@ -51,11 +51,11 @@ LOCK TABLE dvds.address IN`
_, err := query.Exec(tx) _, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
} }

View file

@ -6,14 +6,19 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/pkg/profile" "github.com/pkg/profile"
"os" "os"
"os/exec"
"strings"
"testing" "testing"
) )
var db *sql.DB var db *sql.DB
var testRoot string
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
defer profile.Start().Stop() defer profile.Start().Stop()
setTestRoot()
var err error var err error
db, err = sql.Open("postgres", dbconfig.PostgresConnectString) db, err = sql.Open("postgres", dbconfig.PostgresConnectString)
if err != nil { if err != nil {
@ -25,3 +30,13 @@ func TestMain(m *testing.M) {
os.Exit(ret) os.Exit(ret)
} }
func setTestRoot() {
cmd := exec.Command("git", "rev-parse", "--show-toplevel")
byteArr, err := cmd.Output()
if err != nil {
panic(err)
}
testRoot = strings.TrimSpace(string(byteArr)) + "/tests/"
}

View file

@ -4,7 +4,7 @@ import (
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -59,7 +59,7 @@ func TestNorthwindJoinEverything(t *testing.T) {
} }
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//jsonSave("./testdata/northwind-all.json", dest) //jsonSave("./testdata/northwind-all.json", dest)
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json")

View file

@ -6,7 +6,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"github.com/google/uuid" "github.com/google/uuid"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -25,9 +25,9 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11';
result := model.AllTypes{} result := model.AllTypes{}
err := query.Query(db, &result) err := query.Query(db, &result)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) assert.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
assert.DeepEqual(t, result.UUIDPtr, UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) testutils.AssertDeepEqual(t, result.UUIDPtr, UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
} }
func TestUUIDComplex(t *testing.T) { func TestUUIDComplex(t *testing.T) {
@ -46,7 +46,7 @@ func TestUUIDComplex(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
[ [
@ -96,7 +96,7 @@ func TestUUIDComplex(t *testing.T) {
} }
} }
err := singleQuery.Query(db, &dest) err := singleQuery.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
{ {
@ -132,7 +132,7 @@ func TestUUIDComplex(t *testing.T) {
} }
err := leftQuery.Query(db, &dest) err := leftQuery.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
[ [
{ {
@ -194,7 +194,7 @@ FROM test_sample.person;
err := query.Query(db, &result) err := query.Query(db, &result)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSON(t, result, ` testutils.AssertJSON(t, result, `
[ [
{ {
@ -258,9 +258,9 @@ ORDER BY employee.employee_id;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 8) assert.Equal(t, len(dest), 8)
assert.DeepEqual(t, dest[0].Employee, model.Employee{ testutils.AssertDeepEqual(t, dest[0].Employee, model.Employee{
EmployeeID: 1, EmployeeID: 1,
FirstName: "Windy", FirstName: "Windy",
LastName: "Hays", LastName: "Hays",
@ -268,9 +268,9 @@ ORDER BY employee.employee_id;
ManagerID: nil, ManagerID: nil,
}) })
assert.Assert(t, dest[0].Manager == nil) assert.True(t, dest[0].Manager == nil)
assert.DeepEqual(t, dest[7].Employee, model.Employee{ testutils.AssertDeepEqual(t, dest[7].Employee, model.Employee{
EmployeeID: 8, EmployeeID: 8,
FirstName: "Salley", FirstName: "Salley",
LastName: "Lester", LastName: "Lester",
@ -306,10 +306,10 @@ FROM test_sample."WEIRD NAMES TABLE";
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 1) assert.Equal(t, len(dest), 1)
assert.DeepEqual(t, dest[0], model.WeirdNamesTable{ testutils.AssertDeepEqual(t, dest[0], model.WeirdNamesTable{
WeirdColumnName1: "Doe", WeirdColumnName1: "Doe",
WeirdColumnName2: "Doe", WeirdColumnName2: "Doe",
WeirdColumnName3: "Doe", WeirdColumnName3: "Doe",
@ -328,3 +328,54 @@ FROM test_sample."WEIRD NAMES TABLE";
WeirdColuName16: "Doe", WeirdColuName16: "Doe",
}) })
} }
func TestReserwedWordEscape(t *testing.T) {
stmt := SELECT(User.AllColumns).
FROM(User)
//fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, `
SELECT "User"."column" AS "User.column",
"User"."check" AS "User.check",
"User".ceil AS "User.ceil",
"User".commit AS "User.commit",
"User"."create" AS "User.create",
"User"."default" AS "User.default",
"User"."desc" AS "User.desc",
"User".empty AS "User.empty",
"User".float AS "User.float",
"User".join AS "User.join",
"User".like AS "User.like",
"User".max AS "User.max",
"User".rank AS "User.rank"
FROM test_sample."User";
`)
var dest []model.User
err := stmt.Query(db, &dest)
assert.NoError(t, err)
testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, `
[
{
"Column": "Column",
"Check": "CHECK",
"Ceil": "CEIL",
"Commit": "COMMIT",
"Create": "CREATE",
"Default": "DEFAULT",
"Desc": "DESC",
"Empty": "EMPTY",
"Float": "FLOAT",
"Join": "JOIN",
"Like": "LIKE",
"Max": "MAX",
"Rank": "RANK"
}
]
`)
}

View file

@ -8,7 +8,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table"
"github.com/google/uuid" "github.com/google/uuid"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
@ -53,38 +53,38 @@ func TestScanToValidDestination(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
}) })
t.Run("global query function scan", func(t *testing.T) { t.Run("global query function scan", func(t *testing.T) {
queryStr, args := query.Sql() queryStr, args := query.Sql()
dest := []struct{}{} dest := []struct{}{}
err := qrm.Query(nil, db, queryStr, args, &dest) err := qrm.Query(nil, db, queryStr, args, &dest)
assert.NilError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice", func(t *testing.T) { t.Run("pointer to slice", func(t *testing.T) {
err := query.Query(db, &[]struct{}{}) err := query.Query(db, &[]struct{}{})
assert.NilError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice of pointer to structs", func(t *testing.T) { t.Run("pointer to slice of pointer to structs", func(t *testing.T) {
err := query.Query(db, &[]*struct{}{}) err := query.Query(db, &[]*struct{}{})
assert.NilError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice of strings", func(t *testing.T) { t.Run("pointer to slice of strings", func(t *testing.T) {
err := query.Query(db, &[]int32{}) err := query.Query(db, &[]int32{})
assert.NilError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice of strings", func(t *testing.T) { t.Run("pointer to slice of strings", func(t *testing.T) {
err := query.Query(db, &[]*int32{}) err := query.Query(db, &[]*int32{})
assert.NilError(t, err) assert.NoError(t, err)
}) })
} }
@ -99,16 +99,16 @@ func TestScanToStruct(t *testing.T) {
dest := model.Inventory{} dest := model.Inventory{}
err := query.LIMIT(1).Query(db, &dest) err := query.LIMIT(1).Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, inventory1, dest) testutils.AssertDeepEqual(t, inventory1, dest)
}) })
t.Run("multiple structs, just first one used", func(t *testing.T) { t.Run("multiple structs, just first one used", func(t *testing.T) {
dest := model.Inventory{} dest := model.Inventory{}
err := query.LIMIT(10).Query(db, &dest) err := query.LIMIT(10).Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, inventory1, dest) testutils.AssertDeepEqual(t, inventory1, dest)
}) })
t.Run("one struct", func(t *testing.T) { t.Run("one struct", func(t *testing.T) {
@ -117,8 +117,8 @@ func TestScanToStruct(t *testing.T) {
}{} }{}
err := query.LIMIT(1).Query(db, &dest) err := query.LIMIT(1).Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, inventory1, dest.Inventory) testutils.AssertDeepEqual(t, inventory1, dest.Inventory)
}) })
t.Run("one struct", func(t *testing.T) { t.Run("one struct", func(t *testing.T) {
@ -127,8 +127,8 @@ func TestScanToStruct(t *testing.T) {
}{} }{}
err := query.LIMIT(1).Query(db, &dest) err := query.LIMIT(1).Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, inventory1, *dest.Inventory) testutils.AssertDeepEqual(t, inventory1, *dest.Inventory)
}) })
t.Run("invalid dest", func(t *testing.T) { t.Run("invalid dest", func(t *testing.T) {
@ -158,7 +158,7 @@ func TestScanToStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, *dest.InventoryID, int32(1)) assert.Equal(t, *dest.InventoryID, int32(1))
assert.Equal(t, dest.FilmID, int16(1)) assert.Equal(t, dest.FilmID, int16(1))
@ -175,7 +175,7 @@ func TestScanToStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
}) })
t.Run("type mismatch scanner type", func(t *testing.T) { t.Run("type mismatch scanner type", func(t *testing.T) {
@ -217,10 +217,10 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Film, film1) testutils.AssertDeepEqual(t, dest.Film, film1)
assert.DeepEqual(t, dest.Store, store1) testutils.AssertDeepEqual(t, dest.Store, store1)
}) })
t.Run("embedded pointer structs", func(t *testing.T) { t.Run("embedded pointer structs", func(t *testing.T) {
@ -232,10 +232,10 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, *dest.Inventory, inventory1) testutils.AssertDeepEqual(t, *dest.Inventory, inventory1)
assert.DeepEqual(t, *dest.Film, film1) testutils.AssertDeepEqual(t, *dest.Film, film1)
assert.DeepEqual(t, *dest.Store, store1) testutils.AssertDeepEqual(t, *dest.Store, store1)
}) })
t.Run("embedded unused structs", func(t *testing.T) { t.Run("embedded unused structs", func(t *testing.T) {
@ -246,9 +246,9 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, model.Actor{}) testutils.AssertDeepEqual(t, dest.Actor, model.Actor{})
}) })
t.Run("embedded unused pointer structs", func(t *testing.T) { t.Run("embedded unused pointer structs", func(t *testing.T) {
@ -259,9 +259,9 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil)) testutils.AssertDeepEqual(t, dest.Actor, (*model.Actor)(nil))
}) })
t.Run("embedded unused pointer structs", func(t *testing.T) { t.Run("embedded unused pointer structs", func(t *testing.T) {
@ -272,9 +272,9 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil)) testutils.AssertDeepEqual(t, dest.Actor, (*model.Actor)(nil))
}) })
t.Run("embedded pointer to selected column", func(t *testing.T) { t.Run("embedded pointer to selected column", func(t *testing.T) {
@ -291,9 +291,9 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.Assert(t, dest.Actor != nil) assert.True(t, dest.Actor != nil)
}) })
t.Run("struct embedded unused pointer", func(t *testing.T) { t.Run("struct embedded unused pointer", func(t *testing.T) {
@ -306,9 +306,9 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*struct{ model.Actor })(nil)) testutils.AssertDeepEqual(t, dest.Actor, (*struct{ model.Actor })(nil))
}) })
t.Run("multiple embedded unused pointer", func(t *testing.T) { t.Run("multiple embedded unused pointer", func(t *testing.T) {
@ -322,9 +322,9 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*struct { testutils.AssertDeepEqual(t, dest.Actor, (*struct {
model.Actor model.Actor
model.Language model.Language
})(nil)) })(nil))
@ -341,11 +341,11 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.Assert(t, dest.Actor != nil) assert.True(t, dest.Actor != nil)
assert.DeepEqual(t, dest.Actor.Actor, model.Actor{}) testutils.AssertDeepEqual(t, dest.Actor.Actor, model.Actor{})
assert.DeepEqual(t, dest.Actor.Film, film1) testutils.AssertDeepEqual(t, dest.Actor.Film, film1)
}) })
t.Run("field not nil, deeply nested selected model", func(t *testing.T) { t.Run("field not nil, deeply nested selected model", func(t *testing.T) {
@ -361,11 +361,11 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.Assert(t, dest.Actor != nil) assert.True(t, dest.Actor != nil)
assert.Assert(t, dest.Actor.Film != nil) assert.True(t, dest.Actor.Film != nil)
assert.DeepEqual(t, dest.Actor.Film.Film, &film1) testutils.AssertDeepEqual(t, dest.Actor.Film.Film, &film1)
}) })
t.Run("embedded structs", func(t *testing.T) { t.Run("embedded structs", func(t *testing.T) {
@ -398,15 +398,15 @@ func TestScanToNestedStruct(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1) testutils.AssertDeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Film.Film, film1) testutils.AssertDeepEqual(t, dest.Film.Film, film1)
assert.DeepEqual(t, dest.Store, store1) testutils.AssertDeepEqual(t, dest.Store, store1)
assert.DeepEqual(t, dest.Film.Language, language1) testutils.AssertDeepEqual(t, dest.Film.Language, language1)
assert.DeepEqual(t, dest.Film.Lang.Language, language1) testutils.AssertDeepEqual(t, dest.Film.Lang.Language, language1)
assert.DeepEqual(t, dest.Film.Lang2.Language, language1) testutils.AssertDeepEqual(t, dest.Film.Lang2.Language, language1)
assert.DeepEqual(t, dest.Film.Language2, &language1) testutils.AssertDeepEqual(t, dest.Film.Language2, &language1)
assert.DeepEqual(t, model.Language(*dest.Film.Language3), language1) testutils.AssertDeepEqual(t, model.Language(*dest.Film.Language3), language1)
}) })
} }
@ -423,18 +423,18 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0], inventory1) testutils.AssertDeepEqual(t, dest[0], inventory1)
assert.DeepEqual(t, dest[1], inventory2) testutils.AssertDeepEqual(t, dest[1], inventory2)
}) })
t.Run("slice of ints", func(t *testing.T) { t.Run("slice of ints", func(t *testing.T) {
var dest []int32 var dest []int32
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest, []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) testutils.AssertDeepEqual(t, dest, []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
}) })
@ -442,7 +442,7 @@ func TestScanToSlice(t *testing.T) {
var dest []int var dest []int
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
}) })
t.Run("slice type mismatch", func(t *testing.T) { t.Run("slice type mismatch", func(t *testing.T) {
@ -473,9 +473,9 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest.Film, film1) testutils.AssertDeepEqual(t, dest.Film, film1)
assert.DeepEqual(t, dest.IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) testutils.AssertDeepEqual(t, dest.IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8})
}) })
t.Run("slice of structs with slice of ints", func(t *testing.T) { t.Run("slice of structs with slice of ints", func(t *testing.T) {
@ -486,12 +486,12 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, dest[0].IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) testutils.AssertDeepEqual(t, dest[0].IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8})
assert.DeepEqual(t, dest[1].Film, film2) testutils.AssertDeepEqual(t, dest[1].Film, film2)
assert.DeepEqual(t, dest[1].IDs, []int32{9, 10}) testutils.AssertDeepEqual(t, dest[1].IDs, []int32{9, 10})
}) })
t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) { t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) {
@ -502,13 +502,13 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, dest[0].IDs, []*int32{Int32Ptr(1), Int32Ptr(2), Int32Ptr(3), Int32Ptr(4), testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{Int32Ptr(1), Int32Ptr(2), Int32Ptr(3), Int32Ptr(4),
Int32Ptr(5), Int32Ptr(6), Int32Ptr(7), Int32Ptr(8)}) Int32Ptr(5), Int32Ptr(6), Int32Ptr(7), Int32Ptr(8)})
assert.DeepEqual(t, dest[1].Film, film2) testutils.AssertDeepEqual(t, dest[1].Film, film2)
assert.DeepEqual(t, dest[1].IDs, []*int32{Int32Ptr(9), Int32Ptr(10)}) testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{Int32Ptr(9), Int32Ptr(10)})
}) })
t.Run("complex struct 1", func(t *testing.T) { t.Run("complex struct 1", func(t *testing.T) {
@ -520,13 +520,13 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0].Inventory, inventory1) testutils.AssertDeepEqual(t, dest[0].Inventory, inventory1)
assert.DeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, dest[0].Store, store1) testutils.AssertDeepEqual(t, dest[0].Store, store1)
assert.DeepEqual(t, dest[1].Inventory, inventory2) testutils.AssertDeepEqual(t, dest[1].Inventory, inventory2)
}) })
t.Run("complex struct 2", func(t *testing.T) { t.Run("complex struct 2", func(t *testing.T) {
@ -538,13 +538,13 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0].Inventory, &inventory1) testutils.AssertDeepEqual(t, dest[0].Inventory, &inventory1)
assert.DeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, dest[0].Store, &store1) testutils.AssertDeepEqual(t, dest[0].Store, &store1)
assert.DeepEqual(t, dest[1].Inventory, &inventory2) testutils.AssertDeepEqual(t, dest[1].Inventory, &inventory2)
}) })
t.Run("complex struct 3", func(t *testing.T) { t.Run("complex struct 3", func(t *testing.T) {
@ -558,13 +558,13 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0].Inventory, inventory1) testutils.AssertDeepEqual(t, dest[0].Inventory, inventory1)
assert.DeepEqual(t, dest[0].Film, &film1) testutils.AssertDeepEqual(t, dest[0].Film, &film1)
assert.DeepEqual(t, dest[0].Store.Store, &store1) testutils.AssertDeepEqual(t, dest[0].Store.Store, &store1)
assert.DeepEqual(t, dest[1].Inventory, inventory2) testutils.AssertDeepEqual(t, dest[1].Inventory, inventory2)
}) })
t.Run("complex struct 4", func(t *testing.T) { t.Run("complex struct 4", func(t *testing.T) {
@ -579,12 +579,12 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, len(dest[0].Inventories), 8) testutils.AssertDeepEqual(t, len(dest[0].Inventories), 8)
assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) testutils.AssertDeepEqual(t, dest[0].Inventories[0].Inventory, inventory1)
assert.DeepEqual(t, dest[0].Inventories[0].Store, store1) testutils.AssertDeepEqual(t, dest[0].Inventories[0].Store, store1)
}) })
t.Run("complex struct 5", func(t *testing.T) { t.Run("complex struct 5", func(t *testing.T) {
@ -601,14 +601,14 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0].Film, film1) testutils.AssertDeepEqual(t, dest[0].Film, film1)
assert.Equal(t, len(dest[0].Inventories), 8) assert.Equal(t, len(dest[0].Inventories), 8)
assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) testutils.AssertDeepEqual(t, dest[0].Inventories[0].Inventory, inventory1)
assert.Assert(t, dest[0].Inventories[0].Rentals == nil) assert.True(t, dest[0].Inventories[0].Rentals == nil)
assert.Assert(t, dest[0].Inventories[0].Rentals2 == nil) assert.True(t, dest[0].Inventories[0].Rentals2 == nil)
}) })
}) })
@ -638,16 +638,16 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 108) assert.Equal(t, len(dest), 108)
assert.DeepEqual(t, dest[100].Country, countryUk) testutils.AssertDeepEqual(t, dest[100].Country, countryUk)
assert.Equal(t, len(dest[100].Cities), 8) assert.Equal(t, len(dest[100].Cities), 8)
assert.DeepEqual(t, dest[100].Cities[2].City, cityLondon) testutils.AssertDeepEqual(t, dest[100].Cities[2].City, cityLondon)
assert.Equal(t, len(dest[100].Cities[2].Adresses), 2) assert.Equal(t, len(dest[100].Cities[2].Adresses), 2)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Address, address256) testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[0].Address, address256)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Customer, customer256) testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[0].Customer, customer256)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Address, addres517) testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[1].Address, addres517)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Customer, customer512) testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[1].Customer, customer512)
}) })
t.Run("dest1", func(t *testing.T) { t.Run("dest1", func(t *testing.T) {
@ -667,16 +667,16 @@ func TestScanToSlice(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 108) assert.Equal(t, len(dest), 108)
assert.DeepEqual(t, dest[100].Country, &countryUk) testutils.AssertDeepEqual(t, dest[100].Country, &countryUk)
assert.Equal(t, len(dest[100].Cities), 8) assert.Equal(t, len(dest[100].Cities), 8)
assert.DeepEqual(t, dest[100].Cities[2].City, &cityLondon) testutils.AssertDeepEqual(t, dest[100].Cities[2].City, &cityLondon)
assert.Equal(t, len(*dest[100].Cities[2].Adresses), 2) assert.Equal(t, len(*dest[100].Cities[2].Adresses), 2)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Address, &address256) testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Address, &address256)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Customer, &customer256) testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Customer, &customer256)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Address, &addres517) testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Address, &addres517)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Customer, &customer512) testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Customer, &customer512)
}) })
}) })
@ -716,8 +716,8 @@ func TestStructScanAllNull(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.DeepEqual(t, dest, struct { testutils.AssertDeepEqual(t, dest, struct {
Null1 *int Null1 *int
Null2 *int Null2 *int
}{}) }{})

View file

@ -8,7 +8,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -33,7 +33,7 @@ WHERE actor.actor_id = 2;
actor := model.Actor{} actor := model.Actor{}
err := query.Query(db, &actor) err := query.Query(db, &actor)
assert.NilError(t, err) assert.NoError(t, err)
expectedActor := model.Actor{ expectedActor := model.Actor{
ActorID: 2, ActorID: 2,
@ -42,7 +42,7 @@ WHERE actor.actor_id = 2;
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2),
} }
assert.DeepEqual(t, actor, expectedActor) testutils.AssertDeepEqual(t, actor, expectedActor)
} }
func TestClassicSelect(t *testing.T) { func TestClassicSelect(t *testing.T) {
@ -84,7 +84,7 @@ LIMIT 30;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 30) assert.Equal(t, len(dest), 30)
} }
@ -110,13 +110,13 @@ ORDER BY customer.customer_id ASC;
testutils.AssertDebugStatementSql(t, query, expectedSQL) testutils.AssertDebugStatementSql(t, query, expectedSQL)
err := query.Query(db, &customers) err := query.Query(db, &customers)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(customers), 599) assert.Equal(t, len(customers), 599)
assert.DeepEqual(t, customer0, customers[0]) testutils.AssertDeepEqual(t, customer0, customers[0])
assert.DeepEqual(t, customer1, customers[1]) testutils.AssertDeepEqual(t, customer1, customers[1])
assert.DeepEqual(t, lastCustomer, customers[598]) testutils.AssertDeepEqual(t, lastCustomer, customers[598])
} }
func TestSelectAndUnionInProjection(t *testing.T) { func TestSelectAndUnionInProjection(t *testing.T) {
@ -158,13 +158,13 @@ LIMIT 12;
). ).
LIMIT(12) LIMIT(12)
fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12))
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestJoinQueryStruct(t *testing.T) { func TestJoinQueryStruct(t *testing.T) {
@ -253,7 +253,7 @@ LIMIT 1000;
err := query.Query(db, &languageActorFilm) err := query.Query(db, &languageActorFilm)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(languageActorFilm), 1) assert.Equal(t, len(languageActorFilm), 1)
assert.Equal(t, len(languageActorFilm[0].Films), 10) assert.Equal(t, len(languageActorFilm[0].Films), 10)
assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10) assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10)
@ -302,7 +302,7 @@ LIMIT 15;
err := query.Query(db, &filmsPerLanguage) err := query.Query(db, &filmsPerLanguage)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(filmsPerLanguage), 1) assert.Equal(t, len(filmsPerLanguage), 1)
assert.Equal(t, len(filmsPerLanguage[0].Film), limit) assert.Equal(t, len(filmsPerLanguage[0].Film), limit)
@ -313,7 +313,7 @@ LIMIT 15;
filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} filmsPerLanguageWithPtrs := []*FilmsPerLanguage{}
err = query.Query(db, &filmsPerLanguageWithPtrs) err = query.Query(db, &filmsPerLanguageWithPtrs)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(filmsPerLanguage), 1) assert.Equal(t, len(filmsPerLanguage), 1)
assert.Equal(t, len(filmsPerLanguage[0].Film), limit) assert.Equal(t, len(filmsPerLanguage[0].Film), limit)
} }
@ -359,7 +359,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.Equal(t, dest[0].City.City, "London") assert.Equal(t, dest[0].City.City, "London")
@ -423,7 +423,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.Equal(t, dest[0].Name, "London") assert.Equal(t, dest[0].Name, "London")
@ -481,7 +481,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.Equal(t, dest[0].CityName, "London") assert.Equal(t, dest[0].CityName, "London")
@ -538,7 +538,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
[ [
@ -597,7 +597,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) {
filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} filmsPerLanguageWithPtrs := []*FilmsPerLanguage{}
err := query.Query(db, &filmsPerLanguageWithPtrs) err := query.Query(db, &filmsPerLanguageWithPtrs)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(filmsPerLanguageWithPtrs), 1) assert.Equal(t, len(filmsPerLanguageWithPtrs), 1)
assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit)) assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit))
} }
@ -609,7 +609,7 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) {
err := query.Query(db, &customers) err := query.Query(db, &customers)
assert.NilError(t, err) assert.NoError(t, err)
//spew.Dump(customers) //spew.Dump(customers)
@ -623,7 +623,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
ORDER_BY(Customer.FirstName.ASC()). ORDER_BY(Customer.FirstName.ASC()).
Query(db, &customersAsc) Query(db, &customersAsc)
assert.NilError(t, err) assert.NoError(t, err)
firstCustomerAsc := customersAsc[0] firstCustomerAsc := customersAsc[0]
lastCustomerAsc := customersAsc[len(customersAsc)-1] lastCustomerAsc := customersAsc[len(customersAsc)-1]
@ -633,20 +633,20 @@ func TestSelectOrderByAscDesc(t *testing.T) {
ORDER_BY(Customer.FirstName.DESC()). ORDER_BY(Customer.FirstName.DESC()).
Query(db, &customersDesc) Query(db, &customersDesc)
assert.NilError(t, err) assert.NoError(t, err)
firstCustomerDesc := customersDesc[0] firstCustomerDesc := customersDesc[0]
lastCustomerDesc := customersDesc[len(customersAsc)-1] lastCustomerDesc := customersDesc[len(customersAsc)-1]
assert.DeepEqual(t, firstCustomerAsc, lastCustomerDesc) testutils.AssertDeepEqual(t, firstCustomerAsc, lastCustomerDesc)
assert.DeepEqual(t, lastCustomerAsc, firstCustomerDesc) testutils.AssertDeepEqual(t, lastCustomerAsc, firstCustomerDesc)
customersAscDesc := []model.Customer{} customersAscDesc := []model.Customer{}
err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName).
ORDER_BY(Customer.FirstName.ASC(), Customer.LastName.DESC()). ORDER_BY(Customer.FirstName.ASC(), Customer.LastName.DESC()).
Query(db, &customersAscDesc) Query(db, &customersAscDesc)
assert.NilError(t, err) assert.NoError(t, err)
customerAscDesc326 := model.Customer{ customerAscDesc326 := model.Customer{
CustomerID: 67, CustomerID: 67,
@ -660,8 +660,8 @@ func TestSelectOrderByAscDesc(t *testing.T) {
LastName: "Knott", LastName: "Knott",
} }
assert.DeepEqual(t, customerAscDesc326, customersAscDesc[326]) testutils.AssertDeepEqual(t, customerAscDesc326, customersAscDesc[326])
assert.DeepEqual(t, customerAscDesc327, customersAscDesc[327]) testutils.AssertDeepEqual(t, customerAscDesc327, customersAscDesc[327])
} }
func TestSelectFullJoin(t *testing.T) { func TestSelectFullJoin(t *testing.T) {
@ -702,16 +702,16 @@ ORDER BY customer.customer_id ASC;
err := query.Query(db, &allCustomersAndAddress) err := query.Query(db, &allCustomersAndAddress)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(allCustomersAndAddress), 603) assert.Equal(t, len(allCustomersAndAddress), 603)
assert.DeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0)
assert.Assert(t, allCustomersAndAddress[0].Address != nil) assert.True(t, allCustomersAndAddress[0].Address != nil)
lastCustomerAddress := allCustomersAndAddress[len(allCustomersAndAddress)-1] lastCustomerAddress := allCustomersAndAddress[len(allCustomersAndAddress)-1]
assert.Assert(t, lastCustomerAddress.Customer == nil) assert.True(t, lastCustomerAddress.Customer == nil)
assert.Assert(t, lastCustomerAddress.Address != nil) assert.True(t, lastCustomerAddress.Address != nil)
} }
@ -757,7 +757,7 @@ LIMIT 1000;
assert.Equal(t, len(customerAddresCrosJoined), 1000) assert.Equal(t, len(customerAddresCrosJoined), 1000)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestSelectSelfJoin(t *testing.T) { func TestSelectSelfJoin(t *testing.T) {
@ -813,7 +813,7 @@ ORDER BY f1.film_id ASC;
err := query.Query(db, &theSameLengthFilms) err := query.Query(db, &theSameLengthFilms)
assert.NilError(t, err) assert.NoError(t, err)
//spew.Dump(theSameLengthFilms) //spew.Dump(theSameLengthFilms)
@ -854,12 +854,12 @@ LIMIT 1000;
err := query.Query(db, &films) err := query.Query(db, &films)
assert.NilError(t, err) assert.NoError(t, err)
//spew.Dump(films) //spew.Dump(films)
assert.Equal(t, len(films), 1000) assert.Equal(t, len(films), 1000)
assert.DeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) testutils.AssertDeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46})
} }
func TestSubQuery(t *testing.T) { func TestSubQuery(t *testing.T) {
@ -911,7 +911,7 @@ FROM dvds.actor
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestSelectFunctions(t *testing.T) { func TestSelectFunctions(t *testing.T) {
@ -931,7 +931,7 @@ FROM dvds.film;
err := query.Query(db, &ret) err := query.Query(db, &ret)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, ret.MaxFilmRate, 4.99) assert.Equal(t, ret.MaxFilmRate, 4.99)
} }
@ -973,13 +973,13 @@ ORDER BY film.film_id ASC;
maxRentalRateFilms := []model.Film{} maxRentalRateFilms := []model.Film{}
err := query.Query(db, &maxRentalRateFilms) err := query.Query(db, &maxRentalRateFilms)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(maxRentalRateFilms), 336) assert.Equal(t, len(maxRentalRateFilms), 336)
gRating := model.MpaaRating_G gRating := model.MpaaRating_G
assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{ testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{
FilmID: 2, FilmID: 2,
Title: "Ace Goldfinger", Title: "Ace Goldfinger",
Description: StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), Description: StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
@ -1060,7 +1060,7 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//testutils.PrintJson(dest) //testutils.PrintJson(dest)
@ -1121,10 +1121,10 @@ ORDER BY customer_payment_sum."amount_sum" ASC;
customersWithAmounts := []CustomerWithAmounts{} customersWithAmounts := []CustomerWithAmounts{}
err := query.Query(db, &customersWithAmounts) err := query.Query(db, &customersWithAmounts)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(customersWithAmounts), 599) assert.Equal(t, len(customersWithAmounts), 599)
assert.DeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{ testutils.AssertDeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{
CustomerID: 318, CustomerID: 318,
StoreID: 1, StoreID: 1,
FirstName: "Brian", FirstName: "Brian",
@ -1145,7 +1145,7 @@ func TestSelectStaff(t *testing.T) {
err := Staff.SELECT(Staff.AllColumns).Query(db, &staffs) err := Staff.SELECT(Staff.AllColumns).Query(db, &staffs)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSON(t, staffs, ` testutils.AssertJSON(t, staffs, `
[ [
@ -1203,12 +1203,12 @@ ORDER BY payment.payment_date ASC;
err := query.Query(db, &payments) err := query.Query(db, &payments)
assert.NilError(t, err) assert.NoError(t, err)
//spew.Dump(payments) //spew.Dump(payments)
assert.Equal(t, len(payments), 9) assert.Equal(t, len(payments), 9)
assert.DeepEqual(t, payments[0], model.Payment{ testutils.AssertDeepEqual(t, payments[0], model.Payment{
PaymentID: 17793, PaymentID: 17793,
CustomerID: 416, CustomerID: 416,
StaffID: 2, StaffID: 2,
@ -1257,17 +1257,17 @@ OFFSET 20;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0], model.Payment{ testutils.AssertDeepEqual(t, dest[0], model.Payment{
PaymentID: 17523, PaymentID: 17523,
Amount: 4.99, Amount: 4.99,
}) })
assert.DeepEqual(t, dest[1], model.Payment{ testutils.AssertDeepEqual(t, dest[1], model.Payment{
PaymentID: 17524, PaymentID: 17524,
Amount: 0.99, Amount: 0.99,
}) })
assert.DeepEqual(t, dest[9], model.Payment{ testutils.AssertDeepEqual(t, dest[9], model.Payment{
PaymentID: 17532, PaymentID: 17532,
Amount: 8.99, Amount: 8.99,
}) })
@ -1283,7 +1283,7 @@ func TestAllSetOperators(t *testing.T) {
dest := []model.Payment{} dest := []model.Payment{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 20) assert.Equal(t, len(dest), 20)
}) })
@ -1293,7 +1293,7 @@ func TestAllSetOperators(t *testing.T) {
dest := []model.Payment{} dest := []model.Payment{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 20) assert.Equal(t, len(dest), 20)
}) })
@ -1303,7 +1303,7 @@ func TestAllSetOperators(t *testing.T) {
dest := []model.Payment{} dest := []model.Payment{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 0) assert.Equal(t, len(dest), 0)
}) })
@ -1313,7 +1313,7 @@ func TestAllSetOperators(t *testing.T) {
dest := []model.Payment{} dest := []model.Payment{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 0) assert.Equal(t, len(dest), 0)
}) })
@ -1323,7 +1323,7 @@ func TestAllSetOperators(t *testing.T) {
dest := []model.Payment{} dest := []model.Payment{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
}) })
@ -1333,7 +1333,7 @@ func TestAllSetOperators(t *testing.T) {
dest := []model.Payment{} dest := []model.Payment{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 10) assert.Equal(t, len(dest), 10)
}) })
} }
@ -1363,7 +1363,7 @@ LIMIT 20;
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 20) assert.Equal(t, len(dest), 20)
assert.Equal(t, dest[0].StaffIDNum, "TWO") assert.Equal(t, dest[0].StaffIDNum, "TWO")
assert.Equal(t, dest[1].StaffIDNum, "ONE") assert.Equal(t, dest[1].StaffIDNum, "ONE")
@ -1396,12 +1396,12 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
res, err := query.Exec(tx) res, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
rowsAffected, _ := res.RowsAffected() rowsAffected, _ := res.RowsAffected()
assert.Equal(t, rowsAffected, int64(3)) assert.Equal(t, rowsAffected, int64(3))
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
@ -1412,12 +1412,12 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
res, err := query.Exec(tx) res, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
rowsAffected, _ := res.RowsAffected() rowsAffected, _ := res.RowsAffected()
assert.Equal(t, rowsAffected, int64(3)) assert.Equal(t, rowsAffected, int64(3))
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
@ -1428,12 +1428,12 @@ FOR`
tx, _ := db.Begin() tx, _ := db.Begin()
res, err := query.Exec(tx) res, err := query.Exec(tx)
assert.NilError(t, err) assert.NoError(t, err)
rowsAffected, _ := res.RowsAffected() rowsAffected, _ := res.RowsAffected()
assert.Equal(t, rowsAffected, int64(3)) assert.Equal(t, rowsAffected, int64(3))
err = tx.Rollback() err = tx.Rollback()
assert.NilError(t, err) assert.NoError(t, err)
} }
} }
@ -1509,7 +1509,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
} }
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//jsonSave("./testdata/quick-start-dest.json", dest) //jsonSave("./testdata/quick-start-dest.json", dest)
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json")
@ -1522,7 +1522,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC;
} }
err = stmt.Query(db, &dest2) err = stmt.Query(db, &dest2)
assert.NilError(t, err) assert.NoError(t, err)
//jsonSave("./testdata/quick-start-dest2.json", dest2) //jsonSave("./testdata/quick-start-dest2.json", dest2)
testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json")
@ -1574,7 +1574,7 @@ func TestQuickStartWithSubQueries(t *testing.T) {
} }
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
//jsonSave("./testdata/quick-start-dest.json", dest) //jsonSave("./testdata/quick-start-dest.json", dest)
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json")
@ -1587,7 +1587,7 @@ func TestQuickStartWithSubQueries(t *testing.T) {
} }
err = stmt.Query(db, &dest2) err = stmt.Query(db, &dest2)
assert.NilError(t, err) assert.NoError(t, err)
//jsonSave("./testdata/quick-start-dest2.json", dest2) //jsonSave("./testdata/quick-start-dest2.json", dest2)
testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json")
@ -1620,7 +1620,7 @@ SELECT true,
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestWindowFunction(t *testing.T) { func TestWindowFunction(t *testing.T) {
@ -1686,13 +1686,13 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date;
).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate).
WHERE(Payment.PaymentID.LT(Int(10))) WHERE(Payment.PaymentID.LT(Int(10)))
fmt.Println(query.Sql()) //fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestWindowClause(t *testing.T) { func TestWindowClause(t *testing.T) {
@ -1722,14 +1722,14 @@ ORDER BY payment.customer_id;
WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)).
ORDER_BY(Payment.CustomerID) ORDER_BY(Payment.CustomerID)
fmt.Println(query.Sql()) //fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
} }
func TestSimpleView(t *testing.T) { func TestSimpleView(t *testing.T) {
@ -1748,16 +1748,10 @@ func TestSimpleView(t *testing.T) {
FilmInfo string FilmInfo string
} }
//sql, args := query.Sql()
//
//row := db.QueryRow(sql, args...)
//
//row.Scan()
var dest []ActorInfo var dest []ActorInfo
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
testutils.AssertJSON(t, dest[1:2], ` testutils.AssertJSON(t, dest[1:2], `
[ [
@ -1791,9 +1785,106 @@ func TestJoinViewWithTable(t *testing.T) {
fmt.Println(query.DebugSql()) fmt.Println(query.DebugSql())
err := query.Query(db, &dest) err := query.Query(db, &dest)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest), 2)
assert.Equal(t, len(dest[0].Rentals), 32) assert.Equal(t, len(dest[0].Rentals), 32)
assert.Equal(t, len(dest[1].Rentals), 27) assert.Equal(t, len(dest[1].Rentals), 27)
} }
func TestDynamicProjectionList(t *testing.T) {
var request struct {
ColumnsToSelect []string
ShowFullName bool
}
request.ColumnsToSelect = []string{"customer_id", "create_date"}
request.ShowFullName = true
// ...
projectionList := ProjectionList{}
for _, columnName := range request.ColumnsToSelect {
switch columnName {
case Customer.CustomerID.Name():
projectionList = append(projectionList, Customer.CustomerID)
case Customer.Email.Name():
projectionList = append(projectionList, Customer.Email)
case Customer.CreateDate.Name():
projectionList = append(projectionList, Customer.CreateDate)
}
}
var showFullName bool
if showFullName {
projectionList = append(projectionList, Customer.FirstName.CONCAT(Customer.LastName))
}
stmt := SELECT(projectionList).
FROM(Customer).
LIMIT(3)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT customer.customer_id AS "customer.customer_id",
customer.create_date AS "customer.create_date"
FROM dvds.customer
LIMIT 3;
`)
var dest []model.Customer
err := stmt.Query(db, &dest)
assert.NoError(t, err)
assert.Equal(t, len(dest), 3)
}
func TestDynamicCondition(t *testing.T) {
var request struct {
CustomerID *int64
Email *string
Active *bool
}
request.CustomerID = Int64Ptr(1)
request.Active = BoolPtr(true)
// ...
condition := Bool(true)
if request.CustomerID != nil {
condition = condition.AND(Customer.CustomerID.EQ(Int(*request.CustomerID)))
}
if request.Email != nil {
condition = condition.AND(Customer.Email.EQ(String(*request.Email)))
}
if request.Active != nil {
condition = condition.AND(Customer.Activebool.EQ(Bool(*request.Active)))
}
stmt := SELECT(Customer.AllColumns).
FROM(Customer).
WHERE(condition)
testutils.AssertStatementSql(t, stmt, `
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
customer.last_name AS "customer.last_name",
customer.email AS "customer.email",
customer.address_id AS "customer.address_id",
customer.activebool AS "customer.activebool",
customer.create_date AS "customer.create_date",
customer.last_update AS "customer.last_update",
customer.active AS "customer.active"
FROM dvds.customer
WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3);
`, true, int64(1), true)
dest := []model.Customer{}
err := stmt.Query(db, &dest)
assert.NoError(t, err)
assert.Len(t, dest, 1)
testutils.AssertDeepEqual(t, dest[0], customer0)
}

View file

@ -6,7 +6,7 @@ import (
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
) )
@ -35,9 +35,9 @@ WHERE link.name = 'Bing';
WHERE(Link.Name.EQ(String("Bong"))). WHERE(Link.Name.EQ(String("Bong"))).
Query(db, &links) Query(db, &links)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(links), 1) assert.Equal(t, len(links), 1)
assert.DeepEqual(t, links[0], model.Link{ testutils.AssertDeepEqual(t, links[0], model.Link{
ID: 204, ID: 204,
URL: "http://bong.com", URL: "http://bong.com",
Name: "Bong", Name: "Bong",
@ -99,7 +99,7 @@ RETURNING link.id AS "link.id",
err := stmt.Query(db, &links) err := stmt.Query(db, &links)
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, len(links), 2) assert.Equal(t, len(links), 2)
assert.Equal(t, links[0].Name, "DuckDuckGo") assert.Equal(t, links[0].Name, "DuckDuckGo")
assert.Equal(t, links[1].Name, "DuckDuckGo") assert.Equal(t, links[1].Name, "DuckDuckGo")
@ -293,10 +293,10 @@ func setupLinkTableForUpdateTest(t *testing.T) {
VALUES(204, "http://www.bing.com", "Bing", DEFAULT). VALUES(204, "http://www.bing.com", "Bing", DEFAULT).
Exec(db) Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }
func cleanUpLinkTable(t *testing.T) { func cleanUpLinkTable(t *testing.T) {
_, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db) _, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
} }

View file

@ -5,16 +5,16 @@ import (
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model"
"github.com/google/uuid" "github.com/google/uuid"
"gotest.tools/assert" "github.com/stretchr/testify/assert"
"testing" "testing"
) )
func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) {
res, err := stmt.Exec(db) res, err := stmt.Exec(db)
assert.NilError(t, err) assert.NoError(t, err)
rows, err := res.RowsAffected() rows, err := res.RowsAffected()
assert.NilError(t, err) assert.NoError(t, err)
assert.Equal(t, rows, rowsAffected) assert.Equal(t, rows, rowsAffected)
} }

@ -1 +1 @@
Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6 Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17