diff --git a/internal/jet/array_expression.go b/internal/jet/array_expression.go index b054baa..8d4cf02 100644 --- a/internal/jet/array_expression.go +++ b/internal/jet/array_expression.go @@ -71,7 +71,7 @@ func (a arrayInterfaceImpl[E]) CONCAT_ELEMENT(rhs E) Array[E] { } func (a arrayInterfaceImpl[E]) AT(at IntegerExpression) E { - return CastToArrayElemType[E](a.parent, CustomExpression(a.parent, Token("["), at, Token("]"))) + return CastToArrayElemType[E](a.parent, AtomicCustomExpression(a.parent, Token("["), at, Token("]"))) } type arrayExpressionWrapper[E Expression] struct { @@ -126,12 +126,8 @@ func CastToArrayElemType[E Expression](array Array[E], exp Expression) E { // ARRAY constructor builds an array value using list of expressions. func ARRAY[E Expression](elems ...E) Array[E] { - var args = make([]Serializer, len(elems)) - for i, each := range elems { - args[i] = each - } - return ArrayExp[E](CustomExpression(Token("ARRAY["), ListSerializer{ - Serializers: args, + return ArrayExp[E](AtomicCustomExpression(Token("ARRAY["), ListSerializer{ + Serializers: ToSerializerList(elems), Separator: ",", }, Token("]"))) } diff --git a/internal/jet/bool_expression_test.go b/internal/jet/bool_expression_test.go index c6cdbbb..f4c1770 100644 --- a/internal/jet/bool_expression_test.go +++ b/internal/jet/bool_expression_test.go @@ -6,6 +6,7 @@ import ( func TestBoolExpressionEQ(t *testing.T) { assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)") + assertClauseSerialize(t, Bool(true).EQ(String("foo").IS_NOT_NULL()), `($1 = ($2 IS NOT NULL))`, true, "foo") } func TestBoolExpressionNOT_EQ(t *testing.T) { @@ -24,31 +25,31 @@ func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { } func TestBoolExpressionIS_TRUE(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_TRUE(), "table1.col_bool IS TRUE") + assertClauseSerialize(t, table1ColBool.IS_TRUE(), "(table1.col_bool IS TRUE)") assertClauseSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE(), - `($1 = table1.col_int) IS TRUE`, int64(2)) + `(($1 = table1.col_int) IS TRUE)`, int64(2)) assertClauseSerialize(t, (Int(2).EQ(table1ColInt)).IS_TRUE().AND(Int(4).EQ(table2ColInt)), - `(($1 = table1.col_int) IS TRUE AND ($2 = table2.col_int))`, int64(2), int64(4)) + `((($1 = table1.col_int) IS TRUE) AND ($2 = table2.col_int))`, int64(2), int64(4)) } func TestBoolExpressionIS_NOT_TRUE(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_NOT_TRUE(), "table1.col_bool IS NOT TRUE") + assertClauseSerialize(t, table1ColBool.IS_NOT_TRUE(), "(table1.col_bool IS NOT TRUE)") } func TestBoolExpressionIS_FALSE(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_FALSE(), "table1.col_bool IS FALSE") + assertClauseSerialize(t, table1ColBool.IS_FALSE(), "(table1.col_bool IS FALSE)") } func TestBoolExpressionIS_NOT_FALSE(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_NOT_FALSE(), "table1.col_bool IS NOT FALSE") + assertClauseSerialize(t, table1ColBool.IS_NOT_FALSE(), "(table1.col_bool IS NOT FALSE)") } func TestBoolExpressionIS_UNKNOWN(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_UNKNOWN(), "table1.col_bool IS UNKNOWN") + assertClauseSerialize(t, table1ColBool.IS_UNKNOWN(), "(table1.col_bool IS UNKNOWN)") } func TestBoolExpressionIS_NOT_UNKNOWN(t *testing.T) { - assertClauseSerialize(t, table1ColBool.IS_NOT_UNKNOWN(), "table1.col_bool IS NOT UNKNOWN") + assertClauseSerialize(t, table1ColBool.IS_NOT_UNKNOWN(), "(table1.col_bool IS NOT UNKNOWN)") } func TestBinaryBoolExpression(t *testing.T) { @@ -72,5 +73,5 @@ func TestBoolLiteral(t *testing.T) { func TestBoolExp(t *testing.T) { assertClauseSerialize(t, BoolExp(String("true")), "$1", "true") - assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "$1 IS TRUE", "true") + assertClauseSerialize(t, BoolExp(String("true")).IS_TRUE(), "($1 IS TRUE)", "true") } diff --git a/internal/jet/cast.go b/internal/jet/cast.go deleted file mode 100644 index 311ab88..0000000 --- a/internal/jet/cast.go +++ /dev/null @@ -1,53 +0,0 @@ -package jet - -// Cast interface -type Cast interface { - AS(castType string) Expression -} - -type castImpl struct { - expression Expression -} - -// NewCastImpl creates new generic cast -func NewCastImpl(expression Expression) Cast { - castImpl := castImpl{ - expression: expression, - } - - return &castImpl -} - -func (b *castImpl) AS(castType string) Expression { - castExp := &castExpression{ - expression: b.expression, - cast: string(castType), - } - - castExp.ExpressionInterfaceImpl.Root = castExp - - return castExp -} - -type castExpression struct { - ExpressionInterfaceImpl - - expression Expression - cast string -} - -func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - - expression := b.expression - castType := b.cast - - if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil { - castOverride(expression, String(castType))(statement, out, FallTrough(options)...) - return - } - - out.WriteString("CAST(") - expression.serialize(statement, out, FallTrough(options)...) - out.WriteString("AS") - out.WriteString(castType + ")") -} diff --git a/internal/jet/cast_test.go b/internal/jet/cast_test.go deleted file mode 100644 index b72cede..0000000 --- a/internal/jet/cast_test.go +++ /dev/null @@ -1,11 +0,0 @@ -package jet - -import ( - "testing" -) - -func TestCastAS(t *testing.T) { - assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1)) - assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)") - assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)") -} diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 678e044..5becc73 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -9,7 +9,6 @@ type Dialect interface { Name() string PackageName() string OperatorSerializeOverride(operator string) SerializeOverride - FunctionSerializeOverride(function string) SerializeOverride AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc @@ -18,6 +17,7 @@ type Dialect interface { SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName(index int) string JsonValueEncode(expr Expression) Expression + RegexpLike(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc } // SerializerFunc func @@ -34,7 +34,6 @@ type DialectParams struct { Name string PackageName string OperatorSerializeOverrides map[string]SerializeOverride - FunctionSerializeOverrides map[string]SerializeOverride AliasQuoteChar byte IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc @@ -43,6 +42,7 @@ type DialectParams struct { SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName func(index int) string JsonValueEncode func(expr Expression) Expression + RegexpLike func(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc } // NewDialect creates new dialect with params @@ -51,7 +51,6 @@ func NewDialect(params DialectParams) Dialect { name: params.Name, packageName: params.PackageName, operatorSerializeOverrides: params.OperatorSerializeOverrides, - functionSerializeOverrides: params.FunctionSerializeOverrides, aliasQuoteChar: params.AliasQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, @@ -60,6 +59,7 @@ func NewDialect(params DialectParams) Dialect { serializeOrderBy: params.SerializeOrderBy, valuesDefaultColumnName: params.ValuesDefaultColumnName, jsonValueEncode: params.JsonValueEncode, + regexpLike: params.RegexpLike, } } @@ -67,7 +67,6 @@ type dialectImpl struct { name string packageName string operatorSerializeOverrides map[string]SerializeOverride - functionSerializeOverrides map[string]SerializeOverride aliasQuoteChar byte identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc @@ -76,6 +75,7 @@ type dialectImpl struct { serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc valuesDefaultColumnName func(index int) string jsonValueEncode func(expr Expression) Expression + regexpLike func(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc } func (d *dialectImpl) Name() string { @@ -93,13 +93,6 @@ func (d *dialectImpl) OperatorSerializeOverride(operator string) SerializeOverri return d.operatorSerializeOverrides[operator] } -func (d *dialectImpl) FunctionSerializeOverride(function string) SerializeOverride { - if d.functionSerializeOverrides == nil { - return nil - } - return d.functionSerializeOverrides[function] -} - func (d *dialectImpl) AliasQuoteChar() byte { return d.aliasQuoteChar } @@ -133,6 +126,21 @@ func (d *dialectImpl) JsonValueEncode(expr Expression) Expression { return d.jsonValueEncode(expr) } +func (d *dialectImpl) RegexpLike(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc { + if d.regexpLike != nil { + return d.regexpLike(str, not, pattern, caseSensitive) + } + + return func(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + str.serialize(statement, out, FallTrough(options)...) + if not { + out.WriteString("NOT") + } + out.WriteString("REGEXP") + pattern.serialize(statement, out, FallTrough(options)...) + } +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/enum_value.go b/internal/jet/enum_value.go index 4864c74..5158cef 100644 --- a/internal/jet/enum_value.go +++ b/internal/jet/enum_value.go @@ -1,22 +1,16 @@ package jet -type enumValue struct { - ExpressionInterfaceImpl - stringInterfaceImpl - - name string -} - // NewEnumValue creates new named enum value func NewEnumValue(name string) StringExpression { - enumValue := &enumValue{name: name} - - enumValue.ExpressionInterfaceImpl.Root = enumValue - enumValue.stringInterfaceImpl.root = enumValue - - return enumValue + return StringExp(newExpression( + enumValueSerializer{name: name}, + )) } -func (e enumValue) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +type enumValueSerializer struct { + name string +} + +func (e enumValueSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.insertConstantArgument(e.name) } diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 62e5aff..2694b7a 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -118,92 +118,84 @@ func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, o e.Root.serialize(statement, out, NoWrap) } -// Representation of binary operations (e.g. comparisons, arithmetic) -type binaryOperatorExpression struct { +type expression struct { ExpressionInterfaceImpl + Serializer +} +func newExpression(serializer Serializer) Expression { + expr := &expression{ + Serializer: serializer, + } + + expr.ExpressionInterfaceImpl.Root = expr + + return expr +} + +// Representation of binary operations (e.g. comparisons, arithmetic) +type binaryOperatorSerializer struct { lhs, rhs Serializer additionalParam Serializer operator string } +func (c *binaryOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { + serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) + serializeOverrideFunc(statement, out, FallTrough(options)...) + } else { + c.lhs.serialize(statement, out, FallTrough(options)...) + out.WriteString(c.operator) + c.rhs.serialize(statement, out, FallTrough(options)...) + } + }) + +} + // NewBinaryOperatorExpression creates new binaryOperatorExpression func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression { - binaryExpression := &binaryOperatorExpression{ - lhs: lhs, - rhs: rhs, - operator: operator, - } - - if len(additionalParam) > 0 { - binaryExpression.additionalParam = additionalParam[0] - } - - binaryExpression.ExpressionInterfaceImpl.Root = binaryExpression - - return complexExpr(binaryExpression) + return newExpression(&binaryOperatorSerializer{ + lhs: lhs, + rhs: rhs, + additionalParam: OptionalOrDefault(additionalParam, nil), + operator: operator, + }) } -func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) - serializeOverrideFunc(statement, out, FallTrough(options)...) - } else { - c.lhs.serialize(statement, out, FallTrough(options)...) - out.WriteString(c.operator) - c.rhs.serialize(statement, out, FallTrough(options)...) - } -} - -type expressionListOperator struct { - ExpressionInterfaceImpl - +type serializersWithOperator struct { operator string - expressions []Expression + serializers []Serializer } -func newExpressionListOperator(operator string, expressions ...Expression) *expressionListOperator { - ret := &expressionListOperator{ - operator: operator, - expressions: expressions, - } - - ret.ExpressionInterfaceImpl.Root = ret - - return ret -} - -func newBoolExpressionListOperator(operator string, expressions ...BoolExpression) BoolExpression { - return BoolExp(newExpressionListOperator(operator, ToExpressionList(expressions)...)) -} - -func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if len(elo.expressions) == 0 { +func (s *serializersWithOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(s.serializers) == 0 { panic("jet: syntax error, expression list empty") } - shouldWrap := len(elo.expressions) > 1 + shouldWrap := len(s.serializers) > 1 if shouldWrap { out.WriteByte('(') out.IncreaseIdent(tabSize) out.NewLine() } - for i, expression := range elo.expressions { + for i, expression := range s.serializers { if i == 1 { out.IncreaseIdent(tabSize) } if i > 0 { out.NewLine() - out.WriteString(elo.operator) + out.WriteString(s.operator) } - out.IncreaseIdent(len(elo.operator) + 1) + out.IncreaseIdent(len(s.operator) + 1) expression.serialize(statement, out, FallTrough(options)...) - out.DecreaseIdent(len(elo.operator) + 1) + out.DecreaseIdent(len(s.operator) + 1) } - if len(elo.expressions) > 1 { + if len(s.serializers) > 1 { out.DecreaseIdent(tabSize) } @@ -214,130 +206,47 @@ func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBu } } -// A prefix operator Expression -type prefixExpression struct { - ExpressionInterfaceImpl - - expression Expression - operator string +func newBoolExpressionListOperator(operator string, expressions []BoolExpression) BoolExpression { + return BoolExp(newExpression(&serializersWithOperator{ + operator: operator, + serializers: ToSerializerList(expressions), + })) } func newPrefixOperatorExpression(expression Expression, operator string) Expression { - prefixExpression := &prefixExpression{ - expression: expression, - operator: operator, - } - prefixExpression.ExpressionInterfaceImpl.Root = prefixExpression - - return complexExpr(prefixExpression) + return CustomExpression(Token(operator), expression) } -func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString(p.operator) - p.expression.serialize(statement, out, FallTrough(options)...) +func newPostfixOperatorExpression(expression Expression, operator string) Expression { + return CustomExpression(expression, Token(operator)) } -// A postfix operator Expression -type postfixOpExpression struct { - ExpressionInterfaceImpl - - expression Expression - operator string -} - -func newPostfixOperatorExpression(expression Expression, operator string) *postfixOpExpression { - postfixOpExpression := &postfixOpExpression{ - expression: expression, - operator: operator, - } - - postfixOpExpression.ExpressionInterfaceImpl.Root = postfixOpExpression - - return postfixOpExpression -} - -func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - p.expression.serialize(statement, out, FallTrough(options)...) - out.WriteString(p.operator) -} - -type betweenOperatorExpression struct { - ExpressionInterfaceImpl - +type betweenOperatorSerializer struct { expression Expression notBetween bool min Expression max Expression } +func (b *betweenOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + b.expression.serialize(statement, out, FallTrough(options)...) + if b.notBetween { + out.WriteString("NOT") + } + out.WriteString("BETWEEN") + b.min.serialize(statement, out, FallTrough(options)...) + out.WriteString("AND") + b.max.serialize(statement, out, FallTrough(options)...) + }) +} + // NewBetweenOperatorExpression creates new BETWEEN operator expression func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bool) BoolExpression { - newBetweenOperator := &betweenOperatorExpression{ + return BoolExp(newExpression(&betweenOperatorSerializer{ expression: expression, notBetween: notBetween, min: min, max: max, - } - - newBetweenOperator.ExpressionInterfaceImpl.Root = newBetweenOperator - - return BoolExp(complexExpr(newBetweenOperator)) -} - -func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - p.expression.serialize(statement, out, FallTrough(options)...) - if p.notBetween { - out.WriteString("NOT") - } - out.WriteString("BETWEEN") - p.min.serialize(statement, out, FallTrough(options)...) - out.WriteString("AND") - p.max.serialize(statement, out, FallTrough(options)...) -} - -type customExpression struct { - ExpressionInterfaceImpl - parts []Serializer -} - -func CustomExpression(parts ...Serializer) Expression { - ret := customExpression{ - parts: parts, - } - ret.ExpressionInterfaceImpl.Root = &ret - return &ret -} - -func (c *customExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - for _, expression := range c.parts { - expression.serialize(statement, out, options...) - } -} - -type complexExpression struct { - ExpressionInterfaceImpl - expressions Expression -} - -func complexExpr(expression Expression) Expression { - complexExpression := &complexExpression{expressions: expression} - complexExpression.ExpressionInterfaceImpl.Root = complexExpression - - return complexExpression -} - -func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, NoWrap) { - out.WriteString("(") - } - - s.expressions.serialize(statement, out, options...) // FallTrough here because complexExpression is just a wrapper - - if !contains(options, NoWrap) { - out.WriteString(")") - } -} - -func wrap(expressions ...Expression) Expression { - return NewFunc("", expressions, nil) + })) } diff --git a/internal/jet/expression_test.go b/internal/jet/expression_test.go index 74e6a05..0160150 100644 --- a/internal/jet/expression_test.go +++ b/internal/jet/expression_test.go @@ -5,13 +5,13 @@ import ( ) func TestExpressionIS_NULL(t *testing.T) { - assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") - assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") + assertClauseSerialize(t, table2Col3.IS_NULL(), "(table2.col3 IS NULL)") + assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "((table2.col3 + table2.col3) IS NULL)") } func TestExpressionIS_NOT_NULL(t *testing.T) { - assertClauseSerialize(t, table2Col3.IS_NOT_NULL(), "table2.col3 IS NOT NULL") - assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NOT_NULL(), "(table2.col3 + table2.col3) IS NOT NULL") + assertClauseSerialize(t, table2Col3.IS_NOT_NULL(), "(table2.col3 IS NOT NULL)") + assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NOT_NULL(), "((table2.col3 + table2.col3) IS NOT NULL)") } func TestExpressionIS_DISTINCT_FROM(t *testing.T) { diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 4ea2a4f..c6c598f 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -3,13 +3,13 @@ package jet // AND function adds AND operator between expressions. This function can be used, instead of method AND, // to have a better inlining of a complex condition in the Go code and in the generated SQL. func AND(expressions ...BoolExpression) BoolExpression { - return newBoolExpressionListOperator("AND", expressions...) + return newBoolExpressionListOperator("AND", expressions) } // OR function adds OR operator between expressions. This function can be used, instead of method OR, // to have a better inlining of a complex condition in the Go code and in the generated SQL. func OR(expressions ...BoolExpression) BoolExpression { - return newBoolExpressionListOperator("OR", expressions...) + return newBoolExpressionListOperator("OR", expressions) } // ------------------ Mathematical functions ---------------// @@ -244,7 +244,7 @@ func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) defaultValue, ok = offsetAndDefault[1].(Expression) if !ok { - defaultValue = literal(offsetAndDefault[1]) + defaultValue = Literal(offsetAndDefault[1]) } params = append(params, FixedLiteral(offset), defaultValue) @@ -484,12 +484,12 @@ func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType // LOWER_BOUND returns range expressions lower bound. Returns null if range is empty or the requested bound is infinite. func LOWER_BOUND[T Expression](rangeExpression Range[T]) T { - return rangeTypeCaster[T](rangeExpression, NewFunc("LOWER", []Expression{rangeExpression}, nil)) + return rangeTypeCaster[T](rangeExpression, newFunc("LOWER", []Expression{rangeExpression})) } // UPPER_BOUND returns range expressions upper bound. Returns null if range is empty or the requested bound is infinite. func UPPER_BOUND[T Expression](rangeExpression Range[T]) T { - return rangeTypeCaster[T](rangeExpression, NewFunc("UPPER", []Expression{rangeExpression}, nil)) + return rangeTypeCaster[T](rangeExpression, newFunc("UPPER", []Expression{rangeExpression})) } func rangeTypeCaster[T Expression](rangeExpression Range[T], exp Expression) T { @@ -543,7 +543,7 @@ func TO_CHAR(expression Expression, format StringExpression) StringExpression { // TO_DATE converts string to date using format func TO_DATE(dateStr, format StringExpression) DateExpression { - return NewDateFunc("TO_DATE", dateStr, format) + return DateExp(newFunc("TO_DATE", []Expression{dateStr, format})) } // TO_NUMBER converts string to numeric using format @@ -560,74 +560,47 @@ func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression { // EXTRACT extracts time component from time expression func EXTRACT(field string, from Expression) Expression { - return CustomExpression(Token("EXTRACT("), Token(field), Token("FROM"), from, Token(")")) + return AtomicCustomExpression(Token("EXTRACT("), Token(field), Token("FROM"), from, Token(")")) } // CURRENT_DATE returns current date func CURRENT_DATE() DateExpression { - dateFunc := NewDateFunc("CURRENT_DATE") - dateFunc.noBrackets = true - return dateFunc + return DateKeyword("CURRENT_DATE") } // CURRENT_TIME returns current time with time zone func CURRENT_TIME(precision ...int) TimezExpression { - var timezFunc *timezFunc - if len(precision) > 0 { - timezFunc = newTimezFunc("CURRENT_TIME", FixedLiteral(precision[0])) - } else { - timezFunc = newTimezFunc("CURRENT_TIME") + return newTimezFunc("CURRENT_TIME", FixedLiteral(precision[0])) } - timezFunc.noBrackets = true - - return timezFunc + return TimezKeyword("CURRENT_TIME") } // CURRENT_TIMESTAMP returns current timestamp with time zone func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression { - var timestampzFunc *timestampzFunc - if len(precision) > 0 { - timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", FixedLiteral(precision[0])) - } else { - timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") + return newTimestampzFunc("CURRENT_TIMESTAMP", FixedLiteral(precision[0])) } - timestampzFunc.noBrackets = true - - return timestampzFunc + return TimestampzKeyword("CURRENT_TIMESTAMP") } // LOCALTIME returns local time of day using optional precision func LOCALTIME(precision ...int) TimeExpression { - var timeFunc *timeFunc - if len(precision) > 0 { - timeFunc = NewTimeFunc("LOCALTIME", FixedLiteral(precision[0])) - } else { - timeFunc = NewTimeFunc("LOCALTIME") + return NewTimeFunc("LOCALTIME", FixedLiteral(precision[0])) } - timeFunc.noBrackets = true - - return timeFunc + return TimeKeyword("LOCALTIME") } // LOCALTIMESTAMP returns current date and time using optional precision func LOCALTIMESTAMP(precision ...int) TimestampExpression { - var timestampFunc *timestampFunc - if len(precision) > 0 { - timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", FixedLiteral(precision[0])) - } else { - timestampFunc = NewTimestampFunc("LOCALTIMESTAMP") + return NewTimestampFunc("LOCALTIMESTAMP", FixedLiteral(precision[0])) } - - timestampFunc.noBrackets = true - - return timestampFunc + return TimestampKeyword("LOCALTIMESTAMP") } // NOW returns current date and time @@ -641,74 +614,53 @@ func NOW() TimestampzExpression { func COALESCE(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return NewFunc("COALESCE", allValues, nil) + return newFunc("COALESCE", allValues) } // NULLIF function returns a null value if value1 equals value2; otherwise it returns value1. func NULLIF(value1, value2 Expression) Expression { - return NewFunc("NULLIF", []Expression{value1, value2}, nil) + return newFunc("NULLIF", []Expression{value1, value2}) } // GREATEST selects the largest value from a list of expressions func GREATEST(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return NewFunc("GREATEST", allValues, nil) + return newFunc("GREATEST", allValues) } // LEAST selects the smallest value from a list of expressions func LEAST(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return NewFunc("LEAST", allValues, nil) + return newFunc("LEAST", allValues) } //--------------------------------------------------------------------// -type funcExpressionImpl struct { - ExpressionInterfaceImpl +// newFunc creates new function with name and expressions parameters +func newFunc(name string, expressions []Expression) Expression { + return newExpression(&funcSerializer{ + name: name, + parameters: expressions, + }) +} +type funcSerializer struct { name string parameters parametersSerializer - noBrackets bool } -// NewFunc creates new function with name and expressions parameters -func NewFunc(name string, expressions []Expression, root Expression) *funcExpressionImpl { - funcExp := &funcExpressionImpl{ - name: name, - parameters: parametersSerializer(expressions), - } - - if root != nil { - funcExp.ExpressionInterfaceImpl.Root = root - } else { - funcExp.ExpressionInterfaceImpl.Root = funcExp - } - - return funcExp -} - -func (f *funcExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(ExpressionListToSerializerList(f.parameters)...) - serializeOverrideFunc(statement, out, FallTrough(options)...) - return - } - - addBrackets := !f.noBrackets || len(f.parameters) > 0 - - if addBrackets { - out.WriteString(f.name + "(") - } else { - out.WriteString(f.name) - } +func (f *funcSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString(f.name + "(") f.parameters.serialize(statement, out, options...) - if addBrackets { - out.WriteString(")") - } + out.WriteString(")") +} + +func newBoolFunc(name string, expressions ...Expression) BoolExpression { + return BoolExp(newFunc(name, expressions)) } type parametersSerializer []Expression @@ -730,208 +682,83 @@ func (p parametersSerializer) serialize(statement StatementType, out *SQLBuilder // NewFloatWindowFunc creates new float function with name and expressions func newWindowFunc(name string, expressions ...Expression) windowExpression { - newFun := NewFunc(name, expressions, nil) - windowExpr := newWindowExpression(newFun) - newFun.ExpressionInterfaceImpl.Root = windowExpr - - return windowExpr -} - -type boolFunc struct { - funcExpressionImpl - boolInterfaceImpl -} - -func newBoolFunc(name string, expressions ...Expression) BoolExpression { - boolFunc := &boolFunc{} - - boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc) - boolFunc.boolInterfaceImpl.root = boolFunc - boolFunc.ExpressionInterfaceImpl.Root = boolFunc - - return boolFunc + return newWindowExpression(newFunc(name, expressions)) } // NewFloatWindowFunc creates new float function with name and expressions func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression { - boolFunc := &boolFunc{} - - boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc) - intWindowFunc := newBoolWindowExpression(boolFunc) - boolFunc.boolInterfaceImpl.root = intWindowFunc - boolFunc.ExpressionInterfaceImpl.Root = intWindowFunc - - return intWindowFunc -} - -type floatFunc struct { - funcExpressionImpl - floatInterfaceImpl + return newBoolWindowExpression(BoolExp(newFunc(name, expressions))) } // NewFloatFunc creates new float function with name and expressions func NewFloatFunc(name string, expressions ...Expression) FloatExpression { - floatFunc := &floatFunc{} - - floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) - floatFunc.floatInterfaceImpl.root = floatFunc - - return floatFunc + return FloatExp(newFunc(name, expressions)) } // NewFloatWindowFunc creates new float function with name and expressions func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression { - floatFunc := &floatFunc{} - - floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) - floatWindowFunc := newFloatWindowExpression(floatFunc) - floatFunc.floatInterfaceImpl.root = floatWindowFunc - floatFunc.ExpressionInterfaceImpl.Root = floatWindowFunc - - return floatWindowFunc -} - -type integerFunc struct { - funcExpressionImpl - integerInterfaceImpl + return newFloatWindowExpression(FloatExp(newFunc(name, expressions))) } func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { - intFunc := &integerFunc{} - - intFunc.funcExpressionImpl = *NewFunc(name, expressions, intFunc) - intFunc.integerInterfaceImpl.root = intFunc - - return intFunc + return IntExp(newFunc(name, expressions)) } // NewFloatWindowFunc creates new float function with name and expressions func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression { - integerFunc := &integerFunc{} - - integerFunc.funcExpressionImpl = *NewFunc(name, expressions, integerFunc) - intWindowFunc := newIntegerWindowExpression(integerFunc) - integerFunc.integerInterfaceImpl.root = intWindowFunc - integerFunc.ExpressionInterfaceImpl.Root = intWindowFunc - - return intWindowFunc -} - -type stringFunc struct { - funcExpressionImpl - stringInterfaceImpl + return newIntegerWindowExpression(IntExp(newFunc(name, expressions))) } // NewStringFunc creates new string function with name and expression parameters func NewStringFunc(name string, expressions ...Expression) StringExpression { - stringFunc := &stringFunc{} - - stringFunc.funcExpressionImpl = *NewFunc(name, expressions, stringFunc) - stringFunc.stringInterfaceImpl.root = stringFunc - - return stringFunc -} - -type dateFunc struct { - funcExpressionImpl - dateInterfaceImpl -} - -// NewDateFunc creates new date function with name and expression parameters -func NewDateFunc(name string, expressions ...Expression) *dateFunc { - dateFunc := &dateFunc{} - - dateFunc.funcExpressionImpl = *NewFunc(name, expressions, dateFunc) - dateFunc.dateInterfaceImpl.root = dateFunc - - return dateFunc -} - -type timeFunc struct { - funcExpressionImpl - timeInterfaceImpl + return StringExp(newFunc(name, expressions)) } // NewTimeFunc creates new time function with name and expression parameters -func NewTimeFunc(name string, expressions ...Expression) *timeFunc { - timeFun := &timeFunc{} - - timeFun.funcExpressionImpl = *NewFunc(name, expressions, timeFun) - timeFun.timeInterfaceImpl.root = timeFun - - return timeFun +func NewTimeFunc(name string, expressions ...Expression) TimeExpression { + return TimeExp(newFunc(name, expressions)) } -type timezFunc struct { - funcExpressionImpl - timezInterfaceImpl -} - -func newTimezFunc(name string, expressions ...Expression) *timezFunc { - timezFun := &timezFunc{} - - timezFun.funcExpressionImpl = *NewFunc(name, expressions, timezFun) - timezFun.timezInterfaceImpl.root = timezFun - - return timezFun -} - -type timestampFunc struct { - funcExpressionImpl - timestampInterfaceImpl +func newTimezFunc(name string, expressions ...Expression) TimezExpression { + return TimezExp(newFunc(name, expressions)) } // NewTimestampFunc creates new timestamp function with name and expressions -func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc { - timestampFunc := ×tampFunc{} - - timestampFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampFunc) - timestampFunc.timestampInterfaceImpl.root = timestampFunc - - return timestampFunc +func NewTimestampFunc(name string, expressions ...Expression) TimestampExpression { + return TimestampExp(newFunc(name, expressions)) } -type timestampzFunc struct { - funcExpressionImpl - timestampzInterfaceImpl -} - -func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { - timestampzFunc := ×tampzFunc{} - - timestampzFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampzFunc) - timestampzFunc.timestampzInterfaceImpl.root = timestampzFunc - - return timestampzFunc +func newTimestampzFunc(name string, expressions ...Expression) TimestampzExpression { + return TimestampzExp(newFunc(name, expressions)) } // Func can be used to call custom or unsupported database functions. func Func(name string, expressions ...Expression) Expression { - return NewFunc(name, expressions, nil) + return newFunc(name, expressions) } func NumRange(lowNum, highNum NumericExpression, bounds ...StringExpression) Range[NumericExpression] { - return NumRangeExp(NewFunc("numrange", rangeFuncParamCombiner(lowNum, highNum, bounds...), nil)) + return NumRangeExp(newFunc("numrange", rangeFuncParamCombiner(lowNum, highNum, bounds...))) } func Int4Range(lowNum, highNum IntegerExpression, bounds ...StringExpression) Range[Int4Expression] { - return Int4RangeExp(NewFunc("int4range", rangeFuncParamCombiner(lowNum, highNum, bounds...), nil)) + return Int4RangeExp(newFunc("int4range", rangeFuncParamCombiner(lowNum, highNum, bounds...))) } func Int8Range(lowNum, highNum Int8Expression, bounds ...StringExpression) Range[Int8Expression] { - return Int8RangeExp(NewFunc("int8range", rangeFuncParamCombiner(lowNum, highNum, bounds...), nil)) + return Int8RangeExp(newFunc("int8range", rangeFuncParamCombiner(lowNum, highNum, bounds...))) } func TsRange(lowTs, highTs TimestampExpression, bounds ...StringExpression) Range[TimestampExpression] { - return TsRangeExp(NewFunc("tsrange", rangeFuncParamCombiner(lowTs, highTs, bounds...), nil)) + return TsRangeExp(newFunc("tsrange", rangeFuncParamCombiner(lowTs, highTs, bounds...))) } func TstzRange(lowTs, highTs TimestampzExpression, bounds ...StringExpression) Range[TimestampzExpression] { - return TstzRangeExp(NewFunc("tstzrange", rangeFuncParamCombiner(lowTs, highTs, bounds...), nil)) + return TstzRangeExp(newFunc("tstzrange", rangeFuncParamCombiner(lowTs, highTs, bounds...))) } func DateRange(lowTs, highTs DateExpression, bounds ...StringExpression) Range[DateExpression] { - return DateRangeExp(NewFunc("daterange", rangeFuncParamCombiner(lowTs, highTs, bounds...), nil)) + return DateRangeExp(newFunc("daterange", rangeFuncParamCombiner(lowTs, highTs, bounds...))) } func rangeFuncParamCombiner(low, high Expression, bounds ...StringExpression) []Expression { @@ -941,3 +768,23 @@ func rangeFuncParamCombiner(low, high Expression, bounds ...StringExpression) [] } return exp } + +func TimeKeyword(name string) TimeExpression { + return TimeExp(newExpression(Keyword(name))) +} + +func TimezKeyword(name string) TimezExpression { + return TimezExp(newExpression(Keyword(name))) +} + +func TimestampKeyword(name string) TimestampExpression { + return TimestampExp(newExpression(Keyword(name))) +} + +func TimestampzKeyword(name string) TimestampzExpression { + return TimestampzExp(newExpression(Keyword(name))) +} + +func DateKeyword(name string) DateExpression { + return DateExp(newExpression(Keyword(name))) +} diff --git a/internal/jet/func_expression_test.go b/internal/jet/func_expression_test.go index ca32589..eeef5f9 100644 --- a/internal/jet/func_expression_test.go +++ b/internal/jet/func_expression_test.go @@ -6,7 +6,7 @@ import ( func TestAND(t *testing.T) { assertClauseSerializeErr(t, AND(), "jet: syntax error, expression list empty") - assertClauseSerialize(t, AND(table1ColInt.IS_NULL()), `table1.col_int IS NULL`) // IS NULL doesn't add parenthesis + assertClauseSerialize(t, AND(table1ColInt.IS_NULL()), `(table1.col_int IS NULL)`) // IS NULL doesn't add parenthesis assertClauseSerialize(t, AND(table1ColInt.LT(Int(11))), `(table1.col_int < $1)`, int64(11)) assertClauseSerialize(t, AND(table1ColInt.GT(Int(11)), table1ColFloat.EQ(Float(0))), `( @@ -17,7 +17,7 @@ func TestAND(t *testing.T) { func TestOR(t *testing.T) { assertClauseSerializeErr(t, OR(), "jet: syntax error, expression list empty") - assertClauseSerialize(t, OR(table1ColInt.IS_NULL()), `table1.col_int IS NULL`) // IS NULL doesn't add parenthesis + assertClauseSerialize(t, OR(table1ColInt.IS_NULL()), `(table1.col_int IS NULL)`) // IS NULL doesn't add parenthesis assertClauseSerialize(t, OR(table1ColInt.LT(Int(11))), `(table1.col_int < $1)`, int64(11)) assertClauseSerialize(t, OR(table1ColInt.GT(Int(11)), table1ColFloat.EQ(Float(0))), `( @@ -205,7 +205,7 @@ func TestFunc(t *testing.T) { func Test_rangePointCaster(t *testing.T) { mainRange := Int8Range(Int8(10), Int8(12)) - exp := NewFunc("UPPER", []Expression{mainRange}, nil) + exp := newFunc("UPPER", []Expression{mainRange}) got := rangeTypeCaster(mainRange, exp) _, ok := got.(IntegerExpression) diff --git a/internal/jet/integer_expression_test.go b/internal/jet/integer_expression_test.go index a20981b..e1c850c 100644 --- a/internal/jet/integer_expression_test.go +++ b/internal/jet/integer_expression_test.go @@ -66,7 +66,7 @@ func TestIntExpressionPOW(t *testing.T) { func TestIntExpressionBIT_NOT(t *testing.T) { assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)") - assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ 11)") + assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ $1)", int64(11)) } func TestIntExpressionBIT_AND(t *testing.T) { diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index b8baa51..f7ceedf 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -5,48 +5,12 @@ import ( "time" ) -// LiteralExpression is representation of an escaped literal -type LiteralExpression interface { - Expression - - Value() interface{} - SetConstant(constant bool) -} - -type literalExpressionImpl struct { - ExpressionInterfaceImpl - +type literalSerializer struct { value interface{} constant bool } -func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl { - exp := literalExpressionImpl{value: value} - - if len(optionalConstant) > 0 { - exp.constant = optionalConstant[0] - } - - exp.ExpressionInterfaceImpl.Root = &exp - - return &exp -} - -// Literal is injected directly to SQL query, and does not appear in parametrized argument list. -func Literal(value interface{}) *literalExpressionImpl { - exp := literal(value) - return exp -} - -// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. -func FixedLiteral(value interface{}) *literalExpressionImpl { - exp := literal(value) - exp.constant = true - - return exp -} - -func (l *literalExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (l *literalSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if l.constant { out.insertConstantArgument(l.value) } else { @@ -54,260 +18,145 @@ func (l *literalExpressionImpl) serialize(statement StatementType, out *SQLBuild } } -func (l *literalExpressionImpl) Value() interface{} { - return l.value +// Literal is injected directly to SQL query, and does not appear in parametrized argument list. +func Literal(value interface{}) Expression { + return newExpression(&literalSerializer{ + value: value, + constant: false, + }) } -func (l *literalExpressionImpl) SetConstant(constant bool) { - l.constant = constant -} - -type integerLiteralExpression struct { - literalExpressionImpl - integerInterfaceImpl -} - -func intLiteral(value interface{}) IntegerExpression { - numLiteral := &integerLiteralExpression{} - - numLiteral.literalExpressionImpl = *literal(value) - - numLiteral.literalExpressionImpl.Root = numLiteral - numLiteral.integerInterfaceImpl.root = numLiteral - - return numLiteral +// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. +func FixedLiteral(value interface{}) Expression { + return newExpression(&literalSerializer{ + value: value, + constant: true, + }) } // Int creates a new 64 bit signed integer literal func Int(value int64) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Int8 creates a new 8 bit signed integer literal func Int8(value int8) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Int16 creates a new 16 bit signed integer literal func Int16(value int16) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Int32 creates a new 32 bit signed integer literal func Int32(value int32) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint8 creates a new 8 bit unsigned integer literal func Uint8(value uint8) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint16 creates a new 16 bit unsigned integer literal func Uint16(value uint16) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint32 creates a new 32 bit unsigned integer literal func Uint32(value uint32) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint64 creates a new 64 bit unsigned integer literal func Uint64(value uint64) IntegerExpression { - return intLiteral(value) -} - -// ---------------------------------------------------// -type boolLiteralExpression struct { - boolInterfaceImpl - literalExpressionImpl + return IntExp(Literal(value)) } // Bool creates new bool literal expression func Bool(value bool) BoolExpression { - boolLiteralExpression := boolLiteralExpression{} - - boolLiteralExpression.literalExpressionImpl = *literal(value) - boolLiteralExpression.boolInterfaceImpl.root = &boolLiteralExpression - - return &boolLiteralExpression -} - -// ---------------------------------------------------// -type floatLiteral struct { - floatInterfaceImpl - literalExpressionImpl + return BoolExp(Literal(value)) } // Float creates new float literal from float64 value func Float(value float64) FloatExpression { - floatLiteral := floatLiteral{} - floatLiteral.literalExpressionImpl = *literal(value) - - floatLiteral.floatInterfaceImpl.root = &floatLiteral - - return &floatLiteral + return FloatExp(Literal(value)) } // Decimal creates new float literal from string value func Decimal(value string) FloatExpression { - floatLiteral := floatLiteral{} - floatLiteral.literalExpressionImpl = *literal(value) - - floatLiteral.floatInterfaceImpl.root = &floatLiteral - - return &floatLiteral -} - -// ---------------------------------------------------// -type stringLiteral struct { - stringInterfaceImpl - literalExpressionImpl + return FloatExp(Literal(value)) } // String creates new string literal expression func String(value string) StringExpression { - stringLiteral := stringLiteral{} - stringLiteral.literalExpressionImpl = *literal(value) - - stringLiteral.stringInterfaceImpl.root = &stringLiteral - - return &stringLiteral -} - -//---------------------------------------------------// - -type timeLiteral struct { - timeInterfaceImpl - literalExpressionImpl + return StringExp(Literal(value)) } // Time creates new time literal expression func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { - timeLiteral := &timeLiteral{} timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) timeStr += formatNanoseconds(nanoseconds...) - timeLiteral.literalExpressionImpl = *literal(timeStr) - timeLiteral.timeInterfaceImpl.root = timeLiteral - - return timeLiteral + return TimeExp(Literal(timeStr)) } // TimeT creates new time literal expression from time.Time object func TimeT(t time.Time) TimeExpression { - timeLiteral := &timeLiteral{} - timeLiteral.literalExpressionImpl = *literal(t) - timeLiteral.timeInterfaceImpl.root = timeLiteral - - return timeLiteral -} - -//---------------------------------------------------// - -type timezLiteral struct { - timezInterfaceImpl - literalExpressionImpl + return TimeExp(Literal(t)) } // Timez creates new time with time zone literal expression func Timez(hour, minute, second int, nanoseconds time.Duration, timezone string) TimezExpression { - timezLiteral := timezLiteral{} timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) timeStr += formatNanoseconds(nanoseconds) timeStr += " " + timezone - timezLiteral.literalExpressionImpl = *literal(timeStr) - return TimezExp(literal(timeStr)) + return TimezExp(Literal(timeStr)) } // TimezT creates new time with time zone literal expression from time.Time object func TimezT(t time.Time) TimezExpression { - timeLiteral := &timezLiteral{} - timeLiteral.literalExpressionImpl = *literal(t) - timeLiteral.timezInterfaceImpl.root = timeLiteral - - return timeLiteral -} - -//---------------------------------------------------// - -type timestampLiteral struct { - timestampInterfaceImpl - literalExpressionImpl + return TimezExp(Literal(t)) } // Timestamp creates new timestamp literal expression func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression { - timestamp := ×tampLiteral{} timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second) timeStr += formatNanoseconds(nanoseconds...) - timestamp.literalExpressionImpl = *literal(timeStr) - timestamp.timestampInterfaceImpl.root = timestamp - return timestamp + + return TimestampExp(Literal(timeStr)) } // TimestampT creates new timestamp literal expression from time.Time object func TimestampT(t time.Time) TimestampExpression { - timestamp := ×tampLiteral{} - timestamp.literalExpressionImpl = *literal(t) - timestamp.timestampInterfaceImpl.root = timestamp - return timestamp -} - -//---------------------------------------------------// - -type timestampzLiteral struct { - timestampzInterfaceImpl - literalExpressionImpl + return TimestampExp(Literal(t)) } // Timestampz creates new timestamp with time zone literal expression func Timestampz(year int, month time.Month, day, hour, minute, second int, nanoseconds time.Duration, timezone string) TimestampzExpression { - timestamp := ×tampzLiteral{} timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second) timeStr += formatNanoseconds(nanoseconds) timeStr += " " + timezone - timestamp.literalExpressionImpl = *literal(timeStr) - timestamp.timestampzInterfaceImpl.root = timestamp - return timestamp + return TimestampzExp(Literal(timeStr)) } // TimestampzT creates new timestamp literal expression from time.Time object func TimestampzT(t time.Time) TimestampzExpression { - timestamp := ×tampzLiteral{} - timestamp.literalExpressionImpl = *literal(t) - timestamp.timestampzInterfaceImpl.root = timestamp - return timestamp -} - -//---------------------------------------------------// - -type dateLiteral struct { - dateInterfaceImpl - literalExpressionImpl + return TimestampzExp(Literal(t)) } // Date creates new date literal expression func Date(year int, month time.Month, day int) DateExpression { - dateLiteral := &dateLiteral{} - timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day) - dateLiteral.literalExpressionImpl = *literal(timeStr) - dateLiteral.dateInterfaceImpl.root = dateLiteral - - return dateLiteral + return DateExp(Literal(timeStr)) } // DateT creates new date literal expression from time.Time object func DateT(t time.Time) DateExpression { - dateLiteral := &dateLiteral{} - dateLiteral.literalExpressionImpl = *literal(t) - dateLiteral.dateInterfaceImpl.root = dateLiteral - - return dateLiteral + return DateExp(Literal(t)) } func formatNanoseconds(nanoseconds ...time.Duration) string { @@ -330,86 +179,35 @@ func formatNanoseconds(nanoseconds ...time.Duration) string { var ( // NULL is jet equivalent of SQL NULL - NULL = newNullLiteral() + NULL = newExpression(Keyword("NULL")) // STAR is jet equivalent of SQL * - STAR = newStarLiteral() + STAR = newExpression(Keyword("*")) // PLUS_INFINITY is jet equivalent for sql infinity PLUS_INFINITY = String("infinity") // MINUS_INFINITY is jet equivalent for sql -infinity MINUS_INFINITY = String("-infinity") ) -type nullLiteral struct { - ExpressionInterfaceImpl -} - -func newNullLiteral() Expression { - nullExpression := &nullLiteral{} - - nullExpression.ExpressionInterfaceImpl.Root = nullExpression - - return nullExpression -} - -func (n *nullLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("NULL") -} - -// --------------------------------------------------// -type starLiteral struct { - ExpressionInterfaceImpl -} - -func newStarLiteral() Expression { - starExpression := &starLiteral{} - - starExpression.ExpressionInterfaceImpl.Root = starExpression - - return starExpression -} - -func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("*") -} - //---------------------------------------------------// -type rawExpression struct { - ExpressionInterfaceImpl - +type rawSerializer struct { Raw string NamedArgument map[string]interface{} - noWrap bool } -func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !n.noWrap && !contains(options, NoWrap) { - out.WriteByte('(') - } - - out.insertRawQuery(n.Raw, n.NamedArgument) - - if !n.noWrap && !contains(options, NoWrap) { - out.WriteByte(')') - } +func (n *rawSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + out.insertRawQuery(n.Raw, n.NamedArgument) + }) } // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") func Raw(raw string, namedArgs ...map[string]interface{}) Expression { - var namedArguments map[string]interface{} - - if len(namedArgs) > 0 { - namedArguments = namedArgs[0] - } - - rawExp := &rawExpression{ + return newExpression(&rawSerializer{ Raw: raw, - NamedArgument: namedArguments, - } - rawExp.ExpressionInterfaceImpl.Root = rawExp - - return rawExp + NamedArgument: singleOptional(namedArgs), + }) } // RawBool helper that for raw string boolean expressions diff --git a/internal/jet/operators.go b/internal/jet/operators.go index b56cb1b..277fb2c 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -16,9 +16,6 @@ func NOT(exp BoolExpression) BoolExpression { // BIT_NOT inverts every bit in integer expression result func BIT_NOT(expr IntegerExpression) IntegerExpression { - if literalExp, ok := expr.(LiteralExpression); ok { - literalExp.SetConstant(true) - } return newPrefixIntegerOperatorExpression(expr, "~") } @@ -131,10 +128,8 @@ type caseOperatorImpl struct { // CASE create CASE operator with optional list of expressions func CASE(expression ...Expression) CaseOperator { - caseExp := &caseOperatorImpl{} - - if len(expression) > 0 { - caseExp.expression = expression[0] + caseExp := &caseOperatorImpl{ + expression: singleOptional(expression), } caseExp.ExpressionInterfaceImpl.Root = caseExp diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go index 288f8a3..952afa0 100644 --- a/internal/jet/order_set_aggregate_functions.go +++ b/internal/jet/order_set_aggregate_functions.go @@ -34,25 +34,20 @@ func newOrderSetAggregateFunction(name string, fraction FloatExpression) *OrderS // WITHIN_GROUP_ORDER_BY specifies ordered set of aggregated argument values func (p *OrderSetAggregateFunc) WITHIN_GROUP_ORDER_BY(orderBy OrderByClause) Expression { p.orderBy = ORDER_BY(orderBy) - return newOrderSetAggregateFuncExpression(*p) + return newOrderSetAggregateFuncExpression(p) } -func newOrderSetAggregateFuncExpression(aggFunc OrderSetAggregateFunc) *orderSetAggregateFuncExpression { - ret := &orderSetAggregateFuncExpression{ +func newOrderSetAggregateFuncExpression(aggFunc *OrderSetAggregateFunc) Expression { + return newExpression(&orderSetAggregateFuncSerializer{ OrderSetAggregateFunc: aggFunc, - } - - ret.ExpressionInterfaceImpl.Root = ret - - return ret + }) } -type orderSetAggregateFuncExpression struct { - ExpressionInterfaceImpl - OrderSetAggregateFunc +type orderSetAggregateFuncSerializer struct { + *OrderSetAggregateFunc } -func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (p *orderSetAggregateFuncSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(p.name) if p.fraction != nil { diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go index ec551f0..7d73c6f 100644 --- a/internal/jet/raw_statement.go +++ b/internal/jet/raw_statement.go @@ -15,11 +15,8 @@ func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string] statementType: "", root: nil, }, - RawQuery: rawQuery, - } - - if len(namedArgument) > 0 { - newRawStatement.NamedArguments = namedArgument[0] + RawQuery: rawQuery, + NamedArguments: singleOptional(namedArgument), } newRawStatement.root = &newRawStatement diff --git a/internal/jet/row_expression.go b/internal/jet/row_expression.go index 6df1c60..5df6199 100644 --- a/internal/jet/row_expression.go +++ b/internal/jet/row_expression.go @@ -74,7 +74,7 @@ func newRowExpression(name string, dialect Dialect, expressions ...Expression) R ret := &rowExpressionWrapper{} ret.rowInterfaceImpl.root = ret - ret.Expression = NewFunc(name, expressions, ret) + ret.Expression = newFunc(name, expressions) ret.dialect = dialect ret.expressions = expressions diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index d876f36..d6e09db 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -1,5 +1,7 @@ package jet +import "slices" + // SerializeOption type type SerializeOption int @@ -73,6 +75,12 @@ func FallTrough(options []SerializeOption) []SerializeOption { return ret } +func without(options []SerializeOption, option SerializeOption) []SerializeOption { + return slices.DeleteFunc(options, func(elem SerializeOption) bool { + return elem == option + }) +} + // ListSerializer serializes list of serializers with separator type ListSerializer struct { Serializers []Serializer @@ -109,3 +117,54 @@ type Token string func (t Token) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(string(t)) } + +// CustomExpression creates new custom expression. When serialized may require parentheses +// depending on context. +func CustomExpression(parts ...Serializer) Expression { + return newExpression(&customSerializer{ + parts: parts, + }) +} + +// AtomicCustomExpression creates new custom expression. When serialized does not require parentheses. +func AtomicCustomExpression(parts ...Serializer) Expression { + return newExpression(&customSerializer{ + parts: parts, + atomic: true, + }) +} + +type customSerializer struct { + parts []Serializer + atomic bool +} + +func (c *customSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if c.atomic { + for _, expr := range c.parts { + expr.serialize(statement, out, without(options, NoWrap)...) + } + } else { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + for _, expr := range c.parts { + expr.serialize(statement, out, options...) + } + }) + } +} + +func optionalWrap(out *SQLBuilder, options []SerializeOption, ser func(out *SQLBuilder, options []SerializeOption)) { + if !contains(options, NoWrap) { + out.WriteString("(") + } + + ser(out, without(options, NoWrap)) + + if !contains(options, NoWrap) { + out.WriteString(")") + } +} + +func wrap(expressions ...Expression) Expression { + return newFunc("", expressions) +} diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index 5d373ca..d032aad 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -85,11 +85,33 @@ func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression } func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperatorExpression(s.root, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return BoolExp(newExpression(®expLikeSerializer{ + str: s.root, + pattern: pattern, + caseSensitive: len(caseSensitive) > 0 && caseSensitive[0], + })) } func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperatorExpression(s.root, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return BoolExp(newExpression(®expLikeSerializer{ + not: true, + str: s.root, + pattern: pattern, + caseSensitive: len(caseSensitive) > 0 && caseSensitive[0], + })) +} + +type regexpLikeSerializer struct { + not bool + str StringExpression + pattern StringExpression + caseSensitive bool +} + +func (r *regexpLikeSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + out.Dialect.RegexpLike(r.str, r.not, r.pattern, r.caseSensitive)(statement, out, options...) + }) } // ---------------------------------------------------// diff --git a/internal/jet/utils.go b/internal/jet/utils.go index b6f5355..f502970 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -121,12 +121,12 @@ func SerializeColumnExpressionNames(columns []ColumnExpression, out *SQLBuilder) } } -// ExpressionListToSerializerList converts list of expressions to list of serializers -func ExpressionListToSerializerList(expressions []Expression) []Serializer { - var ret []Serializer +// ToSerializerList converts list of expressions to list of serializers +func ToSerializerList[T Serializer](elems []T) []Serializer { + ret := make([]Serializer, len(elems)) - for _, expr := range expressions { - ret = append(ret, expr) + for i, ser := range elems { + ret[i] = ser } return ret @@ -134,10 +134,10 @@ func ExpressionListToSerializerList(expressions []Expression) []Serializer { // ToExpressionList converts list of any expressions to list of expressions func ToExpressionList[T Expression](expressions []T) []Expression { - var ret []Expression + ret := make([]Expression, len(expressions)) - for _, expression := range expressions { - ret = append(ret, expression) + for i, expr := range expressions { + ret[i] = expr } return ret @@ -145,10 +145,10 @@ func ToExpressionList[T Expression](expressions []T) []Expression { // ColumnListToProjectionList func func ColumnListToProjectionList(columns []ColumnExpression) []Projection { - var ret []Projection + ret := make([]Projection, len(columns)) - for _, column := range columns { - ret = append(ret, column) + for i, column := range columns { + ret[i] = column } return ret @@ -160,7 +160,7 @@ func ToSerializerValue(value interface{}) Serializer { return clause } - return literal(value) + return Literal(value) } // UnwindRowFromModel func @@ -189,7 +189,7 @@ func UnwindRowFromModel(columns []Column, data interface{}) []Serializer { field = reflect.Indirect(structField).Interface() } - row[i] = literal(field) + row[i] = Literal(field) } return row @@ -252,11 +252,11 @@ func OptionalOrDefaultString(defaultStr string, str ...string) string { return defaultStr } -// OptionalOrDefaultExpression will return first value from variable argument list expression or +// OptionalOrDefault will return first value from variable argument list expression or // defaultExpression if variable argument list is empty -func OptionalOrDefaultExpression(defaultExpression Expression, expression ...Expression) Expression { - if len(expression) > 0 { - return expression[0] +func OptionalOrDefault(expressions []Expression, defaultExpression Expression) Expression { + if len(expressions) > 0 { + return expressions[0] } return defaultExpression @@ -292,3 +292,13 @@ func joinAlias(tableAlias, columnAlias string) string { } return strings.TrimRight(tableAlias, ".*") + "." + columnAlias } + +func singleOptional[T any](value []T) T { + if len(value) > 0 { + return value[0] + } + + var def T + + return def +} diff --git a/internal/jet/utils_test.go b/internal/jet/utils_test.go index 86feff1..0317457 100644 --- a/internal/jet/utils_test.go +++ b/internal/jet/utils_test.go @@ -12,11 +12,12 @@ func TestOptionalOrDefaultString(t *testing.T) { } func TestOptionalOrDefaultExpression(t *testing.T) { - defaultExpression := table2ColFloat + defaultExpression := []Expression{table2ColFloat} optionalExpression := table1Col1 - require.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression) - require.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression) + require.Equal(t, OptionalOrDefault(defaultExpression, nil), table2ColFloat) + require.Equal(t, OptionalOrDefault(defaultExpression, optionalExpression), table2ColFloat) + require.Equal(t, OptionalOrDefault(nil, optionalExpression), table1Col1) } func TestJoinAlias(t *testing.T) { diff --git a/internal/jet/window_expression.go b/internal/jet/window_expression.go index e5f18c1..0bbfc60 100644 --- a/internal/jet/window_expression.go +++ b/internal/jet/window_expression.go @@ -28,12 +28,13 @@ type windowExpression interface { OVER(window ...Window) Expression } -func newWindowExpression(Exp Expression) windowExpression { +func newWindowExpression(exp Expression) windowExpression { newExp := &windowExpressionImpl{ - Expression: Exp, + Expression: exp, } - newExp.commonWindowImpl.expression = Exp + newExp.commonWindowImpl.expression = exp + exp.setRoot(newExp) return newExp } @@ -65,6 +66,7 @@ func newFloatWindowExpression(floatExp FloatExpression) floatWindowExpression { } newExp.commonWindowImpl.expression = floatExp + floatExp.setRoot(newExp) return newExp } @@ -96,6 +98,7 @@ func newIntegerWindowExpression(intExp IntegerExpression) integerWindowExpressio } newExp.commonWindowImpl.expression = intExp + intExp.setRoot(newExp) return newExp } @@ -127,6 +130,7 @@ func newBoolWindowExpression(boolExp BoolExpression) boolWindowExpression { } newExp.commonWindowImpl.expression = boolExp + boolExp.setRoot(newExp) return newExp } diff --git a/mysql/cast.go b/mysql/cast.go index fbce06c..05dc61c 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -1,25 +1,25 @@ package mysql import ( - "github.com/go-jet/jet/v2/internal/jet" "strconv" + + "github.com/go-jet/jet/v2/internal/jet" ) -type cast struct { - jet.Cast +// CAST function converts an expr (of any type) into later specified datatype. +func CAST(expr Expression) *cast { + return &cast{ + expr: expr, + } } -// CAST function converts a expr (of any type) into latter specified datatype. -func CAST(expr Expression) *cast { - ret := &cast{} - ret.Cast = jet.NewCastImpl(expr) - - return ret +type cast struct { + expr Expression } // AS casts expressions to castType func (c *cast) AS(castType string) Expression { - return c.Cast.AS(castType) + return jet.AtomicCustomExpression(Token("CAST("), c.expr, Token("AS "+castType+")")) } // AS_DATETIME cast expression to DATETIME type diff --git a/mysql/dialect.go b/mysql/dialect.go index bac3419..71a8e38 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -12,8 +12,6 @@ var Dialect = newDialect() func newDialect() jet.Dialect { operatorSerializeOverrides := map[string]jet.SerializeOverride{} - operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator - operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator operatorSerializeOverrides["IS DISTINCT FROM"] = mysqlISDISTINCTFROM operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysqlISNOTDISTINCTFROM operatorSerializeOverrides["/"] = mysqlDivision @@ -42,16 +40,17 @@ func newDialect() jet.Dialect { // CustomExpression used bellow (instead DATE_FORMAT function) so that only expr is parametrized case TimestampExpression: - return CustomExpression(Token("DATE_FORMAT("), e, Token(",'%Y-%m-%dT%H:%i:%s.%fZ')")) + return jet.AtomicCustomExpression(Token("DATE_FORMAT("), e, Token(",'%Y-%m-%dT%H:%i:%s.%fZ')")) case TimeExpression: - return CustomExpression(Token("CONCAT('0000-01-01T', DATE_FORMAT("), e, Token(",'%H:%i:%s.%fZ'))")) + return jet.AtomicCustomExpression(Token("CONCAT('0000-01-01T', DATE_FORMAT("), e, Token(",'%H:%i:%s.%fZ'))")) case DateExpression: - return CustomExpression(Token("CONCAT(DATE_FORMAT("), e, Token(",'%Y-%m-%d')"), Token(", 'T00:00:00Z')")) + return jet.AtomicCustomExpression(Token("CONCAT(DATE_FORMAT("), e, Token(",'%Y-%m-%d')"), Token(", 'T00:00:00Z')")) case BoolExpression: return CustomExpression(e, Token(" = 1")) } return expr }, + RegexpLike: regexpLikeOperator, } return jet.NewDialect(mySQLDialectParams) @@ -144,20 +143,12 @@ func mysqlISDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc { } } -func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { +func regexpLikeOperator(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } + jet.Serialize(str, statement, out, options...) - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } + if not { + out.WriteString("NOT") } out.WriteString("REGEXP") @@ -166,33 +157,7 @@ func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { out.WriteString("BINARY") } - jet.Serialize(expressions[1], statement, out, options...) - } -} - -func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { - return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } - - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } - } - - out.WriteString("NOT REGEXP") - - if caseSensitive { - out.WriteString("BINARY") - } - - jet.Serialize(expressions[1], statement, out, options...) + jet.Serialize(pattern, statement, out, options...) } } diff --git a/mysql/expressions_test.go b/mysql/expressions_test.go index 39174b6..b51ee54 100644 --- a/mysql/expressions_test.go +++ b/mysql/expressions_test.go @@ -48,7 +48,7 @@ func TestRawInvalidArguments(t *testing.T) { func TestRawType(t *testing.T) { assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), - "(table.colInt < ?) IS FALSE", 11.22) + "((table.colInt < ?) IS FALSE)", 11.22) assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), "((table.colInt + ?) = ?)", 11.22, 3.14) diff --git a/mysql/interval_literal.go b/mysql/interval_literal.go index 4fff705..a67457e 100644 --- a/mysql/interval_literal.go +++ b/mysql/interval_literal.go @@ -2,10 +2,11 @@ package mysql import ( "fmt" - "github.com/go-jet/jet/v2/internal/utils/datetime" "regexp" "time" + "github.com/go-jet/jet/v2/internal/utils/datetime" + "github.com/go-jet/jet/v2/internal/jet" ) @@ -98,7 +99,7 @@ func INTERVAL(value interface{}, unitType unitType) Interval { // INTERVALe creates new temporal interval from expresion and unit type. func INTERVALe(expr Expression, unitType unitType) Interval { - return jet.IntervalExp(CustomExpression(Token("INTERVAL"), expr, Token(unitType))) + return jet.IntervalExp(jet.AtomicCustomExpression(Token("INTERVAL"), expr, Token(unitType))) } // INTERVALd creates new temporal interval from time.Duration diff --git a/postgres/cast.go b/postgres/cast.go index 46abf39..42e1a21 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -7,21 +7,19 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -type cast struct { - jet.Cast -} - // CAST function converts an expr (of any type) into later specified datatype. func CAST(expr Expression) *cast { - ret := &cast{} - ret.Cast = jet.NewCastImpl(expr) - - return ret + return &cast{ + expr: expr, + } +} + +type cast struct { + expr Expression } -// AS casts expression as castType func (b *cast) AS(castType string) Expression { - return b.Cast.AS(castType) + return jet.AtomicCustomExpression(b.expr, Token("::"+castType)) } // AS_BOOL casts expression as bool type diff --git a/postgres/dialect.go b/postgres/dialect.go index 9ffa5f7..ebcfa2e 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -3,8 +3,9 @@ package postgres import ( "encoding/hex" "fmt" - "github.com/go-jet/jet/v2/internal/jet" "strconv" + + "github.com/go-jet/jet/v2/internal/jet" ) // Dialect is implementation of postgres dialect for SQL Builder serialisation. @@ -12,15 +13,10 @@ var Dialect = newDialect() func newDialect() jet.Dialect { - operatorSerializeOverrides := map[string]jet.SerializeOverride{} - operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgresREGEXPLIKEoperator - operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgresNOTREGEXPLIKEoperator - operatorSerializeOverrides["CAST"] = postgresCAST - dialectParams := jet.DialectParams{ Name: "PostgreSQL", PackageName: "postgres", - OperatorSerializeOverrides: operatorSerializeOverrides, + OperatorSerializeOverrides: nil, AliasQuoteChar: '"', IdentifierQuoteChar: '"', ArgumentPlaceholder: func(ord int) string { @@ -42,12 +38,13 @@ func newDialect() jet.Dialect { case TimezExpression: return CustomExpression(Token("'0000-01-01T' || to_char('2000-10-10'::date + "), e, Token(`, 'HH24:MI:SS.USTZH:TZM')`)) case TimestampExpression: - return CustomExpression(Token("to_char("), e, Token(`, 'YYYY-MM-DD"T"HH24:MI:SS.USZ')`)) + return jet.AtomicCustomExpression(Token("to_char("), e, Token(`, 'YYYY-MM-DD"T"HH24:MI:SS.USZ')`)) case DateExpression: return CustomExpression(Token("to_char("), e, Token(`::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z'`)) } return expr }, + RegexpLike: regexpLike, } return jet.NewDialect(dialectParams) @@ -62,80 +59,23 @@ func argumentToString(value any) (string, bool) { return "", false } -func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { +func regexpLike(str jet.StringExpression, not bool, pattern jet.StringExpression, caseSensitive bool) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } + jet.Serialize(str, statement, out, options...) - expression := expressions[0] + var notOperator string - litExpr, ok := expressions[1].(jet.LiteralExpression) - - if !ok { - panic("jet: cast invalid cast type") - } - - castType, ok := litExpr.Value().(string) - - if !ok { - panic("jet: cast type is not string") - } - - jet.Serialize(expression, statement, out, options...) - out.WriteString("::" + castType) - } -} - -func postgresREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { - return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } - - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } + if not { + notOperator = "!" } if caseSensitive { - out.WriteString("~") + out.WriteString(notOperator + "~") } else { - out.WriteString("~*") + out.WriteString(notOperator + "~*") } - jet.Serialize(expressions[1], statement, out, options...) - } -} - -func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { - return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } - - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } - } - - if caseSensitive { - out.WriteString("!~") - } else { - out.WriteString("!~*") - } - - jet.Serialize(expressions[1], statement, out, options...) + jet.Serialize(pattern, statement, out, options...) } } diff --git a/postgres/expressions_test.go b/postgres/expressions_test.go index 3ab361d..f8288c8 100644 --- a/postgres/expressions_test.go +++ b/postgres/expressions_test.go @@ -58,7 +58,7 @@ func TestRawInvalidArguments(t *testing.T) { func TestRawHelperMethods(t *testing.T) { assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), - "(table.colInt < $1) IS FALSE", 11.22) + "((table.colInt < $1) IS FALSE)", 11.22) assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)), "((table.colInt + $1) = $2)", 11.22, 3.14) diff --git a/postgres/functions.go b/postgres/functions.go index 14a5bd7..cfd53b8 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -184,12 +184,12 @@ var CHR = jet.CHR // CONCAT adds two or more expressions together var CONCAT = func(expressions ...Expression) StringExpression { - return jet.CONCAT(explicitLiteralCasts(expressions...)...) + return jet.CONCAT(expressions...) } // CONCAT_WS adds two or more expressions together with a separator. func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression { - return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...) + return jet.CONCAT_WS(separator, expressions...) } // Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions @@ -239,7 +239,7 @@ var DECODE = jet.DECODE // FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf. func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { - return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...) + return jet.FORMAT(formatStr, formatArgs...) } // INITCAP converts the first letter of each word to upper case @@ -552,10 +552,10 @@ func DATE_TRUNC(field unit, source Expression, timezone ...string) TimestampExpr // GENERATE_SERIES generates a series of values from start to stop, with a step size of step. func GENERATE_SERIES(start Expression, stop Expression, step ...Expression) Expression { if len(step) > 0 { - return jet.NewFunc("GENERATE_SERIES", []Expression{start, stop, step[0]}, nil) + return Func("GENERATE_SERIES", start, stop, step[0]) } - return jet.NewFunc("GENERATE_SERIES", []Expression{start, stop}, nil) + return Func("GENERATE_SERIES", start, stop) } // --------------- Conditional Expressions Functions -------------// @@ -578,55 +578,19 @@ var EXISTS = jet.EXISTS // CASE create CASE operator with optional list of expressions var CASE = jet.CASE -func explicitLiteralCasts(expressions ...Expression) []jet.Expression { - ret := []jet.Expression{} - - for _, exp := range expressions { - ret = append(ret, explicitLiteralCast(exp)) - } - - return ret -} - -func explicitLiteralCast(expresion Expression) jet.Expression { - if _, ok := expresion.(jet.LiteralExpression); !ok { - return expresion - } - - switch expresion.(type) { - case jet.BoolExpression: - return CAST(expresion).AS_BOOL() - case jet.IntegerExpression: - return CAST(expresion).AS_INTEGER() - case jet.FloatExpression: - return CAST(expresion).AS_NUMERIC() - case jet.StringExpression: - return CAST(expresion).AS_TEXT() - } - - return expresion -} - // MODE computes the most frequent value of the aggregated argument var MODE = jet.MODE // PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of // aggregated argument values. This will interpolate between adjacent input items if needed. func PERCENTILE_CONT(fraction FloatExpression) *jet.OrderSetAggregateFunc { - return jet.PERCENTILE_CONT(castFloatLiteral(fraction)) + return jet.PERCENTILE_CONT(fraction) } // PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position // in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type. func PERCENTILE_DISC(fraction FloatExpression) *jet.OrderSetAggregateFunc { - return jet.PERCENTILE_DISC(castFloatLiteral(fraction)) -} - -func castFloatLiteral(fraction FloatExpression) FloatExpression { - if _, ok := fraction.(jet.LiteralExpression); ok { - return CAST(fraction).AS_DOUBLE() // to make postgres aware of the type - } - return fraction + return jet.PERCENTILE_DISC(fraction) } // ----------------- Group By operators --------------------------// diff --git a/postgres/interval_literal.go b/postgres/interval_literal.go index c36f539..3b320ca 100644 --- a/postgres/interval_literal.go +++ b/postgres/interval_literal.go @@ -2,10 +2,12 @@ package postgres import ( "fmt" - "github.com/go-jet/jet/v2/internal/utils/datetime" "strconv" "strings" "time" + + "github.com/go-jet/jet/v2/internal/jet" + "github.com/go-jet/jet/v2/internal/utils/datetime" ) type quantityAndUnit = float64 @@ -44,7 +46,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { fields = append(fields, quantity+" "+unitString) } - return IntervalExp(CustomExpression(Token(fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " "))))) + return IntervalExp(jet.AtomicCustomExpression(Token(fmt.Sprintf("INTERVAL '%s'", strings.Join(fields, " "))))) } // INTERVALd creates interval expression from time.Duration diff --git a/postgres/interval_literal_test.go b/postgres/interval_literal_test.go index 8d6e647..ef4ae1d 100644 --- a/postgres/interval_literal_test.go +++ b/postgres/interval_literal_test.go @@ -23,7 +23,7 @@ func TestINTERVAL(t *testing.T) { assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY), "INTERVAL '1 YEAR 10 MONTH 20 DAY'") assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), "INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR'") - assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL") + assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "(INTERVAL '1 YEAR' IS NOT NULL)") assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`) f := 5.2 diff --git a/postgres/literal.go b/postgres/literal.go index 1e93110..046231f 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -2,9 +2,10 @@ package postgres import ( "fmt" - "github.com/lib/pq" "time" + "github.com/lib/pq" + "github.com/go-jet/jet/v2/internal/jet" ) diff --git a/sqlite/cast.go b/sqlite/cast.go index fb74820..6f53853 100644 --- a/sqlite/cast.go +++ b/sqlite/cast.go @@ -4,21 +4,20 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -type cast struct { - jet.Cast +// CAST function converts an expr (of any type) into later specified datatype. +func CAST(expr Expression) *cast { + return &cast{ + expr: expr, + } } -// CAST function converts a expr (of any type) into latter specified datatype. -func CAST(expr Expression) *cast { - ret := &cast{} - ret.Cast = jet.NewCastImpl(expr) - - return ret +type cast struct { + expr Expression } // AS casts expressions to castType func (c *cast) AS(castType string) Expression { - return c.Cast.AS(castType) + return jet.AtomicCustomExpression(Token("CAST("), c.expr, Token("AS "+castType+")")) } // AS_TEXT cast expression to TEXT type diff --git a/sqlite/expressions_test.go b/sqlite/expressions_test.go index 0f04df6..922118a 100644 --- a/sqlite/expressions_test.go +++ b/sqlite/expressions_test.go @@ -1,8 +1,9 @@ package sqlite import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestRaw(t *testing.T) { @@ -46,7 +47,7 @@ func TestRawInvalidArguments(t *testing.T) { func TestRawType(t *testing.T) { assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), - "(table.colInt < ?) IS FALSE", 11.22) + "((table.colInt < ?) IS FALSE)", 11.22) assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), "((table.colInt + ?) = ?)", 11.22, 3.14) diff --git a/sqlite/functions.go b/sqlite/functions.go index 92139b4..f2578c5 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -2,8 +2,9 @@ package sqlite import ( "fmt" - "github.com/go-jet/jet/v2/internal/jet" "time" + + "github.com/go-jet/jet/v2/internal/jet" ) // This functions can be used, instead of its method counterparts, to have a better indentation of a complex condition @@ -297,7 +298,7 @@ func modifier(modifierName string) func(value float64) Expression { func DATE(timeValue interface{}, modifiers ...Expression) DateExpression { exprList := getFuncExprList(timeValue, modifiers...) - return jet.NewDateFunc("DATE", exprList...) + return DateExp(Func("DATE", exprList...)) } // TIME function creates new time from time-value and zero or more time modifiers diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index ae3de1b..33f4377 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,13 +1,14 @@ package mysql import ( - "github.com/go-jet/jet/v2/internal/utils/ptr" - "github.com/shopspring/decimal" - "github.com/stretchr/testify/require" "strings" "testing" "time" + "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + "github.com/google/uuid" "github.com/go-jet/jet/v2/internal/testutils" @@ -53,8 +54,8 @@ func TestAllTypesJSON(t *testing.T) { testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(` SELECT JSON_ARRAYAGG(JSON_OBJECT( 'id', all_types.id, - 'boolean', all_types.boolean = 1, - 'booleanPtr', all_types.boolean_ptr = 1, + 'boolean', (all_types.boolean = 1), + 'booleanPtr', (all_types.boolean_ptr = 1), 'tinyInt', all_types.tiny_int, 'uTinyInt', all_types.u_tiny_int, 'smallInt', all_types.small_int, @@ -190,8 +191,8 @@ func TestExpressionOperators(t *testing.T) { ).LIMIT(2) testutils.AssertStatementSql(t, query, strings.Replace(` -SELECT all_types.'integer' IS NULL AS "result.is_null", - all_types.date_ptr IS NOT NULL AS "result.is_not_null", +SELECT (all_types.'integer' IS NULL) AS "result.is_null", + (all_types.date_ptr IS NOT NULL) AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", (all_types.small_int_ptr IN (( SELECT all_types.'integer' AS "all_types.integer" @@ -259,12 +260,12 @@ SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", (NOT(all_types.boolean <=> ?)) AS "distinct2", (all_types.boolean <=> all_types.boolean_ptr) AS "not_distinct_1", (all_types.boolean <=> ?) AS "NOTDISTINCT2", - all_types.boolean IS TRUE AS "ISTRUE", - all_types.boolean IS NOT TRUE AS "isnottrue", - all_types.boolean IS FALSE AS "is_False", - all_types.boolean IS NOT FALSE AS "is not false", - all_types.boolean IS UNKNOWN AS "is unknown", - all_types.boolean IS NOT UNKNOWN AS "is_not_unknown", + (all_types.boolean IS TRUE) AS "ISTRUE", + (all_types.boolean IS NOT TRUE) AS "isnottrue", + (all_types.boolean IS FALSE) AS "is_False", + (all_types.boolean IS NOT FALSE) AS "is not false", + (all_types.boolean IS UNKNOWN) AS "is unknown", + (all_types.boolean IS NOT UNKNOWN) AS "is_not_unknown", ((all_types.boolean AND all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex1", ((all_types.boolean OR all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex2" FROM test_sample.all_types; @@ -1143,7 +1144,7 @@ SELECT EXTRACT(MICROSECOND FROM CAST(? AS TIME)), EXTRACT(HOUR FROM all_types.timestamp), EXTRACT(DAY FROM all_types.date), EXTRACT(WEEK FROM all_types.timestamp), - EXTRACT(MONTH FROM all_types.timestamp + INTERVAL 1 DAY), + EXTRACT(MONTH FROM (all_types.timestamp + INTERVAL 1 DAY)), EXTRACT(QUARTER FROM all_types.timestamp), EXTRACT(YEAR FROM all_types.timestamp) = ?, EXTRACT(SECOND_MICROSECOND FROM all_types.time), @@ -1305,7 +1306,7 @@ FROM ( testutils.AssertDebugStatementSql(t, stmtJson, strings.ReplaceAll(` SELECT JSON_ARRAYAGG(JSON_OBJECT( - 'boolean', sub_query.''all_types.boolean'' = 1, + 'boolean', (sub_query.''all_types.boolean'' = 1), 'integer', sub_query.''all_types.integer'', 'double', sub_query.''all_types.double'', 'text', sub_query.''all_types.text'', diff --git a/tests/mysql/select_json_test.go b/tests/mysql/select_json_test.go index 390bfe3..d2994db 100644 --- a/tests/mysql/select_json_test.go +++ b/tests/mysql/select_json_test.go @@ -2,11 +2,12 @@ package mysql import ( "context" - "github.com/go-jet/jet/v2/qrm" "slices" "strings" "testing" + "github.com/go-jet/jet/v2/qrm" + "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" @@ -410,7 +411,7 @@ SELECT JSON_ARRAYAGG(JSON_OBJECT( 'lastName', customers_info.''customer.last_name'', 'email', customers_info.''customer.email'', 'addressID', customers_info.''customer.address_id'', - 'active', customers_info.''customer.active'' = 1, + 'active', (customers_info.''customer.active'' = 1), 'createDate', DATE_FORMAT(customers_info.''customer.create_date'','%Y-%m-%dT%H:%i:%s.%fZ'), 'lastUpdate', DATE_FORMAT(customers_info.''customer.last_update'','%Y-%m-%dT%H:%i:%s.%fZ'), 'amount', ( diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 1f00520..46262ff 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -3,13 +3,14 @@ package postgres import ( "encoding/base64" "fmt" + "math" + "testing" + "time" + "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/go-jet/jet/v2/qrm" "github.com/lib/pq" "github.com/stretchr/testify/assert" - "math" - "testing" - "time" "github.com/stretchr/testify/require" @@ -89,12 +90,12 @@ FROM ( all_types.timestampz AS "timestampz", to_char(all_types.timestamp_ptr, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampPtr", to_char(all_types.timestamp, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp", - to_char(all_types.date_ptr::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "datePtr", - to_char(all_types.date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date", - '0000-01-01T' || to_char('2000-10-10'::date + all_types.timez_ptr, 'HH24:MI:SS.USTZH:TZM') AS "timezPtr", - '0000-01-01T' || to_char('2000-10-10'::date + all_types.timez, 'HH24:MI:SS.USTZH:TZM') AS "timez", - '0000-01-01T' || to_char('2000-10-10'::date + all_types.time_ptr, 'HH24:MI:SS.USZ') AS "timePtr", - '0000-01-01T' || to_char('2000-10-10'::date + all_types.time, 'HH24:MI:SS.USZ') AS "time", + (to_char(all_types.date_ptr::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z') AS "datePtr", + (to_char(all_types.date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z') AS "date", + ('0000-01-01T' || to_char('2000-10-10'::date + all_types.timez_ptr, 'HH24:MI:SS.USTZH:TZM')) AS "timezPtr", + ('0000-01-01T' || to_char('2000-10-10'::date + all_types.timez, 'HH24:MI:SS.USTZH:TZM')) AS "timez", + ('0000-01-01T' || to_char('2000-10-10'::date + all_types.time_ptr, 'HH24:MI:SS.USZ')) AS "timePtr", + ('0000-01-01T' || to_char('2000-10-10'::date + all_types.time, 'HH24:MI:SS.USZ')) AS "time", all_types.interval_ptr AS "intervalPtr", all_types.interval AS "interval", all_types.boolean_ptr AS "booleanPtr", @@ -480,11 +481,14 @@ func TestExpressionOperators(t *testing.T) { AllTypes.SmallIntPtr.NOT_IN(Int(11), Int16(22), NULL).AS("result.not_in"), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), + + Bool(true).EQ(String("foo").IS_NOT_NULL()), + Bool(true).EQ(String("foo").IS_NOT_NULL()).AS("complex"), ).LIMIT(2) testutils.AssertStatementSql(t, query, ` -SELECT all_types.integer IS NULL AS "result.is_null", - all_types.date_ptr IS NOT NULL AS "result.is_not_null", +SELECT (all_types.integer IS NULL) AS "result.is_null", + (all_types.date_ptr IS NOT NULL) AS "result.is_not_null", (all_types.small_int_ptr IN ($1::smallint, $2::smallint)) AS "result.in", (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" @@ -497,18 +501,22 @@ SELECT all_types.integer IS NULL AS "result.is_null", (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.not_in_select" + ))) AS "result.not_in_select", + $11::boolean = ($12::text IS NOT NULL), + ($13::boolean = ($14::text IS NOT NULL)) AS "complex" FROM test_sample.all_types -LIMIT $11; -`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), int64(2)) +LIMIT $15; +`, int8(11), int8(22), 78, 56, 11, 22, 33, 44, int64(11), int16(22), true, "foo", true, "foo", int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` } - err := query.Query(db, &dest) + allowUnusedColumns(func() { + err := query.Query(db, &dest) + require.NoError(t, err) + }) - require.NoError(t, err) testutils.AssertJSON(t, dest, ` [ { @@ -640,10 +648,10 @@ func TestStringOperators(t *testing.T) { LTRIM(String("Ltrim"), String("A")), RTRIM(String("rtrim")), RTRIM(AllTypes.VarChar, String("B")), - CHR(Int(65)), - CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), - CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), - CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), + CHR(Int8(65)), + CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, Text("aaa"), Int8(1)), + CONCAT(Bool(false), Int16(1), Real(22.2), Text("test test")), + CONCAT_WS(Text("string1"), Int64(1), Real(11.22), Text("bytea"), Bool(false)), //Float(11.12)), CONVERT(Bytea("bytea"), UTF8, LATIN1), CONVERT(AllTypes.Bytea, UTF8, LATIN1), CONVERT_FROM(Bytea("text_in_utf8"), UTF8), @@ -904,12 +912,12 @@ SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", (all_types.boolean IS DISTINCT FROM $3::boolean) AS "distinct2", (all_types.boolean IS NOT DISTINCT FROM all_types.boolean_ptr) AS "not_distinct_1", (all_types.boolean IS NOT DISTINCT FROM $4::boolean) AS "NOTDISTINCT2", - all_types.boolean IS TRUE AS "ISTRUE", - all_types.boolean IS NOT TRUE AS "isnottrue", - all_types.boolean IS FALSE AS "is_False", - all_types.boolean IS NOT FALSE AS "is not false", - all_types.boolean IS UNKNOWN AS "is unknown", - all_types.boolean IS NOT UNKNOWN AS "is_not_unknown", + (all_types.boolean IS TRUE) AS "ISTRUE", + (all_types.boolean IS NOT TRUE) AS "isnottrue", + (all_types.boolean IS FALSE) AS "is_False", + (all_types.boolean IS NOT FALSE) AS "is not false", + (all_types.boolean IS UNKNOWN) AS "is unknown", + (all_types.boolean IS NOT UNKNOWN) AS "is_not_unknown", ((all_types.boolean AND all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex1", ((all_types.boolean OR all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex2" FROM test_sample.all_types @@ -1109,8 +1117,8 @@ func TestIntegerOperators(t *testing.T) { AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), - BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), - BIT_NOT(Int(-11)).AS("bit_not_2"), + BIT_NOT(Int32(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), + BIT_NOT(Int32(-11)).AS("bit_not_2"), AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), @@ -1122,8 +1130,6 @@ func TestIntegerOperators(t *testing.T) { CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), ).LIMIT(2) - // fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT all_types.big_int AS "all_types.big_int", all_types.big_int_ptr AS "all_types.big_int_ptr", @@ -1165,17 +1171,17 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.small_int | $20) AS "bit or 2", (all_types.small_int # all_types.small_int) AS "bit xor 1", (all_types.small_int # $21) AS "bit xor 2", - (~ ($22 * all_types.small_int)) AS "bit_not_1", - (~ -11) AS "bit_not_2", - (all_types.small_int << (all_types.small_int / $23::smallint)) AS "bit shift left 1", - (all_types.small_int << $24) AS "bit shift left 2", - (all_types.small_int >> (all_types.small_int / $25)) AS "bit shift right 1", - (all_types.small_int >> $26) AS "bit shift right 2", + (~ ($22::integer * all_types.small_int)) AS "bit_not_1", + (~ $23::integer) AS "bit_not_2", + (all_types.small_int << (all_types.small_int / $24::smallint)) AS "bit shift left 1", + (all_types.small_int << $25) AS "bit shift left 2", + (all_types.small_int >> (all_types.small_int / $26)) AS "bit shift right 1", + (all_types.small_int >> $27) AS "bit shift right 2", ABS(all_types.big_int) AS "abs", SQRT(ABS(all_types.big_int)) AS "sqrt", CBRT(ABS(all_types.big_int)) AS "cbrt" FROM test_sample.all_types -LIMIT $27; +LIMIT $28; `) var dest []struct { @@ -1267,7 +1273,72 @@ func TestTimeExpression(t *testing.T) { NOW(), ) - // fmt.Println(query.DebugSql()) + testutils.AssertStatementSql(t, query, ` +SELECT all_types.time = all_types.time, + all_types.time = $1::time without time zone, + all_types.timez = all_types.timez_ptr, + all_types.timez = $2::time with time zone, + all_types.timestamp = all_types.timestamp_ptr, + all_types.timestamp = $3::timestamp without time zone, + all_types.timestampz = all_types.timestampz_ptr, + all_types.timestampz = $4::timestamp with time zone, + all_types.date = all_types.date_ptr, + all_types.date = $5::date, + all_types.time != all_types.time, + all_types.time != $6::time without time zone, + all_types.timez != all_types.timez_ptr, + all_types.timez != $7::time with time zone, + all_types.timestamp != all_types.timestamp_ptr, + all_types.timestamp != $8::timestamp without time zone, + all_types.timestampz != all_types.timestampz_ptr, + all_types.timestampz != $9::timestamp with time zone, + all_types.date != all_types.date_ptr, + all_types.date != $10::date, + all_types.time IS DISTINCT FROM all_types.time, + all_types.time IS DISTINCT FROM $11::time without time zone, + all_types.time IS NOT DISTINCT FROM all_types.time, + all_types.time IS NOT DISTINCT FROM $12::time without time zone, + all_types.time < all_types.time, + all_types.time < $13::time without time zone, + all_types.time <= all_types.time, + all_types.time <= $14::time without time zone, + all_types.time > all_types.time, + all_types.time > $15::time without time zone, + all_types.time >= all_types.time, + all_types.time >= $16::time without time zone, + all_types.time BETWEEN $17::time without time zone AND $18::time without time zone, + all_types.time NOT BETWEEN all_types.time_ptr AND (all_types.time + INTERVAL '2 HOUR'), + all_types.date + INTERVAL '1 HOUR', + all_types.date - INTERVAL '1 MINUTE', + all_types.time + INTERVAL '1 HOUR', + all_types.time - INTERVAL '1 MINUTE', + all_types.timez + INTERVAL '1 HOUR', + all_types.timez - INTERVAL '1 MINUTE', + all_types.timez BETWEEN $19::time with time zone AND all_types.timez_ptr, + all_types.timez NOT BETWEEN all_types.timez AND $20::time with time zone, + all_types.timestamp + INTERVAL '1 HOUR', + all_types.timestamp - INTERVAL '1 MINUTE', + all_types.timestamp BETWEEN all_types.timestamp_ptr AND $21::timestamp without time zone, + all_types.timestamp NOT BETWEEN $22::timestamp without time zone AND all_types.timestamp_ptr, + all_types.timestampz + INTERVAL '1 HOUR', + all_types.timestampz - INTERVAL '1 MINUTE', + all_types.timestamp BETWEEN all_types.timestamp_ptr AND $23::timestamp without time zone, + all_types.timestamp NOT BETWEEN all_types.timestamp_ptr AND $24::timestamp without time zone, + all_types.date - $25::text::interval, + all_types.date BETWEEN $26::date AND $27::date, + all_types.date NOT BETWEEN all_types.date_ptr AND $28::date, + CURRENT_DATE, + CURRENT_TIME, + CURRENT_TIME(2), + CURRENT_TIMESTAMP, + CURRENT_TIMESTAMP(1), + LOCALTIME, + LOCALTIME(11), + LOCALTIMESTAMP, + LOCALTIMESTAMP(4), + NOW() +FROM test_sample.all_types; +`) var dest []struct{} @@ -1339,16 +1410,16 @@ SELECT $1::time without time zone AS "time", ( SELECT row_to_json(json_records) AS "json_json" FROM ( - SELECT '0000-01-01T' || to_char('2000-10-10'::date + $11::time without time zone, 'HH24:MI:SS.USZ') AS "time", - '0000-01-01T' || to_char('2000-10-10'::date + $12::time without time zone, 'HH24:MI:SS.USZ') AS "timeWithNanoSeconds", - '0000-01-01T' || to_char('2000-10-10'::date + $13::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timez", - '0000-01-01T' || to_char('2000-10-10'::date + $14::time with time zone, 'HH24:MI:SS.USTZH:TZM') AS "timezWithNanoSeconds", + SELECT ('0000-01-01T' || to_char('2000-10-10'::date + $11::time without time zone, 'HH24:MI:SS.USZ')) AS "time", + ('0000-01-01T' || to_char('2000-10-10'::date + $12::time without time zone, 'HH24:MI:SS.USZ')) AS "timeWithNanoSeconds", + ('0000-01-01T' || to_char('2000-10-10'::date + $13::time with time zone, 'HH24:MI:SS.USTZH:TZM')) AS "timez", + ('0000-01-01T' || to_char('2000-10-10'::date + $14::time with time zone, 'HH24:MI:SS.USTZH:TZM')) AS "timezWithNanoSeconds", to_char($15::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp", to_char($16::timestamp without time zone, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestampWithNanoSeconds", $17::timestamp with time zone AS "timestampz", $18::timestamp with time zone AS "timestampzWithNanoSeconds", - to_char($19::date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date", - '0000-01-01T' || to_char('2000-10-10'::date + ($20::time without time zone + INTERVAL '2 HOUR'), 'HH24:MI:SS.USZ') AS "timeExpression" + (to_char($19::date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z') AS "date", + ('0000-01-01T' || to_char('2000-10-10'::date + ($20::time without time zone + INTERVAL '2 HOUR'), 'HH24:MI:SS.USZ')) AS "timeExpression" ) AS json_records ) AS "json"; `) @@ -1626,21 +1697,21 @@ FROM test_sample.all_types; func TestTimeEXTRACT(t *testing.T) { stmt := SELECT( - EXTRACT(CENTURY, AllTypes.Timestampz), + EXTRACT(CENTURY, AllTypes.Timestampz).AS("century"), EXTRACT(DAY, AllTypes.Timestamp), EXTRACT(DECADE, AllTypes.Date), EXTRACT(DOW, AllTypes.TimestampzPtr), - EXTRACT(DOY, DateT(time.Now())), + EXTRACT(DOY, DateT(time.Now())).AS("date"), EXTRACT(EPOCH, TimestampT(time.Now())), - EXTRACT(HOUR, AllTypes.Time.ADD(INTERVAL(1, HOUR))), - EXTRACT(ISODOW, AllTypes.Timestampz), + EXTRACT(HOUR, AllTypes.Time.ADD(INTERVAL(1, HOUR))).AS("hour"), + EXTRACT(ISODOW, AllTypes.Date.SUB(INTERVAL(1, DAY))), EXTRACT(ISOYEAR, AllTypes.Timestampz), - EXTRACT(JULIAN, AllTypes.Timestampz).EQ(Float(3456.123)), - EXTRACT(MICROSECOND, AllTypes.Timestampz), + EXTRACT(JULIAN, AllTypes.Timestampz).EQ(Float(3456.123)).AS("microsecond_equal"), + EXTRACT(MICROSECOND, AllTypes.Timestampz).EQ(Float(123.001)), EXTRACT(MILLENNIUM, AllTypes.Timestampz), EXTRACT(MILLISECOND, AllTypes.Timez), - EXTRACT(MINUTE, INTERVAL(1, HOUR, 2, MINUTE)), - EXTRACT(MONTH, AllTypes.Timestampz), + EXTRACT(MINUTE, INTERVAL(1, HOUR, 2, MINUTE)).AS("minute_interval"), + EXTRACT(MONTH, INTERVAL(11, DAY)), EXTRACT(QUARTER, AllTypes.Timestampz), EXTRACT(SECOND, AllTypes.Timestampz), EXTRACT(TIMEZONE, AllTypes.Timestampz), @@ -1652,24 +1723,24 @@ func TestTimeEXTRACT(t *testing.T) { AllTypes, ) - // fmt.Println(stmt.Sql()) + //fmt.Println(stmt.Sql()) testutils.AssertStatementSql(t, stmt, ` -SELECT EXTRACT(CENTURY FROM all_types.timestampz), +SELECT EXTRACT(CENTURY FROM all_types.timestampz) AS "century", EXTRACT(DAY FROM all_types.timestamp), EXTRACT(DECADE FROM all_types.date), EXTRACT(DOW FROM all_types.timestampz_ptr), - EXTRACT(DOY FROM $1::date), + EXTRACT(DOY FROM $1::date) AS "date", EXTRACT(EPOCH FROM $2::timestamp without time zone), - EXTRACT(HOUR FROM all_types.time + INTERVAL '1 HOUR'), - EXTRACT(ISODOW FROM all_types.timestampz), + EXTRACT(HOUR FROM (all_types.time + INTERVAL '1 HOUR')) AS "hour", + EXTRACT(ISODOW FROM (all_types.date - INTERVAL '1 DAY')), EXTRACT(ISOYEAR FROM all_types.timestampz), - EXTRACT(JULIAN FROM all_types.timestampz) = $3, - EXTRACT(MICROSECOND FROM all_types.timestampz), + (EXTRACT(JULIAN FROM all_types.timestampz) = $3) AS "microsecond_equal", + EXTRACT(MICROSECOND FROM all_types.timestampz) = $4, EXTRACT(MILLENNIUM FROM all_types.timestampz), EXTRACT(MILLISECOND FROM all_types.timez), - EXTRACT(MINUTE FROM INTERVAL '1 HOUR 2 MINUTE'), - EXTRACT(MONTH FROM all_types.timestampz), + EXTRACT(MINUTE FROM INTERVAL '1 HOUR 2 MINUTE') AS "minute_interval", + EXTRACT(MONTH FROM INTERVAL '11 DAY'), EXTRACT(QUARTER FROM all_types.timestampz), EXTRACT(SECOND FROM all_types.timestampz), EXTRACT(TIMEZONE FROM all_types.timestampz), @@ -1947,9 +2018,9 @@ FROM ( "subQuery"."all_types.integer" AS "integer", "subQuery"."all_types.double_precision" AS "doublePrecision", "subQuery"."all_types.text" AS "text", - to_char("subQuery"."all_types.date"::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "date", - '0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.time", 'HH24:MI:SS.USZ') AS "time", - '0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.timez", 'HH24:MI:SS.USTZH:TZM') AS "timez", + (to_char("subQuery"."all_types.date"::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z') AS "date", + ('0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.time", 'HH24:MI:SS.USZ')) AS "time", + ('0000-01-01T' || to_char('2000-10-10'::date + "subQuery"."all_types.timez", 'HH24:MI:SS.USTZH:TZM')) AS "timez", to_char("subQuery"."all_types.timestamp", 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "timestamp", "subQuery"."all_types.timestampz" AS "timestampz", "subQuery"."all_types.interval" AS "interval", diff --git a/tests/postgres/array_test.go b/tests/postgres/array_test.go index b1677c8..95317f3 100644 --- a/tests/postgres/array_test.go +++ b/tests/postgres/array_test.go @@ -1,6 +1,9 @@ package postgres import ( + "testing" + "time" + "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/postgres" @@ -11,8 +14,6 @@ import ( "github.com/google/uuid" "github.com/lib/pq" "github.com/stretchr/testify/require" - "testing" - "time" ) func TestArraySelect(t *testing.T) { @@ -158,8 +159,8 @@ SELECT $1::boolean[] AS "bool_array", (sample_arrays.bool_array = $13::boolean[]) AS "bool_eq", (sample_arrays.text_array = sample_arrays.text_array) AS "text_eq", (sample_arrays.text_array != $14::text[]) AS "text_neq", - (sample_arrays.int4_array < $15::integer[]) IS TRUE AS "int4_lt", - (sample_arrays.int8_array <= $16::bigint[]) IS FALSE AS "int8_lteq", + ((sample_arrays.int4_array < $15::integer[]) IS TRUE) AS "int4_lt", + ((sample_arrays.int8_array <= $16::bigint[]) IS FALSE) AS "int8_lteq", (sample_arrays.real_array > $17::real[]) AS "decimal_gt", (sample_arrays.double_array >= $18::double precision[]) AS "numeric_gt_eq", (sample_arrays.bytea_array @> $19::bytea[]) AS "bytea_contains", diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 3dad8c6..9e491b4 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -2,6 +2,9 @@ package postgres import ( "context" + "testing" + "time" + "github.com/bytedance/sonic" "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/utils/ptr" @@ -10,8 +13,6 @@ import ( . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook2/table" "github.com/stretchr/testify/require" - "testing" - "time" ) func TestSelectAlbum(t *testing.T) { @@ -1301,13 +1302,13 @@ func TestAggregateFunc(t *testing.T) { skipForCockroachDB(t) stmt := SELECT( - PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), + PERCENTILE_DISC(Double(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"), PERCENTILE_DISC(RawFloat("(select array_agg(s) from generate_series(0, 1, 0.2) as s)")). WITHIN_GROUP_ORDER_BY(Invoice.BillingAddress.DESC()).AS("percentile_disc_3"), - PERCENTILE_CONT(Float(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"), - PERCENTILE_CONT(Float(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_interval"), + PERCENTILE_CONT(Double(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"), + PERCENTILE_CONT(Double(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_interval"), MODE().WITHIN_GROUP_ORDER_BY(Invoice.BillingPostalCode.DESC()).AS("mode_1"), ).FROM( diff --git a/tests/postgres/range_test.go b/tests/postgres/range_test.go index cd2e583..7eda955 100644 --- a/tests/postgres/range_test.go +++ b/tests/postgres/range_test.go @@ -4,14 +4,15 @@ package postgres import ( + "math/big" + "testing" + "time" + "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/qrm" "github.com/google/go-cmp/cmp" "github.com/jackc/pgtype" "github.com/stretchr/testify/require" - "math/big" - "testing" - "time" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" @@ -81,8 +82,8 @@ SELECT sample_ranges.date_range AS "sample_ranges.date_range", (sample_ranges.int4_range = sample_ranges.int4_range) AS "sample.int4eq", (sample_ranges.int8_range = int8range($1, $2, $3::text)) AS "sample.int8eq", (sample_ranges.int4_range != int4range($4, $5)) AS "sample.int4neq", - (sample_ranges.num_range < numrange($6, $7)) IS TRUE AS "sample.num_lt", - (sample_ranges.date_range <= daterange($8::date, $9)) IS FALSE AS "sample.date_lteq", + ((sample_ranges.num_range < numrange($6, $7)) IS TRUE) AS "sample.num_lt", + ((sample_ranges.date_range <= daterange($8::date, $9)) IS FALSE) AS "sample.date_lteq", (sample_ranges.timestamp_range > tsrange($10::timestamp without time zone, $11::timestamp without time zone)) AS "sample.ts_gt", (sample_ranges.timestampz_range >= tstzrange($12, $13::timestamp with time zone)) AS "sample.tstz_gteq", (sample_ranges.int4_range @> $14::integer) AS "sample.int4cont", diff --git a/tests/postgres/select_json_test.go b/tests/postgres/select_json_test.go index 3ceadf6..634f93a 100644 --- a/tests/postgres/select_json_test.go +++ b/tests/postgres/select_json_test.go @@ -1,13 +1,14 @@ package postgres import ( + "testing" + "time" + "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/view" "github.com/stretchr/testify/require" - "testing" - "time" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" @@ -155,7 +156,7 @@ FROM ( customer.email AS "email", customer.address_id AS "addressID", customer.activebool AS "activebool", - to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "createDate", + (to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z') AS "createDate", to_char(customer.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", customer.active AS "active", ( @@ -307,7 +308,7 @@ FROM ( customer.email AS "email", customer.address_id AS "addressID", customer.activebool AS "activebool", - to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "createDate", + (to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z') AS "createDate", to_char(customer.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", customer.active AS "active", ( @@ -522,7 +523,7 @@ RETURNING rental.rental_id AS "rental.rental_id", customer.email AS "email", customer.address_id AS "addressID", customer.activebool AS "activebool", - to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z' AS "createDate", + (to_char(customer.create_date::timestamp, 'YYYY-MM-DD') || 'T00:00:00Z') AS "createDate", to_char(customer.last_update, 'YYYY-MM-DD"T"HH24:MI:SS.USZ') AS "lastUpdate", customer.active AS "active" FROM dvds.customer diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index b676bb4..aa6ad08 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -2,6 +2,10 @@ package sqlite import ( "encoding/hex" + "strings" + "testing" + "time" + "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/sqlite" @@ -12,9 +16,6 @@ import ( "github.com/google/uuid" "github.com/shopspring/decimal" "github.com/stretchr/testify/require" - "strings" - "testing" - "time" ) func TestAllTypes(t *testing.T) { @@ -232,8 +233,8 @@ func TestExpressionOperators(t *testing.T) { ).LIMIT(2) testutils.AssertStatementSql(t, query, strings.Replace(` -SELECT all_types.integer IS NULL AS "result.is_null", - all_types.date_ptr IS NOT NULL AS "result.is_not_null", +SELECT (all_types.integer IS NULL) AS "result.is_null", + (all_types.date_ptr IS NOT NULL) AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", (all_types.small_int_ptr IN (( SELECT all_types.integer AS "all_types.integer" @@ -299,12 +300,12 @@ SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", (all_types.boolean IS NOT ?) AS "distinct2", (all_types.boolean IS all_types.boolean_ptr) AS "not_distinct_1", (all_types.boolean IS ?) AS "NOTDISTINCT2", - all_types.boolean IS TRUE AS "ISTRUE", - all_types.boolean IS NOT TRUE AS "isnottrue", - all_types.boolean IS FALSE AS "is_False", - all_types.boolean IS NOT FALSE AS "is not false", - all_types.boolean IS NULL AS "is unknown", - all_types.boolean IS NOT NULL AS "is_not_unknown", + (all_types.boolean IS TRUE) AS "ISTRUE", + (all_types.boolean IS NOT TRUE) AS "isnottrue", + (all_types.boolean IS FALSE) AS "is_False", + (all_types.boolean IS NOT FALSE) AS "is not false", + (all_types.boolean IS NULL) AS "is unknown", + (all_types.boolean IS NOT NULL) AS "is_not_unknown", ((all_types.boolean AND all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex1", ((all_types.boolean OR all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex2" FROM all_types; diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go index ad491d2..85cc33b 100644 --- a/tests/sqlite/update_test.go +++ b/tests/sqlite/update_test.go @@ -2,11 +2,12 @@ package sqlite import ( "context" + "testing" + "time" + "github.com/go-jet/jet/v2/qrm" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" - "testing" - "time" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/sqlite" @@ -156,7 +157,7 @@ RETURNING link.id AS "link.id", (link.id + ?) AS "dest.binary_operator", CAST(link.id AS TEXT) AS "dest.cast_operator", (link.name LIKE ?) AS "dest.like_operator", - link.description IS NULL AS "dest.is_null", + (link.description IS NULL) AS "dest.is_null", (CASE link.name WHEN ? THEN ? WHEN ? THEN ? ELSE ? END) AS "dest.case_operator"; ` testutils.AssertStatementSql(t, stmt, expectedSQL, int32(20), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(20),