[bug] Escape reserved words used as identifier.
This commit is contained in:
parent
63c1fd6430
commit
3019fdbbb2
10 changed files with 226 additions and 4 deletions
|
|
@ -1,5 +1,7 @@
|
|||
package jet
|
||||
|
||||
import "strings"
|
||||
|
||||
// Dialect interface
|
||||
type Dialect interface {
|
||||
Name() string
|
||||
|
|
@ -9,6 +11,7 @@ type Dialect interface {
|
|||
AliasQuoteChar() byte
|
||||
IdentifierQuoteChar() byte
|
||||
ArgumentPlaceholder() QueryPlaceholderFunc
|
||||
IsReservedWord(name string) bool
|
||||
}
|
||||
|
||||
// SerializerFunc func
|
||||
|
|
@ -29,6 +32,7 @@ type DialectParams struct {
|
|||
AliasQuoteChar byte
|
||||
IdentifierQuoteChar byte
|
||||
ArgumentPlaceholder QueryPlaceholderFunc
|
||||
ReservedWords []string
|
||||
}
|
||||
|
||||
// NewDialect creates new dialect with params
|
||||
|
|
@ -41,6 +45,7 @@ func NewDialect(params DialectParams) Dialect {
|
|||
aliasQuoteChar: params.AliasQuoteChar,
|
||||
identifierQuoteChar: params.IdentifierQuoteChar,
|
||||
argumentPlaceholder: params.ArgumentPlaceholder,
|
||||
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -52,6 +57,7 @@ type dialectImpl struct {
|
|||
aliasQuoteChar byte
|
||||
identifierQuoteChar byte
|
||||
argumentPlaceholder QueryPlaceholderFunc
|
||||
reservedWords map[string]bool
|
||||
|
||||
supportsReturning bool
|
||||
}
|
||||
|
|
@ -89,3 +95,17 @@ func (d *dialectImpl) IdentifierQuoteChar() byte {
|
|||
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
|
||||
return d.argumentPlaceholder
|
||||
}
|
||||
|
||||
func (d *dialectImpl) IsReservedWord(name string) bool {
|
||||
_, isReservedWord := d.reservedWords[strings.ToLower(name)]
|
||||
return isReservedWord
|
||||
}
|
||||
|
||||
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
|
||||
ret := map[string]bool{}
|
||||
for _, elem := range arr {
|
||||
ret[strings.ToLower(elem)] = true
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) {
|
|||
|
||||
// WriteIdentifier adds identifier to output SQL
|
||||
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
|
||||
if shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
|
||||
if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
|
||||
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
|
||||
s.WriteString(identQuoteChar + name + identQuoteChar)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -211,6 +211,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
|
|||
}
|
||||
}
|
||||
|
||||
// AssertDeepEqual checks if actual and expected objects are deeply equal.
|
||||
func AssertDeepEqual(t *testing.T, actual, expected interface{}) {
|
||||
assert.True(t, cmp.Equal(actual, expected))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ func ToGoIdentifier(databaseIdentifier string) string {
|
|||
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
|
||||
}
|
||||
|
||||
// ToGoEnumValueIdentifier converts enum value name to Go identifier name.
|
||||
func ToGoEnumValueIdentifier(enumName, enumValue string) string {
|
||||
enumValueIdentifier := ToGoIdentifier(enumValue)
|
||||
if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ func newDialect() jet.Dialect {
|
|||
ArgumentPlaceholder: func(ord int) string {
|
||||
return "$" + strconv.Itoa(ord)
|
||||
},
|
||||
ReservedWords: reservedWords,
|
||||
}
|
||||
|
||||
return jet.NewDialect(dialectParams)
|
||||
|
|
@ -105,3 +106,83 @@ func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.Serializer
|
|||
jet.Serialize(expressions[1], statement, out, options...)
|
||||
}
|
||||
}
|
||||
|
||||
var reservedWords = []string{
|
||||
"ALL",
|
||||
"ANALYSE",
|
||||
"ANALYZE",
|
||||
"AND",
|
||||
"ANY",
|
||||
"ARRAY",
|
||||
"AS",
|
||||
"ASC",
|
||||
"ASYMMETRIC",
|
||||
"BOTH",
|
||||
"CASE",
|
||||
"CAST",
|
||||
"CHECK",
|
||||
"COLLATE",
|
||||
"COLUMN",
|
||||
"CONSTRAINT",
|
||||
"CREATE",
|
||||
"CURRENT_CATALOG",
|
||||
"CURRENT_DATE",
|
||||
"CURRENT_ROLE",
|
||||
"CURRENT_TIME",
|
||||
"CURRENT_TIMESTAMP",
|
||||
"CURRENT_USER",
|
||||
"DEFAULT",
|
||||
"DEFERRABLE",
|
||||
"DESC",
|
||||
"DISTINCT",
|
||||
"DO",
|
||||
"ELSE",
|
||||
"END",
|
||||
"EXCEPT",
|
||||
"FALSE",
|
||||
"FETCH",
|
||||
"FOR",
|
||||
"FOREIGN",
|
||||
"FROM",
|
||||
"GRANT",
|
||||
"GROUP",
|
||||
"HAVING",
|
||||
"IN",
|
||||
"INITIALLY",
|
||||
"INTERSECT",
|
||||
"INTO",
|
||||
"LATERAL",
|
||||
"LEADING",
|
||||
"LIMIT",
|
||||
"LOCALTIME",
|
||||
"LOCALTIMESTAMP",
|
||||
"NOT",
|
||||
"NULL",
|
||||
"OFFSET",
|
||||
"ON",
|
||||
"ONLY",
|
||||
"OR",
|
||||
"ORDER",
|
||||
"PLACING",
|
||||
"PRIMARY",
|
||||
"REFERENCES",
|
||||
"RETURNING",
|
||||
"SELECT",
|
||||
"SESSION_USER",
|
||||
"SOME",
|
||||
"SYMMETRIC",
|
||||
"TABLE",
|
||||
"THEN",
|
||||
"TO",
|
||||
"TRAILING",
|
||||
"TRUE",
|
||||
"UNION",
|
||||
"UNIQUE",
|
||||
"USER",
|
||||
"USING",
|
||||
"VARIADIC",
|
||||
"WHEN",
|
||||
"WHERE",
|
||||
"WINDOW",
|
||||
"WITH",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,3 +74,21 @@ func TestNOT_IN(t *testing.T) {
|
|||
FROM db.table2
|
||||
)))`, int64(12))
|
||||
}
|
||||
|
||||
func TestReservedWordEscaped(t *testing.T) {
|
||||
var table1ColUser = IntervalColumn("user")
|
||||
var table1ColVariadic = IntervalColumn("VARIADIC")
|
||||
var table1ColProcedure = IntervalColumn("procedure")
|
||||
|
||||
_ = NewTable(
|
||||
"db",
|
||||
"table1",
|
||||
table1ColUser,
|
||||
table1ColVariadic,
|
||||
table1ColProcedure,
|
||||
)
|
||||
|
||||
assertSerialize(t, table1ColUser, `table1."user"`)
|
||||
assertSerialize(t, table1ColVariadic, `table1."VARIADIC"`)
|
||||
assertSerialize(t, table1ColProcedure, `table1.procedure`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1095,3 +1095,53 @@ var allTypesJson = `
|
|||
}
|
||||
]
|
||||
`
|
||||
|
||||
func TestReservedWord(t *testing.T) {
|
||||
stmt := SELECT(User.AllColumns).
|
||||
FROM(User)
|
||||
|
||||
// NOTE: A word that follows a period in a qualified name must be an identifier, so it
|
||||
// need not be quoted even if it is reserved
|
||||
testutils.AssertDebugStatementSql(t, stmt, `
|
||||
SELECT user.column AS "user.column",
|
||||
user.use AS "user.use",
|
||||
user.ceil AS "user.ceil",
|
||||
user.commit AS "user.commit",
|
||||
user.create AS "user.create",
|
||||
user.default AS "user.default",
|
||||
user.desc AS "user.desc",
|
||||
user.empty AS "user.empty",
|
||||
user.float AS "user.float",
|
||||
user.join AS "user.join",
|
||||
user.like AS "user.like",
|
||||
user.max AS "user.max",
|
||||
user.rank AS "user.rank"
|
||||
FROM test_sample.user;
|
||||
`)
|
||||
|
||||
var dest []model.User
|
||||
err := stmt.Query(db, &dest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testutils.PrintJson(dest)
|
||||
|
||||
testutils.AssertJSON(t, dest, `
|
||||
[
|
||||
{
|
||||
"Column": "Column",
|
||||
"Use": "CHECK",
|
||||
"Ceil": "CEIL",
|
||||
"Commit": "COMMIT",
|
||||
"Create": "CREATE",
|
||||
"Default": "DEFAULT",
|
||||
"Desc": "DESC",
|
||||
"Empty": "EMPTY",
|
||||
"Float": "FLOAT",
|
||||
"Join": "JOIN",
|
||||
"Like": "LIKE",
|
||||
"Max": "MAX",
|
||||
"Rank": "RANK"
|
||||
}
|
||||
]
|
||||
`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -291,7 +291,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go",
|
||||
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go")
|
||||
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go")
|
||||
|
||||
testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent)
|
||||
|
||||
|
|
@ -299,7 +299,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go",
|
||||
"person.go", "person_phone.go", "weird_names_table.go")
|
||||
"person.go", "person_phone.go", "weird_names_table.go", "user.go")
|
||||
|
||||
testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -328,3 +328,54 @@ FROM test_sample."WEIRD NAMES TABLE";
|
|||
WeirdColuName16: "Doe",
|
||||
})
|
||||
}
|
||||
|
||||
func TestReserwedWordEscape(t *testing.T) {
|
||||
stmt := SELECT(User.AllColumns).
|
||||
FROM(User)
|
||||
|
||||
//fmt.Println(stmt.DebugSql())
|
||||
|
||||
testutils.AssertDebugStatementSql(t, stmt, `
|
||||
SELECT "User"."column" AS "User.column",
|
||||
"User"."check" AS "User.check",
|
||||
"User".ceil AS "User.ceil",
|
||||
"User".commit AS "User.commit",
|
||||
"User"."create" AS "User.create",
|
||||
"User"."default" AS "User.default",
|
||||
"User"."desc" AS "User.desc",
|
||||
"User".empty AS "User.empty",
|
||||
"User".float AS "User.float",
|
||||
"User".join AS "User.join",
|
||||
"User".like AS "User.like",
|
||||
"User".max AS "User.max",
|
||||
"User".rank AS "User.rank"
|
||||
FROM test_sample."User";
|
||||
`)
|
||||
|
||||
var dest []model.User
|
||||
|
||||
err := stmt.Query(db, &dest)
|
||||
assert.NoError(t, err)
|
||||
|
||||
testutils.PrintJson(dest)
|
||||
|
||||
testutils.AssertJSON(t, dest, `
|
||||
[
|
||||
{
|
||||
"Column": "Column",
|
||||
"Check": "CHECK",
|
||||
"Ceil": "CEIL",
|
||||
"Commit": "COMMIT",
|
||||
"Create": "CREATE",
|
||||
"Default": "DEFAULT",
|
||||
"Desc": "DESC",
|
||||
"Empty": "EMPTY",
|
||||
"Float": "FLOAT",
|
||||
"Join": "JOIN",
|
||||
"Like": "LIKE",
|
||||
"Max": "MAX",
|
||||
"Rank": "RANK"
|
||||
}
|
||||
]
|
||||
`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6
|
||||
Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17
|
||||
Loading…
Add table
Add a link
Reference in a new issue