[bug] Escape reserved words used as identifier.
This commit is contained in:
parent
63c1fd6430
commit
3019fdbbb2
10 changed files with 226 additions and 4 deletions
|
|
@ -1,5 +1,7 @@
|
||||||
package jet
|
package jet
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
// Dialect interface
|
// Dialect interface
|
||||||
type Dialect interface {
|
type Dialect interface {
|
||||||
Name() string
|
Name() string
|
||||||
|
|
@ -9,6 +11,7 @@ type Dialect interface {
|
||||||
AliasQuoteChar() byte
|
AliasQuoteChar() byte
|
||||||
IdentifierQuoteChar() byte
|
IdentifierQuoteChar() byte
|
||||||
ArgumentPlaceholder() QueryPlaceholderFunc
|
ArgumentPlaceholder() QueryPlaceholderFunc
|
||||||
|
IsReservedWord(name string) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// SerializerFunc func
|
// SerializerFunc func
|
||||||
|
|
@ -29,6 +32,7 @@ type DialectParams struct {
|
||||||
AliasQuoteChar byte
|
AliasQuoteChar byte
|
||||||
IdentifierQuoteChar byte
|
IdentifierQuoteChar byte
|
||||||
ArgumentPlaceholder QueryPlaceholderFunc
|
ArgumentPlaceholder QueryPlaceholderFunc
|
||||||
|
ReservedWords []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDialect creates new dialect with params
|
// NewDialect creates new dialect with params
|
||||||
|
|
@ -41,6 +45,7 @@ func NewDialect(params DialectParams) Dialect {
|
||||||
aliasQuoteChar: params.AliasQuoteChar,
|
aliasQuoteChar: params.AliasQuoteChar,
|
||||||
identifierQuoteChar: params.IdentifierQuoteChar,
|
identifierQuoteChar: params.IdentifierQuoteChar,
|
||||||
argumentPlaceholder: params.ArgumentPlaceholder,
|
argumentPlaceholder: params.ArgumentPlaceholder,
|
||||||
|
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -52,6 +57,7 @@ type dialectImpl struct {
|
||||||
aliasQuoteChar byte
|
aliasQuoteChar byte
|
||||||
identifierQuoteChar byte
|
identifierQuoteChar byte
|
||||||
argumentPlaceholder QueryPlaceholderFunc
|
argumentPlaceholder QueryPlaceholderFunc
|
||||||
|
reservedWords map[string]bool
|
||||||
|
|
||||||
supportsReturning bool
|
supportsReturning bool
|
||||||
}
|
}
|
||||||
|
|
@ -89,3 +95,17 @@ func (d *dialectImpl) IdentifierQuoteChar() byte {
|
||||||
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
|
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
|
||||||
return d.argumentPlaceholder
|
return d.argumentPlaceholder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *dialectImpl) IsReservedWord(name string) bool {
|
||||||
|
_, isReservedWord := d.reservedWords[strings.ToLower(name)]
|
||||||
|
return isReservedWord
|
||||||
|
}
|
||||||
|
|
||||||
|
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
|
||||||
|
ret := map[string]bool{}
|
||||||
|
for _, elem := range arr {
|
||||||
|
ret[strings.ToLower(elem)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) {
|
||||||
|
|
||||||
// WriteIdentifier adds identifier to output SQL
|
// WriteIdentifier adds identifier to output SQL
|
||||||
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
|
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
|
||||||
if shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
|
if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
|
||||||
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
|
identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
|
||||||
s.WriteString(identQuoteChar + name + identQuoteChar)
|
s.WriteString(identQuoteChar + name + identQuoteChar)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -211,6 +211,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AssertDeepEqual checks if actual and expected objects are deeply equal.
|
||||||
func AssertDeepEqual(t *testing.T, actual, expected interface{}) {
|
func AssertDeepEqual(t *testing.T, actual, expected interface{}) {
|
||||||
assert.True(t, cmp.Equal(actual, expected))
|
assert.True(t, cmp.Equal(actual, expected))
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ func ToGoIdentifier(databaseIdentifier string) string {
|
||||||
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
|
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToGoEnumValueIdentifier converts enum value name to Go identifier name.
|
||||||
func ToGoEnumValueIdentifier(enumName, enumValue string) string {
|
func ToGoEnumValueIdentifier(enumName, enumValue string) string {
|
||||||
enumValueIdentifier := ToGoIdentifier(enumValue)
|
enumValueIdentifier := ToGoIdentifier(enumValue)
|
||||||
if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) {
|
if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ func newDialect() jet.Dialect {
|
||||||
ArgumentPlaceholder: func(ord int) string {
|
ArgumentPlaceholder: func(ord int) string {
|
||||||
return "$" + strconv.Itoa(ord)
|
return "$" + strconv.Itoa(ord)
|
||||||
},
|
},
|
||||||
|
ReservedWords: reservedWords,
|
||||||
}
|
}
|
||||||
|
|
||||||
return jet.NewDialect(dialectParams)
|
return jet.NewDialect(dialectParams)
|
||||||
|
|
@ -105,3 +106,83 @@ func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.Serializer
|
||||||
jet.Serialize(expressions[1], statement, out, options...)
|
jet.Serialize(expressions[1], statement, out, options...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var reservedWords = []string{
|
||||||
|
"ALL",
|
||||||
|
"ANALYSE",
|
||||||
|
"ANALYZE",
|
||||||
|
"AND",
|
||||||
|
"ANY",
|
||||||
|
"ARRAY",
|
||||||
|
"AS",
|
||||||
|
"ASC",
|
||||||
|
"ASYMMETRIC",
|
||||||
|
"BOTH",
|
||||||
|
"CASE",
|
||||||
|
"CAST",
|
||||||
|
"CHECK",
|
||||||
|
"COLLATE",
|
||||||
|
"COLUMN",
|
||||||
|
"CONSTRAINT",
|
||||||
|
"CREATE",
|
||||||
|
"CURRENT_CATALOG",
|
||||||
|
"CURRENT_DATE",
|
||||||
|
"CURRENT_ROLE",
|
||||||
|
"CURRENT_TIME",
|
||||||
|
"CURRENT_TIMESTAMP",
|
||||||
|
"CURRENT_USER",
|
||||||
|
"DEFAULT",
|
||||||
|
"DEFERRABLE",
|
||||||
|
"DESC",
|
||||||
|
"DISTINCT",
|
||||||
|
"DO",
|
||||||
|
"ELSE",
|
||||||
|
"END",
|
||||||
|
"EXCEPT",
|
||||||
|
"FALSE",
|
||||||
|
"FETCH",
|
||||||
|
"FOR",
|
||||||
|
"FOREIGN",
|
||||||
|
"FROM",
|
||||||
|
"GRANT",
|
||||||
|
"GROUP",
|
||||||
|
"HAVING",
|
||||||
|
"IN",
|
||||||
|
"INITIALLY",
|
||||||
|
"INTERSECT",
|
||||||
|
"INTO",
|
||||||
|
"LATERAL",
|
||||||
|
"LEADING",
|
||||||
|
"LIMIT",
|
||||||
|
"LOCALTIME",
|
||||||
|
"LOCALTIMESTAMP",
|
||||||
|
"NOT",
|
||||||
|
"NULL",
|
||||||
|
"OFFSET",
|
||||||
|
"ON",
|
||||||
|
"ONLY",
|
||||||
|
"OR",
|
||||||
|
"ORDER",
|
||||||
|
"PLACING",
|
||||||
|
"PRIMARY",
|
||||||
|
"REFERENCES",
|
||||||
|
"RETURNING",
|
||||||
|
"SELECT",
|
||||||
|
"SESSION_USER",
|
||||||
|
"SOME",
|
||||||
|
"SYMMETRIC",
|
||||||
|
"TABLE",
|
||||||
|
"THEN",
|
||||||
|
"TO",
|
||||||
|
"TRAILING",
|
||||||
|
"TRUE",
|
||||||
|
"UNION",
|
||||||
|
"UNIQUE",
|
||||||
|
"USER",
|
||||||
|
"USING",
|
||||||
|
"VARIADIC",
|
||||||
|
"WHEN",
|
||||||
|
"WHERE",
|
||||||
|
"WINDOW",
|
||||||
|
"WITH",
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -74,3 +74,21 @@ func TestNOT_IN(t *testing.T) {
|
||||||
FROM db.table2
|
FROM db.table2
|
||||||
)))`, int64(12))
|
)))`, int64(12))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReservedWordEscaped(t *testing.T) {
|
||||||
|
var table1ColUser = IntervalColumn("user")
|
||||||
|
var table1ColVariadic = IntervalColumn("VARIADIC")
|
||||||
|
var table1ColProcedure = IntervalColumn("procedure")
|
||||||
|
|
||||||
|
_ = NewTable(
|
||||||
|
"db",
|
||||||
|
"table1",
|
||||||
|
table1ColUser,
|
||||||
|
table1ColVariadic,
|
||||||
|
table1ColProcedure,
|
||||||
|
)
|
||||||
|
|
||||||
|
assertSerialize(t, table1ColUser, `table1."user"`)
|
||||||
|
assertSerialize(t, table1ColVariadic, `table1."VARIADIC"`)
|
||||||
|
assertSerialize(t, table1ColProcedure, `table1.procedure`)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1095,3 +1095,53 @@ var allTypesJson = `
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
`
|
`
|
||||||
|
|
||||||
|
func TestReservedWord(t *testing.T) {
|
||||||
|
stmt := SELECT(User.AllColumns).
|
||||||
|
FROM(User)
|
||||||
|
|
||||||
|
// NOTE: A word that follows a period in a qualified name must be an identifier, so it
|
||||||
|
// need not be quoted even if it is reserved
|
||||||
|
testutils.AssertDebugStatementSql(t, stmt, `
|
||||||
|
SELECT user.column AS "user.column",
|
||||||
|
user.use AS "user.use",
|
||||||
|
user.ceil AS "user.ceil",
|
||||||
|
user.commit AS "user.commit",
|
||||||
|
user.create AS "user.create",
|
||||||
|
user.default AS "user.default",
|
||||||
|
user.desc AS "user.desc",
|
||||||
|
user.empty AS "user.empty",
|
||||||
|
user.float AS "user.float",
|
||||||
|
user.join AS "user.join",
|
||||||
|
user.like AS "user.like",
|
||||||
|
user.max AS "user.max",
|
||||||
|
user.rank AS "user.rank"
|
||||||
|
FROM test_sample.user;
|
||||||
|
`)
|
||||||
|
|
||||||
|
var dest []model.User
|
||||||
|
err := stmt.Query(db, &dest)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
testutils.PrintJson(dest)
|
||||||
|
|
||||||
|
testutils.AssertJSON(t, dest, `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"Column": "Column",
|
||||||
|
"Use": "CHECK",
|
||||||
|
"Ceil": "CEIL",
|
||||||
|
"Commit": "COMMIT",
|
||||||
|
"Create": "CREATE",
|
||||||
|
"Default": "DEFAULT",
|
||||||
|
"Desc": "DESC",
|
||||||
|
"Empty": "EMPTY",
|
||||||
|
"Float": "FLOAT",
|
||||||
|
"Join": "JOIN",
|
||||||
|
"Like": "LIKE",
|
||||||
|
"Max": "MAX",
|
||||||
|
"Rank": "RANK"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -291,7 +291,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go",
|
testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go",
|
||||||
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go")
|
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go")
|
||||||
|
|
||||||
testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent)
|
testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent)
|
||||||
|
|
||||||
|
|
@ -299,7 +299,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go",
|
testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go",
|
||||||
"person.go", "person_phone.go", "weird_names_table.go")
|
"person.go", "person_phone.go", "weird_names_table.go", "user.go")
|
||||||
|
|
||||||
testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent)
|
testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -328,3 +328,54 @@ FROM test_sample."WEIRD NAMES TABLE";
|
||||||
WeirdColuName16: "Doe",
|
WeirdColuName16: "Doe",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReserwedWordEscape(t *testing.T) {
|
||||||
|
stmt := SELECT(User.AllColumns).
|
||||||
|
FROM(User)
|
||||||
|
|
||||||
|
//fmt.Println(stmt.DebugSql())
|
||||||
|
|
||||||
|
testutils.AssertDebugStatementSql(t, stmt, `
|
||||||
|
SELECT "User"."column" AS "User.column",
|
||||||
|
"User"."check" AS "User.check",
|
||||||
|
"User".ceil AS "User.ceil",
|
||||||
|
"User".commit AS "User.commit",
|
||||||
|
"User"."create" AS "User.create",
|
||||||
|
"User"."default" AS "User.default",
|
||||||
|
"User"."desc" AS "User.desc",
|
||||||
|
"User".empty AS "User.empty",
|
||||||
|
"User".float AS "User.float",
|
||||||
|
"User".join AS "User.join",
|
||||||
|
"User".like AS "User.like",
|
||||||
|
"User".max AS "User.max",
|
||||||
|
"User".rank AS "User.rank"
|
||||||
|
FROM test_sample."User";
|
||||||
|
`)
|
||||||
|
|
||||||
|
var dest []model.User
|
||||||
|
|
||||||
|
err := stmt.Query(db, &dest)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
testutils.PrintJson(dest)
|
||||||
|
|
||||||
|
testutils.AssertJSON(t, dest, `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"Column": "Column",
|
||||||
|
"Check": "CHECK",
|
||||||
|
"Ceil": "CEIL",
|
||||||
|
"Commit": "COMMIT",
|
||||||
|
"Create": "CREATE",
|
||||||
|
"Default": "DEFAULT",
|
||||||
|
"Desc": "DESC",
|
||||||
|
"Empty": "EMPTY",
|
||||||
|
"Float": "FLOAT",
|
||||||
|
"Join": "JOIN",
|
||||||
|
"Like": "LIKE",
|
||||||
|
"Max": "MAX",
|
||||||
|
"Rank": "RANK"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6
|
Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17
|
||||||
Loading…
Add table
Add a link
Reference in a new issue