Add RawStatement support

RawStatement method creates new sql statements from raw query and optional map of named arguments.
This commit is contained in:
go-jet 2021-05-15 11:54:41 +02:00
parent e95a2385ee
commit a5b7769589
11 changed files with 393 additions and 78 deletions

View file

@ -2,8 +2,6 @@ package jet
import (
"fmt"
"sort"
"strings"
"time"
)
@ -402,71 +400,15 @@ type rawExpression struct {
}
func (n *rawExpression) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
raw := n.Raw
type namedArgumentPosition struct {
Name string
Value interface{}
Position int
if !n.noWrap && !contains(options, NoWrap) {
out.WriteByte('(')
}
var namedArgumentPositions []namedArgumentPosition
for namedArg, value := range n.NamedArgument {
rawCopy := n.Raw
rawIndex := 0
exists := false
// one named argument can occur multiple times inside raw string
for {
index := strings.Index(rawCopy, namedArg)
if index == -1 {
break
}
exists = true
namedArgumentPositions = append(namedArgumentPositions, namedArgumentPosition{
Name: namedArg,
Value: value,
Position: rawIndex + index,
})
rawCopy = rawCopy[index+len(namedArg):]
rawIndex += index + len(namedArg)
}
if !exists {
panic("jet: named argument '" + namedArg + "' does not appear in raw query")
}
}
sort.Slice(namedArgumentPositions, func(i, j int) bool {
return namedArgumentPositions[i].Position < namedArgumentPositions[j].Position
})
for _, namedArgumentPos := range namedArgumentPositions {
// if named argument does not exists in raw string do not add argument to the list of arguments
// It can happen if the same argument occurs multiple times in postgres query.
if !strings.Contains(raw, namedArgumentPos.Name) {
continue
}
out.Args = append(out.Args, namedArgumentPos.Value)
currentArgNum := len(out.Args)
dialectPlaceholder := out.Dialect.ArgumentPlaceholder()(currentArgNum)
// if placeholder is not unique identifier ($1, $2, etc..), we will replace just one occurence of the argument
toReplace := -1 // all occurrences
if dialectPlaceholder == "?" {
toReplace = 1 // just one occurrence
}
raw = strings.Replace(raw, namedArgumentPos.Name, dialectPlaceholder, toReplace)
}
out.insertRawQuery(n.Raw, n.NamedArgument)
if !n.noWrap && !contains(options, NoWrap) {
raw = "(" + raw + ")"
out.WriteByte(')')
}
out.WriteString(raw)
}
// Raw can be used for any unsupported functions, operators or expressions.

View file

@ -0,0 +1,47 @@
package jet
type rawStatementImpl struct {
serializerStatementInterfaceImpl
RawQuery string
NamedArguments map[string]interface{}
}
// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(dialect Dialect, rawQuery string, namedArgument ...map[string]interface{}) Statement {
newRawStatement := rawStatementImpl{
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
dialect: dialect,
statementType: "",
parent: nil,
},
RawQuery: rawQuery,
}
if len(namedArgument) > 0 {
newRawStatement.NamedArguments = namedArgument[0]
}
newRawStatement.parent = &newRawStatement
return &newRawStatement
}
func (s *rawStatementImpl) projections() ProjectionList {
return nil
}
func (s *rawStatementImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if !contains(options, NoWrap) {
out.WriteString("(")
out.IncreaseIdent()
}
out.insertRawQuery(s.RawQuery, s.NamedArguments)
if !contains(options, NoWrap) {
out.DecreaseIdent()
out.NewLine()
out.WriteString(")")
}
}

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/v2/internal/utils"
"github.com/google/uuid"
"reflect"
"sort"
"strconv"
"strings"
"time"
@ -135,6 +136,73 @@ func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) {
s.WriteString(argPlaceholder)
}
func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) {
type namedArgumentPosition struct {
Name string
Value interface{}
Position int
}
var namedArgumentPositions []namedArgumentPosition
for namedArg, value := range namedArg {
rawCopy := raw
rawIndex := 0
exists := false
// one named argument can occur multiple times inside raw string
for {
index := strings.Index(rawCopy, namedArg)
if index == -1 {
break
}
exists = true
namedArgumentPositions = append(namedArgumentPositions, namedArgumentPosition{
Name: namedArg,
Value: value,
Position: rawIndex + index,
})
rawCopy = rawCopy[index+len(namedArg):]
rawIndex += index + len(namedArg)
}
if !exists {
panic("jet: named argument '" + namedArg + "' does not appear in raw query")
}
}
sort.Slice(namedArgumentPositions, func(i, j int) bool {
return namedArgumentPositions[i].Position < namedArgumentPositions[j].Position
})
for _, namedArgumentPos := range namedArgumentPositions {
// if named argument does not exists in raw string do not add argument to the list of arguments
// It can happen if the same argument occurs multiple times in postgres query.
if !strings.Contains(raw, namedArgumentPos.Name) {
continue
}
s.Args = append(s.Args, namedArgumentPos.Value)
currentArgNum := len(s.Args)
placeholder := s.Dialect.ArgumentPlaceholder()(currentArgNum)
// if placeholder is not unique identifier ($1, $2, etc..), we will replace just one occurrence of the argument
toReplace := -1 // all occurrences
if placeholder == "?" {
toReplace = 1 // just one occurrence
}
if s.Debug {
placeholder = argToString(namedArgumentPos.Value)
}
raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace)
}
s.WriteString(raw)
}
func argToString(value interface{}) string {
if utils.IsNil(value) {
return "NULL"

View file

@ -116,8 +116,8 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
AssertDeepEqual(t, args, expectedArgs, "arguments are not equal")
}
debuqSql := query.DebugSql()
assertQueryString(t, debuqSql, expectedQuery)
debugSql := query.DebugSql()
assertQueryString(t, debugSql, expectedQuery)
}
// AssertSerialize checks if clause serialize produces expected query and args
@ -134,18 +134,6 @@ func AssertSerialize(t *testing.T, dialect jet.Dialect, serializer jet.Serialize
}
}
// AssertClauseSerialize checks if clause serialize produces expected query and args
func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect}
clause.Serialize(jet.SelectStatementType, &out)
require.Equal(t, out.Buff.String(), query)
if len(args) > 0 {
AssertDeepEqual(t, out.Args, args)
}
}
// AssertDebugSerialize checks if clause serialize produces expected debug query and args
func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect, Debug: true}
@ -158,6 +146,18 @@ func AssertDebugSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializ
}
}
// AssertClauseSerialize checks if clause serialize produces expected query and args
func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Clause, query string, args ...interface{}) {
out := jet.SQLBuilder{Dialect: dialect}
clause.Serialize(jet.SelectStatementType, &out)
require.Equal(t, out.Buff.String(), query)
if len(args) > 0 {
AssertDeepEqual(t, out.Args, args)
}
}
// AssertPanicErr checks if running a function fun produces a panic with errorStr string
func AssertPanicErr(t *testing.T, fun func(), errorStr string) {
defer func() {