Add support for database enum types.

This commit is contained in:
zer0sub 2019-04-03 19:21:46 +02:00
parent 273bf1ed4c
commit 2c7a9f5058
9 changed files with 202 additions and 22 deletions

View file

@ -30,3 +30,29 @@ func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) erro
return nil
}
func generateEnumTypes(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
err := ensureDirPath(modelDirPath)
if err != nil {
return err
}
for _, enumInfo := range databaseInfo.EnumInfos {
text, err := generateTemplate(EnumModelTemplate, enumInfo)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, enumInfo.Name, text)
if err != nil {
return err
}
}
return nil
}

View file

@ -53,5 +53,11 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
return err
}
err = generateEnumTypes(databaseInfo, folderPath)
if err != nil {
return err
}
return nil
}

View file

@ -10,6 +10,7 @@ type ColumnInfo struct {
Name string
IsNullable bool
DataType string
EnumName string
TableInfo *TableInfo
}
@ -65,6 +66,8 @@ func (c ColumnInfo) GoBaseType() string {
return snaker.SnakeToCamel(forignKeyTable)
} else {
switch c.DataType {
case "USER-DEFINED":
return c.EnumName
case "boolean":
return "bool"
case "smallint":
@ -105,7 +108,7 @@ func (c ColumnInfo) ToGoFieldName() string {
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type
SELECT column_name, is_nullable, data_type, udt_name
FROM information_schema.columns
where table_schema = $1 and table_name = $2
order by ordinal_position;`
@ -124,7 +127,7 @@ order by ordinal_position;`
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType)
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName)
columnInfo.IsNullable = isNullable == "YES"

View file

@ -8,6 +8,7 @@ type DatabaseInfo struct {
DatabaseName string
SchemaName string
TableInfos []TableInfo
EnumInfos []EnumInfo
}
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
@ -16,6 +17,7 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
databaseName,
schemaName,
[]TableInfo{},
[]EnumInfo{},
}
var err error
@ -25,5 +27,11 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
return nil, err
}
databaseInfo.EnumInfos, err = fetchEnumInfos(db, databaseInfo)
if err != nil {
return nil, err
}
return databaseInfo, nil
}

View file

@ -0,0 +1,70 @@
package metadata
import (
"database/sql"
"fmt"
)
type EnumInfo struct {
Name string
Values []string
}
func (e *EnumInfo) goValueName(index int) {
return
}
func fetchEnumInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]EnumInfo, error) {
query := `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
//fmt.Println(query, schemaName)
rows, err := db.Query(query, &databaseInfo.SchemaName)
if err != nil {
return nil, err
}
defer rows.Close()
enumsInfosMap := map[string][]string{}
for rows.Next() {
var enumName string
var enumValue string
err = rows.Scan(&enumName, &enumValue)
if err != nil {
return nil, err
}
enumValues := enumsInfosMap[enumName]
enumValues = append(enumValues, enumValue)
enumsInfosMap[enumName] = enumValues
}
err = rows.Err()
if err != nil {
return nil, err
}
ret := []EnumInfo{}
for enumName, enumValues := range enumsInfosMap {
ret = append(ret, EnumInfo{
enumName,
enumValues,
})
}
fmt.Println("FOUND", len(ret), " enums")
return ret, nil
}

View file

@ -61,3 +61,34 @@ type {{.ToGoModelStructName}} struct {
{{- end}}
}
`
var EnumModelTemplate = `package model
import "errors"
type {{.Name}} string
const (
{{- range $index, $element := .Values}}
{{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}"
{{- end}}
)
func (e *{{$.Name}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("Invalid data for {{$.Name}} enum")
} else {
switch string(v) {
{{- range $index, $element := .Values}}
case "{{$element}}":
*e = {{camelize $.Name}}_{{camelize $element}}
{{- end}}
default:
return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum")
}
return nil
}
}
`

View file

@ -6,6 +6,7 @@ import (
"go/format"
"os"
"path/filepath"
"strings"
"text/template"
)
@ -50,7 +51,7 @@ func generateTemplate(templateText string, templateData interface{}) ([]byte, er
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{
"camelize": func(txt string) string {
return snaker.SnakeToCamel(txt)
return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1))
},
}).Parse(templateText)

View file

@ -419,7 +419,23 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
fieldName := field.Name
if !isDbBaseType(field.Type) {
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
cellValue := getCellValue(scanContext, tableName, fieldName, row)
if cellValue == nil {
continue
}
initializeValue(fieldValue)
scanner := fieldValue.Interface().(sql.Scanner)
err := scanner.Scan(cellValue)
if err != nil {
return err
}
} else if !isDbBaseType(field.Type) {
//var fieldValueInterface interface{}
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field)
@ -427,18 +443,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
return err
}
} else {
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
//columnName := snaker.CamelToSnake(fieldName)
////fmt.Println(columnName)
index := getIndex(scanContext.columnNames, columnName)
if index < 0 {
continue
}
////spew.Dump(row[index])
cellValue := cellValue(row, index)
cellValue := getCellValue(scanContext, tableName, fieldName, row)
//spew.Dump(cellValue)
//spew.Dump(rowColumnValue, fieldValue)
@ -451,6 +456,29 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
return nil
}
func initializeValue(value reflect.Value) {
newValuePtr := reflect.New(value.Type().Elem())
if value.Kind() == reflect.Ptr {
value.Set(newValuePtr)
} else {
value.Set(newValuePtr.Elem())
}
}
func getCellValue(scanContext *scanContext, tableName, fieldName string, row []interface{}) interface{} {
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
//columnName := snaker.CamelToSnake(fieldName)
////fmt.Println(columnName)
index := getIndex(scanContext.columnNames, columnName)
if index < 0 {
return nil
}
return cellValue(row, index)
}
func reflectValueToString(val interface{}) string {
//spew.Dump(val)
@ -549,6 +577,7 @@ var nullTimeType = reflect.TypeOf(NullTime{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
//spew.Dump(columnType)
//fmt.Println(columnType.DatabaseTypeName())
switch columnType.DatabaseTypeName() {
case "INT2":
return nullInt16Type

View file

@ -130,7 +130,7 @@ func TestSelect_ScanToSlice(t *testing.T) {
func TestJoinQuerySlice(t *testing.T) {
type FilmsPerLanguage struct {
Language *model.Language
Film *[]model.Film
Film []model.Film
}
filmsPerLanguage := []FilmsPerLanguage{}
@ -139,13 +139,13 @@ func TestJoinQuerySlice(t *testing.T) {
query := Film.
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Language.AllColumns, Film.AllColumns).
Where(Film.Rating.EqL(string(model.MpaaRating_NC17))).
Limit(15)
queryStr, err := query.String()
assert.NilError(t, err)
assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id LIMIT 15`)
assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = 'NC-17' LIMIT 15`)
//fmt.Println(queryStr)
err = query.Execute(db, &filmsPerLanguage)
@ -158,7 +158,11 @@ func TestJoinQuerySlice(t *testing.T) {
//spew.Dump(filmsPerLanguage)
assert.Equal(t, len(filmsPerLanguage), 1)
assert.Equal(t, len(*filmsPerLanguage[0].Film), limit)
assert.Equal(t, len(filmsPerLanguage[0].Film), limit)
englishFilms := filmsPerLanguage[0]
assert.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_NC17)
//spew.Dump(filmsPerLanguage)
@ -167,7 +171,7 @@ func TestJoinQuerySlice(t *testing.T) {
assert.NilError(t, err)
assert.Equal(t, len(filmsPerLanguage), 1)
assert.Equal(t, len(*filmsPerLanguage[0].Film), limit)
assert.Equal(t, len(filmsPerLanguage[0].Film), limit)
}
func TestJoinQuerySliceWithPtrs(t *testing.T) {
@ -492,6 +496,8 @@ func TestSelectQueryScalar(t *testing.T) {
assert.Equal(t, len(maxRentalRateFilms), 336)
gRating := model.MpaaRating_G
assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{
FilmID: 2,
Title: "Ace Goldfinger",
@ -501,7 +507,7 @@ func TestSelectQueryScalar(t *testing.T) {
RentalRate: 4.99,
Length: int16Ptr(48),
ReplacementCost: 12.99,
Rating: stringPtr("G"),
Rating: &gRating,
RentalDuration: 3,
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"),