Skip complex expression parenthesis wrap for function parameters.
This commit is contained in:
parent
a506a96d6a
commit
7377e078cd
12 changed files with 102 additions and 83 deletions
|
|
@ -6,7 +6,6 @@ import (
|
|||
|
||||
func TestBoolExpressionEQ(t *testing.T) {
|
||||
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)")
|
||||
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator")
|
||||
}
|
||||
|
||||
func TestBoolExpressionNOT_EQ(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ type binaryOperatorExpression struct {
|
|||
}
|
||||
|
||||
// NewBinaryOperatorExpression creates new binaryOperatorExpression
|
||||
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression {
|
||||
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression {
|
||||
binaryExpression := &binaryOperatorExpression{
|
||||
lhs: lhs,
|
||||
rhs: rhs,
|
||||
|
|
@ -109,23 +109,10 @@ func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additiona
|
|||
|
||||
binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression
|
||||
|
||||
return binaryExpression
|
||||
return complexExpr(binaryExpression)
|
||||
}
|
||||
|
||||
func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
if c.lhs == nil {
|
||||
panic("jet: lhs is nil for '" + c.operator + "' operator")
|
||||
}
|
||||
if c.rhs == nil {
|
||||
panic("jet: rhs is nil for '" + c.operator + "' operator")
|
||||
}
|
||||
|
||||
wrap := !contains(options, NoWrap)
|
||||
|
||||
if wrap {
|
||||
out.WriteString("(")
|
||||
}
|
||||
|
||||
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
|
||||
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
|
||||
serializeOverrideFunc(statement, out, FallTrough(options)...)
|
||||
|
|
@ -134,10 +121,6 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
|
|||
out.WriteString(c.operator)
|
||||
c.rhs.serialize(statement, out, FallTrough(options)...)
|
||||
}
|
||||
|
||||
if wrap {
|
||||
out.WriteString(")")
|
||||
}
|
||||
}
|
||||
|
||||
// A prefix operator Expression
|
||||
|
|
@ -148,27 +131,19 @@ type prefixExpression struct {
|
|||
operator string
|
||||
}
|
||||
|
||||
func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression {
|
||||
func newPrefixOperatorExpression(expression Expression, operator string) Expression {
|
||||
prefixExpression := &prefixExpression{
|
||||
expression: expression,
|
||||
operator: operator,
|
||||
}
|
||||
prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression
|
||||
|
||||
return prefixExpression
|
||||
return complexExpr(prefixExpression)
|
||||
}
|
||||
|
||||
func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
out.WriteString("(")
|
||||
out.WriteString(p.operator)
|
||||
|
||||
if p.expression == nil {
|
||||
panic("jet: nil prefix expression in prefix operator " + p.operator)
|
||||
}
|
||||
|
||||
p.expression.serialize(statement, out, FallTrough(options)...)
|
||||
|
||||
out.WriteString(")")
|
||||
}
|
||||
|
||||
// A postfix operator Expression
|
||||
|
|
@ -191,12 +166,7 @@ func newPostfixOperatorExpression(expression Expression, operator string) *postf
|
|||
}
|
||||
|
||||
func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
if p.expression == nil {
|
||||
panic("jet: nil prefix expression in postfix operator " + p.operator)
|
||||
}
|
||||
|
||||
p.expression.serialize(statement, out, FallTrough(options)...)
|
||||
|
||||
out.WriteString(p.operator)
|
||||
}
|
||||
|
||||
|
|
@ -220,14 +190,10 @@ func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bo
|
|||
|
||||
newBetweenOperator.ExpressionInterfaceImpl.Parent = newBetweenOperator
|
||||
|
||||
return BoolExp(newBetweenOperator)
|
||||
return BoolExp(complexExpr(newBetweenOperator))
|
||||
}
|
||||
|
||||
func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
if !contains(options, NoWrap) {
|
||||
out.WriteString("(")
|
||||
}
|
||||
|
||||
p.expression.serialize(statement, out, FallTrough(options)...)
|
||||
if p.notBetween {
|
||||
out.WriteString("NOT")
|
||||
|
|
@ -236,8 +202,41 @@ func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLB
|
|||
p.min.serialize(statement, out, FallTrough(options)...)
|
||||
out.WriteString("AND")
|
||||
p.max.serialize(statement, out, FallTrough(options)...)
|
||||
}
|
||||
|
||||
type complexExpression struct {
|
||||
ExpressionInterfaceImpl
|
||||
expressions Expression
|
||||
}
|
||||
|
||||
func complexExpr(expressions Expression) Expression {
|
||||
complexExpression := &complexExpression{expressions: expressions}
|
||||
complexExpression.ExpressionInterfaceImpl.Parent = complexExpression
|
||||
|
||||
return complexExpression
|
||||
}
|
||||
|
||||
func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
if !contains(options, NoWrap) {
|
||||
out.WriteString("(")
|
||||
}
|
||||
|
||||
s.expressions.serialize(statement, out, options...) // FallTrough here because complexExpression is just a wrapper
|
||||
|
||||
if !contains(options, NoWrap) {
|
||||
out.WriteString(")")
|
||||
}
|
||||
}
|
||||
|
||||
type skipParenthesisWrap struct {
|
||||
Expression
|
||||
}
|
||||
|
||||
func skipWrap(expression Expression) Expression {
|
||||
return &skipParenthesisWrap{expression}
|
||||
}
|
||||
|
||||
// since the expression is a function parameter, there is no need to wrap it in parentheses
|
||||
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
s.Expression.serialize(statement, out, append(options, NoWrap)...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,6 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestInvalidExpression(t *testing.T) {
|
||||
assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`)
|
||||
}
|
||||
|
||||
func TestExpressionIS_NULL(t *testing.T) {
|
||||
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
|
||||
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ func LOG(floatExpression FloatExpression) FloatExpression {
|
|||
// ----------------- Aggregate functions -------------------//
|
||||
|
||||
// AVG is aggregate function used to calculate avg value from numeric expression
|
||||
func AVG(numericExpression NumericExpression) floatWindowExpression {
|
||||
func AVG(numericExpression Expression) floatWindowExpression {
|
||||
return NewFloatWindowFunc("AVG", numericExpression)
|
||||
}
|
||||
|
||||
|
|
@ -594,7 +594,7 @@ type funcExpressionImpl struct {
|
|||
func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl {
|
||||
funcExp := &funcExpressionImpl{
|
||||
name: name,
|
||||
expressions: expressions,
|
||||
expressions: parameters(expressions),
|
||||
}
|
||||
|
||||
if parent != nil {
|
||||
|
|
@ -606,9 +606,22 @@ func NewFunc(name string, expressions []Expression, parent Expression) *funcExpr
|
|||
return funcExp
|
||||
}
|
||||
|
||||
func parameters(expressions []Expression) []Expression {
|
||||
var ret []Expression
|
||||
|
||||
for _, expression := range expressions {
|
||||
if _, isStatement := expression.(Statement); isStatement {
|
||||
ret = append(ret, expression)
|
||||
} else {
|
||||
ret = append(ret, skipWrap(expression))
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// NewFloatWindowFunc creates new float function with name and expressions
|
||||
func newWindowFunc(name string, expressions ...Expression) windowExpression {
|
||||
|
||||
newFun := NewFunc(name, expressions, nil)
|
||||
windowExpr := newWindowExpression(newFun)
|
||||
newFun.ExpressionInterfaceImpl.Parent = windowExpr
|
||||
|
|
@ -698,12 +711,12 @@ type integerFunc struct {
|
|||
}
|
||||
|
||||
func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
|
||||
floatFunc := &integerFunc{}
|
||||
intFunc := &integerFunc{}
|
||||
|
||||
floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
|
||||
floatFunc.integerInterfaceImpl.parent = floatFunc
|
||||
intFunc.funcExpressionImpl = *NewFunc(name, expressions, intFunc)
|
||||
intFunc.integerInterfaceImpl.parent = intFunc
|
||||
|
||||
return floatFunc
|
||||
return intFunc
|
||||
}
|
||||
|
||||
// NewFloatWindowFunc creates new float function with name and expressions
|
||||
|
|
@ -806,7 +819,7 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
|
|||
return timestampzFunc
|
||||
}
|
||||
|
||||
// Func can be used to call an custom or as of yet unsupported function in the database.
|
||||
// Func can be used to call custom or unsupported database functions.
|
||||
func Func(name string, expressions ...Expression) Expression {
|
||||
return NewFunc(name, expressions, nil)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ var (
|
|||
RawDate = jet.RawDate
|
||||
)
|
||||
|
||||
// Func can be used to call an custom or as of yet unsupported function in the database.
|
||||
// Func can be used to call custom or unsupported database functions.
|
||||
var Func = jet.Func
|
||||
|
||||
// NewEnumValue creates new named enum value
|
||||
|
|
|
|||
|
|
@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
|
|||
))), `
|
||||
SELECT table1.col_int AS "table1.col_int"
|
||||
FROM db.table1
|
||||
WHERE (NOT (EXISTS (
|
||||
WHERE NOT (EXISTS (
|
||||
SELECT table2.col_int AS "table2.col_int"
|
||||
FROM db.table2
|
||||
WHERE table1.col_int = table2.col_int
|
||||
)));
|
||||
));
|
||||
`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ var (
|
|||
RawDate = jet.RawDate
|
||||
)
|
||||
|
||||
// Func can be used to call an custom or as of yet unsupported function in the database.
|
||||
// Func can be used to call custom or unsupported database functions.
|
||||
var Func = jet.Func
|
||||
|
||||
// NewEnumValue creates new named enum value
|
||||
|
|
|
|||
12
postgres/functions_test.go
Normal file
12
postgres/functions_test.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package postgres
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestROW(t *testing.T) {
|
||||
assertSerialize(t, ROW(SELECT(Int(1))), `ROW((
|
||||
SELECT $1
|
||||
))`)
|
||||
assertSerialize(t, ROW(Int(1), SELECT(Int(2)), Float(11.11)), `ROW($1, (
|
||||
SELECT $2
|
||||
), $3)`)
|
||||
}
|
||||
|
|
@ -90,7 +90,7 @@ var (
|
|||
RawDate = jet.RawDate
|
||||
)
|
||||
|
||||
// Func can be used to call an custom or as of yet unsupported function in the database.
|
||||
// Func can be used to call custom or unsupported database functions.
|
||||
var Func = jet.Func
|
||||
|
||||
// NewEnumValue creates new named enum value
|
||||
|
|
|
|||
|
|
@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
|
|||
))), `
|
||||
SELECT table1.col_int AS "table1.col_int"
|
||||
FROM db.table1
|
||||
WHERE (NOT (EXISTS (
|
||||
WHERE NOT (EXISTS (
|
||||
SELECT table2.col_int AS "table2.col_int"
|
||||
FROM db.table2
|
||||
WHERE table1.col_int = table2.col_int
|
||||
)));
|
||||
));
|
||||
`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -264,22 +264,22 @@ SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1",
|
|||
(all_types.'numeric' > ?) AS "gt2",
|
||||
(all_types.'numeric' BETWEEN ? AND all_types.'decimal') AS "between",
|
||||
(all_types.'numeric' NOT BETWEEN (all_types.'decimal' * ?) AND ?) AS "not_between",
|
||||
TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1",
|
||||
TRUNCATE((all_types.'decimal' + ?), ?) AS "add2",
|
||||
TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1",
|
||||
TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2",
|
||||
TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1",
|
||||
TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2",
|
||||
TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1",
|
||||
TRUNCATE((all_types.'decimal' / ?), ?) AS "div2",
|
||||
TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1",
|
||||
TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2",
|
||||
TRUNCATE(all_types.'decimal' + all_types.'decimal', ?) AS "add1",
|
||||
TRUNCATE(all_types.'decimal' + ?, ?) AS "add2",
|
||||
TRUNCATE(all_types.'decimal' - all_types.decimal_ptr, ?) AS "sub1",
|
||||
TRUNCATE(all_types.'decimal' - ?, ?) AS "sub2",
|
||||
TRUNCATE(all_types.'decimal' * all_types.decimal_ptr, ?) AS "mul1",
|
||||
TRUNCATE(all_types.'decimal' * ?, ?) AS "mul2",
|
||||
TRUNCATE(all_types.'decimal' / all_types.decimal_ptr, ?) AS "div1",
|
||||
TRUNCATE(all_types.'decimal' / ?, ?) AS "div2",
|
||||
TRUNCATE(all_types.'decimal' % all_types.decimal_ptr, ?) AS "mod1",
|
||||
TRUNCATE(all_types.'decimal' % ?, ?) AS "mod2",
|
||||
TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1",
|
||||
TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2",
|
||||
TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs",
|
||||
TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power",
|
||||
TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt",
|
||||
TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt",
|
||||
TRUNCATE(POWER(all_types.'decimal', ? / ?), ?) AS "cbrt",
|
||||
CEIL(all_types.'real') AS "ceil",
|
||||
FLOOR(all_types.'real') AS "floor",
|
||||
ROUND(all_types.'decimal') AS "round1",
|
||||
|
|
@ -395,7 +395,7 @@ SELECT all_types.big_int AS "all_types.big_int",
|
|||
(all_types.big_int DIV ?) AS "div2",
|
||||
(all_types.big_int % all_types.big_int) AS "mod1",
|
||||
(all_types.big_int % ?) AS "mod2",
|
||||
POW(all_types.small_int, (all_types.small_int DIV ?)) AS "pow1",
|
||||
POW(all_types.small_int, all_types.small_int DIV ?) AS "pow1",
|
||||
POW(all_types.small_int, ?) AS "pow2",
|
||||
(all_types.small_int & all_types.small_int) AS "bit_and1",
|
||||
(all_types.small_int & all_types.small_int) AS "bit_and2",
|
||||
|
|
@ -411,7 +411,7 @@ SELECT all_types.big_int AS "all_types.big_int",
|
|||
(all_types.small_int >> ?) AS "bit shift right 2",
|
||||
ABS(all_types.big_int) AS "abs",
|
||||
SQRT(ABS(all_types.big_int)) AS "sqrt",
|
||||
POWER(ABS(all_types.big_int), (? / ?)) AS "cbrt"
|
||||
POWER(ABS(all_types.big_int), ? / ?) AS "cbrt"
|
||||
FROM test_sample.all_types
|
||||
LIMIT ?;
|
||||
`, "''", "`"))
|
||||
|
|
|
|||
|
|
@ -563,16 +563,16 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1",
|
|||
(all_types.numeric > $10) AS "gt2",
|
||||
(all_types.numeric BETWEEN $11 AND all_types.decimal) AS "between",
|
||||
(all_types.numeric NOT BETWEEN (all_types.decimal * $12) AND $13) AS "not_between",
|
||||
TRUNC((all_types.decimal + all_types.decimal), $14::smallint) AS "add1",
|
||||
TRUNC((all_types.decimal + $15), $16::smallint) AS "add2",
|
||||
TRUNC((all_types.decimal - all_types.decimal_ptr), $17::integer) AS "sub1",
|
||||
TRUNC((all_types.decimal - $18), $19::smallint) AS "sub2",
|
||||
TRUNC((all_types.decimal * all_types.decimal_ptr), $20::smallint) AS "mul1",
|
||||
TRUNC((all_types.decimal * $21), $22::integer) AS "mul2",
|
||||
TRUNC((all_types.decimal / all_types.decimal_ptr), $23::integer) AS "div1",
|
||||
TRUNC((all_types.decimal / $24), $25::smallint) AS "div2",
|
||||
TRUNC((all_types.decimal % all_types.decimal_ptr), $26::smallint) AS "mod1",
|
||||
TRUNC((all_types.decimal % $27), $28::smallint) AS "mod2",
|
||||
TRUNC(all_types.decimal + all_types.decimal, $14::smallint) AS "add1",
|
||||
TRUNC(all_types.decimal + $15, $16::smallint) AS "add2",
|
||||
TRUNC(all_types.decimal - all_types.decimal_ptr, $17::integer) AS "sub1",
|
||||
TRUNC(all_types.decimal - $18, $19::smallint) AS "sub2",
|
||||
TRUNC(all_types.decimal * all_types.decimal_ptr, $20::smallint) AS "mul1",
|
||||
TRUNC(all_types.decimal * $21, $22::integer) AS "mul2",
|
||||
TRUNC(all_types.decimal / all_types.decimal_ptr, $23::integer) AS "div1",
|
||||
TRUNC(all_types.decimal / $24, $25::smallint) AS "div2",
|
||||
TRUNC(all_types.decimal % all_types.decimal_ptr, $26::smallint) AS "mod1",
|
||||
TRUNC(all_types.decimal % $27, $28::smallint) AS "mod2",
|
||||
TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $29::smallint) AS "pow1",
|
||||
TRUNC(POW(all_types.decimal, $30), $31::smallint) AS "pow2",
|
||||
TRUNC(ABS(all_types.decimal), $32::smallint) AS "abs",
|
||||
|
|
@ -698,7 +698,7 @@ SELECT all_types.big_int AS "all_types.big_int",
|
|||
(all_types.big_int / $16::integer) AS "div2",
|
||||
(all_types.big_int % all_types.big_int) AS "mod1",
|
||||
(all_types.big_int % $17::bigint) AS "mod2",
|
||||
POW(all_types.small_int, (all_types.small_int / $18::smallint)) AS "pow1",
|
||||
POW(all_types.small_int, all_types.small_int / $18::smallint) AS "pow1",
|
||||
POW(all_types.small_int, $19::smallint) AS "pow2",
|
||||
(all_types.small_int & all_types.small_int) AS "bit_and1",
|
||||
(all_types.small_int & all_types.small_int) AS "bit_and2",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue