Simplify literal expressions.

This commit is contained in:
go-jet 2026-02-02 13:21:35 +01:00
parent 4995a90483
commit 0e495a279e
26 changed files with 233 additions and 616 deletions

View file

@ -1,53 +0,0 @@
package jet
// Cast interface
type Cast interface {
AS(castType string) Expression
}
type castImpl struct {
expression Expression
}
// NewCastImpl creates new generic cast
func NewCastImpl(expression Expression) Cast {
castImpl := castImpl{
expression: expression,
}
return &castImpl
}
func (b *castImpl) AS(castType string) Expression {
castExp := &castExpression{
expression: b.expression,
cast: string(castType),
}
castExp.ExpressionInterfaceImpl.Root = castExp
return castExp
}
type castExpression struct {
ExpressionInterfaceImpl
expression Expression
cast string
}
func (b *castExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
expression := b.expression
castType := b.cast
if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil {
castOverride(expression, String(castType))(statement, out, FallTrough(options)...)
return
}
out.WriteString("CAST(")
expression.serialize(statement, out, FallTrough(options)...)
out.WriteString("AS")
out.WriteString(castType + ")")
}

View file

@ -1,11 +0,0 @@
package jet
import (
"testing"
)
func TestCastAS(t *testing.T) {
assertClauseSerialize(t, NewCastImpl(Int(1)).AS("boolean"), "CAST($1 AS boolean)", int64(1))
assertClauseSerialize(t, NewCastImpl(table2Col3).AS("real"), "CAST(table2.col3 AS real)")
assertClauseSerialize(t, NewCastImpl(table2Col3.ADD(table2Col3)).AS("integer"), "CAST((table2.col3 + table2.col3) AS integer)")
}

View file

@ -18,6 +18,7 @@ type Dialect interface {
SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName(index int) string ValuesDefaultColumnName(index int) string
JsonValueEncode(expr Expression) Expression JsonValueEncode(expr Expression) Expression
RegexpLike(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc
} }
// SerializerFunc func // SerializerFunc func
@ -43,6 +44,7 @@ type DialectParams struct {
SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
ValuesDefaultColumnName func(index int) string ValuesDefaultColumnName func(index int) string
JsonValueEncode func(expr Expression) Expression JsonValueEncode func(expr Expression) Expression
RegexpLike func(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc
} }
// NewDialect creates new dialect with params // NewDialect creates new dialect with params
@ -60,6 +62,7 @@ func NewDialect(params DialectParams) Dialect {
serializeOrderBy: params.SerializeOrderBy, serializeOrderBy: params.SerializeOrderBy,
valuesDefaultColumnName: params.ValuesDefaultColumnName, valuesDefaultColumnName: params.ValuesDefaultColumnName,
jsonValueEncode: params.JsonValueEncode, jsonValueEncode: params.JsonValueEncode,
regexpLike: params.RegexpLike,
} }
} }
@ -76,6 +79,7 @@ type dialectImpl struct {
serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc
valuesDefaultColumnName func(index int) string valuesDefaultColumnName func(index int) string
jsonValueEncode func(expr Expression) Expression jsonValueEncode func(expr Expression) Expression
regexpLike func(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc
} }
func (d *dialectImpl) Name() string { func (d *dialectImpl) Name() string {
@ -133,6 +137,21 @@ func (d *dialectImpl) JsonValueEncode(expr Expression) Expression {
return d.jsonValueEncode(expr) return d.jsonValueEncode(expr)
} }
func (d *dialectImpl) RegexpLike(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) SerializerFunc {
if d.regexpLike != nil {
return d.regexpLike(str, not, pattern, caseSensitive)
}
return func(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
str.serialize(statement, out, FallTrough(options)...)
if not {
out.WriteString("NOT")
}
out.WriteString("REGEXP")
pattern.serialize(statement, out, FallTrough(options)...)
}
}
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{} ret := map[string]bool{}
for _, elem := range arr { for _, elem := range arr {

View file

@ -141,6 +141,7 @@ type binaryOperatorSerializer struct {
} }
func (c *binaryOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (c *binaryOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) {
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, FallTrough(options)...) serializeOverrideFunc(statement, out, FallTrough(options)...)
@ -149,16 +150,18 @@ func (c *binaryOperatorSerializer) serialize(statement StatementType, out *SQLBu
out.WriteString(c.operator) out.WriteString(c.operator)
c.rhs.serialize(statement, out, FallTrough(options)...) c.rhs.serialize(statement, out, FallTrough(options)...)
} }
})
} }
// NewBinaryOperatorExpression creates new binaryOperatorExpression // NewBinaryOperatorExpression creates new binaryOperatorExpression
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression { func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression {
return newExpression(optionalWrap(&binaryOperatorSerializer{ return newExpression(&binaryOperatorSerializer{
lhs: lhs, lhs: lhs,
rhs: rhs, rhs: rhs,
additionalParam: OptionalOrDefault(additionalParam, nil), additionalParam: OptionalOrDefault(additionalParam, nil),
operator: operator, operator: operator,
})) })
} }
type serializersWithOperator struct { type serializersWithOperator struct {
@ -226,6 +229,7 @@ type betweenOperatorSerializer struct {
} }
func (b *betweenOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (b *betweenOperatorSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) {
b.expression.serialize(statement, out, FallTrough(options)...) b.expression.serialize(statement, out, FallTrough(options)...)
if b.notBetween { if b.notBetween {
out.WriteString("NOT") out.WriteString("NOT")
@ -234,16 +238,15 @@ func (b *betweenOperatorSerializer) serialize(statement StatementType, out *SQLB
b.min.serialize(statement, out, FallTrough(options)...) b.min.serialize(statement, out, FallTrough(options)...)
out.WriteString("AND") out.WriteString("AND")
b.max.serialize(statement, out, FallTrough(options)...) b.max.serialize(statement, out, FallTrough(options)...)
})
} }
// NewBetweenOperatorExpression creates new BETWEEN operator expression // NewBetweenOperatorExpression creates new BETWEEN operator expression
func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bool) BoolExpression { func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bool) BoolExpression {
return BoolExp(newExpression( return BoolExp(newExpression(&betweenOperatorSerializer{
optionalWrap(&betweenOperatorSerializer{
expression: expression, expression: expression,
notBetween: notBetween, notBetween: notBetween,
min: min, min: min,
max: max, max: max,
}), }))
))
} }

View file

@ -244,7 +244,7 @@ func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{})
defaultValue, ok = offsetAndDefault[1].(Expression) defaultValue, ok = offsetAndDefault[1].(Expression)
if !ok { if !ok {
defaultValue = literal(offsetAndDefault[1]) defaultValue = Literal(offsetAndDefault[1])
} }
params = append(params, FixedLiteral(offset), defaultValue) params = append(params, FixedLiteral(offset), defaultValue)

View file

@ -66,7 +66,7 @@ func TestIntExpressionPOW(t *testing.T) {
func TestIntExpressionBIT_NOT(t *testing.T) { func TestIntExpressionBIT_NOT(t *testing.T) {
assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)") assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)")
assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ 11)") assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ $1)", int64(11))
} }
func TestIntExpressionBIT_AND(t *testing.T) { func TestIntExpressionBIT_AND(t *testing.T) {

View file

@ -5,48 +5,12 @@ import (
"time" "time"
) )
// LiteralExpression is representation of an escaped literal type literalSerializer struct {
type LiteralExpression interface {
Expression
Value() interface{}
SetConstant(constant bool)
}
type literalExpressionImpl struct {
ExpressionInterfaceImpl
value interface{} value interface{}
constant bool constant bool
} }
func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl { func (l *literalSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
exp := literalExpressionImpl{value: value}
if len(optionalConstant) > 0 {
exp.constant = optionalConstant[0]
}
exp.ExpressionInterfaceImpl.Root = &exp
return &exp
}
// Literal is injected directly to SQL query, and does not appear in parametrized argument list.
func Literal(value interface{}) *literalExpressionImpl {
exp := literal(value)
return exp
}
// FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list.
func FixedLiteral(value interface{}) *literalExpressionImpl {
exp := literal(value)
exp.constant = true
return exp
}
func (l *literalExpressionImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if l.constant { if l.constant {
out.insertConstantArgument(l.value) out.insertConstantArgument(l.value)
} else { } else {
@ -54,260 +18,145 @@ func (l *literalExpressionImpl) serialize(statement StatementType, out *SQLBuild
} }
} }
func (l *literalExpressionImpl) Value() interface{} { // Literal is injected directly to SQL query, and does not appear in parametrized argument list.
return l.value func Literal(value interface{}) Expression {
return newExpression(&literalSerializer{
value: value,
constant: false,
})
} }
func (l *literalExpressionImpl) SetConstant(constant bool) { // FixedLiteral is injected directly to SQL query, and does not appear in parametrized argument list.
l.constant = constant func FixedLiteral(value interface{}) Expression {
} return newExpression(&literalSerializer{
value: value,
type integerLiteralExpression struct { constant: true,
literalExpressionImpl })
integerInterfaceImpl
}
func intLiteral(value interface{}) IntegerExpression {
numLiteral := &integerLiteralExpression{}
numLiteral.literalExpressionImpl = *literal(value)
numLiteral.literalExpressionImpl.Root = numLiteral
numLiteral.integerInterfaceImpl.root = numLiteral
return numLiteral
} }
// Int creates a new 64 bit signed integer literal // Int creates a new 64 bit signed integer literal
func Int(value int64) IntegerExpression { func Int(value int64) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
} }
// Int8 creates a new 8 bit signed integer literal // Int8 creates a new 8 bit signed integer literal
func Int8(value int8) IntegerExpression { func Int8(value int8) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
} }
// Int16 creates a new 16 bit signed integer literal // Int16 creates a new 16 bit signed integer literal
func Int16(value int16) IntegerExpression { func Int16(value int16) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
} }
// Int32 creates a new 32 bit signed integer literal // Int32 creates a new 32 bit signed integer literal
func Int32(value int32) IntegerExpression { func Int32(value int32) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
} }
// Uint8 creates a new 8 bit unsigned integer literal // Uint8 creates a new 8 bit unsigned integer literal
func Uint8(value uint8) IntegerExpression { func Uint8(value uint8) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
} }
// Uint16 creates a new 16 bit unsigned integer literal // Uint16 creates a new 16 bit unsigned integer literal
func Uint16(value uint16) IntegerExpression { func Uint16(value uint16) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
} }
// Uint32 creates a new 32 bit unsigned integer literal // Uint32 creates a new 32 bit unsigned integer literal
func Uint32(value uint32) IntegerExpression { func Uint32(value uint32) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
} }
// Uint64 creates a new 64 bit unsigned integer literal // Uint64 creates a new 64 bit unsigned integer literal
func Uint64(value uint64) IntegerExpression { func Uint64(value uint64) IntegerExpression {
return intLiteral(value) return IntExp(Literal(value))
}
// ---------------------------------------------------//
type boolLiteralExpression struct {
boolInterfaceImpl
literalExpressionImpl
} }
// Bool creates new bool literal expression // Bool creates new bool literal expression
func Bool(value bool) BoolExpression { func Bool(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{} return BoolExp(Literal(value))
boolLiteralExpression.literalExpressionImpl = *literal(value)
boolLiteralExpression.boolInterfaceImpl.root = &boolLiteralExpression
return &boolLiteralExpression
}
// ---------------------------------------------------//
type floatLiteral struct {
floatInterfaceImpl
literalExpressionImpl
} }
// Float creates new float literal from float64 value // Float creates new float literal from float64 value
func Float(value float64) FloatExpression { func Float(value float64) FloatExpression {
floatLiteral := floatLiteral{} return FloatExp(Literal(value))
floatLiteral.literalExpressionImpl = *literal(value)
floatLiteral.floatInterfaceImpl.root = &floatLiteral
return &floatLiteral
} }
// Decimal creates new float literal from string value // Decimal creates new float literal from string value
func Decimal(value string) FloatExpression { func Decimal(value string) FloatExpression {
floatLiteral := floatLiteral{} return FloatExp(Literal(value))
floatLiteral.literalExpressionImpl = *literal(value)
floatLiteral.floatInterfaceImpl.root = &floatLiteral
return &floatLiteral
}
// ---------------------------------------------------//
type stringLiteral struct {
stringInterfaceImpl
literalExpressionImpl
} }
// String creates new string literal expression // String creates new string literal expression
func String(value string) StringExpression { func String(value string) StringExpression {
stringLiteral := stringLiteral{} return StringExp(Literal(value))
stringLiteral.literalExpressionImpl = *literal(value)
stringLiteral.stringInterfaceImpl.root = &stringLiteral
return &stringLiteral
}
//---------------------------------------------------//
type timeLiteral struct {
timeInterfaceImpl
literalExpressionImpl
} }
// Time creates new time literal expression // Time creates new time literal expression
func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression {
timeLiteral := &timeLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...) timeStr += formatNanoseconds(nanoseconds...)
timeLiteral.literalExpressionImpl = *literal(timeStr)
timeLiteral.timeInterfaceImpl.root = timeLiteral return TimeExp(Literal(timeStr))
return timeLiteral
} }
// TimeT creates new time literal expression from time.Time object // TimeT creates new time literal expression from time.Time object
func TimeT(t time.Time) TimeExpression { func TimeT(t time.Time) TimeExpression {
timeLiteral := &timeLiteral{} return TimeExp(Literal(t))
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timeInterfaceImpl.root = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timezLiteral struct {
timezInterfaceImpl
literalExpressionImpl
} }
// Timez creates new time with time zone literal expression // Timez creates new time with time zone literal expression
func Timez(hour, minute, second int, nanoseconds time.Duration, timezone string) TimezExpression { func Timez(hour, minute, second int, nanoseconds time.Duration, timezone string) TimezExpression {
timezLiteral := timezLiteral{}
timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second) timeStr := fmt.Sprintf("%02d:%02d:%02d", hour, minute, second)
timeStr += formatNanoseconds(nanoseconds) timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone timeStr += " " + timezone
timezLiteral.literalExpressionImpl = *literal(timeStr)
return TimezExp(literal(timeStr)) return TimezExp(Literal(timeStr))
} }
// TimezT creates new time with time zone literal expression from time.Time object // TimezT creates new time with time zone literal expression from time.Time object
func TimezT(t time.Time) TimezExpression { func TimezT(t time.Time) TimezExpression {
timeLiteral := &timezLiteral{} return TimezExp(Literal(t))
timeLiteral.literalExpressionImpl = *literal(t)
timeLiteral.timezInterfaceImpl.root = timeLiteral
return timeLiteral
}
//---------------------------------------------------//
type timestampLiteral struct {
timestampInterfaceImpl
literalExpressionImpl
} }
// Timestamp creates new timestamp literal expression // Timestamp creates new timestamp literal expression
func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression { func Timestamp(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) TimestampExpression {
timestamp := &timestampLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second) timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds...) timeStr += formatNanoseconds(nanoseconds...)
timestamp.literalExpressionImpl = *literal(timeStr)
timestamp.timestampInterfaceImpl.root = timestamp return TimestampExp(Literal(timeStr))
return timestamp
} }
// TimestampT creates new timestamp literal expression from time.Time object // TimestampT creates new timestamp literal expression from time.Time object
func TimestampT(t time.Time) TimestampExpression { func TimestampT(t time.Time) TimestampExpression {
timestamp := &timestampLiteral{} return TimestampExp(Literal(t))
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampInterfaceImpl.root = timestamp
return timestamp
}
//---------------------------------------------------//
type timestampzLiteral struct {
timestampzInterfaceImpl
literalExpressionImpl
} }
// Timestampz creates new timestamp with time zone literal expression // Timestampz creates new timestamp with time zone literal expression
func Timestampz(year int, month time.Month, day, hour, minute, second int, nanoseconds time.Duration, timezone string) TimestampzExpression { func Timestampz(year int, month time.Month, day, hour, minute, second int, nanoseconds time.Duration, timezone string) TimestampzExpression {
timestamp := &timestampzLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second) timeStr := fmt.Sprintf("%04d-%02d-%02d %02d:%02d:%02d", year, month, day, hour, minute, second)
timeStr += formatNanoseconds(nanoseconds) timeStr += formatNanoseconds(nanoseconds)
timeStr += " " + timezone timeStr += " " + timezone
timestamp.literalExpressionImpl = *literal(timeStr) return TimestampzExp(Literal(timeStr))
timestamp.timestampzInterfaceImpl.root = timestamp
return timestamp
} }
// TimestampzT creates new timestamp literal expression from time.Time object // TimestampzT creates new timestamp literal expression from time.Time object
func TimestampzT(t time.Time) TimestampzExpression { func TimestampzT(t time.Time) TimestampzExpression {
timestamp := &timestampzLiteral{} return TimestampzExp(Literal(t))
timestamp.literalExpressionImpl = *literal(t)
timestamp.timestampzInterfaceImpl.root = timestamp
return timestamp
}
//---------------------------------------------------//
type dateLiteral struct {
dateInterfaceImpl
literalExpressionImpl
} }
// Date creates new date literal expression // Date creates new date literal expression
func Date(year int, month time.Month, day int) DateExpression { func Date(year int, month time.Month, day int) DateExpression {
dateLiteral := &dateLiteral{}
timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day) timeStr := fmt.Sprintf("%04d-%02d-%02d", year, month, day)
dateLiteral.literalExpressionImpl = *literal(timeStr) return DateExp(Literal(timeStr))
dateLiteral.dateInterfaceImpl.root = dateLiteral
return dateLiteral
} }
// DateT creates new date literal expression from time.Time object // DateT creates new date literal expression from time.Time object
func DateT(t time.Time) DateExpression { func DateT(t time.Time) DateExpression {
dateLiteral := &dateLiteral{} return DateExp(Literal(t))
dateLiteral.literalExpressionImpl = *literal(t)
dateLiteral.dateInterfaceImpl.root = dateLiteral
return dateLiteral
} }
func formatNanoseconds(nanoseconds ...time.Duration) string { func formatNanoseconds(nanoseconds ...time.Duration) string {
@ -330,86 +179,35 @@ func formatNanoseconds(nanoseconds ...time.Duration) string {
var ( var (
// NULL is jet equivalent of SQL NULL // NULL is jet equivalent of SQL NULL
NULL = newNullLiteral() NULL = newExpression(Keyword("NULL"))
// STAR is jet equivalent of SQL * // STAR is jet equivalent of SQL *
STAR = newStarLiteral() STAR = newExpression(Keyword("*"))
// PLUS_INFINITY is jet equivalent for sql infinity // PLUS_INFINITY is jet equivalent for sql infinity
PLUS_INFINITY = String("infinity") PLUS_INFINITY = String("infinity")
// MINUS_INFINITY is jet equivalent for sql -infinity // MINUS_INFINITY is jet equivalent for sql -infinity
MINUS_INFINITY = String("-infinity") MINUS_INFINITY = String("-infinity")
) )
type nullLiteral struct {
ExpressionInterfaceImpl
}
func newNullLiteral() Expression {
nullExpression := &nullLiteral{}
nullExpression.ExpressionInterfaceImpl.Root = nullExpression
return nullExpression
}
func (n *nullLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("NULL")
}
// --------------------------------------------------//
type starLiteral struct {
ExpressionInterfaceImpl
}
func newStarLiteral() Expression {
starExpression := &starLiteral{}
starExpression.ExpressionInterfaceImpl.Root = starExpression
return starExpression
}
func (n *starLiteral) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("*")
}
//---------------------------------------------------// //---------------------------------------------------//
type rawExpression struct { type rawSerializer struct {
ExpressionInterfaceImpl
Raw string Raw string
NamedArgument map[string]interface{} NamedArgument map[string]interface{}
noWrap bool
}
func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !n.noWrap && !contains(options, NoWrap) {
out.WriteByte('(')
} }
func (n *rawSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) {
out.insertRawQuery(n.Raw, n.NamedArgument) out.insertRawQuery(n.Raw, n.NamedArgument)
})
if !n.noWrap && !contains(options, NoWrap) {
out.WriteByte(')')
}
} }
// Raw can be used for any unsupported functions, operators or expressions. // Raw can be used for any unsupported functions, operators or expressions.
// For example: Raw("current_database()") // For example: Raw("current_database()")
func Raw(raw string, namedArgs ...map[string]interface{}) Expression { func Raw(raw string, namedArgs ...map[string]interface{}) Expression {
var namedArguments map[string]interface{} return newExpression(&rawSerializer{
if len(namedArgs) > 0 {
namedArguments = namedArgs[0]
}
rawExp := &rawExpression{
Raw: raw, Raw: raw,
NamedArgument: namedArguments, NamedArgument: singleOptional(namedArgs),
} })
rawExp.ExpressionInterfaceImpl.Root = rawExp
return rawExp
} }
// RawBool helper that for raw string boolean expressions // RawBool helper that for raw string boolean expressions

View file

@ -16,9 +16,6 @@ func NOT(exp BoolExpression) BoolExpression {
// BIT_NOT inverts every bit in integer expression result // BIT_NOT inverts every bit in integer expression result
func BIT_NOT(expr IntegerExpression) IntegerExpression { func BIT_NOT(expr IntegerExpression) IntegerExpression {
if literalExp, ok := expr.(LiteralExpression); ok {
literalExp.SetConstant(true)
}
return newPrefixIntegerOperatorExpression(expr, "~") return newPrefixIntegerOperatorExpression(expr, "~")
} }
@ -131,10 +128,8 @@ type caseOperatorImpl struct {
// CASE create CASE operator with optional list of expressions // CASE create CASE operator with optional list of expressions
func CASE(expression ...Expression) CaseOperator { func CASE(expression ...Expression) CaseOperator {
caseExp := &caseOperatorImpl{} caseExp := &caseOperatorImpl{
expression: singleOptional(expression),
if len(expression) > 0 {
caseExp.expression = expression[0]
} }
caseExp.ExpressionInterfaceImpl.Root = caseExp caseExp.ExpressionInterfaceImpl.Root = caseExp

View file

@ -34,25 +34,20 @@ func newOrderSetAggregateFunction(name string, fraction FloatExpression) *OrderS
// WITHIN_GROUP_ORDER_BY specifies ordered set of aggregated argument values // WITHIN_GROUP_ORDER_BY specifies ordered set of aggregated argument values
func (p *OrderSetAggregateFunc) WITHIN_GROUP_ORDER_BY(orderBy OrderByClause) Expression { func (p *OrderSetAggregateFunc) WITHIN_GROUP_ORDER_BY(orderBy OrderByClause) Expression {
p.orderBy = ORDER_BY(orderBy) p.orderBy = ORDER_BY(orderBy)
return newOrderSetAggregateFuncExpression(*p) return newOrderSetAggregateFuncExpression(p)
} }
func newOrderSetAggregateFuncExpression(aggFunc OrderSetAggregateFunc) *orderSetAggregateFuncExpression { func newOrderSetAggregateFuncExpression(aggFunc *OrderSetAggregateFunc) Expression {
ret := &orderSetAggregateFuncExpression{ return newExpression(&orderSetAggregateFuncSerializer{
OrderSetAggregateFunc: aggFunc, OrderSetAggregateFunc: aggFunc,
})
} }
ret.ExpressionInterfaceImpl.Root = ret type orderSetAggregateFuncSerializer struct {
*OrderSetAggregateFunc
return ret
} }
type orderSetAggregateFuncExpression struct { func (p *orderSetAggregateFuncSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
ExpressionInterfaceImpl
OrderSetAggregateFunc
}
func (p *orderSetAggregateFuncExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString(p.name) out.WriteString(p.name)
if p.fraction != nil { if p.fraction != nil {

View file

@ -16,10 +16,7 @@ func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]
root: nil, root: nil,
}, },
RawQuery: rawQuery, RawQuery: rawQuery,
} NamedArguments: singleOptional(namedArgument),
if len(namedArgument) > 0 {
newRawStatement.NamedArguments = namedArgument[0]
} }
newRawStatement.root = &newRawStatement newRawStatement.root = &newRawStatement

View file

@ -121,64 +121,50 @@ func (t Token) serialize(statement StatementType, out *SQLBuilder, options ...Se
// CustomExpression creates new custom expression. When serialized may require parentheses // CustomExpression creates new custom expression. When serialized may require parentheses
// depending on context. // depending on context.
func CustomExpression(parts ...Serializer) Expression { func CustomExpression(parts ...Serializer) Expression {
return newExpression(optionalWrap(&customSerializer{ return newExpression(&customSerializer{
parts: parts, parts: parts,
})) })
}
// AtomicCustomExpression creates new custom expression. When serialized does not require parentheses.
func AtomicCustomExpression(parts ...Serializer) Expression {
return newExpression(&customSerializer{
parts: parts,
atomic: true,
})
} }
type customSerializer struct { type customSerializer struct {
parts []Serializer parts []Serializer
atomic bool
} }
func (c *customSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (c *customSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
for _, expression := range c.parts { if c.atomic {
expression.serialize(statement, out, options...) for _, expr := range c.parts {
expr.serialize(statement, out, without(options, NoWrap)...)
}
} else {
optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) {
for _, expr := range c.parts {
expr.serialize(statement, out, options...)
}
})
} }
} }
type optionalWrapSerializer struct { func optionalWrap(out *SQLBuilder, options []SerializeOption, ser func(out *SQLBuilder, options []SerializeOption)) {
serializer []Serializer
}
func optionalWrap(serializer ...Serializer) Serializer {
return &optionalWrapSerializer{serializer: serializer}
}
func (s *optionalWrapSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) { if !contains(options, NoWrap) {
out.WriteString("(") out.WriteString("(")
} }
for _, ser := range s.serializer { ser(out, without(options, NoWrap))
ser.serialize(statement, out, without(options, NoWrap)...)
}
if !contains(options, NoWrap) { if !contains(options, NoWrap) {
out.WriteString(")") out.WriteString(")")
} }
} }
// AtomicCustomExpression creates new custom expression. When serialized does not require parentheses.
func AtomicCustomExpression(parts ...Serializer) Expression {
return newExpression(noWrap(&customSerializer{
parts: parts,
}))
}
type noWrapSerializer struct {
serializer []Serializer
}
func noWrap(serializer ...Serializer) Serializer {
return &noWrapSerializer{serializer: serializer}
}
func (s *noWrapSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
for _, ser := range s.serializer {
ser.serialize(statement, out, without(options, NoWrap)...)
}
}
func wrap(expressions ...Expression) Expression { func wrap(expressions ...Expression) Expression {
return newFunc("", expressions) return newFunc("", expressions)
} }

View file

@ -85,11 +85,33 @@ func (s *stringInterfaceImpl) NOT_LIKE(pattern StringExpression) BoolExpression
} }
func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { func (s *stringInterfaceImpl) REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperatorExpression(s.root, pattern, StringRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) return BoolExp(newExpression(&regexpLikeSerializer{
str: s.root,
pattern: pattern,
caseSensitive: len(caseSensitive) > 0 && caseSensitive[0],
}))
} }
func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression { func (s *stringInterfaceImpl) NOT_REGEXP_LIKE(pattern StringExpression, caseSensitive ...bool) BoolExpression {
return newBinaryBoolOperatorExpression(s.root, pattern, StringNotRegexpLikeOperator, Bool(len(caseSensitive) > 0 && caseSensitive[0])) return BoolExp(newExpression(&regexpLikeSerializer{
not: true,
str: s.root,
pattern: pattern,
caseSensitive: len(caseSensitive) > 0 && caseSensitive[0],
}))
}
type regexpLikeSerializer struct {
not bool
str StringExpression
pattern StringExpression
caseSensitive bool
}
func (r *regexpLikeSerializer) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
optionalWrap(out, options, func(out *SQLBuilder, options []SerializeOption) {
out.Dialect.RegexpLike(r.str, r.not, r.pattern, r.caseSensitive)(statement, out, options...)
})
} }
// ---------------------------------------------------// // ---------------------------------------------------//

