diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index acf03d9..434e3b8 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -1,5 +1,7 @@ package jet +import "strings" + // Dialect interface type Dialect interface { Name() string @@ -9,6 +11,7 @@ type Dialect interface { AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc + IsReservedWord(name string) bool } // SerializerFunc func @@ -29,6 +32,7 @@ type DialectParams struct { AliasQuoteChar byte IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc + ReservedWords []string } // NewDialect creates new dialect with params @@ -41,6 +45,7 @@ func NewDialect(params DialectParams) Dialect { aliasQuoteChar: params.AliasQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, + reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), } } @@ -52,6 +57,7 @@ type dialectImpl struct { aliasQuoteChar byte identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc + reservedWords map[string]bool supportsReturning bool } @@ -89,3 +95,17 @@ func (d *dialectImpl) IdentifierQuoteChar() byte { func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { return d.argumentPlaceholder } + +func (d *dialectImpl) IsReservedWord(name string) bool { + _, isReservedWord := d.reservedWords[strings.ToLower(name)] + return isReservedWord +} + +func arrayOfStringsToMapOfStrings(arr []string) map[string]bool { + ret := map[string]bool{} + for _, elem := range arr { + ret[strings.ToLower(elem)] = true + } + + return ret +} diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index f71ca1e..ef7f801 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -98,7 +98,7 @@ func (s *SQLBuilder) WriteString(str string) { // WriteIdentifier adds identifier to output SQL func (s *SQLBuilder) WriteIdentifier(name string, alwaysQuote ...bool) { - if shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { + if s.Dialect.IsReservedWord(name) || shouldQuoteIdentifier(name) || len(alwaysQuote) > 0 { identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) s.WriteString(identQuoteChar + name + identQuoteChar) } else { diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 5719e75..f0c18e9 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -211,6 +211,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } } +// AssertDeepEqual checks if actual and expected objects are deeply equal. func AssertDeepEqual(t *testing.T, actual, expected interface{}) { assert.True(t, cmp.Equal(actual, expected)) } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 5301605..e346775 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -18,6 +18,7 @@ func ToGoIdentifier(databaseIdentifier string) string { return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) } +// ToGoEnumValueIdentifier converts enum value name to Go identifier name. func ToGoEnumValueIdentifier(enumName, enumValue string) string { enumValueIdentifier := ToGoIdentifier(enumValue) if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) { diff --git a/postgres/dialect.go b/postgres/dialect.go index c1e8c0b..b440c5d 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -24,6 +24,7 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, + ReservedWords: reservedWords, } return jet.NewDialect(dialectParams) @@ -105,3 +106,83 @@ func postgresNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.Serializer jet.Serialize(expressions[1], statement, out, options...) } } + +var reservedWords = []string{ + "ALL", + "ANALYSE", + "ANALYZE", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASYMMETRIC", + "BOTH", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "CONSTRAINT", + "CREATE", + "CURRENT_CATALOG", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "DEFAULT", + "DEFERRABLE", + "DESC", + "DISTINCT", + "DO", + "ELSE", + "END", + "EXCEPT", + "FALSE", + "FETCH", + "FOR", + "FOREIGN", + "FROM", + "GRANT", + "GROUP", + "HAVING", + "IN", + "INITIALLY", + "INTERSECT", + "INTO", + "LATERAL", + "LEADING", + "LIMIT", + "LOCALTIME", + "LOCALTIMESTAMP", + "NOT", + "NULL", + "OFFSET", + "ON", + "ONLY", + "OR", + "ORDER", + "PLACING", + "PRIMARY", + "REFERENCES", + "RETURNING", + "SELECT", + "SESSION_USER", + "SOME", + "SYMMETRIC", + "TABLE", + "THEN", + "TO", + "TRAILING", + "TRUE", + "UNION", + "UNIQUE", + "USER", + "USING", + "VARIADIC", + "WHEN", + "WHERE", + "WINDOW", + "WITH", +} diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index f53587e..7bf9242 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -74,3 +74,21 @@ func TestNOT_IN(t *testing.T) { FROM db.table2 )))`, int64(12)) } + +func TestReservedWordEscaped(t *testing.T) { + var table1ColUser = IntervalColumn("user") + var table1ColVariadic = IntervalColumn("VARIADIC") + var table1ColProcedure = IntervalColumn("procedure") + + _ = NewTable( + "db", + "table1", + table1ColUser, + table1ColVariadic, + table1ColProcedure, + ) + + assertSerialize(t, table1ColUser, `table1."user"`) + assertSerialize(t, table1ColVariadic, `table1."VARIADIC"`) + assertSerialize(t, table1ColProcedure, `table1.procedure`) +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 98a1a4a..2c6768e 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1095,3 +1095,53 @@ var allTypesJson = ` } ] ` + +func TestReservedWord(t *testing.T) { + stmt := SELECT(User.AllColumns). + FROM(User) + + // NOTE: A word that follows a period in a qualified name must be an identifier, so it + // need not be quoted even if it is reserved + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT user.column AS "user.column", + user.use AS "user.use", + user.ceil AS "user.ceil", + user.commit AS "user.commit", + user.create AS "user.create", + user.default AS "user.default", + user.desc AS "user.desc", + user.empty AS "user.empty", + user.float AS "user.float", + user.join AS "user.join", + user.like AS "user.like", + user.max AS "user.max", + user.rank AS "user.rank" +FROM test_sample.user; +`) + + var dest []model.User + err := stmt.Query(db, &dest) + assert.NoError(t, err) + + testutils.PrintJson(dest) + + testutils.AssertJSON(t, dest, ` +[ + { + "Column": "Column", + "Use": "CHECK", + "Ceil": "CEIL", + "Commit": "COMMIT", + "Create": "CREATE", + "Default": "DEFAULT", + "Desc": "DESC", + "Empty": "EMPTY", + "Float": "FLOAT", + "Join": "JOIN", + "Like": "LIKE", + "Max": "MAX", + "Rank": "RANK" + } +] +`) +} diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 52ddfac..54d7cc4 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -291,7 +291,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { assert.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", - "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go") + "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go") testutils.AssertFileContent(t, modelDir+"all_types.go", "\npackage model", allTypesModelContent) @@ -299,7 +299,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { assert.NoError(t, err) testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", - "person.go", "person_phone.go", "weird_names_table.go") + "person.go", "person_phone.go", "weird_names_table.go", "user.go") testutils.AssertFileContent(t, tableDir+"all_types.go", "\npackage table", allTypesTableContent) } diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 5989d17..2cfbc85 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -328,3 +328,54 @@ FROM test_sample."WEIRD NAMES TABLE"; WeirdColuName16: "Doe", }) } + +func TestReserwedWordEscape(t *testing.T) { + stmt := SELECT(User.AllColumns). + FROM(User) + + //fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT "User"."column" AS "User.column", + "User"."check" AS "User.check", + "User".ceil AS "User.ceil", + "User".commit AS "User.commit", + "User"."create" AS "User.create", + "User"."default" AS "User.default", + "User"."desc" AS "User.desc", + "User".empty AS "User.empty", + "User".float AS "User.float", + "User".join AS "User.join", + "User".like AS "User.like", + "User".max AS "User.max", + "User".rank AS "User.rank" +FROM test_sample."User"; +`) + + var dest []model.User + + err := stmt.Query(db, &dest) + assert.NoError(t, err) + + testutils.PrintJson(dest) + + testutils.AssertJSON(t, dest, ` +[ + { + "Column": "Column", + "Check": "CHECK", + "Ceil": "CEIL", + "Commit": "COMMIT", + "Create": "CREATE", + "Default": "DEFAULT", + "Desc": "DESC", + "Empty": "EMPTY", + "Float": "FLOAT", + "Join": "JOIN", + "Like": "LIKE", + "Max": "MAX", + "Rank": "RANK" + } +] +`) +} diff --git a/tests/testdata b/tests/testdata index 02e0795..889e07c 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6 +Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17