Add support for database enum types.
This commit is contained in:
parent
273bf1ed4c
commit
2c7a9f5058
9 changed files with 202 additions and 22 deletions
|
|
@ -30,3 +30,29 @@ func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) erro
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateEnumTypes(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
|
||||
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
|
||||
|
||||
err := ensureDirPath(modelDirPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, enumInfo := range databaseInfo.EnumInfos {
|
||||
text, err := generateTemplate(EnumModelTemplate, enumInfo)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = saveGoFile(modelDirPath, enumInfo.Name, text)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,5 +53,11 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
|
|||
return err
|
||||
}
|
||||
|
||||
err = generateEnumTypes(databaseInfo, folderPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ type ColumnInfo struct {
|
|||
Name string
|
||||
IsNullable bool
|
||||
DataType string
|
||||
EnumName string
|
||||
TableInfo *TableInfo
|
||||
}
|
||||
|
||||
|
|
@ -65,6 +66,8 @@ func (c ColumnInfo) GoBaseType() string {
|
|||
return snaker.SnakeToCamel(forignKeyTable)
|
||||
} else {
|
||||
switch c.DataType {
|
||||
case "USER-DEFINED":
|
||||
return c.EnumName
|
||||
case "boolean":
|
||||
return "bool"
|
||||
case "smallint":
|
||||
|
|
@ -105,7 +108,7 @@ func (c ColumnInfo) ToGoFieldName() string {
|
|||
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) {
|
||||
|
||||
query := `
|
||||
SELECT column_name, is_nullable, data_type
|
||||
SELECT column_name, is_nullable, data_type, udt_name
|
||||
FROM information_schema.columns
|
||||
where table_schema = $1 and table_name = $2
|
||||
order by ordinal_position;`
|
||||
|
|
@ -124,7 +127,7 @@ order by ordinal_position;`
|
|||
for rows.Next() {
|
||||
columnInfo := ColumnInfo{}
|
||||
var isNullable string
|
||||
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType)
|
||||
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName)
|
||||
|
||||
columnInfo.IsNullable = isNullable == "YES"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ type DatabaseInfo struct {
|
|||
DatabaseName string
|
||||
SchemaName string
|
||||
TableInfos []TableInfo
|
||||
EnumInfos []EnumInfo
|
||||
}
|
||||
|
||||
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
|
||||
|
|
@ -16,6 +17,7 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
|
|||
databaseName,
|
||||
schemaName,
|
||||
[]TableInfo{},
|
||||
[]EnumInfo{},
|
||||
}
|
||||
|
||||
var err error
|
||||
|
|
@ -25,5 +27,11 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
|
|||
return nil, err
|
||||
}
|
||||
|
||||
databaseInfo.EnumInfos, err = fetchEnumInfos(db, databaseInfo)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return databaseInfo, nil
|
||||
}
|
||||
|
|
|
|||
70
generator/metadata/enum_info.go
Normal file
70
generator/metadata/enum_info.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type EnumInfo struct {
|
||||
Name string
|
||||
Values []string
|
||||
}
|
||||
|
||||
func (e *EnumInfo) goValueName(index int) {
|
||||
return
|
||||
}
|
||||
|
||||
func fetchEnumInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]EnumInfo, error) {
|
||||
query := `
|
||||
SELECT t.typname,
|
||||
e.enumlabel
|
||||
FROM pg_catalog.pg_type t
|
||||
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
|
||||
WHERE n.nspname = $1
|
||||
ORDER BY n.nspname, t.typname, e.enumsortorder;`
|
||||
|
||||
//fmt.Println(query, schemaName)
|
||||
|
||||
rows, err := db.Query(query, &databaseInfo.SchemaName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
enumsInfosMap := map[string][]string{}
|
||||
for rows.Next() {
|
||||
var enumName string
|
||||
var enumValue string
|
||||
err = rows.Scan(&enumName, &enumValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enumValues := enumsInfosMap[enumName]
|
||||
|
||||
enumValues = append(enumValues, enumValue)
|
||||
|
||||
enumsInfosMap[enumName] = enumValues
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := []EnumInfo{}
|
||||
|
||||
for enumName, enumValues := range enumsInfosMap {
|
||||
ret = append(ret, EnumInfo{
|
||||
enumName,
|
||||
enumValues,
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Println("FOUND", len(ret), " enums")
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
|
@ -61,3 +61,34 @@ type {{.ToGoModelStructName}} struct {
|
|||
{{- end}}
|
||||
}
|
||||
`
|
||||
|
||||
var EnumModelTemplate = `package model
|
||||
|
||||
import "errors"
|
||||
|
||||
type {{.Name}} string
|
||||
|
||||
const (
|
||||
{{- range $index, $element := .Values}}
|
||||
{{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}"
|
||||
{{- end}}
|
||||
)
|
||||
|
||||
func (e *{{$.Name}}) Scan(value interface{}) error {
|
||||
if v, ok := value.(string); !ok {
|
||||
return errors.New("Invalid data for {{$.Name}} enum")
|
||||
} else {
|
||||
switch string(v) {
|
||||
{{- range $index, $element := .Values}}
|
||||
case "{{$element}}":
|
||||
*e = {{camelize $.Name}}_{{camelize $element}}
|
||||
{{- end}}
|
||||
default:
|
||||
return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
`
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
"go/format"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
|
|
@ -50,7 +51,7 @@ func generateTemplate(templateText string, templateData interface{}) ([]byte, er
|
|||
|
||||
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{
|
||||
"camelize": func(txt string) string {
|
||||
return snaker.SnakeToCamel(txt)
|
||||
return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1))
|
||||
},
|
||||
}).Parse(templateText)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue