From 7377e078cdd1063898cfe760d49517e62261198f Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 10 Jan 2022 16:57:57 +0100 Subject: [PATCH] Skip complex expression parenthesis wrap for function parameters. --- internal/jet/bool_expression_test.go | 1 - internal/jet/expression.go | 77 ++++++++++++++-------------- internal/jet/expression_test.go | 4 -- internal/jet/func_expression.go | 29 ++++++++--- mysql/expressions.go | 2 +- mysql/select_statement_test.go | 4 +- postgres/expressions.go | 2 +- postgres/functions_test.go | 12 +++++ sqlite/expressions.go | 2 +- sqlite/select_statement_test.go | 4 +- tests/mysql/alltypes_test.go | 26 +++++----- tests/postgres/alltypes_test.go | 22 ++++---- 12 files changed, 102 insertions(+), 83 deletions(-) create mode 100644 postgres/functions_test.go diff --git a/internal/jet/bool_expression_test.go b/internal/jet/bool_expression_test.go index 765cdab..c6cdbbb 100644 --- a/internal/jet/bool_expression_test.go +++ b/internal/jet/bool_expression_test.go @@ -6,7 +6,6 @@ import ( func TestBoolExpressionEQ(t *testing.T) { assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)") - assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator") } func TestBoolExpressionNOT_EQ(t *testing.T) { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index d748dcf..e657f30 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -96,7 +96,7 @@ type binaryOperatorExpression struct { } // NewBinaryOperatorExpression creates new binaryOperatorExpression -func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) *binaryOperatorExpression { +func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additionalParam ...Expression) Expression { binaryExpression := &binaryOperatorExpression{ lhs: lhs, rhs: rhs, @@ -109,23 +109,10 @@ func NewBinaryOperatorExpression(lhs, rhs Serializer, operator string, additiona binaryExpression.ExpressionInterfaceImpl.Parent = binaryExpression - return binaryExpression + return complexExpr(binaryExpression) } func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if c.lhs == nil { - panic("jet: lhs is nil for '" + c.operator + "' operator") - } - if c.rhs == nil { - panic("jet: rhs is nil for '" + c.operator + "' operator") - } - - wrap := !contains(options, NoWrap) - - if wrap { - out.WriteString("(") - } - if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) serializeOverrideFunc(statement, out, FallTrough(options)...) @@ -134,10 +121,6 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu out.WriteString(c.operator) c.rhs.serialize(statement, out, FallTrough(options)...) } - - if wrap { - out.WriteString(")") - } } // A prefix operator Expression @@ -148,27 +131,19 @@ type prefixExpression struct { operator string } -func newPrefixOperatorExpression(expression Expression, operator string) *prefixExpression { +func newPrefixOperatorExpression(expression Expression, operator string) Expression { prefixExpression := &prefixExpression{ expression: expression, operator: operator, } prefixExpression.ExpressionInterfaceImpl.Parent = prefixExpression - return prefixExpression + return complexExpr(prefixExpression) } func (p *prefixExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("(") out.WriteString(p.operator) - - if p.expression == nil { - panic("jet: nil prefix expression in prefix operator " + p.operator) - } - p.expression.serialize(statement, out, FallTrough(options)...) - - out.WriteString(")") } // A postfix operator Expression @@ -191,12 +166,7 @@ func newPostfixOperatorExpression(expression Expression, operator string) *postf } func (p *postfixOpExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if p.expression == nil { - panic("jet: nil prefix expression in postfix operator " + p.operator) - } - p.expression.serialize(statement, out, FallTrough(options)...) - out.WriteString(p.operator) } @@ -220,14 +190,10 @@ func NewBetweenOperatorExpression(expression, min, max Expression, notBetween bo newBetweenOperator.ExpressionInterfaceImpl.Parent = newBetweenOperator - return BoolExp(newBetweenOperator) + return BoolExp(complexExpr(newBetweenOperator)) } func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - if !contains(options, NoWrap) { - out.WriteString("(") - } - p.expression.serialize(statement, out, FallTrough(options)...) if p.notBetween { out.WriteString("NOT") @@ -236,8 +202,41 @@ func (p *betweenOperatorExpression) serialize(statement StatementType, out *SQLB p.min.serialize(statement, out, FallTrough(options)...) out.WriteString("AND") p.max.serialize(statement, out, FallTrough(options)...) +} + +type complexExpression struct { + ExpressionInterfaceImpl + expressions Expression +} + +func complexExpr(expressions Expression) Expression { + complexExpression := &complexExpression{expressions: expressions} + complexExpression.ExpressionInterfaceImpl.Parent = complexExpression + + return complexExpression +} + +func (s *complexExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if !contains(options, NoWrap) { + out.WriteString("(") + } + + s.expressions.serialize(statement, out, options...) // FallTrough here because complexExpression is just a wrapper if !contains(options, NoWrap) { out.WriteString(")") } } + +type skipParenthesisWrap struct { + Expression +} + +func skipWrap(expression Expression) Expression { + return &skipParenthesisWrap{expression} +} + +// since the expression is a function parameter, there is no need to wrap it in parentheses +func (s *skipParenthesisWrap) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + s.Expression.serialize(statement, out, append(options, NoWrap)...) +} diff --git a/internal/jet/expression_test.go b/internal/jet/expression_test.go index 7b7bed6..74e6a05 100644 --- a/internal/jet/expression_test.go +++ b/internal/jet/expression_test.go @@ -4,10 +4,6 @@ import ( "testing" ) -func TestInvalidExpression(t *testing.T) { - assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`) -} - func TestExpressionIS_NULL(t *testing.T) { assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 9a647e9..3e40edf 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -81,7 +81,7 @@ func LOG(floatExpression FloatExpression) FloatExpression { // ----------------- Aggregate functions -------------------// // AVG is aggregate function used to calculate avg value from numeric expression -func AVG(numericExpression NumericExpression) floatWindowExpression { +func AVG(numericExpression Expression) floatWindowExpression { return NewFloatWindowFunc("AVG", numericExpression) } @@ -594,7 +594,7 @@ type funcExpressionImpl struct { func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { funcExp := &funcExpressionImpl{ name: name, - expressions: expressions, + expressions: parameters(expressions), } if parent != nil { @@ -606,9 +606,22 @@ func NewFunc(name string, expressions []Expression, parent Expression) *funcExpr return funcExp } +func parameters(expressions []Expression) []Expression { + var ret []Expression + + for _, expression := range expressions { + if _, isStatement := expression.(Statement); isStatement { + ret = append(ret, expression) + } else { + ret = append(ret, skipWrap(expression)) + } + } + + return ret +} + // NewFloatWindowFunc creates new float function with name and expressions func newWindowFunc(name string, expressions ...Expression) windowExpression { - newFun := NewFunc(name, expressions, nil) windowExpr := newWindowExpression(newFun) newFun.ExpressionInterfaceImpl.Parent = windowExpr @@ -698,12 +711,12 @@ type integerFunc struct { } func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { - floatFunc := &integerFunc{} + intFunc := &integerFunc{} - floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) - floatFunc.integerInterfaceImpl.parent = floatFunc + intFunc.funcExpressionImpl = *NewFunc(name, expressions, intFunc) + intFunc.integerInterfaceImpl.parent = intFunc - return floatFunc + return intFunc } // NewFloatWindowFunc creates new float function with name and expressions @@ -806,7 +819,7 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { return timestampzFunc } -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. func Func(name string, expressions ...Expression) Expression { return NewFunc(name, expressions, nil) } diff --git a/mysql/expressions.go b/mysql/expressions.go index b585719..5ab9c1f 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -87,7 +87,7 @@ var ( RawDate = jet.RawDate ) -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. var Func = jet.Func // NewEnumValue creates new named enum value diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go index 630f800..37827d5 100644 --- a/mysql/select_statement_test.go +++ b/mysql/select_statement_test.go @@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) { ))), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 -WHERE (NOT (EXISTS ( +WHERE NOT (EXISTS ( SELECT table2.col_int AS "table2.col_int" FROM db.table2 WHERE table1.col_int = table2.col_int - ))); + )); `) } diff --git a/postgres/expressions.go b/postgres/expressions.go index 8d0be88..c5c2065 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -100,7 +100,7 @@ var ( RawDate = jet.RawDate ) -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. var Func = jet.Func // NewEnumValue creates new named enum value diff --git a/postgres/functions_test.go b/postgres/functions_test.go new file mode 100644 index 0000000..4190f70 --- /dev/null +++ b/postgres/functions_test.go @@ -0,0 +1,12 @@ +package postgres + +import "testing" + +func TestROW(t *testing.T) { + assertSerialize(t, ROW(SELECT(Int(1))), `ROW(( + SELECT $1 +))`) + assertSerialize(t, ROW(Int(1), SELECT(Int(2)), Float(11.11)), `ROW($1, ( + SELECT $2 +), $3)`) +} diff --git a/sqlite/expressions.go b/sqlite/expressions.go index d1d4737..8d45735 100644 --- a/sqlite/expressions.go +++ b/sqlite/expressions.go @@ -90,7 +90,7 @@ var ( RawDate = jet.RawDate ) -// Func can be used to call an custom or as of yet unsupported function in the database. +// Func can be used to call custom or unsupported database functions. var Func = jet.Func // NewEnumValue creates new named enum value diff --git a/sqlite/select_statement_test.go b/sqlite/select_statement_test.go index 0ba76f0..a42fe06 100644 --- a/sqlite/select_statement_test.go +++ b/sqlite/select_statement_test.go @@ -147,10 +147,10 @@ func TestSelect_NOT_EXISTS(t *testing.T) { ))), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 -WHERE (NOT (EXISTS ( +WHERE NOT (EXISTS ( SELECT table2.col_int AS "table2.col_int" FROM db.table2 WHERE table1.col_int = table2.col_int - ))); + )); `) } diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 15a4862..428a0e6 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -264,22 +264,22 @@ SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1", (all_types.'numeric' > ?) AS "gt2", (all_types.'numeric' BETWEEN ? AND all_types.'decimal') AS "between", (all_types.'numeric' NOT BETWEEN (all_types.'decimal' * ?) AND ?) AS "not_between", - TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1", - TRUNCATE((all_types.'decimal' + ?), ?) AS "add2", - TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1", - TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2", - TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1", - TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2", - TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1", - TRUNCATE((all_types.'decimal' / ?), ?) AS "div2", - TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1", - TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2", + TRUNCATE(all_types.'decimal' + all_types.'decimal', ?) AS "add1", + TRUNCATE(all_types.'decimal' + ?, ?) AS "add2", + TRUNCATE(all_types.'decimal' - all_types.decimal_ptr, ?) AS "sub1", + TRUNCATE(all_types.'decimal' - ?, ?) AS "sub2", + TRUNCATE(all_types.'decimal' * all_types.decimal_ptr, ?) AS "mul1", + TRUNCATE(all_types.'decimal' * ?, ?) AS "mul2", + TRUNCATE(all_types.'decimal' / all_types.decimal_ptr, ?) AS "div1", + TRUNCATE(all_types.'decimal' / ?, ?) AS "div2", + TRUNCATE(all_types.'decimal' % all_types.decimal_ptr, ?) AS "mod1", + TRUNCATE(all_types.'decimal' % ?, ?) AS "mod2", TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1", TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2", TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs", TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power", TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt", - TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt", + TRUNCATE(POWER(all_types.'decimal', ? / ?), ?) AS "cbrt", CEIL(all_types.'real') AS "ceil", FLOOR(all_types.'real') AS "floor", ROUND(all_types.'decimal') AS "round1", @@ -395,7 +395,7 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.big_int DIV ?) AS "div2", (all_types.big_int % all_types.big_int) AS "mod1", (all_types.big_int % ?) AS "mod2", - POW(all_types.small_int, (all_types.small_int DIV ?)) AS "pow1", + POW(all_types.small_int, all_types.small_int DIV ?) AS "pow1", POW(all_types.small_int, ?) AS "pow2", (all_types.small_int & all_types.small_int) AS "bit_and1", (all_types.small_int & all_types.small_int) AS "bit_and2", @@ -411,7 +411,7 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.small_int >> ?) AS "bit shift right 2", ABS(all_types.big_int) AS "abs", SQRT(ABS(all_types.big_int)) AS "sqrt", - POWER(ABS(all_types.big_int), (? / ?)) AS "cbrt" + POWER(ABS(all_types.big_int), ? / ?) AS "cbrt" FROM test_sample.all_types LIMIT ?; `, "''", "`")) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 1f3f177..405ec9e 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -563,16 +563,16 @@ SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.numeric > $10) AS "gt2", (all_types.numeric BETWEEN $11 AND all_types.decimal) AS "between", (all_types.numeric NOT BETWEEN (all_types.decimal * $12) AND $13) AS "not_between", - TRUNC((all_types.decimal + all_types.decimal), $14::smallint) AS "add1", - TRUNC((all_types.decimal + $15), $16::smallint) AS "add2", - TRUNC((all_types.decimal - all_types.decimal_ptr), $17::integer) AS "sub1", - TRUNC((all_types.decimal - $18), $19::smallint) AS "sub2", - TRUNC((all_types.decimal * all_types.decimal_ptr), $20::smallint) AS "mul1", - TRUNC((all_types.decimal * $21), $22::integer) AS "mul2", - TRUNC((all_types.decimal / all_types.decimal_ptr), $23::integer) AS "div1", - TRUNC((all_types.decimal / $24), $25::smallint) AS "div2", - TRUNC((all_types.decimal % all_types.decimal_ptr), $26::smallint) AS "mod1", - TRUNC((all_types.decimal % $27), $28::smallint) AS "mod2", + TRUNC(all_types.decimal + all_types.decimal, $14::smallint) AS "add1", + TRUNC(all_types.decimal + $15, $16::smallint) AS "add2", + TRUNC(all_types.decimal - all_types.decimal_ptr, $17::integer) AS "sub1", + TRUNC(all_types.decimal - $18, $19::smallint) AS "sub2", + TRUNC(all_types.decimal * all_types.decimal_ptr, $20::smallint) AS "mul1", + TRUNC(all_types.decimal * $21, $22::integer) AS "mul2", + TRUNC(all_types.decimal / all_types.decimal_ptr, $23::integer) AS "div1", + TRUNC(all_types.decimal / $24, $25::smallint) AS "div2", + TRUNC(all_types.decimal % all_types.decimal_ptr, $26::smallint) AS "mod1", + TRUNC(all_types.decimal % $27, $28::smallint) AS "mod2", TRUNC(POW(all_types.decimal, all_types.decimal_ptr), $29::smallint) AS "pow1", TRUNC(POW(all_types.decimal, $30), $31::smallint) AS "pow2", TRUNC(ABS(all_types.decimal), $32::smallint) AS "abs", @@ -698,7 +698,7 @@ SELECT all_types.big_int AS "all_types.big_int", (all_types.big_int / $16::integer) AS "div2", (all_types.big_int % all_types.big_int) AS "mod1", (all_types.big_int % $17::bigint) AS "mod2", - POW(all_types.small_int, (all_types.small_int / $18::smallint)) AS "pow1", + POW(all_types.small_int, all_types.small_int / $18::smallint) AS "pow1", POW(all_types.small_int, $19::smallint) AS "pow2", (all_types.small_int & all_types.small_int) AS "bit_and1", (all_types.small_int & all_types.small_int) AS "bit_and2",