diff --git a/internal/jet/cast.go b/internal/jet/cast.go deleted file mode 100644 index 311ab88..0000000 --- a/internal/jet/cast.go +++ /dev/null @@ -1,53 +0,0 @@ -package jet - -// Cast interface -type Cast interface { - AS(castType string) Expression -} - -type castImpl struct { - expression Expression -} - -// NewCastImpl creates new generic cast -func NewCastImpl(expression Expression) Cast { - castImpl := castImpl{ - expression: expression, - } - - return &castImpl -} - -func (b *castImpl) AS(castType string) Expression { - castExp := &castExpression{ - expression: b.expression, - cast: string(castType), - } - - castExp.ExpressionInterfaceImpl.Root = castExp - - return castExp -} - -type castExpression struct { - ExpressionInterfaceImpl - - expression Expression - cast string -} - -func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - - expression := b.expression - castType := b.cast - - if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil { - castOverride(expression, String(castType))(statement, out, FallTrough(options)...) - return - } - - out.WriteString("CAST(") - expression.serialize(statement, out, FallTrough(options)...) - out.WriteString("AS") - out.WriteString(castType + ")") -} diff --git a/internal/jet/cast_test.go b/internal/jet/cast_test.go deleted file mode 100644 index b72cede..0000000 --- a/internal/jet/cast_test.go +++ /dev/null @@ -1,11 +0,0 @@ -package jet - -import ( - "testing" -) - -func TestCastAS(t *testing.T) { - assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1)) - assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)") - assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)") -} diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 678e044..83b6fe4 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -18,6 +18,7 @@ type Dialect interface { SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName(index int) string JsonValueEncode(expr Expression) Expression + RegexpLike(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc } // SerializerFunc func @@ -43,6 +44,7 @@ type DialectParams struct { SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName func(index int) string JsonValueEncode func(expr Expression) Expression + RegexpLike func(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc } // NewDialect creates new dialect with params @@ -60,6 +62,7 @@ func NewDialect(params DialectParams) Dialect { serializeOrderBy: params.SerializeOrderBy, valuesDefaultColumnName: params.ValuesDefaultColumnName, jsonValueEncode: params.JsonValueEncode, + regexpLike: params.RegexpLike, } } @@ -76,6 +79,7 @@ type dialectImpl struct { serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc valuesDefaultColumnName func(index int) string jsonValueEncode func(expr Expression) Expression + regexpLike func(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc } func (d *dialectImpl) Name() string { @@ -133,6 +137,21 @@ func (d *dialectImpl) JsonValueEncode(expr Expression) Expression { return d.jsonValueEncode(expr) } +func (d *dialectImpl) RegexpLike(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc { + if d.regexpLike != nil { + return d.regexpLike(str, not, pattern, caseSensitive) + } + + return func(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + str.serialize(statement, out, FallTrough(options)...) + if not { + out.WriteString("NOT") + } + out.WriteString("REGEXP") + pattern.serialize(statement, out, FallTrough(options)...) + } +} + func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { ret := map[string]bool{} for _, elem := range arr { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index e915852..2694b7a 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -141,24 +141,27 @@ type binaryOperatorSerializer struct { } func (c *binaryOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) - serializeOverrideFunc(statement, out, FallTrough(options)...) - } else { - c.lhs.serialize(statement, out, FallTrough(options)...) - out.WriteString(c.operator) - c.rhs.serialize(statement, out, FallTrough(options)...) - } + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { + serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) + serializeOverrideFunc(statement, out, FallTrough(options)...) + } else { + c.lhs.serialize(statement, out, FallTrough(options)...) + out.WriteString(c.operator) + c.rhs.serialize(statement, out, FallTrough(options)...) + } + }) + } // NewBinaryOperatorExpression creates new binaryOperatorExpression func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression { - return newExpression(optionalWrap(&binaryOperatorSerializer{ + return newExpression(&binaryOperatorSerializer{ lhs: lhs, rhs: rhs, additionalParam: OptionalOrDefault(additionalParam, nil), operator: operator, - })) + }) } type serializersWithOperator struct { @@ -226,24 +229,24 @@ type betweenOperatorSerializer struct { } func (b *betweenOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - b.expression.serialize(statement, out, FallTrough(options)...) - if b.notBetween { - out.WriteString("NOT") - } - out.WriteString("BETWEEN") - b.min.serialize(statement, out, FallTrough(options)...) - out.WriteString("AND") - b.max.serialize(statement, out, FallTrough(options)...) + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + b.expression.serialize(statement, out, FallTrough(options)...) + if b.notBetween { + out.WriteString("NOT") + } + out.WriteString("BETWEEN") + b.min.serialize(statement, out, FallTrough(options)...) + out.WriteString("AND") + b.max.serialize(statement, out, FallTrough(options)...) + }) } // NewBetweenOperatorExpression creates new BETWEEN operator expression func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bool) BoolExpression { - return BoolExp(newExpression( - optionalWrap(&betweenOperatorSerializer{ - expression: expression, - notBetween: notBetween, - min: min, - max: max, - }), - )) + return BoolExp(newExpression(&betweenOperatorSerializer{ + expression: expression, + notBetween: notBetween, + min: min, + max: max, + })) } diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index af256f3..38363b5 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -244,7 +244,7 @@ func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) defaultValue, ok = offsetAndDefault[1].(Expression) if !ok { - defaultValue = literal(offsetAndDefault[1]) + defaultValue = Literal(offsetAndDefault[1]) } params = append(params, FixedLiteral(offset), defaultValue) diff --git a/internal/jet/integer_expression_test.go b/internal/jet/integer_expression_test.go index a20981b..e1c850c 100644 --- a/internal/jet/integer_expression_test.go +++ b/internal/jet/integer_expression_test.go @@ -66,7 +66,7 @@ func TestIntExpressionPOW(t *testing.T) { func TestIntExpressionBIT_NOT(t *testing.T) { assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)") - assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ 11)") + assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ $1)", int64(11)) } func TestIntExpressionBIT_AND(t *testing.T) { diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index b8baa51..f7ceedf 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -5,48 +5,12 @@ import ( "time" ) -// LiteralExpression is representation of an escaped literal -type LiteralExpression interface { - Expression - - Value() interface{} - SetConstant(constant bool) -} - -type literalExpressionImpl struct { - ExpressionInterfaceImpl - +type literalSerializer struct { value interface{} constant bool } -func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl { - exp := literalExpressionImpl{value: value} - - if len(optionalConstant) > 0 { - exp.constant = optionalConstant[0] - } - - exp.ExpressionInterfaceImpl.Root = &exp - - return &exp -} - -// Literal is injected directly to SQL query, and does not appear in parametrized argument list. -func Literal(value interface{}) *literalExpressionImpl { - exp := literal(value) - return exp -} - -// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. -func FixedLiteral(value interface{}) *literalExpressionImpl { - exp := literal(value) - exp.constant = true - - return exp -} - -func (l *literalExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (l *literalSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { if l.constant { out.insertConstantArgument(l.value) } else { @@ -54,260 +18,145 @@ func (l *literalExpressionImpl) serialize(statement StatementType, out *SQLBuild } } -func (l *literalExpressionImpl) Value() interface{} { - return l.value +// Literal is injected directly to SQL query, and does not appear in parametrized argument list. +func Literal(value interface{}) Expression { + return newExpression(&literalSerializer{ + value: value, + constant: false, + }) } -func (l *literalExpressionImpl) SetConstant(constant bool) { - l.constant = constant -} - -type integerLiteralExpression struct { - literalExpressionImpl - integerInterfaceImpl -} - -func intLiteral(value interface{}) IntegerExpression { - numLiteral := &integerLiteralExpression{} - - numLiteral.literalExpressionImpl = *literal(value) - - numLiteral.literalExpressionImpl.Root = numLiteral - numLiteral.integerInterfaceImpl.root = numLiteral - - return numLiteral +// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list. +func FixedLiteral(value interface{}) Expression { + return newExpression(&literalSerializer{ + value: value, + constant: true, + }) } // Int creates a new 64 bit signed integer literal func Int(value int64) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Int8 creates a new 8 bit signed integer literal func Int8(value int8) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Int16 creates a new 16 bit signed integer literal func Int16(value int16) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Int32 creates a new 32 bit signed integer literal func Int32(value int32) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint8 creates a new 8 bit unsigned integer literal func Uint8(value uint8) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint16 creates a new 16 bit unsigned integer literal func Uint16(value uint16) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint32 creates a new 32 bit unsigned integer literal func Uint32(value uint32) IntegerExpression { - return intLiteral(value) + return IntExp(Literal(value)) } // Uint64 creates a new 64 bit unsigned integer literal func Uint64(value uint64) IntegerExpression { - return intLiteral(value) -} - -// ---------------------------------------------------// -type boolLiteralExpression struct { - boolInterfaceImpl - literalExpressionImpl + return IntExp(Literal(value)) } // Bool creates new bool literal expression func Bool(value bool) BoolExpression { - boolLiteralExpression := boolLiteralExpression{} - - boolLiteralExpression.literalExpressionImpl = *literal(value) - boolLiteralExpression.boolInterfaceImpl.root = &boolLiteralExpression - - return &boolLiteralExpression -} - -// ---------------------------------------------------// -type floatLiteral struct { - floatInterfaceImpl - literalExpressionImpl + return BoolExp(Literal(value)) } // Float creates new float literal from float64 value func Float(value float64) FloatExpression { - floatLiteral := floatLiteral{} - floatLiteral.literalExpressionImpl = *literal(value) - - floatLiteral.floatInterfaceImpl.root = &floatLiteral - - return &floatLiteral + return FloatExp(Literal(value)) } // Decimal creates new float literal from string value func Decimal(value string) FloatExpression { - floatLiteral := floatLiteral{} - floatLiteral.literalExpressionImpl = *literal(value) - - floatLiteral.floatInterfaceImpl.root = &floatLiteral - - return &floatLiteral -} - -// ---------------------------------------------------// -type stringLiteral struct { - stringInterfaceImpl - literalExpressionImpl + return FloatExp(Literal(value)) } // String creates new string literal expression func String(value string) StringExpression { - stringLiteral := stringLiteral{} - stringLiteral.literalExpressionImpl = *literal(value) - - stringLiteral.stringInterfaceImpl.root = &stringLiteral - - return &stringLiteral -} - -//---------------------------------------------------// - -type timeLiteral struct { - timeInterfaceImpl - literalExpressionImpl + return StringExp(Literal(value)) } // Time creates new time literal expression func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { - timeLiteral := &timeLiteral{} timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) timeStr += formatNanoseconds(nanoseconds...) - timeLiteral.literalExpressionImpl = *literal(timeStr) - timeLiteral.timeInterfaceImpl.root = timeLiteral - - return timeLiteral + return TimeExp(Literal(timeStr)) } // TimeT creates new time literal expression from time.Time object func TimeT(t time.Time) TimeExpression { - timeLiteral := &timeLiteral{} - timeLiteral.literalExpressionImpl = *literal(t) - timeLiteral.timeInterfaceImpl.root = timeLiteral - - return timeLiteral -} - -//---------------------------------------------------// - -type timezLiteral struct { - timezInterfaceImpl - literalExpressionImpl + return TimeExp(Literal(t)) } // Timez creates new time with time zone literal expression func Timez(hour, minute, second int, nanoseconds time.Duration, timezone string) TimezExpression { - timezLiteral := timezLiteral{} timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) timeStr += formatNanoseconds(nanoseconds) timeStr += " " + timezone - timezLiteral.literalExpressionImpl = *literal(timeStr) - return TimezExp(literal(timeStr)) + return TimezExp(Literal(timeStr)) } // TimezT creates new time with time zone literal expression from time.Time object func TimezT(t time.Time) TimezExpression { - timeLiteral := &timezLiteral{} - timeLiteral.literalExpressionImpl = *literal(t) - timeLiteral.timezInterfaceImpl.root = timeLiteral - - return timeLiteral -} - -//---------------------------------------------------// - -type timestampLiteral struct { - timestampInterfaceImpl - literalExpressionImpl + return TimezExp(Literal(t)) } // Timestamp creates new timestamp literal expression func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression { - timestamp := ×tampLiteral{} timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second) timeStr += formatNanoseconds(nanoseconds...) - timestamp.literalExpressionImpl = *literal(timeStr) - timestamp.timestampInterfaceImpl.root = timestamp - return timestamp + + return TimestampExp(Literal(timeStr)) } // TimestampT creates new timestamp literal expression from time.Time object func TimestampT(t time.Time) TimestampExpression { - timestamp := ×tampLiteral{} - timestamp.literalExpressionImpl = *literal(t) - timestamp.timestampInterfaceImpl.root = timestamp - return timestamp -} - -//---------------------------------------------------// - -type timestampzLiteral struct { - timestampzInterfaceImpl - literalExpressionImpl + return TimestampExp(Literal(t)) } // Timestampz creates new timestamp with time zone literal expression func Timestampz(year int, month time.Month, day, hour, minute, second int, nanoseconds time.Duration, timezone string) TimestampzExpression { - timestamp := ×tampzLiteral{} timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second) timeStr += formatNanoseconds(nanoseconds) timeStr += " " + timezone - timestamp.literalExpressionImpl = *literal(timeStr) - timestamp.timestampzInterfaceImpl.root = timestamp - return timestamp + return TimestampzExp(Literal(timeStr)) } // TimestampzT creates new timestamp literal expression from time.Time object func TimestampzT(t time.Time) TimestampzExpression { - timestamp := ×tampzLiteral{} - timestamp.literalExpressionImpl = *literal(t) - timestamp.timestampzInterfaceImpl.root = timestamp - return timestamp -} - -//---------------------------------------------------// - -type dateLiteral struct { - dateInterfaceImpl - literalExpressionImpl + return TimestampzExp(Literal(t)) } // Date creates new date literal expression func Date(year int, month time.Month, day int) DateExpression { - dateLiteral := &dateLiteral{} - timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day) - dateLiteral.literalExpressionImpl = *literal(timeStr) - dateLiteral.dateInterfaceImpl.root = dateLiteral - - return dateLiteral + return DateExp(Literal(timeStr)) } // DateT creates new date literal expression from time.Time object func DateT(t time.Time) DateExpression { - dateLiteral := &dateLiteral{} - dateLiteral.literalExpressionImpl = *literal(t) - dateLiteral.dateInterfaceImpl.root = dateLiteral - - return dateLiteral + return DateExp(Literal(t)) } func formatNanoseconds(nanoseconds ...time.Duration) string { @@ -330,86 +179,35 @@ func formatNanoseconds(nanoseconds ...time.Duration) string { var ( // NULL is jet equivalent of SQL NULL - NULL = newNullLiteral() + NULL = newExpression(Keyword("NULL")) // STAR is jet equivalent of SQL * - STAR = newStarLiteral() + STAR = newExpression(Keyword("*")) // PLUS_INFINITY is jet equivalent for sql infinity PLUS_INFINITY = String("infinity") // MINUS_INFINITY is jet equivalent for sql -infinity MINUS_INFINITY = String("-infinity") ) -type nullLiteral struct { - ExpressionInterfaceImpl -} - -func newNullLiteral() Expression { - nullExpression := &nullLiteral{} - - nullExpression.ExpressionInterfaceImpl.Root = nullExpression - - return nullExpression -} - -func (n *nullLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("NULL") -} - -// --------------------------------------------------// -type starLiteral struct { - ExpressionInterfaceImpl -} - -func newStarLiteral() Expression { - starExpression := &starLiteral{} - - starExpression.ExpressionInterfaceImpl.Root = starExpression - - return starExpression -} - -func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("*") -} - //---------------------------------------------------// -type rawExpression struct { - ExpressionInterfaceImpl - +type rawSerializer struct { Raw string NamedArgument map[string]interface{} - noWrap bool } -func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !n.noWrap && !contains(options, NoWrap) { - out.WriteByte('(') - } - - out.insertRawQuery(n.Raw, n.NamedArgument) - - if !n.noWrap && !contains(options, NoWrap) { - out.WriteByte(')') - } +func (n *rawSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + out.insertRawQuery(n.Raw, n.NamedArgument) + }) } // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") func Raw(raw string, namedArgs ...map[string]interface{}) Expression { - var namedArguments map[string]interface{} - - if len(namedArgs) > 0 { - namedArguments = namedArgs[0] - } - - rawExp := &rawExpression{ + return newExpression(&rawSerializer{ Raw: raw, - NamedArgument: namedArguments, - } - rawExp.ExpressionInterfaceImpl.Root = rawExp - - return rawExp + NamedArgument: singleOptional(namedArgs), + }) } // RawBool helper that for raw string boolean expressions diff --git a/internal/jet/operators.go b/internal/jet/operators.go index b56cb1b..277fb2c 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -16,9 +16,6 @@ func NOT(exp BoolExpression) BoolExpression { // BIT_NOT inverts every bit in integer expression result func BIT_NOT(expr IntegerExpression) IntegerExpression { - if literalExp, ok := expr.(LiteralExpression); ok { - literalExp.SetConstant(true) - } return newPrefixIntegerOperatorExpression(expr, "~") } @@ -131,10 +128,8 @@ type caseOperatorImpl struct { // CASE create CASE operator with optional list of expressions func CASE(expression ...Expression) CaseOperator { - caseExp := &caseOperatorImpl{} - - if len(expression) > 0 { - caseExp.expression = expression[0] + caseExp := &caseOperatorImpl{ + expression: singleOptional(expression), } caseExp.ExpressionInterfaceImpl.Root = caseExp diff --git a/internal/jet/order_set_aggregate_functions.go b/internal/jet/order_set_aggregate_functions.go index 288f8a3..952afa0 100644 --- a/internal/jet/order_set_aggregate_functions.go +++ b/internal/jet/order_set_aggregate_functions.go @@ -34,25 +34,20 @@ func newOrderSetAggregateFunction(name string, fraction FloatExpression) *OrderS // WITHIN_GROUP_ORDER_BY specifies ordered set of aggregated argument values func (p *OrderSetAggregateFunc) WITHIN_GROUP_ORDER_BY(orderBy OrderByClause) Expression { p.orderBy = ORDER_BY(orderBy) - return newOrderSetAggregateFuncExpression(*p) + return newOrderSetAggregateFuncExpression(p) } -func newOrderSetAggregateFuncExpression(aggFunc OrderSetAggregateFunc) *orderSetAggregateFuncExpression { - ret := &orderSetAggregateFuncExpression{ +func newOrderSetAggregateFuncExpression(aggFunc *OrderSetAggregateFunc) Expression { + return newExpression(&orderSetAggregateFuncSerializer{ OrderSetAggregateFunc: aggFunc, - } - - ret.ExpressionInterfaceImpl.Root = ret - - return ret + }) } -type orderSetAggregateFuncExpression struct { - ExpressionInterfaceImpl - OrderSetAggregateFunc +type orderSetAggregateFuncSerializer struct { + *OrderSetAggregateFunc } -func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func (p *orderSetAggregateFuncSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString(p.name) if p.fraction != nil { diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go index ec551f0..7d73c6f 100644 --- a/internal/jet/raw_statement.go +++ b/internal/jet/raw_statement.go @@ -15,11 +15,8 @@ func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string] statementType: "", root: nil, }, - RawQuery: rawQuery, - } - - if len(namedArgument) > 0 { - newRawStatement.NamedArguments = namedArgument[0] + RawQuery: rawQuery, + NamedArguments: singleOptional(namedArgument), } newRawStatement.root = &newRawStatement diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 7f49745..d6e09db 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -121,64 +121,50 @@ func (t Token) serialize(statement StatementType, out *SQLBuilder, options ...Se // CustomExpression creates new custom expression. When serialized may require parentheses // depending on context. func CustomExpression(parts ...Serializer) Expression { - return newExpression(optionalWrap(&customSerializer{ + return newExpression(&customSerializer{ parts: parts, - })) + }) +} + +// AtomicCustomExpression creates new custom expression. When serialized does not require parentheses. +func AtomicCustomExpression(parts ...Serializer) Expression { + return newExpression(&customSerializer{ + parts: parts, + atomic: true, + }) } type customSerializer struct { - parts []Serializer + parts []Serializer + atomic bool } func (c *customSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - for _, expression := range c.parts { - expression.serialize(statement, out, options...) + if c.atomic { + for _, expr := range c.parts { + expr.serialize(statement, out, without(options, NoWrap)...) + } + } else { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + for _, expr := range c.parts { + expr.serialize(statement, out, options...) + } + }) } } -type optionalWrapSerializer struct { - serializer []Serializer -} - -func optionalWrap(serializer ...Serializer) Serializer { - return &optionalWrapSerializer{serializer: serializer} -} - -func (s *optionalWrapSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { +func optionalWrap(out *SQLBuilder, options []SerializeOption, ser func(out *SQLBuilder, options []SerializeOption)) { if !contains(options, NoWrap) { out.WriteString("(") } - for _, ser := range s.serializer { - ser.serialize(statement, out, without(options, NoWrap)...) - } + ser(out, without(options, NoWrap)) if !contains(options, NoWrap) { out.WriteString(")") } } -// AtomicCustomExpression creates new custom expression. When serialized does not require parentheses. -func AtomicCustomExpression(parts ...Serializer) Expression { - return newExpression(noWrap(&customSerializer{ - parts: parts, - })) -} - -type noWrapSerializer struct { - serializer []Serializer -} - -func noWrap(serializer ...Serializer) Serializer { - return &noWrapSerializer{serializer: serializer} -} - -func (s *noWrapSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - for _, ser := range s.serializer { - ser.serialize(statement, out, without(options, NoWrap)...) - } -} - func wrap(expressions ...Expression) Expression { return newFunc("", expressions) } diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index 5d373ca..d032aad 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -85,11 +85,33 @@ func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression } func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperatorExpression(s.root, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return BoolExp(newExpression(®expLikeSerializer{ + str: s.root, + pattern: pattern, + caseSensitive: len(caseSensitive) > 0 && caseSensitive[0], + })) } func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { - return newBinaryBoolOperatorExpression(s.root, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) + return BoolExp(newExpression(®expLikeSerializer{ + not: true, + str: s.root, + pattern: pattern, + caseSensitive: len(caseSensitive) > 0 && caseSensitive[0], + })) +} + +type regexpLikeSerializer struct { + not bool + str StringExpression + pattern StringExpression + caseSensitive bool +} + +func (r *regexpLikeSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) { + out.Dialect.RegexpLike(r.str, r.not, r.pattern, r.caseSensitive)(statement, out, options...) + }) } // ---------------------------------------------------// diff --git a/internal/jet/utils.go b/internal/jet/utils.go index aef2c8b..f502970 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -160,7 +160,7 @@ func ToSerializerValue(value interface{}) Serializer { return clause } - return literal(value) + return Literal(value) } // UnwindRowFromModel func @@ -189,7 +189,7 @@ func UnwindRowFromModel(columns []Column, data interface{}) []Serializer { field = reflect.Indirect(structField).Interface() } - row[i] = literal(field) + row[i] = Literal(field) } return row @@ -293,7 +293,7 @@ func joinAlias(tableAlias, columnAlias string) string { return strings.TrimRight(tableAlias, ".*") + "." + columnAlias } -func optional[T any](value []T) T { +func singleOptional[T any](value []T) T { if len(value) > 0 { return value[0] } diff --git a/mysql/cast.go b/mysql/cast.go index fbce06c..05dc61c 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -1,25 +1,25 @@ package mysql import ( - "github.com/go-jet/jet/v2/internal/jet" "strconv" + + "github.com/go-jet/jet/v2/internal/jet" ) -type cast struct { - jet.Cast +// CAST function converts an expr (of any type) into later specified datatype. +func CAST(expr Expression) *cast { + return &cast{ + expr: expr, + } } -// CAST function converts a expr (of any type) into latter specified datatype. -func CAST(expr Expression) *cast { - ret := &cast{} - ret.Cast = jet.NewCastImpl(expr) - - return ret +type cast struct { + expr Expression } // AS casts expressions to castType func (c *cast) AS(castType string) Expression { - return c.Cast.AS(castType) + return jet.AtomicCustomExpression(Token("CAST("), c.expr, Token("AS "+castType+")")) } // AS_DATETIME cast expression to DATETIME type diff --git a/mysql/dialect.go b/mysql/dialect.go index 23947df..71a8e38 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -12,8 +12,6 @@ var Dialect = newDialect() func newDialect() jet.Dialect { operatorSerializeOverrides := map[string]jet.SerializeOverride{} - operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator - operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator operatorSerializeOverrides["IS DISTINCT FROM"] = mysqlISDISTINCTFROM operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysqlISNOTDISTINCTFROM operatorSerializeOverrides["/"] = mysqlDivision @@ -52,6 +50,7 @@ func newDialect() jet.Dialect { } return expr }, + RegexpLike: regexpLikeOperator, } return jet.NewDialect(mySQLDialectParams) @@ -144,20 +143,12 @@ func mysqlISDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc { } } -func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { +func regexpLikeOperator(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } + jet.Serialize(str, statement, out, options...) - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } + if not { + out.WriteString("NOT") } out.WriteString("REGEXP") @@ -166,33 +157,7 @@ func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { out.WriteString("BINARY") } - jet.Serialize(expressions[1], statement, out, options...) - } -} - -func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { - return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } - - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } - } - - out.WriteString("NOT REGEXP") - - if caseSensitive { - out.WriteString("BINARY") - } - - jet.Serialize(expressions[1], statement, out, options...) + jet.Serialize(pattern, statement, out, options...) } } diff --git a/mysql/expressions_test.go b/mysql/expressions_test.go index 39174b6..b51ee54 100644 --- a/mysql/expressions_test.go +++ b/mysql/expressions_test.go @@ -48,7 +48,7 @@ func TestRawInvalidArguments(t *testing.T) { func TestRawType(t *testing.T) { assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), - "(table.colInt < ?) IS FALSE", 11.22) + "((table.colInt < ?) IS FALSE)", 11.22) assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), "((table.colInt + ?) = ?)", 11.22, 3.14) diff --git a/postgres/cast.go b/postgres/cast.go index 46abf39..42e1a21 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -7,21 +7,19 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -type cast struct { - jet.Cast -} - // CAST function converts an expr (of any type) into later specified datatype. func CAST(expr Expression) *cast { - ret := &cast{} - ret.Cast = jet.NewCastImpl(expr) - - return ret + return &cast{ + expr: expr, + } +} + +type cast struct { + expr Expression } -// AS casts expression as castType func (b *cast) AS(castType string) Expression { - return b.Cast.AS(castType) + return jet.AtomicCustomExpression(b.expr, Token("::"+castType)) } // AS_BOOL casts expression as bool type diff --git a/postgres/dialect.go b/postgres/dialect.go index c5f5085..ebcfa2e 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -13,15 +13,10 @@ var Dialect = newDialect() func newDialect() jet.Dialect { - operatorSerializeOverrides := map[string]jet.SerializeOverride{} - operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgresREGEXPLIKEoperator - operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgresNOTREGEXPLIKEoperator - operatorSerializeOverrides["CAST"] = postgresCAST - dialectParams := jet.DialectParams{ Name: "PostgreSQL", PackageName: "postgres", - OperatorSerializeOverrides: operatorSerializeOverrides, + OperatorSerializeOverrides: nil, AliasQuoteChar: '"', IdentifierQuoteChar: '"', ArgumentPlaceholder: func(ord int) string { @@ -49,6 +44,7 @@ func newDialect() jet.Dialect { } return expr }, + RegexpLike: regexpLike, } return jet.NewDialect(dialectParams) @@ -63,80 +59,23 @@ func argumentToString(value any) (string, bool) { return "", false } -func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { +func regexpLike(str jet.StringExpression, not bool, pattern jet.StringExpression, caseSensitive bool) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } + jet.Serialize(str, statement, out, options...) - expression := expressions[0] + var notOperator string - litExpr, ok := expressions[1].(jet.LiteralExpression) - - if !ok { - panic("jet: cast invalid cast type") - } - - castType, ok := litExpr.Value().(string) - - if !ok { - panic("jet: cast type is not string") - } - - jet.Serialize(expression, statement, out, options...) - out.WriteString("::" + castType) - } -} - -func postgresREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { - return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } - - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } + if not { + notOperator = "!" } if caseSensitive { - out.WriteString("~") + out.WriteString(notOperator + "~") } else { - out.WriteString("~*") + out.WriteString(notOperator + "~*") } - jet.Serialize(expressions[1], statement, out, options...) - } -} - -func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { - return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(expressions) < 2 { - panic("jet: invalid number of expressions for operator") - } - - jet.Serialize(expressions[0], statement, out, options...) - - caseSensitive := false - - if len(expressions) >= 3 { - if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok { - caseSensitive = stringLiteral.Value().(bool) - } - } - - if caseSensitive { - out.WriteString("!~") - } else { - out.WriteString("!~*") - } - - jet.Serialize(expressions[1], statement, out, options...) + jet.Serialize(pattern, statement, out, options...) } } diff --git a/postgres/expressions_test.go b/postgres/expressions_test.go index 3ab361d..f8288c8 100644 --- a/postgres/expressions_test.go +++ b/postgres/expressions_test.go @@ -58,7 +58,7 @@ func TestRawInvalidArguments(t *testing.T) { func TestRawHelperMethods(t *testing.T) { assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), - "(table.colInt < $1) IS FALSE", 11.22) + "((table.colInt < $1) IS FALSE)", 11.22) assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)), "((table.colInt + $1) = $2)", 11.22, 3.14) diff --git a/postgres/functions.go b/postgres/functions.go index b76637b..cfd53b8 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -184,12 +184,12 @@ var CHR = jet.CHR // CONCAT adds two or more expressions together var CONCAT = func(expressions ...Expression) StringExpression { - return jet.CONCAT(explicitLiteralCasts(expressions...)...) + return jet.CONCAT(expressions...) } // CONCAT_WS adds two or more expressions together with a separator. func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression { - return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...) + return jet.CONCAT_WS(separator, expressions...) } // Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions @@ -239,7 +239,7 @@ var DECODE = jet.DECODE // FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf. func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { - return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...) + return jet.FORMAT(formatStr, formatArgs...) } // INITCAP converts the first letter of each word to upper case @@ -578,55 +578,19 @@ var EXISTS = jet.EXISTS // CASE create CASE operator with optional list of expressions var CASE = jet.CASE -func explicitLiteralCasts(expressions ...Expression) []jet.Expression { - ret := []jet.Expression{} - - for _, exp := range expressions { - ret = append(ret, explicitLiteralCast(exp)) - } - - return ret -} - -func explicitLiteralCast(expresion Expression) jet.Expression { - if _, ok := expresion.(jet.LiteralExpression); !ok { - return expresion - } - - switch expresion.(type) { - case jet.BoolExpression: - return CAST(expresion).AS_BOOL() - case jet.IntegerExpression: - return CAST(expresion).AS_INTEGER() - case jet.FloatExpression: - return CAST(expresion).AS_NUMERIC() - case jet.StringExpression: - return CAST(expresion).AS_TEXT() - } - - return expresion -} - // MODE computes the most frequent value of the aggregated argument var MODE = jet.MODE // PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of // aggregated argument values. This will interpolate between adjacent input items if needed. func PERCENTILE_CONT(fraction FloatExpression) *jet.OrderSetAggregateFunc { - return jet.PERCENTILE_CONT(castFloatLiteral(fraction)) + return jet.PERCENTILE_CONT(fraction) } // PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position // in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type. func PERCENTILE_DISC(fraction FloatExpression) *jet.OrderSetAggregateFunc { - return jet.PERCENTILE_DISC(castFloatLiteral(fraction)) -} - -func castFloatLiteral(fraction FloatExpression) FloatExpression { - if _, ok := fraction.(jet.LiteralExpression); ok { - return CAST(fraction).AS_DOUBLE() // to make postgres aware of the type - } - return fraction + return jet.PERCENTILE_DISC(fraction) } // ----------------- Group By operators --------------------------// diff --git a/postgres/interval_literal_test.go b/postgres/interval_literal_test.go index 8d6e647..ef4ae1d 100644 --- a/postgres/interval_literal_test.go +++ b/postgres/interval_literal_test.go @@ -23,7 +23,7 @@ func TestINTERVAL(t *testing.T) { assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY), "INTERVAL '1 YEAR 10 MONTH 20 DAY'") assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), "INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR'") - assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL") + assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "(INTERVAL '1 YEAR' IS NOT NULL)") assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`) f := 5.2 diff --git a/postgres/literal.go b/postgres/literal.go index 1e93110..046231f 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -2,9 +2,10 @@ package postgres import ( "fmt" - "github.com/lib/pq" "time" + "github.com/lib/pq" + "github.com/go-jet/jet/v2/internal/jet" ) diff --git a/sqlite/cast.go b/sqlite/cast.go index fb74820..6f53853 100644 --- a/sqlite/cast.go +++ b/sqlite/cast.go @@ -4,21 +4,20 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -type cast struct { - jet.Cast +// CAST function converts an expr (of any type) into later specified datatype. +func CAST(expr Expression) *cast { + return &cast{ + expr: expr, + } } -// CAST function converts a expr (of any type) into latter specified datatype. -func CAST(expr Expression) *cast { - ret := &cast{} - ret.Cast = jet.NewCastImpl(expr) - - return ret +type cast struct { + expr Expression } // AS casts expressions to castType func (c *cast) AS(castType string) Expression { - return c.Cast.AS(castType) + return jet.AtomicCustomExpression(Token("CAST("), c.expr, Token("AS "+castType+")")) } // AS_TEXT cast expression to TEXT type diff --git a/sqlite/expressions_test.go b/sqlite/expressions_test.go index 0f04df6..922118a 100644 --- a/sqlite/expressions_test.go +++ b/sqlite/expressions_test.go @@ -1,8 +1,9 @@ package sqlite import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestRaw(t *testing.T) { @@ -46,7 +47,7 @@ func TestRawInvalidArguments(t *testing.T) { func TestRawType(t *testing.T) { assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), - "(table.colInt < ?) IS FALSE", 11.22) + "((table.colInt < ?) IS FALSE)", 11.22) assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), "((table.colInt + ?) = ?)", 11.22, 3.14) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 59d1dd0..46262ff 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -648,10 +648,10 @@ func TestStringOperators(t *testing.T) { LTRIM(String("Ltrim"), String("A")), RTRIM(String("rtrim")), RTRIM(AllTypes.VarChar, String("B")), - CHR(Int(65)), - CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), - CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), - CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), + CHR(Int8(65)), + CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, Text("aaa"), Int8(1)), + CONCAT(Bool(false), Int16(1), Real(22.2), Text("test test")), + CONCAT_WS(Text("string1"), Int64(1), Real(11.22), Text("bytea"), Bool(false)), //Float(11.12)), CONVERT(Bytea("bytea"), UTF8, LATIN1), CONVERT(AllTypes.Bytea, UTF8, LATIN1), CONVERT_FROM(Bytea("text_in_utf8"), UTF8), @@ -1117,8 +1117,8 @@ func TestIntegerOperators(t *testing.T) { AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), - BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), - BIT_NOT(Int(-11)).AS("bit_not_2"), + BIT_NOT(Int32(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), + BIT_NOT(Int32(-11)).AS("bit_not_2"), AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), @@ -1130,8 +1130,6 @@ func TestIntegerOperators(t *testing.T) { CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), ).LIMIT(2) - // fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT all_types.big_int AS "all_types.big_int", all_types.big_int_ptr AS "all_types.big_int_ptr", @@ -1173,17 +1171,17 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.small_int | $20) AS "bit or 2", (all_types.small_int # all_types.small_int) AS "bit xor 1", (all_types.small_int # $21) AS "bit xor 2", - (~ ($22 * all_types.small_int)) AS "bit_not_1", - (~ -11) AS "bit_not_2", - (all_types.small_int << (all_types.small_int / $23::smallint)) AS "bit shift left 1", - (all_types.small_int << $24) AS "bit shift left 2", - (all_types.small_int >> (all_types.small_int / $25)) AS "bit shift right 1", - (all_types.small_int >> $26) AS "bit shift right 2", + (~ ($22::integer * all_types.small_int)) AS "bit_not_1", + (~ $23::integer) AS "bit_not_2", + (all_types.small_int << (all_types.small_int / $24::smallint)) AS "bit shift left 1", + (all_types.small_int << $25) AS "bit shift left 2", + (all_types.small_int >> (all_types.small_int / $26)) AS "bit shift right 1", + (all_types.small_int >> $27) AS "bit shift right 2", ABS(all_types.big_int) AS "abs", SQRT(ABS(all_types.big_int)) AS "sqrt", CBRT(ABS(all_types.big_int)) AS "cbrt" FROM test_sample.all_types -LIMIT $27; +LIMIT $28; `) var dest []struct { diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 3dad8c6..9e491b4 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -2,6 +2,9 @@ package postgres import ( "context" + "testing" + "time" + "github.com/bytedance/sonic" "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/utils/ptr" @@ -10,8 +13,6 @@ import ( . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook2/table" "github.com/stretchr/testify/require" - "testing" - "time" ) func TestSelectAlbum(t *testing.T) { @@ -1301,13 +1302,13 @@ func TestAggregateFunc(t *testing.T) { skipForCockroachDB(t) stmt := SELECT( - PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), + PERCENTILE_DISC(Double(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"), PERCENTILE_DISC(RawFloat("(select array_agg(s) from generate_series(0, 1, 0.2) as s)")). WITHIN_GROUP_ORDER_BY(Invoice.BillingAddress.DESC()).AS("percentile_disc_3"), - PERCENTILE_CONT(Float(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"), - PERCENTILE_CONT(Float(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_interval"), + PERCENTILE_CONT(Double(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"), + PERCENTILE_CONT(Double(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_interval"), MODE().WITHIN_GROUP_ORDER_BY(Invoice.BillingPostalCode.DESC()).AS("mode_1"), ).FROM(