[bug] Escape reserved words used as identifier.

This commit is contained in:
go-jet 2020-02-16 10:25:21 +01:00
parent 63c1fd6430
commit 3019fdbbb2
10 changed files with 226 additions and 4 deletions

View file

@ -1,5 +1,7 @@
package jet package jet
import "strings"
// Dialect interface // Dialect interface
type Dialect interface { type Dialect interface {
Name() string Name() string
@ -9,6 +11,7 @@ type Dialect interface {
AliasQuoteChar() byte AliasQuoteChar() byte
IdentifierQuoteChar() byte IdentifierQuoteChar() byte
ArgumentPlaceholder() QueryPlaceholderFunc ArgumentPlaceholder() QueryPlaceholderFunc
IsReservedWord(name string) bool
} }
// SerializerFunc func // SerializerFunc func
@ -29,6 +32,7 @@ type DialectParams struct {
AliasQuoteChar byte AliasQuoteChar byte
IdentifierQuoteChar byte IdentifierQuoteChar byte
ArgumentPlaceholder QueryPlaceholderFunc ArgumentPlaceholder QueryPlaceholderFunc
ReservedWords []string
} }
// NewDialect creates new dialect with params // NewDialect creates new dialect with params
@ -41,6 +45,7 @@ func NewDialect(params DialectParams) Dialect {
aliasQuoteChar: params.AliasQuoteChar, aliasQuoteChar: params.AliasQuoteChar,
identifierQuoteChar: params.IdentifierQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar,
argumentPlaceholder: params.ArgumentPlaceholder, argumentPlaceholder: params.ArgumentPlaceholder,
reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords),
} }
} }
@ -52,6 +57,7 @@ type dialectImpl struct {
aliasQuoteChar byte aliasQuoteChar byte
identifierQuoteChar byte identifierQuoteChar byte
argumentPlaceholder QueryPlaceholderFunc argumentPlaceholder QueryPlaceholderFunc
reservedWords map[string]bool
supportsReturning bool supportsReturning bool
} }
@ -89,3 +95,17 @@ func (d *dialectImpl) IdentifierQuoteChar() byte {
func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc {
return d.argumentPlaceholder return d.argumentPlaceholder
} }
func (d *dialectImpl) IsReservedWord(name string) bool {
_, isReservedWord := d.reservedWords[strings.ToLower(name)]
return isReservedWord
}
func arrayOfStringsToMapOfStrings(arr []string) map[string]bool {
ret := map[string]bool{}
for _, elem := range arr {
ret[strings.ToLower(elem)] = true
}
return ret
}

View file

@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) {
// WriteIdentifier adds identifier to output SQL // WriteIdentifier adds identifier to output SQL
func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) {
if shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 {
identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
s.WriteString(identQuoteChar + name + identQuoteChar) s.WriteString(identQuoteChar + name + identQuoteChar)
} else { } else {

View file

@ -211,6 +211,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
} }
} }
// AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}) { func AssertDeepEqual(t *testing.T, actual, expected interface{}) {
assert.True(t, cmp.Equal(actual, expected)) assert.True(t, cmp.Equal(actual, expected))
} }

View file

@ -18,6 +18,7 @@ func ToGoIdentifier(databaseIdentifier string) string {
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
} }
// ToGoEnumValueIdentifier converts enum value name to Go identifier name.
func ToGoEnumValueIdentifier(enumName, enumValue string) string { func ToGoEnumValueIdentifier(enumName, enumValue string) string {
enumValueIdentifier := ToGoIdentifier(enumValue) enumValueIdentifier := ToGoIdentifier(enumValue)
if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) { if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) {

View file

@ -24,6 +24,7 @@ func newDialect() jet.Dialect {
ArgumentPlaceholder: func(ord int) string { ArgumentPlaceholder: func(ord int) string {
return "$" + strconv.Itoa(ord) return "$" + strconv.Itoa(ord)
}, },
ReservedWords: reservedWords,
} }
return jet.NewDialect(dialectParams) return jet.NewDialect(dialectParams)
@ -105,3 +106,83 @@ func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.Serializer
jet.Serialize(expressions[1], statement, out, options...) jet.Serialize(expressions[1], statement, out, options...)
} }
} }
var reservedWords = []string{
"ALL",
"ANALYSE",
"ANALYZE",
"AND",
"ANY",
"ARRAY",
"AS",
"ASC",
"ASYMMETRIC",
"BOTH",
"CASE",
"CAST",
"CHECK",
"COLLATE",
"COLUMN",
"CONSTRAINT",
"CREATE",
"CURRENT_CATALOG",
"CURRENT_DATE",
"CURRENT_ROLE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"CURRENT_USER",
"DEFAULT",
"DEFERRABLE",
"DESC",
"DISTINCT",
"DO",
"ELSE",
"END",
"EXCEPT",
"FALSE",
"FETCH",
"FOR",
"FOREIGN",
"FROM",
"GRANT",
"GROUP",
"HAVING",
"IN",
"INITIALLY",
"INTERSECT",
"INTO",
"LATERAL",
"LEADING",
"LIMIT",
"LOCALTIME",
"LOCALTIMESTAMP",
"NOT",
"NULL",
"OFFSET",
"ON",
"ONLY",
"OR",
"ORDER",
"PLACING",
"PRIMARY",
"REFERENCES",
"RETURNING",
"SELECT",
"SESSION_USER",
"SOME",
"SYMMETRIC",
"TABLE",
"THEN",
"TO",
"TRAILING",
"TRUE",
"UNION",
"UNIQUE",
"USER",
"USING",
"VARIADIC",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
}

