diff --git a/.circleci/config.yml b/.circleci/config.yml index 06d7df9..c040498 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -28,7 +28,7 @@ jobs: command: | go get github.com/google/uuid go get github.com/lib/pq - go get github.com/serenize/snaker + go get github.com/pkg/profile go get gotest.tools/assert go get github.com/davecgh/go-spew/spew diff --git a/clause.go b/clause.go index ec6b014..d2a08e9 100644 --- a/clause.go +++ b/clause.go @@ -171,7 +171,7 @@ func (q *queryData) writeString(str string) { } func (q *queryData) writeIdentifier(name string) { - quoteWrap := name != strings.ToLower(name) || strings.Contains(name, ".") + quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -") if quoteWrap { q.writeString(`"` + name + `"`) diff --git a/execution/execution.go b/execution/execution.go index 587d489..d836770 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -7,7 +7,7 @@ import ( "errors" "fmt" "github.com/go-jet/jet/execution/internal" - "github.com/serenize/snaker" + "github.com/go-jet/jet/internal/util" "reflect" "strconv" "strings" @@ -139,10 +139,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl if isGoBaseType(sliceElemType) { index := 0 if structField != nil { - tableName, columnName := getRefAlias(structField) - index = scanContext.columnIndex(tableName, columnName) - - if index < 0 { + if index = scanContext.aliasColumnIndex(structField.Tag.Get("alias")); index < 0 { return } } @@ -293,28 +290,24 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } -func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { +func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { structType := structPtrValue.Type().Elem() structValue := structPtrValue.Elem() - tableName, _ := getRefAlias(structField) - - if tableName == "" { - tableName = structType.Name() - } + typeName := getTypeName(structType, parentField) for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) fieldValue := structValue.Field(i) - columnName := field.Name + fieldName := field.Name if scannerValue, ok := implementsScanner(fieldValue); ok { if len(onlySlices) > 0 { continue } - cellValue := scanContext.getCellValue(tableName, columnName) + cellValue := scanContext.getCellValue(typeName, fieldName) if cellValue == nil { continue @@ -336,7 +329,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re continue } - cellValue := scanContext.getCellValue(tableName, columnName) + cellValue := scanContext.getCellValue(typeName, fieldName) if cellValue != nil { updated = true @@ -365,26 +358,20 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re return } -func getRefAlias(structField *reflect.StructField) (table, column string) { - if structField == nil { - return +func getTypeName(structType reflect.Type, parentField *reflect.StructField) string { + if parentField == nil { + return structType.Name() } - aliasTag := structField.Tag.Get("alias") + aliasTag := parentField.Tag.Get("alias") if aliasTag == "" { - return + return structType.Name() } aliasParts := strings.Split(aliasTag, ".") - table = aliasParts[0] - - if len(aliasParts) > 1 { - column = aliasParts[1] - } - - return + return aliasParts[0] } func initializeValueIfNilPtr(value reflect.Value) { @@ -533,12 +520,14 @@ type scanContext struct { row []interface{} uniqueDestObjectsMap map[string]int - columnNameIndexMap map[string]int - groupKeyInfoCache map[string]groupKeyInfo + aliasIndexMap map[string]int + goNameMap map[string]int + + groupKeyInfoCache map[string]groupKeyInfo } func newScanContext(rows *sql.Rows) (*scanContext, error) { - columnNames, err := rows.Columns() + aliases, err := rows.Columns() if err != nil { return nil, err @@ -550,10 +539,24 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { return nil, err } - columnNameIndexMap := map[string]int{} + aliasIndexMap := map[string]int{} - for i, columnName := range columnNames { - columnNameIndexMap[strings.ToLower(columnName)] = i + for i, columnName := range aliases { + aliasIndexMap[strings.ToLower(columnName)] = i + } + + goNamesMap := map[string]int{} + + for i, alias := range aliases { + names := strings.SplitN(alias, ".", 2) + + goName := util.ToGoIdentifier(names[0]) + + if len(names) > 1 { + goName += "." + util.ToGoIdentifier(names[1]) + } + + goNamesMap[strings.ToLower(goName)] = i } return &scanContext{ @@ -561,8 +564,8 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { uniqueDestObjectsMap: make(map[string]int), groupKeyInfoCache: make(map[string]groupKeyInfo), - - columnNameIndexMap: columnNameIndexMap, + aliasIndexMap: aliasIndexMap, + goNameMap: goNamesMap, }, nil } @@ -607,12 +610,8 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { return "{" + groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")}" } -func (s *scanContext) getGroupKeyInfo(structType reflect.Type, structField *reflect.StructField) groupKeyInfo { - tableName, _ := getRefAlias(structField) - - if tableName == "" { - tableName = structType.Name() - } +func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { + typeName := getTypeName(structType, parentField) ret := groupKeyInfo{typeName: structType.Name()} @@ -635,7 +634,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, structField *refl ret.subTypes = append(ret.subTypes, subType) } } else if isPrimaryKey(field) { - index := s.columnIndex(tableName, field.Name) + index := s.typeColumnIndex(typeName, field.Name) if index < 0 { continue @@ -654,47 +653,36 @@ type groupKeyInfo struct { subTypes []groupKeyInfo } -func (s *scanContext) columnIndex(tableName, columnName string) int { - if tableName == "" { - name := strings.ToLower(columnName) - if i, ok := s.columnNameIndexMap[name]; ok { - return i - } +func (s *scanContext) aliasColumnIndex(alias string) int { + index, ok := s.aliasIndexMap[alias] - name = strings.ToLower(snaker.CamelToSnake(columnName)) - if i, ok := s.columnNameIndexMap[name]; ok { - return i - } - } else { - name := strings.ToLower(tableName + "." + columnName) - if i, ok := s.columnNameIndexMap[name]; ok { - return i - } - - snakedTableName := snaker.CamelToSnake(tableName) - snakedColumnName := snaker.CamelToSnake(columnName) - - name = strings.ToLower(snakedTableName + "." + snakedColumnName) - if i, ok := s.columnNameIndexMap[name]; ok { - return i - } - - name = strings.ToLower(tableName + "." + snakedColumnName) - if i, ok := s.columnNameIndexMap[name]; ok { - return i - } - - name = strings.ToLower(snakedTableName + "." + columnName) - if i, ok := s.columnNameIndexMap[name]; ok { - return i - } + if !ok { + return -1 } - return -1 + return index } -func (s *scanContext) getCellValue(tableName, fieldName string) interface{} { - index := s.columnIndex(tableName, fieldName) +func (s *scanContext) typeColumnIndex(typeName, fieldName string) int { + var key string + + if typeName != "" { + key = strings.ToLower(typeName + "." + fieldName) + } else { + key = strings.ToLower(fieldName) + } + + index, ok := s.goNameMap[key] + + if !ok { + return -1 + } + + return index +} + +func (s *scanContext) getCellValue(typeName, fieldName string) interface{} { + index := s.typeColumnIndex(typeName, fieldName) if index < 0 { return nil diff --git a/generator/internal/metadata/postgres-metadata/column_info.go b/generator/internal/metadata/postgres-metadata/column_info.go index 2e825a8..20e168d 100644 --- a/generator/internal/metadata/postgres-metadata/column_info.go +++ b/generator/internal/metadata/postgres-metadata/column_info.go @@ -3,7 +3,7 @@ package postgres_metadata import ( "database/sql" "fmt" - "github.com/serenize/snaker" + "github.com/go-jet/jet/internal/util" "strings" ) @@ -44,7 +44,7 @@ func (c ColumnInfo) SqlBuilderColumnType() string { func (c ColumnInfo) GoBaseType() string { switch c.DataType { case "USER-DEFINED": - return snaker.SnakeToCamel(c.EnumName) + return util.ToGoIdentifier(c.EnumName) case "boolean": return "bool" case "smallint": diff --git a/generator/internal/metadata/postgres-metadata/table_info.go b/generator/internal/metadata/postgres-metadata/table_info.go index 43e8265..3b7c67e 100644 --- a/generator/internal/metadata/postgres-metadata/table_info.go +++ b/generator/internal/metadata/postgres-metadata/table_info.go @@ -2,7 +2,7 @@ package postgres_metadata import ( "database/sql" - "github.com/serenize/snaker" + "github.com/go-jet/jet/internal/util" ) type TableInfo struct { @@ -58,7 +58,7 @@ func (t TableInfo) GetImports() []string { } func (t TableInfo) GoStructName() string { - return snaker.SnakeToCamel(t.name) + "Table" + return util.ToGoIdentifier(t.name) + "Table" } func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) { diff --git a/generator/internal/utils/utils.go b/generator/internal/utils/utils.go index 343f4d5..f32147c 100644 --- a/generator/internal/utils/utils.go +++ b/generator/internal/utils/utils.go @@ -2,11 +2,10 @@ package utils import ( "bytes" - "github.com/serenize/snaker" + "github.com/go-jet/jet/internal/util" "go/format" "os" "path/filepath" - "strings" "text/template" "time" ) @@ -51,9 +50,7 @@ func EnsureDirPath(dirPath string) error { func GenerateTemplate(templateText string, templateData interface{}) ([]byte, error) { t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ - "camelize": func(txt string) string { - return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1)) - }, + "ToGoIdentifier": util.ToGoIdentifier, "now": func() string { return time.Now().Format(time.RFC850) }, diff --git a/generator/postgresgen/generator.go b/generator/postgresgen/generator.go index 7d48866..dee6026 100644 --- a/generator/postgresgen/generator.go +++ b/generator/postgresgen/generator.go @@ -6,6 +6,7 @@ import ( "github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata/postgres-metadata" "github.com/go-jet/jet/generator/internal/utils" + "github.com/go-jet/jet/internal/util" _ "github.com/lib/pq" "path" "path/filepath" @@ -115,7 +116,7 @@ func generate(schemaInfo postgres_metadata.SchemaInfo, dirPath, packageName stri return err } - err = utils.SaveGoFile(modelDirPath, metaData.Name(), append(autoGenWarning, text...)) + err = utils.SaveGoFile(modelDirPath, util.ToGoFileName(metaData.Name()), append(autoGenWarning, text...)) if err != nil { return err diff --git a/generator/postgresgen/templates.go b/generator/postgresgen/templates.go index 267168a..3e89d59 100644 --- a/generator/postgresgen/templates.go +++ b/generator/postgresgen/templates.go @@ -2,13 +2,13 @@ package postgresgen var autoGenWarningTemplate = ` // -// Code generated by go-jet DO NOT EDIT. +// Code generated by jetgen DO NOT EDIT. // Generated at {{now}} // // WARNING: Changes to this file may cause incorrect behavior and will be lost // if the code is regenerated // -// Licence under ... +// Licence under github.com/go-jet/jet/LICENSE // ` @@ -16,7 +16,7 @@ var autoGenWarningTemplate = ` var sqlBuilderTableTemplate = ` {{define "column-list" -}} {{- range $i, $c := . }} - {{- if gt $i 0 }}, {{end}}{{camelize $c.Name}}Column + {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column {{- end}} {{- end}} @@ -26,14 +26,14 @@ import ( "github.com/go-jet/jet" ) -var {{camelize .Name}} = new{{.GoStructName}}() +var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() type {{.GoStructName}} struct { jet.Table //Columns {{- range .Columns}} - {{camelize .Name}} jet.Column{{.SqlBuilderColumnType}} + {{ToGoIdentifier .Name}} jet.Column{{.SqlBuilderColumnType}} {{- end}} AllColumns jet.ColumnList @@ -52,7 +52,7 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} { func new{{.GoStructName}}() *{{.GoStructName}} { var ( {{- range .Columns}} - {{camelize .Name}}Column = jet.{{.SqlBuilderColumnType}}Column("{{.Name}}") + {{ToGoIdentifier .Name}}Column = jet.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{- end}} ) @@ -61,7 +61,7 @@ func new{{.GoStructName}}() *{{.GoStructName}} { //Columns {{- range .Columns}} - {{camelize .Name}}: {{camelize .Name}}Column, + {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, {{- end}} AllColumns: jet.ColumnList{ {{template "column-list" .Columns}} }, @@ -82,9 +82,9 @@ import ( {{end}} -type {{camelize .Name}} struct { +type {{ToGoIdentifier .Name}} struct { {{- range .Columns}} - {{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + ` + {{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + ` {{- end}} } ` @@ -93,32 +93,32 @@ var enumModelTemplate = `package model import "errors" -type {{camelize $.Name}} string +type {{ToGoIdentifier $.Name}} string const ( {{- range $index, $element := .Values}} - {{camelize $.Name}}_{{camelize $element}} {{camelize $.Name}} = "{{$element}}" + {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} {{ToGoIdentifier $.Name}} = "{{$element}}" {{- end}} ) -func (e *{{camelize $.Name}}) Scan(value interface{}) error { +func (e *{{ToGoIdentifier $.Name}}) Scan(value interface{}) error { if v, ok := value.(string); !ok { - return errors.New("Invalid data for {{camelize $.Name}} enum") + return errors.New("Invalid data for {{ToGoIdentifier $.Name}} enum") } else { switch string(v) { {{- range $index, $element := .Values}} case "{{$element}}": - *e = {{camelize $.Name}}_{{camelize $element}} + *e = {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} {{- end}} default: - return errors.New("Inavlid data " + string(v) + "for {{camelize $.Name}} enum") + return errors.New("Inavlid data " + string(v) + "for {{ToGoIdentifier $.Name}} enum") } return nil } } -func (e {{camelize $.Name}}) String() string { +func (e {{ToGoIdentifier $.Name}}) String() string { return string(e) } @@ -127,13 +127,13 @@ var enumTypeTemplate = `package enum import "github.com/go-jet/jet" -var {{camelize $.Name}} = &struct { +var {{ToGoIdentifier $.Name}} = &struct { {{- range $index, $element := .Values}} - {{camelize $element}} jet.StringExpression + {{ToGoIdentifier $element}} jet.StringExpression {{- end}} } { {{- range $index, $element := .Values}} - {{camelize $element}}: jet.NewEnumValue("{{$element}}"), + {{ToGoIdentifier $element}}: jet.NewEnumValue("{{$element}}"), {{- end}} } ` diff --git a/internal/3rdparty/snaker/snaker.go b/internal/3rdparty/snaker/snaker.go new file mode 100644 index 0000000..bb20fee --- /dev/null +++ b/internal/3rdparty/snaker/snaker.go @@ -0,0 +1,192 @@ +package snaker + +// Package snaker provides methods to convert CamelCase names to snake_case and back. +// It considers the list of allowed initialsms used by github.com/golang/lint/golint (e.g. ID or HTTP) + +import ( + "strings" + "unicode" +) + +// CamelToSnake converts a given string to snake case +func CamelToSnake(s string) string { + var result string + var words []string + var lastPos int + rs := []rune(s) + + for i := 0; i < len(rs); i++ { + if i > 0 && unicode.IsUpper(rs[i]) { + if initialism := startsWithInitialism(s[lastPos:]); initialism != "" { + words = append(words, initialism) + + i += len(initialism) - 1 + lastPos = i + continue + } + + words = append(words, s[lastPos:i]) + lastPos = i + } + } + + // append the last word + if s[lastPos:] != "" { + words = append(words, s[lastPos:]) + } + + for k, word := range words { + if k > 0 { + result += "_" + } + + result += strings.ToLower(word) + } + + return result +} + +func snakeToCamel(s string, upperCase bool) string { + if len(s) == 0 { + return s + } + var result string + + words := strings.Split(s, "_") + + //// if there is no underscore, first try commons and then just return + //if len(words) == 1 { + // if exception := snakeToCamelExceptions[words[0]]; len(exception) > 0 { + // return exception + // } + // + // if upperCase { + // if upper := strings.ToUpper(words[0]); commonInitialisms[upper] { + // return upper + // } + // } + // + // w := []rune(s) + // if upperCase { + // w[0] = unicode.ToUpper(w[0]) + // } else { + // w[0] = unicode.ToLower(w[0]) + // } + // + // return string(w) + //} + + for i, word := range words { + if exception := snakeToCamelExceptions[word]; len(exception) > 0 { + result += exception + continue + } + + if upperCase || i > 0 { + if upper := strings.ToUpper(word); commonInitialisms[upper] { + result += upper + continue + } + } + + if upperCase || i > 0 { + result += camelizeWord(word, len(words) > 1) + } else { + result += word + } + } + + return result +} + +func camelizeWord(word string, force bool) string { + runes := []rune(word) + + for i, r := range runes { + if i == 0 { + runes[i] = unicode.ToUpper(r) + } else { + if !force && unicode.IsLower(r) { // already camelCase + return string(runes) + } + + runes[i] = unicode.ToLower(r) + } + } + + return string(runes) +} + +// SnakeToCamel returns a string converted from snake case to uppercase +func SnakeToCamel(s string) string { + return snakeToCamel(s, true) +} + +// SnakeToCamelLower returns a string converted from snake case to lowercase +func SnakeToCamelLower(s string) string { + return snakeToCamel(s, false) +} + +// startsWithInitialism returns the initialism if the given string begins with it +func startsWithInitialism(s string) string { + var initialism string + // the longest initialism is 5 char, the shortest 2 + for i := 1; i <= 5; i++ { + if len(s) > i-1 && commonInitialisms[s[:i]] { + initialism = s[:i] + } + } + return initialism +} + +// commonInitialisms, taken from +// https://github.com/golang/lint/blob/206c0f020eba0f7fbcfbc467a5eb808037df2ed6/lint.go#L731 +var commonInitialisms = map[string]bool{ + "ACL": true, + "API": true, + "ASCII": true, + "CPU": true, + "CSS": true, + "DNS": true, + "EOF": true, + "ETA": true, + "GPU": true, + "GUID": true, + "HTML": true, + "HTTP": true, + "HTTPS": true, + "ID": true, + "IP": true, + "JSON": true, + "LHS": true, + "OS": true, + "QPS": true, + "RAM": true, + "RHS": true, + "RPC": true, + "SLA": true, + "SMTP": true, + "SQL": true, + "SSH": true, + "TCP": true, + "TLS": true, + "TTL": true, + "UDP": true, + "UI": true, + "UID": true, + "UUID": true, + "URI": true, + "URL": true, + "UTF8": true, + "VM": true, + "XML": true, + "XMPP": true, + "XSRF": true, + "XSS": true, + "OAuth": true, +} + +// add exceptions here for things that are not automatically convertable +var snakeToCamelExceptions = map[string]string{ + "oauth": "OAuth", +} diff --git a/internal/3rdparty/snaker/snaker_suite_test.go b/internal/3rdparty/snaker/snaker_suite_test.go new file mode 100644 index 0000000..da69837 --- /dev/null +++ b/internal/3rdparty/snaker/snaker_suite_test.go @@ -0,0 +1,13 @@ +package snaker + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "testing" +) + +func TestDb(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Snaker Suite") +} diff --git a/internal/3rdparty/snaker/snaker_test.go b/internal/3rdparty/snaker/snaker_test.go new file mode 100644 index 0000000..443409a --- /dev/null +++ b/internal/3rdparty/snaker/snaker_test.go @@ -0,0 +1,121 @@ +package snaker + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Snaker", func() { + Describe("CamelToSnake test", func() { + It("should return an empty string on an empty input", func() { + Expect(CamelToSnake("")).To(Equal("")) + }) + + It("should work with one word", func() { + Expect(CamelToSnake("One")).To(Equal("one")) + }) + + It("should return an uppercase string as seperate words", func() { + Expect(CamelToSnake("ONE")).To(Equal("o_n_e")) + }) + + It("should return ID as lowercase", func() { + Expect(CamelToSnake("ID")).To(Equal("id")) + }) + + It("should work with a single lowercase character", func() { + Expect(CamelToSnake("i")).To(Equal("i")) + }) + + It("should work with a single uppcase character", func() { + Expect(CamelToSnake("I")).To(Equal("i")) + }) + + It("should return a long text as expected", func() { + Expect(CamelToSnake("ThisHasToBeConvertedCorrectlyID")).To( + Equal("this_has_to_be_converted_correctly_id")) + }) + + It("should return the text as expected if the initialism is in the middle", func() { + Expect(CamelToSnake("ThisIDIsFine")).To(Equal("this_id_is_fine")) + }) + + It("should work with long initialism", func() { + Expect(CamelToSnake("ThisHTTPSConnection")).To(Equal("this_https_connection")) + }) + + It("should work with multi initialisms", func() { + Expect(CamelToSnake("HelloHTTPSConnectionID")).To(Equal("hello_https_connection_id")) + }) + + It("sould work with concat initialisms", func() { + Expect(CamelToSnake("HTTPSID")).To(Equal("https_id")) + }) + + It("sould work with initialism where only certain characters are uppercase", func() { + Expect(CamelToSnake("OAuthClient")).To(Equal("oauth_client")) + }) + }) + + Describe("SnakeToCamel test", func() { + It("should return an empty string on an empty input", func() { + Expect(SnakeToCamel("")).To(Equal("")) + }) + + It("should not blow up on trailing _", func() { + Expect(SnakeToCamel("potato_")).To(Equal("Potato")) + }) + + It("should return a snaked text as camel case", func() { + Expect(SnakeToCamel("this_has_to_be_uppercased")).To( + Equal("ThisHasToBeUppercased")) + }) + + It("should return a snaked text as camel case, except the word ID", func() { + Expect(SnakeToCamel("this_is_an_id")).To(Equal("ThisIsAnID")) + }) + + It("should return 'id' not as uppercase", func() { + Expect(SnakeToCamel("this_is_an_identifier")).To(Equal("ThisIsAnIdentifier")) + }) + + It("should simply work with id", func() { + Expect(SnakeToCamel("id")).To(Equal("ID")) + }) + + It("sould work with initialism where only certain characters are uppercase", func() { + Expect(SnakeToCamel("oauth_client")).To(Equal("OAuthClient")) + }) + }) + + Describe("SnakeToCamelLower test", func() { + It("should return an empty string on an empty input", func() { + Ω(SnakeToCamelLower("")).To(Equal("")) + }) + + It("should not blow up on trailing _", func() { + Ω(SnakeToCamelLower("potato_")).To(Equal("potato")) + }) + + It("should return a snaked text as camel case", func() { + Ω(SnakeToCamelLower("this_has_to_be_uppercased")).To( + Equal("thisHasToBeUppercased")) + }) + + It("should return a snaked text as camel case, except the word ID", func() { + Ω(SnakeToCamelLower("this_is_an_id")).To(Equal("thisIsAnID")) + }) + + It("should return 'id' not as uppercase", func() { + Ω(SnakeToCamelLower("this_is_an_identifier")).To(Equal("thisIsAnIdentifier")) + }) + + It("should simply work with id", func() { + Ω(SnakeToCamelLower("id")).To(Equal("id")) + }) + + It("should simply work with leading id", func() { + Ω(SnakeToCamelLower("id_me_please")).To(Equal("idMePlease")) + }) + }) +}) diff --git a/internal/util/utils.go b/internal/util/utils.go new file mode 100644 index 0000000..cb9f989 --- /dev/null +++ b/internal/util/utils.go @@ -0,0 +1,23 @@ +package util + +import ( + "github.com/go-jet/jet/internal/3rdparty/snaker" + "strings" +) + +func ToGoIdentifier(databaseIdentifier string) string { + if len(databaseIdentifier) == 0 { + return databaseIdentifier + } + databaseIdentifier = strings.ReplaceAll(databaseIdentifier, " ", "_") + databaseIdentifier = strings.ReplaceAll(databaseIdentifier, "-", "_") + + return snaker.SnakeToCamel(databaseIdentifier) +} + +func ToGoFileName(databaseIdentifier string) string { + databaseIdentifier = strings.ReplaceAll(databaseIdentifier, " ", "_") + databaseIdentifier = strings.ReplaceAll(databaseIdentifier, "-", "_") + + return strings.ToLower(databaseIdentifier) +} diff --git a/internal/util/utils_test.go b/internal/util/utils_test.go new file mode 100644 index 0000000..e77b356 --- /dev/null +++ b/internal/util/utils_test.go @@ -0,0 +1,24 @@ +package util + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestToGoIdentifier(t *testing.T) { + assert.Equal(t, ToGoIdentifier("uuid"), "UUID") + assert.Equal(t, ToGoIdentifier("col1"), "Col1") + assert.Equal(t, ToGoIdentifier("PG-13"), "Pg13") + assert.Equal(t, ToGoIdentifier("13_pg"), "13Pg") + + assert.Equal(t, ToGoIdentifier("mytable"), "Mytable") + assert.Equal(t, ToGoIdentifier("MYTABLE"), "Mytable") + assert.Equal(t, ToGoIdentifier("MyTaBlE"), "MyTaBlE") + assert.Equal(t, ToGoIdentifier("myTaBlE"), "MyTaBlE") + + assert.Equal(t, ToGoIdentifier("my_table"), "MyTable") + assert.Equal(t, ToGoIdentifier("MY_TABLE"), "MyTable") + assert.Equal(t, ToGoIdentifier("My_Table"), "MyTable") + assert.Equal(t, ToGoIdentifier("My Table"), "MyTable") + assert.Equal(t, ToGoIdentifier("My-Table"), "MyTable") +} diff --git a/table.go b/table.go index a7dfb2e..8d469d6 100644 --- a/table.go +++ b/table.go @@ -98,12 +98,6 @@ type writableTableInterfaceImpl struct { } func (w *writableTableInterfaceImpl) INSERT(columns ...column) InsertStatement { - //columnList := unwidColumnList(columns) - // - //if len(columns) == 0 { - // columnList = w.parent.columns() - //} - return newInsertStatement(w.parent, unwidColumnList(columns)) } diff --git a/tests/init/data/test_sample.sql b/tests/init/data/test_sample.sql index 5ac455b..c44e276 100644 --- a/tests/init/data/test_sample.sql +++ b/tests/init/data/test_sample.sql @@ -198,4 +198,34 @@ CREATE TABLE test_sample.person( DROP TYPE IF EXISTS test_sample.MOOD CASCADE; -CREATE TYPE test_sample.MOOD AS ENUM ('sad', 'ok', 'happy'); \ No newline at end of file +CREATE TYPE test_sample.MOOD AS ENUM ('sad', 'ok', 'happy'); + + +-- WEIRD TABLE NAMES -------------- + +DROP TABLE IF EXISTS test_sample."WEIRD NAMES TABLE"; + +CREATE TABLE test_sample."WEIRD NAMES TABLE"( + "weird_column_name1" varchar(100) NOT NULL, + "Weird_Column_Name2" varchar(100) NOT NULL, + "wEiRd_cOluMn_nAmE3" varchar(100) NOT NULL, + "WeIrd_CoLuMN_Name4" varchar(100) NOT NULL, + "WEIRD_COLUMN_NAME5" varchar(100) NOT NULL, + + "WeirdColumnName6" varchar(100) NOT NULL, + "weirdColumnName7" varchar(100) NOT NULL, + "weirdcolumnname8" varchar(100), + + "weird col name9" varchar(100) NOT NULL, + "wEiRd cOlu nAmE10" varchar(100) NOT NULL, + "WEIRD COLU NAME11" varchar(100) NOT NULL, + "Weird Colu Name12" varchar(100) NOT NULL, + + "weird-col-name13" varchar(100) NOT NULL, + "wEiRd-cOlu-nAmE14" varchar(100) NOT NULL, + "WEIRD-COLU-NAME15" varchar(100) NOT NULL, + "Weird-Colu-Name16" varchar(100) NOT NULL +); + +INSERT INTO test_sample."WEIRD NAMES TABLE" +VALUES ('Doe', 'Doe', 'Doe', 'Doe','Doe', 'Doe', 'Doe', 'Doe','Doe', 'Doe', 'Doe', 'Doe','Doe', 'Doe', 'Doe', 'Doe'); \ No newline at end of file diff --git a/tests/sample_test.go b/tests/sample_test.go index 93590c1..b7663b2 100644 --- a/tests/sample_test.go +++ b/tests/sample_test.go @@ -116,3 +116,35 @@ ORDER BY employee.employee_id; ManagerID: int32Ptr(3), }) } + +func TestWierdNamesTable(t *testing.T) { + stmt := WeirdNamesTable.SELECT(WeirdNamesTable.AllColumns) + + fmt.Println(stmt.DebugSql()) + + dest := []model.WeirdNamesTable{} + + err := stmt.Query(db, &dest) + + assert.NilError(t, err) + + assert.Equal(t, len(dest), 1) + assert.DeepEqual(t, dest[0], model.WeirdNamesTable{ + WeirdColumnName1: "Doe", + WeirdColumnName2: "Doe", + WeirdColumnName3: "Doe", + WeirdColumnName4: "Doe", + WeirdColumnName5: "Doe", + WeirdColumnName6: "Doe", + WeirdColumnName7: "Doe", + Weirdcolumnname8: stringPtr("Doe"), + WeirdColName9: "Doe", + WeirdColuName10: "Doe", + WeirdColuName11: "Doe", + WeirdColuName12: "Doe", + WeirdColName13: "Doe", + WeirdColuName14: "Doe", + WeirdColuName15: "Doe", + WeirdColuName16: "Doe", + }) +} diff --git a/tests/scan_test.go b/tests/scan_test.go index 75fc8f7..631c093 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -467,7 +467,7 @@ func TestScanToSlice(t *testing.T) { t.Run("slice of structs with slice of ints", func(t *testing.T) { var dest []struct { model.Film - IDs []int32 `alias:"Inventory.inventory_id"` + IDs []int32 `alias:"inventory.inventory_id"` } err := query.Query(db, &dest) @@ -483,7 +483,7 @@ func TestScanToSlice(t *testing.T) { t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) { var dest []struct { model.Film - IDs []*int32 `alias:"inventory.InventoryId"` + IDs []*int32 `alias:"inventory.inventory_id"` } err := query.Query(db, &dest) @@ -796,7 +796,7 @@ var store1 = model.Store{ LastUpdate: *timestampWithoutTimeZone("2006-02-15 09:57:12", 0), } -var pgRating = model.MpaaRating_PG +var pgRating = model.MpaaRating_Pg var gRating = model.MpaaRating_G var language1 = model.Language{ diff --git a/tests/select_test.go b/tests/select_test.go index cab7198..123d988 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -291,7 +291,7 @@ LIMIT 15; query := Film. INNER_JOIN(Language, Film.LanguageID.EQ(Language.LanguageID)). SELECT(Language.AllColumns, Film.AllColumns). - WHERE(Film.Rating.EQ(enum.MpaaRating.NC17)). + WHERE(Film.Rating.EQ(enum.MpaaRating.Nc17)). LIMIT(15) assertStatementSql(t, query, expectedSql, int64(15)) @@ -304,7 +304,7 @@ LIMIT 15; englishFilms := filmsPerLanguage[0] - assert.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_NC17) + assert.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_Nc17) filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err = query.Query(db, &filmsPerLanguageWithPtrs) @@ -1237,6 +1237,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; } func TestQuickStartWithSubQueries(t *testing.T) { + filmLogerThan180 := Film. SELECT(Film.AllColumns). WHERE(Film.Length.GT(Int(180))). diff --git a/utils.go b/utils.go index 982c986..cfc9bff 100644 --- a/utils.go +++ b/utils.go @@ -2,7 +2,7 @@ package jet import ( "errors" - "github.com/serenize/snaker" + "github.com/go-jet/jet/internal/util" "reflect" "strings" ) @@ -145,7 +145,7 @@ func unwindRowFromModel(columns []column, data interface{}) []clause { for _, column := range columns { columnName := column.Name() - structFieldName := snaker.SnakeToCamel(columnName) + structFieldName := util.ToGoIdentifier(columnName) structField := structValue.FieldByName(structFieldName)