diff --git a/postgres/cast.go b/postgres/cast.go index e9ec209..0f4b255 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -2,8 +2,9 @@ package postgres import ( "fmt" - "github.com/go-jet/jet/internal/jet" "strconv" + + "github.com/go-jet/jet/internal/jet" ) type cast interface { @@ -32,7 +33,7 @@ type cast interface { AS_TIME() TimeExpression // Cast expression AS text type AS_TEXT() StringExpression - + // Cast expression AS bytea type AS_BYTEA() StringExpression // Cast expression AS time with time timezone type AS_TIMEZ() TimezExpression @@ -40,6 +41,8 @@ type cast interface { AS_TIMESTAMP() TimestampExpression // Cast expression AS timestamp with timezone type AS_TIMESTAMPZ() TimestampzExpression + // Cast expression AS interval type + AS_INTERVAL() IntervalExpression } type castImpl struct { @@ -151,3 +154,8 @@ func (b *castImpl) AS_TIMESTAMP() TimestampExpression { func (b *castImpl) AS_TIMESTAMPZ() TimestampzExpression { return TimestampzExp(b.AS("timestamp with time zone")) } + +// Cast expression AS interval type +func (b *castImpl) AS_INTERVAL() IntervalExpression { + return IntervalExp(b.AS("interval")) +} diff --git a/postgres/cast_test.go b/postgres/cast_test.go index a1e4be5..e02336a 100644 --- a/postgres/cast_test.go +++ b/postgres/cast_test.go @@ -62,3 +62,10 @@ func TestExpressionCAST_AS_TIMESTAMP(t *testing.T) { func TestExpressionCAST_AS_TIMESTAMPZ(t *testing.T) { assertSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone") } + +func TestExpressionCAST_AS_INTERVAL(t *testing.T) { + assertSerialize(t, CAST(table2ColTimez).AS_INTERVAL(), "table2.col_timez::interval") + assertSerialize(t, CAST(Time(20, 11, 10)).AS_INTERVAL(), "$1::time without time zone::interval", "20:11:10") + assertSerialize(t, table2ColDate.SUB(CAST(Time(20, 11, 10)).AS_INTERVAL()), + "(table2.col_date - $1::time without time zone::interval)", "20:11:10") +} diff --git a/postgres/interval.go b/postgres/interval.go index dba2c6c..f6344fd 100644 --- a/postgres/interval.go +++ b/postgres/interval.go @@ -129,3 +129,22 @@ func unitToString(unit quantityAndUnit) string { panic("jet: invalid INTERVAL unit type") } } + +//---------------------------------------------------// + +type intervalWrapper struct { + jet.IsInterval + Expression +} + +func newIntervalExpressionWrap(expression Expression) IntervalExpression { + intervalWrap := intervalWrapper{Expression: expression} + return &intervalWrap +} + +// IntervalExp is interval expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as interval expression. +// Does not add sql cast to generated sql builder output. +func IntervalExp(expression Expression) IntervalExpression { + return newIntervalExpressionWrap(expression) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index eed0695..7b9d1d6 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -8,7 +8,6 @@ import ( "gotest.tools/assert" "github.com/go-jet/jet/internal/testutils" - "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" @@ -136,22 +135,23 @@ LIMIT $5; func TestExpressionCast(t *testing.T) { query := AllTypes.SELECT( - postgres.CAST(Int(150)).AS_CHAR(12).AS("char12"), - postgres.CAST(String("TRUE")).AS_BOOL(), - postgres.CAST(String("111")).AS_SMALLINT(), - postgres.CAST(String("111")).AS_INTEGER(), - postgres.CAST(String("111")).AS_BIGINT(), - postgres.CAST(String("11.23")).AS_NUMERIC(30, 10), - postgres.CAST(String("11.23")).AS_NUMERIC(30), - postgres.CAST(String("11.23")).AS_NUMERIC(), - postgres.CAST(String("11.23")).AS_REAL(), - postgres.CAST(String("11.23")).AS_DOUBLE(), - postgres.CAST(Int(234)).AS_TEXT(), - postgres.CAST(String("1/8/1999")).AS_DATE(), - postgres.CAST(String("04:05:06.789")).AS_TIME(), - postgres.CAST(String("04:05:06 PST")).AS_TIMEZ(), - postgres.CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), - postgres.CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(Int(150)).AS_CHAR(12).AS("char12"), + CAST(String("TRUE")).AS_BOOL(), + CAST(String("111")).AS_SMALLINT(), + CAST(String("111")).AS_INTEGER(), + CAST(String("111")).AS_BIGINT(), + CAST(String("11.23")).AS_NUMERIC(30, 10), + CAST(String("11.23")).AS_NUMERIC(30), + CAST(String("11.23")).AS_NUMERIC(), + CAST(String("11.23")).AS_REAL(), + CAST(String("11.23")).AS_DOUBLE(), + CAST(Int(234)).AS_TEXT(), + CAST(String("1/8/1999")).AS_DATE(), + CAST(String("04:05:06.789")).AS_TIME(), + CAST(String("04:05:06 PST")).AS_TIMEZ(), + CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), + CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(String("04:05:06")).AS_INTERVAL(), TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), TO_CHAR(AllTypes.Integer, String("999")), @@ -361,7 +361,7 @@ func TestFloatOperators(t *testing.T) { TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), - TRUNC(postgres.CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), + TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), CEIL(AllTypes.Real).AS("ceil"), FLOOR(AllTypes.Real).AS("floor"), @@ -619,6 +619,8 @@ func TestTimeExpression(t *testing.T) { AllTypes.Timestampz.ADD(INTERVAL(1, HOUR)), AllTypes.Timestampz.SUB(INTERVAL(1, MINUTE)), + AllTypes.Date.SUB(CAST(String("04:05:06")).AS_INTERVAL()), + CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIME(2),