Skip complex expression parenthesis wrap for function parameters.

This commit is contained in:
go-jet 2022-01-10 16:57:57 +01:00
parent a506a96d6a
commit 7377e078cd
12 changed files with 102 additions and 83 deletions

View file

@ -6,7 +6,6 @@ import (
func TestBoolExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)")
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator")
}
func TestBoolExpressionNOT_EQ(t *testing.T) {

View file

@ -96,7 +96,7 @@ type binaryOperatorExpression struct {
}
// NewBinaryOperatorExpression creates new binaryOperatorExpression
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression {
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression {
binaryExpression := &binaryOperatorExpression{
lhs: lhs,
rhs: rhs,
@ -109,23 +109,10 @@ func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additiona
binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression
return binaryExpression
return complexExpr(binaryExpression)
}
func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.lhs == nil {
panic("jet: lhs is nil for '" + c.operator + "' operator")
}
if c.rhs == nil {
panic("jet: rhs is nil for '" + c.operator + "' operator")
}
wrap := !contains(options, NoWrap)
if wrap {
out.WriteString("(")
}
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, FallTrough(options)...)
@ -134,10 +121,6 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
out.WriteString(c.operator)
c.rhs.serialize(statement, out, FallTrough(options)...)
}
if wrap {
out.WriteString(")")
}
}
// A prefix operator Expression
@ -148,27 +131,19 @@ type prefixExpression struct {
operator string
}
func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression {
func newPrefixOperatorExpression(expression Expression, operator string) Expression {
prefixExpression := &prefixExpression{
expression: expression,
operator: operator,
}
prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression
return prefixExpression
return complexExpr(prefixExpression)
}
func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
out.WriteString(p.operator)
if p.expression == nil {
panic("jet: nil prefix expression in prefix operator " + p.operator)
}
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(")")
}
// A postfix operator Expression
@ -191,12 +166,7 @@ func newPostfixOperatorExpression(expression Expression, operator string) *postf
}
func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if p.expression == nil {
panic("jet: nil prefix expression in postfix operator " + p.operator)
}
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(p.operator)
}
@ -220,14 +190,10 @@ func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bo
newBetweenOperator.ExpressionInterfaceImpl.Parent = newBetweenOperator
return BoolExp(newBetweenOperator)
return BoolExp(complexExpr(newBetweenOperator))
}
func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
}
p.expression.serialize(statement, out, FallTrough(options)...)
if p.notBetween {
out.WriteString("NOT")
@ -236,8 +202,41 @@ func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLB
p.min.serialize(statement, out, FallTrough(options)...)
out.WriteString("AND")
p.max.serialize(statement, out, FallTrough(options)...)
}
type complexExpression struct {
ExpressionInterfaceImpl
expressions Expression
}
func complexExpr(expressions Expression) Expression {
complexExpression := &complexExpression{expressions: expressions}
complexExpression.ExpressionInterfaceImpl.Parent = complexExpression
return complexExpression
}
func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
}
s.expressions.serialize(statement, out, options...) // FallTrough here because complexExpression is just a wrapper
if !contains(options, NoWrap) {
out.WriteString(")")
}
}
type skipParenthesisWrap struct {
Expression
}
func skipWrap(expression Expression) Expression {
return &skipParenthesisWrap{expression}
}
// since the expression is a function parameter, there is no need to wrap it in parentheses
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.Expression.serialize(statement, out, append(options, NoWrap)...)
}

View file

