From 2c7a9f5058b70921ef9fd824b318f363dab0ce07 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Wed, 3 Apr 2019 19:21:46 +0200 Subject: [PATCH] Add support for database enum types. --- generator/datamodel_generator.go | 26 +++++++++++ generator/generator.go | 6 +++ generator/metadata/column_info.go | 7 ++- generator/metadata/database_info.go | 8 ++++ generator/metadata/enum_info.go | 70 +++++++++++++++++++++++++++++ generator/templates.go | 31 +++++++++++++ generator/utils.go | 3 +- sqlbuilder/execution/execution.go | 55 +++++++++++++++++------ tests/generator_test.go | 18 +++++--- 9 files changed, 202 insertions(+), 22 deletions(-) create mode 100644 generator/metadata/enum_info.go diff --git a/generator/datamodel_generator.go b/generator/datamodel_generator.go index 230841f..a6a8550 100644 --- a/generator/datamodel_generator.go +++ b/generator/datamodel_generator.go @@ -30,3 +30,29 @@ func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) erro return nil } + +func generateEnumTypes(databaseInfo *metadata.DatabaseInfo, dirPath string) error { + modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model") + + err := ensureDirPath(modelDirPath) + + if err != nil { + return err + } + + for _, enumInfo := range databaseInfo.EnumInfos { + text, err := generateTemplate(EnumModelTemplate, enumInfo) + + if err != nil { + return err + } + + err = saveGoFile(modelDirPath, enumInfo.Name, text) + + if err != nil { + return err + } + } + + return nil +} diff --git a/generator/generator.go b/generator/generator.go index d69448f..71ac5e8 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -53,5 +53,11 @@ func Generate(folderPath string, connectString string, databaseName, schemaName return err } + err = generateEnumTypes(databaseInfo, folderPath) + + if err != nil { + return err + } + return nil } diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go index 7d54b90..4197ae8 100644 --- a/generator/metadata/column_info.go +++ b/generator/metadata/column_info.go @@ -10,6 +10,7 @@ type ColumnInfo struct { Name string IsNullable bool DataType string + EnumName string TableInfo *TableInfo } @@ -65,6 +66,8 @@ func (c ColumnInfo) GoBaseType() string { return snaker.SnakeToCamel(forignKeyTable) } else { switch c.DataType { + case "USER-DEFINED": + return c.EnumName case "boolean": return "bool" case "smallint": @@ -105,7 +108,7 @@ func (c ColumnInfo) ToGoFieldName() string { func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) { query := ` -SELECT column_name, is_nullable, data_type +SELECT column_name, is_nullable, data_type, udt_name FROM information_schema.columns where table_schema = $1 and table_name = $2 order by ordinal_position;` @@ -124,7 +127,7 @@ order by ordinal_position;` for rows.Next() { columnInfo := ColumnInfo{} var isNullable string - err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType) + err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName) columnInfo.IsNullable = isNullable == "YES" diff --git a/generator/metadata/database_info.go b/generator/metadata/database_info.go index 658df3e..9bbba9b 100644 --- a/generator/metadata/database_info.go +++ b/generator/metadata/database_info.go @@ -8,6 +8,7 @@ type DatabaseInfo struct { DatabaseName string SchemaName string TableInfos []TableInfo + EnumInfos []EnumInfo } func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) { @@ -16,6 +17,7 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo databaseName, schemaName, []TableInfo{}, + []EnumInfo{}, } var err error @@ -25,5 +27,11 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo return nil, err } + databaseInfo.EnumInfos, err = fetchEnumInfos(db, databaseInfo) + + if err != nil { + return nil, err + } + return databaseInfo, nil } diff --git a/generator/metadata/enum_info.go b/generator/metadata/enum_info.go new file mode 100644 index 0000000..0f129b2 --- /dev/null +++ b/generator/metadata/enum_info.go @@ -0,0 +1,70 @@ +package metadata + +import ( + "database/sql" + "fmt" +) + +type EnumInfo struct { + Name string + Values []string +} + +func (e *EnumInfo) goValueName(index int) { + return +} + +func fetchEnumInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]EnumInfo, error) { + query := ` +SELECT t.typname, + e.enumlabel +FROM pg_catalog.pg_type t + JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid + JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace +WHERE n.nspname = $1 +ORDER BY n.nspname, t.typname, e.enumsortorder;` + + //fmt.Println(query, schemaName) + + rows, err := db.Query(query, &databaseInfo.SchemaName) + + if err != nil { + return nil, err + } + defer rows.Close() + + enumsInfosMap := map[string][]string{} + for rows.Next() { + var enumName string + var enumValue string + err = rows.Scan(&enumName, &enumValue) + if err != nil { + return nil, err + } + + enumValues := enumsInfosMap[enumName] + + enumValues = append(enumValues, enumValue) + + enumsInfosMap[enumName] = enumValues + } + + err = rows.Err() + + if err != nil { + return nil, err + } + + ret := []EnumInfo{} + + for enumName, enumValues := range enumsInfosMap { + ret = append(ret, EnumInfo{ + enumName, + enumValues, + }) + } + + fmt.Println("FOUND", len(ret), " enums") + + return ret, nil +} diff --git a/generator/templates.go b/generator/templates.go index d7ec706..d8672dc 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -61,3 +61,34 @@ type {{.ToGoModelStructName}} struct { {{- end}} } ` + +var EnumModelTemplate = `package model + +import "errors" + +type {{.Name}} string + +const ( +{{- range $index, $element := .Values}} + {{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}" +{{- end}} +) + +func (e *{{$.Name}}) Scan(value interface{}) error { + if v, ok := value.(string); !ok { + return errors.New("Invalid data for {{$.Name}} enum") + } else { + switch string(v) { +{{- range $index, $element := .Values}} + case "{{$element}}": + *e = {{camelize $.Name}}_{{camelize $element}} +{{- end}} + default: + return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum") + } + + return nil + } +} + +` diff --git a/generator/utils.go b/generator/utils.go index 95742a6..cce2997 100644 --- a/generator/utils.go +++ b/generator/utils.go @@ -6,6 +6,7 @@ import ( "go/format" "os" "path/filepath" + "strings" "text/template" ) @@ -50,7 +51,7 @@ func generateTemplate(templateText string, templateData interface{}) ([]byte, er t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{ "camelize": func(txt string) string { - return snaker.SnakeToCamel(txt) + return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1)) }, }).Parse(templateText) diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 5a6db84..4bcbbcc 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -419,7 +419,23 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma fieldName := field.Name - if !isDbBaseType(field.Type) { + if _, ok := fieldValue.Interface().(sql.Scanner); ok { + cellValue := getCellValue(scanContext, tableName, fieldName, row) + + if cellValue == nil { + continue + } + + initializeValue(fieldValue) + + scanner := fieldValue.Interface().(sql.Scanner) + + err := scanner.Scan(cellValue) + + if err != nil { + return err + } + } else if !isDbBaseType(field.Type) { //var fieldValueInterface interface{} err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field) @@ -427,18 +443,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma return err } } else { - columnName := tableName + "." + snaker.CamelToSnake(fieldName) - //columnName := snaker.CamelToSnake(fieldName) - - ////fmt.Println(columnName) - index := getIndex(scanContext.columnNames, columnName) - - if index < 0 { - continue - } - ////spew.Dump(row[index]) - - cellValue := cellValue(row, index) + cellValue := getCellValue(scanContext, tableName, fieldName, row) //spew.Dump(cellValue) //spew.Dump(rowColumnValue, fieldValue) @@ -451,6 +456,29 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma return nil } +func initializeValue(value reflect.Value) { + newValuePtr := reflect.New(value.Type().Elem()) + if value.Kind() == reflect.Ptr { + value.Set(newValuePtr) + } else { + value.Set(newValuePtr.Elem()) + } +} + +func getCellValue(scanContext *scanContext, tableName, fieldName string, row []interface{}) interface{} { + columnName := tableName + "." + snaker.CamelToSnake(fieldName) + //columnName := snaker.CamelToSnake(fieldName) + + ////fmt.Println(columnName) + index := getIndex(scanContext.columnNames, columnName) + + if index < 0 { + return nil + } + + return cellValue(row, index) +} + func reflectValueToString(val interface{}) string { //spew.Dump(val) @@ -549,6 +577,7 @@ var nullTimeType = reflect.TypeOf(NullTime{}) func newScanType(columnType *sql.ColumnType) reflect.Type { //spew.Dump(columnType) + //fmt.Println(columnType.DatabaseTypeName()) switch columnType.DatabaseTypeName() { case "INT2": return nullInt16Type diff --git a/tests/generator_test.go b/tests/generator_test.go index 63ca2e3..4b77b84 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -130,7 +130,7 @@ func TestSelect_ScanToSlice(t *testing.T) { func TestJoinQuerySlice(t *testing.T) { type FilmsPerLanguage struct { Language *model.Language - Film *[]model.Film + Film []model.Film } filmsPerLanguage := []FilmsPerLanguage{} @@ -139,13 +139,13 @@ func TestJoinQuerySlice(t *testing.T) { query := Film. INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). SELECT(Language.AllColumns, Film.AllColumns). + Where(Film.Rating.EqL(string(model.MpaaRating_NC17))). Limit(15) queryStr, err := query.String() assert.NilError(t, err) - assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id LIMIT 15`) - + assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = 'NC-17' LIMIT 15`) //fmt.Println(queryStr) err = query.Execute(db, &filmsPerLanguage) @@ -158,7 +158,11 @@ func TestJoinQuerySlice(t *testing.T) { //spew.Dump(filmsPerLanguage) assert.Equal(t, len(filmsPerLanguage), 1) - assert.Equal(t, len(*filmsPerLanguage[0].Film), limit) + assert.Equal(t, len(filmsPerLanguage[0].Film), limit) + + englishFilms := filmsPerLanguage[0] + + assert.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_NC17) //spew.Dump(filmsPerLanguage) @@ -167,7 +171,7 @@ func TestJoinQuerySlice(t *testing.T) { assert.NilError(t, err) assert.Equal(t, len(filmsPerLanguage), 1) - assert.Equal(t, len(*filmsPerLanguage[0].Film), limit) + assert.Equal(t, len(filmsPerLanguage[0].Film), limit) } func TestJoinQuerySliceWithPtrs(t *testing.T) { @@ -492,6 +496,8 @@ func TestSelectQueryScalar(t *testing.T) { assert.Equal(t, len(maxRentalRateFilms), 336) + gRating := model.MpaaRating_G + assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{ FilmID: 2, Title: "Ace Goldfinger", @@ -501,7 +507,7 @@ func TestSelectQueryScalar(t *testing.T) { RentalRate: 4.99, Length: int16Ptr(48), ReplacementCost: 12.99, - Rating: stringPtr("G"), + Rating: &gRating, RentalDuration: 3, LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3), SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"),