From a5b77695894d6216c4d80153b9d54c671f06d4cb Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 15 May 2021 11:54:41 +0200 Subject: [PATCH] Add RawStatement support RawStatement method creates new sql statements from raw query and optional map of named arguments. --- internal/jet/literal_expression.go | 66 +----------- internal/jet/raw_statement.go | 47 +++++++++ internal/jet/sql_builder.go | 68 +++++++++++++ internal/testutils/test_utils.go | 28 +++--- mysql/expressions_test.go | 6 ++ mysql/statement.go | 8 ++ postgres/expressions_test.go | 16 ++- postgres/statement.go | 8 ++ postgres/utils_test.go | 4 + tests/mysql/raw_statement_test.go | 82 +++++++++++++++ tests/postgres/raw_statements_test.go | 138 ++++++++++++++++++++++++++ 11 files changed, 393 insertions(+), 78 deletions(-) create mode 100644 internal/jet/raw_statement.go create mode 100644 mysql/statement.go create mode 100644 postgres/statement.go create mode 100644 tests/mysql/raw_statement_test.go create mode 100644 tests/postgres/raw_statements_test.go diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index b5ac9f0..9ce9a81 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -2,8 +2,6 @@ package jet import ( "fmt" - "sort" - "strings" "time" ) @@ -402,71 +400,15 @@ type rawExpression struct { } func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - raw := n.Raw - - type namedArgumentPosition struct { - Name string - Value interface{} - Position int + if !n.noWrap && !contains(options, NoWrap) { + out.WriteByte('(') } - var namedArgumentPositions []namedArgumentPosition - - for namedArg, value := range n.NamedArgument { - rawCopy := n.Raw - rawIndex := 0 - exists := false - - // one named argument can occur multiple times inside raw string - for { - index := strings.Index(rawCopy, namedArg) - if index == -1 { - break - } - - exists = true - namedArgumentPositions = append(namedArgumentPositions, namedArgumentPosition{ - Name: namedArg, - Value: value, - Position: rawIndex + index, - }) - - rawCopy = rawCopy[index+len(namedArg):] - rawIndex += index + len(namedArg) - } - - if !exists { - panic("jet: named argument '" + namedArg + "' does not appear in raw query") - } - } - - sort.Slice(namedArgumentPositions, func(i, j int) bool { - return namedArgumentPositions[i].Position < namedArgumentPositions[j].Position - }) - - for _, namedArgumentPos := range namedArgumentPositions { - // if named argument does not exists in raw string do not add argument to the list of arguments - // It can happen if the same argument occurs multiple times in postgres query. - if !strings.Contains(raw, namedArgumentPos.Name) { - continue - } - out.Args = append(out.Args, namedArgumentPos.Value) - currentArgNum := len(out.Args) - - dialectPlaceholder := out.Dialect.ArgumentPlaceholder()(currentArgNum) - // if placeholder is not unique identifier ($1, $2, etc..), we will replace just one occurence of the argument - toReplace := -1 // all occurrences - if dialectPlaceholder == "?" { - toReplace = 1 // just one occurrence - } - raw = strings.Replace(raw, namedArgumentPos.Name, dialectPlaceholder, toReplace) - } + out.insertRawQuery(n.Raw, n.NamedArgument) if !n.noWrap && !contains(options, NoWrap) { - raw = "(" + raw + ")" + out.WriteByte(')') } - - out.WriteString(raw) } // Raw can be used for any unsupported functions, operators or expressions. diff --git a/internal/jet/raw_statement.go b/internal/jet/raw_statement.go new file mode 100644 index 0000000..191c7b4 --- /dev/null +++ b/internal/jet/raw_statement.go @@ -0,0 +1,47 @@ +package jet + +type rawStatementImpl struct { + serializerStatementInterfaceImpl + + RawQuery string + NamedArguments map[string]interface{} +} + +// RawStatement creates new sql statements from raw query and optional map of named arguments +func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement { + newRawStatement := rawStatementImpl{ + serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ + dialect: dialect, + statementType: "", + parent: nil, + }, + RawQuery: rawQuery, + } + + if len(namedArgument) > 0 { + newRawStatement.NamedArguments = namedArgument[0] + } + + newRawStatement.parent = &newRawStatement + + return &newRawStatement +} + +func (s *rawStatementImpl) projections() ProjectionList { + return nil +} + +func (s *rawStatementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + if !contains(options, NoWrap) { + out.WriteString("(") + out.IncreaseIdent() + } + + out.insertRawQuery(s.RawQuery, s.NamedArguments) + + if !contains(options, NoWrap) { + out.DecreaseIdent() + out.NewLine() + out.WriteString(")") + } +} diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index bca078d..6241fee 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/v2/internal/utils" "github.com/google/uuid" "reflect" + "sort" "strconv" "strings" "time" @@ -135,6 +136,73 @@ func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { s.WriteString(argPlaceholder) } +func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) { + type namedArgumentPosition struct { + Name string + Value interface{} + Position int + } + + var namedArgumentPositions []namedArgumentPosition + + for namedArg, value := range namedArg { + rawCopy := raw + rawIndex := 0 + exists := false + + // one named argument can occur multiple times inside raw string + for { + index := strings.Index(rawCopy, namedArg) + if index == -1 { + break + } + + exists = true + namedArgumentPositions = append(namedArgumentPositions, namedArgumentPosition{ + Name: namedArg, + Value: value, + Position: rawIndex + index, + }) + + rawCopy = rawCopy[index+len(namedArg):] + rawIndex += index + len(namedArg) + } + + if !exists { + panic("jet: named argument '" + namedArg + "' does not appear in raw query") + } + } + + sort.Slice(namedArgumentPositions, func(i, j int) bool { + return namedArgumentPositions[i].Position < namedArgumentPositions[j].Position + }) + + for _, namedArgumentPos := range namedArgumentPositions { + // if named argument does not exists in raw string do not add argument to the list of arguments + // It can happen if the same argument occurs multiple times in postgres query. + if !strings.Contains(raw, namedArgumentPos.Name) { + continue + } + s.Args = append(s.Args, namedArgumentPos.Value) + currentArgNum := len(s.Args) + + placeholder := s.Dialect.ArgumentPlaceholder()(currentArgNum) + // if placeholder is not unique identifier ($1, $2, etc..), we will replace just one occurrence of the argument + toReplace := -1 // all occurrences + if placeholder == "?" { + toReplace = 1 // just one occurrence + } + + if s.Debug { + placeholder = argToString(namedArgumentPos.Value) + } + + raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace) + } + + s.WriteString(raw) +} + func argToString(value interface{}) string { if utils.IsNil(value) { return "NULL" diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index f849219..dd5e790 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -116,8 +116,8 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st AssertDeepEqual(t, args, expectedArgs, "arguments are not equal") } - debuqSql := query.DebugSql() - assertQueryString(t, debuqSql, expectedQuery) + debugSql := query.DebugSql() + assertQueryString(t, debugSql, expectedQuery) } // AssertSerialize checks if clause serialize produces expected query and args @@ -134,18 +134,6 @@ func AssertSerialize(t *testing.T, dialect jet.Dialect, serializer jet.Serialize } } -// AssertClauseSerialize checks if clause serialize produces expected query and args -func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) { - out := jet.SQLBuilder{Dialect: dialect} - clause.Serialize(jet.SelectStatementType, &out) - - require.Equal(t, out.Buff.String(), query) - - if len(args) > 0 { - AssertDeepEqual(t, out.Args, args) - } -} - // AssertDebugSerialize checks if clause serialize produces expected debug query and args func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { out := jet.SQLBuilder{Dialect: dialect, Debug: true} @@ -158,6 +146,18 @@ func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializ } } +// AssertClauseSerialize checks if clause serialize produces expected query and args +func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) { + out := jet.SQLBuilder{Dialect: dialect} + clause.Serialize(jet.SelectStatementType, &out) + + require.Equal(t, out.Buff.String(), query) + + if len(args) > 0 { + AssertDeepEqual(t, out.Args, args) + } +} + // AssertPanicErr checks if running a function fun produces a panic with errorStr string func AssertPanicErr(t *testing.T, fun func(), errorStr string) { defer func() { diff --git a/mysql/expressions_test.go b/mysql/expressions_test.go index 127fccd..2826c08 100644 --- a/mysql/expressions_test.go +++ b/mysql/expressions_test.go @@ -9,14 +9,20 @@ import ( func TestRaw(t *testing.T) { assertSerialize(t, Raw("current_database()"), "(current_database())") + assertDebugSerialize(t, Raw("current_database()"), "(current_database())") assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), "(? + table.colInt + ?)", 11, 22) + assertDebugSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "(11 + table.colInt + 22)") assertSerialize(t, Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})), "(? + (? + table.colInt + ?))", int64(700), 11, 22) + assertDebugSerialize(t, + Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})), + "(700 + (11 + table.colInt + 22))") } func TestRawDuplicateArguments(t *testing.T) { diff --git a/mysql/statement.go b/mysql/statement.go new file mode 100644 index 0000000..073adce --- /dev/null +++ b/mysql/statement.go @@ -0,0 +1,8 @@ +package mysql + +import "github.com/go-jet/jet/v2/internal/jet" + +// RawStatement creates new sql statements from raw query and optional map of named arguments +func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { + return jet.RawStatement(Dialect, rawQuery, namedArguments...) +} diff --git a/postgres/expressions_test.go b/postgres/expressions_test.go index 1e7f3c6..77c3dee 100644 --- a/postgres/expressions_test.go +++ b/postgres/expressions_test.go @@ -9,27 +9,39 @@ import ( func TestRaw(t *testing.T) { assertSerialize(t, Raw("current_database()"), "(current_database())") + assertDebugSerialize(t, Raw("current_database()"), "(current_database())") assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), "($1 + table.colInt + $2)", 11, 22) + assertDebugSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "(11 + table.colInt + 22)") assertSerialize(t, Int(700).ADD(RawInt(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22})), "($1 + ($2 + table.colInt + $3))", int64(700), 11, 22) + assertDebugSerialize(t, + Int(700).ADD(RawInt(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22})), + "(700 + (11 + table.colInt + 22))") } func TestDuplicateArguments(t *testing.T) { - assertSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}), "($1 + table.colInt + $1)", 11) + assertDebugSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}), + "(11 + table.colInt + 11)") assertSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}), "($1 + table.colInt + $2 + $1 + $2 + 11)", 11, 2000) + assertDebugSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}), + "(11 + table.colInt + 2000 + 11 + 2000 + 11)") assertSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}), - `($1 + all_types.integer + $2 + $1 + $2 + $3 + $4)`, 11, 22, 11, 22, 33, 44) + `($1 + all_types.integer + $2 + $1 + $2 + $3 + $4)`, 11, 22, 33, 44) + assertDebugSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4", + RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}), + `(11 + all_types.integer + 22 + 11 + 22 + 33 + 44)`) } func TestRawInvalidArguments(t *testing.T) { diff --git a/postgres/statement.go b/postgres/statement.go new file mode 100644 index 0000000..d10bd65 --- /dev/null +++ b/postgres/statement.go @@ -0,0 +1,8 @@ +package postgres + +import "github.com/go-jet/jet/v2/internal/jet" + +// RawStatement creates new sql statements from raw query and optional map of named arguments +func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { + return jet.RawStatement(Dialect, rawQuery, namedArguments...) +} diff --git a/postgres/utils_test.go b/postgres/utils_test.go index bd59b43..292d7e4 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -58,6 +58,10 @@ func assertSerialize(t *testing.T, serializer jet.Serializer, query string, args testutils.AssertSerialize(t, Dialect, serializer, query, args...) } +func assertDebugSerialize(t *testing.T, serializer jet.Serializer, query string, args ...interface{}) { + testutils.AssertDebugSerialize(t, Dialect, serializer, query, args...) +} + func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) { testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } diff --git a/tests/mysql/raw_statement_test.go b/tests/mysql/raw_statement_test.go new file mode 100644 index 0000000..9af4c46 --- /dev/null +++ b/tests/mysql/raw_statement_test.go @@ -0,0 +1,82 @@ +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" + + . "github.com/go-jet/jet/v2/mysql" +) + +func TestRawStatementSelect(t *testing.T) { + stmt := RawStatement(` + SELECT actor.first_name AS "actor.first_name" + FROM dvds.actor + WHERE actor.actor_id = 2`) + + testutils.AssertStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM dvds.actor + WHERE actor.actor_id = 2; +`) + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM dvds.actor + WHERE actor.actor_id = 2; +`) + var actor model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + require.Equal(t, actor.FirstName, "NICK") +} + +func TestRawStatementSelectWithArguments(t *testing.T) { + stmt := RawStatement(` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + WHERE actor.actor_id IN (#actorID1, #actorID2, #actorID3) AND ((#actorID1 / #actorID2) <> (#actorID2 * #actorID3)) + ORDER BY actor.actor_id`, + RawArgs{ + "#actorID1": int64(1), + "#actorID2": int64(2), + "#actorID3": int64(3), + }, + ) + + testutils.AssertStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + WHERE actor.actor_id IN (?, ?, ?) AND ((? / ?) <> (? * ?)) + ORDER BY actor.actor_id; +`, int64(1), int64(2), int64(3), int64(1), int64(2), int64(2), int64(3)) + + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + WHERE actor.actor_id IN (1, 2, 3) AND ((1 / 2) <> (2 * 3)) + ORDER BY actor.actor_id; +`) + + var actor []model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + + testutils.AssertDeepEqual(t, actor[1], model.Actor{ + ActorID: 2, + FirstName: "NICK", + LastName: "WAHLBERG", + LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), + }) +} diff --git a/tests/postgres/raw_statements_test.go b/tests/postgres/raw_statements_test.go new file mode 100644 index 0000000..61c3228 --- /dev/null +++ b/tests/postgres/raw_statements_test.go @@ -0,0 +1,138 @@ +package postgres + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" + + . "github.com/go-jet/jet/v2/postgres" +) + +func TestRawStatementSelect(t *testing.T) { + stmt := RawStatement(` + SELECT actor.first_name AS "actor.first_name" + FROM dvds.actor + WHERE actor.actor_id = 2`) + + testutils.AssertStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM dvds.actor + WHERE actor.actor_id = 2; +`) + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM dvds.actor + WHERE actor.actor_id = 2; +`) + var actor model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + require.Equal(t, actor.FirstName, "Nick") +} + +func TestRawStatementSelectWithArguments(t *testing.T) { + stmt := RawStatement(` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + WHERE actor.actor_id IN (#actorID1, #actorID2, #actorID3) AND ((#actorID1 / #actorID2) <> (#actorID2 * #actorID3)) + ORDER BY actor.actor_id`, + RawArgs{ + "#actorID1": int64(1), + "#actorID2": int64(2), + "#actorID3": int64(3), + }, + ) + + testutils.AssertStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + WHERE actor.actor_id IN ($1, $2, $3) AND (($1 / $2) <> ($2 * $3)) + ORDER BY actor.actor_id; +`, int64(1), int64(2), int64(3)) + + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + WHERE actor.actor_id IN (1, 2, 3) AND ((1 / 2) <> (2 * 3)) + ORDER BY actor.actor_id; +`) + + var actor []model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + + testutils.AssertDeepEqual(t, actor[1], model.Actor{ + ActorID: 2, + FirstName: "Nick", + LastName: "Wahlberg", + LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), + }) +} + +func TestRawInsert(t *testing.T) { + cleanUpLinkTable(t) + + stmt := RawStatement(` +INSERT INTO test_sample.link (id, url, name, description) +VALUES (@id1, @url1, @name1, DEFAULT), + (200, @url1, @name1, NULL), + (@id2, @url2, @name2, DEFAULT), + (@id3, @url3, @name3, NULL) +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"`, + RawArgs{ + "@id1": 100, "@url1": "http://www.postgresqltutorial.com", "@name1": "PostgreSQL Tutorial", + "@id2": 101, "@url2": "http://www.google.com", "@name2": "Google", + "@id3": 102, "@url3": "http://www.yahoo.com", "@name3": "Yahoo", + }) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES ($1, $2, $3, DEFAULT), + (200, $2, $3, NULL), + ($4, $5, $6, DEFAULT), + ($7, $8, $9, NULL) +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + 101, "http://www.google.com", "Google", + 102, "http://www.yahoo.com", "Yahoo") + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (200, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL) +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + var links []model2.Link + err := stmt.Query(db, &links) + require.NoError(t, err) + require.Len(t, links, 4) + require.Equal(t, links[0].ID, int32(100)) + require.Equal(t, links[1].URL, "http://www.postgresqltutorial.com") + require.Equal(t, links[2].Name, "Google") + require.Nil(t, links[2].Description) +}