@ -4,10 +4,6 @@ import (
"testing"
)
func TestInvalidExpression(t *testing.T) {
assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`)
}
func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")

View file

@ -81,7 +81,7 @@ func LOG(floatExpression FloatExpression) FloatExpression {
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) floatWindowExpression {
func AVG(numericExpression Expression) floatWindowExpression {
return NewFloatWindowFunc("AVG", numericExpression)
}
@ -594,7 +594,7 @@ type funcExpressionImpl struct {
func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl {
funcExp := &funcExpressionImpl{
name: name,
expressions: expressions,
expressions: parameters(expressions),
}
if parent != nil {
@ -606,9 +606,22 @@ func NewFunc(name string, expressions []Expression, parent Expression) *funcExpr
return funcExp
}
func parameters(expressions []Expression) []Expression {
var ret []Expression
for _, expression := range expressions {
if _, isStatement := expression.(Statement); isStatement {
ret = append(ret, expression)
} else {
ret = append(ret, skipWrap(expression))
}
}
return ret
}
// NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := NewFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun)
newFun.ExpressionInterfaceImpl.Parent = windowExpr
@ -698,12 +711,12 @@ type integerFunc struct {
}
func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
floatFunc := &integerFunc{}
intFunc := &integerFunc{}
floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatFunc.integerInterfaceImpl.parent = floatFunc
intFunc.funcExpressionImpl = *NewFunc(name, expressions, intFunc)
intFunc.integerInterfaceImpl.parent = intFunc
return floatFunc
return intFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
@ -806,7 +819,7 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
return timestampzFunc
}
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
func Func(name string, expressions ...Expression) Expression {
return NewFunc(name, expressions, nil)
}

View file

@ -87,7 +87,7 @@ var (
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
var Func = jet.Func
// NewEnumValue creates new named enum value

View file

@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE (NOT (EXISTS (
WHERE NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
WHERE table1.col_int = table2.col_int
)));
));
`)
}

View file

@ -100,7 +100,7 @@ var (
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
var Func = jet.Func
// NewEnumValue creates new named enum value

View file

@ -0,0 +1,12 @@
package postgres
import "testing"
func TestROW(t *testing.T) {
assertSerialize(t, ROW(SELECT(Int(1))), `ROW((
SELECT $1
))`)
assertSerialize(t, ROW(Int(1), SELECT(Int(2)), Float(11.11)), `ROW($1, (
SELECT $2
), $3)`)
}

View file

@ -90,7 +90,7 @@ var (
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
var Func = jet.Func
// NewEnumValue creates new named enum value

View file

@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE (NOT (EXISTS (
WHERE NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
WHERE table1.col_int = table2.col_int
)));
));
`)
}

View file

@ -264,22 +264,22 @@ SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1",
(all_types.'numeric' > ?) AS "gt2",
(all_types.'numeric' BETWEEN ? AND all_types.'decimal') AS "between",
(all_types.'numeric' NOT BETWEEN (all_types.'decimal' * ?) AND ?) AS "not_between",
TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1",
TRUNCATE((all_types.'decimal' + ?), ?) AS "add2",
TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1",
TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2",
TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1",
TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2",
TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1",
TRUNCATE((all_types.'decimal' / ?), ?) AS "div2",
TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1",
TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2",
TRUNCATE(all_types.'decimal' + all_types.'decimal', ?) AS "add1",
TRUNCATE(all_types.'decimal' + ?, ?) AS "add2",
TRUNCATE(all_types.'decimal' - all_types.decimal_ptr, ?) AS "sub1",
TRUNCATE(all_types.'decimal' - ?, ?) AS "sub2",
TRUNCATE(all_types.'decimal' * all_types.decimal_ptr, ?) AS "mul1",
TRUNCATE(all_types.'decimal' * ?, ?) AS "mul2",
TRUNCATE(all_types.'decimal' / all_types.decimal_ptr, ?) AS "div1",
TRUNCATE(all_types.'decimal' / ?, ?) AS "div2",
TRUNCATE(all_types.'decimal' % all_types.decimal_ptr, ?) AS "mod1",
TRUNCATE(all_types.'decimal' % ?, ?) AS "mod2",
TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1",
TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2",
TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs",
TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power",
TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt",
TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt",
TRUNCATE(POWER(all_types.'decimal', ? / ?), ?) AS "cbrt",
CEIL(all_types.'real') AS "ceil",
FLOOR(all_types.'real') AS "floor",
ROUND(all_types.'decimal') AS "round1",
@ -395,7 +395,7 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.big_int DIV ?) AS "div2",
(all_types.big_int % all_types.big_int) AS "mod1",
(all_types.big_int % ?) AS "mod2",
POW(all_types.small_int, (all_types.small_int DIV ?)) AS "pow1",
POW(all_types.small_int, all_types.small_int DIV ?) AS "pow1",
POW(all_types.small_int, ?) AS "pow2",
(all_types.small_int & all_types.small_int) AS "bit_and1",
(all_types.small_int & all_types.small_int) AS "bit_and2",
@ -411,7 +411,7 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.small_int >> ?) AS "bit shift right 2",
ABS(all_types.big_int) AS "abs",
SQRT(ABS(all_types.big_int)) AS "sqrt",
POWER(ABS(all_types.big_int), (? / ?)) AS "cbrt"
POWER(ABS(all_types.big_int), ? / ?) AS "cbrt"
FROM test_sample.all_types
LIMIT ?;
`, "''", "`"))

