Skip complex expression parenthesis wrap for function parameters.

This commit is contained in:
go-jet 2022-01-10 16:57:57 +01:00
parent a506a96d6a
commit 7377e078cd
12 changed files with 102 additions and 83 deletions

View file

@ -6,7 +6,6 @@ import (
func TestBoolExpressionEQ(t *testing.T) {
assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)")
assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator")
}
func TestBoolExpressionNOT_EQ(t *testing.T) {

View file

@ -96,7 +96,7 @@ type binaryOperatorExpression struct {
}
// NewBinaryOperatorExpression creates new binaryOperatorExpression
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression {
func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression {
binaryExpression := &binaryOperatorExpression{
lhs: lhs,
rhs: rhs,
@ -109,23 +109,10 @@ func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additiona
binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression
return binaryExpression
return complexExpr(binaryExpression)
}
func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if c.lhs == nil {
panic("jet: lhs is nil for '" + c.operator + "' operator")
}
if c.rhs == nil {
panic("jet: rhs is nil for '" + c.operator + "' operator")
}
wrap := !contains(options, NoWrap)
if wrap {
out.WriteString("(")
}
if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil {
serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam)
serializeOverrideFunc(statement, out, FallTrough(options)...)
@ -134,10 +121,6 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
out.WriteString(c.operator)
c.rhs.serialize(statement, out, FallTrough(options)...)
}
if wrap {
out.WriteString(")")
}
}
// A prefix operator Expression
@ -148,27 +131,19 @@ type prefixExpression struct {
operator string
}
func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression {
func newPrefixOperatorExpression(expression Expression, operator string) Expression {
prefixExpression := &prefixExpression{
expression: expression,
operator: operator,
}
prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression
return prefixExpression
return complexExpr(prefixExpression)
}
func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
out.WriteString(p.operator)
if p.expression == nil {
panic("jet: nil prefix expression in prefix operator " + p.operator)
}
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(")")
}
// A postfix operator Expression
@ -191,12 +166,7 @@ func newPostfixOperatorExpression(expression Expression, operator string) *postf
}
func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if p.expression == nil {
panic("jet: nil prefix expression in postfix operator " + p.operator)
}
p.expression.serialize(statement, out, FallTrough(options)...)
out.WriteString(p.operator)
}
@ -220,14 +190,10 @@ func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bo
newBetweenOperator.ExpressionInterfaceImpl.Parent = newBetweenOperator
return BoolExp(newBetweenOperator)
return BoolExp(complexExpr(newBetweenOperator))
}
func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
}
p.expression.serialize(statement, out, FallTrough(options)...)
if p.notBetween {
out.WriteString("NOT")
@ -236,8 +202,41 @@ func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLB
p.min.serialize(statement, out, FallTrough(options)...)
out.WriteString("AND")
p.max.serialize(statement, out, FallTrough(options)...)
}
type complexExpression struct {
ExpressionInterfaceImpl
expressions Expression
}
func complexExpr(expressions Expression) Expression {
complexExpression := &complexExpression{expressions: expressions}
complexExpression.ExpressionInterfaceImpl.Parent = complexExpression
return complexExpression
}
func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
}
s.expressions.serialize(statement, out, options...) // FallTrough here because complexExpression is just a wrapper
if !contains(options, NoWrap) {
out.WriteString(")")
}
}
type skipParenthesisWrap struct {
Expression
}
func skipWrap(expression Expression) Expression {
return &skipParenthesisWrap{expression}
}
// since the expression is a function parameter, there is no need to wrap it in parentheses
func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
s.Expression.serialize(statement, out, append(options, NoWrap)...)
}

View file

@ -4,10 +4,6 @@ import (
"testing"
)
func TestInvalidExpression(t *testing.T) {
assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`)
}
func TestExpressionIS_NULL(t *testing.T) {
assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL")
assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL")

View file

@ -81,7 +81,7 @@ func LOG(floatExpression FloatExpression) FloatExpression {
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
func AVG(numericExpression NumericExpression) floatWindowExpression {
func AVG(numericExpression Expression) floatWindowExpression {
return NewFloatWindowFunc("AVG", numericExpression)
}
@ -594,7 +594,7 @@ type funcExpressionImpl struct {
func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl {
funcExp := &funcExpressionImpl{
name: name,
expressions: expressions,
expressions: parameters(expressions),
}
if parent != nil {
@ -606,9 +606,22 @@ func NewFunc(name string, expressions []Expression, parent Expression) *funcExpr
return funcExp
}
func parameters(expressions []Expression) []Expression {
var ret []Expression
for _, expression := range expressions {
if _, isStatement := expression.(Statement); isStatement {
ret = append(ret, expression)
} else {
ret = append(ret, skipWrap(expression))
}
}
return ret
}
// NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := NewFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun)
newFun.ExpressionInterfaceImpl.Parent = windowExpr
@ -698,12 +711,12 @@ type integerFunc struct {
}
func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
floatFunc := &integerFunc{}
intFunc := &integerFunc{}
floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatFunc.integerInterfaceImpl.parent = floatFunc
intFunc.funcExpressionImpl = *NewFunc(name, expressions, intFunc)
intFunc.integerInterfaceImpl.parent = intFunc
return floatFunc
return intFunc
}
// NewFloatWindowFunc creates new float function with name and expressions
@ -806,7 +819,7 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
return timestampzFunc
}
// Func can be used to call an custom or as of yet unsupported function in the database.
// Func can be used to call custom or unsupported database functions.
func Func(name string, expressions ...Expression) Expression {
return NewFunc(name, expressions, nil)
}