diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 71720d1..acf03d9 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -14,7 +14,7 @@ type Dialect interface { // SerializerFunc func type SerializerFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption) -//// SerializeOverride func +// SerializeOverride func type SerializeOverride func(expressions ...Serializer) SerializerFunc // QueryPlaceholderFunc func diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 379bd70..26b9186 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -8,25 +8,26 @@ type Expression interface { GroupByClause OrderByClause - // Test expression whether it is a NULL value. + // IS_NULL tests expression whether it is a NULL value. IS_NULL() BoolExpression - // Test expression whether it is a non-NULL value. + // IS_NOT_NULL tests expression whether it is a non-NULL value. IS_NOT_NULL() BoolExpression - // Check if this expressions matches any in expressions list + // IN checks if this expressions matches any in expressions list IN(expressions ...Expression) BoolExpression - // Check if this expressions is different of all expressions in expressions list + // NOT_IN checks if this expressions is different of all expressions in expressions list NOT_IN(expressions ...Expression) BoolExpression - // The temporary alias name to assign to the expression + // AS the temporary alias name to assign to the expression AS(alias string) Projection - // Expression will be used to sort query result in ascending order + // ASC expression will be used to sort query result in ascending order ASC() OrderByClause - // Expression will be used to sort query result in ascending order + // DESC expression will be used to sort query result in ascending order DESC() OrderByClause } +// ExpressionInterfaceImpl implements Expression interface methods type ExpressionInterfaceImpl struct { Parent Expression } @@ -35,30 +36,37 @@ func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { return e.Parent } +// IS_NULL tests expression whether it is a NULL value. func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression { return newPostfixBoolOperatorExpression(e.Parent, "IS NULL") } +// IS_NOT_NULL tests expression whether it is a non-NULL value. func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { return newPostfixBoolOperatorExpression(e.Parent, "IS NOT NULL") } +// IN checks if this expressions matches any in expressions list func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN") } +// NOT_IN checks if this expressions is different of all expressions in expressions list func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN") } +// AS the temporary alias name to assign to the expression func (e *ExpressionInterfaceImpl) AS(alias string) Projection { return newAlias(e.Parent, alias) } +// ASC expression will be used to sort query result in ascending order func (e *ExpressionInterfaceImpl) ASC() OrderByClause { return newOrderByClause(e.Parent, true) } +// DESC expression will be used to sort query result in ascending order func (e *ExpressionInterfaceImpl) DESC() OrderByClause { return newOrderByClause(e.Parent, false) } diff --git a/internal/jet/interval.go b/internal/jet/interval.go index 705f2be..e66ca56 100644 --- a/internal/jet/interval.go +++ b/internal/jet/interval.go @@ -1,14 +1,17 @@ package jet +// Interval is internal common representation of sql interval type Interval interface { Serializer IsInterval } +// IsInterval interface type IsInterval interface { isInterval() } +// NewInterval creates new interval from serializer func NewInterval(s Serializer) Interval { newInterval := &intervalImpl{ interval: s, diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 6983b8f..499b7b4 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -334,20 +334,20 @@ func WRAP(expression ...Expression) Expression { //---------------------------------------------------// -type RawExpression struct { +type rawExpression struct { ExpressionInterfaceImpl Raw string } -func (n *RawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(n.Raw) } // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") func Raw(raw string) Expression { - rawExp := &RawExpression{Raw: raw} + rawExp := &rawExpression{Raw: raw} rawExp.ExpressionInterfaceImpl.Parent = rawExp return rawExp diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index d5ff4b9..dc661d7 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -42,6 +42,7 @@ func contains(options []SerializeOption, option SerializeOption) bool { return false } +// ListSerializer serializes list of serializers with separator type ListSerializer struct { Serializers []Serializer Separator string diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 3b34ab2..95bd0b6 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -142,12 +142,8 @@ func argToString(value interface{}) string { return "TRUE" } return "FALSE" - case int: - return strconv.FormatInt(int64(bindVal), 10) - case int32: - return strconv.FormatInt(int64(bindVal), 10) - case int64: - return strconv.FormatInt(bindVal, 10) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return integerTypesToString(bindVal) case float32: return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) @@ -167,6 +163,33 @@ func argToString(value interface{}) string { } } +func integerTypesToString(value interface{}) string { + switch bindVal := value.(type) { + case bool: + case int: + return strconv.FormatInt(int64(bindVal), 10) + case uint: + return strconv.FormatUint(uint64(bindVal), 10) + case int8: + return strconv.FormatInt(int64(bindVal), 10) + case uint8: + return strconv.FormatUint(uint64(bindVal), 10) + case int16: + return strconv.FormatInt(int64(bindVal), 10) + case uint16: + return strconv.FormatUint(uint64(bindVal), 10) + case int32: + return strconv.FormatInt(int64(bindVal), 10) + case uint32: + return strconv.FormatUint(uint64(bindVal), 10) + case int64: + return strconv.FormatInt(bindVal, 10) + case uint64: + return strconv.FormatUint(bindVal, 10) + } + panic("jet: Unsupported integer type: " + reflect.TypeOf(value).String()) +} + func shouldQuoteIdentifier(identifier string) bool { for _, c := range identifier { if unicode.IsNumber(c) || c == '_' { diff --git a/internal/jet/serializer_test.go b/internal/jet/sql_builder_test.go similarity index 70% rename from internal/jet/serializer_test.go rename to internal/jet/sql_builder_test.go index 6d2fd4a..dc4a476 100644 --- a/internal/jet/serializer_test.go +++ b/internal/jet/sql_builder_test.go @@ -12,8 +12,16 @@ func TestArgToString(t *testing.T) { assert.Equal(t, argToString(false), "FALSE") assert.Equal(t, argToString(int(-32)), "-32") - assert.Equal(t, argToString(int32(-32)), "-32") + assert.Equal(t, argToString(uint(32)), "32") + assert.Equal(t, argToString(int8(-43)), "-43") + assert.Equal(t, argToString(uint8(43)), "43") + assert.Equal(t, argToString(int16(-54)), "-54") + assert.Equal(t, argToString(uint16(54)), "54") + assert.Equal(t, argToString(int32(-65)), "-65") + assert.Equal(t, argToString(uint32(65)), "65") assert.Equal(t, argToString(int64(-64)), "-64") + assert.Equal(t, argToString(uint64(64)), "64") + assert.Equal(t, argToString(float32(2.0)), "2") assert.Equal(t, argToString(float64(1.11)), "1.11") assert.Equal(t, argToString("john"), "'john'") diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 67c2c6c..58394f4 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -63,6 +63,7 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } +// ExpressionListToSerializerList converts list of expressions to list of serializers func ExpressionListToSerializerList(expressions []Expression) []Serializer { var ret []Serializer diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 1b19a20..c3d3ff0 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -129,7 +129,7 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali } } -// AssertClauseSerialize checks if clause serialize produces expected query and args +// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { out := jet.SQLBuilder{Dialect: dialect, Debug: true} jet.Serialize(clause, jet.SelectStatementType, &out) diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 97ac48a..42a5c36 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -184,6 +184,7 @@ func StringSliceContains(strings []string, contains string) bool { return false } +// ExtractDateTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration func ExtractDateTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) { days = int64(duration / (24 * time.Hour)) reminder := duration % (24 * time.Hour) diff --git a/mysql/interval.go b/mysql/interval.go index 57c2cfc..478e9c4 100644 --- a/mysql/interval.go +++ b/mysql/interval.go @@ -9,10 +9,11 @@ import ( "github.com/go-jet/jet/internal/utils" ) -type UnitType string +type unitType string +// List of interval unit types for MySQL const ( - MICROSECOND UnitType = "MICROSECOND" + MICROSECOND unitType = "MICROSECOND" SECOND = "SECOND" MINUTE = "MINUTE" HOUR = "HOUR" @@ -34,9 +35,15 @@ const ( YEAR_MONTH = "YEAR_MONTH" ) +// Interval is representation of MySQL interval type Interval = jet.Interval -func INTERVAL(value interface{}, unitType UnitType) Interval { +// INTERVAL creates new Interval type. +// In a case of MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR unit type +// value parameter should be number. For example: INTERVAL(1, DAY) +// In a case of other unit types, value should be string with appropriate format. +// For example: INTERVAL("10:08:50", HOUR_SECOND) +func INTERVAL(value interface{}, unitType unitType) Interval { switch unitType { case MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR: if !isNumericType(value) { @@ -87,14 +94,15 @@ func INTERVAL(value interface{}, unitType UnitType) Interval { } } -func INTERVALe(expr Expression, unitType UnitType) Interval { +// INTERVALe creates new Interval type from expresion and unit type. +func INTERVALe(expr Expression, unitType unitType) Interval { return jet.NewInterval(jet.ListSerializer{ Serializers: []jet.Serializer{expr, jet.Raw(string(unitType))}, Separator: " ", }) } -// INTERVALd returns a representation of duration as MySQL INTERVAL +// INTERVALd returns a interval representation from duration func INTERVALd(duration time.Duration) Interval { var sign int64 = 1 if duration < 0 { @@ -179,7 +187,7 @@ var ( func isNumericType(value interface{}) bool { switch value.(type) { - case float64, float32, int16, int32, int64, uint16, uint32, uint64, int, uint: + case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: return true default: return false diff --git a/mysql/interval_test.go b/mysql/interval_test.go index a7c16f8..c88b808 100644 --- a/mysql/interval_test.go +++ b/mysql/interval_test.go @@ -32,14 +32,18 @@ func TestINTERVAL(t *testing.T) { assertDebugSerialize(t, INTERVAL("08.000100", SECOND_MICROSECOND), "INTERVAL '08.000100' SECOND_MICROSECOND") assertDebugSerialize(t, INTERVAL("-08.000100", SECOND_MICROSECOND), "INTERVAL '-08.000100' SECOND_MICROSECOND") - assertDebugSerialize(t, INTERVAL(15, SECOND), "INTERVAL 15 SECOND") - assertDebugSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL 1 MICROSECOND") - assertDebugSerialize(t, INTERVAL(2, MINUTE), "INTERVAL 2 MINUTE") - assertDebugSerialize(t, INTERVAL(3, HOUR), "INTERVAL 3 HOUR") - assertDebugSerialize(t, INTERVAL(4, DAY), "INTERVAL 4 DAY") - assertDebugSerialize(t, INTERVAL(5, MONTH), "INTERVAL 5 MONTH") - assertDebugSerialize(t, INTERVAL(6, YEAR), "INTERVAL 6 YEAR") - assertDebugSerialize(t, INTERVAL(-6, YEAR), "INTERVAL -6 YEAR") + assertSerialize(t, INTERVAL(15, SECOND), "INTERVAL 15 SECOND") + assertSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL 1 MICROSECOND") + assertSerialize(t, INTERVAL(2, MINUTE), "INTERVAL 2 MINUTE") + assertSerialize(t, INTERVAL(3, HOUR), "INTERVAL 3 HOUR") + assertSerialize(t, INTERVAL(4, DAY), "INTERVAL 4 DAY") + assertSerialize(t, INTERVAL(5, MONTH), "INTERVAL 5 MONTH") + assertSerialize(t, INTERVAL(6, YEAR), "INTERVAL 6 YEAR") + assertSerialize(t, INTERVAL(-6, YEAR), "INTERVAL -6 YEAR") + + assertSerialize(t, INTERVAL(uint(6), YEAR), "INTERVAL 6 YEAR") + assertSerialize(t, INTERVAL(int16(7), YEAR), "INTERVAL 7 YEAR") + assertSerialize(t, INTERVAL(3.5, YEAR), "INTERVAL 3.5 YEAR") } func TestINTERVAL_InvalidUnitType(t *testing.T) { diff --git a/postgres/interval.go b/postgres/interval.go index f6344fd..de659e7 100644 --- a/postgres/interval.go +++ b/postgres/interval.go @@ -9,12 +9,11 @@ import ( "time" ) -type quantityAndUnit float64 +type quantityAndUnit = float64 +// Interval unit types const ( - pow2_32 = -4.294967296e+09 - - YEAR quantityAndUnit = pow2_32 + iota + YEAR quantityAndUnit = 123456789 + iota MONTH WEEK DAY @@ -33,11 +32,14 @@ type intervalExpressionImpl struct { jet.ExpressionInterfaceImpl } +// IntervalExpression is representation of postgres INTERVAL type IntervalExpression interface { jet.IsInterval jet.Expression } +// INTERVAL creates new interval expression from the list of quantity-unit pairs. +// For example: INTERVAL(1, DAY, 3, MINUTE) func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { if len(quantityAndUnit)%2 != 0 { panic("jet: invalid number of quantity and unit fields") @@ -62,6 +64,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { return newInterval } +// INTERVALd creates interval expression from duration func INTERVALd(duration time.Duration) IntervalExpression { days, hours, minutes, seconds, microseconds := utils.ExtractDateTimeComponents(duration) diff --git a/postgres/interval_test.go b/postgres/interval_test.go index 1d1e3de..785f1d5 100644 --- a/postgres/interval_test.go +++ b/postgres/interval_test.go @@ -25,6 +25,9 @@ func TestINTERVAL(t *testing.T) { assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL") assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`) + + f := 5.2 + assertSerialize(t, INTERVAL(f, YEAR), "INTERVAL '5.2 YEAR'") } func TestINTERVALd(t *testing.T) {