View file

@ -563,16 +563,16 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.numeric > $10) AS "gt2",
(all_types.numeric BETWEEN $11 AND all_types.decimal) AS "between",
(all_types.numeric NOT BETWEEN (all_types.decimal * $12) AND $13) AS "not_between",
TRUNC((all_types.decimal + all_types.decimal), $14::smallint) AS "add1",
TRUNC((all_types.decimal + $15), $16::smallint) AS "add2",
TRUNC((all_types.decimal - all_types.decimal_ptr), $17::integer) AS "sub1",
TRUNC((all_types.decimal - $18), $19::smallint) AS "sub2",
TRUNC((all_types.decimal * all_types.decimal_ptr), $20::smallint) AS "mul1",
TRUNC((all_types.decimal * $21), $22::integer) AS "mul2",
TRUNC((all_types.decimal / all_types.decimal_ptr), $23::integer) AS "div1",
TRUNC((all_types.decimal / $24), $25::smallint) AS "div2",
TRUNC((all_types.decimal % all_types.decimal_ptr), $26::smallint) AS "mod1",
TRUNC((all_types.decimal % $27), $28::smallint) AS "mod2",
TRUNC(all_types.decimal + all_types.decimal, $14::smallint) AS "add1",
TRUNC(all_types.decimal + $15, $16::smallint) AS "add2",
TRUNC(all_types.decimal - all_types.decimal_ptr, $17::integer) AS "sub1",
TRUNC(all_types.decimal - $18, $19::smallint) AS "sub2",
TRUNC(all_types.decimal * all_types.decimal_ptr, $20::smallint) AS "mul1",
TRUNC(all_types.decimal * $21, $22::integer) AS "mul2",
TRUNC(all_types.decimal / all_types.decimal_ptr, $23::integer) AS "div1",
TRUNC(all_types.decimal / $24, $25::smallint) AS "div2",
TRUNC(all_types.decimal % all_types.decimal_ptr, $26::smallint) AS "mod1",
TRUNC(all_types.decimal % $27, $28::smallint) AS "mod2",
TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $29::smallint) AS "pow1",
TRUNC(POW(all_types.decimal, $30), $31::smallint) AS "pow2",
TRUNC(ABS(all_types.decimal), $32::smallint) AS "abs",
@ -698,7 +698,7 @@ SELECT all_types.big_int AS "all_types.big_int",
(all_types.big_int / $16::integer) AS "div2",
(all_types.big_int % all_types.big_int) AS "mod1",
(all_types.big_int % $17::bigint) AS "mod2",
POW(all_types.small_int, (all_types.small_int / $18::smallint)) AS "pow1",
POW(all_types.small_int, all_types.small_int / $18::smallint) AS "pow1",
POW(all_types.small_int, $19::smallint) AS "pow2",
(all_types.small_int & all_types.small_int) AS "bit_and1",
(all_types.small_int & all_types.small_int) AS "bit_and2",