View file

@ -160,7 +160,7 @@ func ToSerializerValue(value interface{}) Serializer {
return clause return clause
} }
return literal(value) return Literal(value)
} }
// UnwindRowFromModel func // UnwindRowFromModel func
@ -189,7 +189,7 @@ func UnwindRowFromModel(columns []Column, data interface{}) []Serializer {
field = reflect.Indirect(structField).Interface() field = reflect.Indirect(structField).Interface()
} }
row[i] = literal(field) row[i] = Literal(field)
} }
return row return row
@ -293,7 +293,7 @@ func joinAlias(tableAlias, columnAlias string) string {
return strings.TrimRight(tableAlias, ".*") + "." + columnAlias return strings.TrimRight(tableAlias, ".*") + "." + columnAlias
} }
func optional[T any](value []T) T { func singleOptional[T any](value []T) T {
if len(value) > 0 { if len(value) > 0 {
return value[0] return value[0]
} }

View file

@ -1,25 +1,25 @@
package mysql package mysql
import ( import (
"github.com/go-jet/jet/v2/internal/jet"
"strconv" "strconv"
"github.com/go-jet/jet/v2/internal/jet"
) )
type cast struct { // CAST function converts an expr (of any type) into later specified datatype.
jet.Cast func CAST(expr Expression) *cast {
return &cast{
expr: expr,
}
} }
// CAST function converts a expr (of any type) into latter specified datatype. type cast struct {
func CAST(expr Expression) *cast { expr Expression
ret := &cast{}
ret.Cast = jet.NewCastImpl(expr)
return ret
} }
// AS casts expressions to castType // AS casts expressions to castType
func (c *cast) AS(castType string) Expression { func (c *cast) AS(castType string) Expression {
return c.Cast.AS(castType) return jet.AtomicCustomExpression(Token("CAST("), c.expr, Token("AS "+castType+")"))
} }
// AS_DATETIME cast expression to DATETIME type // AS_DATETIME cast expression to DATETIME type

View file

@ -12,8 +12,6 @@ var Dialect = newDialect()
func newDialect() jet.Dialect { func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{} operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysqlREGEXPLIKEoperator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysqlNOTREGEXPLIKEoperator
operatorSerializeOverrides["IS DISTINCT FROM"] = mysqlISDISTINCTFROM operatorSerializeOverrides["IS DISTINCT FROM"] = mysqlISDISTINCTFROM
operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysqlISNOTDISTINCTFROM operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysqlISNOTDISTINCTFROM
operatorSerializeOverrides["/"] = mysqlDivision operatorSerializeOverrides["/"] = mysqlDivision
@ -52,6 +50,7 @@ func newDialect() jet.Dialect {
} }
return expr return expr
}, },
RegexpLike: regexpLikeOperator,
} }
return jet.NewDialect(mySQLDialectParams) return jet.NewDialect(mySQLDialectParams)
@ -144,20 +143,12 @@ func mysqlISDISTINCTFROM(expressions ...jet.Serializer) jet.SerializerFunc {
} }
} }
func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc { func regexpLikeOperator(str StringExpression, not bool, pattern StringExpression, caseSensitive bool) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { jet.Serialize(str, statement, out, options...)
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out, options...) if not {
out.WriteString("NOT")
caseSensitive := false
if len(expressions) >= 3 {
if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok {
caseSensitive = stringLiteral.Value().(bool)
}
} }
out.WriteString("REGEXP") out.WriteString("REGEXP")
@ -166,33 +157,7 @@ func mysqlREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
out.WriteString("BINARY") out.WriteString("BINARY")
} }
jet.Serialize(expressions[1], statement, out, options...) jet.Serialize(pattern, statement, out, options...)
}
}
func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out, options...)
caseSensitive := false
if len(expressions) >= 3 {
if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok {
caseSensitive = stringLiteral.Value().(bool)
}
}
out.WriteString("NOT REGEXP")
if caseSensitive {
out.WriteString("BINARY")
}
jet.Serialize(expressions[1], statement, out, options...)
} }
} }

View file

@ -48,7 +48,7 @@ func TestRawInvalidArguments(t *testing.T) {
func TestRawType(t *testing.T) { func TestRawType(t *testing.T) {
assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(),
"(table.colInt < ?) IS FALSE", 11.22) "((table.colInt < ?) IS FALSE)", 11.22)
assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)),
"((table.colInt + ?) = ?)", 11.22, 3.14) "((table.colInt + ?) = ?)", 11.22, 3.14)

View file

@ -7,21 +7,19 @@ import (
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )
type cast struct {
jet.Cast
}
// CAST function converts an expr (of any type) into later specified datatype. // CAST function converts an expr (of any type) into later specified datatype.
func CAST(expr Expression) *cast { func CAST(expr Expression) *cast {
ret := &cast{} return &cast{
ret.Cast = jet.NewCastImpl(expr) expr: expr,
}
return ret }
type cast struct {
expr Expression
} }
// AS casts expression as castType
func (b *cast) AS(castType string) Expression { func (b *cast) AS(castType string) Expression {
return b.Cast.AS(castType) return jet.AtomicCustomExpression(b.expr, Token("::"+castType))
} }
// AS_BOOL casts expression as bool type // AS_BOOL casts expression as bool type

View file

@ -13,15 +13,10 @@ var Dialect = newDialect()
func newDialect() jet.Dialect { func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgresREGEXPLIKEoperator
operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgresNOTREGEXPLIKEoperator
operatorSerializeOverrides["CAST"] = postgresCAST
dialectParams := jet.DialectParams{ dialectParams := jet.DialectParams{
Name: "PostgreSQL", Name: "PostgreSQL",
PackageName: "postgres", PackageName: "postgres",
OperatorSerializeOverrides: operatorSerializeOverrides, OperatorSerializeOverrides: nil,
AliasQuoteChar: '"', AliasQuoteChar: '"',
IdentifierQuoteChar: '"', IdentifierQuoteChar: '"',
ArgumentPlaceholder: func(ord int) string { ArgumentPlaceholder: func(ord int) string {
@ -49,6 +44,7 @@ func newDialect() jet.Dialect {
} }
return expr return expr
}, },
RegexpLike: regexpLike,
} }
return jet.NewDialect(dialectParams) return jet.NewDialect(dialectParams)
@ -63,80 +59,23 @@ func argumentToString(value any) (string, bool) {
return "", false return "", false
} }
func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { func regexpLike(str jet.StringExpression, not bool, pattern jet.StringExpression, caseSensitive bool) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 { jet.Serialize(str, statement, out, options...)
panic("jet: invalid number of expressions for operator")
}
expression := expressions[0] var notOperator string
litExpr, ok := expressions[1].(jet.LiteralExpression) if not {
notOperator = "!"
if !ok {
panic("jet: cast invalid cast type")
}
castType, ok := litExpr.Value().(string)
if !ok {
panic("jet: cast type is not string")
}
jet.Serialize(expression, statement, out, options...)
out.WriteString("::" + castType)
}
}
func postgresREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out, options...)
caseSensitive := false
if len(expressions) >= 3 {
if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok {
caseSensitive = stringLiteral.Value().(bool)
}
} }
if caseSensitive { if caseSensitive {
out.WriteString("~") out.WriteString(notOperator + "~")
} else { } else {
out.WriteString("~*") out.WriteString(notOperator + "~*")
} }
jet.Serialize(expressions[1], statement, out, options...) jet.Serialize(pattern, statement, out, options...)
}
}
func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out, options...)
caseSensitive := false
if len(expressions) >= 3 {
if stringLiteral, ok := expressions[2].(jet.LiteralExpression); ok {
caseSensitive = stringLiteral.Value().(bool)
}
}
if caseSensitive {
out.WriteString("!~")
} else {
out.WriteString("!~*")
}
jet.Serialize(expressions[1], statement, out, options...)
} }
} }

