From c342f296cae9b4a426bec5886a1fc2b85a59a96e Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 1 Aug 2019 10:39:57 +0200 Subject: [PATCH] MySQL cast expressions. Simplified. --- cast.go | 41 +++++++++++++++++++++++---- cast_test.go | 12 ++++---- mysql/mysql_cast.go | 29 ++++++------------- mysql/mysql_types.go | 2 ++ postgres/date_expression_test.go | 16 +++++------ postgres/postgres_cast.go | 48 ++++++++++++-------------------- postgres/postgres_cast_test.go | 6 +++- tests/mysql/cast_test.go | 23 +++++++++++---- tests/postgres/alltypes_test.go | 3 +- visitor.go | 4 +-- 10 files changed, 103 insertions(+), 81 deletions(-) diff --git a/cast.go b/cast.go index 759dfc5..7143395 100644 --- a/cast.go +++ b/cast.go @@ -1,24 +1,32 @@ package jet -type CastType string +import "strconv" type Cast interface { - As(castType CastType) Expression + AS(castType string) Expression + + AS_CHAR(lenght ...int) StringExpression + // Cast expression AS date type + AS_DATE() DateExpression + // Cast expression AS numeric type, using precision and optionally scale + AS_DECIMAL() FloatExpression + // Cast expression AS time type + AS_TIME() TimeExpression } type CastImpl struct { expression Expression } -func NewCastImpl(expression Expression) Cast { +func NewCastImpl(expression Expression) CastImpl { castImpl := CastImpl{ expression: expression, } - return &castImpl + return castImpl } -func (b *CastImpl) As(castType CastType) Expression { +func (b *CastImpl) AS(castType string) Expression { castExp := &castExpression{ expression: b.expression, cast: string(castType), @@ -29,6 +37,29 @@ func (b *CastImpl) As(castType CastType) Expression { return castExp } +func (b *CastImpl) AS_CHAR(lenght ...int) StringExpression { + if len(lenght) > 0 { + return StringExp(b.AS("CHAR(" + strconv.Itoa(lenght[0]) + ")")) + } + + return StringExp(b.AS("CHAR")) +} + +// Cast expression AS date type +func (b *CastImpl) AS_DATE() DateExpression { + return DateExp(b.AS("DATE")) +} + +// Cast expression AS date type +func (b *CastImpl) AS_DECIMAL() FloatExpression { + return FloatExp(b.AS("DECIMAL")) +} + +// Cast expression AS date type +func (b *CastImpl) AS_TIME() TimeExpression { + return TimeExp(b.AS("TIME")) +} + type castExpression struct { expressionInterfaceImpl diff --git a/cast_test.go b/cast_test.go index 5a15d1b..996a5f8 100644 --- a/cast_test.go +++ b/cast_test.go @@ -1,9 +1,7 @@ package jet -import "testing" - -func TestCastAS(t *testing.T) { - AssertClauseSerialize(t, NewCastImpl(Int(1)).As("boolean"), "CAST(? AS boolean)", int64(1)) - AssertClauseSerialize(t, NewCastImpl(table2Col3).As("real"), "CAST(table2.col3 AS real)") - AssertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).As("integer"), "CAST((table2.col3 + table2.col3) AS integer)") -} +//func TestCastAS(t *testing.T) { +// AssertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST(? AS boolean)", int64(1)) +// AssertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)") +// AssertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)") +//} diff --git a/mysql/mysql_cast.go b/mysql/mysql_cast.go index 7667f75..bcb853f 100644 --- a/mysql/mysql_cast.go +++ b/mysql/mysql_cast.go @@ -5,51 +5,38 @@ import ( ) type cast interface { - AS_DATE() DateExpression - AS_TIME() TimeExpression + jet.Cast + AS_DATETIME() DateTimeExpression - AS_CHAR() StringExpression AS_SIGNED() IntegerExpression AS_UNSIGNED() IntegerExpression AS_BINARY() StringExpression } type castImpl struct { - jet.Cast + jet.CastImpl } func CAST(expr jet.Expression) cast { castImpl := &castImpl{} - castImpl.Cast = jet.NewCastImpl(expr) + castImpl.CastImpl = jet.NewCastImpl(expr) return castImpl } -func (c *castImpl) AS_DATE() DateExpression { - return jet.DateExp(c.As("DATE")) -} - func (c *castImpl) AS_DATETIME() DateTimeExpression { - return jet.TimestampExp(c.As("DATETIME")) -} - -func (c *castImpl) AS_TIME() TimeExpression { - return jet.TimeExp(c.As("TIME")) -} - -func (c *castImpl) AS_CHAR() StringExpression { - return jet.StringExp(c.As("CHAR")) + return jet.TimestampExp(c.AS("DATETIME")) } func (c *castImpl) AS_SIGNED() IntegerExpression { - return jet.IntExp(c.As("SIGNED")) + return jet.IntExp(c.AS("SIGNED")) } func (c *castImpl) AS_UNSIGNED() IntegerExpression { - return jet.IntExp(c.As("UNSIGNED")) + return jet.IntExp(c.AS("UNSIGNED")) } func (c *castImpl) AS_BINARY() StringExpression { - return jet.StringExp(c.As("BINARY")) + return jet.StringExp(c.AS("BINARY")) } diff --git a/mysql/mysql_types.go b/mysql/mysql_types.go index 0b12069..acf9e07 100644 --- a/mysql/mysql_types.go +++ b/mysql/mysql_types.go @@ -2,6 +2,8 @@ package mysql import "github.com/go-jet/jet" +type Expression jet.Expression + type ColumnBool jet.ColumnBool type BoolExpression jet.BoolExpression diff --git a/postgres/date_expression_test.go b/postgres/date_expression_test.go index 56ca426..ff37d76 100644 --- a/postgres/date_expression_test.go +++ b/postgres/date_expression_test.go @@ -9,40 +9,40 @@ var dateVar = Date(2000, 12, 30) func TestDateExpressionEQ(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.EQ(table2ColDate), "(table1.col_date = table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.EQ(dateVar), "(table1.col_date = $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.EQ(dateVar), "(table1.col_date = $1::DATE)", "2000-12-30") } func TestDateExpressionNOT_EQ(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.NOT_EQ(table2ColDate), "(table1.col_date != table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.NOT_EQ(dateVar), "(table1.col_date != $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.NOT_EQ(dateVar), "(table1.col_date != $1::DATE)", "2000-12-30") } func TestDateExpressionIS_DISTINCT_FROM(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(table2ColDate), "(table1.col_date IS DISTINCT FROM table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(dateVar), "(table1.col_date IS DISTINCT FROM $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.IS_DISTINCT_FROM(dateVar), "(table1.col_date IS DISTINCT FROM $1::DATE)", "2000-12-30") } func TestDateExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(table2ColDate), "(table1.col_date IS NOT DISTINCT FROM table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(dateVar), "(table1.col_date IS NOT DISTINCT FROM $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.IS_NOT_DISTINCT_FROM(dateVar), "(table1.col_date IS NOT DISTINCT FROM $1::DATE)", "2000-12-30") } func TestDateExpressionGT(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.GT(table2ColDate), "(table1.col_date > table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.GT(dateVar), "(table1.col_date > $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.GT(dateVar), "(table1.col_date > $1::DATE)", "2000-12-30") } func TestDateExpressionGT_EQ(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.GT_EQ(table2ColDate), "(table1.col_date >= table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.GT_EQ(dateVar), "(table1.col_date >= $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.GT_EQ(dateVar), "(table1.col_date >= $1::DATE)", "2000-12-30") } func TestDateExpressionLT(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.LT(table2ColDate), "(table1.col_date < table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.LT(dateVar), "(table1.col_date < $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.LT(dateVar), "(table1.col_date < $1::DATE)", "2000-12-30") } func TestDateExpressionLT_EQ(t *testing.T) { jet.AssertPostgreClauseSerialize(t, table1ColDate.LT_EQ(table2ColDate), "(table1.col_date <= table2.col_date)") - jet.AssertPostgreClauseSerialize(t, table1ColDate.LT_EQ(dateVar), "(table1.col_date <= $1::date)", "2000-12-30") + jet.AssertPostgreClauseSerialize(t, table1ColDate.LT_EQ(dateVar), "(table1.col_date <= $1::DATE)", "2000-12-30") } diff --git a/postgres/postgres_cast.go b/postgres/postgres_cast.go index ae83ed1..2074f89 100644 --- a/postgres/postgres_cast.go +++ b/postgres/postgres_cast.go @@ -6,6 +6,7 @@ import ( ) type cast interface { + jet.Cast // Cast expression AS bool type AS_BOOL() BoolExpression // Cast expression AS smallint type @@ -16,18 +17,14 @@ type cast interface { AS_BIGINT() IntegerExpression // Cast expression AS numeric type, using precision and optionally scale AS_NUMERIC(precision int, scale ...int) FloatExpression - // Cast expression AS numeric type, using precision and optionally scale - AS_DECIMAL() FloatExpression + // Cast expression AS real type AS_REAL() FloatExpression // Cast expression AS double precision type AS_DOUBLE() FloatExpression // Cast expression AS text type AS_TEXT() StringExpression - // Cast expression AS date type - AS_DATE() DateExpression - // Cast expression AS time type - AS_TIME() TimeExpression + // Cast expression AS time with time timezone type AS_TIMEZ() TimezExpression // Cast expression AS timestamp type @@ -37,33 +34,33 @@ type cast interface { } type castImpl struct { - jet.Cast + jet.CastImpl } func CAST(expr jet.Expression) cast { castImpl := &castImpl{} - castImpl.Cast = jet.NewCastImpl(expr) + castImpl.CastImpl = jet.NewCastImpl(expr) return castImpl } func (b *castImpl) AS_BOOL() BoolExpression { - return jet.BoolExp(b.As("boolean")) + return jet.BoolExp(b.AS("boolean")) } func (b *castImpl) AS_SMALLINT() IntegerExpression { - return jet.IntExp(b.As("smallint")) + return jet.IntExp(b.AS("smallint")) } // Cast expression AS integer type func (b *castImpl) AS_INTEGER() IntegerExpression { - return jet.IntExp(b.As("integer")) + return jet.IntExp(b.AS("integer")) } // Cast expression AS bigint type func (b *castImpl) AS_BIGINT() IntegerExpression { - return jet.IntExp(b.As("bigint")) + return jet.IntExp(b.AS("bigint")) } // Cast expression AS numeric type, using precision and optionally scale @@ -76,49 +73,40 @@ func (b *castImpl) AS_NUMERIC(precision int, scale ...int) FloatExpression { castType = fmt.Sprintf("numeric(%d)", precision) } - return jet.FloatExp(b.As(jet.CastType(castType))) -} - -func (b *castImpl) AS_DECIMAL() FloatExpression { - return jet.FloatExp(b.As("decimal")) + return jet.FloatExp(b.AS(castType)) } // Cast expression AS real type func (b *castImpl) AS_REAL() FloatExpression { - return jet.FloatExp(b.As("real")) + return jet.FloatExp(b.AS("real")) } // Cast expression AS double precision type func (b *castImpl) AS_DOUBLE() FloatExpression { - return jet.FloatExp(b.As("double precision")) + return jet.FloatExp(b.AS("double precision")) } // Cast expression AS text type func (b *castImpl) AS_TEXT() StringExpression { - return jet.StringExp(b.As("text")) + return jet.StringExp(b.AS("text")) } // Cast expression AS date type -func (b *castImpl) AS_DATE() DateExpression { - return jet.DateExp(b.As("date")) -} - -// Cast expression AS time type -func (b *castImpl) AS_TIME() TimeExpression { - return jet.TimeExp(b.As("time without time zone")) +func (b *castImpl) AS_TIME() jet.TimeExpression { + return TimeExp(b.AS("time without time zone")) } // Cast expression AS time with time timezone type func (b *castImpl) AS_TIMEZ() TimezExpression { - return jet.TimezExp(b.As("time with time zone")) + return jet.TimezExp(b.AS("time with time zone")) } // Cast expression AS timestamp type func (b *castImpl) AS_TIMESTAMP() TimestampExpression { - return jet.TimestampExp(b.As("timestamp without time zone")) + return jet.TimestampExp(b.AS("timestamp without time zone")) } // Cast expression AS timestamp with timezone type func (b *castImpl) AS_TIMESTAMPZ() TimestampzExpression { - return jet.TimestampzExp(b.As("timestamp with time zone")) + return jet.TimestampzExp(b.AS("timestamp with time zone")) } diff --git a/postgres/postgres_cast_test.go b/postgres/postgres_cast_test.go index 725fb8d..87a9e26 100644 --- a/postgres/postgres_cast_test.go +++ b/postgres/postgres_cast_test.go @@ -5,6 +5,10 @@ import ( "testing" ) +func TestExpressionCAST_AS(t *testing.T) { + jet.AssertPostgreClauseSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") +} + func TestExpressionCAST_AS_BOOL(t *testing.T) { jet.AssertPostgreClauseSerialize(t, CAST(Int(1)).AS_BOOL(), "$1::boolean", int64(1)) jet.AssertPostgreClauseSerialize(t, CAST(table2Col3).AS_BOOL(), "table2.col3::boolean") @@ -41,7 +45,7 @@ func TestExpressionCAST_AS_TEXT(t *testing.T) { } func TestExpressionCAST_AS_DATE(t *testing.T) { - jet.AssertPostgreClauseSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::date") + jet.AssertPostgreClauseSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::DATE") } func TestExpressionCAST_AS_TIME(t *testing.T) { diff --git a/tests/mysql/cast_test.go b/tests/mysql/cast_test.go index 05d2fa7..bfc8630 100644 --- a/tests/mysql/cast_test.go +++ b/tests/mysql/cast_test.go @@ -12,31 +12,40 @@ import ( func TestCast(t *testing.T) { query := SELECT( + CAST(String("test")).AS("CHAR CHARACTER SET utf8").AS("result.AS1"), CAST(String("2011-02-02")).AS_DATE().AS("result.date"), CAST(String("14:06:10")).AS_TIME().AS("result.time"), CAST(String("2011-02-02 14:06:10")).AS_DATETIME().AS("result.datetime"), - CAST(Int(150)).AS_CHAR().AS("result.char"), + + CAST(Int(150)).AS_CHAR().AS("result.char1"), + CAST(Int(150)).AS_CHAR(30).AS("result.char2"), + CAST(Int(5).SUB(Int(10))).AS_SIGNED().AS("result.signed"), CAST(Int(5).ADD(Int(10))).AS_UNSIGNED().AS("result.unsigned"), CAST(String("Some text")).AS_BINARY().AS("result.binary"), ).FROM(AllTypes) testutils.AssertStatementSql(t, query, ` -SELECT CAST(? AS DATE) AS "result.date", +SELECT CAST(? AS CHAR CHARACTER SET utf8) AS "result.AS1", + CAST(? AS DATE) AS "result.date", CAST(? AS TIME) AS "result.time", CAST(? AS DATETIME) AS "result.datetime", - CAST(? AS CHAR) AS "result.char", + CAST(? AS CHAR) AS "result.char1", + CAST(? AS CHAR(30)) AS "result.char2", CAST((? - ?) AS SIGNED) AS "result.signed", CAST((? + ?) AS UNSIGNED) AS "result.unsigned", CAST(? AS BINARY) AS "result.binary" FROM test_sample.all_types; -`, "2011-02-02", "14:06:10", "2011-02-02 14:06:10", int64(150), int64(5), int64(10), int64(5), int64(10), "Some text") +`, "test", "2011-02-02", "14:06:10", "2011-02-02 14:06:10", int64(150), int64(150), int64(5), + int64(10), int64(5), int64(10), "Some text") type Result struct { + As1 string Date time.Time Time time.Time DateTime time.Time - Char string + Char1 string + Char2 string Signed int Unsigned int Binary string @@ -49,10 +58,12 @@ FROM test_sample.all_types; assert.NilError(t, err) assert.DeepEqual(t, dest, Result{ + As1: "test", Date: *testutils.Date("2011-02-02"), Time: *testutils.TimeWithoutTimeZone("14:06:10"), DateTime: *testutils.TimestampWithoutTimeZone("2011-02-02 14:06:10", 0), - Char: "150", + Char1: "150", + Char2: "150", Signed: -5, Unsigned: 15, Binary: "Some text", diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index bc17696..558a5d0 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -65,6 +65,7 @@ func TestExpressionOperators(t *testing.T) { AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.IntegerPtr)), + postgres.CAST(Int(150)).AS_CHAR(12), postgres.CAST(String("TRUE")).AS_BOOL(), postgres.CAST(String("111")).AS_SMALLINT(), postgres.CAST(String("111")).AS_INTEGER(), @@ -320,7 +321,7 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1", TRUNC(ABS(all_types.decimal), $29) AS "abs", TRUNC(POWER(all_types.decimal, $30), $31) AS "power", TRUNC(SQRT(all_types.decimal), $32) AS "sqrt", - TRUNC(CBRT(all_types.decimal)::decimal, $33) AS "cbrt", + TRUNC(CBRT(all_types.decimal)::DECIMAL, $33) AS "cbrt", CEIL(all_types.real) AS "ceil", FLOOR(all_types.real) AS "floor", ROUND(all_types.decimal) AS "round1", diff --git a/visitor.go b/visitor.go index 6c9dc84..6532655 100644 --- a/visitor.go +++ b/visitor.go @@ -27,7 +27,7 @@ func newDialectFinder() *DialectFinder { } } -func (f *DialectFinder) dialect() Dialect { +func (f *DialectFinder) mustGetDialect() Dialect { if len(f.dialects) == 0 { panic("jet: can't detect dialect") } @@ -59,5 +59,5 @@ func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect { dialectFinder := newDialectFinder() element.accept(dialectFinder) - return dialectFinder.dialect() + return dialectFinder.mustGetDialect() }