From 947d2df47e90b4de1ea3b8a7e34fcdfc9f1ce8b0 Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 18 Oct 2019 09:56:38 +0200 Subject: [PATCH 01/19] Additional qrm tests. --- .circleci/config.yml | 2 +- qrm/utill.go | 9 ++------- qrm/utill_test.go | 35 +++++++++++++++++++++++++++++++++++ tests/testdata | 2 +- 4 files changed, 39 insertions(+), 9 deletions(-) create mode 100644 qrm/utill_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 6c73d31..d3ea5d0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -94,7 +94,7 @@ jobs: - run: mkdir -p $TEST_RESULTS - - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml + - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - run: name: Upload code coverage diff --git a/qrm/utill.go b/qrm/utill.go index baee26e..20b8b54 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -156,6 +156,7 @@ func valueToString(value reflect.Value) string { var timeType = reflect.TypeOf(time.Now()) var uuidType = reflect.TypeOf(uuid.New()) +var byteArrayType = reflect.TypeOf([]byte("")) func isSimpleModelType(objType reflect.Type) bool { objType = indirectType(objType) @@ -167,15 +168,9 @@ func isSimpleModelType(objType reflect.Type) bool { reflect.String, reflect.Bool: return true - case reflect.Slice: - return objType.Elem().Kind() == reflect.Uint8 //[]byte - case reflect.Struct: - return objType == timeType - case reflect.Array: - return objType == uuidType // uuid.UUID returns reflect.Array kind } - return false + return objType == timeType || objType == uuidType || objType == byteArrayType } func isIntegerType(value reflect.Type) bool { diff --git a/qrm/utill_test.go b/qrm/utill_test.go new file mode 100644 index 0000000..e4ab53d --- /dev/null +++ b/qrm/utill_test.go @@ -0,0 +1,35 @@ +package qrm + +import ( + "github.com/google/uuid" + "gotest.tools/assert" + "reflect" + "testing" + "time" +) + +func TestIsSimpleModelType(t *testing.T) { + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int8(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int16(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int32(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int64(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) + + assert.Assert(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) + + assert.Assert(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(time.Now()))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) + + complexModelType := struct { + Field1 string + Field2 string + }{} + + assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false) + assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false) +} diff --git a/tests/testdata b/tests/testdata index 1f6bd8b..02e0795 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1f6bd8bb86458019fa43b1e2cd7ae9488a7ac9a4 +Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6 From 0b00a6b12cd6b8dc9f5941ae5fe62f1b3f597d19 Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 18 Oct 2019 10:09:56 +0200 Subject: [PATCH 02/19] Some linter errors. --- internal/utils/utils.go | 8 ++++++++ postgres/functions.go | 2 +- qrm/scan_context.go | 4 +--- qrm/utill.go | 5 ++--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 9be5310..22ea1c3 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -113,6 +113,13 @@ func IsNil(v interface{}) bool { return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) } +// MustBeTrue panics when condition is false +func MustBeTrue(condition bool, errorStr string) { + if !condition { + panic(errorStr) + } +} + // MustBe panics with errorStr error, if v interface is not of reflect kind func MustBe(v interface{}, kind reflect.Kind, errorStr string) { if reflect.TypeOf(v).Kind() != kind { @@ -165,6 +172,7 @@ func ErrorCatch(err *error) { } } +// StringSliceContains checks if slice of strings contains a string func StringSliceContains(strings []string, contains string) bool { for _, str := range strings { if str == contains { diff --git a/postgres/functions.go b/postgres/functions.go index 6993de4..ddd01db 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -69,7 +69,7 @@ var COUNT = jet.COUNT // EVERY is aggregate function. Returns true if all input values are true, otherwise false var EVERY = jet.EVERY -// MAXf is aggregate function. Returns maximum value of expression across all input values +// MAX is aggregate function. Returns maximum value of expression across all input values var MAX = jet.MAX // MAXf is aggregate function. Returns maximum value of float expression across all input values diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 8141ed7..e3f7f40 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -212,9 +212,7 @@ func (s *scanContext) rowElem(index int) interface{} { valuer, ok := s.row[index].(driver.Valuer) - if !ok { - panic("jet: internal error, scan value doesn't implement driver.Valuer") - } + utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer") value, err := valuer.Value() diff --git a/qrm/utill.go b/qrm/utill.go index 20b8b54..7791f9a 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -50,9 +50,8 @@ func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value { } func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error { - if slicePtrValue.IsNil() { - panic("jet: internal, slice is nil") - } + utils.MustBeTrue(!slicePtrValue.IsNil(), "jet: internal, slice is nil") + sliceValue := slicePtrValue.Elem() sliceElemType := sliceValue.Type().Elem() From 15acb1c3262bee09053b62cfb9e5cc7c69282dbb Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 18 Oct 2019 10:15:08 +0200 Subject: [PATCH 03/19] QRM returns qrm.ErrNoRows when scanning into struct destination and query result set is empty. --- internal/jet/statement.go | 4 ++-- qrm/qrm.go | 9 ++++++--- tests/postgres/scan_test.go | 3 +-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index ab0c655..e4ba41b 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -16,11 +16,11 @@ type Statement interface { // Query executes statement over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. - // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. + // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. Query(db qrm.DB, destination interface{}) error // QueryContext executes statement with a context over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. - // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. + // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. QueryContext(context context.Context, db qrm.DB, destination interface{}) error //Exec executes statement over db connection without returning any rows. diff --git a/qrm/qrm.go b/qrm/qrm.go index 8f79c57..e7e6406 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -2,15 +2,18 @@ package qrm import ( "context" - "database/sql" + "errors" "github.com/go-jet/jet/internal/utils" "reflect" ) +// ErrNoRows is returned by Query when query result set is empty +var ErrNoRows = errors.New("qrm: no rows in result set") + // Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. -// If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. +// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { utils.MustBeInitializedPtr(db, "jet: db is nil") @@ -33,7 +36,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr } if rowsProcessed == 0 { - return sql.ErrNoRows + return ErrNoRows } // edge case when row result set contains only NULLs. diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index f42dbd5..11dac96 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -1,7 +1,6 @@ package postgres import ( - "database/sql" "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -704,7 +703,7 @@ func TestStructScanErrNoRows(t *testing.T) { err := query.Query(db, &customer) - assert.Error(t, err, sql.ErrNoRows.Error()) + assert.Error(t, err, qrm.ErrNoRows.Error()) } func TestStructScanAllNull(t *testing.T) { From d1970b3a554e645c26cdfcb529f08764de73cfc0 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 1 Dec 2019 18:25:30 +0100 Subject: [PATCH 04/19] MySQL interval with date/time expression arithmetic. --- internal/jet/bool_expression.go | 23 +-- internal/jet/cast.go | 4 +- internal/jet/column.go | 4 +- internal/jet/column_test.go | 2 +- internal/jet/date_expression.go | 43 +++-- internal/jet/date_expression_test.go | 13 ++ internal/jet/dialect.go | 8 +- internal/jet/enum_value.go | 4 +- internal/jet/expression.go | 40 +++-- internal/jet/float_expression.go | 16 +- internal/jet/func_expression.go | 18 +- internal/jet/integer_expression.go | 38 +---- internal/jet/interval.go | 29 ++++ internal/jet/literal_expression.go | 36 ++-- internal/jet/operators.go | 4 +- internal/jet/serializer.go | 14 ++ internal/jet/sql_builder.go | 4 +- internal/jet/statement.go | 6 +- internal/jet/string_expression.go | 12 +- internal/jet/testutils.go | 58 ++++--- internal/jet/time_expression.go | 23 +-- internal/jet/time_expression_test.go | 8 + internal/jet/timestamp_expression.go | 11 ++ internal/jet/timestamp_expression_test.go | 8 + internal/jet/timestampz_expression.go | 14 +- internal/jet/timestampz_expression_test.go | 8 + internal/jet/timez_expression.go | 31 +--- internal/jet/timez_expression_test.go | 12 +- internal/jet/utils.go | 10 ++ internal/testutils/test_utils.go | 22 +++ internal/utils/utils.go | 19 +++ mysql/cast_test.go | 20 +-- mysql/dialect.go | 15 +- mysql/dialect_test.go | 42 ++--- mysql/interval.go | 187 +++++++++++++++++++++ mysql/interval_test.go | 95 +++++++++++ mysql/literal_test.go | 24 +-- mysql/table_test.go | 28 +-- mysql/utils_test.go | 7 +- tests/mysql/alltypes_test.go | 158 ++++++++++++++--- tests/mysql/select_test.go | 5 +- 41 files changed, 805 insertions(+), 318 deletions(-) create mode 100644 internal/jet/date_expression_test.go create mode 100644 internal/jet/interval.go create mode 100644 mysql/interval.go create mode 100644 mysql/interval_test.go diff --git a/internal/jet/bool_expression.go b/internal/jet/bool_expression.go index 5bdda95..fa9342f 100644 --- a/internal/jet/bool_expression.go +++ b/internal/jet/bool_expression.go @@ -85,26 +85,13 @@ func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression { } //---------------------------------------------------// -type binaryBoolExpression struct { - expressionInterfaceImpl - boolInterfaceImpl - - binaryOpExpression -} - func newBinaryBoolOperator(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression { - binaryBoolExpression := binaryBoolExpression{} - - binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator, additionalParams...) - binaryBoolExpression.expressionInterfaceImpl.Parent = &binaryBoolExpression - binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression - - return &binaryBoolExpression + return BoolExp(newBinaryOperatorExpression(lhs, rhs, operator, additionalParams...)) } //---------------------------------------------------// type prefixBoolExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl boolInterfaceImpl prefixOpExpression @@ -114,7 +101,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio exp := prefixBoolExpression{} exp.prefixOpExpression = newPrefixExpression(expression, operator) - exp.expressionInterfaceImpl.Parent = &exp + exp.ExpressionInterfaceImpl.Parent = &exp exp.boolInterfaceImpl.parent = &exp return &exp @@ -122,7 +109,7 @@ func newPrefixBoolOperator(expression Expression, operator string) BoolExpressio //---------------------------------------------------// type postfixBoolOpExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl boolInterfaceImpl postfixOpExpression @@ -132,7 +119,7 @@ func newPostifxBoolExpression(expression Expression, operator string) BoolExpres exp := postfixBoolOpExpression{} exp.postfixOpExpression = newPostfixOpExpression(expression, operator) - exp.expressionInterfaceImpl.Parent = &exp + exp.ExpressionInterfaceImpl.Parent = &exp exp.boolInterfaceImpl.parent = &exp return &exp diff --git a/internal/jet/cast.go b/internal/jet/cast.go index 84e1962..c5fe9a7 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -24,13 +24,13 @@ func (b *castImpl) AS(castType string) Expression { cast: string(castType), } - castExp.expressionInterfaceImpl.Parent = castExp + castExp.ExpressionInterfaceImpl.Parent = castExp return castExp } type castExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expression Expression cast string diff --git a/internal/jet/column.go b/internal/jet/column.go index d1422d4..c7a5f41 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -20,7 +20,7 @@ type ColumnExpression interface { // The base type for real materialized columns. type columnImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl name string tableName string @@ -34,7 +34,7 @@ func newColumn(name string, tableName string, parent ColumnExpression) columnImp tableName: tableName, } - bc.expressionInterfaceImpl.Parent = parent + bc.ExpressionInterfaceImpl.Parent = parent return bc } diff --git a/internal/jet/column_test.go b/internal/jet/column_test.go index 2159c68..ca3f5f6 100644 --- a/internal/jet/column_test.go +++ b/internal/jet/column_test.go @@ -4,7 +4,7 @@ import "testing" func TestColumn(t *testing.T) { column := newColumn("col", "", nil) - column.expressionInterfaceImpl.Parent = &column + column.ExpressionInterfaceImpl.Parent = &column assertClauseSerialize(t, column, "col") column.setTableName("table1") diff --git a/internal/jet/date_expression.go b/internal/jet/date_expression.go index 357e8a5..8b0a524 100644 --- a/internal/jet/date_expression.go +++ b/internal/jet/date_expression.go @@ -13,42 +13,53 @@ type DateExpression interface { LT_EQ(rhs DateExpression) BoolExpression GT(rhs DateExpression) BoolExpression GT_EQ(rhs DateExpression) BoolExpression + + ADD(rhs Interval) TimestampExpression + SUB(rhs Interval) TimestampExpression } type dateInterfaceImpl struct { parent DateExpression } -func (t *dateInterfaceImpl) EQ(rhs DateExpression) BoolExpression { - return eq(t.parent, rhs) +func (d *dateInterfaceImpl) EQ(rhs DateExpression) BoolExpression { + return eq(d.parent, rhs) } -func (t *dateInterfaceImpl) NOT_EQ(rhs DateExpression) BoolExpression { - return notEq(t.parent, rhs) +func (d *dateInterfaceImpl) NOT_EQ(rhs DateExpression) BoolExpression { + return notEq(d.parent, rhs) } -func (t *dateInterfaceImpl) IS_DISTINCT_FROM(rhs DateExpression) BoolExpression { - return isDistinctFrom(t.parent, rhs) +func (d *dateInterfaceImpl) IS_DISTINCT_FROM(rhs DateExpression) BoolExpression { + return isDistinctFrom(d.parent, rhs) } -func (t *dateInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs DateExpression) BoolExpression { - return isNotDistinctFrom(t.parent, rhs) +func (d *dateInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs DateExpression) BoolExpression { + return isNotDistinctFrom(d.parent, rhs) } -func (t *dateInterfaceImpl) LT(rhs DateExpression) BoolExpression { - return lt(t.parent, rhs) +func (d *dateInterfaceImpl) LT(rhs DateExpression) BoolExpression { + return lt(d.parent, rhs) } -func (t *dateInterfaceImpl) LT_EQ(rhs DateExpression) BoolExpression { - return ltEq(t.parent, rhs) +func (d *dateInterfaceImpl) LT_EQ(rhs DateExpression) BoolExpression { + return ltEq(d.parent, rhs) } -func (t *dateInterfaceImpl) GT(rhs DateExpression) BoolExpression { - return gt(t.parent, rhs) +func (d *dateInterfaceImpl) GT(rhs DateExpression) BoolExpression { + return gt(d.parent, rhs) } -func (t *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression { - return gtEq(t.parent, rhs) +func (d *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression { + return gtEq(d.parent, rhs) +} + +func (d *dateInterfaceImpl) ADD(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(d.parent, rhs, "+")) +} + +func (d *dateInterfaceImpl) SUB(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(d.parent, rhs, "-")) } //---------------------------------------------------// diff --git a/internal/jet/date_expression_test.go b/internal/jet/date_expression_test.go new file mode 100644 index 0000000..14fdd76 --- /dev/null +++ b/internal/jet/date_expression_test.go @@ -0,0 +1,13 @@ +package jet + +import ( + "testing" +) + +func TestDateArithmetic(t *testing.T) { + timestamp := Timestamp(2000, 1, 1, 0, 0, 0) + assertClauseDebugSerialize(t, table1ColDate.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_date + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") + assertClauseDebugSerialize(t, table1ColDate.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_date - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") +} diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index e46000a..71720d1 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -11,11 +11,11 @@ type Dialect interface { ArgumentPlaceholder() QueryPlaceholderFunc } -// SerializeFunc func -type SerializeFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption) +// SerializerFunc func +type SerializerFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption) -// SerializeOverride func -type SerializeOverride func(expressions ...Expression) SerializeFunc +//// SerializeOverride func +type SerializeOverride func(expressions ...Serializer) SerializerFunc // QueryPlaceholderFunc func type QueryPlaceholderFunc func(ord int) string diff --git a/internal/jet/enum_value.go b/internal/jet/enum_value.go index 5bd609d..17e8c74 100644 --- a/internal/jet/enum_value.go +++ b/internal/jet/enum_value.go @@ -1,7 +1,7 @@ package jet type enumValue struct { - expressionInterfaceImpl + ExpressionInterfaceImpl stringInterfaceImpl name string @@ -11,7 +11,7 @@ type enumValue struct { func NewEnumValue(name string) StringExpression { enumValue := &enumValue{name: name} - enumValue.expressionInterfaceImpl.Parent = enumValue + enumValue.ExpressionInterfaceImpl.Parent = enumValue enumValue.stringInterfaceImpl.parent = enumValue return enumValue diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 4ded89f..6a99638 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -27,63 +27,65 @@ type Expression interface { DESC() OrderByClause } -type expressionInterfaceImpl struct { +type ExpressionInterfaceImpl struct { Parent Expression } -func (e *expressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { +func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { return e.Parent } -func (e *expressionInterfaceImpl) IS_NULL() BoolExpression { +func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression { return newPostifxBoolExpression(e.Parent, "IS NULL") } -func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression { +func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { return newPostifxBoolExpression(e.Parent, "IS NOT NULL") } -func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { +func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN") } -func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { +func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN") } -func (e *expressionInterfaceImpl) AS(alias string) Projection { +func (e *ExpressionInterfaceImpl) AS(alias string) Projection { return newAlias(e.Parent, alias) } -func (e *expressionInterfaceImpl) ASC() OrderByClause { +func (e *ExpressionInterfaceImpl) ASC() OrderByClause { return newOrderByClause(e.Parent, true) } -func (e *expressionInterfaceImpl) DESC() OrderByClause { +func (e *ExpressionInterfaceImpl) DESC() OrderByClause { return newOrderByClause(e.Parent, false) } -func (e *expressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { +func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { e.Parent.serialize(statement, out, noWrap) } -func (e *expressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { +func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { e.Parent.serialize(statement, out, noWrap) } -func (e *expressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { +func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { e.Parent.serialize(statement, out, noWrap) } // Representation of binary operations (e.g. comparisons, arithmetic) -type binaryOpExpression struct { - lhs, rhs Expression - additionalParam Expression +type binaryOperatorExpression struct { + ExpressionInterfaceImpl + + lhs, rhs Serializer + additionalParam Serializer operator string } -func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam ...Expression) binaryOpExpression { - binaryExpression := binaryOpExpression{ +func newBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression { + binaryExpression := &binaryOperatorExpression{ lhs: lhs, rhs: rhs, operator: operator, @@ -93,10 +95,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam . binaryExpression.additionalParam = additionalParam[0] } + binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression + return binaryExpression } -func (c *binaryOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if c.lhs == nil { panic("jet: lhs is nil for '" + c.operator + "' operator") } diff --git a/internal/jet/float_expression.go b/internal/jet/float_expression.go index c2ec535..aa821ba 100644 --- a/internal/jet/float_expression.go +++ b/internal/jet/float_expression.go @@ -85,22 +85,8 @@ func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression { } //---------------------------------------------------// -type binaryFloatExpression struct { - expressionInterfaceImpl - floatInterfaceImpl - - binaryOpExpression -} - func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpression { - floatExpression := binaryFloatExpression{} - - floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - - floatExpression.expressionInterfaceImpl.Parent = &floatExpression - floatExpression.floatInterfaceImpl.parent = &floatExpression - - return &floatExpression + return FloatExp(newBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 3b334c6..f38c9a2 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -578,7 +578,7 @@ func LEAST(value Expression, values ...Expression) Expression { //--------------------------------------------------------------------// type funcExpressionImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl name string expressions []Expression @@ -592,9 +592,9 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr } if parent != nil { - funcExp.expressionInterfaceImpl.Parent = parent + funcExp.ExpressionInterfaceImpl.Parent = parent } else { - funcExp.expressionInterfaceImpl.Parent = funcExp + funcExp.ExpressionInterfaceImpl.Parent = funcExp } return funcExp @@ -605,14 +605,14 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression { newFun := newFunc(name, expressions, nil) windowExpr := newWindowExpression(newFun) - newFun.expressionInterfaceImpl.Parent = windowExpr + newFun.ExpressionInterfaceImpl.Parent = windowExpr return windowExpr } func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(f.expressions...) + serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...) serializeOverrideFunc(statement, out, options...) return } @@ -642,7 +642,7 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression { boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.boolInterfaceImpl.parent = boolFunc - boolFunc.expressionInterfaceImpl.Parent = boolFunc + boolFunc.ExpressionInterfaceImpl.Parent = boolFunc return boolFunc } @@ -654,7 +654,7 @@ func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpress boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) intWindowFunc := newBoolWindowExpression(boolFunc) boolFunc.boolInterfaceImpl.parent = intWindowFunc - boolFunc.expressionInterfaceImpl.Parent = intWindowFunc + boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc return intWindowFunc } @@ -681,7 +681,7 @@ func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpre floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatWindowFunc := newFloatWindowExpression(floatFunc) floatFunc.floatInterfaceImpl.parent = floatWindowFunc - floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc + floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc return floatWindowFunc } @@ -707,7 +707,7 @@ func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowE integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) intWindowFunc := newIntegerWindowExpression(integerFunc) integerFunc.integerInterfaceImpl.parent = intWindowFunc - integerFunc.expressionInterfaceImpl.Parent = intWindowFunc + integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc return intWindowFunc } diff --git a/internal/jet/integer_expression.go b/internal/jet/integer_expression.go index 74a3927..efda817 100644 --- a/internal/jet/integer_expression.go +++ b/internal/jet/integer_expression.go @@ -130,27 +130,13 @@ func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) } //---------------------------------------------------// -type binaryIntegerExpression struct { - expressionInterfaceImpl - integerInterfaceImpl - - binaryOpExpression -} - func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { - integerExpression := binaryIntegerExpression{} - - integerExpression.expressionInterfaceImpl.Parent = &integerExpression - integerExpression.integerInterfaceImpl.parent = &integerExpression - - integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - - return &integerExpression + return IntExp(newBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// type prefixIntegerOpExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl integerInterfaceImpl prefixOpExpression @@ -160,30 +146,12 @@ func newPrefixIntegerOperator(expression IntegerExpression, operator string) Int integerExpression := prefixIntegerOpExpression{} integerExpression.prefixOpExpression = newPrefixExpression(expression, operator) - integerExpression.expressionInterfaceImpl.Parent = &integerExpression + integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression integerExpression.integerInterfaceImpl.parent = &integerExpression return &integerExpression } -//---------------------------------------------------// -type prefixFloatOpExpression struct { - expressionInterfaceImpl - floatInterfaceImpl - - prefixOpExpression -} - -func newPrefixFloatOperator(expression FloatExpression, operator string) FloatExpression { - floatOpExpression := prefixFloatOpExpression{} - floatOpExpression.prefixOpExpression = newPrefixExpression(expression, operator) - - floatOpExpression.expressionInterfaceImpl.Parent = &floatOpExpression - floatOpExpression.floatInterfaceImpl.parent = &floatOpExpression - - return &floatOpExpression -} - //---------------------------------------------------// type integerExpressionWrapper struct { integerInterfaceImpl diff --git a/internal/jet/interval.go b/internal/jet/interval.go new file mode 100644 index 0000000..705f2be --- /dev/null +++ b/internal/jet/interval.go @@ -0,0 +1,29 @@ +package jet + +type Interval interface { + Serializer + IsInterval +} + +type IsInterval interface { + isInterval() +} + +func NewInterval(s Serializer) Interval { + newInterval := &intervalImpl{ + interval: s, + } + + return newInterval +} + +type intervalImpl struct { + interval Serializer +} + +func (i intervalImpl) isInterval() {} + +func (i intervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString("INTERVAL") + i.interval.serialize(statement, out, options...) +} diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 68fb429..6983b8f 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -14,7 +14,7 @@ type LiteralExpression interface { } type literalExpressionImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl value interface{} constant bool @@ -27,11 +27,17 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl exp.constant = optionalConstant[0] } - exp.expressionInterfaceImpl.Parent = &exp + exp.ExpressionInterfaceImpl.Parent = &exp return &exp } +// Literal is injected directly to SQL query, and does not appear in parametrized argument list. +func Literal(value interface{}) *literalExpressionImpl { + exp := literal(value) + return exp +} + // FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. func FixedLiteral(value interface{}) *literalExpressionImpl { exp := literal(value) @@ -273,13 +279,13 @@ func formatNanoseconds(nanoseconds ...time.Duration) string { //--------------------------------------------------// type nullLiteral struct { - expressionInterfaceImpl + ExpressionInterfaceImpl } func newNullLiteral() Expression { nullExpression := &nullLiteral{} - nullExpression.expressionInterfaceImpl.Parent = nullExpression + nullExpression.ExpressionInterfaceImpl.Parent = nullExpression return nullExpression } @@ -290,13 +296,13 @@ func (n *nullLiteral) serialize(statement StatementType, out *SQLBuilder, option //--------------------------------------------------// type starLiteral struct { - expressionInterfaceImpl + ExpressionInterfaceImpl } func newStarLiteral() Expression { starExpression := &starLiteral{} - starExpression.expressionInterfaceImpl.Parent = starExpression + starExpression.ExpressionInterfaceImpl.Parent = starExpression return starExpression } @@ -308,7 +314,7 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option //---------------------------------------------------// type wrap struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expressions []Expression } @@ -321,28 +327,28 @@ func (n *wrap) serialize(statement StatementType, out *SQLBuilder, options ...Se // WRAP wraps list of expressions with brackets '(' and ')' func WRAP(expression ...Expression) Expression { wrap := &wrap{expressions: expression} - wrap.expressionInterfaceImpl.Parent = wrap + wrap.ExpressionInterfaceImpl.Parent = wrap return wrap } //---------------------------------------------------// -type rawExpression struct { - expressionInterfaceImpl +type RawExpression struct { + ExpressionInterfaceImpl - raw string + Raw string } -func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString(n.raw) +func (n *RawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString(n.Raw) } // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") func Raw(raw string) Expression { - rawExp := &rawExpression{raw: raw} - rawExp.expressionInterfaceImpl.Parent = rawExp + rawExp := &RawExpression{Raw: raw} + rawExp.ExpressionInterfaceImpl.Parent = rawExp return rawExp } diff --git a/internal/jet/operators.go b/internal/jet/operators.go index 4a4a32d..9229717 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -79,7 +79,7 @@ type CaseOperator interface { } type caseOperatorImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expression Expression when []Expression @@ -95,7 +95,7 @@ func CASE(expression ...Expression) CaseOperator { caseExp.expression = expression[0] } - caseExp.expressionInterfaceImpl.Parent = caseExp + caseExp.ExpressionInterfaceImpl.Parent = caseExp return caseExp } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 585d7db..d5ff4b9 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -41,3 +41,17 @@ func contains(options []SerializeOption, option SerializeOption) bool { return false } + +type ListSerializer struct { + Serializers []Serializer + Separator string +} + +func (s ListSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + for i, ser := range s.Serializers { + if i > 0 { + out.WriteString(s.Separator) + } + ser.serialize(statement, out) + } +} diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 4eaf626..3b34ab2 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -22,7 +22,7 @@ type SQLBuilder struct { lastChar byte ident int - debug bool + Debug bool } const defaultIdent = 5 @@ -120,7 +120,7 @@ func (s *SQLBuilder) insertConstantArgument(arg interface{}) { } func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { - if s.debug { + if s.Debug { s.insertConstantArgument(arg) return } diff --git a/internal/jet/statement.go b/internal/jet/statement.go index e4ba41b..3b0638d 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -65,7 +65,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface } func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { - sqlBuilder := &SQLBuilder{Dialect: s.dialect, debug: true} + sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true} s.parent.serialize(s.statementType, sqlBuilder, noWrap) @@ -106,7 +106,7 @@ type ExpressionStatement interface { // NewExpressionStatementImpl creates new expression statement func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement { return &expressionStatementImpl{ - expressionInterfaceImpl{Parent: parent}, + ExpressionInterfaceImpl{Parent: parent}, statementImpl{ serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ parent: parent, @@ -119,7 +119,7 @@ func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, pa } type expressionStatementImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl statementImpl } diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index b0351ea..a3a76c7 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -82,20 +82,14 @@ func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSens //---------------------------------------------------// type binaryStringExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl stringInterfaceImpl - binaryOpExpression + binaryOperatorExpression } func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpression { - boolExpression := binaryStringExpression{} - - boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - boolExpression.expressionInterfaceImpl.Parent = &boolExpression - boolExpression.stringInterfaceImpl.parent = &boolExpression - - return &boolExpression + return StringExp(newBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 3c0a969..545f12c 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -14,36 +14,40 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests }, }) -var table1Col1 = IntegerColumn("col1") -var table1ColInt = IntegerColumn("col_int") -var table1ColFloat = FloatColumn("col_float") -var table1Col3 = IntegerColumn("col3") -var table1ColTime = TimeColumn("col_time") -var table1ColTimez = TimezColumn("col_timez") -var table1ColTimestamp = TimestampColumn("col_timestamp") -var table1ColTimestampz = TimestampzColumn("col_timestampz") -var table1ColBool = BoolColumn("col_bool") -var table1ColDate = DateColumn("col_date") - +var ( + table1Col1 = IntegerColumn("col1") + table1ColInt = IntegerColumn("col_int") + table1ColFloat = FloatColumn("col_float") + table1Col3 = IntegerColumn("col3") + table1ColTime = TimeColumn("col_time") + table1ColTimez = TimezColumn("col_timez") + table1ColTimestamp = TimestampColumn("col_timestamp") + table1ColTimestampz = TimestampzColumn("col_timestampz") + table1ColBool = BoolColumn("col_bool") + table1ColDate = DateColumn("col_date") +) var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz) -var table2Col3 = IntegerColumn("col3") -var table2Col4 = IntegerColumn("col4") -var table2ColInt = IntegerColumn("col_int") -var table2ColFloat = FloatColumn("col_float") -var table2ColStr = StringColumn("col_str") -var table2ColBool = BoolColumn("col_bool") -var table2ColTime = TimeColumn("col_time") -var table2ColTimez = TimezColumn("col_timez") -var table2ColTimestamp = TimestampColumn("col_timestamp") -var table2ColTimestampz = TimestampzColumn("col_timestampz") -var table2ColDate = DateColumn("col_date") - +var ( + table2Col3 = IntegerColumn("col3") + table2Col4 = IntegerColumn("col4") + table2ColInt = IntegerColumn("col_int") + table2ColFloat = FloatColumn("col_float") + table2ColStr = StringColumn("col_str") + table2ColBool = BoolColumn("col_bool") + table2ColTime = TimeColumn("col_time") + table2ColTimez = TimezColumn("col_timez") + table2ColTimestamp = TimestampColumn("col_timestamp") + table2ColTimestampz = TimestampzColumn("col_timestampz") + table2ColDate = DateColumn("col_date") +) var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz) -var table3Col1 = IntegerColumn("col1") -var table3ColInt = IntegerColumn("col_int") -var table3StrCol = StringColumn("col2") +var ( + table3Col1 = IntegerColumn("col1") + table3ColInt = IntegerColumn("col_int") + table3StrCol = StringColumn("col2") +) var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol) func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { @@ -67,7 +71,7 @@ func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) } func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { - out := SQLBuilder{Dialect: defaultDialect, debug: true} + out := SQLBuilder{Dialect: defaultDialect, Debug: true} clause.serialize(SelectStatementType, &out) //fmt.Println(out.Buff.String()) diff --git a/internal/jet/time_expression.go b/internal/jet/time_expression.go index 779d37f..b83f731 100644 --- a/internal/jet/time_expression.go +++ b/internal/jet/time_expression.go @@ -13,6 +13,9 @@ type TimeExpression interface { LT_EQ(rhs TimeExpression) BoolExpression GT(rhs TimeExpression) BoolExpression GT_EQ(rhs TimeExpression) BoolExpression + + ADD(rhs Interval) TimeExpression + SUB(rhs Interval) TimeExpression } type timeInterfaceImpl struct { @@ -51,23 +54,13 @@ func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression { return gtEq(t.parent, rhs) } -//---------------------------------------------------// -type prefixTimeExpression struct { - expressionInterfaceImpl - timeInterfaceImpl - - prefixOpExpression +func (t *timeInterfaceImpl) ADD(rhs Interval) TimeExpression { + return TimeExp(newBinaryOperatorExpression(t.parent, rhs, "+")) } -//func newPrefixTimeExpression(operator string, expression Expression) TimeExpression { -// timeExpr := prefixTimeExpression{} -// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) -// -// timeExpr.expressionInterfaceImpl.parent = &timeExpr -// timeExpr.timeInterfaceImpl.parent = &timeExpr -// -// return &timeExpr -//} +func (t *timeInterfaceImpl) SUB(rhs Interval) TimeExpression { + return TimeExp(newBinaryOperatorExpression(t.parent, rhs, "-")) +} //---------------------------------------------------// diff --git a/internal/jet/time_expression_test.go b/internal/jet/time_expression_test.go index 2b3d015..61ee29f 100644 --- a/internal/jet/time_expression_test.go +++ b/internal/jet/time_expression_test.go @@ -52,3 +52,11 @@ func TestTimeExp(t *testing.T) { assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)), "(table1.col_float < $1)", string("01:01:01.001")) } + +func TestTimeArithmetic(t *testing.T) { + time := Time(10, 20, 3) + assertClauseDebugSerialize(t, table1ColTime.ADD(NewInterval(String("1 HOUR"))).EQ(time), + "((table1.col_time + INTERVAL '1 HOUR') = '10:20:03')") + assertClauseDebugSerialize(t, table1ColTime.SUB(NewInterval(String("1 HOUR"))).EQ(time), + "((table1.col_time - INTERVAL '1 HOUR') = '10:20:03')") +} diff --git a/internal/jet/timestamp_expression.go b/internal/jet/timestamp_expression.go index f76c27a..81eda61 100644 --- a/internal/jet/timestamp_expression.go +++ b/internal/jet/timestamp_expression.go @@ -13,6 +13,9 @@ type TimestampExpression interface { LT_EQ(rhs TimestampExpression) BoolExpression GT(rhs TimestampExpression) BoolExpression GT_EQ(rhs TimestampExpression) BoolExpression + + ADD(rhs Interval) TimestampExpression + SUB(rhs Interval) TimestampExpression } type timestampInterfaceImpl struct { @@ -51,6 +54,14 @@ func (t *timestampInterfaceImpl) GT_EQ(rhs TimestampExpression) BoolExpression { return gtEq(t.parent, rhs) } +func (t *timestampInterfaceImpl) ADD(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(t.parent, rhs, "+")) +} + +func (t *timestampInterfaceImpl) SUB(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(t.parent, rhs, "-")) +} + //------------------------------------------------- type timestampExpressionWrapper struct { diff --git a/internal/jet/timestamp_expression_test.go b/internal/jet/timestamp_expression_test.go index 9a9ceb4..e34d8dd 100644 --- a/internal/jet/timestamp_expression_test.go +++ b/internal/jet/timestamp_expression_test.go @@ -53,3 +53,11 @@ func TestTimestampExp(t *testing.T) { assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp), "(table1.col_float < $1)", "2000-01-31 10:20:00.003") } + +func TestTimestampArithmetic(t *testing.T) { + timestamp := Timestamp(2000, 1, 1, 0, 0, 0) + assertClauseDebugSerialize(t, table1ColTimestamp.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_timestamp + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") + assertClauseDebugSerialize(t, table1ColTimestamp.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_timestamp - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") +} diff --git a/internal/jet/timestampz_expression.go b/internal/jet/timestampz_expression.go index 4f3e6ec..a9f8c9f 100644 --- a/internal/jet/timestampz_expression.go +++ b/internal/jet/timestampz_expression.go @@ -13,6 +13,9 @@ type TimestampzExpression interface { LT_EQ(rhs TimestampzExpression) BoolExpression GT(rhs TimestampzExpression) BoolExpression GT_EQ(rhs TimestampzExpression) BoolExpression + + ADD(rhs Interval) TimestampzExpression + SUB(rhs Interval) TimestampzExpression } type timestampzInterfaceImpl struct { @@ -51,13 +54,12 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression return gtEq(t.parent, rhs) } -//---------------------------------------------------// +func (t *timestampzInterfaceImpl) ADD(rhs Interval) TimestampzExpression { + return TimestampzExp(newBinaryOperatorExpression(t.parent, rhs, "+")) +} -type prefixTimestampzOperator struct { - expressionInterfaceImpl - timestampzInterfaceImpl - - prefixOpExpression +func (t *timestampzInterfaceImpl) SUB(rhs Interval) TimestampzExpression { + return TimestampzExp(newBinaryOperatorExpression(t.parent, rhs, "-")) } //------------------------------------------------- diff --git a/internal/jet/timestampz_expression_test.go b/internal/jet/timestampz_expression_test.go index 6880c93..1ff1eac 100644 --- a/internal/jet/timestampz_expression_test.go +++ b/internal/jet/timestampz_expression_test.go @@ -53,3 +53,11 @@ func TestTimestampzExp(t *testing.T) { assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), "(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200") } + +func TestTimestampzArithmetic(t *testing.T) { + timestampz := Timestampz(2000, 1, 1, 0, 0, 0, 100, "UTC") + assertClauseDebugSerialize(t, table1ColTimestampz.ADD(NewInterval(String("1 HOUR"))).EQ(timestampz), + "((table1.col_timestampz + INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')") + assertClauseDebugSerialize(t, table1ColTimestampz.SUB(NewInterval(String("1 HOUR"))).EQ(timestampz), + "((table1.col_timestampz - INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')") +} diff --git a/internal/jet/timez_expression.go b/internal/jet/timez_expression.go index 36b5c8f..c791c62 100644 --- a/internal/jet/timez_expression.go +++ b/internal/jet/timez_expression.go @@ -4,23 +4,18 @@ package jet type TimezExpression interface { Expression - //EQ EQ(rhs TimezExpression) BoolExpression - //NOT_EQ NOT_EQ(rhs TimezExpression) BoolExpression - //IS_DISTINCT_FROM IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression - //IS_NOT_DISTINCT_FROM IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression - //LT LT(rhs TimezExpression) BoolExpression - //LT_EQ LT_EQ(rhs TimezExpression) BoolExpression - //GT GT(rhs TimezExpression) BoolExpression - //GT_EQ GT_EQ(rhs TimezExpression) BoolExpression + + ADD(rhs Interval) TimezExpression + SUB(rhs Interval) TimezExpression } type timezInterfaceImpl struct { @@ -59,23 +54,13 @@ func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression { return gtEq(t.parent, rhs) } -//---------------------------------------------------// -type prefixTimezExpression struct { - expressionInterfaceImpl - timezInterfaceImpl - - prefixOpExpression +func (t *timezInterfaceImpl) ADD(rhs Interval) TimezExpression { + return TimezExp(newBinaryOperatorExpression(t.parent, rhs, "+")) } -//func newPrefixTimezExpression(operator string, expression Expression) TimezExpression { -// timeExpr := prefixTimezExpression{} -// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) -// -// timeExpr.expressionInterfaceImpl.parent = &timeExpr -// timeExpr.timezInterfaceImpl.parent = &timeExpr -// -// return &timeExpr -//} +func (t *timezInterfaceImpl) SUB(rhs Interval) TimezExpression { + return TimezExp(newBinaryOperatorExpression(t.parent, rhs, "-")) +} //---------------------------------------------------// diff --git a/internal/jet/timez_expression_test.go b/internal/jet/timez_expression_test.go index 2a0312a..9f21c08 100644 --- a/internal/jet/timez_expression_test.go +++ b/internal/jet/timez_expression_test.go @@ -1,6 +1,8 @@ package jet -import "testing" +import ( + "testing" +) var timezVar = Timez(10, 20, 0, 0, "+4:00") @@ -49,3 +51,11 @@ func TestTimezExp(t *testing.T) { assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")), "(table1.col_float < $1)", string("01:01:01.000000001 +4:00")) } + +func TestTimezArithmetic(t *testing.T) { + timez := Timez(0, 0, 0, 100, "UTC") + assertClauseDebugSerialize(t, table1ColTimez.ADD(NewInterval(String("1 HOUR"))).EQ(timez), + "((table1.col_timez + INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')") + assertClauseDebugSerialize(t, table1ColTimez.SUB(NewInterval(String("1 HOUR"))).EQ(timez), + "((table1.col_timez - INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')") +} diff --git a/internal/jet/utils.go b/internal/jet/utils.go index fdaf1f6..67c2c6c 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -63,6 +63,16 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } +func ExpressionListToSerializerList(expressions []Expression) []Serializer { + var ret []Serializer + + for _, expr := range expressions { + ret = append(ret, expr) + } + + return ret +} + // ColumnListToProjectionList func func ColumnListToProjectionList(columns []ColumnExpression) []Projection { var ret []Projection diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 7948504..1b19a20 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -129,6 +129,28 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali } } +// AssertClauseSerialize checks if clause serialize produces expected query and args +func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { + out := jet.SQLBuilder{Dialect: dialect, Debug: true} + jet.Serialize(clause, jet.SelectStatementType, &out) + + assert.DeepEqual(t, out.Buff.String(), query) + + if len(args) > 0 { + assert.DeepEqual(t, out.Args, args) + } +} + +// AssertPanicErr checks if running a function fun produces a panic with errorStr string +func AssertPanicErr(t *testing.T, fun func(), errorStr string) { + defer func() { + r := recover() + assert.Equal(t, r, errorStr) + }() + + fun() +} + // AssertClauseSerializeErr check if clause serialize panics with errString func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { defer func() { diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 22ea1c3..97ac48a 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -9,6 +9,7 @@ import ( "path/filepath" "reflect" "strings" + "time" ) // ToGoIdentifier converts database to Go identifier. @@ -182,3 +183,21 @@ func StringSliceContains(strings []string, contains string) bool { return false } + +func ExtractDateTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) { + days = int64(duration / (24 * time.Hour)) + reminder := duration % (24 * time.Hour) + + hours = int64(reminder / time.Hour) + reminder = reminder % time.Hour + + minutes = int64(reminder / time.Minute) + reminder = reminder % time.Minute + + seconds = int64(reminder / time.Second) + reminder = reminder % time.Second + + microseconds = int64(reminder / time.Microsecond) + + return +} diff --git a/mysql/cast_test.go b/mysql/cast_test.go index cc1a809..170cde8 100644 --- a/mysql/cast_test.go +++ b/mysql/cast_test.go @@ -5,14 +5,14 @@ import ( ) func TestCAST(t *testing.T) { - assertClauseSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`) - assertClauseSerialize(t, CAST(Int(22)).AS_CHAR(), `CAST(? AS CHAR)`) - assertClauseSerialize(t, CAST(Int(22)).AS_CHAR(10), `CAST(? AS CHAR(10))`) - assertClauseSerialize(t, CAST(Int(22)).AS_DATE(), `CAST(? AS DATE)`) - assertClauseSerialize(t, CAST(Int(22)).AS_DECIMAL(), `CAST(? AS DECIMAL)`) - assertClauseSerialize(t, CAST(Int(22)).AS_TIME(), `CAST(? AS TIME)`) - assertClauseSerialize(t, CAST(Int(22)).AS_DATETIME(), `CAST(? AS DATETIME)`) - assertClauseSerialize(t, CAST(Int(22)).AS_SIGNED(), `CAST(? AS SIGNED)`) - assertClauseSerialize(t, CAST(Int(22)).AS_UNSIGNED(), `CAST(? AS UNSIGNED)`) - assertClauseSerialize(t, CAST(Int(22)).AS_BINARY(), `CAST(? AS BINARY)`) + assertSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`) + assertSerialize(t, CAST(Int(22)).AS_CHAR(), `CAST(? AS CHAR)`) + assertSerialize(t, CAST(Int(22)).AS_CHAR(10), `CAST(? AS CHAR(10))`) + assertSerialize(t, CAST(Int(22)).AS_DATE(), `CAST(? AS DATE)`) + assertSerialize(t, CAST(Int(22)).AS_DECIMAL(), `CAST(? AS DECIMAL)`) + assertSerialize(t, CAST(Int(22)).AS_TIME(), `CAST(? AS TIME)`) + assertSerialize(t, CAST(Int(22)).AS_DATETIME(), `CAST(? AS DATETIME)`) + assertSerialize(t, CAST(Int(22)).AS_SIGNED(), `CAST(? AS SIGNED)`) + assertSerialize(t, CAST(Int(22)).AS_UNSIGNED(), `CAST(? AS UNSIGNED)`) + assertSerialize(t, CAST(Int(22)).AS_BINARY(), `CAST(? AS BINARY)`) } diff --git a/mysql/dialect.go b/mysql/dialect.go index 45509a7..cfd452a 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -8,7 +8,6 @@ import ( var Dialect = newDialect() func newDialect() jet.Dialect { - operatorSerializeOverrides := map[string]jet.SerializeOverride{} operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator @@ -32,7 +31,7 @@ func newDialect() jet.Dialect { return jet.NewDialect(mySQLDialectParams) } -func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator XOR") @@ -49,7 +48,7 @@ func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlCONCAToperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator CONCAT") @@ -66,7 +65,7 @@ func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlDivision(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator DIV") @@ -90,7 +89,7 @@ func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlISNOTDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -102,7 +101,7 @@ func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlISDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { out.WriteString("NOT(") mysqlISNOTDISTINCTFROM(expressions...)(statement, out, options...) @@ -110,7 +109,7 @@ func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -136,7 +135,7 @@ func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") diff --git a/mysql/dialect_test.go b/mysql/dialect_test.go index 936277a..92ece4d 100644 --- a/mysql/dialect_test.go +++ b/mysql/dialect_test.go @@ -5,37 +5,37 @@ import ( ) func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(NOT(table1.col_bool <=> table2.col_bool))") - assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(NOT(table1.col_bool <=> ?))", false) + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(NOT(table1.col_bool <=> table2.col_bool))") + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(NOT(table1.col_bool <=> ?))", false) } func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool <=> table2.col_bool)") - assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool <=> ?)", false) + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool <=> table2.col_bool)") + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool <=> ?)", false) } func TestBoolLiteral(t *testing.T) { - assertClauseSerialize(t, Bool(true), "?", true) - assertClauseSerialize(t, Bool(false), "?", false) + assertSerialize(t, Bool(true), "?", true) + assertSerialize(t, Bool(false), "?", false) } func TestIntegerExpressionDIV(t *testing.T) { - assertClauseSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int DIV table2.col_int)") - assertClauseSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int DIV ?)", int64(11)) + assertSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int DIV table2.col_int)") + assertSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int DIV ?)", int64(11)) } func TestIntExpressionPOW(t *testing.T) { - assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)") - assertClauseSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11)) + assertSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)") + assertSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11)) } func TestIntExpressionBIT_XOR(t *testing.T) { - assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)") - assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11)) + assertSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)") + assertSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11)) } func TestExists(t *testing.T) { - assertClauseSerialize(t, EXISTS( + assertSerialize(t, EXISTS( table2. SELECT(Int(1)). WHERE(table1Col1.EQ(table2Col3)), @@ -48,15 +48,15 @@ func TestExists(t *testing.T) { } func TestString_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP BINARY ?)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP BINARY ?)", "JOHN") } func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 NOT REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP BINARY ?)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 NOT REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP BINARY ?)", "JOHN") } diff --git a/mysql/interval.go b/mysql/interval.go new file mode 100644 index 0000000..57c2cfc --- /dev/null +++ b/mysql/interval.go @@ -0,0 +1,187 @@ +package mysql + +import ( + "fmt" + "regexp" + "time" + + "github.com/go-jet/jet/internal/jet" + "github.com/go-jet/jet/internal/utils" +) + +type UnitType string + +const ( + MICROSECOND UnitType = "MICROSECOND" + SECOND = "SECOND" + MINUTE = "MINUTE" + HOUR = "HOUR" + DAY = "DAY" + WEEK = "WEEK" + MONTH = "MONTH" + QUARTER = "QUARTER" + YEAR = "YEAR" + SECOND_MICROSECOND = "SECOND_MICROSECOND" + MINUTE_MICROSECOND = "MINUTE_MICROSECOND" + MINUTE_SECOND = "MINUTE_SECOND" + HOUR_MICROSECOND = "HOUR_MICROSECOND" + HOUR_SECOND = "HOUR_SECOND" + HOUR_MINUTE = "HOUR_MINUTE" + DAY_MICROSECOND = "DAY_MICROSECOND" + DAY_SECOND = "DAY_SECOND" + DAY_MINUTE = "DAY_MINUTE" + DAY_HOUR = "DAY_HOUR" + YEAR_MONTH = "YEAR_MONTH" +) + +type Interval = jet.Interval + +func INTERVAL(value interface{}, unitType UnitType) Interval { + switch unitType { + case MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR: + if !isNumericType(value) { + panic("jet: INTERVAL invalid value type. Numeric type expected") + } + return INTERVALe(jet.FixedLiteral(value), unitType) + default: + strValue, ok := value.(string) + + if !ok { + panic("jet: INTERNAL invalid value type. String type expected") + } + + var regexp *regexp.Regexp + + switch unitType { + case SECOND_MICROSECOND: + regexp = regexSecondMicrosecond + case MINUTE_MICROSECOND: + regexp = regexMinuteMicrosecond + case MINUTE_SECOND: + regexp = regexMinuteSecond + case HOUR_MICROSECOND: + regexp = regexHourMicrosecond + case HOUR_SECOND: + regexp = regexHourSecond + case HOUR_MINUTE: + regexp = regexHourMinute + case DAY_MICROSECOND: + regexp = regexDayMicrosecond + case DAY_SECOND: + regexp = regexDaySecond + case DAY_MINUTE: + regexp = regexDayMinute + case DAY_HOUR: + regexp = regexDayHour + case YEAR_MONTH: + regexp = regexYearMonth + default: + panic("jet: INTERVAL invalid unit type") + } + + if !regexp.MatchString(strValue) { + panic("jet: INTERVAL invalid format") + } + + return INTERVALe(jet.Literal(value), unitType) + } +} + +func INTERVALe(expr Expression, unitType UnitType) Interval { + return jet.NewInterval(jet.ListSerializer{ + Serializers: []jet.Serializer{expr, jet.Raw(string(unitType))}, + Separator: " ", + }) +} + +// INTERVALd returns a representation of duration as MySQL INTERVAL +func INTERVALd(duration time.Duration) Interval { + var sign int64 = 1 + if duration < 0 { + sign = -1 + duration = -duration + } + + days, hours, minutes, sec, microsec := utils.ExtractDateTimeComponents(duration) + + if days != 0 { + switch { + case microsec > 0: + intervalStr := fmt.Sprintf("%d %02d:%02d:%02d.%06d", sign*days, hours, minutes, sec, microsec) + return INTERVAL(intervalStr, DAY_MICROSECOND) + case sec > 0: + intervalStr := fmt.Sprintf("%d %02d:%02d:%02d", sign*days, hours, minutes, sec) + return INTERVAL(intervalStr, DAY_SECOND) + case minutes > 0: + intervalStr := fmt.Sprintf("%d %02d:%02d", sign*days, hours, minutes) + return INTERVAL(intervalStr, DAY_MINUTE) + case hours > 0: + intervalStr := fmt.Sprintf("%d %02d", sign*days, hours) + return INTERVAL(intervalStr, DAY_HOUR) + default: + return INTERVAL(sign*days, DAY) + } + } + + if hours != 0 { + switch { + case microsec > 0: + intervalStr := fmt.Sprintf("%02d:%02d:%02d.%06d", sign*hours, minutes, sec, microsec) + return INTERVAL(intervalStr, HOUR_MICROSECOND) + case sec > 0: + intervalStr := fmt.Sprintf("%02d:%02d:%02d", sign*hours, minutes, sec) + return INTERVAL(intervalStr, HOUR_SECOND) + case minutes > 0: + intervalStr := fmt.Sprintf("%02d:%02d", sign*hours, minutes) + return INTERVAL(intervalStr, HOUR_MINUTE) + default: + return INTERVAL(sign*hours, HOUR) + } + } + + if minutes != 0 { + switch { + case microsec > 0: + intervalStr := fmt.Sprintf("%02d:%02d.%06d", sign*minutes, sec, microsec) + return INTERVAL(intervalStr, MINUTE_MICROSECOND) + case sec > 0: + intervalStr := fmt.Sprintf("%02d:%02d", sign*minutes, sec) + return INTERVAL(intervalStr, MINUTE_SECOND) + default: + return INTERVAL(sign*minutes, MINUTE) + } + } + + if sec != 0 { + if microsec > 0 { + intervalStr := fmt.Sprintf("%02d.%06d", sign*sec, microsec) + return INTERVAL(intervalStr, SECOND_MICROSECOND) + } + return INTERVAL(sign*sec, SECOND) + } + + return INTERVAL(sign*microsec, MICROSECOND) +} + +var ( + regexSecondMicrosecond = regexp.MustCompile(`^-?\d{1,2}\.\d+$`) //'SECONDS.MICROSECONDS' + regexMinuteMicrosecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}\.\d+$`) //'MINUTE:SECONDS.MICROSECONDS' + regexMinuteSecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}$`) //'MINUTE:SECONDS' + regexHourMicrosecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}:\d{2}\.\d+$`) //'HOUR:MINUTE:SECONDS.MICROSECONDS' + regexHourSecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}:\d{2}$`) //'HOUR:MINUTE:SECONDS' + regexHourMinute = regexp.MustCompile(`^-?\d{1,2}:\d{2}$`) //'HOUR:MINUTE' + regexDayMicrosecond = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}:\d{2}.\d+$`) //'DAY HOUR:MINUTE:SECONDS' + regexDaySecond = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}:\d{2}$`) //'DAY HOUR:MINUTE:SECONDS' + regexDayMinute = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}$`) //'DAY HOUR:MINUTE' + regexDayHour = regexp.MustCompile(`^-?\d+ \d{1,2}$`) //'DAY HOUR:MINUTE' + regexYearMonth = regexp.MustCompile(`^-?\d+-\d{1,2}$`) //'YEAR-MONTH' +) + +func isNumericType(value interface{}) bool { + switch value.(type) { + case float64, float32, int16, int32, int64, uint16, uint32, uint64, int, uint: + return true + default: + return false + } +} diff --git a/mysql/interval_test.go b/mysql/interval_test.go new file mode 100644 index 0000000..a7c16f8 --- /dev/null +++ b/mysql/interval_test.go @@ -0,0 +1,95 @@ +package mysql + +import ( + "testing" + "time" +) + +func TestINTERVAL(t *testing.T) { + assertSerialize(t, INTERVAL("3-2", YEAR_MONTH), "INTERVAL ? YEAR_MONTH") + assertDebugSerialize(t, INTERVAL("3-2", YEAR_MONTH), "INTERVAL '3-2' YEAR_MONTH") + assertDebugSerialize(t, INTERVAL("-3-2", YEAR_MONTH), "INTERVAL '-3-2' YEAR_MONTH") + assertDebugSerialize(t, INTERVAL("10 25", DAY_HOUR), "INTERVAL '10 25' DAY_HOUR") + assertDebugSerialize(t, INTERVAL("-10 25", DAY_HOUR), "INTERVAL '-10 25' DAY_HOUR") + assertDebugSerialize(t, INTERVAL("10 25:15", DAY_MINUTE), "INTERVAL '10 25:15' DAY_MINUTE") + assertDebugSerialize(t, INTERVAL("-10 25:15", DAY_MINUTE), "INTERVAL '-10 25:15' DAY_MINUTE") + assertDebugSerialize(t, INTERVAL("10 25:15:08", DAY_SECOND), "INTERVAL '10 25:15:08' DAY_SECOND") + assertDebugSerialize(t, INTERVAL("-10 25:15:08", DAY_SECOND), "INTERVAL '-10 25:15:08' DAY_SECOND") + assertDebugSerialize(t, INTERVAL("10 25:15:08.000100", DAY_MICROSECOND), "INTERVAL '10 25:15:08.000100' DAY_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-10 25:15:08.000100", DAY_MICROSECOND), "INTERVAL '-10 25:15:08.000100' DAY_MICROSECOND") + assertDebugSerialize(t, INTERVAL("15:08", HOUR_MINUTE), "INTERVAL '15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("-15:08", HOUR_MINUTE), "INTERVAL '-15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("15:08", HOUR_MINUTE), "INTERVAL '15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("-15:08", HOUR_MINUTE), "INTERVAL '-15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("15:08:03", HOUR_SECOND), "INTERVAL '15:08:03' HOUR_SECOND") + assertDebugSerialize(t, INTERVAL("-15:08:03", HOUR_SECOND), "INTERVAL '-15:08:03' HOUR_SECOND") + assertDebugSerialize(t, INTERVAL("25:15:08.000100", HOUR_MICROSECOND), "INTERVAL '25:15:08.000100' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-25:15:08.000100", HOUR_MICROSECOND), "INTERVAL '-25:15:08.000100' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVAL("08:03", MINUTE_SECOND), "INTERVAL '08:03' MINUTE_SECOND") + assertDebugSerialize(t, INTERVAL("-08:03", MINUTE_SECOND), "INTERVAL '-08:03' MINUTE_SECOND") + assertDebugSerialize(t, INTERVAL("15:08.000100", MINUTE_MICROSECOND), "INTERVAL '15:08.000100' MINUTE_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-15:08.000100", MINUTE_MICROSECOND), "INTERVAL '-15:08.000100' MINUTE_MICROSECOND") + assertDebugSerialize(t, INTERVAL("08.000100", SECOND_MICROSECOND), "INTERVAL '08.000100' SECOND_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-08.000100", SECOND_MICROSECOND), "INTERVAL '-08.000100' SECOND_MICROSECOND") + + assertDebugSerialize(t, INTERVAL(15, SECOND), "INTERVAL 15 SECOND") + assertDebugSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL 1 MICROSECOND") + assertDebugSerialize(t, INTERVAL(2, MINUTE), "INTERVAL 2 MINUTE") + assertDebugSerialize(t, INTERVAL(3, HOUR), "INTERVAL 3 HOUR") + assertDebugSerialize(t, INTERVAL(4, DAY), "INTERVAL 4 DAY") + assertDebugSerialize(t, INTERVAL(5, MONTH), "INTERVAL 5 MONTH") + assertDebugSerialize(t, INTERVAL(6, YEAR), "INTERVAL 6 YEAR") + assertDebugSerialize(t, INTERVAL(-6, YEAR), "INTERVAL -6 YEAR") +} + +func TestINTERVAL_InvalidUnitType(t *testing.T) { + assertPanicErr(t, func() { INTERVAL("11", HOUR) }, "jet: INTERVAL invalid value type. Numeric type expected") + assertPanicErr(t, func() { INTERVAL("11", YEAR_MONTH) }, "jet: INTERVAL invalid format") + assertPanicErr(t, func() { INTERVAL("11+11", YEAR_MONTH) }, "jet: INTERVAL invalid format") + assertPanicErr(t, func() { INTERVAL(156.11, YEAR_MONTH) }, "jet: INTERNAL invalid value type. String type expected") +} + +func TestINTERVALd(t *testing.T) { + assertDebugSerialize(t, INTERVALd(3*time.Microsecond), "INTERVAL 3 MICROSECOND") + assertDebugSerialize(t, INTERVALd(-1*time.Microsecond), "INTERVAL -1 MICROSECOND") + + assertDebugSerialize(t, INTERVALd(3*time.Second), "INTERVAL 3 SECOND") + assertDebugSerialize(t, INTERVALd(3*time.Second+4*time.Microsecond), "INTERVAL '03.000004' SECOND_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-1*time.Second), "INTERVAL -1 SECOND") + + assertDebugSerialize(t, INTERVALd(3*time.Minute), "INTERVAL 3 MINUTE") + assertDebugSerialize(t, INTERVALd(3*time.Minute+4*time.Second), "INTERVAL '03:04' MINUTE_SECOND") + assertDebugSerialize(t, INTERVALd(3*time.Minute+4*time.Second+5*time.Microsecond), "INTERVAL '03:04.000005' MINUTE_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-11*time.Minute), "INTERVAL -11 MINUTE") + assertDebugSerialize(t, INTERVALd(-11*time.Minute-22*time.Second), "INTERVAL '-11:22' MINUTE_SECOND") + + assertDebugSerialize(t, INTERVALd(3*time.Hour), "INTERVAL 3 HOUR") + assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute), "INTERVAL '03:04' HOUR_MINUTE") + assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute+5*time.Second), "INTERVAL '03:04:05' HOUR_SECOND") + assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute+5*time.Second+6*time.Millisecond), "INTERVAL '03:04:05.006000' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-11*time.Hour), "INTERVAL -11 HOUR") + assertDebugSerialize(t, INTERVALd(-11*time.Hour-22*time.Minute), "INTERVAL '-11:22' HOUR_MINUTE") + + assertDebugSerialize(t, INTERVALd(3*24*time.Hour), "INTERVAL 3 DAY") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour), "INTERVAL '3 04' DAY_HOUR") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute), "INTERVAL '3 04:05' DAY_MINUTE") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute+6*time.Second), "INTERVAL '3 04:05:06' DAY_SECOND") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute+6*time.Second+7*time.Microsecond), "INTERVAL '3 04:05:06.000007' DAY_MICROSECOND") + + assertDebugSerialize(t, INTERVALd(-11*24*time.Hour), "INTERVAL -11 DAY") + + assertDebugSerialize(t, INTERVALd(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond), "INTERVAL '01:02:03.000345' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-1*(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond)), "INTERVAL '-1:02:03.000345' HOUR_MICROSECOND") +} + +func TestINTERVALe(t *testing.T) { + assertSerialize(t, INTERVALe(table1ColFloat, MICROSECOND), "INTERVAL table1.col_float MICROSECOND") + assertSerialize(t, INTERVALe(table1ColFloat, SECOND), "INTERVAL table1.col_float SECOND") + assertSerialize(t, INTERVALe(table1ColFloat, MINUTE), "INTERVAL table1.col_float MINUTE") + assertSerialize(t, INTERVALe(table1ColFloat, HOUR), "INTERVAL table1.col_float HOUR") + assertSerialize(t, INTERVALe(table1ColFloat, DAY), "INTERVAL table1.col_float DAY") + assertSerialize(t, INTERVALe(table1ColFloat, WEEK), "INTERVAL table1.col_float WEEK") + assertSerialize(t, INTERVALe(table1ColFloat, MONTH), "INTERVAL table1.col_float MONTH") + assertSerialize(t, INTERVALe(table1ColFloat, QUARTER), "INTERVAL table1.col_float QUARTER") + assertSerialize(t, INTERVALe(table1ColFloat, YEAR), "INTERVAL table1.col_float YEAR") +} diff --git a/mysql/literal_test.go b/mysql/literal_test.go index f677d11..09d331e 100644 --- a/mysql/literal_test.go +++ b/mysql/literal_test.go @@ -6,37 +6,37 @@ import ( ) func TestBool(t *testing.T) { - assertClauseSerialize(t, Bool(false), `?`, false) + assertSerialize(t, Bool(false), `?`, false) } func TestInt(t *testing.T) { - assertClauseSerialize(t, Int(11), `?`, int64(11)) + assertSerialize(t, Int(11), `?`, int64(11)) } func TestFloat(t *testing.T) { - assertClauseSerialize(t, Float(12.34), `?`, float64(12.34)) + assertSerialize(t, Float(12.34), `?`, float64(12.34)) } func TestString(t *testing.T) { - assertClauseSerialize(t, String("Some text"), `?`, "Some text") + assertSerialize(t, String("Some text"), `?`, "Some text") } func TestDate(t *testing.T) { - assertClauseSerialize(t, Date(2014, time.January, 2), `CAST(? AS DATE)`, "2014-01-02") - assertClauseSerialize(t, DateT(time.Now()), `CAST(? AS DATE)`) + assertSerialize(t, Date(2014, time.January, 2), `CAST(? AS DATE)`, "2014-01-02") + assertSerialize(t, DateT(time.Now()), `CAST(? AS DATE)`) } func TestTime(t *testing.T) { - assertClauseSerialize(t, Time(10, 15, 30), `CAST(? AS TIME)`, "10:15:30") - assertClauseSerialize(t, TimeT(time.Now()), `CAST(? AS TIME)`) + assertSerialize(t, Time(10, 15, 30), `CAST(? AS TIME)`, "10:15:30") + assertSerialize(t, TimeT(time.Now()), `CAST(? AS TIME)`) } func TestDateTime(t *testing.T) { - assertClauseSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `CAST(? AS DATETIME)`, "2010-03-30 10:15:30") - assertClauseSerialize(t, DateTimeT(time.Now()), `CAST(? AS DATETIME)`) + assertSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `CAST(? AS DATETIME)`, "2010-03-30 10:15:30") + assertSerialize(t, DateTimeT(time.Now()), `CAST(? AS DATETIME)`) } func TestTimestamp(t *testing.T) { - assertClauseSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30") - assertClauseSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`) + assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30") + assertSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`) } diff --git a/mysql/table_test.go b/mysql/table_test.go index da45f36..3894378 100644 --- a/mysql/table_test.go +++ b/mysql/table_test.go @@ -12,17 +12,17 @@ func TestJoinNilInputs(t *testing.T) { } func TestINNER_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int) INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(Int(1))). INNER_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -31,17 +31,17 @@ INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestLEFT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -50,17 +50,17 @@ LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestRIGHT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -69,17 +69,17 @@ RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestFULL_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int) FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(Int(1))). FULL_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -88,11 +88,11 @@ FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestCROSS_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2), `db.table1 CROSS JOIN db.table2`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2). CROSS_JOIN(table3), `db.table1 diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 5804a07..1cc42f1 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -58,10 +58,14 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { +func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } +func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { + testutils.AssertDebugClauseSerialize(t, Dialect, clause, query, args...) +} + func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) } @@ -70,5 +74,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...) } +var assertPanicErr = testutils.AssertPanicErr var assertStatementSql = testutils.AssertStatementSql var assertStatementSqlErr = testutils.AssertStatementSqlErr diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index f0f7406..361bb8d 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,19 +1,20 @@ package mysql import ( - "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" - "github.com/google/uuid" - "time" . "github.com/go-jet/jet/mysql" "gotest.tools/assert" - "testing" ) func TestAllTypes(t *testing.T) { @@ -506,15 +507,11 @@ func TestStringOperators(t *testing.T) { REGEXP_LIKE(AllTypes.Text, String("aba"), "i"), }...) } - //_, args, _ := query.Sql() - - //fmt.Println(query.Sql()) - //fmt.Println(args[15]) query := SELECT(projectionList[0], projectionList[1:]...). FROM(AllTypes) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) dest := []struct{}{} err := query.Query(db, &dest) @@ -555,32 +552,49 @@ func TestTimeExpressions(t *testing.T) { AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(Time(14, 26, 36)), + AllTypes.Time.ADD(INTERVAL(10, MINUTE)), + AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)), + AllTypes.Time.ADD(INTERVALd(3*time.Hour)), + + AllTypes.Time.SUB(INTERVAL(20, MINUTE)), + AllTypes.Time.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + AllTypes.Time.SUB(INTERVALd(3*time.Minute)), + + AllTypes.Time.ADD(INTERVAL(20, MINUTE)).SUB(INTERVAL(11, HOUR)), + CURRENT_TIME(), CURRENT_TIME(3), ) - //fmt.Println(query.Sql()) + //fmt.Println(query.DebugSql()) - testutils.AssertStatementSql(t, query, ` -SELECT CAST(? AS TIME), + testutils.AssertDebugStatementSql(t, query, ` +SELECT CAST('20:34:58' AS TIME), all_types.time = all_types.time, - all_types.time = CAST(? AS TIME), - all_types.time = CAST(? AS TIME), - all_types.time = CAST(? AS TIME), + all_types.time = CAST('23:06:06' AS TIME), + all_types.time = CAST('22:06:06.011' AS TIME), + all_types.time = CAST('21:06:06.011111' AS TIME), all_types.time_ptr != all_types.time, - all_types.time_ptr != CAST(? AS TIME), + all_types.time_ptr != CAST('20:16:06' AS TIME), NOT(all_types.time <=> all_types.time), - NOT(all_types.time <=> CAST(? AS TIME)), + NOT(all_types.time <=> CAST('19:26:06' AS TIME)), all_types.time <=> all_types.time, - all_types.time <=> CAST(? AS TIME), + all_types.time <=> CAST('18:36:06' AS TIME), all_types.time < all_types.time, - all_types.time < CAST(? AS TIME), + all_types.time < CAST('17:46:06' AS TIME), all_types.time <= all_types.time, - all_types.time <= CAST(? AS TIME), + all_types.time <= CAST('16:56:56' AS TIME), all_types.time > all_types.time, - all_types.time > CAST(? AS TIME), + all_types.time > CAST('15:16:46' AS TIME), all_types.time >= all_types.time, - all_types.time >= CAST(? AS TIME), + all_types.time >= CAST('14:26:36' AS TIME), + all_types.time + INTERVAL 10 MINUTE, + all_types.time + INTERVAL all_types.integer MINUTE, + all_types.time + INTERVAL 3 HOUR, + all_types.time - INTERVAL 20 MINUTE, + all_types.time - INTERVAL all_types.small_int MINUTE, + all_types.time - INTERVAL 3 MINUTE, + (all_types.time + INTERVAL 20 MINUTE) - INTERVAL 11 HOUR, CURRENT_TIME, CURRENT_TIME(3) FROM test_sample.all_types; @@ -621,10 +635,18 @@ func TestDateExpressions(t *testing.T) { AllTypes.Date.GT_EQ(AllTypes.Date), AllTypes.Date.GT_EQ(Date(2019, 2, 3)), + AllTypes.Date.ADD(INTERVAL("10:20.000100", MINUTE_MICROSECOND)), + AllTypes.Date.ADD(INTERVALe(AllTypes.BigInt, MINUTE)), + AllTypes.Date.ADD(INTERVALd(15*time.Hour)), + + AllTypes.Date.SUB(INTERVAL(20, MINUTE)), + AllTypes.Date.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + AllTypes.Date.SUB(INTERVALd(3*time.Minute)), + CURRENT_DATE(), ) - //fmt.Println(query.Sql()) + //fmt.Println(query.DebugSql()) testutils.AssertStatementSql(t, query, ` SELECT CAST(? AS DATE), @@ -644,6 +666,12 @@ SELECT CAST(? AS DATE), all_types.date > CAST(? AS DATE), all_types.date >= all_types.date, all_types.date >= CAST(? AS DATE), + all_types.date + INTERVAL ? MINUTE_MICROSECOND, + all_types.date + INTERVAL all_types.big_int MINUTE, + all_types.date + INTERVAL 15 HOUR, + all_types.date - INTERVAL 20 MINUTE, + all_types.date - INTERVAL all_types.small_int MINUTE, + all_types.date - INTERVAL 3 MINUTE, CURRENT_DATE FROM test_sample.all_types; `) @@ -683,11 +711,19 @@ func TestDateTimeExpressions(t *testing.T) { AllTypes.DateTime.GT_EQ(AllTypes.DateTime), AllTypes.DateTime.GT_EQ(dateTime), + AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)), + AllTypes.DateTime.ADD(INTERVALd(2*time.Hour)), + + AllTypes.DateTime.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.DateTime.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + AllTypes.DateTime.SUB(INTERVALd(3*time.Hour)), + NOW(), NOW(1), ) - //fmt.Println(query.DebugSql()) + //Println(query.DebugSql()) testutils.AssertDebugStatementSql(t, query, ` SELECT all_types.date_time = all_types.date_time, @@ -706,6 +742,12 @@ SELECT all_types.date_time = all_types.date_time, all_types.date_time > CAST('2019-06-06 10:02:46' AS DATETIME), all_types.date_time >= all_types.date_time, all_types.date_time >= CAST('2019-06-06 10:02:46' AS DATETIME), + all_types.date_time + INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.date_time + INTERVAL all_types.big_int HOUR, + all_types.date_time + INTERVAL 2 HOUR, + all_types.date_time - INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.date_time - INTERVAL all_types.integer_ptr HOUR, + all_types.date_time - INTERVAL 3 HOUR, NOW(), NOW(1) FROM test_sample.all_types; @@ -746,6 +788,14 @@ func TestTimestampExpressions(t *testing.T) { AllTypes.Timestamp.GT_EQ(AllTypes.Timestamp), AllTypes.Timestamp.GT_EQ(timestamp), + AllTypes.Timestamp.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.Timestamp.ADD(INTERVALe(AllTypes.BigInt, HOUR)), + AllTypes.Timestamp.ADD(INTERVALd(2*time.Hour)), + + AllTypes.Timestamp.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.Timestamp.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + AllTypes.Timestamp.SUB(INTERVALd(3*time.Hour)), + CURRENT_TIMESTAMP(), CURRENT_TIMESTAMP(2), ) @@ -769,6 +819,12 @@ SELECT all_types.timestamp = all_types.timestamp, all_types.timestamp > TIMESTAMP('2019-06-06 10:02:46'), all_types.timestamp >= all_types.timestamp, all_types.timestamp >= TIMESTAMP('2019-06-06 10:02:46'), + all_types.timestamp + INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.timestamp + INTERVAL all_types.big_int HOUR, + all_types.timestamp + INTERVAL 2 HOUR, + all_types.timestamp - INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.timestamp - INTERVAL all_types.integer_ptr HOUR, + all_types.timestamp - INTERVAL 3 HOUR, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP(2) FROM test_sample.all_types; @@ -853,6 +909,60 @@ LIMIT ?; } +func TestINTERVAL(t *testing.T) { + query := SELECT( + Date(2000, 2, 10).ADD(INTERVAL(1, MICROSECOND)). + EQ(Timestamp(2000, 2, 10, 0, 0, 0, 1*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVAL(2, SECOND)), + Date(2000, 2, 10).ADD(INTERVAL(3, MINUTE)), + Date(2000, 2, 10).SUB(INTERVAL(4, HOUR)), + Date(2000, 2, 10).ADD(INTERVAL(5, DAY)), + Date(2000, 2, 10).SUB(INTERVAL(6, MONTH)), + Date(2000, 2, 10).ADD(INTERVAL(7, YEAR)), + Date(2000, 2, 10).ADD(INTERVAL(-7, YEAR)), + Date(2000, 2, 10).ADD(INTERVAL("20.0000100", SECOND_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("02:20.0000100", MINUTE_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02:20.0000100", HOUR_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("100 11:02:20.0000100", DAY_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02", MINUTE_SECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02:20", HOUR_SECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02", HOUR_MINUTE)), + Date(2000, 2, 10).SUB(INTERVAL("11 02:03:04", DAY_SECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11 02:03", DAY_MINUTE)), + Date(2000, 2, 10).SUB(INTERVAL("11 2", DAY_HOUR)), + Date(2000, 2, 10).SUB(INTERVAL("2000-2", YEAR_MONTH)), + + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, SECOND)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MINUTE)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, DAY)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, WEEK)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MONTH)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, QUARTER)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, YEAR)), + + Date(2000, 2, 10).SUB(INTERVALd(3*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(-3*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Second)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Second+4*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Minute+4*time.Second+5*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Hour+4*time.Minute+5*time.Second+6*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute+5*time.Second+6*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute+5*time.Second)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Hour)), + Date(2000, 2, 10).SUB(INTERVALd(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond)), + ).FROM(AllTypes) + + //fmt.Println(query.DebugSql()) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + var allTypesJson = ` [ { diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index f404e29..952eb63 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -1,7 +1,6 @@ package mysql import ( - "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum" @@ -607,7 +606,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). WHERE(Payment.PaymentID.LT(Int(10))) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) @@ -643,7 +642,7 @@ ORDER BY payment.customer_id; WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). ORDER_BY(Payment.CustomerID) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) From 4a3579e7f954f1639908df46f348a55c3247b9d6 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 1 Dec 2019 18:26:01 +0100 Subject: [PATCH 05/19] Postgres interval with date/time expression arithmetic. --- postgres/cast_test.go | 34 ++++----- postgres/dialect.go | 6 +- postgres/dialect_test.go | 26 +++---- postgres/interval.go | 131 ++++++++++++++++++++++++++++++++ postgres/interval_test.go | 59 ++++++++++++++ postgres/literal_test.go | 28 +++---- postgres/table_test.go | 28 +++---- postgres/utils_test.go | 6 +- tests/postgres/alltypes_test.go | 59 +++++++++++++- tests/postgres/select_test.go | 12 +-- 10 files changed, 313 insertions(+), 76 deletions(-) create mode 100644 postgres/interval.go create mode 100644 postgres/interval_test.go diff --git a/postgres/cast_test.go b/postgres/cast_test.go index 4537784..a1e4be5 100644 --- a/postgres/cast_test.go +++ b/postgres/cast_test.go @@ -5,60 +5,60 @@ import ( ) func TestExpressionCAST_AS(t *testing.T) { - assertClauseSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") + assertSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") } func TestExpressionCAST_AS_BOOL(t *testing.T) { - assertClauseSerialize(t, CAST(Int(1)).AS_BOOL(), "$1::boolean", int64(1)) - assertClauseSerialize(t, CAST(table2Col3).AS_BOOL(), "table2.col3::boolean") - assertClauseSerialize(t, CAST(table2Col3.ADD(table2Col3)).AS_BOOL(), "(table2.col3 + table2.col3)::boolean") + assertSerialize(t, CAST(Int(1)).AS_BOOL(), "$1::boolean", int64(1)) + assertSerialize(t, CAST(table2Col3).AS_BOOL(), "table2.col3::boolean") + assertSerialize(t, CAST(table2Col3.ADD(table2Col3)).AS_BOOL(), "(table2.col3 + table2.col3)::boolean") } func TestExpressionCAST_AS_SMALLINT(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_SMALLINT(), "table2.col3::smallint") + assertSerialize(t, CAST(table2Col3).AS_SMALLINT(), "table2.col3::smallint") } func TestExpressionCAST_AS_INTEGER(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_INTEGER(), "table2.col3::integer") + assertSerialize(t, CAST(table2Col3).AS_INTEGER(), "table2.col3::integer") } func TestExpressionCAST_AS_BIGINT(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_BIGINT(), "table2.col3::bigint") + assertSerialize(t, CAST(table2Col3).AS_BIGINT(), "table2.col3::bigint") } func TestExpressionCAST_AS_NUMERIC(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_NUMERIC(11, 11), "table2.col3::numeric(11, 11)") - assertClauseSerialize(t, CAST(table2Col3).AS_NUMERIC(11), "table2.col3::numeric(11)") + assertSerialize(t, CAST(table2Col3).AS_NUMERIC(11, 11), "table2.col3::numeric(11, 11)") + assertSerialize(t, CAST(table2Col3).AS_NUMERIC(11), "table2.col3::numeric(11)") } func TestExpressionCAST_AS_REAL(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_REAL(), "table2.col3::real") + assertSerialize(t, CAST(table2Col3).AS_REAL(), "table2.col3::real") } func TestExpressionCAST_AS_DOUBLE(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_DOUBLE(), "table2.col3::double precision") + assertSerialize(t, CAST(table2Col3).AS_DOUBLE(), "table2.col3::double precision") } func TestExpressionCAST_AS_TEXT(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TEXT(), "table2.col3::text") + assertSerialize(t, CAST(table2Col3).AS_TEXT(), "table2.col3::text") } func TestExpressionCAST_AS_DATE(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::date") + assertSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::date") } func TestExpressionCAST_AS_TIME(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIME(), "table2.col3::time without time zone") + assertSerialize(t, CAST(table2Col3).AS_TIME(), "table2.col3::time without time zone") } func TestExpressionCAST_AS_TIMEZ(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIMEZ(), "table2.col3::time with time zone") + assertSerialize(t, CAST(table2Col3).AS_TIMEZ(), "table2.col3::time with time zone") } func TestExpressionCAST_AS_TIMESTAMP(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIMESTAMP(), "table2.col3::timestamp without time zone") + assertSerialize(t, CAST(table2Col3).AS_TIMESTAMP(), "table2.col3::timestamp without time zone") } func TestExpressionCAST_AS_TIMESTAMPZ(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone") + assertSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone") } diff --git a/postgres/dialect.go b/postgres/dialect.go index 114e5a6..c1e8c0b 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -29,7 +29,7 @@ func newDialect() jet.Dialect { return jet.NewDialect(dialectParams) } -func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { +func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -54,7 +54,7 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { } } -func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func postgresREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -80,7 +80,7 @@ func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc } } -func postgresNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index b6061b7..a37c0c9 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -3,21 +3,21 @@ package postgres import "testing" func TestString_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") } func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") } func TestExists(t *testing.T) { - assertClauseSerialize(t, EXISTS( + assertSerialize(t, EXISTS( table2. SELECT(Int(1)). WHERE(table1Col1.EQ(table2Col3)), @@ -31,13 +31,13 @@ func TestExists(t *testing.T) { func TestIN(t *testing.T) { - assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), + assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), `($1 IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 )))`, float64(1.11)) - assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), + assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), `(ROW($1, table1.col1) IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" @@ -47,13 +47,13 @@ func TestIN(t *testing.T) { func TestNOT_IN(t *testing.T) { - assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), + assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), `($1 NOT IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 )))`, float64(1.11)) - assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), + assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), `(ROW($1, table1.col1) NOT IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" diff --git a/postgres/interval.go b/postgres/interval.go new file mode 100644 index 0000000..dba2c6c --- /dev/null +++ b/postgres/interval.go @@ -0,0 +1,131 @@ +package postgres + +import ( + "fmt" + "github.com/go-jet/jet/internal/jet" + "github.com/go-jet/jet/internal/utils" + "strconv" + "strings" + "time" +) + +type quantityAndUnit float64 + +const ( + pow2_32 = -4.294967296e+09 + + YEAR quantityAndUnit = pow2_32 + iota + MONTH + WEEK + DAY + HOUR + MINUTE + SECOND + MILLISECOND + MICROSECOND + DECADE + CENTURY + MILLENNIUM +) + +type intervalExpressionImpl struct { + jet.Interval + jet.ExpressionInterfaceImpl +} + +type IntervalExpression interface { + jet.IsInterval + jet.Expression +} + +func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { + if len(quantityAndUnit)%2 != 0 { + panic("jet: invalid number of quantity and unit fields") + } + + fields := []string{} + + for i := 0; i < len(quantityAndUnit); i += 2 { + quantity := strconv.FormatFloat(float64(quantityAndUnit[i]), 'f', -1, 64) + unitString := unitToString(quantityAndUnit[i+1]) + fields = append(fields, quantity+" "+unitString) + } + + intervalStr := fmt.Sprintf("'%s'", strings.Join(fields, " ")) + + newInterval := &intervalExpressionImpl{ + Interval: jet.NewInterval(jet.Raw(intervalStr)), + } + + newInterval.ExpressionInterfaceImpl.Parent = newInterval + + return newInterval +} + +func INTERVALd(duration time.Duration) IntervalExpression { + days, hours, minutes, seconds, microseconds := utils.ExtractDateTimeComponents(duration) + + quantityAndUnits := []quantityAndUnit{} + + if days > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days)) + quantityAndUnits = append(quantityAndUnits, DAY) + } + + if hours > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours)) + quantityAndUnits = append(quantityAndUnits, HOUR) + } + + if minutes > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes)) + quantityAndUnits = append(quantityAndUnits, MINUTE) + } + + if seconds > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds)) + quantityAndUnits = append(quantityAndUnits, SECOND) + } + + if microseconds > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds)) + quantityAndUnits = append(quantityAndUnits, MICROSECOND) + } + + if len(quantityAndUnits) == 0 { + return INTERVAL(0, MICROSECOND) + } + + return INTERVAL(quantityAndUnits...) +} + +func unitToString(unit quantityAndUnit) string { + switch unit { + case YEAR: + return "YEAR" + case MONTH: + return "MONTH" + case WEEK: + return "WEEK" + case DAY: + return "DAY" + case HOUR: + return "HOUR" + case MINUTE: + return "MINUTE" + case SECOND: + return "SECOND" + case MILLISECOND: + return "MILLISECOND" + case MICROSECOND: + return "MICROSECOND" + case DECADE: + return "DECADE" + case CENTURY: + return "CENTURY" + case MILLENNIUM: + return "MILLENNIUM" + default: + panic("jet: invalid INTERVAL unit type") + } +} diff --git a/postgres/interval_test.go b/postgres/interval_test.go new file mode 100644 index 0000000..1d1e3de --- /dev/null +++ b/postgres/interval_test.go @@ -0,0 +1,59 @@ +package postgres + +import ( + "testing" + "time" +) + +func TestINTERVAL(t *testing.T) { + assertSerialize(t, INTERVAL(1, YEAR), "INTERVAL '1 YEAR'") + assertSerialize(t, INTERVAL(1, MONTH), "INTERVAL '1 MONTH'") + assertSerialize(t, INTERVAL(1, WEEK), "INTERVAL '1 WEEK'") + assertSerialize(t, INTERVAL(1, DAY), "INTERVAL '1 DAY'") + assertSerialize(t, INTERVAL(1, HOUR), "INTERVAL '1 HOUR'") + assertSerialize(t, INTERVAL(1, MINUTE), "INTERVAL '1 MINUTE'") + assertSerialize(t, INTERVAL(1, SECOND), "INTERVAL '1 SECOND'") + assertSerialize(t, INTERVAL(1, MILLISECOND), "INTERVAL '1 MILLISECOND'") + assertSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL '1 MICROSECOND'") + assertSerialize(t, INTERVAL(1, DECADE), "INTERVAL '1 DECADE'") + assertSerialize(t, INTERVAL(1, CENTURY), "INTERVAL '1 CENTURY'") + assertSerialize(t, INTERVAL(1, MILLENNIUM), "INTERVAL '1 MILLENNIUM'") + + assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH), "INTERVAL '1 YEAR 10 MONTH'") + assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY), "INTERVAL '1 YEAR 10 MONTH 20 DAY'") + assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), "INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR'") + + assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL") + assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`) +} + +func TestINTERVALd(t *testing.T) { + assertSerialize(t, INTERVALd(0), "INTERVAL '0 MICROSECOND'") + assertSerialize(t, INTERVALd(1*time.Microsecond), "INTERVAL '1 MICROSECOND'") + assertSerialize(t, INTERVALd(1*time.Millisecond), "INTERVAL '1000 MICROSECOND'") + assertSerialize(t, INTERVALd(1*time.Second), "INTERVAL '1 SECOND'") + assertSerialize(t, INTERVALd(1*time.Minute), "INTERVAL '1 MINUTE'") + assertSerialize(t, INTERVALd(1*time.Hour), "INTERVAL '1 HOUR'") + assertSerialize(t, INTERVALd(24*time.Hour), "INTERVAL '1 DAY'") + + assertSerialize(t, INTERVALd(24*time.Hour+2*time.Hour+3*time.Minute+4*time.Second+5*time.Microsecond), + "INTERVAL '1 DAY 2 HOUR 3 MINUTE 4 SECOND 5 MICROSECOND'") +} + +func TestINTERVAL_InvalidParams(t *testing.T) { + assertPanicErr(t, func() { INTERVAL(1) }, "jet: invalid number of quantity and unit fields") + assertPanicErr(t, func() { INTERVAL(1, 2) }, "jet: invalid INTERVAL unit type") +} + +func TestIntervalArithmetic(t *testing.T) { + assertSerialize(t, table2ColDate.ADD(INTERVAL(1, HOUR)), "(table2.col_date + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColDate.SUB(INTERVAL(1, HOUR)), "(table2.col_date - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTime.ADD(INTERVAL(1, HOUR)), "(table2.col_time + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTime.SUB(INTERVAL(1, HOUR)), "(table2.col_time - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimez.ADD(INTERVAL(1, HOUR)), "(table2.col_timez + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimez.SUB(INTERVAL(1, HOUR)), "(table2.col_timez - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestamp.ADD(INTERVAL(1, HOUR)), "(table2.col_timestamp + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestamp.SUB(INTERVAL(1, HOUR)), "(table2.col_timestamp - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestampz.ADD(INTERVAL(1, HOUR)), "(table2.col_timestampz + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestampz.SUB(INTERVAL(1, HOUR)), "(table2.col_timestampz - INTERVAL '1 HOUR')") +} diff --git a/postgres/literal_test.go b/postgres/literal_test.go index f30ef01..5206aaa 100644 --- a/postgres/literal_test.go +++ b/postgres/literal_test.go @@ -6,45 +6,45 @@ import ( ) func TestBool(t *testing.T) { - assertClauseSerialize(t, Bool(false), `$1`, false) + assertSerialize(t, Bool(false), `$1`, false) } func TestInt(t *testing.T) { - assertClauseSerialize(t, Int(11), `$1`, int64(11)) + assertSerialize(t, Int(11), `$1`, int64(11)) } func TestFloat(t *testing.T) { - assertClauseSerialize(t, Float(12.34), `$1`, float64(12.34)) + assertSerialize(t, Float(12.34), `$1`, float64(12.34)) } func TestString(t *testing.T) { - assertClauseSerialize(t, String("Some text"), `$1`, "Some text") + assertSerialize(t, String("Some text"), `$1`, "Some text") } func TestDate(t *testing.T) { - assertClauseSerialize(t, Date(2014, time.January, 2), `$1::date`, "2014-01-02") - assertClauseSerialize(t, DateT(time.Now()), `$1::date`) + assertSerialize(t, Date(2014, time.January, 2), `$1::date`, "2014-01-02") + assertSerialize(t, DateT(time.Now()), `$1::date`) } func TestTime(t *testing.T) { - assertClauseSerialize(t, Time(10, 15, 30), `$1::time without time zone`, "10:15:30") - assertClauseSerialize(t, TimeT(time.Now()), `$1::time without time zone`) + assertSerialize(t, Time(10, 15, 30), `$1::time without time zone`, "10:15:30") + assertSerialize(t, TimeT(time.Now()), `$1::time without time zone`) } func TestTimez(t *testing.T) { - assertClauseSerialize(t, Timez(10, 15, 30, 0, "UTC"), + assertSerialize(t, Timez(10, 15, 30, 0, "UTC"), `$1::time with time zone`, "10:15:30 UTC") - assertClauseSerialize(t, TimezT(time.Now()), `$1::time with time zone`) + assertSerialize(t, TimezT(time.Now()), `$1::time with time zone`) } func TestTimestamp(t *testing.T) { - assertClauseSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), + assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `$1::timestamp without time zone`, "2010-03-30 10:15:30") - assertClauseSerialize(t, TimestampT(time.Now()), `$1::timestamp without time zone`) + assertSerialize(t, TimestampT(time.Now()), `$1::timestamp without time zone`) } func TestTimestampz(t *testing.T) { - assertClauseSerialize(t, Timestampz(2010, time.March, 30, 10, 15, 30, 0, "UTC"), + assertSerialize(t, Timestampz(2010, time.March, 30, 10, 15, 30, 0, "UTC"), `$1::timestamp with time zone`, "2010-03-30 10:15:30 UTC") - assertClauseSerialize(t, TimestampzT(time.Now()), `$1::timestamp with time zone`) + assertSerialize(t, TimestampzT(time.Now()), `$1::timestamp with time zone`) } diff --git a/postgres/table_test.go b/postgres/table_test.go index 6573b02..43aa096 100644 --- a/postgres/table_test.go +++ b/postgres/table_test.go @@ -12,17 +12,17 @@ func TestJoinNilInputs(t *testing.T) { } func TestINNER_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int) INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(Int(1))). INNER_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -31,17 +31,17 @@ INNER JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestLEFT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -50,17 +50,17 @@ LEFT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestRIGHT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -69,17 +69,17 @@ RIGHT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestFULL_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int) FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(Int(1))). FULL_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -88,11 +88,11 @@ FULL JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestCROSS_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2), `db.table1 CROSS JOIN db.table2`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2). CROSS_JOIN(table3), `db.table1 diff --git a/postgres/utils_test.go b/postgres/utils_test.go index c65d5b6..4a80954 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -1,9 +1,10 @@ package postgres import ( + "testing" + "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" - "testing" ) var table1Col1 = IntegerColumn("col1") @@ -70,7 +71,7 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { +func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } @@ -84,3 +85,4 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st var assertStatementSql = testutils.AssertStatementSql var assertStatementSqlErr = testutils.AssertStatementSqlErr +var assertPanicErr = testutils.AssertPanicErr diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index f2e8fbe..eed0695 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,6 +1,12 @@ package postgres import ( + "testing" + "time" + + "github.com/google/uuid" + "gotest.tools/assert" + "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres" @@ -8,10 +14,6 @@ import ( . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" - "github.com/google/uuid" - "gotest.tools/assert" - "testing" - "time" ) func TestAllTypesSelect(t *testing.T) { @@ -606,6 +608,17 @@ func TestTimeExpression(t *testing.T) { AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(Time(23, 6, 6, 1)), + AllTypes.Date.ADD(INTERVAL(1, HOUR)), + AllTypes.Date.SUB(INTERVAL(1, MINUTE)), + AllTypes.Time.ADD(INTERVAL(1, HOUR)), + AllTypes.Time.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timez.ADD(INTERVAL(1, HOUR)), + AllTypes.Timez.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timestamp.ADD(INTERVAL(1, HOUR)), + AllTypes.Timestamp.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timestampz.ADD(INTERVAL(1, HOUR)), + AllTypes.Timestampz.SUB(INTERVAL(1, MINUTE)), + CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIME(2), @@ -626,6 +639,44 @@ func TestTimeExpression(t *testing.T) { assert.NilError(t, err) } +func TestInterval(t *testing.T) { + stmt := SELECT( + INTERVAL(1, YEAR), + INTERVAL(1, MONTH), + INTERVAL(1, WEEK), + INTERVAL(1, DAY), + INTERVAL(1, HOUR), + INTERVAL(1, MINUTE), + INTERVAL(1, SECOND), + INTERVAL(1, MILLISECOND), + INTERVAL(1, MICROSECOND), + INTERVAL(1, DECADE), + INTERVAL(1, CENTURY), + INTERVAL(1, MILLENNIUM), + + INTERVAL(1, YEAR, 10, MONTH), + INTERVAL(1, YEAR, 10, MONTH, 20, DAY), + INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), + + INTERVAL(1, YEAR).IS_NOT_NULL(), + INTERVAL(1, YEAR).AS("one year"), + + INTERVALd(0), + INTERVALd(1*time.Microsecond), + INTERVALd(1*time.Millisecond), + INTERVALd(1*time.Second), + INTERVALd(1*time.Minute), + INTERVALd(1*time.Hour), + INTERVALd(24*time.Hour), + INTERVALd(24*time.Hour+2*time.Hour+3*time.Minute+4*time.Second+5*time.Microsecond), + ) + + //fmt.Println(stmt.DebugSql()) + + err := stmt.Query(db, &struct{}{}) + assert.NilError(t, err) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 3cd2fba..2ac6c36 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -158,7 +158,7 @@ LIMIT 12; ). LIMIT(12) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) @@ -1686,7 +1686,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). WHERE(Payment.PaymentID.LT(Int(10))) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) @@ -1722,7 +1722,7 @@ ORDER BY payment.customer_id; WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). ORDER_BY(Payment.CustomerID) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) @@ -1748,12 +1748,6 @@ func TestSimpleView(t *testing.T) { FilmInfo string } - //sql, args := query.Sql() - // - //row := db.QueryRow(sql, args...) - // - //row.Scan() - var dest []ActorInfo err := query.Query(db, &dest) From a2fbc4f53af1e548230438c5f8419d03cc50f10f Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 7 Dec 2019 13:52:51 +0100 Subject: [PATCH 06/19] Add postgres interval cast. --- postgres/cast.go | 12 +++++++++-- postgres/cast_test.go | 7 ++++++ postgres/interval.go | 19 +++++++++++++++++ tests/postgres/alltypes_test.go | 38 +++++++++++++++++---------------- 4 files changed, 56 insertions(+), 20 deletions(-) diff --git a/postgres/cast.go b/postgres/cast.go index e9ec209..0f4b255 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -2,8 +2,9 @@ package postgres import ( "fmt" - "github.com/go-jet/jet/internal/jet" "strconv" + + "github.com/go-jet/jet/internal/jet" ) type cast interface { @@ -32,7 +33,7 @@ type cast interface { AS_TIME() TimeExpression // Cast expression AS text type AS_TEXT() StringExpression - + // Cast expression AS bytea type AS_BYTEA() StringExpression // Cast expression AS time with time timezone type AS_TIMEZ() TimezExpression @@ -40,6 +41,8 @@ type cast interface { AS_TIMESTAMP() TimestampExpression // Cast expression AS timestamp with timezone type AS_TIMESTAMPZ() TimestampzExpression + // Cast expression AS interval type + AS_INTERVAL() IntervalExpression } type castImpl struct { @@ -151,3 +154,8 @@ func (b *castImpl) AS_TIMESTAMP() TimestampExpression { func (b *castImpl) AS_TIMESTAMPZ() TimestampzExpression { return TimestampzExp(b.AS("timestamp with time zone")) } + +// Cast expression AS interval type +func (b *castImpl) AS_INTERVAL() IntervalExpression { + return IntervalExp(b.AS("interval")) +} diff --git a/postgres/cast_test.go b/postgres/cast_test.go index a1e4be5..e02336a 100644 --- a/postgres/cast_test.go +++ b/postgres/cast_test.go @@ -62,3 +62,10 @@ func TestExpressionCAST_AS_TIMESTAMP(t *testing.T) { func TestExpressionCAST_AS_TIMESTAMPZ(t *testing.T) { assertSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone") } + +func TestExpressionCAST_AS_INTERVAL(t *testing.T) { + assertSerialize(t, CAST(table2ColTimez).AS_INTERVAL(), "table2.col_timez::interval") + assertSerialize(t, CAST(Time(20, 11, 10)).AS_INTERVAL(), "$1::time without time zone::interval", "20:11:10") + assertSerialize(t, table2ColDate.SUB(CAST(Time(20, 11, 10)).AS_INTERVAL()), + "(table2.col_date - $1::time without time zone::interval)", "20:11:10") +} diff --git a/postgres/interval.go b/postgres/interval.go index dba2c6c..f6344fd 100644 --- a/postgres/interval.go +++ b/postgres/interval.go @@ -129,3 +129,22 @@ func unitToString(unit quantityAndUnit) string { panic("jet: invalid INTERVAL unit type") } } + +//---------------------------------------------------// + +type intervalWrapper struct { + jet.IsInterval + Expression +} + +func newIntervalExpressionWrap(expression Expression) IntervalExpression { + intervalWrap := intervalWrapper{Expression: expression} + return &intervalWrap +} + +// IntervalExp is interval expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as interval expression. +// Does not add sql cast to generated sql builder output. +func IntervalExp(expression Expression) IntervalExpression { + return newIntervalExpressionWrap(expression) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index eed0695..7b9d1d6 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -8,7 +8,6 @@ import ( "gotest.tools/assert" "github.com/go-jet/jet/internal/testutils" - "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" @@ -136,22 +135,23 @@ LIMIT $5; func TestExpressionCast(t *testing.T) { query := AllTypes.SELECT( - postgres.CAST(Int(150)).AS_CHAR(12).AS("char12"), - postgres.CAST(String("TRUE")).AS_BOOL(), - postgres.CAST(String("111")).AS_SMALLINT(), - postgres.CAST(String("111")).AS_INTEGER(), - postgres.CAST(String("111")).AS_BIGINT(), - postgres.CAST(String("11.23")).AS_NUMERIC(30, 10), - postgres.CAST(String("11.23")).AS_NUMERIC(30), - postgres.CAST(String("11.23")).AS_NUMERIC(), - postgres.CAST(String("11.23")).AS_REAL(), - postgres.CAST(String("11.23")).AS_DOUBLE(), - postgres.CAST(Int(234)).AS_TEXT(), - postgres.CAST(String("1/8/1999")).AS_DATE(), - postgres.CAST(String("04:05:06.789")).AS_TIME(), - postgres.CAST(String("04:05:06 PST")).AS_TIMEZ(), - postgres.CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), - postgres.CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(Int(150)).AS_CHAR(12).AS("char12"), + CAST(String("TRUE")).AS_BOOL(), + CAST(String("111")).AS_SMALLINT(), + CAST(String("111")).AS_INTEGER(), + CAST(String("111")).AS_BIGINT(), + CAST(String("11.23")).AS_NUMERIC(30, 10), + CAST(String("11.23")).AS_NUMERIC(30), + CAST(String("11.23")).AS_NUMERIC(), + CAST(String("11.23")).AS_REAL(), + CAST(String("11.23")).AS_DOUBLE(), + CAST(Int(234)).AS_TEXT(), + CAST(String("1/8/1999")).AS_DATE(), + CAST(String("04:05:06.789")).AS_TIME(), + CAST(String("04:05:06 PST")).AS_TIMEZ(), + CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), + CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(String("04:05:06")).AS_INTERVAL(), TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), TO_CHAR(AllTypes.Integer, String("999")), @@ -361,7 +361,7 @@ func TestFloatOperators(t *testing.T) { TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), - TRUNC(postgres.CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), + TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), CEIL(AllTypes.Real).AS("ceil"), FLOOR(AllTypes.Real).AS("floor"), @@ -619,6 +619,8 @@ func TestTimeExpression(t *testing.T) { AllTypes.Timestampz.ADD(INTERVAL(1, HOUR)), AllTypes.Timestampz.SUB(INTERVAL(1, MINUTE)), + AllTypes.Date.SUB(CAST(String("04:05:06")).AS_INTERVAL()), + CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIME(2), From 57aa62f483fc34018048830697072ca708ef141c Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 7 Dec 2019 18:52:46 +0100 Subject: [PATCH 07/19] Simplified creation of operator expression. --- internal/jet/bool_expression.go | 53 ++++++++---------------------- internal/jet/expression.go | 29 +++++++++------- internal/jet/integer_expression.go | 39 ++++++++-------------- internal/jet/operators.go | 22 ++++++------- internal/jet/string_expression.go | 20 ++++------- postgres/dialect_test.go | 14 ++++++++ 6 files changed, 75 insertions(+), 102 deletions(-) diff --git a/internal/jet/bool_expression.go b/internal/jet/bool_expression.go index fa9342f..1a05ab6 100644 --- a/internal/jet/bool_expression.go +++ b/internal/jet/bool_expression.go @@ -53,80 +53,53 @@ func (b *boolInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BoolExpression) BoolExpress } func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression { - return newBinaryBoolOperator(b.parent, expression, "AND") + return newBinaryBoolOperatorExpression(b.parent, expression, "AND") } func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression { - return newBinaryBoolOperator(b.parent, expression, "OR") + return newBinaryBoolOperatorExpression(b.parent, expression, "OR") } func (b *boolInterfaceImpl) IS_TRUE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS TRUE") + return newPostfixBoolOperatorExpression(b.parent, "IS TRUE") } func (b *boolInterfaceImpl) IS_NOT_TRUE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS NOT TRUE") + return newPostfixBoolOperatorExpression(b.parent, "IS NOT TRUE") } func (b *boolInterfaceImpl) IS_FALSE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS FALSE") + return newPostfixBoolOperatorExpression(b.parent, "IS FALSE") } func (b *boolInterfaceImpl) IS_NOT_FALSE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS NOT FALSE") + return newPostfixBoolOperatorExpression(b.parent, "IS NOT FALSE") } func (b *boolInterfaceImpl) IS_UNKNOWN() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS UNKNOWN") + return newPostfixBoolOperatorExpression(b.parent, "IS UNKNOWN") } func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS NOT UNKNOWN") + return newPostfixBoolOperatorExpression(b.parent, "IS NOT UNKNOWN") } //---------------------------------------------------// -func newBinaryBoolOperator(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression { +func newBinaryBoolOperatorExpression(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression { return BoolExp(newBinaryOperatorExpression(lhs, rhs, operator, additionalParams...)) } //---------------------------------------------------// -type prefixBoolExpression struct { - ExpressionInterfaceImpl - boolInterfaceImpl - - prefixOpExpression -} - -func newPrefixBoolOperator(expression Expression, operator string) BoolExpression { - exp := prefixBoolExpression{} - exp.prefixOpExpression = newPrefixExpression(expression, operator) - - exp.ExpressionInterfaceImpl.Parent = &exp - exp.boolInterfaceImpl.parent = &exp - - return &exp +func newPrefixBoolOperatorExpression(expression Expression, operator string) BoolExpression { + return BoolExp(newPrefixOperatorExpression(expression, operator)) } //---------------------------------------------------// -type postfixBoolOpExpression struct { - ExpressionInterfaceImpl - boolInterfaceImpl - - postfixOpExpression -} - -func newPostifxBoolExpression(expression Expression, operator string) BoolExpression { - exp := postfixBoolOpExpression{} - exp.postfixOpExpression = newPostfixOpExpression(expression, operator) - - exp.ExpressionInterfaceImpl.Parent = &exp - exp.boolInterfaceImpl.parent = &exp - - return &exp +func newPostfixBoolOperatorExpression(expression Expression, operator string) BoolExpression { + return BoolExp(newPostfixOperatorExpression(expression, operator)) } //---------------------------------------------------// - type boolExpressionWrapper struct { boolInterfaceImpl Expression diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 6a99638..379bd70 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -36,19 +36,19 @@ func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { } func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression { - return newPostifxBoolExpression(e.Parent, "IS NULL") + return newPostfixBoolOperatorExpression(e.Parent, "IS NULL") } func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { - return newPostifxBoolExpression(e.Parent, "IS NOT NULL") + return newPostfixBoolOperatorExpression(e.Parent, "IS NOT NULL") } func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN") + return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN") } func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN") + return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN") } func (e *ExpressionInterfaceImpl) AS(alias string) Projection { @@ -129,21 +129,24 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu } // A prefix operator Expression -type prefixOpExpression struct { +type prefixExpression struct { + ExpressionInterfaceImpl + expression Expression operator string } -func newPrefixExpression(expression Expression, operator string) prefixOpExpression { - prefixExpression := prefixOpExpression{ +func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression { + prefixExpression := &prefixExpression{ expression: expression, operator: operator, } + prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression return prefixExpression } -func (p *prefixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("(") out.WriteString(p.operator) @@ -156,18 +159,22 @@ func (p *prefixOpExpression) serialize(statement StatementType, out *SQLBuilder, out.WriteString(")") } -// A postifx operator Expression +// A postfix operator Expression type postfixOpExpression struct { + ExpressionInterfaceImpl + expression Expression operator string } -func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression { - postfixOpExpression := postfixOpExpression{ +func newPostfixOperatorExpression(expression Expression, operator string) *postfixOpExpression { + postfixOpExpression := &postfixOpExpression{ expression: expression, operator: operator, } + postfixOpExpression.ExpressionInterfaceImpl.Parent = postfixOpExpression + return postfixOpExpression } diff --git a/internal/jet/integer_expression.go b/internal/jet/integer_expression.go index efda817..c004437 100644 --- a/internal/jet/integer_expression.go +++ b/internal/jet/integer_expression.go @@ -86,23 +86,23 @@ func (i *integerInterfaceImpl) LT_EQ(expression IntegerExpression) BoolExpressio } func (i *integerInterfaceImpl) ADD(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "+") + return newBinaryIntegerOperatorExpression(i.parent, expression, "+") } func (i *integerInterfaceImpl) SUB(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "-") + return newBinaryIntegerOperatorExpression(i.parent, expression, "-") } func (i *integerInterfaceImpl) MUL(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "*") + return newBinaryIntegerOperatorExpression(i.parent, expression, "*") } func (i *integerInterfaceImpl) DIV(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "/") + return newBinaryIntegerOperatorExpression(i.parent, expression, "/") } func (i *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "%") + return newBinaryIntegerOperatorExpression(i.parent, expression, "%") } func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression { @@ -110,46 +110,33 @@ func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpressi } func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "&") + return newBinaryIntegerOperatorExpression(i.parent, expression, "&") } func (i *integerInterfaceImpl) BIT_OR(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "|") + return newBinaryIntegerOperatorExpression(i.parent, expression, "|") } func (i *integerInterfaceImpl) BIT_XOR(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "#") + return newBinaryIntegerOperatorExpression(i.parent, expression, "#") } func (i *integerInterfaceImpl) BIT_SHIFT_LEFT(intExpression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, intExpression, "<<") + return newBinaryIntegerOperatorExpression(i.parent, intExpression, "<<") } func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, intExpression, ">>") + return newBinaryIntegerOperatorExpression(i.parent, intExpression, ">>") } //---------------------------------------------------// -func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { +func newBinaryIntegerOperatorExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { return IntExp(newBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// -type prefixIntegerOpExpression struct { - ExpressionInterfaceImpl - integerInterfaceImpl - - prefixOpExpression -} - -func newPrefixIntegerOperator(expression IntegerExpression, operator string) IntegerExpression { - integerExpression := prefixIntegerOpExpression{} - integerExpression.prefixOpExpression = newPrefixExpression(expression, operator) - - integerExpression.ExpressionInterfaceImpl.Parent = &integerExpression - integerExpression.integerInterfaceImpl.parent = &integerExpression - - return &integerExpression +func newPrefixIntegerOperatorExpression(expression IntegerExpression, operator string) IntegerExpression { + return IntExp(newPrefixOperatorExpression(expression, operator)) } //---------------------------------------------------// diff --git a/internal/jet/operators.go b/internal/jet/operators.go index 9229717..fad1e26 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -11,7 +11,7 @@ const ( // NOT returns negation of bool expression result func NOT(exp BoolExpression) BoolExpression { - return newPrefixBoolOperator(exp, "NOT") + return newPrefixBoolOperatorExpression(exp, "NOT") } // BIT_NOT inverts every bit in integer expression result @@ -19,52 +19,52 @@ func BIT_NOT(expr IntegerExpression) IntegerExpression { if literalExp, ok := expr.(LiteralExpression); ok { literalExp.SetConstant(true) } - return newPrefixIntegerOperator(expr, "~") + return newPrefixIntegerOperatorExpression(expr, "~") } //----------- Comparison operators ---------------// // EXISTS checks for existence of the rows in subQuery func EXISTS(subQuery Expression) BoolExpression { - return newPrefixBoolOperator(subQuery, "EXISTS") + return newPrefixBoolOperatorExpression(subQuery, "EXISTS") } // Returns a representation of "a=b" func eq(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "=") + return newBinaryBoolOperatorExpression(lhs, rhs, "=") } // Returns a representation of "a!=b" func notEq(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "!=") + return newBinaryBoolOperatorExpression(lhs, rhs, "!=") } func isDistinctFrom(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "IS DISTINCT FROM") + return newBinaryBoolOperatorExpression(lhs, rhs, "IS DISTINCT FROM") } func isNotDistinctFrom(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "IS NOT DISTINCT FROM") + return newBinaryBoolOperatorExpression(lhs, rhs, "IS NOT DISTINCT FROM") } // Returns a representation of "ab" func gt(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, ">") + return newBinaryBoolOperatorExpression(lhs, rhs, ">") } // Returns a representation of "a>=b" func gtEq(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, ">=") + return newBinaryBoolOperatorExpression(lhs, rhs, ">=") } // --------------- CASE operator -------------------// diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index a3a76c7..29ceca6 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -60,35 +60,27 @@ func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression { } func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression { - return newBinaryStringExpression(s.parent, rhs, StringConcatOperator) + return newBinaryStringOperatorExpression(s.parent, rhs, StringConcatOperator) } func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, "LIKE") + return newBinaryBoolOperatorExpression(s.parent, pattern, "LIKE") } func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE") + return newBinaryBoolOperatorExpression(s.parent, pattern, "NOT LIKE") } func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return newBinaryBoolOperatorExpression(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) } func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return newBinaryBoolOperatorExpression(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) } //---------------------------------------------------// - -type binaryStringExpression struct { - ExpressionInterfaceImpl - stringInterfaceImpl - - binaryOperatorExpression -} - -func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpression { +func newBinaryStringOperatorExpression(lhs, rhs Expression, operator string) StringExpression { return StringExp(newBinaryOperatorExpression(lhs, rhs, operator)) } diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index a37c0c9..f53587e 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -27,6 +27,20 @@ func TestExists(t *testing.T) { FROM db.table2 WHERE table1.col1 = table2.col3 ))`, int64(1)) + + assertSerialize(t, EXISTS( + SELECT(Int(1)), + ).EQ(Bool(true)), + `((EXISTS ( + SELECT $1 +)) = $2)`, int64(1), true) + + assertProjectionSerialize(t, EXISTS( + SELECT(Int(1)), + ).AS("exists"), + `(EXISTS ( + SELECT $1 +)) AS "exists"`, int64(1)) } func TestIN(t *testing.T) { From 2487c48428b66d2bce79e721dd0595eb6c536172 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 8 Dec 2019 11:07:49 +0100 Subject: [PATCH 08/19] Fix linter errors. --- internal/jet/dialect.go | 2 +- internal/jet/expression.go | 22 ++++++++---- internal/jet/interval.go | 3 ++ internal/jet/literal_expression.go | 6 ++-- internal/jet/serializer.go | 1 + internal/jet/sql_builder.go | 35 +++++++++++++++---- ...serializer_test.go => sql_builder_test.go} | 10 +++++- internal/jet/utils.go | 1 + internal/testutils/test_utils.go | 2 +- internal/utils/utils.go | 1 + mysql/interval.go | 20 +++++++---- mysql/interval_test.go | 20 ++++++----- postgres/interval.go | 11 +++--- postgres/interval_test.go | 3 ++ 14 files changed, 100 insertions(+), 37 deletions(-) rename internal/jet/{serializer_test.go => sql_builder_test.go} (70%) diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 71720d1..acf03d9 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -14,7 +14,7 @@ type Dialect interface { // SerializerFunc func type SerializerFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption) -//// SerializeOverride func +// SerializeOverride func type SerializeOverride func(expressions ...Serializer) SerializerFunc // QueryPlaceholderFunc func diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 379bd70..26b9186 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -8,25 +8,26 @@ type Expression interface { GroupByClause OrderByClause - // Test expression whether it is a NULL value. + // IS_NULL tests expression whether it is a NULL value. IS_NULL() BoolExpression - // Test expression whether it is a non-NULL value. + // IS_NOT_NULL tests expression whether it is a non-NULL value. IS_NOT_NULL() BoolExpression - // Check if this expressions matches any in expressions list + // IN checks if this expressions matches any in expressions list IN(expressions ...Expression) BoolExpression - // Check if this expressions is different of all expressions in expressions list + // NOT_IN checks if this expressions is different of all expressions in expressions list NOT_IN(expressions ...Expression) BoolExpression - // The temporary alias name to assign to the expression + // AS the temporary alias name to assign to the expression AS(alias string) Projection - // Expression will be used to sort query result in ascending order + // ASC expression will be used to sort query result in ascending order ASC() OrderByClause - // Expression will be used to sort query result in ascending order + // DESC expression will be used to sort query result in ascending order DESC() OrderByClause } +// ExpressionInterfaceImpl implements Expression interface methods type ExpressionInterfaceImpl struct { Parent Expression } @@ -35,30 +36,37 @@ func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { return e.Parent } +// IS_NULL tests expression whether it is a NULL value. func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression { return newPostfixBoolOperatorExpression(e.Parent, "IS NULL") } +// IS_NOT_NULL tests expression whether it is a non-NULL value. func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { return newPostfixBoolOperatorExpression(e.Parent, "IS NOT NULL") } +// IN checks if this expressions matches any in expressions list func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN") } +// NOT_IN checks if this expressions is different of all expressions in expressions list func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN") } +// AS the temporary alias name to assign to the expression func (e *ExpressionInterfaceImpl) AS(alias string) Projection { return newAlias(e.Parent, alias) } +// ASC expression will be used to sort query result in ascending order func (e *ExpressionInterfaceImpl) ASC() OrderByClause { return newOrderByClause(e.Parent, true) } +// DESC expression will be used to sort query result in ascending order func (e *ExpressionInterfaceImpl) DESC() OrderByClause { return newOrderByClause(e.Parent, false) } diff --git a/internal/jet/interval.go b/internal/jet/interval.go index 705f2be..e66ca56 100644 --- a/internal/jet/interval.go +++ b/internal/jet/interval.go @@ -1,14 +1,17 @@ package jet +// Interval is internal common representation of sql interval type Interval interface { Serializer IsInterval } +// IsInterval interface type IsInterval interface { isInterval() } +// NewInterval creates new interval from serializer func NewInterval(s Serializer) Interval { newInterval := &intervalImpl{ interval: s, diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 6983b8f..499b7b4 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -334,20 +334,20 @@ func WRAP(expression ...Expression) Expression { //---------------------------------------------------// -type RawExpression struct { +type rawExpression struct { ExpressionInterfaceImpl Raw string } -func (n *RawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(n.Raw) } // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") func Raw(raw string) Expression { - rawExp := &RawExpression{Raw: raw} + rawExp := &rawExpression{Raw: raw} rawExp.ExpressionInterfaceImpl.Parent = rawExp return rawExp diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index d5ff4b9..dc661d7 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -42,6 +42,7 @@ func contains(options []SerializeOption, option SerializeOption) bool { return false } +// ListSerializer serializes list of serializers with separator type ListSerializer struct { Serializers []Serializer Separator string diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 3b34ab2..95bd0b6 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -142,12 +142,8 @@ func argToString(value interface{}) string { return "TRUE" } return "FALSE" - case int: - return strconv.FormatInt(int64(bindVal), 10) - case int32: - return strconv.FormatInt(int64(bindVal), 10) - case int64: - return strconv.FormatInt(bindVal, 10) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return integerTypesToString(bindVal) case float32: return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) @@ -167,6 +163,33 @@ func argToString(value interface{}) string { } } +func integerTypesToString(value interface{}) string { + switch bindVal := value.(type) { + case bool: + case int: + return strconv.FormatInt(int64(bindVal), 10) + case uint: + return strconv.FormatUint(uint64(bindVal), 10) + case int8: + return strconv.FormatInt(int64(bindVal), 10) + case uint8: + return strconv.FormatUint(uint64(bindVal), 10) + case int16: + return strconv.FormatInt(int64(bindVal), 10) + case uint16: + return strconv.FormatUint(uint64(bindVal), 10) + case int32: + return strconv.FormatInt(int64(bindVal), 10) + case uint32: + return strconv.FormatUint(uint64(bindVal), 10) + case int64: + return strconv.FormatInt(bindVal, 10) + case uint64: + return strconv.FormatUint(bindVal, 10) + } + panic("jet: Unsupported integer type: " + reflect.TypeOf(value).String()) +} + func shouldQuoteIdentifier(identifier string) bool { for _, c := range identifier { if unicode.IsNumber(c) || c == '_' { diff --git a/internal/jet/serializer_test.go b/internal/jet/sql_builder_test.go similarity index 70% rename from internal/jet/serializer_test.go rename to internal/jet/sql_builder_test.go index 6d2fd4a..dc4a476 100644 --- a/internal/jet/serializer_test.go +++ b/internal/jet/sql_builder_test.go @@ -12,8 +12,16 @@ func TestArgToString(t *testing.T) { assert.Equal(t, argToString(false), "FALSE") assert.Equal(t, argToString(int(-32)), "-32") - assert.Equal(t, argToString(int32(-32)), "-32") + assert.Equal(t, argToString(uint(32)), "32") + assert.Equal(t, argToString(int8(-43)), "-43") + assert.Equal(t, argToString(uint8(43)), "43") + assert.Equal(t, argToString(int16(-54)), "-54") + assert.Equal(t, argToString(uint16(54)), "54") + assert.Equal(t, argToString(int32(-65)), "-65") + assert.Equal(t, argToString(uint32(65)), "65") assert.Equal(t, argToString(int64(-64)), "-64") + assert.Equal(t, argToString(uint64(64)), "64") + assert.Equal(t, argToString(float32(2.0)), "2") assert.Equal(t, argToString(float64(1.11)), "1.11") assert.Equal(t, argToString("john"), "'john'") diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 67c2c6c..58394f4 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -63,6 +63,7 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } +// ExpressionListToSerializerList converts list of expressions to list of serializers func ExpressionListToSerializerList(expressions []Expression) []Serializer { var ret []Serializer diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 1b19a20..c3d3ff0 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -129,7 +129,7 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali } } -// AssertClauseSerialize checks if clause serialize produces expected query and args +// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { out := jet.SQLBuilder{Dialect: dialect, Debug: true} jet.Serialize(clause, jet.SelectStatementType, &out) diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 97ac48a..42a5c36 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -184,6 +184,7 @@ func StringSliceContains(strings []string, contains string) bool { return false } +// ExtractDateTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration func ExtractDateTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) { days = int64(duration / (24 * time.Hour)) reminder := duration % (24 * time.Hour) diff --git a/mysql/interval.go b/mysql/interval.go index 57c2cfc..478e9c4 100644 --- a/mysql/interval.go +++ b/mysql/interval.go @@ -9,10 +9,11 @@ import ( "github.com/go-jet/jet/internal/utils" ) -type UnitType string +type unitType string +// List of interval unit types for MySQL const ( - MICROSECOND UnitType = "MICROSECOND" + MICROSECOND unitType = "MICROSECOND" SECOND = "SECOND" MINUTE = "MINUTE" HOUR = "HOUR" @@ -34,9 +35,15 @@ const ( YEAR_MONTH = "YEAR_MONTH" ) +// Interval is representation of MySQL interval type Interval = jet.Interval -func INTERVAL(value interface{}, unitType UnitType) Interval { +// INTERVAL creates new Interval type. +// In a case of MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR unit type +// value parameter should be number. For example: INTERVAL(1, DAY) +// In a case of other unit types, value should be string with appropriate format. +// For example: INTERVAL("10:08:50", HOUR_SECOND) +func INTERVAL(value interface{}, unitType unitType) Interval { switch unitType { case MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR: if !isNumericType(value) { @@ -87,14 +94,15 @@ func INTERVAL(value interface{}, unitType UnitType) Interval { } } -func INTERVALe(expr Expression, unitType UnitType) Interval { +// INTERVALe creates new Interval type from expresion and unit type. +func INTERVALe(expr Expression, unitType unitType) Interval { return jet.NewInterval(jet.ListSerializer{ Serializers: []jet.Serializer{expr, jet.Raw(string(unitType))}, Separator: " ", }) } -// INTERVALd returns a representation of duration as MySQL INTERVAL +// INTERVALd returns a interval representation from duration func INTERVALd(duration time.Duration) Interval { var sign int64 = 1 if duration < 0 { @@ -179,7 +187,7 @@ var ( func isNumericType(value interface{}) bool { switch value.(type) { - case float64, float32, int16, int32, int64, uint16, uint32, uint64, int, uint: + case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return true default: return false diff --git a/mysql/interval_test.go b/mysql/interval_test.go index a7c16f8..c88b808 100644 --- a/mysql/interval_test.go +++ b/mysql/interval_test.go @@ -32,14 +32,18 @@ func TestINTERVAL(t *testing.T) { assertDebugSerialize(t, INTERVAL("08.000100", SECOND_MICROSECOND), "INTERVAL '08.000100' SECOND_MICROSECOND") assertDebugSerialize(t, INTERVAL("-08.000100", SECOND_MICROSECOND), "INTERVAL '-08.000100' SECOND_MICROSECOND") - assertDebugSerialize(t, INTERVAL(15, SECOND), "INTERVAL 15 SECOND") - assertDebugSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL 1 MICROSECOND") - assertDebugSerialize(t, INTERVAL(2, MINUTE), "INTERVAL 2 MINUTE") - assertDebugSerialize(t, INTERVAL(3, HOUR), "INTERVAL 3 HOUR") - assertDebugSerialize(t, INTERVAL(4, DAY), "INTERVAL 4 DAY") - assertDebugSerialize(t, INTERVAL(5, MONTH), "INTERVAL 5 MONTH") - assertDebugSerialize(t, INTERVAL(6, YEAR), "INTERVAL 6 YEAR") - assertDebugSerialize(t, INTERVAL(-6, YEAR), "INTERVAL -6 YEAR") + assertSerialize(t, INTERVAL(15, SECOND), "INTERVAL 15 SECOND") + assertSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL 1 MICROSECOND") + assertSerialize(t, INTERVAL(2, MINUTE), "INTERVAL 2 MINUTE") + assertSerialize(t, INTERVAL(3, HOUR), "INTERVAL 3 HOUR") + assertSerialize(t, INTERVAL(4, DAY), "INTERVAL 4 DAY") + assertSerialize(t, INTERVAL(5, MONTH), "INTERVAL 5 MONTH") + assertSerialize(t, INTERVAL(6, YEAR), "INTERVAL 6 YEAR") + assertSerialize(t, INTERVAL(-6, YEAR), "INTERVAL -6 YEAR") + + assertSerialize(t, INTERVAL(uint(6), YEAR), "INTERVAL 6 YEAR") + assertSerialize(t, INTERVAL(int16(7), YEAR), "INTERVAL 7 YEAR") + assertSerialize(t, INTERVAL(3.5, YEAR), "INTERVAL 3.5 YEAR") } func TestINTERVAL_InvalidUnitType(t *testing.T) { diff --git a/postgres/interval.go b/postgres/interval.go index f6344fd..de659e7 100644 --- a/postgres/interval.go +++ b/postgres/interval.go @@ -9,12 +9,11 @@ import ( "time" ) -type quantityAndUnit float64 +type quantityAndUnit = float64 +// Interval unit types const ( - pow2_32 = -4.294967296e+09 - - YEAR quantityAndUnit = pow2_32 + iota + YEAR quantityAndUnit = 123456789 + iota MONTH WEEK DAY @@ -33,11 +32,14 @@ type intervalExpressionImpl struct { jet.ExpressionInterfaceImpl } +// IntervalExpression is representation of postgres INTERVAL type IntervalExpression interface { jet.IsInterval jet.Expression } +// INTERVAL creates new interval expression from the list of quantity-unit pairs. +// For example: INTERVAL(1, DAY, 3, MINUTE) func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { if len(quantityAndUnit)%2 != 0 { panic("jet: invalid number of quantity and unit fields") @@ -62,6 +64,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { return newInterval } +// INTERVALd creates interval expression from duration func INTERVALd(duration time.Duration) IntervalExpression { days, hours, minutes, seconds, microseconds := utils.ExtractDateTimeComponents(duration) diff --git a/postgres/interval_test.go b/postgres/interval_test.go index 1d1e3de..785f1d5 100644 --- a/postgres/interval_test.go +++ b/postgres/interval_test.go @@ -25,6 +25,9 @@ func TestINTERVAL(t *testing.T) { assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL") assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`) + + f := 5.2 + assertSerialize(t, INTERVAL(f, YEAR), "INTERVAL '5.2 YEAR'") } func TestINTERVALd(t *testing.T) { From 74725e8e112fddfb3d112594232ef35ee8e8fabe Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 14 Dec 2019 18:32:40 +0100 Subject: [PATCH 09/19] Clean up. --- internal/jet/sql_builder.go | 1 - qrm/utill_test.go | 2 ++ tests/postgres/insert_test.go | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 95bd0b6..f71ca1e 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -165,7 +165,6 @@ func argToString(value interface{}) string { func integerTypesToString(value interface{}) string { switch bindVal := value.(type) { - case bool: case int: return strconv.FormatInt(int64(bindVal), 10) case uint: diff --git a/qrm/utill_test.go b/qrm/utill_test.go index e4ab53d..168c55f 100644 --- a/qrm/utill_test.go +++ b/qrm/utill_test.go @@ -32,4 +32,6 @@ func TestIsSimpleModelType(t *testing.T) { assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false) assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false) + assert.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false) + assert.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false) } diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index ee402b8..161ffed 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -24,7 +24,6 @@ RETURNING link.id AS "link.id", link.name AS "link.name", link.description AS "link.description"; ` - Link.ID.Name() insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(101, "http://www.google.com", "Google", DEFAULT). From d2fbdb68e6acb32d7aecd8cc17d9952015f1a19f Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 14 Dec 2019 19:11:35 +0100 Subject: [PATCH 10/19] Add support for conditional constructed projection list. --- mysql/types.go | 3 +++ postgres/types.go | 3 +++ tests/mysql/select_test.go | 33 +++++++++++++++++++++++++++++++++ tests/postgres/select_test.go | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+) diff --git a/mysql/types.go b/mysql/types.go index 2902e72..4ef84b4 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -7,3 +7,6 @@ type Statement = jet.Statement // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. type Projection = jet.Projection + +// ProjectionList can be used to create conditional constructed projection list. +type ProjectionList = jet.ProjectionList diff --git a/postgres/types.go b/postgres/types.go index 215cceb..58a8ae9 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -7,3 +7,6 @@ type Statement = jet.Statement // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. type Projection = jet.Projection + +// ProjectionList can be used to create conditional constructed projection list. +type ProjectionList = jet.ProjectionList diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 952eb63..ed75bad 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -708,3 +708,36 @@ func TestJoinViewWithTable(t *testing.T) { assert.Equal(t, len(dest[0].Rentals), 32) assert.Equal(t, len(dest[1].Rentals), 27) } + +func TestConditionalProjectionList(t *testing.T) { + projectionList := ProjectionList{} + + columnsToSelect := []string{"customer_id", "create_date"} + + for _, columnName := range columnsToSelect { + switch columnName { + case Customer.CustomerID.Name(): + projectionList = append(projectionList, Customer.CustomerID) + case Customer.Email.Name(): + projectionList = append(projectionList, Customer.Email) + case Customer.CreateDate.Name(): + projectionList = append(projectionList, Customer.CreateDate) + } + } + + stmt := SELECT(projectionList). + FROM(Customer). + LIMIT(3) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT customer.customer_id AS "customer.customer_id", + customer.create_date AS "customer.create_date" +FROM dvds.customer +LIMIT 3; +`) + var dest []model.Customer + err := stmt.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 3) +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 2ac6c36..35ba373 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1791,3 +1791,36 @@ func TestJoinViewWithTable(t *testing.T) { assert.Equal(t, len(dest[0].Rentals), 32) assert.Equal(t, len(dest[1].Rentals), 27) } + +func TestConditionalProjectionList(t *testing.T) { + projectionList := ProjectionList{} + + columnsToSelect := []string{"customer_id", "create_date"} + + for _, columnName := range columnsToSelect { + switch columnName { + case Customer.CustomerID.Name(): + projectionList = append(projectionList, Customer.CustomerID) + case Customer.Email.Name(): + projectionList = append(projectionList, Customer.Email) + case Customer.CreateDate.Name(): + projectionList = append(projectionList, Customer.CreateDate) + } + } + + stmt := SELECT(projectionList). + FROM(Customer). + LIMIT(3) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT customer.customer_id AS "customer.customer_id", + customer.create_date AS "customer.create_date" +FROM dvds.customer +LIMIT 3; +`) + var dest []model.Customer + err := stmt.Query(db, &dest) + assert.NilError(t, err) + + assert.Equal(t, len(dest), 3) +} From 6e446ee100dc3a46f215a7812a28588a38014b92 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 21 Dec 2019 16:01:16 +0100 Subject: [PATCH 11/19] Update Readme.md --- README.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9875199..76d36be 100644 --- a/README.md +++ b/README.md @@ -35,13 +35,13 @@ https://medium.com/@go.jet/jet-5f3667efa0cc ## Features 1) Auto-generated type-safe SQL Builder - PostgreSQL: - * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, sub-queries)` + * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, INTERSECT, EXCEPT, WINDOW, sub-queries)` * INSERT `(VALUES, query, RETURNING)`, * UPDATE `(SET, WHERE, RETURNING)`, * DELETE `(WHERE, RETURNING)`, * LOCK `(IN, NOWAIT)` - MySQL and MariaDB: - * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, sub-queries)` + * SELECT `(DISTINCT, FROM, WHERE, GROUP BY, HAVING, ORDER BY, LIMIT, OFFSET, FOR, UNION, LOCK_IN_SHARE_MODE, WINDOW, sub-queries)` * INSERT `(VALUES, query)`, * UPDATE `(SET, WHERE)`, * DELETE `(WHERE, ORDER_BY, LIMIT)`, @@ -566,9 +566,6 @@ To run the tests, additional dependencies are required: [SemVer](http://semver.org/) is used for versioning. For the versions available, take a look at the [releases](https://github.com/go-jet/jet/releases). - -For now there is no guarantee that public API will remain backward compatible. Please read new release drafts to get acquaint how to handle possible build breakable API changes. - ## License Copyright 2019 Goran Bjelanovic From 641c62098c6c45cc4ec1724f7fd7c48c6bb2a5c2 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 21 Dec 2019 16:50:16 +0100 Subject: [PATCH 12/19] Update comments. --- mysql/interval.go | 6 +++--- postgres/interval.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mysql/interval.go b/mysql/interval.go index 478e9c4..ee0354e 100644 --- a/mysql/interval.go +++ b/mysql/interval.go @@ -38,7 +38,7 @@ const ( // Interval is representation of MySQL interval type Interval = jet.Interval -// INTERVAL creates new Interval type. +// INTERVAL creates new temporal interval. // In a case of MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR unit type // value parameter should be number. For example: INTERVAL(1, DAY) // In a case of other unit types, value should be string with appropriate format. @@ -94,7 +94,7 @@ func INTERVAL(value interface{}, unitType unitType) Interval { } } -// INTERVALe creates new Interval type from expresion and unit type. +// INTERVALe creates new temporal interval from expresion and unit type. func INTERVALe(expr Expression, unitType unitType) Interval { return jet.NewInterval(jet.ListSerializer{ Serializers: []jet.Serializer{expr, jet.Raw(string(unitType))}, @@ -102,7 +102,7 @@ func INTERVALe(expr Expression, unitType unitType) Interval { }) } -// INTERVALd returns a interval representation from duration +// INTERVALd temoral interval from time.Duration func INTERVALd(duration time.Duration) Interval { var sign int64 = 1 if duration < 0 { diff --git a/postgres/interval.go b/postgres/interval.go index de659e7..80592e8 100644 --- a/postgres/interval.go +++ b/postgres/interval.go @@ -64,7 +64,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { return newInterval } -// INTERVALd creates interval expression from duration +// INTERVALd creates interval expression from time.Duration func INTERVALd(duration time.Duration) IntervalExpression { days, hours, minutes, seconds, microseconds := utils.ExtractDateTimeComponents(duration) From 3013dc36479b05eebf05afd85bc8be98e8f354f5 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 9 Feb 2020 18:37:48 +0100 Subject: [PATCH 13/19] Add support for PostgreSQL interval column --- .../internal/metadata/column_meta_data.go | 4 +- internal/jet/alias.go | 3 +- internal/jet/bool_expression.go | 10 +- internal/jet/column.go | 51 ++- internal/jet/column_test.go | 2 +- internal/jet/column_types.go | 103 ++---- internal/jet/column_types_test.go | 71 +++- internal/jet/date_expression.go | 20 +- internal/jet/expression.go | 3 +- internal/jet/float_expression.go | 49 ++- internal/jet/integer_expression.go | 58 +-- internal/jet/interval.go | 17 +- internal/jet/literal_expression.go | 4 +- internal/jet/operators.go | 55 ++- internal/jet/string_expression.go | 18 +- internal/jet/time_expression.go | 20 +- internal/jet/timestamp_expression.go | 20 +- internal/jet/timestampz_expression.go | 20 +- internal/jet/timez_expression.go | 20 +- internal/jet/utils.go | 20 + internal/jet/utils_test.go | 19 + postgres/columns.go | 35 +- postgres/columns_test.go | 20 + postgres/expressions.go | 3 + .../{interval.go => interval_expression.go} | 101 +++++- ...al_test.go => interval_expression_test.go} | 24 +- postgres/utils_test.go | 4 + tests/postgres/alltypes_test.go | 161 +++++++- tests/postgres/generator_test.go | 343 +++++++++++++++++- tests/postgres/main_test.go | 15 + 30 files changed, 1038 insertions(+), 255 deletions(-) create mode 100644 internal/jet/utils_test.go create mode 100644 postgres/columns_test.go rename postgres/{interval.go => interval_expression.go} (52%) rename postgres/{interval_test.go => interval_expression_test.go} (61%) diff --git a/generator/internal/metadata/column_meta_data.go b/generator/internal/metadata/column_meta_data.go index 69a16f7..56ae54c 100644 --- a/generator/internal/metadata/column_meta_data.go +++ b/generator/internal/metadata/column_meta_data.go @@ -57,8 +57,10 @@ func (c ColumnMetaData) getSqlBuilderColumnType() string { return "Time" case "time with time zone": return "Timez" + case "interval": + return "Interval" case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid", - "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "interval", "line", "ARRAY", + "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", "char", "varchar", "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL return "String" diff --git a/internal/jet/alias.go b/internal/jet/alias.go index a2b1ae3..57f55cd 100644 --- a/internal/jet/alias.go +++ b/internal/jet/alias.go @@ -13,8 +13,7 @@ func newAlias(expression Expression, aliasName string) Projection { } func (a *alias) fromImpl(subQuery SelectTable) Projection { - column := newColumn(a.alias, "", nil) - column.Parent = &column + column := NewColumnImpl(a.alias, "", nil) column.subQuery = subQuery return &column diff --git a/internal/jet/bool_expression.go b/internal/jet/bool_expression.go index 1a05ab6..b4015b2 100644 --- a/internal/jet/bool_expression.go +++ b/internal/jet/bool_expression.go @@ -37,19 +37,19 @@ type boolInterfaceImpl struct { } func (b *boolInterfaceImpl) EQ(expression BoolExpression) BoolExpression { - return eq(b.parent, expression) + return Eq(b.parent, expression) } func (b *boolInterfaceImpl) NOT_EQ(expression BoolExpression) BoolExpression { - return notEq(b.parent, expression) + return NotEq(b.parent, expression) } func (b *boolInterfaceImpl) IS_DISTINCT_FROM(rhs BoolExpression) BoolExpression { - return isDistinctFrom(b.parent, rhs) + return IsDistinctFrom(b.parent, rhs) } func (b *boolInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BoolExpression) BoolExpression { - return isNotDistinctFrom(b.parent, rhs) + return IsNotDistinctFrom(b.parent, rhs) } func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression { @@ -86,7 +86,7 @@ func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression { //---------------------------------------------------// func newBinaryBoolOperatorExpression(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression { - return BoolExp(newBinaryOperatorExpression(lhs, rhs, operator, additionalParams...)) + return BoolExp(NewBinaryOperatorExpression(lhs, rhs, operator, additionalParams...)) } //---------------------------------------------------// diff --git a/internal/jet/column.go b/internal/jet/column.go index c7a5f41..85c053e 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -18,8 +18,8 @@ type ColumnExpression interface { Expression } -// The base type for real materialized columns. -type columnImpl struct { +// ColumnExpressionImpl is base type for sql columns. +type ColumnExpressionImpl struct { ExpressionInterfaceImpl name string @@ -28,34 +28,41 @@ type columnImpl struct { subQuery SelectTable } -func newColumn(name string, tableName string, parent ColumnExpression) columnImpl { - bc := columnImpl{ +// NewColumnImpl creates new ColumnExpressionImpl +func NewColumnImpl(name string, tableName string, parent ColumnExpression) ColumnExpressionImpl { + bc := ColumnExpressionImpl{ name: name, tableName: tableName, } - bc.ExpressionInterfaceImpl.Parent = parent + if parent != nil { + bc.ExpressionInterfaceImpl.Parent = parent + } else { + bc.ExpressionInterfaceImpl.Parent = &bc + } return bc } -func (c *columnImpl) Name() string { +// Name returns name of the column +func (c *ColumnExpressionImpl) Name() string { return c.name } -func (c *columnImpl) TableName() string { +// TableName returns column table name +func (c *ColumnExpressionImpl) TableName() string { return c.tableName } -func (c *columnImpl) setTableName(table string) { +func (c *ColumnExpressionImpl) setTableName(table string) { c.tableName = table } -func (c *columnImpl) setSubQuery(subQuery SelectTable) { +func (c *ColumnExpressionImpl) setSubQuery(subQuery SelectTable) { c.subQuery = subQuery } -func (c *columnImpl) defaultAlias() string { +func (c *ColumnExpressionImpl) defaultAlias() string { if c.tableName != "" { return c.tableName + "." + c.name } @@ -63,25 +70,31 @@ func (c *columnImpl) defaultAlias() string { return c.name } -func (c *columnImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { +func (c *ColumnExpressionImpl) fromImpl(subQuery SelectTable) Projection { + newColumn := NewColumnImpl(c.name, c.tableName, nil) + newColumn.setSubQuery(subQuery) + + return &newColumn +} + +func (c *ColumnExpressionImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { if statement == SetStatementType { // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause out.WriteAlias(c.defaultAlias()) //always quote - return } c.serialize(statement, out) } -func (c columnImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { +func (c ColumnExpressionImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { c.serialize(statement, out) out.WriteString("AS") out.WriteAlias(c.defaultAlias()) } -func (c columnImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if c.subQuery != nil { out.WriteIdentifier(c.subQuery.Alias()) @@ -128,3 +141,13 @@ func (cl ColumnList) TableName() string { return "" } func (cl ColumnList) setTableName(name string) {} func (cl ColumnList) setSubQuery(subQuery SelectTable) {} func (cl ColumnList) defaultAlias() string { return "" } + +// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName +func SetTableName(columnExpression ColumnExpression, tableName string) { + columnExpression.setTableName(tableName) +} + +// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery +func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) { + columnExpression.setSubQuery(subQuery) +} diff --git a/internal/jet/column_test.go b/internal/jet/column_test.go index ca3f5f6..7e6c0ed 100644 --- a/internal/jet/column_test.go +++ b/internal/jet/column_test.go @@ -3,7 +3,7 @@ package jet import "testing" func TestColumn(t *testing.T) { - column := newColumn("col", "", nil) + column := NewColumnImpl("col", "", nil) column.ExpressionInterfaceImpl.Parent = &column assertClauseSerialize(t, column, "col") diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index 5443911..58f6751 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -10,11 +10,10 @@ type ColumnBool interface { type boolColumnImpl struct { boolInterfaceImpl - - columnImpl + ColumnExpressionImpl } -func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { newBoolColumn := BoolColumn(i.name) newBoolColumn.setTableName(i.tableName) newBoolColumn.setSubQuery(subQuery) @@ -22,16 +21,10 @@ func (i *boolColumnImpl) fromImpl(subQuery SelectTable) Projection { return newBoolColumn } -func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { - newBoolColumn := i.fromImpl(subQuery).(ColumnBool) - - return newBoolColumn -} - // BoolColumn creates named bool column. func BoolColumn(name string) ColumnBool { boolColumn := &boolColumnImpl{} - boolColumn.columnImpl = newColumn(name, "", boolColumn) + boolColumn.ColumnExpressionImpl = NewColumnImpl(name, "", boolColumn) boolColumn.boolInterfaceImpl.parent = boolColumn return boolColumn @@ -49,19 +42,13 @@ type ColumnFloat interface { type floatColumnImpl struct { floatInterfaceImpl - columnImpl -} - -func (i *floatColumnImpl) fromImpl(subQuery SelectTable) Projection { - newFloatColumn := FloatColumn(i.name) - newFloatColumn.setTableName(i.tableName) - newFloatColumn.setSubQuery(subQuery) - - return newFloatColumn + ColumnExpressionImpl } func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { - newFloatColumn := i.fromImpl(subQuery).(ColumnFloat) + newFloatColumn := FloatColumn(i.name) + newFloatColumn.setTableName(i.tableName) + newFloatColumn.setSubQuery(subQuery) return newFloatColumn } @@ -70,7 +57,7 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { func FloatColumn(name string) ColumnFloat { floatColumn := &floatColumnImpl{} floatColumn.floatInterfaceImpl.parent = floatColumn - floatColumn.columnImpl = newColumn(name, "", floatColumn) + floatColumn.ColumnExpressionImpl = NewColumnImpl(name, "", floatColumn) return floatColumn } @@ -88,10 +75,10 @@ type ColumnInteger interface { type integerColumnImpl struct { integerInterfaceImpl - columnImpl + ColumnExpressionImpl } -func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { newIntColumn := IntegerColumn(i.name) newIntColumn.setTableName(i.tableName) newIntColumn.setSubQuery(subQuery) @@ -99,15 +86,11 @@ func (i *integerColumnImpl) fromImpl(subQuery SelectTable) Projection { return newIntColumn } -func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { - return i.fromImpl(subQuery).(ColumnInteger) -} - // IntegerColumn creates named integer column. func IntegerColumn(name string) ColumnInteger { integerColumn := &integerColumnImpl{} integerColumn.integerInterfaceImpl.parent = integerColumn - integerColumn.columnImpl = newColumn(name, "", integerColumn) + integerColumn.ColumnExpressionImpl = NewColumnImpl(name, "", integerColumn) return integerColumn } @@ -126,10 +109,10 @@ type ColumnString interface { type stringColumnImpl struct { stringInterfaceImpl - columnImpl + ColumnExpressionImpl } -func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { newStrColumn := StringColumn(i.name) newStrColumn.setTableName(i.tableName) newStrColumn.setSubQuery(subQuery) @@ -137,15 +120,11 @@ func (i *stringColumnImpl) fromImpl(subQuery SelectTable) Projection { return newStrColumn } -func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { - return i.fromImpl(subQuery).(ColumnString) -} - // StringColumn creates named string column. func StringColumn(name string) ColumnString { stringColumn := &stringColumnImpl{} stringColumn.stringInterfaceImpl.parent = stringColumn - stringColumn.columnImpl = newColumn(name, "", stringColumn) + stringColumn.ColumnExpressionImpl = NewColumnImpl(name, "", stringColumn) return stringColumn } @@ -162,10 +141,10 @@ type ColumnTime interface { type timeColumnImpl struct { timeInterfaceImpl - columnImpl + ColumnExpressionImpl } -func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { newTimeColumn := TimeColumn(i.name) newTimeColumn.setTableName(i.tableName) newTimeColumn.setSubQuery(subQuery) @@ -173,15 +152,11 @@ func (i *timeColumnImpl) fromImpl(subQuery SelectTable) Projection { return newTimeColumn } -func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { - return i.fromImpl(subQuery).(ColumnTime) -} - // TimeColumn creates named time column func TimeColumn(name string) ColumnTime { timeColumn := &timeColumnImpl{} timeColumn.timeInterfaceImpl.parent = timeColumn - timeColumn.columnImpl = newColumn(name, "", timeColumn) + timeColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timeColumn) return timeColumn } @@ -197,11 +172,10 @@ type ColumnTimez interface { type timezColumnImpl struct { timezInterfaceImpl - - columnImpl + ColumnExpressionImpl } -func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { newTimezColumn := TimezColumn(i.name) newTimezColumn.setTableName(i.tableName) newTimezColumn.setSubQuery(subQuery) @@ -209,15 +183,11 @@ func (i *timezColumnImpl) fromImpl(subQuery SelectTable) Projection { return newTimezColumn } -func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { - return i.fromImpl(subQuery).(ColumnTimez) -} - // TimezColumn creates named time with time zone column. func TimezColumn(name string) ColumnTimez { timezColumn := &timezColumnImpl{} timezColumn.timezInterfaceImpl.parent = timezColumn - timezColumn.columnImpl = newColumn(name, "", timezColumn) + timezColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timezColumn) return timezColumn } @@ -234,11 +204,10 @@ type ColumnTimestamp interface { type timestampColumnImpl struct { timestampInterfaceImpl - - columnImpl + ColumnExpressionImpl } -func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { newTimestampColumn := TimestampColumn(i.name) newTimestampColumn.setTableName(i.tableName) newTimestampColumn.setSubQuery(subQuery) @@ -246,15 +215,11 @@ func (i *timestampColumnImpl) fromImpl(subQuery SelectTable) Projection { return newTimestampColumn } -func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { - return i.fromImpl(subQuery).(ColumnTimestamp) -} - // TimestampColumn creates named timestamp column func TimestampColumn(name string) ColumnTimestamp { timestampColumn := ×tampColumnImpl{} timestampColumn.timestampInterfaceImpl.parent = timestampColumn - timestampColumn.columnImpl = newColumn(name, "", timestampColumn) + timestampColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timestampColumn) return timestampColumn } @@ -271,11 +236,10 @@ type ColumnTimestampz interface { type timestampzColumnImpl struct { timestampzInterfaceImpl - - columnImpl + ColumnExpressionImpl } -func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { newTimestampzColumn := TimestampzColumn(i.name) newTimestampzColumn.setTableName(i.tableName) newTimestampzColumn.setSubQuery(subQuery) @@ -283,15 +247,11 @@ func (i *timestampzColumnImpl) fromImpl(subQuery SelectTable) Projection { return newTimestampzColumn } -func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { - return i.fromImpl(subQuery).(ColumnTimestampz) -} - // TimestampzColumn creates named timestamp with time zone column. func TimestampzColumn(name string) ColumnTimestampz { timestampzColumn := ×tampzColumnImpl{} timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn - timestampzColumn.columnImpl = newColumn(name, "", timestampzColumn) + timestampzColumn.ColumnExpressionImpl = NewColumnImpl(name, "", timestampzColumn) return timestampzColumn } @@ -308,11 +268,10 @@ type ColumnDate interface { type dateColumnImpl struct { dateInterfaceImpl - - columnImpl + ColumnExpressionImpl } -func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection { +func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { newDateColumn := DateColumn(i.name) newDateColumn.setTableName(i.tableName) newDateColumn.setSubQuery(subQuery) @@ -320,14 +279,10 @@ func (i *dateColumnImpl) fromImpl(subQuery SelectTable) Projection { return newDateColumn } -func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { - return i.fromImpl(subQuery).(ColumnDate) -} - // DateColumn creates named date column. func DateColumn(name string) ColumnDate { dateColumn := &dateColumnImpl{} dateColumn.dateInterfaceImpl.parent = dateColumn - dateColumn.columnImpl = newColumn(name, "", dateColumn) + dateColumn.ColumnExpressionImpl = NewColumnImpl(name, "", dateColumn) return dateColumn } diff --git a/internal/jet/column_types_test.go b/internal/jet/column_types_test.go index 8c00d5b..9cff309 100644 --- a/internal/jet/column_types_test.go +++ b/internal/jet/column_types_test.go @@ -43,5 +43,74 @@ func TestNewFloatColumnColumn(t *testing.T) { assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`) assertClauseSerialize(t, floatColumn2.EQ(Float(2.22)), `(sub_query."table1.col_float" = $1)`, float64(2.22)) assertProjectionSerialize(t, floatColumn2, `sub_query."table1.col_float" AS "table1.col_float"`) - +} + +func TestNewDateColumnColumn(t *testing.T) { + dateColumn := DateColumn("col_date").From(subQuery) + assertClauseSerialize(t, dateColumn, `sub_query."col_date"`) + assertClauseSerialize(t, dateColumn.EQ(Date(2002, 2, 3)), + `(sub_query."col_date" = $1)`, "2002-02-03") + assertProjectionSerialize(t, dateColumn, `sub_query."col_date" AS "col_date"`) + + dateColumn2 := table1ColDate.From(subQuery) + assertClauseSerialize(t, dateColumn2, `sub_query."table1.col_date"`) + assertClauseSerialize(t, dateColumn2.EQ(Date(2002, 2, 3)), + `(sub_query."table1.col_date" = $1)`, "2002-02-03") + assertProjectionSerialize(t, dateColumn2, `sub_query."table1.col_date" AS "table1.col_date"`) +} + +func TestNewTimeColumnColumn(t *testing.T) { + timeColumn := TimeColumn("col_time").From(subQuery) + assertClauseSerialize(t, timeColumn, `sub_query."col_time"`) + assertClauseSerialize(t, timeColumn.EQ(Time(1, 1, 1, 1)), + `(sub_query."col_time" = $1)`, "01:01:01.000000001") + assertProjectionSerialize(t, timeColumn, `sub_query."col_time" AS "col_time"`) + + timeColumn2 := table1ColTime.From(subQuery) + assertClauseSerialize(t, timeColumn2, `sub_query."table1.col_time"`) + assertClauseSerialize(t, timeColumn2.EQ(Time(2, 2, 2)), + `(sub_query."table1.col_time" = $1)`, "02:02:02") + assertProjectionSerialize(t, timeColumn2, `sub_query."table1.col_time" AS "table1.col_time"`) +} + +func TestNewTimezColumnColumn(t *testing.T) { + timezColumn := TimezColumn("col_timez").From(subQuery) + assertClauseSerialize(t, timezColumn, `sub_query."col_timez"`) + assertClauseSerialize(t, timezColumn.EQ(Timez(1, 1, 1, 1, "UTC")), + `(sub_query."col_timez" = $1)`, "01:01:01.000000001 UTC") + assertProjectionSerialize(t, timezColumn, `sub_query."col_timez" AS "col_timez"`) + + timezColumn2 := table1ColTimez.From(subQuery) + assertClauseSerialize(t, timezColumn2, `sub_query."table1.col_timez"`) + assertClauseSerialize(t, timezColumn2.EQ(Timez(2, 2, 2, 0, "UTC")), + `(sub_query."table1.col_timez" = $1)`, "02:02:02 UTC") + assertProjectionSerialize(t, timezColumn2, `sub_query."table1.col_timez" AS "table1.col_timez"`) +} + +func TestNewTimestampColumnColumn(t *testing.T) { + timestampColumn := TimestampColumn("col_timestamp").From(subQuery) + assertClauseSerialize(t, timestampColumn, `sub_query."col_timestamp"`) + assertClauseSerialize(t, timestampColumn.EQ(Timestamp(1, 1, 1, 1, 1, 1)), + `(sub_query."col_timestamp" = $1)`, "0001-01-01 01:01:01") + assertProjectionSerialize(t, timestampColumn, `sub_query."col_timestamp" AS "col_timestamp"`) + + timestampColumn2 := table1ColTimestamp.From(subQuery) + assertClauseSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp"`) + assertClauseSerialize(t, timestampColumn2.EQ(Timestamp(2, 2, 2, 2, 2, 2)), + `(sub_query."table1.col_timestamp" = $1)`, "0002-02-02 02:02:02") + assertProjectionSerialize(t, timestampColumn2, `sub_query."table1.col_timestamp" AS "table1.col_timestamp"`) +} + +func TestNewTimestampzColumnColumn(t *testing.T) { + timestampzColumn := TimestampzColumn("col_timestampz").From(subQuery) + assertClauseSerialize(t, timestampzColumn, `sub_query."col_timestampz"`) + assertClauseSerialize(t, timestampzColumn.EQ(Timestampz(1, 1, 1, 1, 1, 1, 0, "UTC")), + `(sub_query."col_timestampz" = $1)`, "0001-01-01 01:01:01 UTC") + assertProjectionSerialize(t, timestampzColumn, `sub_query."col_timestampz" AS "col_timestampz"`) + + timestampzColumn2 := table1ColTimestampz.From(subQuery) + assertClauseSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz"`) + assertClauseSerialize(t, timestampzColumn2.EQ(Timestampz(2, 2, 2, 2, 2, 2, 0, "UTC")), + `(sub_query."table1.col_timestampz" = $1)`, "0002-02-02 02:02:02 UTC") + assertProjectionSerialize(t, timestampzColumn2, `sub_query."table1.col_timestampz" AS "table1.col_timestampz"`) } diff --git a/internal/jet/date_expression.go b/internal/jet/date_expression.go index 8b0a524..27c2035 100644 --- a/internal/jet/date_expression.go +++ b/internal/jet/date_expression.go @@ -23,43 +23,43 @@ type dateInterfaceImpl struct { } func (d *dateInterfaceImpl) EQ(rhs DateExpression) BoolExpression { - return eq(d.parent, rhs) + return Eq(d.parent, rhs) } func (d *dateInterfaceImpl) NOT_EQ(rhs DateExpression) BoolExpression { - return notEq(d.parent, rhs) + return NotEq(d.parent, rhs) } func (d *dateInterfaceImpl) IS_DISTINCT_FROM(rhs DateExpression) BoolExpression { - return isDistinctFrom(d.parent, rhs) + return IsDistinctFrom(d.parent, rhs) } func (d *dateInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs DateExpression) BoolExpression { - return isNotDistinctFrom(d.parent, rhs) + return IsNotDistinctFrom(d.parent, rhs) } func (d *dateInterfaceImpl) LT(rhs DateExpression) BoolExpression { - return lt(d.parent, rhs) + return Lt(d.parent, rhs) } func (d *dateInterfaceImpl) LT_EQ(rhs DateExpression) BoolExpression { - return ltEq(d.parent, rhs) + return LtEq(d.parent, rhs) } func (d *dateInterfaceImpl) GT(rhs DateExpression) BoolExpression { - return gt(d.parent, rhs) + return Gt(d.parent, rhs) } func (d *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression { - return gtEq(d.parent, rhs) + return GtEq(d.parent, rhs) } func (d *dateInterfaceImpl) ADD(rhs Interval) TimestampExpression { - return TimestampExp(newBinaryOperatorExpression(d.parent, rhs, "+")) + return TimestampExp(Add(d.parent, rhs)) } func (d *dateInterfaceImpl) SUB(rhs Interval) TimestampExpression { - return TimestampExp(newBinaryOperatorExpression(d.parent, rhs, "-")) + return TimestampExp(Sub(d.parent, rhs)) } //---------------------------------------------------// diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 26b9186..a463b76 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -92,7 +92,8 @@ type binaryOperatorExpression struct { operator string } -func newBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression { +// NewBinaryOperatorExpression creates new binaryOperatorExpression +func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression { binaryExpression := &binaryOperatorExpression{ lhs: lhs, rhs: rhs, diff --git a/internal/jet/float_expression.go b/internal/jet/float_expression.go index aa821ba..20c8d3f 100644 --- a/internal/jet/float_expression.go +++ b/internal/jet/float_expression.go @@ -29,64 +29,59 @@ type floatInterfaceImpl struct { } func (n *floatInterfaceImpl) EQ(rhs FloatExpression) BoolExpression { - return eq(n.parent, rhs) + return Eq(n.parent, rhs) } func (n *floatInterfaceImpl) NOT_EQ(rhs FloatExpression) BoolExpression { - return notEq(n.parent, rhs) + return NotEq(n.parent, rhs) } func (n *floatInterfaceImpl) IS_DISTINCT_FROM(rhs FloatExpression) BoolExpression { - return isDistinctFrom(n.parent, rhs) + return IsDistinctFrom(n.parent, rhs) } func (n *floatInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs FloatExpression) BoolExpression { - return isNotDistinctFrom(n.parent, rhs) + return IsNotDistinctFrom(n.parent, rhs) } func (n *floatInterfaceImpl) GT(rhs FloatExpression) BoolExpression { - return gt(n.parent, rhs) + return Gt(n.parent, rhs) } func (n *floatInterfaceImpl) GT_EQ(rhs FloatExpression) BoolExpression { - return gtEq(n.parent, rhs) + return GtEq(n.parent, rhs) } -func (n *floatInterfaceImpl) LT(expression FloatExpression) BoolExpression { - return lt(n.parent, expression) +func (n *floatInterfaceImpl) LT(rhs FloatExpression) BoolExpression { + return Lt(n.parent, rhs) } -func (n *floatInterfaceImpl) LT_EQ(expression FloatExpression) BoolExpression { - return ltEq(n.parent, expression) +func (n *floatInterfaceImpl) LT_EQ(rhs FloatExpression) BoolExpression { + return LtEq(n.parent, rhs) } -func (n *floatInterfaceImpl) ADD(expression NumericExpression) FloatExpression { - return newBinaryFloatExpression(n.parent, expression, "+") +func (n *floatInterfaceImpl) ADD(rhs NumericExpression) FloatExpression { + return FloatExp(Add(n.parent, rhs)) } -func (n *floatInterfaceImpl) SUB(expression NumericExpression) FloatExpression { - return newBinaryFloatExpression(n.parent, expression, "-") +func (n *floatInterfaceImpl) SUB(rhs NumericExpression) FloatExpression { + return FloatExp(Sub(n.parent, rhs)) } -func (n *floatInterfaceImpl) MUL(expression NumericExpression) FloatExpression { - return newBinaryFloatExpression(n.parent, expression, "*") +func (n *floatInterfaceImpl) MUL(rhs NumericExpression) FloatExpression { + return FloatExp(Mul(n.parent, rhs)) } -func (n *floatInterfaceImpl) DIV(expression NumericExpression) FloatExpression { - return newBinaryFloatExpression(n.parent, expression, "/") +func (n *floatInterfaceImpl) DIV(rhs NumericExpression) FloatExpression { + return FloatExp(Div(n.parent, rhs)) } -func (n *floatInterfaceImpl) MOD(expression NumericExpression) FloatExpression { - return newBinaryFloatExpression(n.parent, expression, "%") +func (n *floatInterfaceImpl) MOD(rhs NumericExpression) FloatExpression { + return FloatExp(Mod(n.parent, rhs)) } -func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression { - return POW(n.parent, expression) -} - -//---------------------------------------------------// -func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpression { - return FloatExp(newBinaryOperatorExpression(lhs, rhs, operator)) +func (n *floatInterfaceImpl) POW(rhs NumericExpression) FloatExpression { + return POW(n.parent, rhs) } //---------------------------------------------------// diff --git a/internal/jet/integer_expression.go b/internal/jet/integer_expression.go index c004437..ff2a0a0 100644 --- a/internal/jet/integer_expression.go +++ b/internal/jet/integer_expression.go @@ -54,71 +54,71 @@ type integerInterfaceImpl struct { } func (i *integerInterfaceImpl) EQ(rhs IntegerExpression) BoolExpression { - return eq(i.parent, rhs) + return Eq(i.parent, rhs) } func (i *integerInterfaceImpl) NOT_EQ(rhs IntegerExpression) BoolExpression { - return notEq(i.parent, rhs) + return NotEq(i.parent, rhs) } func (i *integerInterfaceImpl) IS_DISTINCT_FROM(rhs IntegerExpression) BoolExpression { - return isDistinctFrom(i.parent, rhs) + return IsDistinctFrom(i.parent, rhs) } func (i *integerInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntegerExpression) BoolExpression { - return isNotDistinctFrom(i.parent, rhs) + return IsNotDistinctFrom(i.parent, rhs) } func (i *integerInterfaceImpl) GT(rhs IntegerExpression) BoolExpression { - return gt(i.parent, rhs) + return Gt(i.parent, rhs) } func (i *integerInterfaceImpl) GT_EQ(rhs IntegerExpression) BoolExpression { - return gtEq(i.parent, rhs) + return GtEq(i.parent, rhs) } -func (i *integerInterfaceImpl) LT(expression IntegerExpression) BoolExpression { - return lt(i.parent, expression) +func (i *integerInterfaceImpl) LT(rhs IntegerExpression) BoolExpression { + return Lt(i.parent, rhs) } -func (i *integerInterfaceImpl) LT_EQ(expression IntegerExpression) BoolExpression { - return ltEq(i.parent, expression) +func (i *integerInterfaceImpl) LT_EQ(rhs IntegerExpression) BoolExpression { + return LtEq(i.parent, rhs) } -func (i *integerInterfaceImpl) ADD(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "+") +func (i *integerInterfaceImpl) ADD(rhs IntegerExpression) IntegerExpression { + return IntExp(Add(i.parent, rhs)) } -func (i *integerInterfaceImpl) SUB(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "-") +func (i *integerInterfaceImpl) SUB(rhs IntegerExpression) IntegerExpression { + return IntExp(Sub(i.parent, rhs)) } -func (i *integerInterfaceImpl) MUL(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "*") +func (i *integerInterfaceImpl) MUL(rhs IntegerExpression) IntegerExpression { + return IntExp(Mul(i.parent, rhs)) } -func (i *integerInterfaceImpl) DIV(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "/") +func (i *integerInterfaceImpl) DIV(rhs IntegerExpression) IntegerExpression { + return IntExp(Div(i.parent, rhs)) } -func (i *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "%") +func (i *integerInterfaceImpl) MOD(rhs IntegerExpression) IntegerExpression { + return IntExp(Mod(i.parent, rhs)) } -func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression { - return IntExp(POW(i.parent, expression)) +func (i *integerInterfaceImpl) POW(rhs IntegerExpression) IntegerExpression { + return IntExp(POW(i.parent, rhs)) } -func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "&") +func (i *integerInterfaceImpl) BIT_AND(rhs IntegerExpression) IntegerExpression { + return newBinaryIntegerOperatorExpression(i.parent, rhs, "&") } -func (i *integerInterfaceImpl) BIT_OR(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "|") +func (i *integerInterfaceImpl) BIT_OR(rhs IntegerExpression) IntegerExpression { + return newBinaryIntegerOperatorExpression(i.parent, rhs, "|") } -func (i *integerInterfaceImpl) BIT_XOR(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerOperatorExpression(i.parent, expression, "#") +func (i *integerInterfaceImpl) BIT_XOR(rhs IntegerExpression) IntegerExpression { + return newBinaryIntegerOperatorExpression(i.parent, rhs, "#") } func (i *integerInterfaceImpl) BIT_SHIFT_LEFT(intExpression IntegerExpression) IntegerExpression { @@ -131,7 +131,7 @@ func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) //---------------------------------------------------// func newBinaryIntegerOperatorExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { - return IntExp(newBinaryOperatorExpression(lhs, rhs, operator)) + return IntExp(NewBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// diff --git a/internal/jet/interval.go b/internal/jet/interval.go index e66ca56..fab84e0 100644 --- a/internal/jet/interval.go +++ b/internal/jet/interval.go @@ -11,22 +11,27 @@ type IsInterval interface { isInterval() } +// IsIntervalImpl is implementation of IsInterval interface +type IsIntervalImpl struct{} + +func (i *IsIntervalImpl) isInterval() {} + // NewInterval creates new interval from serializer -func NewInterval(s Serializer) Interval { - newInterval := &intervalImpl{ +func NewInterval(s Serializer) *IntervalImpl { + newInterval := &IntervalImpl{ interval: s, } return newInterval } -type intervalImpl struct { +// IntervalImpl is implementation of Interval type +type IntervalImpl struct { interval Serializer + IsIntervalImpl } -func (i intervalImpl) isInterval() {} - -func (i intervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("INTERVAL") i.interval.serialize(statement, out, options...) } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 499b7b4..15c8a4a 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -346,9 +346,9 @@ func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, opti // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") -func Raw(raw string) Expression { +func Raw(raw string, parent ...Expression) Expression { rawExp := &rawExpression{Raw: raw} - rawExp.ExpressionInterfaceImpl.Parent = rawExp + rawExp.ExpressionInterfaceImpl.Parent = OptionalOrDefaultExpression(rawExp, parent...) return rawExp } diff --git a/internal/jet/operators.go b/internal/jet/operators.go index fad1e26..d17081c 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -29,44 +29,71 @@ func EXISTS(subQuery Expression) BoolExpression { return newPrefixBoolOperatorExpression(subQuery, "EXISTS") } -// Returns a representation of "a=b" -func eq(lhs, rhs Expression) BoolExpression { +// Eq returns a representation of "a=b" +func Eq(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "=") } -// Returns a representation of "a!=b" -func notEq(lhs, rhs Expression) BoolExpression { +// NotEq returns a representation of "a!=b" +func NotEq(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "!=") } -func isDistinctFrom(lhs, rhs Expression) BoolExpression { +// IsDistinctFrom returns a representation of "a IS DISTINCT FROM b" +func IsDistinctFrom(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "IS DISTINCT FROM") } -func isNotDistinctFrom(lhs, rhs Expression) BoolExpression { +// IsNotDistinctFrom returns a representation of "a IS NOT DISTINCT FROM b" +func IsNotDistinctFrom(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, "IS NOT DISTINCT FROM") } -// Returns a representation of "ab" -func gt(lhs, rhs Expression) BoolExpression { +// Gt returns a representation of "a>b" +func Gt(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, ">") } -// Returns a representation of "a>=b" -func gtEq(lhs, rhs Expression) BoolExpression { +// GtEq returns a representation of "a>=b" +func GtEq(lhs, rhs Expression) BoolExpression { return newBinaryBoolOperatorExpression(lhs, rhs, ">=") } +// Add notEq returns a representation of "a + b" +func Add(lhs, rhs Serializer) Expression { + return NewBinaryOperatorExpression(lhs, rhs, "+") +} + +// Sub notEq returns a representation of "a - b" +func Sub(lhs, rhs Serializer) Expression { + return NewBinaryOperatorExpression(lhs, rhs, "-") +} + +// Mul returns a representation of "a * b" +func Mul(lhs, rhs Serializer) Expression { + return NewBinaryOperatorExpression(lhs, rhs, "*") +} + +// Div returns a representation of "a / b" +func Div(lhs, rhs Serializer) Expression { + return NewBinaryOperatorExpression(lhs, rhs, "/") +} + +// Mod returns a representation of "a % b" +func Mod(lhs, rhs Serializer) Expression { + return NewBinaryOperatorExpression(lhs, rhs, "%") +} + // --------------- CASE operator -------------------// // CaseOperator is interface for SQL case operator diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index 29ceca6..3c56896 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -28,35 +28,35 @@ type stringInterfaceImpl struct { } func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression { - return eq(s.parent, rhs) + return Eq(s.parent, rhs) } func (s *stringInterfaceImpl) NOT_EQ(rhs StringExpression) BoolExpression { - return notEq(s.parent, rhs) + return NotEq(s.parent, rhs) } func (s *stringInterfaceImpl) IS_DISTINCT_FROM(rhs StringExpression) BoolExpression { - return isDistinctFrom(s.parent, rhs) + return IsDistinctFrom(s.parent, rhs) } func (s *stringInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs StringExpression) BoolExpression { - return isNotDistinctFrom(s.parent, rhs) + return IsNotDistinctFrom(s.parent, rhs) } func (s *stringInterfaceImpl) GT(rhs StringExpression) BoolExpression { - return gt(s.parent, rhs) + return Gt(s.parent, rhs) } func (s *stringInterfaceImpl) GT_EQ(rhs StringExpression) BoolExpression { - return gtEq(s.parent, rhs) + return GtEq(s.parent, rhs) } func (s *stringInterfaceImpl) LT(rhs StringExpression) BoolExpression { - return lt(s.parent, rhs) + return Lt(s.parent, rhs) } func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression { - return ltEq(s.parent, rhs) + return LtEq(s.parent, rhs) } func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression { @@ -81,7 +81,7 @@ func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSens //---------------------------------------------------// func newBinaryStringOperatorExpression(lhs, rhs Expression, operator string) StringExpression { - return StringExp(newBinaryOperatorExpression(lhs, rhs, operator)) + return StringExp(NewBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// diff --git a/internal/jet/time_expression.go b/internal/jet/time_expression.go index b83f731..4fd7047 100644 --- a/internal/jet/time_expression.go +++ b/internal/jet/time_expression.go @@ -23,43 +23,43 @@ type timeInterfaceImpl struct { } func (t *timeInterfaceImpl) EQ(rhs TimeExpression) BoolExpression { - return eq(t.parent, rhs) + return Eq(t.parent, rhs) } func (t *timeInterfaceImpl) NOT_EQ(rhs TimeExpression) BoolExpression { - return notEq(t.parent, rhs) + return NotEq(t.parent, rhs) } func (t *timeInterfaceImpl) IS_DISTINCT_FROM(rhs TimeExpression) BoolExpression { - return isDistinctFrom(t.parent, rhs) + return IsDistinctFrom(t.parent, rhs) } func (t *timeInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimeExpression) BoolExpression { - return isNotDistinctFrom(t.parent, rhs) + return IsNotDistinctFrom(t.parent, rhs) } func (t *timeInterfaceImpl) LT(rhs TimeExpression) BoolExpression { - return lt(t.parent, rhs) + return Lt(t.parent, rhs) } func (t *timeInterfaceImpl) LT_EQ(rhs TimeExpression) BoolExpression { - return ltEq(t.parent, rhs) + return LtEq(t.parent, rhs) } func (t *timeInterfaceImpl) GT(rhs TimeExpression) BoolExpression { - return gt(t.parent, rhs) + return Gt(t.parent, rhs) } func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression { - return gtEq(t.parent, rhs) + return GtEq(t.parent, rhs) } func (t *timeInterfaceImpl) ADD(rhs Interval) TimeExpression { - return TimeExp(newBinaryOperatorExpression(t.parent, rhs, "+")) + return TimeExp(Add(t.parent, rhs)) } func (t *timeInterfaceImpl) SUB(rhs Interval) TimeExpression { - return TimeExp(newBinaryOperatorExpression(t.parent, rhs, "-")) + return TimeExp(Sub(t.parent, rhs)) } //---------------------------------------------------// diff --git a/internal/jet/timestamp_expression.go b/internal/jet/timestamp_expression.go index 81eda61..f4cdd0b 100644 --- a/internal/jet/timestamp_expression.go +++ b/internal/jet/timestamp_expression.go @@ -23,43 +23,43 @@ type timestampInterfaceImpl struct { } func (t *timestampInterfaceImpl) EQ(rhs TimestampExpression) BoolExpression { - return eq(t.parent, rhs) + return Eq(t.parent, rhs) } func (t *timestampInterfaceImpl) NOT_EQ(rhs TimestampExpression) BoolExpression { - return notEq(t.parent, rhs) + return NotEq(t.parent, rhs) } func (t *timestampInterfaceImpl) IS_DISTINCT_FROM(rhs TimestampExpression) BoolExpression { - return isDistinctFrom(t.parent, rhs) + return IsDistinctFrom(t.parent, rhs) } func (t *timestampInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimestampExpression) BoolExpression { - return isNotDistinctFrom(t.parent, rhs) + return IsNotDistinctFrom(t.parent, rhs) } func (t *timestampInterfaceImpl) LT(rhs TimestampExpression) BoolExpression { - return lt(t.parent, rhs) + return Lt(t.parent, rhs) } func (t *timestampInterfaceImpl) LT_EQ(rhs TimestampExpression) BoolExpression { - return ltEq(t.parent, rhs) + return LtEq(t.parent, rhs) } func (t *timestampInterfaceImpl) GT(rhs TimestampExpression) BoolExpression { - return gt(t.parent, rhs) + return Gt(t.parent, rhs) } func (t *timestampInterfaceImpl) GT_EQ(rhs TimestampExpression) BoolExpression { - return gtEq(t.parent, rhs) + return GtEq(t.parent, rhs) } func (t *timestampInterfaceImpl) ADD(rhs Interval) TimestampExpression { - return TimestampExp(newBinaryOperatorExpression(t.parent, rhs, "+")) + return TimestampExp(Add(t.parent, rhs)) } func (t *timestampInterfaceImpl) SUB(rhs Interval) TimestampExpression { - return TimestampExp(newBinaryOperatorExpression(t.parent, rhs, "-")) + return TimestampExp(Sub(t.parent, rhs)) } //------------------------------------------------- diff --git a/internal/jet/timestampz_expression.go b/internal/jet/timestampz_expression.go index a9f8c9f..0112a3c 100644 --- a/internal/jet/timestampz_expression.go +++ b/internal/jet/timestampz_expression.go @@ -23,43 +23,43 @@ type timestampzInterfaceImpl struct { } func (t *timestampzInterfaceImpl) EQ(rhs TimestampzExpression) BoolExpression { - return eq(t.parent, rhs) + return Eq(t.parent, rhs) } func (t *timestampzInterfaceImpl) NOT_EQ(rhs TimestampzExpression) BoolExpression { - return notEq(t.parent, rhs) + return NotEq(t.parent, rhs) } func (t *timestampzInterfaceImpl) IS_DISTINCT_FROM(rhs TimestampzExpression) BoolExpression { - return isDistinctFrom(t.parent, rhs) + return IsDistinctFrom(t.parent, rhs) } func (t *timestampzInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimestampzExpression) BoolExpression { - return isNotDistinctFrom(t.parent, rhs) + return IsNotDistinctFrom(t.parent, rhs) } func (t *timestampzInterfaceImpl) LT(rhs TimestampzExpression) BoolExpression { - return lt(t.parent, rhs) + return Lt(t.parent, rhs) } func (t *timestampzInterfaceImpl) LT_EQ(rhs TimestampzExpression) BoolExpression { - return ltEq(t.parent, rhs) + return LtEq(t.parent, rhs) } func (t *timestampzInterfaceImpl) GT(rhs TimestampzExpression) BoolExpression { - return gt(t.parent, rhs) + return Gt(t.parent, rhs) } func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression { - return gtEq(t.parent, rhs) + return GtEq(t.parent, rhs) } func (t *timestampzInterfaceImpl) ADD(rhs Interval) TimestampzExpression { - return TimestampzExp(newBinaryOperatorExpression(t.parent, rhs, "+")) + return TimestampzExp(Add(t.parent, rhs)) } func (t *timestampzInterfaceImpl) SUB(rhs Interval) TimestampzExpression { - return TimestampzExp(newBinaryOperatorExpression(t.parent, rhs, "-")) + return TimestampzExp(Sub(t.parent, rhs)) } //------------------------------------------------- diff --git a/internal/jet/timez_expression.go b/internal/jet/timez_expression.go index c791c62..d36ec80 100644 --- a/internal/jet/timez_expression.go +++ b/internal/jet/timez_expression.go @@ -23,43 +23,43 @@ type timezInterfaceImpl struct { } func (t *timezInterfaceImpl) EQ(rhs TimezExpression) BoolExpression { - return eq(t.parent, rhs) + return Eq(t.parent, rhs) } func (t *timezInterfaceImpl) NOT_EQ(rhs TimezExpression) BoolExpression { - return notEq(t.parent, rhs) + return NotEq(t.parent, rhs) } func (t *timezInterfaceImpl) IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression { - return isDistinctFrom(t.parent, rhs) + return IsDistinctFrom(t.parent, rhs) } func (t *timezInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression { - return isNotDistinctFrom(t.parent, rhs) + return IsNotDistinctFrom(t.parent, rhs) } func (t *timezInterfaceImpl) LT(rhs TimezExpression) BoolExpression { - return lt(t.parent, rhs) + return Lt(t.parent, rhs) } func (t *timezInterfaceImpl) LT_EQ(rhs TimezExpression) BoolExpression { - return ltEq(t.parent, rhs) + return LtEq(t.parent, rhs) } func (t *timezInterfaceImpl) GT(rhs TimezExpression) BoolExpression { - return gt(t.parent, rhs) + return Gt(t.parent, rhs) } func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression { - return gtEq(t.parent, rhs) + return GtEq(t.parent, rhs) } func (t *timezInterfaceImpl) ADD(rhs Interval) TimezExpression { - return TimezExp(newBinaryOperatorExpression(t.parent, rhs, "+")) + return TimezExp(Add(t.parent, rhs)) } func (t *timezInterfaceImpl) SUB(rhs Interval) TimezExpression { - return TimezExp(newBinaryOperatorExpression(t.parent, rhs, "-")) + return TimezExp(Sub(t.parent, rhs)) } //---------------------------------------------------// diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 58394f4..126ce32 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -187,3 +187,23 @@ func UnwidColumnList(columns []Column) []Column { return ret } + +// OptionalOrDefaultString will return first value from variable argument list str or +// defaultStr if variable argument list is empty +func OptionalOrDefaultString(defaultStr string, str ...string) string { + if len(str) > 0 { + return str[0] + } + + return defaultStr +} + +// OptionalOrDefaultExpression will return first value from variable argument list expression or +// defaultExpression if variable argument list is empty +func OptionalOrDefaultExpression(defaultExpression Expression, expression ...Expression) Expression { + if len(expression) > 0 { + return expression[0] + } + + return defaultExpression +} diff --git a/internal/jet/utils_test.go b/internal/jet/utils_test.go new file mode 100644 index 0000000..ce1935a --- /dev/null +++ b/internal/jet/utils_test.go @@ -0,0 +1,19 @@ +package jet + +import ( + "gotest.tools/assert" + "testing" +) + +func TestOptionalOrDefaultString(t *testing.T) { + assert.Equal(t, OptionalOrDefaultString("default"), "default") + assert.Equal(t, OptionalOrDefaultString("default", "optional"), "optional") +} + +func TestOptionalOrDefaultExpression(t *testing.T) { + defaultExpression := table2ColFloat + optionalExpression := table1Col1 + + assert.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression) + assert.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression) +} diff --git a/postgres/columns.go b/postgres/columns.go index 3109bd3..c62f202 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -1,6 +1,8 @@ package postgres -import "github.com/go-jet/jet/internal/jet" +import ( + "github.com/go-jet/jet/internal/jet" +) // Column is common column interface for all types of columns. type Column = jet.ColumnExpression @@ -62,3 +64,34 @@ type ColumnTimestampz = jet.ColumnTimestampz // TimestampzColumn creates named timestamp with time zone column. var TimestampzColumn = jet.TimestampzColumn + +//------------------------------------------------------// + +// ColumnInterval is interface of PostgreSQL interval columns. +type ColumnInterval interface { + IntervalExpression + jet.Column + + From(subQuery SelectTable) ColumnInterval +} + +type intervalColumnImpl struct { + jet.ColumnExpressionImpl + intervalInterfaceImpl +} + +func (i *intervalColumnImpl) From(subQuery SelectTable) ColumnInterval { + newIntervalColumn := IntervalColumn(i.Name()) + jet.SetTableName(newIntervalColumn, i.TableName()) + jet.SetSubQuery(newIntervalColumn, subQuery) + + return newIntervalColumn +} + +// IntervalColumn creates named interval column. +func IntervalColumn(name string) ColumnInterval { + intervalColumn := &intervalColumnImpl{} + intervalColumn.ColumnExpressionImpl = jet.NewColumnImpl(name, "", intervalColumn) + intervalColumn.intervalInterfaceImpl.parent = intervalColumn + return intervalColumn +} diff --git a/postgres/columns_test.go b/postgres/columns_test.go new file mode 100644 index 0000000..b00c4c6 --- /dev/null +++ b/postgres/columns_test.go @@ -0,0 +1,20 @@ +package postgres + +import ( + "testing" +) + +func TestNewIntervalColumn(t *testing.T) { + subQuery := SELECT(Int(1)).AsTable("sub_query") + + subQueryIntervalColumn := IntervalColumn("col_interval").From(subQuery) + assertSerialize(t, subQueryIntervalColumn, `sub_query."col_interval"`) + assertSerialize(t, subQueryIntervalColumn.EQ(INTERVAL(2, HOUR, 10, MINUTE)), + `(sub_query."col_interval" = INTERVAL '2 HOUR 10 MINUTE')`) + assertProjectionSerialize(t, subQueryIntervalColumn, `sub_query."col_interval" AS "col_interval"`) + + subQueryIntervalColumn2 := table1ColInterval.From(subQuery) + assertSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval"`) + assertSerialize(t, subQueryIntervalColumn2.EQ(INTERVAL(1, DAY)), `(sub_query."table1.col_interval" = INTERVAL '1 DAY')`) + assertProjectionSerialize(t, subQueryIntervalColumn2, `sub_query."table1.col_interval" AS "table1.col_interval"`) +} diff --git a/postgres/expressions.go b/postgres/expressions.go index 0383eee..93072b8 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression // StringExpression interface type StringExpression = jet.StringExpression +// NumericExpression interface +type NumericExpression = jet.NumericExpression + // IntegerExpression interface type IntegerExpression = jet.IntegerExpression diff --git a/postgres/interval.go b/postgres/interval_expression.go similarity index 52% rename from postgres/interval.go rename to postgres/interval_expression.go index 80592e8..50d29da 100644 --- a/postgres/interval.go +++ b/postgres/interval_expression.go @@ -27,39 +27,109 @@ const ( MILLENNIUM ) -type intervalExpressionImpl struct { - jet.Interval - jet.ExpressionInterfaceImpl -} - // IntervalExpression is representation of postgres INTERVAL type IntervalExpression interface { jet.IsInterval jet.Expression + + EQ(rhs IntervalExpression) BoolExpression + NOT_EQ(rhs IntervalExpression) BoolExpression + IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression + + LT(rhs IntervalExpression) BoolExpression + LT_EQ(rhs IntervalExpression) BoolExpression + GT(rhs IntervalExpression) BoolExpression + GT_EQ(rhs IntervalExpression) BoolExpression + + ADD(rhs IntervalExpression) IntervalExpression + SUB(rhs IntervalExpression) IntervalExpression + + MUL(rhs NumericExpression) IntervalExpression + DIV(rhs NumericExpression) IntervalExpression +} + +type intervalInterfaceImpl struct { + jet.IsIntervalImpl + + parent IntervalExpression +} + +func (i *intervalInterfaceImpl) EQ(rhs IntervalExpression) BoolExpression { + return jet.Eq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) NOT_EQ(rhs IntervalExpression) BoolExpression { + return jet.NotEq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) IS_DISTINCT_FROM(rhs IntervalExpression) BoolExpression { + return jet.IsDistinctFrom(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs IntervalExpression) BoolExpression { + return jet.IsNotDistinctFrom(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) LT(rhs IntervalExpression) BoolExpression { + return jet.Lt(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) LT_EQ(rhs IntervalExpression) BoolExpression { + return jet.LtEq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) GT(rhs IntervalExpression) BoolExpression { + return jet.Gt(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) GT_EQ(rhs IntervalExpression) BoolExpression { + return jet.GtEq(i.parent, rhs) +} + +func (i *intervalInterfaceImpl) ADD(rhs IntervalExpression) IntervalExpression { + return IntervalExp(jet.Add(i.parent, rhs)) +} + +func (i *intervalInterfaceImpl) SUB(rhs IntervalExpression) IntervalExpression { + return IntervalExp(jet.Sub(i.parent, rhs)) +} + +func (i *intervalInterfaceImpl) MUL(rhs NumericExpression) IntervalExpression { + return IntervalExp(jet.Mul(i.parent, rhs)) +} + +func (i *intervalInterfaceImpl) DIV(rhs NumericExpression) IntervalExpression { + return IntervalExp(jet.Div(i.parent, rhs)) +} + +type intervalExpression struct { + jet.Expression + intervalInterfaceImpl } // INTERVAL creates new interval expression from the list of quantity-unit pairs. // For example: INTERVAL(1, DAY, 3, MINUTE) func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { - if len(quantityAndUnit)%2 != 0 { + quantityAndUnitLen := len(quantityAndUnit) + if quantityAndUnitLen == 0 || quantityAndUnitLen%2 != 0 { panic("jet: invalid number of quantity and unit fields") } fields := []string{} for i := 0; i < len(quantityAndUnit); i += 2 { - quantity := strconv.FormatFloat(float64(quantityAndUnit[i]), 'f', -1, 64) + quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64) unitString := unitToString(quantityAndUnit[i+1]) fields = append(fields, quantity+" "+unitString) } - intervalStr := fmt.Sprintf("'%s'", strings.Join(fields, " ")) + intervalStr := fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " ")) - newInterval := &intervalExpressionImpl{ - Interval: jet.NewInterval(jet.Raw(intervalStr)), - } + newInterval := &intervalExpression{} - newInterval.ExpressionInterfaceImpl.Parent = newInterval + newInterval.Expression = jet.Raw(intervalStr, newInterval) + newInterval.intervalInterfaceImpl.parent = newInterval return newInterval } @@ -136,13 +206,14 @@ func unitToString(unit quantityAndUnit) string { //---------------------------------------------------// type intervalWrapper struct { - jet.IsInterval + intervalInterfaceImpl Expression } func newIntervalExpressionWrap(expression Expression) IntervalExpression { - intervalWrap := intervalWrapper{Expression: expression} - return &intervalWrap + intervalWrap := &intervalWrapper{Expression: expression} + intervalWrap.intervalInterfaceImpl.parent = intervalWrap + return intervalWrap } // IntervalExp is interval expression wrapper around arbitrary expression. diff --git a/postgres/interval_test.go b/postgres/interval_expression_test.go similarity index 61% rename from postgres/interval_test.go rename to postgres/interval_expression_test.go index 785f1d5..f2fa9fe 100644 --- a/postgres/interval_test.go +++ b/postgres/interval_expression_test.go @@ -44,11 +44,12 @@ func TestINTERVALd(t *testing.T) { } func TestINTERVAL_InvalidParams(t *testing.T) { + assertPanicErr(t, func() { INTERVAL() }, "jet: invalid number of quantity and unit fields") assertPanicErr(t, func() { INTERVAL(1) }, "jet: invalid number of quantity and unit fields") assertPanicErr(t, func() { INTERVAL(1, 2) }, "jet: invalid INTERVAL unit type") } -func TestIntervalArithmetic(t *testing.T) { +func TestDateTimeIntervalArithmetic(t *testing.T) { assertSerialize(t, table2ColDate.ADD(INTERVAL(1, HOUR)), "(table2.col_date + INTERVAL '1 HOUR')") assertSerialize(t, table2ColDate.SUB(INTERVAL(1, HOUR)), "(table2.col_date - INTERVAL '1 HOUR')") assertSerialize(t, table2ColTime.ADD(INTERVAL(1, HOUR)), "(table2.col_time + INTERVAL '1 HOUR')") @@ -60,3 +61,24 @@ func TestIntervalArithmetic(t *testing.T) { assertSerialize(t, table2ColTimestampz.ADD(INTERVAL(1, HOUR)), "(table2.col_timestampz + INTERVAL '1 HOUR')") assertSerialize(t, table2ColTimestampz.SUB(INTERVAL(1, HOUR)), "(table2.col_timestampz - INTERVAL '1 HOUR')") } + +func TestIntervalExpressionMethods(t *testing.T) { + assertSerialize(t, table1ColInterval.EQ(table2ColInterval), "(table1.col_interval = table2.col_interval)") + assertSerialize(t, table1ColInterval.EQ(INTERVAL(10, SECOND)), "(table1.col_interval = INTERVAL '10 SECOND')") + assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)), "(table1.col_interval = INTERVAL '11 MINUTE')") + assertSerialize(t, table1ColInterval.EQ(INTERVALd(11*time.Minute)).EQ(Bool(false)), + "((table1.col_interval = INTERVAL '11 MINUTE') = $1)", false) + assertSerialize(t, table1ColInterval.NOT_EQ(table2ColInterval), "(table1.col_interval != table2.col_interval)") + assertSerialize(t, table1ColInterval.IS_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS DISTINCT FROM table2.col_interval)") + assertSerialize(t, table1ColInterval.IS_NOT_DISTINCT_FROM(table2ColInterval), "(table1.col_interval IS NOT DISTINCT FROM table2.col_interval)") + assertSerialize(t, table1ColInterval.LT(table2ColInterval), "(table1.col_interval < table2.col_interval)") + assertSerialize(t, table1ColInterval.LT_EQ(table2ColInterval), "(table1.col_interval <= table2.col_interval)") + assertSerialize(t, table1ColInterval.GT(table2ColInterval), "(table1.col_interval > table2.col_interval)") + assertSerialize(t, table1ColInterval.GT_EQ(table2ColInterval), "(table1.col_interval >= table2.col_interval)") + assertSerialize(t, table1ColInterval.ADD(table2ColInterval), "(table1.col_interval + table2.col_interval)") + assertSerialize(t, table1ColInterval.SUB(table2ColInterval), "(table1.col_interval - table2.col_interval)") + assertSerialize(t, table1ColInterval.MUL(table2ColInt), "(table1.col_interval * table2.col_int)") + assertSerialize(t, table1ColInterval.MUL(table2ColFloat), "(table1.col_interval * table2.col_float)") + assertSerialize(t, table1ColInterval.DIV(table2ColInt), "(table1.col_interval / table2.col_int)") + assertSerialize(t, table1ColInterval.DIV(table2ColFloat), "(table1.col_interval / table2.col_float)") +} diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 4a80954..38c429c 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -17,6 +17,7 @@ var table1ColTimestamp = TimestampColumn("col_timestamp") var table1ColTimestampz = TimestampzColumn("col_timestampz") var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") +var table1ColInterval = IntervalColumn("col_interval") var table1 = NewTable( "db", @@ -31,6 +32,7 @@ var table1 = NewTable( table1ColDate, table1ColTimestamp, table1ColTimestampz, + table1ColInterval, ) var table2Col3 = IntegerColumn("col3") @@ -44,6 +46,7 @@ var table2ColTimez = TimezColumn("col_timez") var table2ColTimestamp = TimestampColumn("col_timestamp") var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColDate = DateColumn("col_date") +var table2ColInterval = IntervalColumn("col_interval") var table2 = NewTable( "db", @@ -59,6 +62,7 @@ var table2 = NewTable( table2ColDate, table2ColTimestamp, table2ColTimestampz, + table2ColInterval, ) var table3Col1 = IntegerColumn("col1") diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 7b9d1d6..fbcffd8 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -71,6 +71,152 @@ func TestAllTypesInsertQuery(t *testing.T) { assert.DeepEqual(t, dest[1], allTypesRow1) } +func TestAllTypesFromSubQuery(t *testing.T) { + + subQuery := SELECT(AllTypes.AllColumns). + FROM(AllTypes). + AsTable("allTypesSubQuery") + + mainQuery := SELECT(subQuery.AllColumns()). + FROM(subQuery). + LIMIT(2) + + assert.Equal(t, mainQuery.DebugSql(), ` +SELECT "allTypesSubQuery"."all_types.small_int_ptr" AS "all_types.small_int_ptr", + "allTypesSubQuery"."all_types.small_int" AS "all_types.small_int", + "allTypesSubQuery"."all_types.integer_ptr" AS "all_types.integer_ptr", + "allTypesSubQuery"."all_types.integer" AS "all_types.integer", + "allTypesSubQuery"."all_types.big_int_ptr" AS "all_types.big_int_ptr", + "allTypesSubQuery"."all_types.big_int" AS "all_types.big_int", + "allTypesSubQuery"."all_types.decimal_ptr" AS "all_types.decimal_ptr", + "allTypesSubQuery"."all_types.decimal" AS "all_types.decimal", + "allTypesSubQuery"."all_types.numeric_ptr" AS "all_types.numeric_ptr", + "allTypesSubQuery"."all_types.numeric" AS "all_types.numeric", + "allTypesSubQuery"."all_types.real_ptr" AS "all_types.real_ptr", + "allTypesSubQuery"."all_types.real" AS "all_types.real", + "allTypesSubQuery"."all_types.double_precision_ptr" AS "all_types.double_precision_ptr", + "allTypesSubQuery"."all_types.double_precision" AS "all_types.double_precision", + "allTypesSubQuery"."all_types.smallserial" AS "all_types.smallserial", + "allTypesSubQuery"."all_types.serial" AS "all_types.serial", + "allTypesSubQuery"."all_types.bigserial" AS "all_types.bigserial", + "allTypesSubQuery"."all_types.var_char_ptr" AS "all_types.var_char_ptr", + "allTypesSubQuery"."all_types.var_char" AS "all_types.var_char", + "allTypesSubQuery"."all_types.char_ptr" AS "all_types.char_ptr", + "allTypesSubQuery"."all_types.char" AS "all_types.char", + "allTypesSubQuery"."all_types.text_ptr" AS "all_types.text_ptr", + "allTypesSubQuery"."all_types.text" AS "all_types.text", + "allTypesSubQuery"."all_types.bytea_ptr" AS "all_types.bytea_ptr", + "allTypesSubQuery"."all_types.bytea" AS "all_types.bytea", + "allTypesSubQuery"."all_types.timestampz_ptr" AS "all_types.timestampz_ptr", + "allTypesSubQuery"."all_types.timestampz" AS "all_types.timestampz", + "allTypesSubQuery"."all_types.timestamp_ptr" AS "all_types.timestamp_ptr", + "allTypesSubQuery"."all_types.timestamp" AS "all_types.timestamp", + "allTypesSubQuery"."all_types.date_ptr" AS "all_types.date_ptr", + "allTypesSubQuery"."all_types.date" AS "all_types.date", + "allTypesSubQuery"."all_types.timez_ptr" AS "all_types.timez_ptr", + "allTypesSubQuery"."all_types.timez" AS "all_types.timez", + "allTypesSubQuery"."all_types.time_ptr" AS "all_types.time_ptr", + "allTypesSubQuery"."all_types.time" AS "all_types.time", + "allTypesSubQuery"."all_types.interval_ptr" AS "all_types.interval_ptr", + "allTypesSubQuery"."all_types.interval" AS "all_types.interval", + "allTypesSubQuery"."all_types.boolean_ptr" AS "all_types.boolean_ptr", + "allTypesSubQuery"."all_types.boolean" AS "all_types.boolean", + "allTypesSubQuery"."all_types.point_ptr" AS "all_types.point_ptr", + "allTypesSubQuery"."all_types.bit_ptr" AS "all_types.bit_ptr", + "allTypesSubQuery"."all_types.bit" AS "all_types.bit", + "allTypesSubQuery"."all_types.bit_varying_ptr" AS "all_types.bit_varying_ptr", + "allTypesSubQuery"."all_types.bit_varying" AS "all_types.bit_varying", + "allTypesSubQuery"."all_types.tsvector_ptr" AS "all_types.tsvector_ptr", + "allTypesSubQuery"."all_types.tsvector" AS "all_types.tsvector", + "allTypesSubQuery"."all_types.uuid_ptr" AS "all_types.uuid_ptr", + "allTypesSubQuery"."all_types.uuid" AS "all_types.uuid", + "allTypesSubQuery"."all_types.xml_ptr" AS "all_types.xml_ptr", + "allTypesSubQuery"."all_types.xml" AS "all_types.xml", + "allTypesSubQuery"."all_types.json_ptr" AS "all_types.json_ptr", + "allTypesSubQuery"."all_types.json" AS "all_types.json", + "allTypesSubQuery"."all_types.jsonb_ptr" AS "all_types.jsonb_ptr", + "allTypesSubQuery"."all_types.jsonb" AS "all_types.jsonb", + "allTypesSubQuery"."all_types.integer_array_ptr" AS "all_types.integer_array_ptr", + "allTypesSubQuery"."all_types.integer_array" AS "all_types.integer_array", + "allTypesSubQuery"."all_types.text_array_ptr" AS "all_types.text_array_ptr", + "allTypesSubQuery"."all_types.text_array" AS "all_types.text_array", + "allTypesSubQuery"."all_types.jsonb_array" AS "all_types.jsonb_array", + "allTypesSubQuery"."all_types.text_multi_dim_array_ptr" AS "all_types.text_multi_dim_array_ptr", + "allTypesSubQuery"."all_types.text_multi_dim_array" AS "all_types.text_multi_dim_array" +FROM ( + SELECT all_types.small_int_ptr AS "all_types.small_int_ptr", + all_types.small_int AS "all_types.small_int", + all_types.integer_ptr AS "all_types.integer_ptr", + all_types.integer AS "all_types.integer", + all_types.big_int_ptr AS "all_types.big_int_ptr", + all_types.big_int AS "all_types.big_int", + all_types.decimal_ptr AS "all_types.decimal_ptr", + all_types.decimal AS "all_types.decimal", + all_types.numeric_ptr AS "all_types.numeric_ptr", + all_types.numeric AS "all_types.numeric", + all_types.real_ptr AS "all_types.real_ptr", + all_types.real AS "all_types.real", + all_types.double_precision_ptr AS "all_types.double_precision_ptr", + all_types.double_precision AS "all_types.double_precision", + all_types.smallserial AS "all_types.smallserial", + all_types.serial AS "all_types.serial", + all_types.bigserial AS "all_types.bigserial", + all_types.var_char_ptr AS "all_types.var_char_ptr", + all_types.var_char AS "all_types.var_char", + all_types.char_ptr AS "all_types.char_ptr", + all_types.char AS "all_types.char", + all_types.text_ptr AS "all_types.text_ptr", + all_types.text AS "all_types.text", + all_types.bytea_ptr AS "all_types.bytea_ptr", + all_types.bytea AS "all_types.bytea", + all_types.timestampz_ptr AS "all_types.timestampz_ptr", + all_types.timestampz AS "all_types.timestampz", + all_types.timestamp_ptr AS "all_types.timestamp_ptr", + all_types.timestamp AS "all_types.timestamp", + all_types.date_ptr AS "all_types.date_ptr", + all_types.date AS "all_types.date", + all_types.timez_ptr AS "all_types.timez_ptr", + all_types.timez AS "all_types.timez", + all_types.time_ptr AS "all_types.time_ptr", + all_types.time AS "all_types.time", + all_types.interval_ptr AS "all_types.interval_ptr", + all_types.interval AS "all_types.interval", + all_types.boolean_ptr AS "all_types.boolean_ptr", + all_types.boolean AS "all_types.boolean", + all_types.point_ptr AS "all_types.point_ptr", + all_types.bit_ptr AS "all_types.bit_ptr", + all_types.bit AS "all_types.bit", + all_types.bit_varying_ptr AS "all_types.bit_varying_ptr", + all_types.bit_varying AS "all_types.bit_varying", + all_types.tsvector_ptr AS "all_types.tsvector_ptr", + all_types.tsvector AS "all_types.tsvector", + all_types.uuid_ptr AS "all_types.uuid_ptr", + all_types.uuid AS "all_types.uuid", + all_types.xml_ptr AS "all_types.xml_ptr", + all_types.xml AS "all_types.xml", + all_types.json_ptr AS "all_types.json_ptr", + all_types.json AS "all_types.json", + all_types.jsonb_ptr AS "all_types.jsonb_ptr", + all_types.jsonb AS "all_types.jsonb", + all_types.integer_array_ptr AS "all_types.integer_array_ptr", + all_types.integer_array AS "all_types.integer_array", + all_types.text_array_ptr AS "all_types.text_array_ptr", + all_types.text_array AS "all_types.text_array", + all_types.jsonb_array AS "all_types.jsonb_array", + all_types.text_multi_dim_array_ptr AS "all_types.text_multi_dim_array_ptr", + all_types.text_multi_dim_array AS "all_types.text_multi_dim_array" + FROM test_sample.all_types + ) AS "allTypesSubQuery" +LIMIT 2; +`) + + dest := []model.AllTypes{} + err := mainQuery.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 2) +} + func TestExpressionOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Integer.IS_NULL().AS("result.is_null"), @@ -671,7 +817,20 @@ func TestInterval(t *testing.T) { INTERVALd(1*time.Hour), INTERVALd(24*time.Hour), INTERVALd(24*time.Hour+2*time.Hour+3*time.Minute+4*time.Second+5*time.Microsecond), - ) + + AllTypes.Interval.EQ(INTERVAL(2, HOUR, 20, MINUTE)).EQ(Bool(true)), + AllTypes.IntervalPtr.NOT_EQ(INTERVAL(2, HOUR, 20, MINUTE)).EQ(Bool(false)), + AllTypes.Interval.IS_DISTINCT_FROM(INTERVAL(2, HOUR, 20, MINUTE)).EQ(AllTypes.Boolean), + AllTypes.IntervalPtr.IS_NOT_DISTINCT_FROM(INTERVALd(10*time.Microsecond)).EQ(AllTypes.Boolean), + AllTypes.Interval.LT(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr), + AllTypes.Interval.LT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr), + AllTypes.Interval.GT(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr), + AllTypes.Interval.GT_EQ(AllTypes.IntervalPtr).EQ(AllTypes.BooleanPtr), + AllTypes.Interval.ADD(AllTypes.IntervalPtr).EQ(INTERVALd(17*time.Second)), + AllTypes.Interval.SUB(AllTypes.IntervalPtr).EQ(INTERVAL(100, MICROSECOND)), + AllTypes.IntervalPtr.MUL(Int(11)).EQ(AllTypes.Interval), + AllTypes.IntervalPtr.DIV(Float(22.222)).EQ(AllTypes.IntervalPtr), + ).FROM(AllTypes) //fmt.Println(stmt.DebugSql()) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index b8a9ba4..5100801 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -15,7 +15,6 @@ import ( ) func TestGeneratedModel(t *testing.T) { - actor := model.Actor{} assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") @@ -275,3 +274,345 @@ func newActorInfoTable() *ActorInfoTable { } } ` + +func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { + enumDir := testRoot + ".gentestdata/jetdb/test_sample/enum/" + modelDir := testRoot + ".gentestdata/jetdb/test_sample/model/" + tableDir := testRoot + ".gentestdata/jetdb/test_sample/table/" + + enumFiles, err := ioutil.ReadDir(enumDir) + assert.NilError(t, err) + + testutils.AssertFileNamesEqual(t, enumFiles, "mood.go") + testutils.AssertFileContent(t, enumDir+"mood.go", "\npackage enum", moodEnumContent) + + modelFiles, err := ioutil.ReadDir(modelDir) + assert.NilError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", + "mood.go", "person.go", "person_phone.go", "weird_names_table.go") + + testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent) + + tableFiles, err := ioutil.ReadDir(tableDir) + assert.NilError(t, err) + + testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", + "person.go", "person_phone.go", "weird_names_table.go") + + testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent) +} + +var moodEnumContent = ` +package enum + +import "github.com/go-jet/jet/postgres" + +var Mood = &struct { + Sad postgres.StringExpression + Ok postgres.StringExpression + Happy postgres.StringExpression +}{ + Sad: postgres.NewEnumValue("sad"), + Ok: postgres.NewEnumValue("ok"), + Happy: postgres.NewEnumValue("happy"), +} +` + +var allTypesModelContent = ` +package model + +import ( + "github.com/google/uuid" + "time" +) + +type AllTypes struct { + SmallIntPtr *int16 + SmallInt int16 + IntegerPtr *int32 + Integer int32 + BigIntPtr *int64 + BigInt int64 + DecimalPtr *float64 + Decimal float64 + NumericPtr *float64 + Numeric float64 + RealPtr *float32 + Real float32 + DoublePrecisionPtr *float64 + DoublePrecision float64 + Smallserial int16 + Serial int32 + Bigserial int64 + VarCharPtr *string + VarChar string + CharPtr *string + Char string + TextPtr *string + Text string + ByteaPtr *[]byte + Bytea []byte + TimestampzPtr *time.Time + Timestampz time.Time + TimestampPtr *time.Time + Timestamp time.Time + DatePtr *time.Time + Date time.Time + TimezPtr *time.Time + Timez time.Time + TimePtr *time.Time + Time time.Time + IntervalPtr *string + Interval string + BooleanPtr *bool + Boolean bool + PointPtr *string + BitPtr *string + Bit string + BitVaryingPtr *string + BitVarying string + TsvectorPtr *string + Tsvector string + UUIDPtr *uuid.UUID + UUID uuid.UUID + XMLPtr *string + XML string + JSONPtr *string + JSON string + JsonbPtr *string + Jsonb string + IntegerArrayPtr *string + IntegerArray string + TextArrayPtr *string + TextArray string + JsonbArray string + TextMultiDimArrayPtr *string + TextMultiDimArray string +} +` + +var allTypesTableContent = ` +package table + +import ( + "github.com/go-jet/jet/postgres" +) + +var AllTypes = newAllTypesTable() + +type AllTypesTable struct { + postgres.Table + + //Columns + SmallIntPtr postgres.ColumnInteger + SmallInt postgres.ColumnInteger + IntegerPtr postgres.ColumnInteger + Integer postgres.ColumnInteger + BigIntPtr postgres.ColumnInteger + BigInt postgres.ColumnInteger + DecimalPtr postgres.ColumnFloat + Decimal postgres.ColumnFloat + NumericPtr postgres.ColumnFloat + Numeric postgres.ColumnFloat + RealPtr postgres.ColumnFloat + Real postgres.ColumnFloat + DoublePrecisionPtr postgres.ColumnFloat + DoublePrecision postgres.ColumnFloat + Smallserial postgres.ColumnInteger + Serial postgres.ColumnInteger + Bigserial postgres.ColumnInteger + VarCharPtr postgres.ColumnString + VarChar postgres.ColumnString + CharPtr postgres.ColumnString + Char postgres.ColumnString + TextPtr postgres.ColumnString + Text postgres.ColumnString + ByteaPtr postgres.ColumnString + Bytea postgres.ColumnString + TimestampzPtr postgres.ColumnTimestampz + Timestampz postgres.ColumnTimestampz + TimestampPtr postgres.ColumnTimestamp + Timestamp postgres.ColumnTimestamp + DatePtr postgres.ColumnDate + Date postgres.ColumnDate + TimezPtr postgres.ColumnTimez + Timez postgres.ColumnTimez + TimePtr postgres.ColumnTime + Time postgres.ColumnTime + IntervalPtr postgres.ColumnInterval + Interval postgres.ColumnInterval + BooleanPtr postgres.ColumnBool + Boolean postgres.ColumnBool + PointPtr postgres.ColumnString + BitPtr postgres.ColumnString + Bit postgres.ColumnString + BitVaryingPtr postgres.ColumnString + BitVarying postgres.ColumnString + TsvectorPtr postgres.ColumnString + Tsvector postgres.ColumnString + UUIDPtr postgres.ColumnString + UUID postgres.ColumnString + XMLPtr postgres.ColumnString + XML postgres.ColumnString + JSONPtr postgres.ColumnString + JSON postgres.ColumnString + JsonbPtr postgres.ColumnString + Jsonb postgres.ColumnString + IntegerArrayPtr postgres.ColumnString + IntegerArray postgres.ColumnString + TextArrayPtr postgres.ColumnString + TextArray postgres.ColumnString + JsonbArray postgres.ColumnString + TextMultiDimArrayPtr postgres.ColumnString + TextMultiDimArray postgres.ColumnString + + AllColumns postgres.ColumnList + MutableColumns postgres.ColumnList +} + +// creates new AllTypesTable with assigned alias +func (a *AllTypesTable) AS(alias string) *AllTypesTable { + aliasTable := newAllTypesTable() + + aliasTable.Table.AS(alias) + + return aliasTable +} + +func newAllTypesTable() *AllTypesTable { + var ( + SmallIntPtrColumn = postgres.IntegerColumn("small_int_ptr") + SmallIntColumn = postgres.IntegerColumn("small_int") + IntegerPtrColumn = postgres.IntegerColumn("integer_ptr") + IntegerColumn = postgres.IntegerColumn("integer") + BigIntPtrColumn = postgres.IntegerColumn("big_int_ptr") + BigIntColumn = postgres.IntegerColumn("big_int") + DecimalPtrColumn = postgres.FloatColumn("decimal_ptr") + DecimalColumn = postgres.FloatColumn("decimal") + NumericPtrColumn = postgres.FloatColumn("numeric_ptr") + NumericColumn = postgres.FloatColumn("numeric") + RealPtrColumn = postgres.FloatColumn("real_ptr") + RealColumn = postgres.FloatColumn("real") + DoublePrecisionPtrColumn = postgres.FloatColumn("double_precision_ptr") + DoublePrecisionColumn = postgres.FloatColumn("double_precision") + SmallserialColumn = postgres.IntegerColumn("smallserial") + SerialColumn = postgres.IntegerColumn("serial") + BigserialColumn = postgres.IntegerColumn("bigserial") + VarCharPtrColumn = postgres.StringColumn("var_char_ptr") + VarCharColumn = postgres.StringColumn("var_char") + CharPtrColumn = postgres.StringColumn("char_ptr") + CharColumn = postgres.StringColumn("char") + TextPtrColumn = postgres.StringColumn("text_ptr") + TextColumn = postgres.StringColumn("text") + ByteaPtrColumn = postgres.StringColumn("bytea_ptr") + ByteaColumn = postgres.StringColumn("bytea") + TimestampzPtrColumn = postgres.TimestampzColumn("timestampz_ptr") + TimestampzColumn = postgres.TimestampzColumn("timestampz") + TimestampPtrColumn = postgres.TimestampColumn("timestamp_ptr") + TimestampColumn = postgres.TimestampColumn("timestamp") + DatePtrColumn = postgres.DateColumn("date_ptr") + DateColumn = postgres.DateColumn("date") + TimezPtrColumn = postgres.TimezColumn("timez_ptr") + TimezColumn = postgres.TimezColumn("timez") + TimePtrColumn = postgres.TimeColumn("time_ptr") + TimeColumn = postgres.TimeColumn("time") + IntervalPtrColumn = postgres.IntervalColumn("interval_ptr") + IntervalColumn = postgres.IntervalColumn("interval") + BooleanPtrColumn = postgres.BoolColumn("boolean_ptr") + BooleanColumn = postgres.BoolColumn("boolean") + PointPtrColumn = postgres.StringColumn("point_ptr") + BitPtrColumn = postgres.StringColumn("bit_ptr") + BitColumn = postgres.StringColumn("bit") + BitVaryingPtrColumn = postgres.StringColumn("bit_varying_ptr") + BitVaryingColumn = postgres.StringColumn("bit_varying") + TsvectorPtrColumn = postgres.StringColumn("tsvector_ptr") + TsvectorColumn = postgres.StringColumn("tsvector") + UUIDPtrColumn = postgres.StringColumn("uuid_ptr") + UUIDColumn = postgres.StringColumn("uuid") + XMLPtrColumn = postgres.StringColumn("xml_ptr") + XMLColumn = postgres.StringColumn("xml") + JSONPtrColumn = postgres.StringColumn("json_ptr") + JSONColumn = postgres.StringColumn("json") + JsonbPtrColumn = postgres.StringColumn("jsonb_ptr") + JsonbColumn = postgres.StringColumn("jsonb") + IntegerArrayPtrColumn = postgres.StringColumn("integer_array_ptr") + IntegerArrayColumn = postgres.StringColumn("integer_array") + TextArrayPtrColumn = postgres.StringColumn("text_array_ptr") + TextArrayColumn = postgres.StringColumn("text_array") + JsonbArrayColumn = postgres.StringColumn("jsonb_array") + TextMultiDimArrayPtrColumn = postgres.StringColumn("text_multi_dim_array_ptr") + TextMultiDimArrayColumn = postgres.StringColumn("text_multi_dim_array") + ) + + return &AllTypesTable{ + Table: postgres.NewTable("test_sample", "all_types", SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn), + + //Columns + SmallIntPtr: SmallIntPtrColumn, + SmallInt: SmallIntColumn, + IntegerPtr: IntegerPtrColumn, + Integer: IntegerColumn, + BigIntPtr: BigIntPtrColumn, + BigInt: BigIntColumn, + DecimalPtr: DecimalPtrColumn, + Decimal: DecimalColumn, + NumericPtr: NumericPtrColumn, + Numeric: NumericColumn, + RealPtr: RealPtrColumn, + Real: RealColumn, + DoublePrecisionPtr: DoublePrecisionPtrColumn, + DoublePrecision: DoublePrecisionColumn, + Smallserial: SmallserialColumn, + Serial: SerialColumn, + Bigserial: BigserialColumn, + VarCharPtr: VarCharPtrColumn, + VarChar: VarCharColumn, + CharPtr: CharPtrColumn, + Char: CharColumn, + TextPtr: TextPtrColumn, + Text: TextColumn, + ByteaPtr: ByteaPtrColumn, + Bytea: ByteaColumn, + TimestampzPtr: TimestampzPtrColumn, + Timestampz: TimestampzColumn, + TimestampPtr: TimestampPtrColumn, + Timestamp: TimestampColumn, + DatePtr: DatePtrColumn, + Date: DateColumn, + TimezPtr: TimezPtrColumn, + Timez: TimezColumn, + TimePtr: TimePtrColumn, + Time: TimeColumn, + IntervalPtr: IntervalPtrColumn, + Interval: IntervalColumn, + BooleanPtr: BooleanPtrColumn, + Boolean: BooleanColumn, + PointPtr: PointPtrColumn, + BitPtr: BitPtrColumn, + Bit: BitColumn, + BitVaryingPtr: BitVaryingPtrColumn, + BitVarying: BitVaryingColumn, + TsvectorPtr: TsvectorPtrColumn, + Tsvector: TsvectorColumn, + UUIDPtr: UUIDPtrColumn, + UUID: UUIDColumn, + XMLPtr: XMLPtrColumn, + XML: XMLColumn, + JSONPtr: JSONPtrColumn, + JSON: JSONColumn, + JsonbPtr: JsonbPtrColumn, + Jsonb: JsonbColumn, + IntegerArrayPtr: IntegerArrayPtrColumn, + IntegerArray: IntegerArrayColumn, + TextArrayPtr: TextArrayPtrColumn, + TextArray: TextArrayColumn, + JsonbArray: JsonbArrayColumn, + TextMultiDimArrayPtr: TextMultiDimArrayPtrColumn, + TextMultiDimArray: TextMultiDimArrayColumn, + + AllColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn}, + MutableColumns: postgres.ColumnList{SmallIntPtrColumn, SmallIntColumn, IntegerPtrColumn, IntegerColumn, BigIntPtrColumn, BigIntColumn, DecimalPtrColumn, DecimalColumn, NumericPtrColumn, NumericColumn, RealPtrColumn, RealColumn, DoublePrecisionPtrColumn, DoublePrecisionColumn, SmallserialColumn, SerialColumn, BigserialColumn, VarCharPtrColumn, VarCharColumn, CharPtrColumn, CharColumn, TextPtrColumn, TextColumn, ByteaPtrColumn, ByteaColumn, TimestampzPtrColumn, TimestampzColumn, TimestampPtrColumn, TimestampColumn, DatePtrColumn, DateColumn, TimezPtrColumn, TimezColumn, TimePtrColumn, TimeColumn, IntervalPtrColumn, IntervalColumn, BooleanPtrColumn, BooleanColumn, PointPtrColumn, BitPtrColumn, BitColumn, BitVaryingPtrColumn, BitVaryingColumn, TsvectorPtrColumn, TsvectorColumn, UUIDPtrColumn, UUIDColumn, XMLPtrColumn, XMLColumn, JSONPtrColumn, JSONColumn, JsonbPtrColumn, JsonbColumn, IntegerArrayPtrColumn, IntegerArrayColumn, TextArrayPtrColumn, TextArrayColumn, JsonbArrayColumn, TextMultiDimArrayPtrColumn, TextMultiDimArrayColumn}, + } +} +` diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index dda4fe7..49ab868 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -6,14 +6,19 @@ import ( _ "github.com/lib/pq" "github.com/pkg/profile" "os" + "os/exec" + "strings" "testing" ) var db *sql.DB +var testRoot string func TestMain(m *testing.M) { defer profile.Start().Stop() + setTestRoot() + var err error db, err = sql.Open("postgres", dbconfig.PostgresConnectString) if err != nil { @@ -25,3 +30,13 @@ func TestMain(m *testing.M) { os.Exit(ret) } + +func setTestRoot() { + cmd := exec.Command("git", "rev-parse", "--show-toplevel") + byteArr, err := cmd.Output() + if err != nil { + panic(err) + } + + testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" +} From 3efbb0ccd985f9edb3855ea0b862705c28324e32 Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 11 Feb 2020 10:25:13 +0100 Subject: [PATCH 14/19] Replace gotest.tools with github.com/stretchr/testify. --- .circleci/config.yml | 4 +- README.md | 2 +- internal/3rdparty/snaker/snaker_test.go | 2 +- internal/jet/clause_test.go | 2 +- internal/jet/sql_builder_test.go | 4 +- internal/jet/table_test.go | 2 +- internal/jet/testutils.go | 14 +- internal/jet/utils_test.go | 2 +- internal/testutils/test_utils.go | 44 +++-- internal/utils/utils_test.go | 2 +- mysql/insert_statement_test.go | 2 +- postgres/insert_statement_test.go | 2 +- qrm/internal/null_types_test.go | 48 ++--- qrm/utill_test.go | 28 +-- tests/mysql/alltypes_test.go | 38 ++-- tests/mysql/cast_test.go | 6 +- tests/mysql/delete_test.go | 2 +- tests/mysql/generator_test.go | 22 +-- tests/mysql/insert_test.go | 34 ++-- tests/mysql/lock_test.go | 8 +- tests/mysql/select_test.go | 48 ++--- tests/mysql/update_test.go | 10 +- tests/postgres/alltypes_test.go | 65 ++++--- tests/postgres/chinook_db_test.go | 22 +-- tests/postgres/delete_test.go | 8 +- tests/postgres/generator_test.go | 34 ++-- tests/postgres/insert_test.go | 20 +- tests/postgres/lock_test.go | 10 +- tests/postgres/northwind_test.go | 4 +- tests/postgres/sample_test.go | 26 +-- tests/postgres/scan_test.go | 244 ++++++++++++------------ tests/postgres/select_test.go | 142 +++++++------- tests/postgres/update_test.go | 12 +- tests/postgres/util_test.go | 6 +- 34 files changed, 462 insertions(+), 457 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index d3ea5d0..4907252 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,7 +46,7 @@ jobs: go get github.com/go-sql-driver/mysql go get github.com/pkg/profile - go get gotest.tools/assert + go get github.com/stretchr/testify/assert go get github.com/davecgh/go-spew/spew go get github.com/jstemmer/go-junit-report @@ -142,7 +142,7 @@ jobs: go get github.com/go-sql-driver/mysql go get github.com/pkg/profile - go get gotest.tools/assert + go get github.com/stretchr/testify/assert go get github.com/davecgh/go-spew/spew go get github.com/jstemmer/go-junit-report diff --git a/README.md b/README.md index 76d36be..721219a 100644 --- a/README.md +++ b/README.md @@ -560,7 +560,7 @@ At the moment Jet dependence only of: To run the tests, additional dependencies are required: - `github.com/pkg/profile` -- `gotest.tools/assert` +- `github.com/stretchr/testify` ## Versioning diff --git a/internal/3rdparty/snaker/snaker_test.go b/internal/3rdparty/snaker/snaker_test.go index b3704ea..83ae867 100644 --- a/internal/3rdparty/snaker/snaker_test.go +++ b/internal/3rdparty/snaker/snaker_test.go @@ -1,7 +1,7 @@ package snaker import ( - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) diff --git a/internal/jet/clause_test.go b/internal/jet/clause_test.go index 14ced64..9a86597 100644 --- a/internal/jet/clause_test.go +++ b/internal/jet/clause_test.go @@ -1,7 +1,7 @@ package jet import ( - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) diff --git a/internal/jet/sql_builder_test.go b/internal/jet/sql_builder_test.go index dc4a476..f7b5ade 100644 --- a/internal/jet/sql_builder_test.go +++ b/internal/jet/sql_builder_test.go @@ -2,7 +2,7 @@ package jet import ( "github.com/google/uuid" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -30,7 +30,7 @@ func TestArgToString(t *testing.T) { assert.Equal(t, argToString(uuid.MustParse("b68dbff4-a87d-11e9-a7f2-98ded00c39c6")), "'b68dbff4-a87d-11e9-a7f2-98ded00c39c6'") time, err := time.Parse("Mon Jan 2 15:04:05 -0700 MST 2006", "Mon Jan 2 15:04:05 -0700 MST 2006") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, argToString(time), "'2006-01-02 15:04:05-07:00'") func() { diff --git a/internal/jet/table_test.go b/internal/jet/table_test.go index 30182bc..d899a8f 100644 --- a/internal/jet/table_test.go +++ b/internal/jet/table_test.go @@ -1,7 +1,7 @@ package jet import ( - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 545f12c..1d5009e 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -1,7 +1,7 @@ package jet import ( - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "strconv" "testing" ) @@ -56,8 +56,8 @@ func assertClauseSerialize(t *testing.T, clause Serializer, query string, args . //fmt.Println(out.Buff.String()) - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + assert.Equal(t, out.Buff.String(), query) + assert.Equal(t, out.Args, args) } func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { @@ -76,14 +76,14 @@ func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, a //fmt.Println(out.Buff.String()) - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + assert.Equal(t, out.Buff.String(), query) + assert.Equal(t, out.Args, args) } func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) { out := SQLBuilder{Dialect: defaultDialect} projection.serializeForProjection(SelectStatementType, &out) - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + assert.Equal(t, out.Buff.String(), query) + assert.Equal(t, out.Args, args) } diff --git a/internal/jet/utils_test.go b/internal/jet/utils_test.go index ce1935a..e13d7ff 100644 --- a/internal/jet/utils_test.go +++ b/internal/jet/utils_test.go @@ -1,7 +1,7 @@ package jet import ( - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index c3d3ff0..5719e75 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -7,21 +7,23 @@ import ( "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/qrm" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "io/ioutil" "os" "path/filepath" "runtime" "testing" + + "github.com/google/go-cmp/cmp" ) // AssertExec assert statement execution for successful execution and number of rows affected func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { res, err := stmt.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) rows, err := res.RowsAffected() - assert.NilError(t, err) + assert.NoError(t, err) if len(rowsAffected) > 0 { assert.Equal(t, rows, rowsAffected[0]) @@ -49,7 +51,7 @@ func PrintJson(v interface{}) { // AssertJSON check if data json output is the same as expectedJSON func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { jsonData, err := json.MarshalIndent(data, "", "\t") - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON) } @@ -69,17 +71,17 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) fileJSONData, err := ioutil.ReadFile(filePath) - assert.NilError(t, err) + assert.NoError(t, err) if runtime.GOOS == "windows" { fileJSONData = bytes.Replace(fileJSONData, []byte("\r\n"), []byte("\n"), -1) } jsonData, err := json.MarshalIndent(data, "", "\t") - assert.NilError(t, err) + assert.NoError(t, err) - assert.Assert(t, string(fileJSONData) == string(jsonData)) - //assert.DeepEqual(t, string(fileJSONData), string(jsonData)) + assert.True(t, string(fileJSONData) == string(jsonData)) + //AssertDeepEqual(t, string(fileJSONData), string(jsonData)) } // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs @@ -90,7 +92,7 @@ func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, if len(expectedArgs) == 0 { return } - assert.DeepEqual(t, args, expectedArgs) + AssertDeepEqual(t, args, expectedArgs) } // AssertStatementSqlErr checks if statement Sql() panics with errorStr @@ -108,7 +110,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st _, args := query.Sql() if len(expectedArgs) > 0 { - assert.DeepEqual(t, args, expectedArgs) + AssertDeepEqual(t, args, expectedArgs) } debuqSql := query.DebugSql() @@ -122,10 +124,10 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali //fmt.Println(out.Buff.String()) - assert.DeepEqual(t, out.Buff.String(), query) + AssertDeepEqual(t, out.Buff.String(), query) if len(args) > 0 { - assert.DeepEqual(t, out.Args, args) + AssertDeepEqual(t, out.Args, args) } } @@ -134,10 +136,10 @@ func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Se out := jet.SQLBuilder{Dialect: dialect, Debug: true} jet.Serialize(clause, jet.SelectStatementType, &out) - assert.DeepEqual(t, out.Buff.String(), query) + AssertDeepEqual(t, out.Buff.String(), query) if len(args) > 0 { - assert.DeepEqual(t, out.Args, args) + AssertDeepEqual(t, out.Args, args) } } @@ -167,8 +169,8 @@ func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet out := jet.SQLBuilder{Dialect: dialect} jet.SerializeForProjection(projection, jet.SelectStatementType, &out) - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + AssertDeepEqual(t, out.Buff.String(), query) + AssertDeepEqual(t, out.Args, args) } // AssertQueryPanicErr check if statement Query execution panics with error errString @@ -185,13 +187,13 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter func AssertFileContent(t *testing.T, filePath string, contentBegin string, expectedContent string) { enumFileData, err := ioutil.ReadFile(filePath) - assert.NilError(t, err) + assert.NoError(t, err) beginIndex := bytes.Index(enumFileData, []byte(contentBegin)) //fmt.Println("-"+string(enumFileData[beginIndex:])+"-") - assert.DeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) + AssertDeepEqual(t, string(enumFileData[beginIndex:]), expectedContent) } // AssertFileNamesEqual check if all filesInfos are contained in fileNames @@ -205,6 +207,10 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } for _, fileName := range fileNames { - assert.Assert(t, fileNamesMap[fileName], fileName+" does not exist.") + assert.True(t, fileNamesMap[fileName], fileName+" does not exist.") } } + +func AssertDeepEqual(t *testing.T, actual, expected interface{}) { + assert.True(t, cmp.Equal(actual, expected)) +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index 8ee2d49..ac57bbd 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -2,7 +2,7 @@ package utils import ( "fmt" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index 0732486..6faf313 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -1,7 +1,7 @@ package mysql import ( - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 96f275b..ccb1404 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -1,7 +1,7 @@ package postgres import ( - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go index 70eb42f..f03d3ab 100644 --- a/qrm/internal/null_types_test.go +++ b/qrm/internal/null_types_test.go @@ -2,7 +2,7 @@ package internal import ( "fmt" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -10,10 +10,10 @@ import ( func TestNullByteArray(t *testing.T) { var array NullByteArray - assert.NilError(t, array.Scan(nil)) + assert.NoError(t, array.Scan(nil)) assert.Equal(t, array.Valid, false) - assert.NilError(t, array.Scan([]byte("bytea"))) + assert.NoError(t, array.Scan([]byte("bytea"))) assert.Equal(t, array.Valid, true) assert.Equal(t, string(array.ByteArray), string([]byte("bytea"))) @@ -23,21 +23,21 @@ func TestNullByteArray(t *testing.T) { func TestNullTime(t *testing.T) { var array NullTime - assert.NilError(t, array.Scan(nil)) + assert.NoError(t, array.Scan(nil)) assert.Equal(t, array.Valid, false) time := time.Now() - assert.NilError(t, array.Scan(time)) + assert.NoError(t, array.Scan(time)) assert.Equal(t, array.Valid, true) value, _ := array.Value() assert.Equal(t, value, time) - assert.NilError(t, array.Scan([]byte("13:10:11"))) + assert.NoError(t, array.Scan([]byte("13:10:11"))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") - assert.NilError(t, array.Scan("13:10:11")) + assert.NoError(t, array.Scan("13:10:11")) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") @@ -48,10 +48,10 @@ func TestNullTime(t *testing.T) { func TestNullInt8(t *testing.T) { var array NullInt8 - assert.NilError(t, array.Scan(nil)) + assert.NoError(t, array.Scan(nil)) assert.Equal(t, array.Valid, false) - assert.NilError(t, array.Scan(int64(11))) + assert.NoError(t, array.Scan(int64(11))) assert.Equal(t, array.Valid, true) value, _ := array.Value() assert.Equal(t, value, int8(11)) @@ -62,25 +62,25 @@ func TestNullInt8(t *testing.T) { func TestNullInt16(t *testing.T) { var array NullInt16 - assert.NilError(t, array.Scan(nil)) + assert.NoError(t, array.Scan(nil)) assert.Equal(t, array.Valid, false) - assert.NilError(t, array.Scan(int64(11))) + assert.NoError(t, array.Scan(int64(11))) assert.Equal(t, array.Valid, true) value, _ := array.Value() assert.Equal(t, value, int16(11)) - assert.NilError(t, array.Scan(int16(20))) + assert.NoError(t, array.Scan(int16(20))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int16(20)) - assert.NilError(t, array.Scan(int8(30))) + assert.NoError(t, array.Scan(int8(30))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int16(30)) - assert.NilError(t, array.Scan(uint8(30))) + assert.NoError(t, array.Scan(uint8(30))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int16(30)) @@ -91,35 +91,35 @@ func TestNullInt16(t *testing.T) { func TestNullInt32(t *testing.T) { var array NullInt32 - assert.NilError(t, array.Scan(nil)) + assert.NoError(t, array.Scan(nil)) assert.Equal(t, array.Valid, false) - assert.NilError(t, array.Scan(int64(11))) + assert.NoError(t, array.Scan(int64(11))) assert.Equal(t, array.Valid, true) value, _ := array.Value() assert.Equal(t, value, int32(11)) - assert.NilError(t, array.Scan(int32(32))) + assert.NoError(t, array.Scan(int32(32))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int32(32)) - assert.NilError(t, array.Scan(int16(20))) + assert.NoError(t, array.Scan(int16(20))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int32(20)) - assert.NilError(t, array.Scan(uint16(16))) + assert.NoError(t, array.Scan(uint16(16))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int32(16)) - assert.NilError(t, array.Scan(int8(30))) + assert.NoError(t, array.Scan(int8(30))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int32(30)) - assert.NilError(t, array.Scan(uint8(30))) + assert.NoError(t, array.Scan(uint8(30))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, int32(30)) @@ -130,15 +130,15 @@ func TestNullInt32(t *testing.T) { func TestNullFloat32(t *testing.T) { var array NullFloat32 - assert.NilError(t, array.Scan(nil)) + assert.NoError(t, array.Scan(nil)) assert.Equal(t, array.Valid, false) - assert.NilError(t, array.Scan(float64(64))) + assert.NoError(t, array.Scan(float64(64))) assert.Equal(t, array.Valid, true) value, _ := array.Value() assert.Equal(t, value, float32(64)) - assert.NilError(t, array.Scan(float32(32))) + assert.NoError(t, array.Scan(float32(32))) assert.Equal(t, array.Valid, true) value, _ = array.Value() assert.Equal(t, value, float32(32)) diff --git a/qrm/utill_test.go b/qrm/utill_test.go index 168c55f..045e2b9 100644 --- a/qrm/utill_test.go +++ b/qrm/utill_test.go @@ -2,28 +2,28 @@ package qrm import ( "github.com/google/uuid" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "reflect" "testing" "time" ) func TestIsSimpleModelType(t *testing.T) { - assert.Assert(t, isSimpleModelType(reflect.TypeOf(int8(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(int16(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(int32(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(int64(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(int8(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(int16(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(int32(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(int64(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) + assert.True(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(time.Now()))) - assert.Assert(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) + assert.True(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) + assert.True(t, isSimpleModelType(reflect.TypeOf(time.Now()))) + assert.True(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) complexModelType := struct { Field1 string diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 361bb8d..98a1a4a 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -14,7 +14,7 @@ import ( . "github.com/go-jet/jet/mysql" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" ) func TestAllTypes(t *testing.T) { @@ -26,7 +26,7 @@ func TestAllTypes(t *testing.T) { LIMIT(2). Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) @@ -45,7 +45,7 @@ func TestAllTypesViewSelect(t *testing.T) { dest := []AllTypesView{} err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert @@ -74,10 +74,10 @@ func TestUUID(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.Assert(t, dest.StrUUID != nil) - assert.Assert(t, dest.UUID.String() != uuid.UUID{}.String()) - assert.Assert(t, dest.StrUUID.String() != uuid.UUID{}.String()) + assert.NoError(t, err) + assert.True(t, dest.StrUUID != nil) + assert.True(t, dest.UUID.String() != uuid.UUID{}.String()) + assert.True(t, dest.StrUUID.String() != uuid.UUID{}.String()) assert.Equal(t, dest.StrUUID.String(), dest.BinUUID.String()) } @@ -119,7 +119,7 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) @@ -210,7 +210,7 @@ FROM test_sample.all_types; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") } @@ -307,7 +307,7 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json") } @@ -444,7 +444,7 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) @@ -516,7 +516,7 @@ func TestStringOperators(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) @@ -604,7 +604,7 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestDateExpressions(t *testing.T) { @@ -679,7 +679,7 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestDateTimeExpressions(t *testing.T) { @@ -756,7 +756,7 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestTimestampExpressions(t *testing.T) { @@ -832,13 +832,13 @@ FROM test_sample.all_types; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestTimeLiterals(t *testing.T) { loc, err := time.LoadLocation("Europe/Berlin") - assert.NilError(t, err) + assert.NoError(t, err) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 351387237, loc) @@ -877,7 +877,7 @@ LIMIT ?; } err = query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) @@ -960,7 +960,7 @@ func TestINTERVAL(t *testing.T) { //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) - assert.NilError(t, err) + assert.NoError(t, err) } var allTypesJson = ` diff --git a/tests/mysql/cast_test.go b/tests/mysql/cast_test.go index 07570a0..3ab1914 100644 --- a/tests/mysql/cast_test.go +++ b/tests/mysql/cast_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -55,9 +55,9 @@ FROM test_sample.all_types; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) - assert.DeepEqual(t, dest, Result{ + testutils.AssertDeepEqual(t, dest, Result{ As1: "test", Date1: *testutils.Date("2011-02-02"), Time: *testutils.TimeWithoutTimeZone("14:06:10"), diff --git a/tests/mysql/delete_test.go b/tests/mysql/delete_test.go index c3409a5..2e06d88 100644 --- a/tests/mysql/delete_test.go +++ b/tests/mysql/delete_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 89a07b4..8ae9347 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/generator/mysql" "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "io/ioutil" "os" "os/exec" @@ -25,23 +25,23 @@ func TestGenerator(t *testing.T) { DBName: "dvds", }) - assert.NilError(t, err) + assert.NoError(t, err) assertGeneratedFiles(t) } err := os.RemoveAll(genTestDirRoot) - assert.NilError(t, err) + assert.NoError(t, err) } func TestCmdGenerator(t *testing.T) { goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet") goInstallJet.Stderr = os.Stderr err := goInstallJet.Run() - assert.NilError(t, err) + assert.NoError(t, err) err = os.RemoveAll(genTestDir3) - assert.NilError(t, err) + assert.NoError(t, err) cmd := exec.Command("jet", "-source=MySQL", "-dbname=dvds", "-host=localhost", "-port=3306", "-user=jet", "-password=jet", "-path="+genTestDir3) @@ -50,18 +50,18 @@ func TestCmdGenerator(t *testing.T) { cmd.Stdout = os.Stdout err = cmd.Run() - assert.NilError(t, err) + assert.NoError(t, err) assertGeneratedFiles(t) err = os.RemoveAll(genTestDirRoot) - assert.NilError(t, err) + assert.NoError(t, err) } func assertGeneratedFiles(t *testing.T) { // Table SQL Builder files tableSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/table") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", @@ -71,7 +71,7 @@ func assertGeneratedFiles(t *testing.T) { // View SQL Builder files viewSQLBuilderFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/view") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") @@ -80,14 +80,14 @@ func assertGeneratedFiles(t *testing.T) { // Enums SQL Builder files enumFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/enum") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "film_rating.go", "film_list_rating.go", "nicer_but_slower_film_list_rating.go") testutils.AssertFileContent(t, genTestDir3+"/dvds/enum/film_rating.go", "\npackage enum", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir(genTestDir3 + "/dvds/model") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 697bccb..4c86012 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -32,7 +32,7 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES 102, "http://www.yahoo.com", "Yahoo", nil) _, err := insertQuery.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) insertedLinks := []model.Link{} @@ -41,18 +41,18 @@ INSERT INTO test_sample.link (id, url, name, description) VALUES ORDER_BY(Link.ID). Query(db, &insertedLinks) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(insertedLinks), 3) - assert.DeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) - assert.DeepEqual(t, insertedLinks[1], model.Link{ + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ ID: 101, URL: "http://www.google.com", Name: "Google", }) - assert.DeepEqual(t, insertedLinks[2], model.Link{ + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ ID: 102, URL: "http://www.yahoo.com", Name: "Yahoo", @@ -80,7 +80,7 @@ INSERT INTO test_sample.link VALUES 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") _, err := stmt.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) insertedLinks := []model.Link{} @@ -89,9 +89,9 @@ INSERT INTO test_sample.link VALUES ORDER_BY(Link.ID). Query(db, &insertedLinks) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(insertedLinks), 1) - assert.DeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) } func TestInsertModelObject(t *testing.T) { @@ -113,7 +113,7 @@ INSERT INTO test_sample.link (url, name) VALUES testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") _, err := query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { @@ -136,7 +136,7 @@ INSERT INTO test_sample.link VALUES testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) _, err := query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func TestInsertModelsObject(t *testing.T) { @@ -172,7 +172,7 @@ INSERT INTO test_sample.link (url, name) VALUES "http://www.yahoo.com", "Yahoo") _, err := query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func TestInsertUsingMutableColumns(t *testing.T) { @@ -207,14 +207,14 @@ INSERT INTO test_sample.link (url, name, description) VALUES "http://www.yahoo.com", "Yahoo", nil) _, err := stmt.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func TestInsertQuery(t *testing.T) { _, err := Link.DELETE(). WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))). Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) var expectedSQL = ` INSERT INTO test_sample.link (url, name) ( @@ -236,7 +236,7 @@ INSERT INTO test_sample.link (url, name) ( testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) _, err = query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) youtubeLinks := []model.Link{} err = Link. @@ -244,7 +244,7 @@ INSERT INTO test_sample.link (url, name) ( WHERE(Link.Name.EQ(String("Youtube"))). Query(db, &youtubeLinks) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(youtubeLinks), 2) } @@ -283,5 +283,5 @@ func TestInsertWithExecContext(t *testing.T) { func cleanUpLinkTable(t *testing.T) { _, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } diff --git a/tests/mysql/lock_test.go b/tests/mysql/lock_test.go index 57ffbe8..e3d5749 100644 --- a/tests/mysql/lock_test.go +++ b/tests/mysql/lock_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) @@ -16,7 +16,7 @@ LOCK TABLES dvds.customer READ; `) _, err := query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func TestLockWrite(t *testing.T) { @@ -27,7 +27,7 @@ LOCK TABLES dvds.customer WRITE; `) _, err := query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func TestUnlockTables(t *testing.T) { @@ -38,5 +38,5 @@ UNLOCK TABLES; `) _, err := query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index ed75bad..c47c6d4 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -7,7 +7,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/view" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) @@ -30,9 +30,9 @@ WHERE actor.actor_id = ?; actor := model.Actor{} err := query.Query(db, &actor) - assert.NilError(t, err) + assert.NoError(t, err) - assert.DeepEqual(t, actor, actor2) + testutils.AssertDeepEqual(t, actor, actor2) } var actor2 = model.Actor{ @@ -59,10 +59,10 @@ ORDER BY actor.actor_id; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 200) - assert.DeepEqual(t, dest[1], actor2) + testutils.AssertDeepEqual(t, dest[1], actor2) //testutils.PrintJson(dest) //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") @@ -136,7 +136,7 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) @@ -176,7 +176,7 @@ func TestSubQuery(t *testing.T) { } err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.SaveJsonFile(dest, "mysql/testdata/r_rating_films.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/r_rating_films.json") @@ -229,7 +229,7 @@ LIMIT ?; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestSelectUNION(t *testing.T) { @@ -265,7 +265,7 @@ LIMIT ?; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestSelectUNION_ALL(t *testing.T) { @@ -308,7 +308,7 @@ OFFSET ?; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestJoinQueryStruct(t *testing.T) { @@ -406,7 +406,7 @@ LIMIT ?; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //assert.Equal(t, len(dest), 1) //assert.Equal(t, len(dest[0].Films), 10) //assert.Equal(t, len(dest[0].Films[0].Actors), 10) @@ -450,10 +450,10 @@ FOR` tx, _ := db.Begin() _, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } for lockType, lockTypeStr := range getRowLockTestData() { @@ -464,10 +464,10 @@ FOR` tx, _ := db.Begin() _, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } if sourceIsMariaDB() { @@ -482,10 +482,10 @@ FOR` tx, _ := db.Begin() _, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } } @@ -514,7 +514,7 @@ SELECT true, dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestLockInShareMode(t *testing.T) { @@ -535,7 +535,7 @@ LOCK IN SHARE MODE; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestWindowFunction(t *testing.T) { @@ -612,7 +612,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestWindowClause(t *testing.T) { @@ -649,7 +649,7 @@ ORDER BY payment.customer_id; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestSimpleView(t *testing.T) { @@ -670,7 +670,7 @@ func TestSimpleView(t *testing.T) { var dest []ActorInfo err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) testutils.AssertJSON(t, dest[1:2], ` @@ -702,7 +702,7 @@ func TestJoinViewWithTable(t *testing.T) { } err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest[0].Rentals), 32) @@ -737,7 +737,7 @@ LIMIT 3; `) var dest []model.Customer err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 3) } diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index a6a07f5..c1e3f19 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -8,7 +8,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -40,9 +40,9 @@ WHERE link.name = 'Bing'; WHERE(Link.Name.EQ(String("Bong"))). Query(db, &links) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(links), 1) - assert.DeepEqual(t, links[0], model.Link{ + testutils.AssertDeepEqual(t, links[0], model.Link{ ID: 204, URL: "http://bong.com", Name: "Bong", @@ -244,7 +244,7 @@ func TestUpdateWithJoin(t *testing.T) { //fmt.Println(query.DebugSql()) _, err := query.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func setupLinkTableForUpdateTest(t *testing.T) { @@ -259,5 +259,5 @@ func setupLinkTableForUpdateTest(t *testing.T) { VALUES(204, "http://www.bing.com", "Bing", DEFAULT). Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index fbcffd8..0b30080 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -5,7 +5,7 @@ import ( "time" "github.com/google/uuid" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -19,23 +19,22 @@ func TestAllTypesSelect(t *testing.T) { dest := []model.AllTypes{} err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) - assert.DeepEqual(t, dest[0], allTypesRow0) - assert.DeepEqual(t, dest[1], allTypesRow1) + testutils.AssertDeepEqual(t, dest[0], allTypesRow0) + testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } func TestAllTypesViewSelect(t *testing.T) { - type AllTypesView model.AllTypes dest := []AllTypesView{} err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) - assert.DeepEqual(t, dest[0], AllTypesView(allTypesRow0)) - assert.DeepEqual(t, dest[1], AllTypesView(allTypesRow1)) + testutils.AssertDeepEqual(t, dest[0], AllTypesView(allTypesRow0)) + testutils.AssertDeepEqual(t, dest[1], AllTypesView(allTypesRow1)) } func TestAllTypesInsertModel(t *testing.T) { @@ -46,11 +45,11 @@ func TestAllTypesInsertModel(t *testing.T) { dest := []model.AllTypes{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0], allTypesRow0) - assert.DeepEqual(t, dest[1], allTypesRow1) + testutils.AssertDeepEqual(t, dest[0], allTypesRow0) + testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } func TestAllTypesInsertQuery(t *testing.T) { @@ -65,10 +64,10 @@ func TestAllTypesInsertQuery(t *testing.T) { dest := []model.AllTypes{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0], allTypesRow0) - assert.DeepEqual(t, dest[1], allTypesRow1) + testutils.AssertDeepEqual(t, dest[0], allTypesRow0) + testutils.AssertDeepEqual(t, dest[1], allTypesRow1) } func TestAllTypesFromSubQuery(t *testing.T) { @@ -213,7 +212,7 @@ LIMIT 2; dest := []model.AllTypes{} err := mainQuery.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) } @@ -252,7 +251,7 @@ LIMIT $5; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) @@ -321,7 +320,7 @@ func TestExpressionCast(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestStringOperators(t *testing.T) { @@ -401,7 +400,7 @@ func TestStringOperators(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestBoolOperators(t *testing.T) { @@ -470,7 +469,7 @@ LIMIT $5; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") } @@ -566,7 +565,7 @@ LIMIT $35; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) @@ -705,7 +704,7 @@ LIMIT $23; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.SaveJsonFile("./testdata/common/int_operators.json", dest) //testutils.PrintJson(dest) @@ -784,7 +783,7 @@ func TestTimeExpression(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestInterval(t *testing.T) { @@ -835,7 +834,7 @@ func TestInterval(t *testing.T) { //fmt.Println(stmt.DebugSql()) err := stmt.Query(db, &struct{}{}) - assert.NilError(t, err) + assert.NoError(t, err) } func TestSubQueryColumnReference(t *testing.T) { @@ -987,17 +986,17 @@ FROM` dest1 := []model.AllTypes{} err := stmt1.Query(db, &dest1) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest1), 2) assert.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean) assert.Equal(t, dest1[0].Integer, allTypesRow0.Integer) assert.Equal(t, dest1[0].Real, allTypesRow0.Real) assert.Equal(t, dest1[0].Text, allTypesRow0.Text) - assert.DeepEqual(t, dest1[0].Time, allTypesRow0.Time) - assert.DeepEqual(t, dest1[0].Timez, allTypesRow0.Timez) - assert.DeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp) - assert.DeepEqual(t, dest1[0].Timestampz, allTypesRow0.Timestampz) - assert.DeepEqual(t, dest1[0].Date, allTypesRow0.Date) + testutils.AssertDeepEqual(t, dest1[0].Time, allTypesRow0.Time) + testutils.AssertDeepEqual(t, dest1[0].Timez, allTypesRow0.Timez) + testutils.AssertDeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp) + testutils.AssertDeepEqual(t, dest1[0].Timestampz, allTypesRow0.Timestampz) + testutils.AssertDeepEqual(t, dest1[0].Date, allTypesRow0.Date) stmt2 := SELECT( subQuery.AllColumns(), @@ -1009,15 +1008,15 @@ FROM` dest2 := []model.AllTypes{} err = stmt2.Query(db, &dest2) - assert.NilError(t, err) - assert.DeepEqual(t, dest1, dest2) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest1, dest2) } } func TestTimeLiterals(t *testing.T) { loc, err := time.LoadLocation("Europe/Berlin") - assert.NilError(t, err) + assert.NoError(t, err) var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, loc) @@ -1052,7 +1051,7 @@ LIMIT $6; err = query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 8969b0b..b0b5efe 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -7,7 +7,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -30,11 +30,11 @@ ORDER BY "Album"."AlbumId" ASC; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 347) - assert.DeepEqual(t, dest[0], album1) - assert.DeepEqual(t, dest[1], album2) - assert.DeepEqual(t, dest[len(dest)-1], album347) + testutils.AssertDeepEqual(t, dest[0], album1) + testutils.AssertDeepEqual(t, dest[1], album2) + testutils.AssertDeepEqual(t, dest[len(dest)-1], album347) } func TestJoinEverything(t *testing.T) { @@ -103,7 +103,7 @@ func TestJoinEverything(t *testing.T) { err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 275) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") } @@ -143,7 +143,7 @@ ORDER BY "Employee"."EmployeeId"; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 8) testutils.AssertJSON(t, dest[0:2], ` [ @@ -236,11 +236,11 @@ ORDER BY "Album.AlbumId"; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0], album1) - assert.DeepEqual(t, dest[1], album2) + testutils.AssertDeepEqual(t, dest[0], album1) + testutils.AssertDeepEqual(t, dest[1], album2) } func TestQueryWithContext(t *testing.T) { @@ -327,7 +327,7 @@ ORDER BY "first10Artist"."Artist.ArtistId"; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //spew.Dump(dest) } diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index 352a788..855f6dc 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -48,11 +48,11 @@ RETURNING link.id AS "link.id", err := deleteStmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0].Name, "Gmail") - assert.DeepEqual(t, dest[1].Name, "Outlook") + testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") + testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") } func initForDeleteTest(t *testing.T) { diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 5100801..b98633a 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/generator/postgres" "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/dbconfig" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "io/ioutil" "os" "os/exec" @@ -19,7 +19,7 @@ func TestGeneratedModel(t *testing.T) { assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") - assert.Assert(t, ok) + assert.True(t, ok) assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key") assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") @@ -29,12 +29,12 @@ func TestGeneratedModel(t *testing.T) { assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16") filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") - assert.Assert(t, ok) + assert.True(t, ok) assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16") actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") - assert.Assert(t, ok) + assert.True(t, ok) assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") staff := model.Staff{} @@ -49,10 +49,10 @@ func TestCmdGenerator(t *testing.T) { goInstallJet := exec.Command("sh", "-c", "go install github.com/go-jet/jet/cmd/jet") goInstallJet.Stderr = os.Stderr err := goInstallJet.Run() - assert.NilError(t, err) + assert.NoError(t, err) err = os.RemoveAll(genTestDir2) - assert.NilError(t, err) + assert.NoError(t, err) cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", "-port=5432", "-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2) @@ -60,12 +60,12 @@ func TestCmdGenerator(t *testing.T) { cmd.Stdout = os.Stdout err = cmd.Run() - assert.NilError(t, err) + assert.NoError(t, err) assertGeneratedFiles(t) err = os.RemoveAll(genTestDir2) - assert.NilError(t, err) + assert.NoError(t, err) } func TestGenerator(t *testing.T) { @@ -83,19 +83,19 @@ func TestGenerator(t *testing.T) { SchemaName: "dvds", }) - assert.NilError(t, err) + assert.NoError(t, err) assertGeneratedFiles(t) } err := os.RemoveAll(genTestDir2) - assert.NilError(t, err) + assert.NoError(t, err) } func assertGeneratedFiles(t *testing.T) { // Table SQL Builder files tableSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/table") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", @@ -105,7 +105,7 @@ func assertGeneratedFiles(t *testing.T) { // View SQL Builder files viewSQLBuilderFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/view") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go") @@ -114,14 +114,14 @@ func assertGeneratedFiles(t *testing.T) { // Enums SQL Builder files enumFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/enum") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mpaa_rating.go") testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/enum/mpaa_rating.go", "\npackage enum", mpaaRatingEnumFile) // Model files modelFiles, err := ioutil.ReadDir("./.gentestdata2/jetdb/dvds/model") - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", "customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go", @@ -281,13 +281,13 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { tableDir := testRoot + ".gentestdata/jetdb/test_sample/table/" enumFiles, err := ioutil.ReadDir(enumDir) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, enumFiles, "mood.go") testutils.AssertFileContent(t, enumDir+"mood.go", "\npackage enum", moodEnumContent) modelFiles, err := ioutil.ReadDir(modelDir) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", "mood.go", "person.go", "person_phone.go", "weird_names_table.go") @@ -295,7 +295,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent) tableFiles, err := ioutil.ReadDir(tableDir) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", "person.go", "person_phone.go", "weird_names_table.go") diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 161ffed..38c135e 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -39,23 +39,23 @@ RETURNING link.id AS "link.id", err := insertQuery.Query(db, &insertedLinks) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(insertedLinks), 3) - assert.DeepEqual(t, insertedLinks[0], model.Link{ + testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{ ID: 100, URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", }) - assert.DeepEqual(t, insertedLinks[1], model.Link{ + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ ID: 101, URL: "http://www.google.com", Name: "Google", }) - assert.DeepEqual(t, insertedLinks[2], model.Link{ + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ ID: 102, URL: "http://www.yahoo.com", Name: "Yahoo", @@ -68,9 +68,9 @@ RETURNING link.id AS "link.id", ORDER_BY(Link.ID). Query(db, &allLinks) - assert.NilError(t, err) + assert.NoError(t, err) - assert.DeepEqual(t, insertedLinks, allLinks) + testutils.AssertDeepEqual(t, insertedLinks, allLinks) } func TestInsertEmptyColumnList(t *testing.T) { @@ -206,7 +206,7 @@ func TestInsertQuery(t *testing.T) { _, err := Link.DELETE(). WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) var expectedSQL = ` INSERT INTO test_sample.link (url, name) ( @@ -236,7 +236,7 @@ RETURNING link.id AS "link.id", err = query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) youtubeLinks := []model.Link{} err = Link. @@ -244,7 +244,7 @@ RETURNING link.id AS "link.id", WHERE(Link.Name.EQ(String("Youtube"))). Query(db, &youtubeLinks) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(youtubeLinks), 2) } diff --git a/tests/postgres/lock_test.go b/tests/postgres/lock_test.go index e5ace94..acfb852 100644 --- a/tests/postgres/lock_test.go +++ b/tests/postgres/lock_test.go @@ -3,7 +3,7 @@ package postgres import ( "context" "github.com/go-jet/jet/internal/testutils" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" @@ -35,11 +35,11 @@ LOCK TABLE dvds.address IN` _, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } for _, lockMode := range testData { @@ -51,11 +51,11 @@ LOCK TABLE dvds.address IN` _, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } } diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index 8a02665..a50122b 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -4,7 +4,7 @@ import ( "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/northwind/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) @@ -59,7 +59,7 @@ func TestNorthwindJoinEverything(t *testing.T) { } err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //jsonSave("./testdata/northwind-all.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 41c62f8..5989d17 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -6,7 +6,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/google/uuid" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) @@ -25,9 +25,9 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; result := model.AllTypes{} err := query.Query(db, &result) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - assert.DeepEqual(t, result.UUIDPtr, UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + testutils.AssertDeepEqual(t, result.UUIDPtr, UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) } func TestUUIDComplex(t *testing.T) { @@ -46,7 +46,7 @@ func TestUUIDComplex(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) testutils.AssertJSON(t, dest, ` [ @@ -96,7 +96,7 @@ func TestUUIDComplex(t *testing.T) { } } err := singleQuery.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSON(t, dest, ` { @@ -132,7 +132,7 @@ func TestUUIDComplex(t *testing.T) { } err := leftQuery.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSON(t, dest, ` [ { @@ -194,7 +194,7 @@ FROM test_sample.person; err := query.Query(db, &result) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSON(t, result, ` [ { @@ -258,9 +258,9 @@ ORDER BY employee.employee_id; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 8) - assert.DeepEqual(t, dest[0].Employee, model.Employee{ + testutils.AssertDeepEqual(t, dest[0].Employee, model.Employee{ EmployeeID: 1, FirstName: "Windy", LastName: "Hays", @@ -268,9 +268,9 @@ ORDER BY employee.employee_id; ManagerID: nil, }) - assert.Assert(t, dest[0].Manager == nil) + assert.True(t, dest[0].Manager == nil) - assert.DeepEqual(t, dest[7].Employee, model.Employee{ + testutils.AssertDeepEqual(t, dest[7].Employee, model.Employee{ EmployeeID: 8, FirstName: "Salley", LastName: "Lester", @@ -306,10 +306,10 @@ FROM test_sample."WEIRD NAMES TABLE"; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 1) - assert.DeepEqual(t, dest[0], model.WeirdNamesTable{ + testutils.AssertDeepEqual(t, dest[0], model.WeirdNamesTable{ WeirdColumnName1: "Doe", WeirdColumnName2: "Doe", WeirdColumnName3: "Doe", diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 11dac96..40e18f8 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -8,7 +8,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" "github.com/google/uuid" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) @@ -53,38 +53,38 @@ func TestScanToValidDestination(t *testing.T) { dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) }) t.Run("global query function scan", func(t *testing.T) { queryStr, args := query.Sql() dest := []struct{}{} err := qrm.Query(nil, db, queryStr, args, &dest) - assert.NilError(t, err) + assert.NoError(t, err) }) t.Run("pointer to slice", func(t *testing.T) { err := query.Query(db, &[]struct{}{}) - assert.NilError(t, err) + assert.NoError(t, err) }) t.Run("pointer to slice of pointer to structs", func(t *testing.T) { err := query.Query(db, &[]*struct{}{}) - assert.NilError(t, err) + assert.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { err := query.Query(db, &[]int32{}) - assert.NilError(t, err) + assert.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { err := query.Query(db, &[]*int32{}) - assert.NilError(t, err) + assert.NoError(t, err) }) } @@ -99,16 +99,16 @@ func TestScanToStruct(t *testing.T) { dest := model.Inventory{} err := query.LIMIT(1).Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, inventory1, dest) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, inventory1, dest) }) t.Run("multiple structs, just first one used", func(t *testing.T) { dest := model.Inventory{} err := query.LIMIT(10).Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, inventory1, dest) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, inventory1, dest) }) t.Run("one struct", func(t *testing.T) { @@ -117,8 +117,8 @@ func TestScanToStruct(t *testing.T) { }{} err := query.LIMIT(1).Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, inventory1, dest.Inventory) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, inventory1, dest.Inventory) }) t.Run("one struct", func(t *testing.T) { @@ -127,8 +127,8 @@ func TestScanToStruct(t *testing.T) { }{} err := query.LIMIT(1).Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, inventory1, *dest.Inventory) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, inventory1, *dest.Inventory) }) t.Run("invalid dest", func(t *testing.T) { @@ -158,7 +158,7 @@ func TestScanToStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, *dest.InventoryID, int32(1)) assert.Equal(t, dest.FilmID, int16(1)) @@ -175,7 +175,7 @@ func TestScanToStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) }) t.Run("type mismatch scanner type", func(t *testing.T) { @@ -217,10 +217,10 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.DeepEqual(t, dest.Film, film1) - assert.DeepEqual(t, dest.Store, store1) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, dest.Film, film1) + testutils.AssertDeepEqual(t, dest.Store, store1) }) t.Run("embedded pointer structs", func(t *testing.T) { @@ -232,10 +232,10 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, *dest.Inventory, inventory1) - assert.DeepEqual(t, *dest.Film, film1) - assert.DeepEqual(t, *dest.Store, store1) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, *dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, *dest.Film, film1) + testutils.AssertDeepEqual(t, *dest.Store, store1) }) t.Run("embedded unused structs", func(t *testing.T) { @@ -246,9 +246,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.DeepEqual(t, dest.Actor, model.Actor{}) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, dest.Actor, model.Actor{}) }) t.Run("embedded unused pointer structs", func(t *testing.T) { @@ -259,9 +259,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil)) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, dest.Actor, (*model.Actor)(nil)) }) t.Run("embedded unused pointer structs", func(t *testing.T) { @@ -272,9 +272,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil)) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, dest.Actor, (*model.Actor)(nil)) }) t.Run("embedded pointer to selected column", func(t *testing.T) { @@ -291,9 +291,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.Assert(t, dest.Actor != nil) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + assert.True(t, dest.Actor != nil) }) t.Run("struct embedded unused pointer", func(t *testing.T) { @@ -306,9 +306,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.DeepEqual(t, dest.Actor, (*struct{ model.Actor })(nil)) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, dest.Actor, (*struct{ model.Actor })(nil)) }) t.Run("multiple embedded unused pointer", func(t *testing.T) { @@ -322,9 +322,9 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.DeepEqual(t, dest.Actor, (*struct { + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, dest.Actor, (*struct { model.Actor model.Language })(nil)) @@ -341,11 +341,11 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.Assert(t, dest.Actor != nil) - assert.DeepEqual(t, dest.Actor.Actor, model.Actor{}) - assert.DeepEqual(t, dest.Actor.Film, film1) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + assert.True(t, dest.Actor != nil) + testutils.AssertDeepEqual(t, dest.Actor.Actor, model.Actor{}) + testutils.AssertDeepEqual(t, dest.Actor.Film, film1) }) t.Run("field not nil, deeply nested selected model", func(t *testing.T) { @@ -361,11 +361,11 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.Assert(t, dest.Actor != nil) - assert.Assert(t, dest.Actor.Film != nil) - assert.DeepEqual(t, dest.Actor.Film.Film, &film1) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + assert.True(t, dest.Actor != nil) + assert.True(t, dest.Actor.Film != nil) + testutils.AssertDeepEqual(t, dest.Actor.Film.Film, &film1) }) t.Run("embedded structs", func(t *testing.T) { @@ -398,15 +398,15 @@ func TestScanToNestedStruct(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Inventory, inventory1) - assert.DeepEqual(t, dest.Film.Film, film1) - assert.DeepEqual(t, dest.Store, store1) - assert.DeepEqual(t, dest.Film.Language, language1) - assert.DeepEqual(t, dest.Film.Lang.Language, language1) - assert.DeepEqual(t, dest.Film.Lang2.Language, language1) - assert.DeepEqual(t, dest.Film.Language2, &language1) - assert.DeepEqual(t, model.Language(*dest.Film.Language3), language1) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Inventory, inventory1) + testutils.AssertDeepEqual(t, dest.Film.Film, film1) + testutils.AssertDeepEqual(t, dest.Store, store1) + testutils.AssertDeepEqual(t, dest.Film.Language, language1) + testutils.AssertDeepEqual(t, dest.Film.Lang.Language, language1) + testutils.AssertDeepEqual(t, dest.Film.Lang2.Language, language1) + testutils.AssertDeepEqual(t, dest.Film.Language2, &language1) + testutils.AssertDeepEqual(t, model.Language(*dest.Film.Language3), language1) }) } @@ -423,18 +423,18 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) - assert.DeepEqual(t, dest[0], inventory1) - assert.DeepEqual(t, dest[1], inventory2) + testutils.AssertDeepEqual(t, dest[0], inventory1) + testutils.AssertDeepEqual(t, dest[1], inventory2) }) t.Run("slice of ints", func(t *testing.T) { var dest []int32 err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest, []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest, []int32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) }) @@ -442,7 +442,7 @@ func TestScanToSlice(t *testing.T) { var dest []int err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) }) t.Run("slice type mismatch", func(t *testing.T) { @@ -473,9 +473,9 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest.Film, film1) - assert.DeepEqual(t, dest.IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest.Film, film1) + testutils.AssertDeepEqual(t, dest.IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) }) t.Run("slice of structs with slice of ints", func(t *testing.T) { @@ -486,12 +486,12 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0].Film, film1) - assert.DeepEqual(t, dest[0].IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) - assert.DeepEqual(t, dest[1].Film, film2) - assert.DeepEqual(t, dest[1].IDs, []int32{9, 10}) + testutils.AssertDeepEqual(t, dest[0].Film, film1) + testutils.AssertDeepEqual(t, dest[0].IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) + testutils.AssertDeepEqual(t, dest[1].Film, film2) + testutils.AssertDeepEqual(t, dest[1].IDs, []int32{9, 10}) }) t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) { @@ -502,13 +502,13 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0].Film, film1) - assert.DeepEqual(t, dest[0].IDs, []*int32{Int32Ptr(1), Int32Ptr(2), Int32Ptr(3), Int32Ptr(4), + testutils.AssertDeepEqual(t, dest[0].Film, film1) + testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{Int32Ptr(1), Int32Ptr(2), Int32Ptr(3), Int32Ptr(4), Int32Ptr(5), Int32Ptr(6), Int32Ptr(7), Int32Ptr(8)}) - assert.DeepEqual(t, dest[1].Film, film2) - assert.DeepEqual(t, dest[1].IDs, []*int32{Int32Ptr(9), Int32Ptr(10)}) + testutils.AssertDeepEqual(t, dest[1].Film, film2) + testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{Int32Ptr(9), Int32Ptr(10)}) }) t.Run("complex struct 1", func(t *testing.T) { @@ -520,13 +520,13 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) - assert.DeepEqual(t, dest[0].Inventory, inventory1) - assert.DeepEqual(t, dest[0].Film, film1) - assert.DeepEqual(t, dest[0].Store, store1) + testutils.AssertDeepEqual(t, dest[0].Inventory, inventory1) + testutils.AssertDeepEqual(t, dest[0].Film, film1) + testutils.AssertDeepEqual(t, dest[0].Store, store1) - assert.DeepEqual(t, dest[1].Inventory, inventory2) + testutils.AssertDeepEqual(t, dest[1].Inventory, inventory2) }) t.Run("complex struct 2", func(t *testing.T) { @@ -538,13 +538,13 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) - assert.DeepEqual(t, dest[0].Inventory, &inventory1) - assert.DeepEqual(t, dest[0].Film, film1) - assert.DeepEqual(t, dest[0].Store, &store1) + testutils.AssertDeepEqual(t, dest[0].Inventory, &inventory1) + testutils.AssertDeepEqual(t, dest[0].Film, film1) + testutils.AssertDeepEqual(t, dest[0].Store, &store1) - assert.DeepEqual(t, dest[1].Inventory, &inventory2) + testutils.AssertDeepEqual(t, dest[1].Inventory, &inventory2) }) t.Run("complex struct 3", func(t *testing.T) { @@ -558,13 +558,13 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) - assert.DeepEqual(t, dest[0].Inventory, inventory1) - assert.DeepEqual(t, dest[0].Film, &film1) - assert.DeepEqual(t, dest[0].Store.Store, &store1) + testutils.AssertDeepEqual(t, dest[0].Inventory, inventory1) + testutils.AssertDeepEqual(t, dest[0].Film, &film1) + testutils.AssertDeepEqual(t, dest[0].Store.Store, &store1) - assert.DeepEqual(t, dest[1].Inventory, inventory2) + testutils.AssertDeepEqual(t, dest[1].Inventory, inventory2) }) t.Run("complex struct 4", func(t *testing.T) { @@ -579,12 +579,12 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0].Film, film1) - assert.DeepEqual(t, len(dest[0].Inventories), 8) - assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) - assert.DeepEqual(t, dest[0].Inventories[0].Store, store1) + testutils.AssertDeepEqual(t, dest[0].Film, film1) + testutils.AssertDeepEqual(t, len(dest[0].Inventories), 8) + testutils.AssertDeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) + testutils.AssertDeepEqual(t, dest[0].Inventories[0].Store, store1) }) t.Run("complex struct 5", func(t *testing.T) { @@ -601,14 +601,14 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0].Film, film1) + testutils.AssertDeepEqual(t, dest[0].Film, film1) assert.Equal(t, len(dest[0].Inventories), 8) - assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) - assert.Assert(t, dest[0].Inventories[0].Rentals == nil) - assert.Assert(t, dest[0].Inventories[0].Rentals2 == nil) + testutils.AssertDeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) + assert.True(t, dest[0].Inventories[0].Rentals == nil) + assert.True(t, dest[0].Inventories[0].Rentals2 == nil) }) }) @@ -638,16 +638,16 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 108) - assert.DeepEqual(t, dest[100].Country, countryUk) + testutils.AssertDeepEqual(t, dest[100].Country, countryUk) assert.Equal(t, len(dest[100].Cities), 8) - assert.DeepEqual(t, dest[100].Cities[2].City, cityLondon) + testutils.AssertDeepEqual(t, dest[100].Cities[2].City, cityLondon) assert.Equal(t, len(dest[100].Cities[2].Adresses), 2) - assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Address, address256) - assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Customer, customer256) - assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Address, addres517) - assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Customer, customer512) + testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[0].Address, address256) + testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[0].Customer, customer256) + testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[1].Address, addres517) + testutils.AssertDeepEqual(t, dest[100].Cities[2].Adresses[1].Customer, customer512) }) t.Run("dest1", func(t *testing.T) { @@ -667,16 +667,16 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 108) - assert.DeepEqual(t, dest[100].Country, &countryUk) + testutils.AssertDeepEqual(t, dest[100].Country, &countryUk) assert.Equal(t, len(dest[100].Cities), 8) - assert.DeepEqual(t, dest[100].Cities[2].City, &cityLondon) + testutils.AssertDeepEqual(t, dest[100].Cities[2].City, &cityLondon) assert.Equal(t, len(*dest[100].Cities[2].Adresses), 2) - assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Address, &address256) - assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Customer, &customer256) - assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Address, &addres517) - assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Customer, &customer512) + testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Address, &address256) + testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Customer, &customer256) + testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Address, &addres517) + testutils.AssertDeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Customer, &customer512) }) }) @@ -716,8 +716,8 @@ func TestStructScanAllNull(t *testing.T) { err := query.Query(db, &dest) - assert.NilError(t, err) - assert.DeepEqual(t, dest, struct { + assert.NoError(t, err) + testutils.AssertDeepEqual(t, dest, struct { Null1 *int Null2 *int }{}) diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 35ba373..d354360 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -8,7 +8,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/table" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/view" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -33,7 +33,7 @@ WHERE actor.actor_id = 2; actor := model.Actor{} err := query.Query(db, &actor) - assert.NilError(t, err) + assert.NoError(t, err) expectedActor := model.Actor{ ActorID: 2, @@ -42,7 +42,7 @@ WHERE actor.actor_id = 2; LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), } - assert.DeepEqual(t, actor, expectedActor) + testutils.AssertDeepEqual(t, actor, expectedActor) } func TestClassicSelect(t *testing.T) { @@ -84,7 +84,7 @@ LIMIT 30; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 30) } @@ -110,13 +110,13 @@ ORDER BY customer.customer_id ASC; testutils.AssertDebugStatementSql(t, query, expectedSQL) err := query.Query(db, &customers) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(customers), 599) - assert.DeepEqual(t, customer0, customers[0]) - assert.DeepEqual(t, customer1, customers[1]) - assert.DeepEqual(t, lastCustomer, customers[598]) + testutils.AssertDeepEqual(t, customer0, customers[0]) + testutils.AssertDeepEqual(t, customer1, customers[1]) + testutils.AssertDeepEqual(t, lastCustomer, customers[598]) } func TestSelectAndUnionInProjection(t *testing.T) { @@ -164,7 +164,7 @@ LIMIT 12; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestJoinQueryStruct(t *testing.T) { @@ -253,7 +253,7 @@ LIMIT 1000; err := query.Query(db, &languageActorFilm) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(languageActorFilm), 1) assert.Equal(t, len(languageActorFilm[0].Films), 10) assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10) @@ -302,7 +302,7 @@ LIMIT 15; err := query.Query(db, &filmsPerLanguage) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(filmsPerLanguage), 1) assert.Equal(t, len(filmsPerLanguage[0].Film), limit) @@ -313,7 +313,7 @@ LIMIT 15; filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err = query.Query(db, &filmsPerLanguageWithPtrs) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(filmsPerLanguage), 1) assert.Equal(t, len(filmsPerLanguage[0].Film), limit) } @@ -359,7 +359,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) assert.Equal(t, dest[0].City.City, "London") @@ -423,7 +423,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) assert.Equal(t, dest[0].Name, "London") @@ -481,7 +481,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) assert.Equal(t, dest[0].CityName, "London") @@ -538,7 +538,7 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) testutils.AssertJSON(t, dest, ` [ @@ -597,7 +597,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err := query.Query(db, &filmsPerLanguageWithPtrs) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(filmsPerLanguageWithPtrs), 1) assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit)) } @@ -609,7 +609,7 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { err := query.Query(db, &customers) - assert.NilError(t, err) + assert.NoError(t, err) //spew.Dump(customers) @@ -623,7 +623,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { ORDER_BY(Customer.FirstName.ASC()). Query(db, &customersAsc) - assert.NilError(t, err) + assert.NoError(t, err) firstCustomerAsc := customersAsc[0] lastCustomerAsc := customersAsc[len(customersAsc)-1] @@ -633,20 +633,20 @@ func TestSelectOrderByAscDesc(t *testing.T) { ORDER_BY(Customer.FirstName.DESC()). Query(db, &customersDesc) - assert.NilError(t, err) + assert.NoError(t, err) firstCustomerDesc := customersDesc[0] lastCustomerDesc := customersDesc[len(customersAsc)-1] - assert.DeepEqual(t, firstCustomerAsc, lastCustomerDesc) - assert.DeepEqual(t, lastCustomerAsc, firstCustomerDesc) + testutils.AssertDeepEqual(t, firstCustomerAsc, lastCustomerDesc) + testutils.AssertDeepEqual(t, lastCustomerAsc, firstCustomerDesc) customersAscDesc := []model.Customer{} err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). ORDER_BY(Customer.FirstName.ASC(), Customer.LastName.DESC()). Query(db, &customersAscDesc) - assert.NilError(t, err) + assert.NoError(t, err) customerAscDesc326 := model.Customer{ CustomerID: 67, @@ -660,8 +660,8 @@ func TestSelectOrderByAscDesc(t *testing.T) { LastName: "Knott", } - assert.DeepEqual(t, customerAscDesc326, customersAscDesc[326]) - assert.DeepEqual(t, customerAscDesc327, customersAscDesc[327]) + testutils.AssertDeepEqual(t, customerAscDesc326, customersAscDesc[326]) + testutils.AssertDeepEqual(t, customerAscDesc327, customersAscDesc[327]) } func TestSelectFullJoin(t *testing.T) { @@ -702,16 +702,16 @@ ORDER BY customer.customer_id ASC; err := query.Query(db, &allCustomersAndAddress) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(allCustomersAndAddress), 603) - assert.DeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) - assert.Assert(t, allCustomersAndAddress[0].Address != nil) + testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) + assert.True(t, allCustomersAndAddress[0].Address != nil) lastCustomerAddress := allCustomersAndAddress[len(allCustomersAndAddress)-1] - assert.Assert(t, lastCustomerAddress.Customer == nil) - assert.Assert(t, lastCustomerAddress.Address != nil) + assert.True(t, lastCustomerAddress.Customer == nil) + assert.True(t, lastCustomerAddress.Address != nil) } @@ -757,7 +757,7 @@ LIMIT 1000; assert.Equal(t, len(customerAddresCrosJoined), 1000) - assert.NilError(t, err) + assert.NoError(t, err) } func TestSelectSelfJoin(t *testing.T) { @@ -813,7 +813,7 @@ ORDER BY f1.film_id ASC; err := query.Query(db, &theSameLengthFilms) - assert.NilError(t, err) + assert.NoError(t, err) //spew.Dump(theSameLengthFilms) @@ -854,12 +854,12 @@ LIMIT 1000; err := query.Query(db, &films) - assert.NilError(t, err) + assert.NoError(t, err) //spew.Dump(films) assert.Equal(t, len(films), 1000) - assert.DeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) + testutils.AssertDeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) } func TestSubQuery(t *testing.T) { @@ -911,7 +911,7 @@ FROM dvds.actor err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestSelectFunctions(t *testing.T) { @@ -931,7 +931,7 @@ FROM dvds.film; err := query.Query(db, &ret) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, ret.MaxFilmRate, 4.99) } @@ -973,13 +973,13 @@ ORDER BY film.film_id ASC; maxRentalRateFilms := []model.Film{} err := query.Query(db, &maxRentalRateFilms) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(maxRentalRateFilms), 336) gRating := model.MpaaRating_G - assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{ + testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{ FilmID: 2, Title: "Ace Goldfinger", Description: StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), @@ -1060,7 +1060,7 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //testutils.PrintJson(dest) @@ -1121,10 +1121,10 @@ ORDER BY customer_payment_sum."amount_sum" ASC; customersWithAmounts := []CustomerWithAmounts{} err := query.Query(db, &customersWithAmounts) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(customersWithAmounts), 599) - assert.DeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{ + testutils.AssertDeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{ CustomerID: 318, StoreID: 1, FirstName: "Brian", @@ -1145,7 +1145,7 @@ func TestSelectStaff(t *testing.T) { err := Staff.SELECT(Staff.AllColumns).Query(db, &staffs) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSON(t, staffs, ` [ @@ -1203,12 +1203,12 @@ ORDER BY payment.payment_date ASC; err := query.Query(db, &payments) - assert.NilError(t, err) + assert.NoError(t, err) //spew.Dump(payments) assert.Equal(t, len(payments), 9) - assert.DeepEqual(t, payments[0], model.Payment{ + testutils.AssertDeepEqual(t, payments[0], model.Payment{ PaymentID: 17793, CustomerID: 416, StaffID: 2, @@ -1257,17 +1257,17 @@ OFFSET 20; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) - assert.DeepEqual(t, dest[0], model.Payment{ + testutils.AssertDeepEqual(t, dest[0], model.Payment{ PaymentID: 17523, Amount: 4.99, }) - assert.DeepEqual(t, dest[1], model.Payment{ + testutils.AssertDeepEqual(t, dest[1], model.Payment{ PaymentID: 17524, Amount: 0.99, }) - assert.DeepEqual(t, dest[9], model.Payment{ + testutils.AssertDeepEqual(t, dest[9], model.Payment{ PaymentID: 17532, Amount: 8.99, }) @@ -1283,7 +1283,7 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 20) }) @@ -1293,7 +1293,7 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 20) }) @@ -1303,7 +1303,7 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 0) }) @@ -1313,7 +1313,7 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 0) }) @@ -1323,7 +1323,7 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) }) @@ -1333,7 +1333,7 @@ func TestAllSetOperators(t *testing.T) { dest := []model.Payment{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 10) }) } @@ -1363,7 +1363,7 @@ LIMIT 20; err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 20) assert.Equal(t, dest[0].StaffIDNum, "TWO") assert.Equal(t, dest[1].StaffIDNum, "ONE") @@ -1396,12 +1396,12 @@ FOR` tx, _ := db.Begin() res, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) rowsAffected, _ := res.RowsAffected() assert.Equal(t, rowsAffected, int64(3)) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } for lockType, lockTypeStr := range getRowLockTestData() { @@ -1412,12 +1412,12 @@ FOR` tx, _ := db.Begin() res, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) rowsAffected, _ := res.RowsAffected() assert.Equal(t, rowsAffected, int64(3)) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } for lockType, lockTypeStr := range getRowLockTestData() { @@ -1428,12 +1428,12 @@ FOR` tx, _ := db.Begin() res, err := query.Exec(tx) - assert.NilError(t, err) + assert.NoError(t, err) rowsAffected, _ := res.RowsAffected() assert.Equal(t, rowsAffected, int64(3)) err = tx.Rollback() - assert.NilError(t, err) + assert.NoError(t, err) } } @@ -1509,7 +1509,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; } err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //jsonSave("./testdata/quick-start-dest.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") @@ -1522,7 +1522,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; } err = stmt.Query(db, &dest2) - assert.NilError(t, err) + assert.NoError(t, err) //jsonSave("./testdata/quick-start-dest2.json", dest2) testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") @@ -1574,7 +1574,7 @@ func TestQuickStartWithSubQueries(t *testing.T) { } err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) //jsonSave("./testdata/quick-start-dest.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/quick-start-dest.json") @@ -1587,7 +1587,7 @@ func TestQuickStartWithSubQueries(t *testing.T) { } err = stmt.Query(db, &dest2) - assert.NilError(t, err) + assert.NoError(t, err) //jsonSave("./testdata/quick-start-dest2.json", dest2) testutils.AssertJSONFile(t, dest2, "./testdata/results/postgres/quick-start-dest2.json") @@ -1620,7 +1620,7 @@ SELECT true, dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestWindowFunction(t *testing.T) { @@ -1692,7 +1692,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestWindowClause(t *testing.T) { @@ -1729,7 +1729,7 @@ ORDER BY payment.customer_id; dest := []struct{}{} err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) } func TestSimpleView(t *testing.T) { @@ -1751,7 +1751,7 @@ func TestSimpleView(t *testing.T) { var dest []ActorInfo err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) testutils.AssertJSON(t, dest[1:2], ` [ @@ -1785,7 +1785,7 @@ func TestJoinViewWithTable(t *testing.T) { fmt.Println(query.DebugSql()) err := query.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 2) assert.Equal(t, len(dest[0].Rentals), 32) @@ -1820,7 +1820,7 @@ LIMIT 3; `) var dest []model.Customer err := stmt.Query(db, &dest) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(dest), 3) } diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index e3366de..57f579e 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -6,7 +6,7 @@ import ( . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" "time" ) @@ -35,9 +35,9 @@ WHERE link.name = 'Bing'; WHERE(Link.Name.EQ(String("Bong"))). Query(db, &links) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(links), 1) - assert.DeepEqual(t, links[0], model.Link{ + testutils.AssertDeepEqual(t, links[0], model.Link{ ID: 204, URL: "http://bong.com", Name: "Bong", @@ -99,7 +99,7 @@ RETURNING link.id AS "link.id", err := stmt.Query(db, &links) - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, len(links), 2) assert.Equal(t, links[0].Name, "DuckDuckGo") assert.Equal(t, links[1].Name, "DuckDuckGo") @@ -293,10 +293,10 @@ func setupLinkTableForUpdateTest(t *testing.T) { VALUES(204, "http://www.bing.com", "Bing", DEFAULT). Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } func cleanUpLinkTable(t *testing.T) { _, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) } diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go index 3f00e7b..e7a7816 100644 --- a/tests/postgres/util_test.go +++ b/tests/postgres/util_test.go @@ -5,16 +5,16 @@ import ( "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" "github.com/google/uuid" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" "testing" ) func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { res, err := stmt.Exec(db) - assert.NilError(t, err) + assert.NoError(t, err) rows, err := res.RowsAffected() - assert.NilError(t, err) + assert.NoError(t, err) assert.Equal(t, rows, rowsAffected) } From bf66e151acedced7d91c1df6328f6dae4b2f32c2 Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 11 Feb 2020 10:33:00 +0100 Subject: [PATCH 15/19] Replace gotest.tools with github.com/stretchr/testify. --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 4907252..39a79bb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -94,7 +94,7 @@ jobs: - run: mkdir -p $TEST_RESULTS - - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml + - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 # | go-junit-report > $TEST_RESULTS/results.xml - run: name: Upload code coverage From f154701e606ea19d6e338e3d42406690a63bcfe9 Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 11 Feb 2020 10:36:46 +0100 Subject: [PATCH 16/19] Update circleci config.yml --- .circleci/config.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 39a79bb..b1cb643 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -47,6 +47,7 @@ jobs: go get github.com/pkg/profile go get github.com/stretchr/testify/assert + go get github.com/google/go-cmp/cmp go get github.com/davecgh/go-spew/spew go get github.com/jstemmer/go-junit-report @@ -94,7 +95,7 @@ jobs: - run: mkdir -p $TEST_RESULTS - - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 # | go-junit-report > $TEST_RESULTS/results.xml + - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - run: name: Upload code coverage @@ -143,6 +144,7 @@ jobs: go get github.com/pkg/profile go get github.com/stretchr/testify/assert + go get github.com/google/go-cmp/cmp go get github.com/davecgh/go-spew/spew go get github.com/jstemmer/go-junit-report From 63c1fd643031651d22ced1a8410cb9669c45f69f Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 15 Feb 2020 11:20:51 +0100 Subject: [PATCH 17/19] [bug] Fix crash on generating enum sql builder files when enum contains numeric values. --- generator/internal/template/generate.go | 3 ++- generator/internal/template/templates.go | 4 ++-- internal/utils/utils.go | 10 ++++++++++ internal/utils/utils_test.go | 5 +++++ tests/postgres/generator_test.go | 25 ++++++++++++++++++++++-- 5 files changed, 42 insertions(+), 5 deletions(-) diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go index a076bb1..3872f76 100644 --- a/generator/internal/template/generate.go +++ b/generator/internal/template/generate.go @@ -73,7 +73,8 @@ func generateGoFiles(dirPath, packageName string, template string, metaDataList func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) { t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ - "ToGoIdentifier": utils.ToGoIdentifier, + "ToGoIdentifier": utils.ToGoIdentifier, + "ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier, "now": func() string { return time.Now().Format(time.RFC850) }, diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index e5328a3..40e4773 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -94,11 +94,11 @@ import "github.com/go-jet/jet/{{dialect.PackageName}}" var {{ToGoIdentifier $.Name}} = &struct { {{- range $index, $element := .Values}} - {{ToGoIdentifier $element}} {{dialect.PackageName}}.StringExpression + {{ToGoEnumValueIdentifier $.Name $element}} {{dialect.PackageName}}.StringExpression {{- end}} } { {{- range $index, $element := .Values}} - {{ToGoIdentifier $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"), + {{ToGoEnumValueIdentifier $.Name $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"), {{- end}} } ` diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 42a5c36..5301605 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -10,6 +10,7 @@ import ( "reflect" "strings" "time" + "unicode" ) // ToGoIdentifier converts database to Go identifier. @@ -17,6 +18,15 @@ func ToGoIdentifier(databaseIdentifier string) string { return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) } +func ToGoEnumValueIdentifier(enumName, enumValue string) string { + enumValueIdentifier := ToGoIdentifier(enumValue) + if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) { + return ToGoIdentifier(enumName) + enumValueIdentifier + } + + return enumValueIdentifier +} + // ToGoFileName converts database identifier to Go file name. func ToGoFileName(databaseIdentifier string) string { return strings.ToLower(replaceInvalidChars(databaseIdentifier)) diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index ac57bbd..787b14b 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -25,6 +25,11 @@ func TestToGoIdentifier(t *testing.T) { assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable") } +func TestToGoEnumValueIdentifier(t *testing.T) { + assert.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue") + assert.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100") +} + func TestErrorCatchErr(t *testing.T) { var err error diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index b98633a..52ddfac 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -283,14 +283,15 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { enumFiles, err := ioutil.ReadDir(enumDir) assert.NoError(t, err) - testutils.AssertFileNamesEqual(t, enumFiles, "mood.go") + testutils.AssertFileNamesEqual(t, enumFiles, "mood.go", "level.go") testutils.AssertFileContent(t, enumDir+"mood.go", "\npackage enum", moodEnumContent) + testutils.AssertFileContent(t, enumDir+"level.go", "\npackage enum", levelEnumContent) modelFiles, err := ioutil.ReadDir(modelDir) assert.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", - "mood.go", "person.go", "person_phone.go", "weird_names_table.go") + "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go") testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent) @@ -319,6 +320,26 @@ var Mood = &struct { } ` +var levelEnumContent = ` +package enum + +import "github.com/go-jet/jet/postgres" + +var Level = &struct { + Level1 postgres.StringExpression + Level2 postgres.StringExpression + Level3 postgres.StringExpression + Level4 postgres.StringExpression + Level5 postgres.StringExpression +}{ + Level1: postgres.NewEnumValue("1"), + Level2: postgres.NewEnumValue("2"), + Level3: postgres.NewEnumValue("3"), + Level4: postgres.NewEnumValue("4"), + Level5: postgres.NewEnumValue("5"), +} +` + var allTypesModelContent = ` package model From 3019fdbbb2d6c62af097f2fb3112500026673178 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 16 Feb 2020 10:25:21 +0100 Subject: [PATCH 18/19] [bug] Escape reserved words used as identifier. --- internal/jet/dialect.go | 20 ++++++++ internal/jet/sql_builder.go | 2 +- internal/testutils/test_utils.go | 1 + internal/utils/utils.go | 1 + postgres/dialect.go | 81 ++++++++++++++++++++++++++++++++ postgres/dialect_test.go | 18 +++++++ tests/mysql/alltypes_test.go | 50 ++++++++++++++++++++ tests/postgres/generator_test.go | 4 +- tests/postgres/sample_test.go | 51 ++++++++++++++++++++ tests/testdata | 2 +- 10 files changed, 226 insertions(+), 4 deletions(-) diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index acf03d9..434e3b8 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -1,5 +1,7 @@ package jet +import "strings" + // Dialect interface type Dialect interface { Name() string @@ -9,6 +11,7 @@ type Dialect interface { AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc + IsReservedWord(name string) bool } // SerializerFunc func @@ -29,6 +32,7 @@ type DialectParams struct { AliasQuoteChar byte IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc + ReservedWords []string } // NewDialect creates new dialect with params @@ -41,6 +45,7 @@ func NewDialect(params DialectParams) Dialect { aliasQuoteChar: params.AliasQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, + reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), } } @@ -52,6 +57,7 @@ type dialectImpl struct { aliasQuoteChar byte identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc + reservedWords map[string]bool supportsReturning bool } @@ -89,3 +95,17 @@ func (d *dialectImpl) IdentifierQuoteChar() byte { func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { return d.argumentPlaceholder } + +func (d *dialectImpl) IsReservedWord(name string) bool { + _, isReservedWord := d.reservedWords[strings.ToLower(name)] + return isReservedWord +} + +func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { + ret := map[string]bool{} + for _, elem := range arr { + ret[strings.ToLower(elem)] = true + } + + return ret +} diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index f71ca1e..ef7f801 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) { // WriteIdentifier adds identifier to output SQL func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { - if shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { + if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) s.WriteString(identQuoteChar + name + identQuoteChar) } else { diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 5719e75..f0c18e9 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -211,6 +211,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } } +// AssertDeepEqual checks if actual and expected objects are deeply equal. func AssertDeepEqual(t *testing.T, actual, expected interface{}) { assert.True(t, cmp.Equal(actual, expected)) } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 5301605..e346775 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -18,6 +18,7 @@ func ToGoIdentifier(databaseIdentifier string) string { return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) } +// ToGoEnumValueIdentifier converts enum value name to Go identifier name. func ToGoEnumValueIdentifier(enumName, enumValue string) string { enumValueIdentifier := ToGoIdentifier(enumValue) if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) { diff --git a/postgres/dialect.go b/postgres/dialect.go index c1e8c0b..b440c5d 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -24,6 +24,7 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, + ReservedWords: reservedWords, } return jet.NewDialect(dialectParams) @@ -105,3 +106,83 @@ func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.Serializer jet.Serialize(expressions[1], statement, out, options...) } } + +var reservedWords = []string{ + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONSTRAINT", + "CREATE", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "GRANT", + "GROUP", + "HAVING", + "IN", + "INITIALLY", + "INTERSECT", + "INTO", + "LATERAL", + "LEADING", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NOT", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "SELECT", + "SESSION_USER", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "WHEN", + "WHERE", + "WINDOW", + "WITH", +} diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index f53587e..7bf9242 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -74,3 +74,21 @@ func TestNOT_IN(t *testing.T) { FROM db.table2 )))`, int64(12)) } + +func TestReservedWordEscaped(t *testing.T) { + var table1ColUser = IntervalColumn("user") + var table1ColVariadic = IntervalColumn("VARIADIC") + var table1ColProcedure = IntervalColumn("procedure") + + _ = NewTable( + "db", + "table1", + table1ColUser, + table1ColVariadic, + table1ColProcedure, + ) + + assertSerialize(t, table1ColUser, `table1."user"`) + assertSerialize(t, table1ColVariadic, `table1."VARIADIC"`) + assertSerialize(t, table1ColProcedure, `table1.procedure`) +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 98a1a4a..2c6768e 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1095,3 +1095,53 @@ var allTypesJson = ` } ] ` + +func TestReservedWord(t *testing.T) { + stmt := SELECT(User.AllColumns). + FROM(User) + + // NOTE: A word that follows a period in a qualified name must be an identifier, so it + // need not be quoted even if it is reserved + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT user.column AS "user.column", + user.use AS "user.use", + user.ceil AS "user.ceil", + user.commit AS "user.commit", + user.create AS "user.create", + user.default AS "user.default", + user.desc AS "user.desc", + user.empty AS "user.empty", + user.float AS "user.float", + user.join AS "user.join", + user.like AS "user.like", + user.max AS "user.max", + user.rank AS "user.rank" +FROM test_sample.user; +`) + + var dest []model.User + err := stmt.Query(db, &dest) + assert.NoError(t, err) + + testutils.PrintJson(dest) + + testutils.AssertJSON(t, dest, ` +[ + { + "Column": "Column", + "Use": "CHECK", + "Ceil": "CEIL", + "Commit": "COMMIT", + "Create": "CREATE", + "Default": "DEFAULT", + "Desc": "DESC", + "Empty": "EMPTY", + "Float": "FLOAT", + "Join": "JOIN", + "Like": "LIKE", + "Max": "MAX", + "Rank": "RANK" + } +] +`) +} diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 52ddfac..54d7cc4 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -291,7 +291,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { assert.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", - "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go") + "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go") testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent) @@ -299,7 +299,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { assert.NoError(t, err) testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", - "person.go", "person_phone.go", "weird_names_table.go") + "person.go", "person_phone.go", "weird_names_table.go", "user.go") testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent) } diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 5989d17..2cfbc85 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -328,3 +328,54 @@ FROM test_sample."WEIRD NAMES TABLE"; WeirdColuName16: "Doe", }) } + +func TestReserwedWordEscape(t *testing.T) { + stmt := SELECT(User.AllColumns). + FROM(User) + + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT "User"."column" AS "User.column", + "User"."check" AS "User.check", + "User".ceil AS "User.ceil", + "User".commit AS "User.commit", + "User"."create" AS "User.create", + "User"."default" AS "User.default", + "User"."desc" AS "User.desc", + "User".empty AS "User.empty", + "User".float AS "User.float", + "User".join AS "User.join", + "User".like AS "User.like", + "User".max AS "User.max", + "User".rank AS "User.rank" +FROM test_sample."User"; +`) + + var dest []model.User + + err := stmt.Query(db, &dest) + assert.NoError(t, err) + + testutils.PrintJson(dest) + + testutils.AssertJSON(t, dest, ` +[ + { + "Column": "Column", + "Check": "CHECK", + "Ceil": "CEIL", + "Commit": "COMMIT", + "Create": "CREATE", + "Default": "DEFAULT", + "Desc": "DESC", + "Empty": "EMPTY", + "Float": "FLOAT", + "Join": "JOIN", + "Like": "LIKE", + "Max": "MAX", + "Rank": "RANK" + } +] +`) +} diff --git a/tests/testdata b/tests/testdata index 02e0795..889e07c 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6 +Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17 From 4c6caa403e94b8a156b477722e953b47da373a37 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 16 Feb 2020 17:35:39 +0100 Subject: [PATCH 19/19] Test sample for dynamic projection and dynamic condition. --- tests/postgres/select_test.go | 72 +++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index d354360..b641ec0 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1792,12 +1792,21 @@ func TestJoinViewWithTable(t *testing.T) { assert.Equal(t, len(dest[1].Rentals), 27) } -func TestConditionalProjectionList(t *testing.T) { +func TestDynamicProjectionList(t *testing.T) { + + var request struct { + ColumnsToSelect []string + ShowFullName bool + } + + request.ColumnsToSelect = []string{"customer_id", "create_date"} + request.ShowFullName = true + + // ... + projectionList := ProjectionList{} - columnsToSelect := []string{"customer_id", "create_date"} - - for _, columnName := range columnsToSelect { + for _, columnName := range request.ColumnsToSelect { switch columnName { case Customer.CustomerID.Name(): projectionList = append(projectionList, Customer.CustomerID) @@ -1808,6 +1817,11 @@ func TestConditionalProjectionList(t *testing.T) { } } + var showFullName bool + if showFullName { + projectionList = append(projectionList, Customer.FirstName.CONCAT(Customer.LastName)) + } + stmt := SELECT(projectionList). FROM(Customer). LIMIT(3) @@ -1824,3 +1838,53 @@ LIMIT 3; assert.Equal(t, len(dest), 3) } + +func TestDynamicCondition(t *testing.T) { + var request struct { + CustomerID *int64 + Email *string + Active *bool + } + + request.CustomerID = Int64Ptr(1) + request.Active = BoolPtr(true) + + // ... + + condition := Bool(true) + + if request.CustomerID != nil { + condition = condition.AND(Customer.CustomerID.EQ(Int(*request.CustomerID))) + } + if request.Email != nil { + condition = condition.AND(Customer.Email.EQ(String(*request.Email))) + } + if request.Active != nil { + condition = condition.AND(Customer.Activebool.EQ(Bool(*request.Active))) + } + + stmt := SELECT(Customer.AllColumns). + FROM(Customer). + WHERE(condition) + + testutils.AssertStatementSql(t, stmt, ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.activebool AS "customer.activebool", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + customer.active AS "customer.active" +FROM dvds.customer +WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3); +`, true, int64(1), true) + + dest := []model.Customer{} + err := stmt.Query(db, &dest) + assert.NoError(t, err) + assert.Len(t, dest, 1) + testutils.AssertDeepEqual(t, dest[0], customer0) +}