diff --git a/internal/jet/bool_expression.go b/internal/jet/bool_expression.go index 5bdda95..1a05ab6 100644 --- a/internal/jet/bool_expression.go +++ b/internal/jet/bool_expression.go @@ -53,93 +53,53 @@ func (b *boolInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BoolExpression) BoolExpress } func (b *boolInterfaceImpl) AND(expression BoolExpression) BoolExpression { - return newBinaryBoolOperator(b.parent, expression, "AND") + return newBinaryBoolOperatorExpression(b.parent, expression, "AND") } func (b *boolInterfaceImpl) OR(expression BoolExpression) BoolExpression { - return newBinaryBoolOperator(b.parent, expression, "OR") + return newBinaryBoolOperatorExpression(b.parent, expression, "OR") } func (b *boolInterfaceImpl) IS_TRUE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS TRUE") + return newPostfixBoolOperatorExpression(b.parent, "IS TRUE") } func (b *boolInterfaceImpl) IS_NOT_TRUE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS NOT TRUE") + return newPostfixBoolOperatorExpression(b.parent, "IS NOT TRUE") } func (b *boolInterfaceImpl) IS_FALSE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS FALSE") + return newPostfixBoolOperatorExpression(b.parent, "IS FALSE") } func (b *boolInterfaceImpl) IS_NOT_FALSE() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS NOT FALSE") + return newPostfixBoolOperatorExpression(b.parent, "IS NOT FALSE") } func (b *boolInterfaceImpl) IS_UNKNOWN() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS UNKNOWN") + return newPostfixBoolOperatorExpression(b.parent, "IS UNKNOWN") } func (b *boolInterfaceImpl) IS_NOT_UNKNOWN() BoolExpression { - return newPostifxBoolExpression(b.parent, "IS NOT UNKNOWN") + return newPostfixBoolOperatorExpression(b.parent, "IS NOT UNKNOWN") } //---------------------------------------------------// -type binaryBoolExpression struct { - expressionInterfaceImpl - boolInterfaceImpl - - binaryOpExpression -} - -func newBinaryBoolOperator(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression { - binaryBoolExpression := binaryBoolExpression{} - - binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator, additionalParams...) - binaryBoolExpression.expressionInterfaceImpl.Parent = &binaryBoolExpression - binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression - - return &binaryBoolExpression +func newBinaryBoolOperatorExpression(lhs, rhs Expression, operator string, additionalParams ...Expression) BoolExpression { + return BoolExp(newBinaryOperatorExpression(lhs, rhs, operator, additionalParams...)) } //---------------------------------------------------// -type prefixBoolExpression struct { - expressionInterfaceImpl - boolInterfaceImpl - - prefixOpExpression -} - -func newPrefixBoolOperator(expression Expression, operator string) BoolExpression { - exp := prefixBoolExpression{} - exp.prefixOpExpression = newPrefixExpression(expression, operator) - - exp.expressionInterfaceImpl.Parent = &exp - exp.boolInterfaceImpl.parent = &exp - - return &exp +func newPrefixBoolOperatorExpression(expression Expression, operator string) BoolExpression { + return BoolExp(newPrefixOperatorExpression(expression, operator)) } //---------------------------------------------------// -type postfixBoolOpExpression struct { - expressionInterfaceImpl - boolInterfaceImpl - - postfixOpExpression -} - -func newPostifxBoolExpression(expression Expression, operator string) BoolExpression { - exp := postfixBoolOpExpression{} - exp.postfixOpExpression = newPostfixOpExpression(expression, operator) - - exp.expressionInterfaceImpl.Parent = &exp - exp.boolInterfaceImpl.parent = &exp - - return &exp +func newPostfixBoolOperatorExpression(expression Expression, operator string) BoolExpression { + return BoolExp(newPostfixOperatorExpression(expression, operator)) } //---------------------------------------------------// - type boolExpressionWrapper struct { boolInterfaceImpl Expression diff --git a/internal/jet/cast.go b/internal/jet/cast.go index 84e1962..c5fe9a7 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -24,13 +24,13 @@ func (b *castImpl) AS(castType string) Expression { cast: string(castType), } - castExp.expressionInterfaceImpl.Parent = castExp + castExp.ExpressionInterfaceImpl.Parent = castExp return castExp } type castExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expression Expression cast string diff --git a/internal/jet/column.go b/internal/jet/column.go index d1422d4..c7a5f41 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -20,7 +20,7 @@ type ColumnExpression interface { // The base type for real materialized columns. type columnImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl name string tableName string @@ -34,7 +34,7 @@ func newColumn(name string, tableName string, parent ColumnExpression) columnImp tableName: tableName, } - bc.expressionInterfaceImpl.Parent = parent + bc.ExpressionInterfaceImpl.Parent = parent return bc } diff --git a/internal/jet/column_test.go b/internal/jet/column_test.go index 2159c68..ca3f5f6 100644 --- a/internal/jet/column_test.go +++ b/internal/jet/column_test.go @@ -4,7 +4,7 @@ import "testing" func TestColumn(t *testing.T) { column := newColumn("col", "", nil) - column.expressionInterfaceImpl.Parent = &column + column.ExpressionInterfaceImpl.Parent = &column assertClauseSerialize(t, column, "col") column.setTableName("table1") diff --git a/internal/jet/date_expression.go b/internal/jet/date_expression.go index 357e8a5..8b0a524 100644 --- a/internal/jet/date_expression.go +++ b/internal/jet/date_expression.go @@ -13,42 +13,53 @@ type DateExpression interface { LT_EQ(rhs DateExpression) BoolExpression GT(rhs DateExpression) BoolExpression GT_EQ(rhs DateExpression) BoolExpression + + ADD(rhs Interval) TimestampExpression + SUB(rhs Interval) TimestampExpression } type dateInterfaceImpl struct { parent DateExpression } -func (t *dateInterfaceImpl) EQ(rhs DateExpression) BoolExpression { - return eq(t.parent, rhs) +func (d *dateInterfaceImpl) EQ(rhs DateExpression) BoolExpression { + return eq(d.parent, rhs) } -func (t *dateInterfaceImpl) NOT_EQ(rhs DateExpression) BoolExpression { - return notEq(t.parent, rhs) +func (d *dateInterfaceImpl) NOT_EQ(rhs DateExpression) BoolExpression { + return notEq(d.parent, rhs) } -func (t *dateInterfaceImpl) IS_DISTINCT_FROM(rhs DateExpression) BoolExpression { - return isDistinctFrom(t.parent, rhs) +func (d *dateInterfaceImpl) IS_DISTINCT_FROM(rhs DateExpression) BoolExpression { + return isDistinctFrom(d.parent, rhs) } -func (t *dateInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs DateExpression) BoolExpression { - return isNotDistinctFrom(t.parent, rhs) +func (d *dateInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs DateExpression) BoolExpression { + return isNotDistinctFrom(d.parent, rhs) } -func (t *dateInterfaceImpl) LT(rhs DateExpression) BoolExpression { - return lt(t.parent, rhs) +func (d *dateInterfaceImpl) LT(rhs DateExpression) BoolExpression { + return lt(d.parent, rhs) } -func (t *dateInterfaceImpl) LT_EQ(rhs DateExpression) BoolExpression { - return ltEq(t.parent, rhs) +func (d *dateInterfaceImpl) LT_EQ(rhs DateExpression) BoolExpression { + return ltEq(d.parent, rhs) } -func (t *dateInterfaceImpl) GT(rhs DateExpression) BoolExpression { - return gt(t.parent, rhs) +func (d *dateInterfaceImpl) GT(rhs DateExpression) BoolExpression { + return gt(d.parent, rhs) } -func (t *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression { - return gtEq(t.parent, rhs) +func (d *dateInterfaceImpl) GT_EQ(rhs DateExpression) BoolExpression { + return gtEq(d.parent, rhs) +} + +func (d *dateInterfaceImpl) ADD(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(d.parent, rhs, "+")) +} + +func (d *dateInterfaceImpl) SUB(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(d.parent, rhs, "-")) } //---------------------------------------------------// diff --git a/internal/jet/date_expression_test.go b/internal/jet/date_expression_test.go new file mode 100644 index 0000000..14fdd76 --- /dev/null +++ b/internal/jet/date_expression_test.go @@ -0,0 +1,13 @@ +package jet + +import ( + "testing" +) + +func TestDateArithmetic(t *testing.T) { + timestamp := Timestamp(2000, 1, 1, 0, 0, 0) + assertClauseDebugSerialize(t, table1ColDate.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_date + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") + assertClauseDebugSerialize(t, table1ColDate.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_date - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") +} diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index e46000a..acf03d9 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -11,11 +11,11 @@ type Dialect interface { ArgumentPlaceholder() QueryPlaceholderFunc } -// SerializeFunc func -type SerializeFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption) +// SerializerFunc func +type SerializerFunc func(statement StatementType, out *SQLBuilder, options ...SerializeOption) // SerializeOverride func -type SerializeOverride func(expressions ...Expression) SerializeFunc +type SerializeOverride func(expressions ...Serializer) SerializerFunc // QueryPlaceholderFunc func type QueryPlaceholderFunc func(ord int) string diff --git a/internal/jet/enum_value.go b/internal/jet/enum_value.go index 5bd609d..17e8c74 100644 --- a/internal/jet/enum_value.go +++ b/internal/jet/enum_value.go @@ -1,7 +1,7 @@ package jet type enumValue struct { - expressionInterfaceImpl + ExpressionInterfaceImpl stringInterfaceImpl name string @@ -11,7 +11,7 @@ type enumValue struct { func NewEnumValue(name string) StringExpression { enumValue := &enumValue{name: name} - enumValue.expressionInterfaceImpl.Parent = enumValue + enumValue.ExpressionInterfaceImpl.Parent = enumValue enumValue.stringInterfaceImpl.parent = enumValue return enumValue diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 4ded89f..26b9186 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -8,82 +8,92 @@ type Expression interface { GroupByClause OrderByClause - // Test expression whether it is a NULL value. + // IS_NULL tests expression whether it is a NULL value. IS_NULL() BoolExpression - // Test expression whether it is a non-NULL value. + // IS_NOT_NULL tests expression whether it is a non-NULL value. IS_NOT_NULL() BoolExpression - // Check if this expressions matches any in expressions list + // IN checks if this expressions matches any in expressions list IN(expressions ...Expression) BoolExpression - // Check if this expressions is different of all expressions in expressions list + // NOT_IN checks if this expressions is different of all expressions in expressions list NOT_IN(expressions ...Expression) BoolExpression - // The temporary alias name to assign to the expression + // AS the temporary alias name to assign to the expression AS(alias string) Projection - // Expression will be used to sort query result in ascending order + // ASC expression will be used to sort query result in ascending order ASC() OrderByClause - // Expression will be used to sort query result in ascending order + // DESC expression will be used to sort query result in ascending order DESC() OrderByClause } -type expressionInterfaceImpl struct { +// ExpressionInterfaceImpl implements Expression interface methods +type ExpressionInterfaceImpl struct { Parent Expression } -func (e *expressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { +func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { return e.Parent } -func (e *expressionInterfaceImpl) IS_NULL() BoolExpression { - return newPostifxBoolExpression(e.Parent, "IS NULL") +// IS_NULL tests expression whether it is a NULL value. +func (e *ExpressionInterfaceImpl) IS_NULL() BoolExpression { + return newPostfixBoolOperatorExpression(e.Parent, "IS NULL") } -func (e *expressionInterfaceImpl) IS_NOT_NULL() BoolExpression { - return newPostifxBoolExpression(e.Parent, "IS NOT NULL") +// IS_NOT_NULL tests expression whether it is a non-NULL value. +func (e *ExpressionInterfaceImpl) IS_NOT_NULL() BoolExpression { + return newPostfixBoolOperatorExpression(e.Parent, "IS NOT NULL") } -func (e *expressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "IN") +// IN checks if this expressions matches any in expressions list +func (e *ExpressionInterfaceImpl) IN(expressions ...Expression) BoolExpression { + return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "IN") } -func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { - return newBinaryBoolOperator(e.Parent, WRAP(expressions...), "NOT IN") +// NOT_IN checks if this expressions is different of all expressions in expressions list +func (e *ExpressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpression { + return newBinaryBoolOperatorExpression(e.Parent, WRAP(expressions...), "NOT IN") } -func (e *expressionInterfaceImpl) AS(alias string) Projection { +// AS the temporary alias name to assign to the expression +func (e *ExpressionInterfaceImpl) AS(alias string) Projection { return newAlias(e.Parent, alias) } -func (e *expressionInterfaceImpl) ASC() OrderByClause { +// ASC expression will be used to sort query result in ascending order +func (e *ExpressionInterfaceImpl) ASC() OrderByClause { return newOrderByClause(e.Parent, true) } -func (e *expressionInterfaceImpl) DESC() OrderByClause { +// DESC expression will be used to sort query result in ascending order +func (e *ExpressionInterfaceImpl) DESC() OrderByClause { return newOrderByClause(e.Parent, false) } -func (e *expressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { +func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SQLBuilder) { e.Parent.serialize(statement, out, noWrap) } -func (e *expressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { +func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SQLBuilder) { e.Parent.serialize(statement, out, noWrap) } -func (e *expressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { +func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SQLBuilder) { e.Parent.serialize(statement, out, noWrap) } // Representation of binary operations (e.g. comparisons, arithmetic) -type binaryOpExpression struct { - lhs, rhs Expression - additionalParam Expression +type binaryOperatorExpression struct { + ExpressionInterfaceImpl + + lhs, rhs Serializer + additionalParam Serializer operator string } -func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam ...Expression) binaryOpExpression { - binaryExpression := binaryOpExpression{ +func newBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression { + binaryExpression := &binaryOperatorExpression{ lhs: lhs, rhs: rhs, operator: operator, @@ -93,10 +103,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string, additionalParam . binaryExpression.additionalParam = additionalParam[0] } + binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression + return binaryExpression } -func (c *binaryOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if c.lhs == nil { panic("jet: lhs is nil for '" + c.operator + "' operator") } @@ -125,21 +137,24 @@ func (c *binaryOpExpression) serialize(statement StatementType, out *SQLBuilder, } // A prefix operator Expression -type prefixOpExpression struct { +type prefixExpression struct { + ExpressionInterfaceImpl + expression Expression operator string } -func newPrefixExpression(expression Expression, operator string) prefixOpExpression { - prefixExpression := prefixOpExpression{ +func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression { + prefixExpression := &prefixExpression{ expression: expression, operator: operator, } + prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression return prefixExpression } -func (p *prefixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("(") out.WriteString(p.operator) @@ -152,18 +167,22 @@ func (p *prefixOpExpression) serialize(statement StatementType, out *SQLBuilder, out.WriteString(")") } -// A postifx operator Expression +// A postfix operator Expression type postfixOpExpression struct { + ExpressionInterfaceImpl + expression Expression operator string } -func newPostfixOpExpression(expression Expression, operator string) postfixOpExpression { - postfixOpExpression := postfixOpExpression{ +func newPostfixOperatorExpression(expression Expression, operator string) *postfixOpExpression { + postfixOpExpression := &postfixOpExpression{ expression: expression, operator: operator, } + postfixOpExpression.ExpressionInterfaceImpl.Parent = postfixOpExpression + return postfixOpExpression } diff --git a/internal/jet/float_expression.go b/internal/jet/float_expression.go index c2ec535..aa821ba 100644 --- a/internal/jet/float_expression.go +++ b/internal/jet/float_expression.go @@ -85,22 +85,8 @@ func (n *floatInterfaceImpl) POW(expression NumericExpression) FloatExpression { } //---------------------------------------------------// -type binaryFloatExpression struct { - expressionInterfaceImpl - floatInterfaceImpl - - binaryOpExpression -} - func newBinaryFloatExpression(lhs, rhs Expression, operator string) FloatExpression { - floatExpression := binaryFloatExpression{} - - floatExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - - floatExpression.expressionInterfaceImpl.Parent = &floatExpression - floatExpression.floatInterfaceImpl.parent = &floatExpression - - return &floatExpression + return FloatExp(newBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 3b334c6..f38c9a2 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -578,7 +578,7 @@ func LEAST(value Expression, values ...Expression) Expression { //--------------------------------------------------------------------// type funcExpressionImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl name string expressions []Expression @@ -592,9 +592,9 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr } if parent != nil { - funcExp.expressionInterfaceImpl.Parent = parent + funcExp.ExpressionInterfaceImpl.Parent = parent } else { - funcExp.expressionInterfaceImpl.Parent = funcExp + funcExp.ExpressionInterfaceImpl.Parent = funcExp } return funcExp @@ -605,14 +605,14 @@ func newWindowFunc(name string, expressions ...Expression) windowExpression { newFun := newFunc(name, expressions, nil) windowExpr := newWindowExpression(newFun) - newFun.expressionInterfaceImpl.Parent = windowExpr + newFun.ExpressionInterfaceImpl.Parent = windowExpr return windowExpr } func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(f.expressions...) + serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.expressions)...) serializeOverrideFunc(statement, out, options...) return } @@ -642,7 +642,7 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression { boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.boolInterfaceImpl.parent = boolFunc - boolFunc.expressionInterfaceImpl.Parent = boolFunc + boolFunc.ExpressionInterfaceImpl.Parent = boolFunc return boolFunc } @@ -654,7 +654,7 @@ func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpress boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) intWindowFunc := newBoolWindowExpression(boolFunc) boolFunc.boolInterfaceImpl.parent = intWindowFunc - boolFunc.expressionInterfaceImpl.Parent = intWindowFunc + boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc return intWindowFunc } @@ -681,7 +681,7 @@ func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpre floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatWindowFunc := newFloatWindowExpression(floatFunc) floatFunc.floatInterfaceImpl.parent = floatWindowFunc - floatFunc.expressionInterfaceImpl.Parent = floatWindowFunc + floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc return floatWindowFunc } @@ -707,7 +707,7 @@ func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowE integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) intWindowFunc := newIntegerWindowExpression(integerFunc) integerFunc.integerInterfaceImpl.parent = intWindowFunc - integerFunc.expressionInterfaceImpl.Parent = intWindowFunc + integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc return intWindowFunc } diff --git a/internal/jet/integer_expression.go b/internal/jet/integer_expression.go index 74a3927..c004437 100644 --- a/internal/jet/integer_expression.go +++ b/internal/jet/integer_expression.go @@ -86,23 +86,23 @@ func (i *integerInterfaceImpl) LT_EQ(expression IntegerExpression) BoolExpressio } func (i *integerInterfaceImpl) ADD(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "+") + return newBinaryIntegerOperatorExpression(i.parent, expression, "+") } func (i *integerInterfaceImpl) SUB(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "-") + return newBinaryIntegerOperatorExpression(i.parent, expression, "-") } func (i *integerInterfaceImpl) MUL(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "*") + return newBinaryIntegerOperatorExpression(i.parent, expression, "*") } func (i *integerInterfaceImpl) DIV(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "/") + return newBinaryIntegerOperatorExpression(i.parent, expression, "/") } func (i *integerInterfaceImpl) MOD(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "%") + return newBinaryIntegerOperatorExpression(i.parent, expression, "%") } func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpression { @@ -110,78 +110,33 @@ func (i *integerInterfaceImpl) POW(expression IntegerExpression) IntegerExpressi } func (i *integerInterfaceImpl) BIT_AND(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "&") + return newBinaryIntegerOperatorExpression(i.parent, expression, "&") } func (i *integerInterfaceImpl) BIT_OR(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "|") + return newBinaryIntegerOperatorExpression(i.parent, expression, "|") } func (i *integerInterfaceImpl) BIT_XOR(expression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, expression, "#") + return newBinaryIntegerOperatorExpression(i.parent, expression, "#") } func (i *integerInterfaceImpl) BIT_SHIFT_LEFT(intExpression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, intExpression, "<<") + return newBinaryIntegerOperatorExpression(i.parent, intExpression, "<<") } func (i *integerInterfaceImpl) BIT_SHIFT_RIGHT(intExpression IntegerExpression) IntegerExpression { - return newBinaryIntegerExpression(i.parent, intExpression, ">>") + return newBinaryIntegerOperatorExpression(i.parent, intExpression, ">>") } //---------------------------------------------------// -type binaryIntegerExpression struct { - expressionInterfaceImpl - integerInterfaceImpl - - binaryOpExpression -} - -func newBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { - integerExpression := binaryIntegerExpression{} - - integerExpression.expressionInterfaceImpl.Parent = &integerExpression - integerExpression.integerInterfaceImpl.parent = &integerExpression - - integerExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - - return &integerExpression +func newBinaryIntegerOperatorExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression { + return IntExp(newBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// -type prefixIntegerOpExpression struct { - expressionInterfaceImpl - integerInterfaceImpl - - prefixOpExpression -} - -func newPrefixIntegerOperator(expression IntegerExpression, operator string) IntegerExpression { - integerExpression := prefixIntegerOpExpression{} - integerExpression.prefixOpExpression = newPrefixExpression(expression, operator) - - integerExpression.expressionInterfaceImpl.Parent = &integerExpression - integerExpression.integerInterfaceImpl.parent = &integerExpression - - return &integerExpression -} - -//---------------------------------------------------// -type prefixFloatOpExpression struct { - expressionInterfaceImpl - floatInterfaceImpl - - prefixOpExpression -} - -func newPrefixFloatOperator(expression FloatExpression, operator string) FloatExpression { - floatOpExpression := prefixFloatOpExpression{} - floatOpExpression.prefixOpExpression = newPrefixExpression(expression, operator) - - floatOpExpression.expressionInterfaceImpl.Parent = &floatOpExpression - floatOpExpression.floatInterfaceImpl.parent = &floatOpExpression - - return &floatOpExpression +func newPrefixIntegerOperatorExpression(expression IntegerExpression, operator string) IntegerExpression { + return IntExp(newPrefixOperatorExpression(expression, operator)) } //---------------------------------------------------// diff --git a/internal/jet/interval.go b/internal/jet/interval.go new file mode 100644 index 0000000..e66ca56 --- /dev/null +++ b/internal/jet/interval.go @@ -0,0 +1,32 @@ +package jet + +// Interval is internal common representation of sql interval +type Interval interface { + Serializer + IsInterval +} + +// IsInterval interface +type IsInterval interface { + isInterval() +} + +// NewInterval creates new interval from serializer +func NewInterval(s Serializer) Interval { + newInterval := &intervalImpl{ + interval: s, + } + + return newInterval +} + +type intervalImpl struct { + interval Serializer +} + +func (i intervalImpl) isInterval() {} + +func (i intervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString("INTERVAL") + i.interval.serialize(statement, out, options...) +} diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 68fb429..499b7b4 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -14,7 +14,7 @@ type LiteralExpression interface { } type literalExpressionImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl value interface{} constant bool @@ -27,11 +27,17 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl exp.constant = optionalConstant[0] } - exp.expressionInterfaceImpl.Parent = &exp + exp.ExpressionInterfaceImpl.Parent = &exp return &exp } +// Literal is injected directly to SQL query, and does not appear in parametrized argument list. +func Literal(value interface{}) *literalExpressionImpl { + exp := literal(value) + return exp +} + // FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. func FixedLiteral(value interface{}) *literalExpressionImpl { exp := literal(value) @@ -273,13 +279,13 @@ func formatNanoseconds(nanoseconds ...time.Duration) string { //--------------------------------------------------// type nullLiteral struct { - expressionInterfaceImpl + ExpressionInterfaceImpl } func newNullLiteral() Expression { nullExpression := &nullLiteral{} - nullExpression.expressionInterfaceImpl.Parent = nullExpression + nullExpression.ExpressionInterfaceImpl.Parent = nullExpression return nullExpression } @@ -290,13 +296,13 @@ func (n *nullLiteral) serialize(statement StatementType, out *SQLBuilder, option //--------------------------------------------------// type starLiteral struct { - expressionInterfaceImpl + ExpressionInterfaceImpl } func newStarLiteral() Expression { starExpression := &starLiteral{} - starExpression.expressionInterfaceImpl.Parent = starExpression + starExpression.ExpressionInterfaceImpl.Parent = starExpression return starExpression } @@ -308,7 +314,7 @@ func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, option //---------------------------------------------------// type wrap struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expressions []Expression } @@ -321,7 +327,7 @@ func (n *wrap) serialize(statement StatementType, out *SQLBuilder, options ...Se // WRAP wraps list of expressions with brackets '(' and ')' func WRAP(expression ...Expression) Expression { wrap := &wrap{expressions: expression} - wrap.expressionInterfaceImpl.Parent = wrap + wrap.ExpressionInterfaceImpl.Parent = wrap return wrap } @@ -329,20 +335,20 @@ func WRAP(expression ...Expression) Expression { //---------------------------------------------------// type rawExpression struct { - expressionInterfaceImpl + ExpressionInterfaceImpl - raw string + Raw string } func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString(n.raw) + out.WriteString(n.Raw) } // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") func Raw(raw string) Expression { - rawExp := &rawExpression{raw: raw} - rawExp.expressionInterfaceImpl.Parent = rawExp + rawExp := &rawExpression{Raw: raw} + rawExp.ExpressionInterfaceImpl.Parent = rawExp return rawExp } diff --git a/internal/jet/operators.go b/internal/jet/operators.go index 4a4a32d..fad1e26 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -11,7 +11,7 @@ const ( // NOT returns negation of bool expression result func NOT(exp BoolExpression) BoolExpression { - return newPrefixBoolOperator(exp, "NOT") + return newPrefixBoolOperatorExpression(exp, "NOT") } // BIT_NOT inverts every bit in integer expression result @@ -19,52 +19,52 @@ func BIT_NOT(expr IntegerExpression) IntegerExpression { if literalExp, ok := expr.(LiteralExpression); ok { literalExp.SetConstant(true) } - return newPrefixIntegerOperator(expr, "~") + return newPrefixIntegerOperatorExpression(expr, "~") } //----------- Comparison operators ---------------// // EXISTS checks for existence of the rows in subQuery func EXISTS(subQuery Expression) BoolExpression { - return newPrefixBoolOperator(subQuery, "EXISTS") + return newPrefixBoolOperatorExpression(subQuery, "EXISTS") } // Returns a representation of "a=b" func eq(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "=") + return newBinaryBoolOperatorExpression(lhs, rhs, "=") } // Returns a representation of "a!=b" func notEq(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "!=") + return newBinaryBoolOperatorExpression(lhs, rhs, "!=") } func isDistinctFrom(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "IS DISTINCT FROM") + return newBinaryBoolOperatorExpression(lhs, rhs, "IS DISTINCT FROM") } func isNotDistinctFrom(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, "IS NOT DISTINCT FROM") + return newBinaryBoolOperatorExpression(lhs, rhs, "IS NOT DISTINCT FROM") } // Returns a representation of "ab" func gt(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, ">") + return newBinaryBoolOperatorExpression(lhs, rhs, ">") } // Returns a representation of "a>=b" func gtEq(lhs, rhs Expression) BoolExpression { - return newBinaryBoolOperator(lhs, rhs, ">=") + return newBinaryBoolOperatorExpression(lhs, rhs, ">=") } // --------------- CASE operator -------------------// @@ -79,7 +79,7 @@ type CaseOperator interface { } type caseOperatorImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl expression Expression when []Expression @@ -95,7 +95,7 @@ func CASE(expression ...Expression) CaseOperator { caseExp.expression = expression[0] } - caseExp.expressionInterfaceImpl.Parent = caseExp + caseExp.ExpressionInterfaceImpl.Parent = caseExp return caseExp } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 585d7db..dc661d7 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -41,3 +41,18 @@ func contains(options []SerializeOption, option SerializeOption) bool { return false } + +// ListSerializer serializes list of serializers with separator +type ListSerializer struct { + Serializers []Serializer + Separator string +} + +func (s ListSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + for i, ser := range s.Serializers { + if i > 0 { + out.WriteString(s.Separator) + } + ser.serialize(statement, out) + } +} diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 4eaf626..95bd0b6 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -22,7 +22,7 @@ type SQLBuilder struct { lastChar byte ident int - debug bool + Debug bool } const defaultIdent = 5 @@ -120,7 +120,7 @@ func (s *SQLBuilder) insertConstantArgument(arg interface{}) { } func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { - if s.debug { + if s.Debug { s.insertConstantArgument(arg) return } @@ -142,12 +142,8 @@ func argToString(value interface{}) string { return "TRUE" } return "FALSE" - case int: - return strconv.FormatInt(int64(bindVal), 10) - case int32: - return strconv.FormatInt(int64(bindVal), 10) - case int64: - return strconv.FormatInt(bindVal, 10) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return integerTypesToString(bindVal) case float32: return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) @@ -167,6 +163,33 @@ func argToString(value interface{}) string { } } +func integerTypesToString(value interface{}) string { + switch bindVal := value.(type) { + case bool: + case int: + return strconv.FormatInt(int64(bindVal), 10) + case uint: + return strconv.FormatUint(uint64(bindVal), 10) + case int8: + return strconv.FormatInt(int64(bindVal), 10) + case uint8: + return strconv.FormatUint(uint64(bindVal), 10) + case int16: + return strconv.FormatInt(int64(bindVal), 10) + case uint16: + return strconv.FormatUint(uint64(bindVal), 10) + case int32: + return strconv.FormatInt(int64(bindVal), 10) + case uint32: + return strconv.FormatUint(uint64(bindVal), 10) + case int64: + return strconv.FormatInt(bindVal, 10) + case uint64: + return strconv.FormatUint(bindVal, 10) + } + panic("jet: Unsupported integer type: " + reflect.TypeOf(value).String()) +} + func shouldQuoteIdentifier(identifier string) bool { for _, c := range identifier { if unicode.IsNumber(c) || c == '_' { diff --git a/internal/jet/serializer_test.go b/internal/jet/sql_builder_test.go similarity index 70% rename from internal/jet/serializer_test.go rename to internal/jet/sql_builder_test.go index 6d2fd4a..dc4a476 100644 --- a/internal/jet/serializer_test.go +++ b/internal/jet/sql_builder_test.go @@ -12,8 +12,16 @@ func TestArgToString(t *testing.T) { assert.Equal(t, argToString(false), "FALSE") assert.Equal(t, argToString(int(-32)), "-32") - assert.Equal(t, argToString(int32(-32)), "-32") + assert.Equal(t, argToString(uint(32)), "32") + assert.Equal(t, argToString(int8(-43)), "-43") + assert.Equal(t, argToString(uint8(43)), "43") + assert.Equal(t, argToString(int16(-54)), "-54") + assert.Equal(t, argToString(uint16(54)), "54") + assert.Equal(t, argToString(int32(-65)), "-65") + assert.Equal(t, argToString(uint32(65)), "65") assert.Equal(t, argToString(int64(-64)), "-64") + assert.Equal(t, argToString(uint64(64)), "64") + assert.Equal(t, argToString(float32(2.0)), "2") assert.Equal(t, argToString(float64(1.11)), "1.11") assert.Equal(t, argToString("john"), "'john'") diff --git a/internal/jet/statement.go b/internal/jet/statement.go index e4ba41b..3b0638d 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -65,7 +65,7 @@ func (s *serializerStatementInterfaceImpl) Sql() (query string, args []interface } func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { - sqlBuilder := &SQLBuilder{Dialect: s.dialect, debug: true} + sqlBuilder := &SQLBuilder{Dialect: s.dialect, Debug: true} s.parent.serialize(s.statementType, sqlBuilder, noWrap) @@ -106,7 +106,7 @@ type ExpressionStatement interface { // NewExpressionStatementImpl creates new expression statement func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, parent ExpressionStatement, clauses ...Clause) ExpressionStatement { return &expressionStatementImpl{ - expressionInterfaceImpl{Parent: parent}, + ExpressionInterfaceImpl{Parent: parent}, statementImpl{ serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ parent: parent, @@ -119,7 +119,7 @@ func NewExpressionStatementImpl(Dialect Dialect, statementType StatementType, pa } type expressionStatementImpl struct { - expressionInterfaceImpl + ExpressionInterfaceImpl statementImpl } diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index b0351ea..29ceca6 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -60,42 +60,28 @@ func (s *stringInterfaceImpl) LT_EQ(rhs StringExpression) BoolExpression { } func (s *stringInterfaceImpl) CONCAT(rhs Expression) StringExpression { - return newBinaryStringExpression(s.parent, rhs, StringConcatOperator) + return newBinaryStringOperatorExpression(s.parent, rhs, StringConcatOperator) } func (s *stringInterfaceImpl) LIKE(pattern StringExpression) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, "LIKE") + return newBinaryBoolOperatorExpression(s.parent, pattern, "LIKE") } func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, "NOT LIKE") + return newBinaryBoolOperatorExpression(s.parent, pattern, "NOT LIKE") } func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return newBinaryBoolOperatorExpression(s.parent, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) } func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperator(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return newBinaryBoolOperatorExpression(s.parent, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) } //---------------------------------------------------// - -type binaryStringExpression struct { - expressionInterfaceImpl - stringInterfaceImpl - - binaryOpExpression -} - -func newBinaryStringExpression(lhs, rhs Expression, operator string) StringExpression { - boolExpression := binaryStringExpression{} - - boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - boolExpression.expressionInterfaceImpl.Parent = &boolExpression - boolExpression.stringInterfaceImpl.parent = &boolExpression - - return &boolExpression +func newBinaryStringOperatorExpression(lhs, rhs Expression, operator string) StringExpression { + return StringExp(newBinaryOperatorExpression(lhs, rhs, operator)) } //---------------------------------------------------// diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 3c0a969..545f12c 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -14,36 +14,40 @@ var defaultDialect = NewDialect(DialectParams{ // just for tests }, }) -var table1Col1 = IntegerColumn("col1") -var table1ColInt = IntegerColumn("col_int") -var table1ColFloat = FloatColumn("col_float") -var table1Col3 = IntegerColumn("col3") -var table1ColTime = TimeColumn("col_time") -var table1ColTimez = TimezColumn("col_timez") -var table1ColTimestamp = TimestampColumn("col_timestamp") -var table1ColTimestampz = TimestampzColumn("col_timestampz") -var table1ColBool = BoolColumn("col_bool") -var table1ColDate = DateColumn("col_date") - +var ( + table1Col1 = IntegerColumn("col1") + table1ColInt = IntegerColumn("col_int") + table1ColFloat = FloatColumn("col_float") + table1Col3 = IntegerColumn("col3") + table1ColTime = TimeColumn("col_time") + table1ColTimez = TimezColumn("col_timez") + table1ColTimestamp = TimestampColumn("col_timestamp") + table1ColTimestampz = TimestampzColumn("col_timestampz") + table1ColBool = BoolColumn("col_bool") + table1ColDate = DateColumn("col_date") +) var table1 = NewTable("db", "table1", table1Col1, table1ColInt, table1ColFloat, table1Col3, table1ColTime, table1ColTimez, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTimestampz) -var table2Col3 = IntegerColumn("col3") -var table2Col4 = IntegerColumn("col4") -var table2ColInt = IntegerColumn("col_int") -var table2ColFloat = FloatColumn("col_float") -var table2ColStr = StringColumn("col_str") -var table2ColBool = BoolColumn("col_bool") -var table2ColTime = TimeColumn("col_time") -var table2ColTimez = TimezColumn("col_timez") -var table2ColTimestamp = TimestampColumn("col_timestamp") -var table2ColTimestampz = TimestampzColumn("col_timestampz") -var table2ColDate = DateColumn("col_date") - +var ( + table2Col3 = IntegerColumn("col3") + table2Col4 = IntegerColumn("col4") + table2ColInt = IntegerColumn("col_int") + table2ColFloat = FloatColumn("col_float") + table2ColStr = StringColumn("col_str") + table2ColBool = BoolColumn("col_bool") + table2ColTime = TimeColumn("col_time") + table2ColTimez = TimezColumn("col_timez") + table2ColTimestamp = TimestampColumn("col_timestamp") + table2ColTimestampz = TimestampzColumn("col_timestampz") + table2ColDate = DateColumn("col_date") +) var table2 = NewTable("db", "table2", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColTime, table2ColTimez, table2ColDate, table2ColTimestamp, table2ColTimestampz) -var table3Col1 = IntegerColumn("col1") -var table3ColInt = IntegerColumn("col_int") -var table3StrCol = StringColumn("col2") +var ( + table3Col1 = IntegerColumn("col1") + table3ColInt = IntegerColumn("col_int") + table3StrCol = StringColumn("col2") +) var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol) func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { @@ -67,7 +71,7 @@ func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) } func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { - out := SQLBuilder{Dialect: defaultDialect, debug: true} + out := SQLBuilder{Dialect: defaultDialect, Debug: true} clause.serialize(SelectStatementType, &out) //fmt.Println(out.Buff.String()) diff --git a/internal/jet/time_expression.go b/internal/jet/time_expression.go index 779d37f..b83f731 100644 --- a/internal/jet/time_expression.go +++ b/internal/jet/time_expression.go @@ -13,6 +13,9 @@ type TimeExpression interface { LT_EQ(rhs TimeExpression) BoolExpression GT(rhs TimeExpression) BoolExpression GT_EQ(rhs TimeExpression) BoolExpression + + ADD(rhs Interval) TimeExpression + SUB(rhs Interval) TimeExpression } type timeInterfaceImpl struct { @@ -51,23 +54,13 @@ func (t *timeInterfaceImpl) GT_EQ(rhs TimeExpression) BoolExpression { return gtEq(t.parent, rhs) } -//---------------------------------------------------// -type prefixTimeExpression struct { - expressionInterfaceImpl - timeInterfaceImpl - - prefixOpExpression +func (t *timeInterfaceImpl) ADD(rhs Interval) TimeExpression { + return TimeExp(newBinaryOperatorExpression(t.parent, rhs, "+")) } -//func newPrefixTimeExpression(operator string, expression Expression) TimeExpression { -// timeExpr := prefixTimeExpression{} -// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) -// -// timeExpr.expressionInterfaceImpl.parent = &timeExpr -// timeExpr.timeInterfaceImpl.parent = &timeExpr -// -// return &timeExpr -//} +func (t *timeInterfaceImpl) SUB(rhs Interval) TimeExpression { + return TimeExp(newBinaryOperatorExpression(t.parent, rhs, "-")) +} //---------------------------------------------------// diff --git a/internal/jet/time_expression_test.go b/internal/jet/time_expression_test.go index 2b3d015..61ee29f 100644 --- a/internal/jet/time_expression_test.go +++ b/internal/jet/time_expression_test.go @@ -52,3 +52,11 @@ func TestTimeExp(t *testing.T) { assertClauseSerialize(t, TimeExp(table1ColFloat).LT(Time(1, 1, 1, 1*time.Millisecond)), "(table1.col_float < $1)", string("01:01:01.001")) } + +func TestTimeArithmetic(t *testing.T) { + time := Time(10, 20, 3) + assertClauseDebugSerialize(t, table1ColTime.ADD(NewInterval(String("1 HOUR"))).EQ(time), + "((table1.col_time + INTERVAL '1 HOUR') = '10:20:03')") + assertClauseDebugSerialize(t, table1ColTime.SUB(NewInterval(String("1 HOUR"))).EQ(time), + "((table1.col_time - INTERVAL '1 HOUR') = '10:20:03')") +} diff --git a/internal/jet/timestamp_expression.go b/internal/jet/timestamp_expression.go index f76c27a..81eda61 100644 --- a/internal/jet/timestamp_expression.go +++ b/internal/jet/timestamp_expression.go @@ -13,6 +13,9 @@ type TimestampExpression interface { LT_EQ(rhs TimestampExpression) BoolExpression GT(rhs TimestampExpression) BoolExpression GT_EQ(rhs TimestampExpression) BoolExpression + + ADD(rhs Interval) TimestampExpression + SUB(rhs Interval) TimestampExpression } type timestampInterfaceImpl struct { @@ -51,6 +54,14 @@ func (t *timestampInterfaceImpl) GT_EQ(rhs TimestampExpression) BoolExpression { return gtEq(t.parent, rhs) } +func (t *timestampInterfaceImpl) ADD(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(t.parent, rhs, "+")) +} + +func (t *timestampInterfaceImpl) SUB(rhs Interval) TimestampExpression { + return TimestampExp(newBinaryOperatorExpression(t.parent, rhs, "-")) +} + //------------------------------------------------- type timestampExpressionWrapper struct { diff --git a/internal/jet/timestamp_expression_test.go b/internal/jet/timestamp_expression_test.go index 9a9ceb4..e34d8dd 100644 --- a/internal/jet/timestamp_expression_test.go +++ b/internal/jet/timestamp_expression_test.go @@ -53,3 +53,11 @@ func TestTimestampExp(t *testing.T) { assertClauseSerialize(t, TimestampExp(table1ColFloat).LT(timestamp), "(table1.col_float < $1)", "2000-01-31 10:20:00.003") } + +func TestTimestampArithmetic(t *testing.T) { + timestamp := Timestamp(2000, 1, 1, 0, 0, 0) + assertClauseDebugSerialize(t, table1ColTimestamp.ADD(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_timestamp + INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") + assertClauseDebugSerialize(t, table1ColTimestamp.SUB(NewInterval(String("1 HOUR"))).EQ(timestamp), + "((table1.col_timestamp - INTERVAL '1 HOUR') = '2000-01-01 00:00:00')") +} diff --git a/internal/jet/timestampz_expression.go b/internal/jet/timestampz_expression.go index 4f3e6ec..a9f8c9f 100644 --- a/internal/jet/timestampz_expression.go +++ b/internal/jet/timestampz_expression.go @@ -13,6 +13,9 @@ type TimestampzExpression interface { LT_EQ(rhs TimestampzExpression) BoolExpression GT(rhs TimestampzExpression) BoolExpression GT_EQ(rhs TimestampzExpression) BoolExpression + + ADD(rhs Interval) TimestampzExpression + SUB(rhs Interval) TimestampzExpression } type timestampzInterfaceImpl struct { @@ -51,13 +54,12 @@ func (t *timestampzInterfaceImpl) GT_EQ(rhs TimestampzExpression) BoolExpression return gtEq(t.parent, rhs) } -//---------------------------------------------------// +func (t *timestampzInterfaceImpl) ADD(rhs Interval) TimestampzExpression { + return TimestampzExp(newBinaryOperatorExpression(t.parent, rhs, "+")) +} -type prefixTimestampzOperator struct { - expressionInterfaceImpl - timestampzInterfaceImpl - - prefixOpExpression +func (t *timestampzInterfaceImpl) SUB(rhs Interval) TimestampzExpression { + return TimestampzExp(newBinaryOperatorExpression(t.parent, rhs, "-")) } //------------------------------------------------- diff --git a/internal/jet/timestampz_expression_test.go b/internal/jet/timestampz_expression_test.go index 6880c93..1ff1eac 100644 --- a/internal/jet/timestampz_expression_test.go +++ b/internal/jet/timestampz_expression_test.go @@ -53,3 +53,11 @@ func TestTimestampzExp(t *testing.T) { assertClauseSerialize(t, TimestampzExp(table1ColFloat).LT(timestampz), "(table1.col_float < $1)", "2000-01-31 10:20:05.000023 +200") } + +func TestTimestampzArithmetic(t *testing.T) { + timestampz := Timestampz(2000, 1, 1, 0, 0, 0, 100, "UTC") + assertClauseDebugSerialize(t, table1ColTimestampz.ADD(NewInterval(String("1 HOUR"))).EQ(timestampz), + "((table1.col_timestampz + INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')") + assertClauseDebugSerialize(t, table1ColTimestampz.SUB(NewInterval(String("1 HOUR"))).EQ(timestampz), + "((table1.col_timestampz - INTERVAL '1 HOUR') = '2000-01-01 00:00:00.0000001 UTC')") +} diff --git a/internal/jet/timez_expression.go b/internal/jet/timez_expression.go index 36b5c8f..c791c62 100644 --- a/internal/jet/timez_expression.go +++ b/internal/jet/timez_expression.go @@ -4,23 +4,18 @@ package jet type TimezExpression interface { Expression - //EQ EQ(rhs TimezExpression) BoolExpression - //NOT_EQ NOT_EQ(rhs TimezExpression) BoolExpression - //IS_DISTINCT_FROM IS_DISTINCT_FROM(rhs TimezExpression) BoolExpression - //IS_NOT_DISTINCT_FROM IS_NOT_DISTINCT_FROM(rhs TimezExpression) BoolExpression - //LT LT(rhs TimezExpression) BoolExpression - //LT_EQ LT_EQ(rhs TimezExpression) BoolExpression - //GT GT(rhs TimezExpression) BoolExpression - //GT_EQ GT_EQ(rhs TimezExpression) BoolExpression + + ADD(rhs Interval) TimezExpression + SUB(rhs Interval) TimezExpression } type timezInterfaceImpl struct { @@ -59,23 +54,13 @@ func (t *timezInterfaceImpl) GT_EQ(rhs TimezExpression) BoolExpression { return gtEq(t.parent, rhs) } -//---------------------------------------------------// -type prefixTimezExpression struct { - expressionInterfaceImpl - timezInterfaceImpl - - prefixOpExpression +func (t *timezInterfaceImpl) ADD(rhs Interval) TimezExpression { + return TimezExp(newBinaryOperatorExpression(t.parent, rhs, "+")) } -//func newPrefixTimezExpression(operator string, expression Expression) TimezExpression { -// timeExpr := prefixTimezExpression{} -// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) -// -// timeExpr.expressionInterfaceImpl.parent = &timeExpr -// timeExpr.timezInterfaceImpl.parent = &timeExpr -// -// return &timeExpr -//} +func (t *timezInterfaceImpl) SUB(rhs Interval) TimezExpression { + return TimezExp(newBinaryOperatorExpression(t.parent, rhs, "-")) +} //---------------------------------------------------// diff --git a/internal/jet/timez_expression_test.go b/internal/jet/timez_expression_test.go index 2a0312a..9f21c08 100644 --- a/internal/jet/timez_expression_test.go +++ b/internal/jet/timez_expression_test.go @@ -1,6 +1,8 @@ package jet -import "testing" +import ( + "testing" +) var timezVar = Timez(10, 20, 0, 0, "+4:00") @@ -49,3 +51,11 @@ func TestTimezExp(t *testing.T) { assertClauseSerialize(t, TimezExp(table1ColFloat).LT(Timez(1, 1, 1, 1, "+4:00")), "(table1.col_float < $1)", string("01:01:01.000000001 +4:00")) } + +func TestTimezArithmetic(t *testing.T) { + timez := Timez(0, 0, 0, 100, "UTC") + assertClauseDebugSerialize(t, table1ColTimez.ADD(NewInterval(String("1 HOUR"))).EQ(timez), + "((table1.col_timez + INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')") + assertClauseDebugSerialize(t, table1ColTimez.SUB(NewInterval(String("1 HOUR"))).EQ(timez), + "((table1.col_timez - INTERVAL '1 HOUR') = '00:00:00.0000001 UTC')") +} diff --git a/internal/jet/utils.go b/internal/jet/utils.go index fdaf1f6..58394f4 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -63,6 +63,17 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { } } +// ExpressionListToSerializerList converts list of expressions to list of serializers +func ExpressionListToSerializerList(expressions []Expression) []Serializer { + var ret []Serializer + + for _, expr := range expressions { + ret = append(ret, expr) + } + + return ret +} + // ColumnListToProjectionList func func ColumnListToProjectionList(columns []ColumnExpression) []Projection { var ret []Projection diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 7948504..c3d3ff0 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -129,6 +129,28 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali } } +// AssertDebugClauseSerialize checks if clause serialize produces expected debug query and args +func AssertDebugClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { + out := jet.SQLBuilder{Dialect: dialect, Debug: true} + jet.Serialize(clause, jet.SelectStatementType, &out) + + assert.DeepEqual(t, out.Buff.String(), query) + + if len(args) > 0 { + assert.DeepEqual(t, out.Args, args) + } +} + +// AssertPanicErr checks if running a function fun produces a panic with errorStr string +func AssertPanicErr(t *testing.T, fun func(), errorStr string) { + defer func() { + r := recover() + assert.Equal(t, r, errorStr) + }() + + fun() +} + // AssertClauseSerializeErr check if clause serialize panics with errString func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { defer func() { diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 22ea1c3..42a5c36 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -9,6 +9,7 @@ import ( "path/filepath" "reflect" "strings" + "time" ) // ToGoIdentifier converts database to Go identifier. @@ -182,3 +183,22 @@ func StringSliceContains(strings []string, contains string) bool { return false } + +// ExtractDateTimeComponents extracts number of days, hours, minutes, seconds, microseconds from duration +func ExtractDateTimeComponents(duration time.Duration) (days, hours, minutes, seconds, microseconds int64) { + days = int64(duration / (24 * time.Hour)) + reminder := duration % (24 * time.Hour) + + hours = int64(reminder / time.Hour) + reminder = reminder % time.Hour + + minutes = int64(reminder / time.Minute) + reminder = reminder % time.Minute + + seconds = int64(reminder / time.Second) + reminder = reminder % time.Second + + microseconds = int64(reminder / time.Microsecond) + + return +} diff --git a/mysql/cast_test.go b/mysql/cast_test.go index cc1a809..170cde8 100644 --- a/mysql/cast_test.go +++ b/mysql/cast_test.go @@ -5,14 +5,14 @@ import ( ) func TestCAST(t *testing.T) { - assertClauseSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`) - assertClauseSerialize(t, CAST(Int(22)).AS_CHAR(), `CAST(? AS CHAR)`) - assertClauseSerialize(t, CAST(Int(22)).AS_CHAR(10), `CAST(? AS CHAR(10))`) - assertClauseSerialize(t, CAST(Int(22)).AS_DATE(), `CAST(? AS DATE)`) - assertClauseSerialize(t, CAST(Int(22)).AS_DECIMAL(), `CAST(? AS DECIMAL)`) - assertClauseSerialize(t, CAST(Int(22)).AS_TIME(), `CAST(? AS TIME)`) - assertClauseSerialize(t, CAST(Int(22)).AS_DATETIME(), `CAST(? AS DATETIME)`) - assertClauseSerialize(t, CAST(Int(22)).AS_SIGNED(), `CAST(? AS SIGNED)`) - assertClauseSerialize(t, CAST(Int(22)).AS_UNSIGNED(), `CAST(? AS UNSIGNED)`) - assertClauseSerialize(t, CAST(Int(22)).AS_BINARY(), `CAST(? AS BINARY)`) + assertSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`) + assertSerialize(t, CAST(Int(22)).AS_CHAR(), `CAST(? AS CHAR)`) + assertSerialize(t, CAST(Int(22)).AS_CHAR(10), `CAST(? AS CHAR(10))`) + assertSerialize(t, CAST(Int(22)).AS_DATE(), `CAST(? AS DATE)`) + assertSerialize(t, CAST(Int(22)).AS_DECIMAL(), `CAST(? AS DECIMAL)`) + assertSerialize(t, CAST(Int(22)).AS_TIME(), `CAST(? AS TIME)`) + assertSerialize(t, CAST(Int(22)).AS_DATETIME(), `CAST(? AS DATETIME)`) + assertSerialize(t, CAST(Int(22)).AS_SIGNED(), `CAST(? AS SIGNED)`) + assertSerialize(t, CAST(Int(22)).AS_UNSIGNED(), `CAST(? AS UNSIGNED)`) + assertSerialize(t, CAST(Int(22)).AS_BINARY(), `CAST(? AS BINARY)`) } diff --git a/mysql/dialect.go b/mysql/dialect.go index 45509a7..cfd452a 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -8,7 +8,6 @@ import ( var Dialect = newDialect() func newDialect() jet.Dialect { - operatorSerializeOverrides := map[string]jet.SerializeOverride{} operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator @@ -32,7 +31,7 @@ func newDialect() jet.Dialect { return jet.NewDialect(mySQLDialectParams) } -func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator XOR") @@ -49,7 +48,7 @@ func mysqlBitXor(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlCONCAToperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator CONCAT") @@ -66,7 +65,7 @@ func mysqlCONCAToperator(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlDivision(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator DIV") @@ -90,7 +89,7 @@ func mysqlDivision(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlISNOTDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -102,7 +101,7 @@ func mysqlISNOTDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlISDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { out.WriteString("NOT(") mysqlISNOTDISTINCTFROM(expressions...)(statement, out, options...) @@ -110,7 +109,7 @@ func mysqlISDISTINCTFROM(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -136,7 +135,7 @@ func mysqlREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { } } -func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") diff --git a/mysql/dialect_test.go b/mysql/dialect_test.go index 936277a..92ece4d 100644 --- a/mysql/dialect_test.go +++ b/mysql/dialect_test.go @@ -5,37 +5,37 @@ import ( ) func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(NOT(table1.col_bool <=> table2.col_bool))") - assertClauseSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(NOT(table1.col_bool <=> ?))", false) + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(NOT(table1.col_bool <=> table2.col_bool))") + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(NOT(table1.col_bool <=> ?))", false) } func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool <=> table2.col_bool)") - assertClauseSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool <=> ?)", false) + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool <=> table2.col_bool)") + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool <=> ?)", false) } func TestBoolLiteral(t *testing.T) { - assertClauseSerialize(t, Bool(true), "?", true) - assertClauseSerialize(t, Bool(false), "?", false) + assertSerialize(t, Bool(true), "?", true) + assertSerialize(t, Bool(false), "?", false) } func TestIntegerExpressionDIV(t *testing.T) { - assertClauseSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int DIV table2.col_int)") - assertClauseSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int DIV ?)", int64(11)) + assertSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int DIV table2.col_int)") + assertSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int DIV ?)", int64(11)) } func TestIntExpressionPOW(t *testing.T) { - assertClauseSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)") - assertClauseSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11)) + assertSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)") + assertSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11)) } func TestIntExpressionBIT_XOR(t *testing.T) { - assertClauseSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)") - assertClauseSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11)) + assertSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "(table1.col_int ^ table2.col_int)") + assertSerialize(t, table1ColInt.BIT_XOR(Int(11)), "(table1.col_int ^ ?)", int64(11)) } func TestExists(t *testing.T) { - assertClauseSerialize(t, EXISTS( + assertSerialize(t, EXISTS( table2. SELECT(Int(1)). WHERE(table1Col1.EQ(table2Col3)), @@ -48,15 +48,15 @@ func TestExists(t *testing.T) { } func TestString_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP BINARY ?)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 REGEXP BINARY ?)", "JOHN") } func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 NOT REGEXP ?)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP BINARY ?)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 NOT REGEXP ?)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 NOT REGEXP BINARY ?)", "JOHN") } diff --git a/mysql/interval.go b/mysql/interval.go new file mode 100644 index 0000000..478e9c4 --- /dev/null +++ b/mysql/interval.go @@ -0,0 +1,195 @@ +package mysql + +import ( + "fmt" + "regexp" + "time" + + "github.com/go-jet/jet/internal/jet" + "github.com/go-jet/jet/internal/utils" +) + +type unitType string + +// List of interval unit types for MySQL +const ( + MICROSECOND unitType = "MICROSECOND" + SECOND = "SECOND" + MINUTE = "MINUTE" + HOUR = "HOUR" + DAY = "DAY" + WEEK = "WEEK" + MONTH = "MONTH" + QUARTER = "QUARTER" + YEAR = "YEAR" + SECOND_MICROSECOND = "SECOND_MICROSECOND" + MINUTE_MICROSECOND = "MINUTE_MICROSECOND" + MINUTE_SECOND = "MINUTE_SECOND" + HOUR_MICROSECOND = "HOUR_MICROSECOND" + HOUR_SECOND = "HOUR_SECOND" + HOUR_MINUTE = "HOUR_MINUTE" + DAY_MICROSECOND = "DAY_MICROSECOND" + DAY_SECOND = "DAY_SECOND" + DAY_MINUTE = "DAY_MINUTE" + DAY_HOUR = "DAY_HOUR" + YEAR_MONTH = "YEAR_MONTH" +) + +// Interval is representation of MySQL interval +type Interval = jet.Interval + +// INTERVAL creates new Interval type. +// In a case of MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR unit type +// value parameter should be number. For example: INTERVAL(1, DAY) +// In a case of other unit types, value should be string with appropriate format. +// For example: INTERVAL("10:08:50", HOUR_SECOND) +func INTERVAL(value interface{}, unitType unitType) Interval { + switch unitType { + case MICROSECOND, SECOND, MINUTE, HOUR, DAY, WEEK, MONTH, QUARTER, YEAR: + if !isNumericType(value) { + panic("jet: INTERVAL invalid value type. Numeric type expected") + } + return INTERVALe(jet.FixedLiteral(value), unitType) + default: + strValue, ok := value.(string) + + if !ok { + panic("jet: INTERNAL invalid value type. String type expected") + } + + var regexp *regexp.Regexp + + switch unitType { + case SECOND_MICROSECOND: + regexp = regexSecondMicrosecond + case MINUTE_MICROSECOND: + regexp = regexMinuteMicrosecond + case MINUTE_SECOND: + regexp = regexMinuteSecond + case HOUR_MICROSECOND: + regexp = regexHourMicrosecond + case HOUR_SECOND: + regexp = regexHourSecond + case HOUR_MINUTE: + regexp = regexHourMinute + case DAY_MICROSECOND: + regexp = regexDayMicrosecond + case DAY_SECOND: + regexp = regexDaySecond + case DAY_MINUTE: + regexp = regexDayMinute + case DAY_HOUR: + regexp = regexDayHour + case YEAR_MONTH: + regexp = regexYearMonth + default: + panic("jet: INTERVAL invalid unit type") + } + + if !regexp.MatchString(strValue) { + panic("jet: INTERVAL invalid format") + } + + return INTERVALe(jet.Literal(value), unitType) + } +} + +// INTERVALe creates new Interval type from expresion and unit type. +func INTERVALe(expr Expression, unitType unitType) Interval { + return jet.NewInterval(jet.ListSerializer{ + Serializers: []jet.Serializer{expr, jet.Raw(string(unitType))}, + Separator: " ", + }) +} + +// INTERVALd returns a interval representation from duration +func INTERVALd(duration time.Duration) Interval { + var sign int64 = 1 + if duration < 0 { + sign = -1 + duration = -duration + } + + days, hours, minutes, sec, microsec := utils.ExtractDateTimeComponents(duration) + + if days != 0 { + switch { + case microsec > 0: + intervalStr := fmt.Sprintf("%d %02d:%02d:%02d.%06d", sign*days, hours, minutes, sec, microsec) + return INTERVAL(intervalStr, DAY_MICROSECOND) + case sec > 0: + intervalStr := fmt.Sprintf("%d %02d:%02d:%02d", sign*days, hours, minutes, sec) + return INTERVAL(intervalStr, DAY_SECOND) + case minutes > 0: + intervalStr := fmt.Sprintf("%d %02d:%02d", sign*days, hours, minutes) + return INTERVAL(intervalStr, DAY_MINUTE) + case hours > 0: + intervalStr := fmt.Sprintf("%d %02d", sign*days, hours) + return INTERVAL(intervalStr, DAY_HOUR) + default: + return INTERVAL(sign*days, DAY) + } + } + + if hours != 0 { + switch { + case microsec > 0: + intervalStr := fmt.Sprintf("%02d:%02d:%02d.%06d", sign*hours, minutes, sec, microsec) + return INTERVAL(intervalStr, HOUR_MICROSECOND) + case sec > 0: + intervalStr := fmt.Sprintf("%02d:%02d:%02d", sign*hours, minutes, sec) + return INTERVAL(intervalStr, HOUR_SECOND) + case minutes > 0: + intervalStr := fmt.Sprintf("%02d:%02d", sign*hours, minutes) + return INTERVAL(intervalStr, HOUR_MINUTE) + default: + return INTERVAL(sign*hours, HOUR) + } + } + + if minutes != 0 { + switch { + case microsec > 0: + intervalStr := fmt.Sprintf("%02d:%02d.%06d", sign*minutes, sec, microsec) + return INTERVAL(intervalStr, MINUTE_MICROSECOND) + case sec > 0: + intervalStr := fmt.Sprintf("%02d:%02d", sign*minutes, sec) + return INTERVAL(intervalStr, MINUTE_SECOND) + default: + return INTERVAL(sign*minutes, MINUTE) + } + } + + if sec != 0 { + if microsec > 0 { + intervalStr := fmt.Sprintf("%02d.%06d", sign*sec, microsec) + return INTERVAL(intervalStr, SECOND_MICROSECOND) + } + return INTERVAL(sign*sec, SECOND) + } + + return INTERVAL(sign*microsec, MICROSECOND) +} + +var ( + regexSecondMicrosecond = regexp.MustCompile(`^-?\d{1,2}\.\d+$`) //'SECONDS.MICROSECONDS' + regexMinuteMicrosecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}\.\d+$`) //'MINUTE:SECONDS.MICROSECONDS' + regexMinuteSecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}$`) //'MINUTE:SECONDS' + regexHourMicrosecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}:\d{2}\.\d+$`) //'HOUR:MINUTE:SECONDS.MICROSECONDS' + regexHourSecond = regexp.MustCompile(`^-?\d{1,2}:\d{2}:\d{2}$`) //'HOUR:MINUTE:SECONDS' + regexHourMinute = regexp.MustCompile(`^-?\d{1,2}:\d{2}$`) //'HOUR:MINUTE' + regexDayMicrosecond = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}:\d{2}.\d+$`) //'DAY HOUR:MINUTE:SECONDS' + regexDaySecond = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}:\d{2}$`) //'DAY HOUR:MINUTE:SECONDS' + regexDayMinute = regexp.MustCompile(`^-?\d+ \d{1,2}:\d{2}$`) //'DAY HOUR:MINUTE' + regexDayHour = regexp.MustCompile(`^-?\d+ \d{1,2}$`) //'DAY HOUR:MINUTE' + regexYearMonth = regexp.MustCompile(`^-?\d+-\d{1,2}$`) //'YEAR-MONTH' +) + +func isNumericType(value interface{}) bool { + switch value.(type) { + case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return true + default: + return false + } +} diff --git a/mysql/interval_test.go b/mysql/interval_test.go new file mode 100644 index 0000000..c88b808 --- /dev/null +++ b/mysql/interval_test.go @@ -0,0 +1,99 @@ +package mysql + +import ( + "testing" + "time" +) + +func TestINTERVAL(t *testing.T) { + assertSerialize(t, INTERVAL("3-2", YEAR_MONTH), "INTERVAL ? YEAR_MONTH") + assertDebugSerialize(t, INTERVAL("3-2", YEAR_MONTH), "INTERVAL '3-2' YEAR_MONTH") + assertDebugSerialize(t, INTERVAL("-3-2", YEAR_MONTH), "INTERVAL '-3-2' YEAR_MONTH") + assertDebugSerialize(t, INTERVAL("10 25", DAY_HOUR), "INTERVAL '10 25' DAY_HOUR") + assertDebugSerialize(t, INTERVAL("-10 25", DAY_HOUR), "INTERVAL '-10 25' DAY_HOUR") + assertDebugSerialize(t, INTERVAL("10 25:15", DAY_MINUTE), "INTERVAL '10 25:15' DAY_MINUTE") + assertDebugSerialize(t, INTERVAL("-10 25:15", DAY_MINUTE), "INTERVAL '-10 25:15' DAY_MINUTE") + assertDebugSerialize(t, INTERVAL("10 25:15:08", DAY_SECOND), "INTERVAL '10 25:15:08' DAY_SECOND") + assertDebugSerialize(t, INTERVAL("-10 25:15:08", DAY_SECOND), "INTERVAL '-10 25:15:08' DAY_SECOND") + assertDebugSerialize(t, INTERVAL("10 25:15:08.000100", DAY_MICROSECOND), "INTERVAL '10 25:15:08.000100' DAY_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-10 25:15:08.000100", DAY_MICROSECOND), "INTERVAL '-10 25:15:08.000100' DAY_MICROSECOND") + assertDebugSerialize(t, INTERVAL("15:08", HOUR_MINUTE), "INTERVAL '15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("-15:08", HOUR_MINUTE), "INTERVAL '-15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("15:08", HOUR_MINUTE), "INTERVAL '15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("-15:08", HOUR_MINUTE), "INTERVAL '-15:08' HOUR_MINUTE") + assertDebugSerialize(t, INTERVAL("15:08:03", HOUR_SECOND), "INTERVAL '15:08:03' HOUR_SECOND") + assertDebugSerialize(t, INTERVAL("-15:08:03", HOUR_SECOND), "INTERVAL '-15:08:03' HOUR_SECOND") + assertDebugSerialize(t, INTERVAL("25:15:08.000100", HOUR_MICROSECOND), "INTERVAL '25:15:08.000100' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-25:15:08.000100", HOUR_MICROSECOND), "INTERVAL '-25:15:08.000100' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVAL("08:03", MINUTE_SECOND), "INTERVAL '08:03' MINUTE_SECOND") + assertDebugSerialize(t, INTERVAL("-08:03", MINUTE_SECOND), "INTERVAL '-08:03' MINUTE_SECOND") + assertDebugSerialize(t, INTERVAL("15:08.000100", MINUTE_MICROSECOND), "INTERVAL '15:08.000100' MINUTE_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-15:08.000100", MINUTE_MICROSECOND), "INTERVAL '-15:08.000100' MINUTE_MICROSECOND") + assertDebugSerialize(t, INTERVAL("08.000100", SECOND_MICROSECOND), "INTERVAL '08.000100' SECOND_MICROSECOND") + assertDebugSerialize(t, INTERVAL("-08.000100", SECOND_MICROSECOND), "INTERVAL '-08.000100' SECOND_MICROSECOND") + + assertSerialize(t, INTERVAL(15, SECOND), "INTERVAL 15 SECOND") + assertSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL 1 MICROSECOND") + assertSerialize(t, INTERVAL(2, MINUTE), "INTERVAL 2 MINUTE") + assertSerialize(t, INTERVAL(3, HOUR), "INTERVAL 3 HOUR") + assertSerialize(t, INTERVAL(4, DAY), "INTERVAL 4 DAY") + assertSerialize(t, INTERVAL(5, MONTH), "INTERVAL 5 MONTH") + assertSerialize(t, INTERVAL(6, YEAR), "INTERVAL 6 YEAR") + assertSerialize(t, INTERVAL(-6, YEAR), "INTERVAL -6 YEAR") + + assertSerialize(t, INTERVAL(uint(6), YEAR), "INTERVAL 6 YEAR") + assertSerialize(t, INTERVAL(int16(7), YEAR), "INTERVAL 7 YEAR") + assertSerialize(t, INTERVAL(3.5, YEAR), "INTERVAL 3.5 YEAR") +} + +func TestINTERVAL_InvalidUnitType(t *testing.T) { + assertPanicErr(t, func() { INTERVAL("11", HOUR) }, "jet: INTERVAL invalid value type. Numeric type expected") + assertPanicErr(t, func() { INTERVAL("11", YEAR_MONTH) }, "jet: INTERVAL invalid format") + assertPanicErr(t, func() { INTERVAL("11+11", YEAR_MONTH) }, "jet: INTERVAL invalid format") + assertPanicErr(t, func() { INTERVAL(156.11, YEAR_MONTH) }, "jet: INTERNAL invalid value type. String type expected") +} + +func TestINTERVALd(t *testing.T) { + assertDebugSerialize(t, INTERVALd(3*time.Microsecond), "INTERVAL 3 MICROSECOND") + assertDebugSerialize(t, INTERVALd(-1*time.Microsecond), "INTERVAL -1 MICROSECOND") + + assertDebugSerialize(t, INTERVALd(3*time.Second), "INTERVAL 3 SECOND") + assertDebugSerialize(t, INTERVALd(3*time.Second+4*time.Microsecond), "INTERVAL '03.000004' SECOND_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-1*time.Second), "INTERVAL -1 SECOND") + + assertDebugSerialize(t, INTERVALd(3*time.Minute), "INTERVAL 3 MINUTE") + assertDebugSerialize(t, INTERVALd(3*time.Minute+4*time.Second), "INTERVAL '03:04' MINUTE_SECOND") + assertDebugSerialize(t, INTERVALd(3*time.Minute+4*time.Second+5*time.Microsecond), "INTERVAL '03:04.000005' MINUTE_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-11*time.Minute), "INTERVAL -11 MINUTE") + assertDebugSerialize(t, INTERVALd(-11*time.Minute-22*time.Second), "INTERVAL '-11:22' MINUTE_SECOND") + + assertDebugSerialize(t, INTERVALd(3*time.Hour), "INTERVAL 3 HOUR") + assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute), "INTERVAL '03:04' HOUR_MINUTE") + assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute+5*time.Second), "INTERVAL '03:04:05' HOUR_SECOND") + assertDebugSerialize(t, INTERVALd(3*time.Hour+4*time.Minute+5*time.Second+6*time.Millisecond), "INTERVAL '03:04:05.006000' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-11*time.Hour), "INTERVAL -11 HOUR") + assertDebugSerialize(t, INTERVALd(-11*time.Hour-22*time.Minute), "INTERVAL '-11:22' HOUR_MINUTE") + + assertDebugSerialize(t, INTERVALd(3*24*time.Hour), "INTERVAL 3 DAY") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour), "INTERVAL '3 04' DAY_HOUR") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute), "INTERVAL '3 04:05' DAY_MINUTE") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute+6*time.Second), "INTERVAL '3 04:05:06' DAY_SECOND") + assertDebugSerialize(t, INTERVALd(3*24*time.Hour+4*time.Hour+5*time.Minute+6*time.Second+7*time.Microsecond), "INTERVAL '3 04:05:06.000007' DAY_MICROSECOND") + + assertDebugSerialize(t, INTERVALd(-11*24*time.Hour), "INTERVAL -11 DAY") + + assertDebugSerialize(t, INTERVALd(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond), "INTERVAL '01:02:03.000345' HOUR_MICROSECOND") + assertDebugSerialize(t, INTERVALd(-1*(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond)), "INTERVAL '-1:02:03.000345' HOUR_MICROSECOND") +} + +func TestINTERVALe(t *testing.T) { + assertSerialize(t, INTERVALe(table1ColFloat, MICROSECOND), "INTERVAL table1.col_float MICROSECOND") + assertSerialize(t, INTERVALe(table1ColFloat, SECOND), "INTERVAL table1.col_float SECOND") + assertSerialize(t, INTERVALe(table1ColFloat, MINUTE), "INTERVAL table1.col_float MINUTE") + assertSerialize(t, INTERVALe(table1ColFloat, HOUR), "INTERVAL table1.col_float HOUR") + assertSerialize(t, INTERVALe(table1ColFloat, DAY), "INTERVAL table1.col_float DAY") + assertSerialize(t, INTERVALe(table1ColFloat, WEEK), "INTERVAL table1.col_float WEEK") + assertSerialize(t, INTERVALe(table1ColFloat, MONTH), "INTERVAL table1.col_float MONTH") + assertSerialize(t, INTERVALe(table1ColFloat, QUARTER), "INTERVAL table1.col_float QUARTER") + assertSerialize(t, INTERVALe(table1ColFloat, YEAR), "INTERVAL table1.col_float YEAR") +} diff --git a/mysql/literal_test.go b/mysql/literal_test.go index f677d11..09d331e 100644 --- a/mysql/literal_test.go +++ b/mysql/literal_test.go @@ -6,37 +6,37 @@ import ( ) func TestBool(t *testing.T) { - assertClauseSerialize(t, Bool(false), `?`, false) + assertSerialize(t, Bool(false), `?`, false) } func TestInt(t *testing.T) { - assertClauseSerialize(t, Int(11), `?`, int64(11)) + assertSerialize(t, Int(11), `?`, int64(11)) } func TestFloat(t *testing.T) { - assertClauseSerialize(t, Float(12.34), `?`, float64(12.34)) + assertSerialize(t, Float(12.34), `?`, float64(12.34)) } func TestString(t *testing.T) { - assertClauseSerialize(t, String("Some text"), `?`, "Some text") + assertSerialize(t, String("Some text"), `?`, "Some text") } func TestDate(t *testing.T) { - assertClauseSerialize(t, Date(2014, time.January, 2), `CAST(? AS DATE)`, "2014-01-02") - assertClauseSerialize(t, DateT(time.Now()), `CAST(? AS DATE)`) + assertSerialize(t, Date(2014, time.January, 2), `CAST(? AS DATE)`, "2014-01-02") + assertSerialize(t, DateT(time.Now()), `CAST(? AS DATE)`) } func TestTime(t *testing.T) { - assertClauseSerialize(t, Time(10, 15, 30), `CAST(? AS TIME)`, "10:15:30") - assertClauseSerialize(t, TimeT(time.Now()), `CAST(? AS TIME)`) + assertSerialize(t, Time(10, 15, 30), `CAST(? AS TIME)`, "10:15:30") + assertSerialize(t, TimeT(time.Now()), `CAST(? AS TIME)`) } func TestDateTime(t *testing.T) { - assertClauseSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `CAST(? AS DATETIME)`, "2010-03-30 10:15:30") - assertClauseSerialize(t, DateTimeT(time.Now()), `CAST(? AS DATETIME)`) + assertSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `CAST(? AS DATETIME)`, "2010-03-30 10:15:30") + assertSerialize(t, DateTimeT(time.Now()), `CAST(? AS DATETIME)`) } func TestTimestamp(t *testing.T) { - assertClauseSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30") - assertClauseSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`) + assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `TIMESTAMP(?)`, "2010-03-30 10:15:30") + assertSerialize(t, TimestampT(time.Now()), `TIMESTAMP(?)`) } diff --git a/mysql/table_test.go b/mysql/table_test.go index da45f36..3894378 100644 --- a/mysql/table_test.go +++ b/mysql/table_test.go @@ -12,17 +12,17 @@ func TestJoinNilInputs(t *testing.T) { } func TestINNER_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int) INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(Int(1))). INNER_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -31,17 +31,17 @@ INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestLEFT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -50,17 +50,17 @@ LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestRIGHT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -69,17 +69,17 @@ RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestFULL_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int) FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(Int(1))). FULL_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -88,11 +88,11 @@ FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) } func TestCROSS_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2), `db.table1 CROSS JOIN db.table2`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2). CROSS_JOIN(table3), `db.table1 diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 5804a07..1cc42f1 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -58,10 +58,14 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { +func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } +func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { + testutils.AssertDebugClauseSerialize(t, Dialect, clause, query, args...) +} + func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) } @@ -70,5 +74,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...) } +var assertPanicErr = testutils.AssertPanicErr var assertStatementSql = testutils.AssertStatementSql var assertStatementSqlErr = testutils.AssertStatementSqlErr diff --git a/postgres/cast.go b/postgres/cast.go index e9ec209..0f4b255 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -2,8 +2,9 @@ package postgres import ( "fmt" - "github.com/go-jet/jet/internal/jet" "strconv" + + "github.com/go-jet/jet/internal/jet" ) type cast interface { @@ -32,7 +33,7 @@ type cast interface { AS_TIME() TimeExpression // Cast expression AS text type AS_TEXT() StringExpression - + // Cast expression AS bytea type AS_BYTEA() StringExpression // Cast expression AS time with time timezone type AS_TIMEZ() TimezExpression @@ -40,6 +41,8 @@ type cast interface { AS_TIMESTAMP() TimestampExpression // Cast expression AS timestamp with timezone type AS_TIMESTAMPZ() TimestampzExpression + // Cast expression AS interval type + AS_INTERVAL() IntervalExpression } type castImpl struct { @@ -151,3 +154,8 @@ func (b *castImpl) AS_TIMESTAMP() TimestampExpression { func (b *castImpl) AS_TIMESTAMPZ() TimestampzExpression { return TimestampzExp(b.AS("timestamp with time zone")) } + +// Cast expression AS interval type +func (b *castImpl) AS_INTERVAL() IntervalExpression { + return IntervalExp(b.AS("interval")) +} diff --git a/postgres/cast_test.go b/postgres/cast_test.go index 4537784..e02336a 100644 --- a/postgres/cast_test.go +++ b/postgres/cast_test.go @@ -5,60 +5,67 @@ import ( ) func TestExpressionCAST_AS(t *testing.T) { - assertClauseSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") + assertSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") } func TestExpressionCAST_AS_BOOL(t *testing.T) { - assertClauseSerialize(t, CAST(Int(1)).AS_BOOL(), "$1::boolean", int64(1)) - assertClauseSerialize(t, CAST(table2Col3).AS_BOOL(), "table2.col3::boolean") - assertClauseSerialize(t, CAST(table2Col3.ADD(table2Col3)).AS_BOOL(), "(table2.col3 + table2.col3)::boolean") + assertSerialize(t, CAST(Int(1)).AS_BOOL(), "$1::boolean", int64(1)) + assertSerialize(t, CAST(table2Col3).AS_BOOL(), "table2.col3::boolean") + assertSerialize(t, CAST(table2Col3.ADD(table2Col3)).AS_BOOL(), "(table2.col3 + table2.col3)::boolean") } func TestExpressionCAST_AS_SMALLINT(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_SMALLINT(), "table2.col3::smallint") + assertSerialize(t, CAST(table2Col3).AS_SMALLINT(), "table2.col3::smallint") } func TestExpressionCAST_AS_INTEGER(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_INTEGER(), "table2.col3::integer") + assertSerialize(t, CAST(table2Col3).AS_INTEGER(), "table2.col3::integer") } func TestExpressionCAST_AS_BIGINT(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_BIGINT(), "table2.col3::bigint") + assertSerialize(t, CAST(table2Col3).AS_BIGINT(), "table2.col3::bigint") } func TestExpressionCAST_AS_NUMERIC(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_NUMERIC(11, 11), "table2.col3::numeric(11, 11)") - assertClauseSerialize(t, CAST(table2Col3).AS_NUMERIC(11), "table2.col3::numeric(11)") + assertSerialize(t, CAST(table2Col3).AS_NUMERIC(11, 11), "table2.col3::numeric(11, 11)") + assertSerialize(t, CAST(table2Col3).AS_NUMERIC(11), "table2.col3::numeric(11)") } func TestExpressionCAST_AS_REAL(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_REAL(), "table2.col3::real") + assertSerialize(t, CAST(table2Col3).AS_REAL(), "table2.col3::real") } func TestExpressionCAST_AS_DOUBLE(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_DOUBLE(), "table2.col3::double precision") + assertSerialize(t, CAST(table2Col3).AS_DOUBLE(), "table2.col3::double precision") } func TestExpressionCAST_AS_TEXT(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TEXT(), "table2.col3::text") + assertSerialize(t, CAST(table2Col3).AS_TEXT(), "table2.col3::text") } func TestExpressionCAST_AS_DATE(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::date") + assertSerialize(t, CAST(table2Col3).AS_DATE(), "table2.col3::date") } func TestExpressionCAST_AS_TIME(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIME(), "table2.col3::time without time zone") + assertSerialize(t, CAST(table2Col3).AS_TIME(), "table2.col3::time without time zone") } func TestExpressionCAST_AS_TIMEZ(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIMEZ(), "table2.col3::time with time zone") + assertSerialize(t, CAST(table2Col3).AS_TIMEZ(), "table2.col3::time with time zone") } func TestExpressionCAST_AS_TIMESTAMP(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIMESTAMP(), "table2.col3::timestamp without time zone") + assertSerialize(t, CAST(table2Col3).AS_TIMESTAMP(), "table2.col3::timestamp without time zone") } func TestExpressionCAST_AS_TIMESTAMPZ(t *testing.T) { - assertClauseSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone") + assertSerialize(t, CAST(table2Col3).AS_TIMESTAMPZ(), "table2.col3::timestamp with time zone") +} + +func TestExpressionCAST_AS_INTERVAL(t *testing.T) { + assertSerialize(t, CAST(table2ColTimez).AS_INTERVAL(), "table2.col_timez::interval") + assertSerialize(t, CAST(Time(20, 11, 10)).AS_INTERVAL(), "$1::time without time zone::interval", "20:11:10") + assertSerialize(t, table2ColDate.SUB(CAST(Time(20, 11, 10)).AS_INTERVAL()), + "(table2.col_date - $1::time without time zone::interval)", "20:11:10") } diff --git a/postgres/dialect.go b/postgres/dialect.go index 114e5a6..c1e8c0b 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -29,7 +29,7 @@ func newDialect() jet.Dialect { return jet.NewDialect(dialectParams) } -func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { +func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -54,7 +54,7 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { } } -func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func postgresREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") @@ -80,7 +80,7 @@ func postgresREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc } } -func postgresNOTREGEXPLIKEoperator(expressions ...jet.Expression) jet.SerializeFunc { +func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { panic("jet: invalid number of expressions for operator") diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index b6061b7..f53587e 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -3,21 +3,21 @@ package postgres import "testing" func TestString_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") } func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") - assertClauseSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") } func TestExists(t *testing.T) { - assertClauseSerialize(t, EXISTS( + assertSerialize(t, EXISTS( table2. SELECT(Int(1)). WHERE(table1Col1.EQ(table2Col3)), @@ -27,17 +27,31 @@ func TestExists(t *testing.T) { FROM db.table2 WHERE table1.col1 = table2.col3 ))`, int64(1)) + + assertSerialize(t, EXISTS( + SELECT(Int(1)), + ).EQ(Bool(true)), + `((EXISTS ( + SELECT $1 +)) = $2)`, int64(1), true) + + assertProjectionSerialize(t, EXISTS( + SELECT(Int(1)), + ).AS("exists"), + `(EXISTS ( + SELECT $1 +)) AS "exists"`, int64(1)) } func TestIN(t *testing.T) { - assertClauseSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), + assertSerialize(t, Float(1.11).IN(table1.SELECT(table1Col1)), `($1 IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 )))`, float64(1.11)) - assertClauseSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), + assertSerialize(t, ROW(Int(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)), `(ROW($1, table1.col1) IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" @@ -47,13 +61,13 @@ func TestIN(t *testing.T) { func TestNOT_IN(t *testing.T) { - assertClauseSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), + assertSerialize(t, Float(1.11).NOT_IN(table1.SELECT(table1Col1)), `($1 NOT IN (( SELECT table1.col1 AS "table1.col1" FROM db.table1 )))`, float64(1.11)) - assertClauseSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), + assertSerialize(t, ROW(Int(12), table1Col1).NOT_IN(table2.SELECT(table2Col3, table3Col1)), `(ROW($1, table1.col1) NOT IN (( SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" diff --git a/postgres/interval.go b/postgres/interval.go new file mode 100644 index 0000000..de659e7 --- /dev/null +++ b/postgres/interval.go @@ -0,0 +1,153 @@ +package postgres + +import ( + "fmt" + "github.com/go-jet/jet/internal/jet" + "github.com/go-jet/jet/internal/utils" + "strconv" + "strings" + "time" +) + +type quantityAndUnit = float64 + +// Interval unit types +const ( + YEAR quantityAndUnit = 123456789 + iota + MONTH + WEEK + DAY + HOUR + MINUTE + SECOND + MILLISECOND + MICROSECOND + DECADE + CENTURY + MILLENNIUM +) + +type intervalExpressionImpl struct { + jet.Interval + jet.ExpressionInterfaceImpl +} + +// IntervalExpression is representation of postgres INTERVAL +type IntervalExpression interface { + jet.IsInterval + jet.Expression +} + +// INTERVAL creates new interval expression from the list of quantity-unit pairs. +// For example: INTERVAL(1, DAY, 3, MINUTE) +func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { + if len(quantityAndUnit)%2 != 0 { + panic("jet: invalid number of quantity and unit fields") + } + + fields := []string{} + + for i := 0; i < len(quantityAndUnit); i += 2 { + quantity := strconv.FormatFloat(float64(quantityAndUnit[i]), 'f', -1, 64) + unitString := unitToString(quantityAndUnit[i+1]) + fields = append(fields, quantity+" "+unitString) + } + + intervalStr := fmt.Sprintf("'%s'", strings.Join(fields, " ")) + + newInterval := &intervalExpressionImpl{ + Interval: jet.NewInterval(jet.Raw(intervalStr)), + } + + newInterval.ExpressionInterfaceImpl.Parent = newInterval + + return newInterval +} + +// INTERVALd creates interval expression from duration +func INTERVALd(duration time.Duration) IntervalExpression { + days, hours, minutes, seconds, microseconds := utils.ExtractDateTimeComponents(duration) + + quantityAndUnits := []quantityAndUnit{} + + if days > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(days)) + quantityAndUnits = append(quantityAndUnits, DAY) + } + + if hours > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(hours)) + quantityAndUnits = append(quantityAndUnits, HOUR) + } + + if minutes > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(minutes)) + quantityAndUnits = append(quantityAndUnits, MINUTE) + } + + if seconds > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(seconds)) + quantityAndUnits = append(quantityAndUnits, SECOND) + } + + if microseconds > 0 { + quantityAndUnits = append(quantityAndUnits, quantityAndUnit(microseconds)) + quantityAndUnits = append(quantityAndUnits, MICROSECOND) + } + + if len(quantityAndUnits) == 0 { + return INTERVAL(0, MICROSECOND) + } + + return INTERVAL(quantityAndUnits...) +} + +func unitToString(unit quantityAndUnit) string { + switch unit { + case YEAR: + return "YEAR" + case MONTH: + return "MONTH" + case WEEK: + return "WEEK" + case DAY: + return "DAY" + case HOUR: + return "HOUR" + case MINUTE: + return "MINUTE" + case SECOND: + return "SECOND" + case MILLISECOND: + return "MILLISECOND" + case MICROSECOND: + return "MICROSECOND" + case DECADE: + return "DECADE" + case CENTURY: + return "CENTURY" + case MILLENNIUM: + return "MILLENNIUM" + default: + panic("jet: invalid INTERVAL unit type") + } +} + +//---------------------------------------------------// + +type intervalWrapper struct { + jet.IsInterval + Expression +} + +func newIntervalExpressionWrap(expression Expression) IntervalExpression { + intervalWrap := intervalWrapper{Expression: expression} + return &intervalWrap +} + +// IntervalExp is interval expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as interval expression. +// Does not add sql cast to generated sql builder output. +func IntervalExp(expression Expression) IntervalExpression { + return newIntervalExpressionWrap(expression) +} diff --git a/postgres/interval_test.go b/postgres/interval_test.go new file mode 100644 index 0000000..785f1d5 --- /dev/null +++ b/postgres/interval_test.go @@ -0,0 +1,62 @@ +package postgres + +import ( + "testing" + "time" +) + +func TestINTERVAL(t *testing.T) { + assertSerialize(t, INTERVAL(1, YEAR), "INTERVAL '1 YEAR'") + assertSerialize(t, INTERVAL(1, MONTH), "INTERVAL '1 MONTH'") + assertSerialize(t, INTERVAL(1, WEEK), "INTERVAL '1 WEEK'") + assertSerialize(t, INTERVAL(1, DAY), "INTERVAL '1 DAY'") + assertSerialize(t, INTERVAL(1, HOUR), "INTERVAL '1 HOUR'") + assertSerialize(t, INTERVAL(1, MINUTE), "INTERVAL '1 MINUTE'") + assertSerialize(t, INTERVAL(1, SECOND), "INTERVAL '1 SECOND'") + assertSerialize(t, INTERVAL(1, MILLISECOND), "INTERVAL '1 MILLISECOND'") + assertSerialize(t, INTERVAL(1, MICROSECOND), "INTERVAL '1 MICROSECOND'") + assertSerialize(t, INTERVAL(1, DECADE), "INTERVAL '1 DECADE'") + assertSerialize(t, INTERVAL(1, CENTURY), "INTERVAL '1 CENTURY'") + assertSerialize(t, INTERVAL(1, MILLENNIUM), "INTERVAL '1 MILLENNIUM'") + + assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH), "INTERVAL '1 YEAR 10 MONTH'") + assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY), "INTERVAL '1 YEAR 10 MONTH 20 DAY'") + assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), "INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR'") + + assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL") + assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`) + + f := 5.2 + assertSerialize(t, INTERVAL(f, YEAR), "INTERVAL '5.2 YEAR'") +} + +func TestINTERVALd(t *testing.T) { + assertSerialize(t, INTERVALd(0), "INTERVAL '0 MICROSECOND'") + assertSerialize(t, INTERVALd(1*time.Microsecond), "INTERVAL '1 MICROSECOND'") + assertSerialize(t, INTERVALd(1*time.Millisecond), "INTERVAL '1000 MICROSECOND'") + assertSerialize(t, INTERVALd(1*time.Second), "INTERVAL '1 SECOND'") + assertSerialize(t, INTERVALd(1*time.Minute), "INTERVAL '1 MINUTE'") + assertSerialize(t, INTERVALd(1*time.Hour), "INTERVAL '1 HOUR'") + assertSerialize(t, INTERVALd(24*time.Hour), "INTERVAL '1 DAY'") + + assertSerialize(t, INTERVALd(24*time.Hour+2*time.Hour+3*time.Minute+4*time.Second+5*time.Microsecond), + "INTERVAL '1 DAY 2 HOUR 3 MINUTE 4 SECOND 5 MICROSECOND'") +} + +func TestINTERVAL_InvalidParams(t *testing.T) { + assertPanicErr(t, func() { INTERVAL(1) }, "jet: invalid number of quantity and unit fields") + assertPanicErr(t, func() { INTERVAL(1, 2) }, "jet: invalid INTERVAL unit type") +} + +func TestIntervalArithmetic(t *testing.T) { + assertSerialize(t, table2ColDate.ADD(INTERVAL(1, HOUR)), "(table2.col_date + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColDate.SUB(INTERVAL(1, HOUR)), "(table2.col_date - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTime.ADD(INTERVAL(1, HOUR)), "(table2.col_time + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTime.SUB(INTERVAL(1, HOUR)), "(table2.col_time - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimez.ADD(INTERVAL(1, HOUR)), "(table2.col_timez + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimez.SUB(INTERVAL(1, HOUR)), "(table2.col_timez - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestamp.ADD(INTERVAL(1, HOUR)), "(table2.col_timestamp + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestamp.SUB(INTERVAL(1, HOUR)), "(table2.col_timestamp - INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestampz.ADD(INTERVAL(1, HOUR)), "(table2.col_timestampz + INTERVAL '1 HOUR')") + assertSerialize(t, table2ColTimestampz.SUB(INTERVAL(1, HOUR)), "(table2.col_timestampz - INTERVAL '1 HOUR')") +} diff --git a/postgres/literal_test.go b/postgres/literal_test.go index f30ef01..5206aaa 100644 --- a/postgres/literal_test.go +++ b/postgres/literal_test.go @@ -6,45 +6,45 @@ import ( ) func TestBool(t *testing.T) { - assertClauseSerialize(t, Bool(false), `$1`, false) + assertSerialize(t, Bool(false), `$1`, false) } func TestInt(t *testing.T) { - assertClauseSerialize(t, Int(11), `$1`, int64(11)) + assertSerialize(t, Int(11), `$1`, int64(11)) } func TestFloat(t *testing.T) { - assertClauseSerialize(t, Float(12.34), `$1`, float64(12.34)) + assertSerialize(t, Float(12.34), `$1`, float64(12.34)) } func TestString(t *testing.T) { - assertClauseSerialize(t, String("Some text"), `$1`, "Some text") + assertSerialize(t, String("Some text"), `$1`, "Some text") } func TestDate(t *testing.T) { - assertClauseSerialize(t, Date(2014, time.January, 2), `$1::date`, "2014-01-02") - assertClauseSerialize(t, DateT(time.Now()), `$1::date`) + assertSerialize(t, Date(2014, time.January, 2), `$1::date`, "2014-01-02") + assertSerialize(t, DateT(time.Now()), `$1::date`) } func TestTime(t *testing.T) { - assertClauseSerialize(t, Time(10, 15, 30), `$1::time without time zone`, "10:15:30") - assertClauseSerialize(t, TimeT(time.Now()), `$1::time without time zone`) + assertSerialize(t, Time(10, 15, 30), `$1::time without time zone`, "10:15:30") + assertSerialize(t, TimeT(time.Now()), `$1::time without time zone`) } func TestTimez(t *testing.T) { - assertClauseSerialize(t, Timez(10, 15, 30, 0, "UTC"), + assertSerialize(t, Timez(10, 15, 30, 0, "UTC"), `$1::time with time zone`, "10:15:30 UTC") - assertClauseSerialize(t, TimezT(time.Now()), `$1::time with time zone`) + assertSerialize(t, TimezT(time.Now()), `$1::time with time zone`) } func TestTimestamp(t *testing.T) { - assertClauseSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), + assertSerialize(t, Timestamp(2010, time.March, 30, 10, 15, 30), `$1::timestamp without time zone`, "2010-03-30 10:15:30") - assertClauseSerialize(t, TimestampT(time.Now()), `$1::timestamp without time zone`) + assertSerialize(t, TimestampT(time.Now()), `$1::timestamp without time zone`) } func TestTimestampz(t *testing.T) { - assertClauseSerialize(t, Timestampz(2010, time.March, 30, 10, 15, 30, 0, "UTC"), + assertSerialize(t, Timestampz(2010, time.March, 30, 10, 15, 30, 0, "UTC"), `$1::timestamp with time zone`, "2010-03-30 10:15:30 UTC") - assertClauseSerialize(t, TimestampzT(time.Now()), `$1::timestamp with time zone`) + assertSerialize(t, TimestampzT(time.Now()), `$1::timestamp with time zone`) } diff --git a/postgres/table_test.go b/postgres/table_test.go index 6573b02..43aa096 100644 --- a/postgres/table_test.go +++ b/postgres/table_test.go @@ -12,17 +12,17 @@ func TestJoinNilInputs(t *testing.T) { } func TestINNER_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int) INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. INNER_JOIN(table2, table1ColInt.EQ(Int(1))). INNER_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -31,17 +31,17 @@ INNER JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestLEFT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -50,17 +50,17 @@ LEFT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestRIGHT_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -69,17 +69,17 @@ RIGHT JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestFULL_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), `db.table1 FULL JOIN db.table2 ON (table1.col_int = table2.col_int) FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. FULL_JOIN(table2, table1ColInt.EQ(Int(1))). FULL_JOIN(table3, table1ColInt.EQ(Int(2))), `db.table1 @@ -88,11 +88,11 @@ FULL JOIN db.table3 ON (table1.col_int = $2)`, int64(1), int64(2)) } func TestCROSS_JOIN(t *testing.T) { - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2), `db.table1 CROSS JOIN db.table2`) - assertClauseSerialize(t, table1. + assertSerialize(t, table1. CROSS_JOIN(table2). CROSS_JOIN(table3), `db.table1 diff --git a/postgres/utils_test.go b/postgres/utils_test.go index c65d5b6..4a80954 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -1,9 +1,10 @@ package postgres import ( + "testing" + "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" - "testing" ) var table1Col1 = IntegerColumn("col1") @@ -70,7 +71,7 @@ var table3 = NewTable( table3ColInt, table3StrCol) -func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { +func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } @@ -84,3 +85,4 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st var assertStatementSql = testutils.AssertStatementSql var assertStatementSqlErr = testutils.AssertStatementSqlErr +var assertPanicErr = testutils.AssertPanicErr diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index f0f7406..361bb8d 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,19 +1,20 @@ package mysql import ( - "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" - "github.com/google/uuid" - "time" . "github.com/go-jet/jet/mysql" "gotest.tools/assert" - "testing" ) func TestAllTypes(t *testing.T) { @@ -506,15 +507,11 @@ func TestStringOperators(t *testing.T) { REGEXP_LIKE(AllTypes.Text, String("aba"), "i"), }...) } - //_, args, _ := query.Sql() - - //fmt.Println(query.Sql()) - //fmt.Println(args[15]) query := SELECT(projectionList[0], projectionList[1:]...). FROM(AllTypes) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) dest := []struct{}{} err := query.Query(db, &dest) @@ -555,32 +552,49 @@ func TestTimeExpressions(t *testing.T) { AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(Time(14, 26, 36)), + AllTypes.Time.ADD(INTERVAL(10, MINUTE)), + AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)), + AllTypes.Time.ADD(INTERVALd(3*time.Hour)), + + AllTypes.Time.SUB(INTERVAL(20, MINUTE)), + AllTypes.Time.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + AllTypes.Time.SUB(INTERVALd(3*time.Minute)), + + AllTypes.Time.ADD(INTERVAL(20, MINUTE)).SUB(INTERVAL(11, HOUR)), + CURRENT_TIME(), CURRENT_TIME(3), ) - //fmt.Println(query.Sql()) + //fmt.Println(query.DebugSql()) - testutils.AssertStatementSql(t, query, ` -SELECT CAST(? AS TIME), + testutils.AssertDebugStatementSql(t, query, ` +SELECT CAST('20:34:58' AS TIME), all_types.time = all_types.time, - all_types.time = CAST(? AS TIME), - all_types.time = CAST(? AS TIME), - all_types.time = CAST(? AS TIME), + all_types.time = CAST('23:06:06' AS TIME), + all_types.time = CAST('22:06:06.011' AS TIME), + all_types.time = CAST('21:06:06.011111' AS TIME), all_types.time_ptr != all_types.time, - all_types.time_ptr != CAST(? AS TIME), + all_types.time_ptr != CAST('20:16:06' AS TIME), NOT(all_types.time <=> all_types.time), - NOT(all_types.time <=> CAST(? AS TIME)), + NOT(all_types.time <=> CAST('19:26:06' AS TIME)), all_types.time <=> all_types.time, - all_types.time <=> CAST(? AS TIME), + all_types.time <=> CAST('18:36:06' AS TIME), all_types.time < all_types.time, - all_types.time < CAST(? AS TIME), + all_types.time < CAST('17:46:06' AS TIME), all_types.time <= all_types.time, - all_types.time <= CAST(? AS TIME), + all_types.time <= CAST('16:56:56' AS TIME), all_types.time > all_types.time, - all_types.time > CAST(? AS TIME), + all_types.time > CAST('15:16:46' AS TIME), all_types.time >= all_types.time, - all_types.time >= CAST(? AS TIME), + all_types.time >= CAST('14:26:36' AS TIME), + all_types.time + INTERVAL 10 MINUTE, + all_types.time + INTERVAL all_types.integer MINUTE, + all_types.time + INTERVAL 3 HOUR, + all_types.time - INTERVAL 20 MINUTE, + all_types.time - INTERVAL all_types.small_int MINUTE, + all_types.time - INTERVAL 3 MINUTE, + (all_types.time + INTERVAL 20 MINUTE) - INTERVAL 11 HOUR, CURRENT_TIME, CURRENT_TIME(3) FROM test_sample.all_types; @@ -621,10 +635,18 @@ func TestDateExpressions(t *testing.T) { AllTypes.Date.GT_EQ(AllTypes.Date), AllTypes.Date.GT_EQ(Date(2019, 2, 3)), + AllTypes.Date.ADD(INTERVAL("10:20.000100", MINUTE_MICROSECOND)), + AllTypes.Date.ADD(INTERVALe(AllTypes.BigInt, MINUTE)), + AllTypes.Date.ADD(INTERVALd(15*time.Hour)), + + AllTypes.Date.SUB(INTERVAL(20, MINUTE)), + AllTypes.Date.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + AllTypes.Date.SUB(INTERVALd(3*time.Minute)), + CURRENT_DATE(), ) - //fmt.Println(query.Sql()) + //fmt.Println(query.DebugSql()) testutils.AssertStatementSql(t, query, ` SELECT CAST(? AS DATE), @@ -644,6 +666,12 @@ SELECT CAST(? AS DATE), all_types.date > CAST(? AS DATE), all_types.date >= all_types.date, all_types.date >= CAST(? AS DATE), + all_types.date + INTERVAL ? MINUTE_MICROSECOND, + all_types.date + INTERVAL all_types.big_int MINUTE, + all_types.date + INTERVAL 15 HOUR, + all_types.date - INTERVAL 20 MINUTE, + all_types.date - INTERVAL all_types.small_int MINUTE, + all_types.date - INTERVAL 3 MINUTE, CURRENT_DATE FROM test_sample.all_types; `) @@ -683,11 +711,19 @@ func TestDateTimeExpressions(t *testing.T) { AllTypes.DateTime.GT_EQ(AllTypes.DateTime), AllTypes.DateTime.GT_EQ(dateTime), + AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)), + AllTypes.DateTime.ADD(INTERVALd(2*time.Hour)), + + AllTypes.DateTime.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.DateTime.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + AllTypes.DateTime.SUB(INTERVALd(3*time.Hour)), + NOW(), NOW(1), ) - //fmt.Println(query.DebugSql()) + //Println(query.DebugSql()) testutils.AssertDebugStatementSql(t, query, ` SELECT all_types.date_time = all_types.date_time, @@ -706,6 +742,12 @@ SELECT all_types.date_time = all_types.date_time, all_types.date_time > CAST('2019-06-06 10:02:46' AS DATETIME), all_types.date_time >= all_types.date_time, all_types.date_time >= CAST('2019-06-06 10:02:46' AS DATETIME), + all_types.date_time + INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.date_time + INTERVAL all_types.big_int HOUR, + all_types.date_time + INTERVAL 2 HOUR, + all_types.date_time - INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.date_time - INTERVAL all_types.integer_ptr HOUR, + all_types.date_time - INTERVAL 3 HOUR, NOW(), NOW(1) FROM test_sample.all_types; @@ -746,6 +788,14 @@ func TestTimestampExpressions(t *testing.T) { AllTypes.Timestamp.GT_EQ(AllTypes.Timestamp), AllTypes.Timestamp.GT_EQ(timestamp), + AllTypes.Timestamp.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.Timestamp.ADD(INTERVALe(AllTypes.BigInt, HOUR)), + AllTypes.Timestamp.ADD(INTERVALd(2*time.Hour)), + + AllTypes.Timestamp.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + AllTypes.Timestamp.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + AllTypes.Timestamp.SUB(INTERVALd(3*time.Hour)), + CURRENT_TIMESTAMP(), CURRENT_TIMESTAMP(2), ) @@ -769,6 +819,12 @@ SELECT all_types.timestamp = all_types.timestamp, all_types.timestamp > TIMESTAMP('2019-06-06 10:02:46'), all_types.timestamp >= all_types.timestamp, all_types.timestamp >= TIMESTAMP('2019-06-06 10:02:46'), + all_types.timestamp + INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.timestamp + INTERVAL all_types.big_int HOUR, + all_types.timestamp + INTERVAL 2 HOUR, + all_types.timestamp - INTERVAL '05:10:20.000100' HOUR_MICROSECOND, + all_types.timestamp - INTERVAL all_types.integer_ptr HOUR, + all_types.timestamp - INTERVAL 3 HOUR, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP(2) FROM test_sample.all_types; @@ -853,6 +909,60 @@ LIMIT ?; } +func TestINTERVAL(t *testing.T) { + query := SELECT( + Date(2000, 2, 10).ADD(INTERVAL(1, MICROSECOND)). + EQ(Timestamp(2000, 2, 10, 0, 0, 0, 1*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVAL(2, SECOND)), + Date(2000, 2, 10).ADD(INTERVAL(3, MINUTE)), + Date(2000, 2, 10).SUB(INTERVAL(4, HOUR)), + Date(2000, 2, 10).ADD(INTERVAL(5, DAY)), + Date(2000, 2, 10).SUB(INTERVAL(6, MONTH)), + Date(2000, 2, 10).ADD(INTERVAL(7, YEAR)), + Date(2000, 2, 10).ADD(INTERVAL(-7, YEAR)), + Date(2000, 2, 10).ADD(INTERVAL("20.0000100", SECOND_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("02:20.0000100", MINUTE_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02:20.0000100", HOUR_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("100 11:02:20.0000100", DAY_MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02", MINUTE_SECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02:20", HOUR_SECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11:02", HOUR_MINUTE)), + Date(2000, 2, 10).SUB(INTERVAL("11 02:03:04", DAY_SECOND)), + Date(2000, 2, 10).SUB(INTERVAL("11 02:03", DAY_MINUTE)), + Date(2000, 2, 10).SUB(INTERVAL("11 2", DAY_HOUR)), + Date(2000, 2, 10).SUB(INTERVAL("2000-2", YEAR_MONTH)), + + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MICROSECOND)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, SECOND)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MINUTE)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, DAY)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, WEEK)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, MONTH)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, QUARTER)), + Date(2000, 2, 10).SUB(INTERVALe(AllTypes.IntegerPtr, YEAR)), + + Date(2000, 2, 10).SUB(INTERVALd(3*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(-3*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Second)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Second+4*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Minute+4*time.Second+5*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Hour+4*time.Minute+5*time.Second+6*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute+5*time.Second+6*time.Microsecond)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute+5*time.Second)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour+4*time.Minute)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour+3*time.Hour)), + Date(2000, 2, 10).SUB(INTERVALd(2*24*time.Hour)), + Date(2000, 2, 10).SUB(INTERVALd(3*time.Hour)), + Date(2000, 2, 10).SUB(INTERVALd(1*time.Hour+2*time.Minute+3*time.Second+345*time.Microsecond)), + ).FROM(AllTypes) + + //fmt.Println(query.DebugSql()) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + var allTypesJson = ` [ { diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index f404e29..952eb63 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -1,7 +1,6 @@ package mysql import ( - "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/mysql" "github.com/go-jet/jet/tests/.gentestdata/mysql/dvds/enum" @@ -607,7 +606,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). WHERE(Payment.PaymentID.LT(Int(10))) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) @@ -643,7 +642,7 @@ ORDER BY payment.customer_id; WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). ORDER_BY(Payment.CustomerID) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index f2e8fbe..7b9d1d6 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,17 +1,18 @@ package postgres import ( + "testing" + "time" + + "github.com/google/uuid" + "gotest.tools/assert" + "github.com/go-jet/jet/internal/testutils" - "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/tests/testdata/results/common" - "github.com/google/uuid" - "gotest.tools/assert" - "testing" - "time" ) func TestAllTypesSelect(t *testing.T) { @@ -134,22 +135,23 @@ LIMIT $5; func TestExpressionCast(t *testing.T) { query := AllTypes.SELECT( - postgres.CAST(Int(150)).AS_CHAR(12).AS("char12"), - postgres.CAST(String("TRUE")).AS_BOOL(), - postgres.CAST(String("111")).AS_SMALLINT(), - postgres.CAST(String("111")).AS_INTEGER(), - postgres.CAST(String("111")).AS_BIGINT(), - postgres.CAST(String("11.23")).AS_NUMERIC(30, 10), - postgres.CAST(String("11.23")).AS_NUMERIC(30), - postgres.CAST(String("11.23")).AS_NUMERIC(), - postgres.CAST(String("11.23")).AS_REAL(), - postgres.CAST(String("11.23")).AS_DOUBLE(), - postgres.CAST(Int(234)).AS_TEXT(), - postgres.CAST(String("1/8/1999")).AS_DATE(), - postgres.CAST(String("04:05:06.789")).AS_TIME(), - postgres.CAST(String("04:05:06 PST")).AS_TIMEZ(), - postgres.CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), - postgres.CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(Int(150)).AS_CHAR(12).AS("char12"), + CAST(String("TRUE")).AS_BOOL(), + CAST(String("111")).AS_SMALLINT(), + CAST(String("111")).AS_INTEGER(), + CAST(String("111")).AS_BIGINT(), + CAST(String("11.23")).AS_NUMERIC(30, 10), + CAST(String("11.23")).AS_NUMERIC(30), + CAST(String("11.23")).AS_NUMERIC(), + CAST(String("11.23")).AS_REAL(), + CAST(String("11.23")).AS_DOUBLE(), + CAST(Int(234)).AS_TEXT(), + CAST(String("1/8/1999")).AS_DATE(), + CAST(String("04:05:06.789")).AS_TIME(), + CAST(String("04:05:06 PST")).AS_TIMEZ(), + CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), + CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(String("04:05:06")).AS_INTERVAL(), TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), TO_CHAR(AllTypes.Integer, String("999")), @@ -359,7 +361,7 @@ func TestFloatOperators(t *testing.T) { TRUNC(ABSf(AllTypes.Decimal), Int(2)).AS("abs"), TRUNC(POWER(AllTypes.Decimal, Float(2.1)), Int(2)).AS("power"), TRUNC(SQRT(AllTypes.Decimal), Int(2)).AS("sqrt"), - TRUNC(postgres.CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), + TRUNC(CAST(CBRT(AllTypes.Decimal)).AS_DECIMAL(), Int(2)).AS("cbrt"), CEIL(AllTypes.Real).AS("ceil"), FLOOR(AllTypes.Real).AS("floor"), @@ -606,6 +608,19 @@ func TestTimeExpression(t *testing.T) { AllTypes.Time.GT_EQ(AllTypes.Time), AllTypes.Time.GT_EQ(Time(23, 6, 6, 1)), + AllTypes.Date.ADD(INTERVAL(1, HOUR)), + AllTypes.Date.SUB(INTERVAL(1, MINUTE)), + AllTypes.Time.ADD(INTERVAL(1, HOUR)), + AllTypes.Time.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timez.ADD(INTERVAL(1, HOUR)), + AllTypes.Timez.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timestamp.ADD(INTERVAL(1, HOUR)), + AllTypes.Timestamp.SUB(INTERVAL(1, MINUTE)), + AllTypes.Timestampz.ADD(INTERVAL(1, HOUR)), + AllTypes.Timestampz.SUB(INTERVAL(1, MINUTE)), + + AllTypes.Date.SUB(CAST(String("04:05:06")).AS_INTERVAL()), + CURRENT_DATE(), CURRENT_TIME(), CURRENT_TIME(2), @@ -626,6 +641,44 @@ func TestTimeExpression(t *testing.T) { assert.NilError(t, err) } +func TestInterval(t *testing.T) { + stmt := SELECT( + INTERVAL(1, YEAR), + INTERVAL(1, MONTH), + INTERVAL(1, WEEK), + INTERVAL(1, DAY), + INTERVAL(1, HOUR), + INTERVAL(1, MINUTE), + INTERVAL(1, SECOND), + INTERVAL(1, MILLISECOND), + INTERVAL(1, MICROSECOND), + INTERVAL(1, DECADE), + INTERVAL(1, CENTURY), + INTERVAL(1, MILLENNIUM), + + INTERVAL(1, YEAR, 10, MONTH), + INTERVAL(1, YEAR, 10, MONTH, 20, DAY), + INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), + + INTERVAL(1, YEAR).IS_NOT_NULL(), + INTERVAL(1, YEAR).AS("one year"), + + INTERVALd(0), + INTERVALd(1*time.Microsecond), + INTERVALd(1*time.Millisecond), + INTERVALd(1*time.Second), + INTERVALd(1*time.Minute), + INTERVALd(1*time.Hour), + INTERVALd(24*time.Hour), + INTERVALd(24*time.Hour+2*time.Hour+3*time.Minute+4*time.Second+5*time.Microsecond), + ) + + //fmt.Println(stmt.DebugSql()) + + err := stmt.Query(db, &struct{}{}) + assert.NilError(t, err) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 3cd2fba..2ac6c36 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -158,7 +158,7 @@ LIMIT 12; ). LIMIT(12) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) @@ -1686,7 +1686,7 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; ).GROUP_BY(Payment.Amount, Payment.CustomerID, Payment.PaymentDate). WHERE(Payment.PaymentID.LT(Int(10))) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) @@ -1722,7 +1722,7 @@ ORDER BY payment.customer_id; WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). ORDER_BY(Payment.CustomerID) - fmt.Println(query.Sql()) + //fmt.Println(query.Sql()) testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) @@ -1748,12 +1748,6 @@ func TestSimpleView(t *testing.T) { FilmInfo string } - //sql, args := query.Sql() - // - //row := db.QueryRow(sql, args...) - // - //row.Scan() - var dest []ActorInfo err := query.Query(db, &dest)