Add support for database enum types.

This commit is contained in:
zer0sub 2019-04-03 19:21:46 +02:00
parent 273bf1ed4c
commit 2c7a9f5058
9 changed files with 202 additions and 22 deletions

View file

@ -30,3 +30,29 @@ func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) erro
return nil
}
func generateEnumTypes(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
err := ensureDirPath(modelDirPath)
if err != nil {
return err
}
for _, enumInfo := range databaseInfo.EnumInfos {
text, err := generateTemplate(EnumModelTemplate, enumInfo)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, enumInfo.Name, text)
if err != nil {
return err
}
}
return nil
}

View file

@ -53,5 +53,11 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
return err
}
err = generateEnumTypes(databaseInfo, folderPath)
if err != nil {
return err
}
return nil
}

View file

@ -10,6 +10,7 @@ type ColumnInfo struct {
Name string
IsNullable bool
DataType string
EnumName string
TableInfo *TableInfo
}
@ -65,6 +66,8 @@ func (c ColumnInfo) GoBaseType() string {
return snaker.SnakeToCamel(forignKeyTable)
} else {
switch c.DataType {
case "USER-DEFINED":
return c.EnumName
case "boolean":
return "bool"
case "smallint":
@ -105,7 +108,7 @@ func (c ColumnInfo) ToGoFieldName() string {
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type
SELECT column_name, is_nullable, data_type, udt_name
FROM information_schema.columns
where table_schema = $1 and table_name = $2
order by ordinal_position;`
@ -124,7 +127,7 @@ order by ordinal_position;`
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType)
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName)
columnInfo.IsNullable = isNullable == "YES"

View file

@ -8,6 +8,7 @@ type DatabaseInfo struct {
DatabaseName string
SchemaName string
TableInfos []TableInfo
EnumInfos []EnumInfo
}
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
@ -16,6 +17,7 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
databaseName,
schemaName,
[]TableInfo{},
[]EnumInfo{},
}
var err error
@ -25,5 +27,11 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
return nil, err
}
databaseInfo.EnumInfos, err = fetchEnumInfos(db, databaseInfo)
if err != nil {
return nil, err
}
return databaseInfo, nil
}

View file

@ -0,0 +1,70 @@
package metadata
import (
"database/sql"
"fmt"
)
type EnumInfo struct {
Name string
Values []string
}
func (e *EnumInfo) goValueName(index int) {
return
}
func fetchEnumInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]EnumInfo, error) {
query := `
SELECT t.typname,
e.enumlabel
FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;`
//fmt.Println(query, schemaName)
rows, err := db.Query(query, &databaseInfo.SchemaName)
if err != nil {
return nil, err
}
defer rows.Close()
enumsInfosMap := map[string][]string{}
for rows.Next() {
var enumName string
var enumValue string
err = rows.Scan(&enumName, &enumValue)
if err != nil {
return nil, err
}
enumValues := enumsInfosMap[enumName]
enumValues = append(enumValues, enumValue)
enumsInfosMap[enumName] = enumValues
}
err = rows.Err()
if err != nil {
return nil, err
}
ret := []EnumInfo{}
for enumName, enumValues := range enumsInfosMap {
ret = append(ret, EnumInfo{
enumName,
enumValues,
})
}
fmt.Println("FOUND", len(ret), " enums")
return ret, nil
}

View file

@ -61,3 +61,34 @@ type {{.ToGoModelStructName}} struct {
{{- end}}
}
`
var EnumModelTemplate = `package model
import "errors"
type {{.Name}} string
const (
{{- range $index, $element := .Values}}
{{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}"
{{- end}}
)
func (e *{{$.Name}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("Invalid data for {{$.Name}} enum")
} else {
switch string(v) {
{{- range $index, $element := .Values}}
case "{{$element}}":
*e = {{camelize $.Name}}_{{camelize $element}}
{{- end}}
default:
return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum")
}
return nil
}
}
`

View file

@ -6,6 +6,7 @@ import (
"go/format"
"os"
"path/filepath"
"strings"
"text/template"
)
@ -50,7 +51,7 @@ func generateTemplate(templateText string, templateData interface{}) ([]byte, er
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{
"camelize": func(txt string) string {
return snaker.SnakeToCamel(txt)
return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1))
},
}).Parse(templateText)