From 7af9072b8d9a6278f4f436bc42c6f2f1d66b09a1 Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 14 May 2021 12:15:35 +0200 Subject: [PATCH] Allow Raw helper to accept named arguments --- internal/jet/literal_expression.go | 136 ++++++++++++++++++++++++++++- internal/testutils/test_utils.go | 2 + mysql/expressions.go | 19 +++- mysql/expressions_test.go | 56 ++++++++++++ mysql/interval.go | 2 +- mysql/select_statement_test.go | 18 ++-- postgres/expressions.go | 21 ++++- postgres/expressions_test.go | 64 ++++++++++++++ postgres/interval_expression.go | 2 +- tests/mysql/alltypes_test.go | 25 ++++-- tests/postgres/alltypes_test.go | 29 ++++-- tests/testdata | 2 +- 12 files changed, 340 insertions(+), 36 deletions(-) create mode 100644 mysql/expressions_test.go create mode 100644 postgres/expressions_test.go diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 29560e4..b5ac9f0 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -2,6 +2,8 @@ package jet import ( "fmt" + "sort" + "strings" "time" ) @@ -394,22 +396,148 @@ func WRAP(expression ...Expression) Expression { type rawExpression struct { ExpressionInterfaceImpl - Raw string + Raw string + NamedArgument map[string]interface{} + noWrap bool } func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString(n.Raw) + raw := n.Raw + + type namedArgumentPosition struct { + Name string + Value interface{} + Position int + } + + var namedArgumentPositions []namedArgumentPosition + + for namedArg, value := range n.NamedArgument { + rawCopy := n.Raw + rawIndex := 0 + exists := false + + // one named argument can occur multiple times inside raw string + for { + index := strings.Index(rawCopy, namedArg) + if index == -1 { + break + } + + exists = true + namedArgumentPositions = append(namedArgumentPositions, namedArgumentPosition{ + Name: namedArg, + Value: value, + Position: rawIndex + index, + }) + + rawCopy = rawCopy[index+len(namedArg):] + rawIndex += index + len(namedArg) + } + + if !exists { + panic("jet: named argument '" + namedArg + "' does not appear in raw query") + } + } + + sort.Slice(namedArgumentPositions, func(i, j int) bool { + return namedArgumentPositions[i].Position < namedArgumentPositions[j].Position + }) + + for _, namedArgumentPos := range namedArgumentPositions { + // if named argument does not exists in raw string do not add argument to the list of arguments + // It can happen if the same argument occurs multiple times in postgres query. + if !strings.Contains(raw, namedArgumentPos.Name) { + continue + } + out.Args = append(out.Args, namedArgumentPos.Value) + currentArgNum := len(out.Args) + + dialectPlaceholder := out.Dialect.ArgumentPlaceholder()(currentArgNum) + // if placeholder is not unique identifier ($1, $2, etc..), we will replace just one occurence of the argument + toReplace := -1 // all occurrences + if dialectPlaceholder == "?" { + toReplace = 1 // just one occurrence + } + raw = strings.Replace(raw, namedArgumentPos.Name, dialectPlaceholder, toReplace) + } + + if !n.noWrap && !contains(options, NoWrap) { + raw = "(" + raw + ")" + } + + out.WriteString(raw) } // Raw can be used for any unsupported functions, operators or expressions. // For example: Raw("current_database()") -func Raw(raw string, parent ...Expression) Expression { - rawExp := &rawExpression{Raw: raw} +func Raw(raw string, namedArgs ...map[string]interface{}) Expression { + var namedArguments map[string]interface{} + + if len(namedArgs) > 0 { + namedArguments = namedArgs[0] + } + + rawExp := &rawExpression{ + Raw: raw, + NamedArgument: namedArguments, + } + rawExp.ExpressionInterfaceImpl.Parent = rawExp + + return rawExp +} + +// RawWithParent is a Raw constructor used for construction dialect specific expression +func RawWithParent(raw string, parent ...Expression) Expression { + rawExp := &rawExpression{ + Raw: raw, + noWrap: true, + } rawExp.ExpressionInterfaceImpl.Parent = OptionalOrDefaultExpression(rawExp, parent...) return rawExp } +// Raw helper that for integer expressions +func RawInt(raw string, namedArgs ...map[string]interface{}) IntegerExpression { + return IntExp(Raw(raw, namedArgs...)) +} + +// Raw helper that for float expressions +func RawFloat(raw string, namedArgs ...map[string]interface{}) FloatExpression { + return FloatExp(Raw(raw, namedArgs...)) +} + +// Raw helper that for string expressions +func RawString(raw string, namedArgs ...map[string]interface{}) StringExpression { + return StringExp(Raw(raw, namedArgs...)) +} + +// Raw helper that for time expressions +func RawTime(raw string, namedArgs ...map[string]interface{}) TimeExpression { + return TimeExp(Raw(raw, namedArgs...)) +} + +// Raw helper that for time with time zone expressions +func RawTimez(raw string, namedArgs ...map[string]interface{}) TimezExpression { + return TimezExp(Raw(raw, namedArgs...)) +} + +// Raw helper that for timestamp expressions +func RawTimestamp(raw string, namedArgs ...map[string]interface{}) TimestampExpression { + return TimestampExp(Raw(raw, namedArgs...)) +} + +// Raw helper that for timestamp with time zone expressions +func RawTimestampz(raw string, namedArgs ...map[string]interface{}) TimestampzExpression { + return TimestampzExp(Raw(raw, namedArgs...)) +} + +// Raw helper that for date expressions +func RawDate(raw string, namedArgs ...map[string]interface{}) DateExpression { + return DateExp(Raw(raw, namedArgs...)) +} + // UUID is a helper function to create string literal expression from uuid object // value can be any uuid type with a String method func UUID(value fmt.Stringer) StringExpression { diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 9dd6318..f849219 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -226,12 +226,14 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { if !assert.True(t, cmp.Equal(actual, expected), msg) { printDiff(actual, expected) + t.FailNow() } } func assertQueryString(t *testing.T, actual, expected string) { if !assert.Equal(t, actual, expected) { printDiff(actual, expected) + t.FailNow() } } diff --git a/mysql/expressions.go b/mysql/expressions.go index 3ee660d..7c13939 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -70,9 +70,22 @@ var DateTimeExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampExp = jet.TimestampExp -// Raw can be used for any unsupported functions, operators or expressions. -// For example: Raw("current_database()") -var Raw = jet.Raw +// RawArgs is type used to pass optional arguments to Raw method +type RawArgs = map[string]interface{} + +var ( + // Raw can be used for any unsupported functions, operators or expressions. + // For example: Raw("current_database()") + Raw = jet.Raw + + // Raw helper methods for each of the mysql type + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimestamp = jet.RawTimestamp + RawDate = jet.RawDate +) // Func can be used to call an custom or as of yet unsupported function in the database. var Func = jet.Func diff --git a/mysql/expressions_test.go b/mysql/expressions_test.go new file mode 100644 index 0000000..127fccd --- /dev/null +++ b/mysql/expressions_test.go @@ -0,0 +1,56 @@ +package mysql + +import ( + "testing" + time2 "time" + + "github.com/stretchr/testify/require" +) + +func TestRaw(t *testing.T) { + assertSerialize(t, Raw("current_database()"), "(current_database())") + + assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "(? + table.colInt + ?)", 11, 22) + + assertSerialize(t, + Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})), + "(? + (? + table.colInt + ?))", + int64(700), 11, 22) +} + +func TestRawDuplicateArguments(t *testing.T) { + assertSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}), + "(? + table.colInt + ?)", 11, 11) + + assertSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}), + "(? + table.colInt + ? + ? + ? + 11)", 11, 2000, 11, 2000) + + assertSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4", + RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}), + `(? + all_types.integer + ? + ? + ? + ? + ?)`, 11, 22, 11, 22, 33, 44) +} + +func TestRawInvalidArguments(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, "jet: named argument 'first_arg' does not appear in raw query", r) + }() + + assertSerialize(t, Raw("table.colInt + :second_arg", RawArgs{"first_arg": 11}), "(table.colInt + ?)", 22) +} + +func TestRawType(t *testing.T) { + assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), + "((table.colInt + ?) = ?)", 11.22, 3.14) + assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")), + "((table.colStr || ?) = ?)", "doe", "john doe") + + time := time2.Now() + assertSerialize(t, RawTime("table.colTime").EQ(TimeT(time)), + "((table.colTime) = CAST(? AS TIME))", time) + assertSerialize(t, RawTimestamp("table.colTimestamp").EQ(TimestampT(time)), + "((table.colTimestamp) = TIMESTAMP(?))", time) + assertSerialize(t, RawDate("table.colDate").EQ(DateT(time)), + "((table.colDate) = CAST(? AS DATE))", time) +} diff --git a/mysql/interval.go b/mysql/interval.go index f325bb8..c563855 100644 --- a/mysql/interval.go +++ b/mysql/interval.go @@ -97,7 +97,7 @@ func INTERVAL(value interface{}, unitType unitType) Interval { // INTERVALe creates new temporal interval from expresion and unit type. func INTERVALe(expr Expression, unitType unitType) Interval { return jet.NewInterval(jet.ListSerializer{ - Serializers: []jet.Serializer{expr, jet.Raw(string(unitType))}, + Serializers: []jet.Serializer{expr, jet.RawWithParent(string(unitType))}, Separator: " ", }) } diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go index 2312a60..630f800 100644 --- a/mysql/select_statement_test.go +++ b/mysql/select_statement_test.go @@ -136,15 +136,15 @@ LOCK IN SHARE MODE; func TestSelect_NOT_EXISTS(t *testing.T) { testutils.AssertStatementSql(t, SELECT(table1ColInt). - FROM(table1). - WHERE( - NOT(EXISTS( - SELECT(table2ColInt). - FROM(table2). - WHERE( - table1ColInt.EQ(table2ColInt), - ), - ))), ` + FROM(table1). + WHERE( + NOT(EXISTS( + SELECT(table2ColInt). + FROM(table2). + WHERE( + table1ColInt.EQ(table2ColInt), + ), + ))), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 WHERE (NOT (EXISTS ( diff --git a/postgres/expressions.go b/postgres/expressions.go index 57a3335..ca6223b 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -81,9 +81,24 @@ var TimestampExp = jet.TimestampExp // Does not add sql cast to generated sql builder output. var TimestampzExp = jet.TimestampzExp -// Raw can be used for any unsupported functions, operators or expressions. -// For example: Raw("current_database()") -var Raw = jet.Raw +// RawArgs is type used to pass optional arguments to Raw method +type RawArgs = map[string]interface{} + +var ( + // Raw can be used for any unsupported functions, operators or expressions. + // For example: Raw("current_database()") + Raw = jet.Raw + + // Raw helper methods for each of the postgres type + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimez = jet.RawTimez + RawTimestamp = jet.RawTimestamp + RawTimestampz = jet.RawTimestampz + RawDate = jet.RawDate +) // Func can be used to call an custom or as of yet unsupported function in the database. var Func = jet.Func diff --git a/postgres/expressions_test.go b/postgres/expressions_test.go new file mode 100644 index 0000000..1e7f3c6 --- /dev/null +++ b/postgres/expressions_test.go @@ -0,0 +1,64 @@ +package postgres + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRaw(t *testing.T) { + assertSerialize(t, Raw("current_database()"), "(current_database())") + + assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "($1 + table.colInt + $2)", 11, 22) + + assertSerialize(t, + Int(700).ADD(RawInt(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22})), + "($1 + ($2 + table.colInt + $3))", + int64(700), 11, 22) +} + +func TestDuplicateArguments(t *testing.T) { + + assertSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}), + "($1 + table.colInt + $1)", 11) + + assertSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}), + "($1 + table.colInt + $2 + $1 + $2 + 11)", 11, 2000) + + assertSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4", + RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}), + `($1 + all_types.integer + $2 + $1 + $2 + $3 + $4)`, 11, 22, 11, 22, 33, 44) +} + +func TestRawInvalidArguments(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, "jet: named argument 'first_arg' does not appear in raw query", r) + }() + + assertSerialize(t, Raw("table.colInt + :second_arg", RawArgs{ + "first_arg": 11, + "second_arg": 22, + }), "(table.colInt + $1)", 22) +} + +func TestRawHelperMethods(t *testing.T) { + assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)), + "((table.colInt + $1) = $2)", 11.22, 3.14) + assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")), + "((table.colStr || $1) = $2)", "doe", "john doe") + + now := time.Now() + assertSerialize(t, RawTime("table.colTime").EQ(TimeT(now)), + "((table.colTime) = $1::time without time zone)", now) + assertSerialize(t, RawTimez("table.colTime").EQ(TimezT(now)), + "((table.colTime) = $1::time with time zone)", now) + assertSerialize(t, RawTimestamp("table.colTimestamp").EQ(TimestampT(now)), + "((table.colTimestamp) = $1::timestamp without time zone)", now) + assertSerialize(t, RawTimestampz("table.colTimestampz").EQ(TimestampzT(now)), + "((table.colTimestampz) = $1::timestamp with time zone)", now) + assertSerialize(t, RawDate("table.colDate").EQ(DateT(now)), + "((table.colDate) = $1::date)", now) +} diff --git a/postgres/interval_expression.go b/postgres/interval_expression.go index df2ed60..b8468cf 100644 --- a/postgres/interval_expression.go +++ b/postgres/interval_expression.go @@ -128,7 +128,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { newInterval := &intervalExpression{} - newInterval.Expression = jet.Raw(intervalStr, newInterval) + newInterval.Expression = jet.RawWithParent(intervalStr, newInterval) newInterval.intervalInterfaceImpl.parent = newInterval return newInterval diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 70ca277..a85d0b4 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -89,14 +89,17 @@ func TestExpressionOperators(t *testing.T) { AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), + + Raw("CURRENT_USER()").AS("result.raw"), + Raw(":first + COALESCE(all_types.small_int_ptr, 0) + :second", RawArgs{":first": 78, ":second": 56}). + AS("result.raw_arg"), + Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}). + AS("result.raw_arg2"), + AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), - - Raw("DATABASE()"), ).LIMIT(2) - //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, strings.Replace(` SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", @@ -105,15 +108,17 @@ SELECT all_types.'integer' IS NULL AS "result.is_null", SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types ))) AS "result.in_select", + (CURRENT_USER()) AS "result.raw", + (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", + (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN (( SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types - ))) AS "result.not_in_select", - DATABASE() + ))) AS "result.not_in_select" FROM test_sample.all_types LIMIT ?; -`, "'", "`", -1), int64(11), int64(22), int64(11), int64(22), int64(2)) +`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` @@ -132,6 +137,9 @@ LIMIT ?; "IsNotNull": true, "In": false, "InSelect": false, + "Raw": "jet@localhost", + "RawArg": 148, + "RawArg2": -1479, "NotIn": null, "NotInSelect": true }, @@ -140,6 +148,9 @@ LIMIT ?; "IsNotNull": false, "In": null, "InSelect": null, + "Raw": "jet@localhost", + "RawArg": 134, + "RawArg2": -1479, "NotIn": null, "NotInSelect": null } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 4379e67..8bd1682 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,17 +1,19 @@ package postgres import ( - "github.com/stretchr/testify/require" "testing" "time" + "github.com/stretchr/testify/require" + + "github.com/google/uuid" + "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/view" "github.com/go-jet/jet/v2/tests/testdata/results/common" - "github.com/google/uuid" ) func TestAllTypesSelect(t *testing.T) { @@ -221,12 +223,16 @@ func TestExpressionOperators(t *testing.T) { AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), + + Raw("CURRENT_USER").AS("result.raw"), + Raw("#1 + COALESCE(all_types.small_int_ptr, 0) + #2", RawArgs{"#1": 78, "#2": 56}).AS("result.raw_arg"), + Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", + RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).AS("result.raw_arg2"), + AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), ).LIMIT(2) - //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT all_types.integer IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", @@ -235,14 +241,17 @@ SELECT all_types.integer IS NULL AS "result.is_null", SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types ))) AS "result.in_select", - (all_types.small_int_ptr NOT IN ($3, $4, NULL)) AS "result.not_in", + (CURRENT_USER) AS "result.raw", + ($3 + COALESCE(all_types.small_int_ptr, 0) + $4) AS "result.raw_arg", + ($5 + all_types.integer + $6 + $5 + $7 + $8) AS "result.raw_arg2", + (all_types.small_int_ptr NOT IN ($9, $10, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN (( SELECT all_types.integer AS "all_types.integer" FROM test_sample.all_types ))) AS "result.not_in_select" FROM test_sample.all_types -LIMIT $5; -`, int64(11), int64(22), int64(11), int64(22), int64(2)) +LIMIT $11; +`, int64(11), int64(22), 78, 56, 11, 22, 33, 44, int64(11), int64(22), int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` @@ -261,6 +270,9 @@ LIMIT $5; "IsNotNull": true, "In": false, "InSelect": false, + "Raw": "jet", + "RawArg": 148, + "RawArg2": 421, "NotIn": null, "NotInSelect": true }, @@ -269,6 +281,9 @@ LIMIT $5; "IsNotNull": false, "In": null, "InSelect": null, + "Raw": "jet", + "RawArg": 134, + "RawArg2": 421, "NotIn": null, "NotInSelect": null } diff --git a/tests/testdata b/tests/testdata index 1c97764..0d52780 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1c977643ceb0df149fc953ad617e2a86c6ecdd65 +Subproject commit 0d52780c6510d4b1e560081a82648b85c555ce43