View file

@ -74,3 +74,21 @@ func TestNOT_IN(t *testing.T) {
FROM db.table2 FROM db.table2
)))`, int64(12)) )))`, int64(12))
} }
func TestReservedWordEscaped(t *testing.T) {
var table1ColUser = IntervalColumn("user")
var table1ColVariadic = IntervalColumn("VARIADIC")
var table1ColProcedure = IntervalColumn("procedure")
_ = NewTable(
"db",
"table1",
table1ColUser,
table1ColVariadic,
table1ColProcedure,
)
assertSerialize(t, table1ColUser, `table1."user"`)
assertSerialize(t, table1ColVariadic, `table1."VARIADIC"`)
assertSerialize(t, table1ColProcedure, `table1.procedure`)
}

View file

@ -1095,3 +1095,53 @@ var allTypesJson = `
} }
] ]
` `
func TestReservedWord(t *testing.T) {
stmt := SELECT(User.AllColumns).
FROM(User)
// NOTE: A word that follows a period in a qualified name must be an identifier, so it
// need not be quoted even if it is reserved
testutils.AssertDebugStatementSql(t, stmt, `
SELECT user.column AS "user.column",
user.use AS "user.use",
user.ceil AS "user.ceil",
user.commit AS "user.commit",
user.create AS "user.create",
user.default AS "user.default",
user.desc AS "user.desc",
user.empty AS "user.empty",
user.float AS "user.float",
user.join AS "user.join",
user.like AS "user.like",
user.max AS "user.max",
user.rank AS "user.rank"
FROM test_sample.user;
`)
var dest []model.User
err := stmt.Query(db, &dest)
assert.NoError(t, err)
testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, `
[
{
"Column": "Column",
"Use": "CHECK",
"Ceil": "CEIL",
"Commit": "COMMIT",
"Create": "CREATE",
"Default": "DEFAULT",
"Desc": "DESC",
"Empty": "EMPTY",
"Float": "FLOAT",
"Join": "JOIN",
"Like": "LIKE",
"Max": "MAX",
"Rank": "RANK"
}
]
`)
}

View file

@ -291,7 +291,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go",
"mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go") "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go")
testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent) testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent)
@ -299,7 +299,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go",
"person.go", "person_phone.go", "weird_names_table.go") "person.go", "person_phone.go", "weird_names_table.go", "user.go")
testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent) testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent)
} }

View file

@ -328,3 +328,54 @@ FROM test_sample."WEIRD NAMES TABLE";
WeirdColuName16: "Doe", WeirdColuName16: "Doe",
}) })
} }
func TestReserwedWordEscape(t *testing.T) {
stmt := SELECT(User.AllColumns).
FROM(User)
//fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, `
SELECT "User"."column" AS "User.column",
"User"."check" AS "User.check",
"User".ceil AS "User.ceil",
"User".commit AS "User.commit",
"User"."create" AS "User.create",
"User"."default" AS "User.default",
"User"."desc" AS "User.desc",
"User".empty AS "User.empty",
"User".float AS "User.float",
"User".join AS "User.join",
"User".like AS "User.like",
"User".max AS "User.max",
"User".rank AS "User.rank"
FROM test_sample."User";
`)
var dest []model.User
err := stmt.Query(db, &dest)
assert.NoError(t, err)
testutils.PrintJson(dest)
testutils.AssertJSON(t, dest, `
[
{
"Column": "Column",
"Check": "CHECK",
"Ceil": "CEIL",
"Commit": "COMMIT",
"Create": "CREATE",
"Default": "DEFAULT",
"Desc": "DESC",
"Empty": "EMPTY",
"Float": "FLOAT",
"Join": "JOIN",
"Like": "LIKE",
"Max": "MAX",
"Rank": "RANK"
}
]
`)
}

@ -1 +1 @@
Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6 Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17