View file

@ -58,7 +58,7 @@ func TestRawInvalidArguments(t *testing.T) {
func TestRawHelperMethods(t *testing.T) { func TestRawHelperMethods(t *testing.T) {
assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(),
"(table.colInt < $1) IS FALSE", 11.22) "((table.colInt < $1) IS FALSE)", 11.22)
assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)), assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)),
"((table.colInt + $1) = $2)", 11.22, 3.14) "((table.colInt + $1) = $2)", 11.22, 3.14)

View file

@ -184,12 +184,12 @@ var CHR = jet.CHR
// CONCAT adds two or more expressions together // CONCAT adds two or more expressions together
var CONCAT = func(expressions ...Expression) StringExpression { var CONCAT = func(expressions ...Expression) StringExpression {
return jet.CONCAT(explicitLiteralCasts(expressions...)...) return jet.CONCAT(expressions...)
} }
// CONCAT_WS adds two or more expressions together with a separator. // CONCAT_WS adds two or more expressions together with a separator.
func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression { func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression {
return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...) return jet.CONCAT_WS(separator, expressions...)
} }
// Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions // Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions
@ -239,7 +239,7 @@ var DECODE = jet.DECODE
// FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf. // FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf.
func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...) return jet.FORMAT(formatStr, formatArgs...)
} }
// INITCAP converts the first letter of each word to upper case // INITCAP converts the first letter of each word to upper case
@ -578,55 +578,19 @@ var EXISTS = jet.EXISTS
// CASE create CASE operator with optional list of expressions // CASE create CASE operator with optional list of expressions
var CASE = jet.CASE var CASE = jet.CASE
func explicitLiteralCasts(expressions ...Expression) []jet.Expression {
ret := []jet.Expression{}
for _, exp := range expressions {
ret = append(ret, explicitLiteralCast(exp))
}
return ret
}
func explicitLiteralCast(expresion Expression) jet.Expression {
if _, ok := expresion.(jet.LiteralExpression); !ok {
return expresion
}
switch expresion.(type) {
case jet.BoolExpression:
return CAST(expresion).AS_BOOL()
case jet.IntegerExpression:
return CAST(expresion).AS_INTEGER()
case jet.FloatExpression:
return CAST(expresion).AS_NUMERIC()
case jet.StringExpression:
return CAST(expresion).AS_TEXT()
}
return expresion
}
// MODE computes the most frequent value of the aggregated argument // MODE computes the most frequent value of the aggregated argument
var MODE = jet.MODE var MODE = jet.MODE
// PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of // PERCENTILE_CONT computes a value corresponding to the specified fraction within the ordered set of
// aggregated argument values. This will interpolate between adjacent input items if needed. // aggregated argument values. This will interpolate between adjacent input items if needed.
func PERCENTILE_CONT(fraction FloatExpression) *jet.OrderSetAggregateFunc { func PERCENTILE_CONT(fraction FloatExpression) *jet.OrderSetAggregateFunc {
return jet.PERCENTILE_CONT(castFloatLiteral(fraction)) return jet.PERCENTILE_CONT(fraction)
} }
// PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position // PERCENTILE_DISC computes the first value within the ordered set of aggregated argument values whose position
// in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type. // in the ordering equals or exceeds the specified fraction. The aggregated argument must be of a sortable type.
func PERCENTILE_DISC(fraction FloatExpression) *jet.OrderSetAggregateFunc { func PERCENTILE_DISC(fraction FloatExpression) *jet.OrderSetAggregateFunc {
return jet.PERCENTILE_DISC(castFloatLiteral(fraction)) return jet.PERCENTILE_DISC(fraction)
}
func castFloatLiteral(fraction FloatExpression) FloatExpression {
if _, ok := fraction.(jet.LiteralExpression); ok {
return CAST(fraction).AS_DOUBLE() // to make postgres aware of the type
}
return fraction
} }
// ----------------- Group By operators --------------------------// // ----------------- Group By operators --------------------------//

View file

@ -23,7 +23,7 @@ func TestINTERVAL(t *testing.T) {
assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY), "INTERVAL '1 YEAR 10 MONTH 20 DAY'") assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY), "INTERVAL '1 YEAR 10 MONTH 20 DAY'")
assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), "INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR'") assertSerialize(t, INTERVAL(1, YEAR, 10, MONTH, 20, DAY, 3, HOUR), "INTERVAL '1 YEAR 10 MONTH 20 DAY 3 HOUR'")
assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "INTERVAL '1 YEAR' IS NOT NULL") assertSerialize(t, INTERVAL(1, YEAR).IS_NOT_NULL(), "(INTERVAL '1 YEAR' IS NOT NULL)")
assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`) assertProjectionSerialize(t, INTERVAL(1, YEAR).AS("one year"), `INTERVAL '1 YEAR' AS "one year"`)
f := 5.2 f := 5.2

View file

@ -2,9 +2,10 @@ package postgres
import ( import (
"fmt" "fmt"
"github.com/lib/pq"
"time" "time"
"github.com/lib/pq"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )

View file

@ -4,21 +4,20 @@ import (
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
) )
type cast struct { // CAST function converts an expr (of any type) into later specified datatype.
jet.Cast func CAST(expr Expression) *cast {
return &cast{
expr: expr,
}
} }
// CAST function converts a expr (of any type) into latter specified datatype. type cast struct {
func CAST(expr Expression) *cast { expr Expression
ret := &cast{}
ret.Cast = jet.NewCastImpl(expr)
return ret
} }
// AS casts expressions to castType // AS casts expressions to castType
func (c *cast) AS(castType string) Expression { func (c *cast) AS(castType string) Expression {
return c.Cast.AS(castType) return jet.AtomicCustomExpression(Token("CAST("), c.expr, Token("AS "+castType+")"))
} }
// AS_TEXT cast expression to TEXT type // AS_TEXT cast expression to TEXT type

View file

@ -1,8 +1,9 @@
package sqlite package sqlite
import ( import (
"github.com/stretchr/testify/require"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestRaw(t *testing.T) { func TestRaw(t *testing.T) {
@ -46,7 +47,7 @@ func TestRawInvalidArguments(t *testing.T) {
func TestRawType(t *testing.T) { func TestRawType(t *testing.T) {
assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(), assertSerialize(t, RawBool("table.colInt < :float", RawArgs{":float": 11.22}).IS_FALSE(),
"(table.colInt < ?) IS FALSE", 11.22) "((table.colInt < ?) IS FALSE)", 11.22)
assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)),
"((table.colInt + ?) = ?)", 11.22, 3.14) "((table.colInt + ?) = ?)", 11.22, 3.14)

View file

@ -648,10 +648,10 @@ func TestStringOperators(t *testing.T) {
LTRIM(String("Ltrim"), String("A")), LTRIM(String("Ltrim"), String("A")),
RTRIM(String("rtrim")), RTRIM(String("rtrim")),
RTRIM(AllTypes.VarChar, String("B")), RTRIM(AllTypes.VarChar, String("B")),
CHR(Int(65)), CHR(Int8(65)),
CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, Text("aaa"), Int8(1)),
CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), CONCAT(Bool(false), Int16(1), Real(22.2), Text("test test")),
CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), CONCAT_WS(Text("string1"), Int64(1), Real(11.22), Text("bytea"), Bool(false)), //Float(11.12)),
CONVERT(Bytea("bytea"), UTF8, LATIN1), CONVERT(Bytea("bytea"), UTF8, LATIN1),
CONVERT(AllTypes.Bytea, UTF8, LATIN1), CONVERT(AllTypes.Bytea, UTF8, LATIN1),
CONVERT_FROM(Bytea("text_in_utf8"), UTF8), CONVERT_FROM(Bytea("text_in_utf8"), UTF8),
@ -1117,8 +1117,8 @@ func TestIntegerOperators(t *testing.T) {
AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"),
AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"),
BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), BIT_NOT(Int32(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"),
BIT_NOT(Int(-11)).AS("bit_not_2"), BIT_NOT(Int32(-11)).AS("bit_not_2"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int8(2))).AS("bit shift left 1"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"),
@ -1130,8 +1130,6 @@ func TestIntegerOperators(t *testing.T) {
CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"),
).LIMIT(2) ).LIMIT(2)
// fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
SELECT all_types.big_int AS "all_types.big_int", SELECT all_types.big_int AS "all_types.big_int",
all_types.big_int_ptr AS "all_types.big_int_ptr", all_types.big_int_ptr AS "all_types.big_int_ptr",
@ -1173,17 +1171,17 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.small_int | $20) AS "bit or 2", (all_types.small_int | $20) AS "bit or 2",
(all_types.small_int # all_types.small_int) AS "bit xor 1", (all_types.small_int # all_types.small_int) AS "bit xor 1",
(all_types.small_int # $21) AS "bit xor 2", (all_types.small_int # $21) AS "bit xor 2",
(~ ($22 * all_types.small_int)) AS "bit_not_1", (~ ($22::integer * all_types.small_int)) AS "bit_not_1",
(~ -11) AS "bit_not_2", (~ $23::integer) AS "bit_not_2",
(all_types.small_int << (all_types.small_int / $23::smallint)) AS "bit shift left 1", (all_types.small_int << (all_types.small_int / $24::smallint)) AS "bit shift left 1",
(all_types.small_int << $24) AS "bit shift left 2", (all_types.small_int << $25) AS "bit shift left 2",
(all_types.small_int >> (all_types.small_int / $25)) AS "bit shift right 1", (all_types.small_int >> (all_types.small_int / $26)) AS "bit shift right 1",
(all_types.small_int >> $26) AS "bit shift right 2", (all_types.small_int >> $27) AS "bit shift right 2",
ABS(all_types.big_int) AS "abs", ABS(all_types.big_int) AS "abs",
SQRT(ABS(all_types.big_int)) AS "sqrt", SQRT(ABS(all_types.big_int)) AS "sqrt",
CBRT(ABS(all_types.big_int)) AS "cbrt" CBRT(ABS(all_types.big_int)) AS "cbrt"
FROM test_sample.all_types FROM test_sample.all_types
LIMIT $27; LIMIT $28;
`) `)
var dest []struct { var dest []struct {

View file

@ -2,6 +2,9 @@ package postgres
import ( import (
"context" "context"
"testing"
"time"
"github.com/bytedance/sonic" "github.com/bytedance/sonic"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/go-jet/jet/v2/internal/utils/ptr"
@ -10,8 +13,6 @@ import (
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook/table"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook2/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/chinook2/table"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
"time"
) )
func TestSelectAlbum(t *testing.T) { func TestSelectAlbum(t *testing.T) {
@ -1301,13 +1302,13 @@ func TestAggregateFunc(t *testing.T) {
skipForCockroachDB(t) skipForCockroachDB(t)
stmt := SELECT( stmt := SELECT(
PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), PERCENTILE_DISC(Double(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"),
PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"), PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"),
PERCENTILE_DISC(RawFloat("(select array_agg(s) from generate_series(0, 1, 0.2) as s)")). PERCENTILE_DISC(RawFloat("(select array_agg(s) from generate_series(0, 1, 0.2) as s)")).
WITHIN_GROUP_ORDER_BY(Invoice.BillingAddress.DESC()).AS("percentile_disc_3"), WITHIN_GROUP_ORDER_BY(Invoice.BillingAddress.DESC()).AS("percentile_disc_3"),
PERCENTILE_CONT(Float(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"), PERCENTILE_CONT(Double(0.3)).WITHIN_GROUP_ORDER_BY(Invoice.Total).AS("percentile_cont_1"),
PERCENTILE_CONT(Float(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_interval"), PERCENTILE_CONT(Double(0.2)).WITHIN_GROUP_ORDER_BY(INTERVAL(1, HOUR).DESC()).AS("percentile_cont_interval"),
MODE().WITHIN_GROUP_ORDER_BY(Invoice.BillingPostalCode.DESC()).AS("mode_1"), MODE().WITHIN_GROUP_ORDER_BY(Invoice.BillingPostalCode.DESC()).AS("mode_1"),
).FROM( ).FROM(