Add support for database enum types.
This commit is contained in:
parent
273bf1ed4c
commit
2c7a9f5058
9 changed files with 202 additions and 22 deletions
|
|
@ -30,3 +30,29 @@ func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) erro
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateEnumTypes(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
|
||||||
|
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
|
||||||
|
|
||||||
|
err := ensureDirPath(modelDirPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, enumInfo := range databaseInfo.EnumInfos {
|
||||||
|
text, err := generateTemplate(EnumModelTemplate, enumInfo)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = saveGoFile(modelDirPath, enumInfo.Name, text)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -53,5 +53,11 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = generateEnumTypes(databaseInfo, folderPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ type ColumnInfo struct {
|
||||||
Name string
|
Name string
|
||||||
IsNullable bool
|
IsNullable bool
|
||||||
DataType string
|
DataType string
|
||||||
|
EnumName string
|
||||||
TableInfo *TableInfo
|
TableInfo *TableInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -65,6 +66,8 @@ func (c ColumnInfo) GoBaseType() string {
|
||||||
return snaker.SnakeToCamel(forignKeyTable)
|
return snaker.SnakeToCamel(forignKeyTable)
|
||||||
} else {
|
} else {
|
||||||
switch c.DataType {
|
switch c.DataType {
|
||||||
|
case "USER-DEFINED":
|
||||||
|
return c.EnumName
|
||||||
case "boolean":
|
case "boolean":
|
||||||
return "bool"
|
return "bool"
|
||||||
case "smallint":
|
case "smallint":
|
||||||
|
|
@ -105,7 +108,7 @@ func (c ColumnInfo) ToGoFieldName() string {
|
||||||
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) {
|
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) {
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT column_name, is_nullable, data_type
|
SELECT column_name, is_nullable, data_type, udt_name
|
||||||
FROM information_schema.columns
|
FROM information_schema.columns
|
||||||
where table_schema = $1 and table_name = $2
|
where table_schema = $1 and table_name = $2
|
||||||
order by ordinal_position;`
|
order by ordinal_position;`
|
||||||
|
|
@ -124,7 +127,7 @@ order by ordinal_position;`
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
columnInfo := ColumnInfo{}
|
columnInfo := ColumnInfo{}
|
||||||
var isNullable string
|
var isNullable string
|
||||||
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType)
|
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName)
|
||||||
|
|
||||||
columnInfo.IsNullable = isNullable == "YES"
|
columnInfo.IsNullable = isNullable == "YES"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ type DatabaseInfo struct {
|
||||||
DatabaseName string
|
DatabaseName string
|
||||||
SchemaName string
|
SchemaName string
|
||||||
TableInfos []TableInfo
|
TableInfos []TableInfo
|
||||||
|
EnumInfos []EnumInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
|
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
|
||||||
|
|
@ -16,6 +17,7 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
|
||||||
databaseName,
|
databaseName,
|
||||||
schemaName,
|
schemaName,
|
||||||
[]TableInfo{},
|
[]TableInfo{},
|
||||||
|
[]EnumInfo{},
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
@ -25,5 +27,11 @@ func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
databaseInfo.EnumInfos, err = fetchEnumInfos(db, databaseInfo)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return databaseInfo, nil
|
return databaseInfo, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
70
generator/metadata/enum_info.go
Normal file
70
generator/metadata/enum_info.go
Normal file
|
|
@ -0,0 +1,70 @@
|
||||||
|
package metadata
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EnumInfo struct {
|
||||||
|
Name string
|
||||||
|
Values []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EnumInfo) goValueName(index int) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchEnumInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]EnumInfo, error) {
|
||||||
|
query := `
|
||||||
|
SELECT t.typname,
|
||||||
|
e.enumlabel
|
||||||
|
FROM pg_catalog.pg_type t
|
||||||
|
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
|
||||||
|
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
|
||||||
|
WHERE n.nspname = $1
|
||||||
|
ORDER BY n.nspname, t.typname, e.enumsortorder;`
|
||||||
|
|
||||||
|
//fmt.Println(query, schemaName)
|
||||||
|
|
||||||
|
rows, err := db.Query(query, &databaseInfo.SchemaName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
enumsInfosMap := map[string][]string{}
|
||||||
|
for rows.Next() {
|
||||||
|
var enumName string
|
||||||
|
var enumValue string
|
||||||
|
err = rows.Scan(&enumName, &enumValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
enumValues := enumsInfosMap[enumName]
|
||||||
|
|
||||||
|
enumValues = append(enumValues, enumValue)
|
||||||
|
|
||||||
|
enumsInfosMap[enumName] = enumValues
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Err()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := []EnumInfo{}
|
||||||
|
|
||||||
|
for enumName, enumValues := range enumsInfosMap {
|
||||||
|
ret = append(ret, EnumInfo{
|
||||||
|
enumName,
|
||||||
|
enumValues,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("FOUND", len(ret), " enums")
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
@ -61,3 +61,34 @@ type {{.ToGoModelStructName}} struct {
|
||||||
{{- end}}
|
{{- end}}
|
||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
|
var EnumModelTemplate = `package model
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
type {{.Name}} string
|
||||||
|
|
||||||
|
const (
|
||||||
|
{{- range $index, $element := .Values}}
|
||||||
|
{{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}"
|
||||||
|
{{- end}}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e *{{$.Name}}) Scan(value interface{}) error {
|
||||||
|
if v, ok := value.(string); !ok {
|
||||||
|
return errors.New("Invalid data for {{$.Name}} enum")
|
||||||
|
} else {
|
||||||
|
switch string(v) {
|
||||||
|
{{- range $index, $element := .Values}}
|
||||||
|
case "{{$element}}":
|
||||||
|
*e = {{camelize $.Name}}_{{camelize $element}}
|
||||||
|
{{- end}}
|
||||||
|
default:
|
||||||
|
return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
`
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"go/format"
|
"go/format"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -50,7 +51,7 @@ func generateTemplate(templateText string, templateData interface{}) ([]byte, er
|
||||||
|
|
||||||
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{
|
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{
|
||||||
"camelize": func(txt string) string {
|
"camelize": func(txt string) string {
|
||||||
return snaker.SnakeToCamel(txt)
|
return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1))
|
||||||
},
|
},
|
||||||
}).Parse(templateText)
|
}).Parse(templateText)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -419,7 +419,23 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
|
||||||
|
|
||||||
fieldName := field.Name
|
fieldName := field.Name
|
||||||
|
|
||||||
if !isDbBaseType(field.Type) {
|
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||||
|
cellValue := getCellValue(scanContext, tableName, fieldName, row)
|
||||||
|
|
||||||
|
if cellValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
initializeValue(fieldValue)
|
||||||
|
|
||||||
|
scanner := fieldValue.Interface().(sql.Scanner)
|
||||||
|
|
||||||
|
err := scanner.Scan(cellValue)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if !isDbBaseType(field.Type) {
|
||||||
//var fieldValueInterface interface{}
|
//var fieldValueInterface interface{}
|
||||||
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field)
|
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field)
|
||||||
|
|
||||||
|
|
@ -427,18 +443,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
|
cellValue := getCellValue(scanContext, tableName, fieldName, row)
|
||||||
//columnName := snaker.CamelToSnake(fieldName)
|
|
||||||
|
|
||||||
////fmt.Println(columnName)
|
|
||||||
index := getIndex(scanContext.columnNames, columnName)
|
|
||||||
|
|
||||||
if index < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
////spew.Dump(row[index])
|
|
||||||
|
|
||||||
cellValue := cellValue(row, index)
|
|
||||||
//spew.Dump(cellValue)
|
//spew.Dump(cellValue)
|
||||||
|
|
||||||
//spew.Dump(rowColumnValue, fieldValue)
|
//spew.Dump(rowColumnValue, fieldValue)
|
||||||
|
|
@ -451,6 +456,29 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initializeValue(value reflect.Value) {
|
||||||
|
newValuePtr := reflect.New(value.Type().Elem())
|
||||||
|
if value.Kind() == reflect.Ptr {
|
||||||
|
value.Set(newValuePtr)
|
||||||
|
} else {
|
||||||
|
value.Set(newValuePtr.Elem())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCellValue(scanContext *scanContext, tableName, fieldName string, row []interface{}) interface{} {
|
||||||
|
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
|
||||||
|
//columnName := snaker.CamelToSnake(fieldName)
|
||||||
|
|
||||||
|
////fmt.Println(columnName)
|
||||||
|
index := getIndex(scanContext.columnNames, columnName)
|
||||||
|
|
||||||
|
if index < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return cellValue(row, index)
|
||||||
|
}
|
||||||
|
|
||||||
func reflectValueToString(val interface{}) string {
|
func reflectValueToString(val interface{}) string {
|
||||||
//spew.Dump(val)
|
//spew.Dump(val)
|
||||||
|
|
||||||
|
|
@ -549,6 +577,7 @@ var nullTimeType = reflect.TypeOf(NullTime{})
|
||||||
|
|
||||||
func newScanType(columnType *sql.ColumnType) reflect.Type {
|
func newScanType(columnType *sql.ColumnType) reflect.Type {
|
||||||
//spew.Dump(columnType)
|
//spew.Dump(columnType)
|
||||||
|
//fmt.Println(columnType.DatabaseTypeName())
|
||||||
switch columnType.DatabaseTypeName() {
|
switch columnType.DatabaseTypeName() {
|
||||||
case "INT2":
|
case "INT2":
|
||||||
return nullInt16Type
|
return nullInt16Type
|
||||||
|
|
|
||||||
|
|
@ -130,7 +130,7 @@ func TestSelect_ScanToSlice(t *testing.T) {
|
||||||
func TestJoinQuerySlice(t *testing.T) {
|
func TestJoinQuerySlice(t *testing.T) {
|
||||||
type FilmsPerLanguage struct {
|
type FilmsPerLanguage struct {
|
||||||
Language *model.Language
|
Language *model.Language
|
||||||
Film *[]model.Film
|
Film []model.Film
|
||||||
}
|
}
|
||||||
|
|
||||||
filmsPerLanguage := []FilmsPerLanguage{}
|
filmsPerLanguage := []FilmsPerLanguage{}
|
||||||
|
|
@ -139,13 +139,13 @@ func TestJoinQuerySlice(t *testing.T) {
|
||||||
query := Film.
|
query := Film.
|
||||||
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
|
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
|
||||||
SELECT(Language.AllColumns, Film.AllColumns).
|
SELECT(Language.AllColumns, Film.AllColumns).
|
||||||
|
Where(Film.Rating.EqL(string(model.MpaaRating_NC17))).
|
||||||
Limit(15)
|
Limit(15)
|
||||||
|
|
||||||
queryStr, err := query.String()
|
queryStr, err := query.String()
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id LIMIT 15`)
|
assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = 'NC-17' LIMIT 15`)
|
||||||
|
|
||||||
//fmt.Println(queryStr)
|
//fmt.Println(queryStr)
|
||||||
|
|
||||||
err = query.Execute(db, &filmsPerLanguage)
|
err = query.Execute(db, &filmsPerLanguage)
|
||||||
|
|
@ -158,7 +158,11 @@ func TestJoinQuerySlice(t *testing.T) {
|
||||||
//spew.Dump(filmsPerLanguage)
|
//spew.Dump(filmsPerLanguage)
|
||||||
|
|
||||||
assert.Equal(t, len(filmsPerLanguage), 1)
|
assert.Equal(t, len(filmsPerLanguage), 1)
|
||||||
assert.Equal(t, len(*filmsPerLanguage[0].Film), limit)
|
assert.Equal(t, len(filmsPerLanguage[0].Film), limit)
|
||||||
|
|
||||||
|
englishFilms := filmsPerLanguage[0]
|
||||||
|
|
||||||
|
assert.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_NC17)
|
||||||
|
|
||||||
//spew.Dump(filmsPerLanguage)
|
//spew.Dump(filmsPerLanguage)
|
||||||
|
|
||||||
|
|
@ -167,7 +171,7 @@ func TestJoinQuerySlice(t *testing.T) {
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
assert.Equal(t, len(filmsPerLanguage), 1)
|
assert.Equal(t, len(filmsPerLanguage), 1)
|
||||||
assert.Equal(t, len(*filmsPerLanguage[0].Film), limit)
|
assert.Equal(t, len(filmsPerLanguage[0].Film), limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestJoinQuerySliceWithPtrs(t *testing.T) {
|
func TestJoinQuerySliceWithPtrs(t *testing.T) {
|
||||||
|
|
@ -492,6 +496,8 @@ func TestSelectQueryScalar(t *testing.T) {
|
||||||
|
|
||||||
assert.Equal(t, len(maxRentalRateFilms), 336)
|
assert.Equal(t, len(maxRentalRateFilms), 336)
|
||||||
|
|
||||||
|
gRating := model.MpaaRating_G
|
||||||
|
|
||||||
assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{
|
assert.DeepEqual(t, maxRentalRateFilms[0], model.Film{
|
||||||
FilmID: 2,
|
FilmID: 2,
|
||||||
Title: "Ace Goldfinger",
|
Title: "Ace Goldfinger",
|
||||||
|
|
@ -501,7 +507,7 @@ func TestSelectQueryScalar(t *testing.T) {
|
||||||
RentalRate: 4.99,
|
RentalRate: 4.99,
|
||||||
Length: int16Ptr(48),
|
Length: int16Ptr(48),
|
||||||
ReplacementCost: 12.99,
|
ReplacementCost: 12.99,
|
||||||
Rating: stringPtr("G"),
|
Rating: &gRating,
|
||||||
RentalDuration: 3,
|
RentalDuration: 3,
|
||||||
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3),
|
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3),
|
||||||
SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"),
|
SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue