From 319c9f757d060af8968fb8f7918c4911c5255a3e Mon Sep 17 00:00:00 2001 From: sub0Zero Date: Mon, 4 Mar 2019 19:35:49 +0100 Subject: [PATCH] Data model generator for postgres database. --- generator/datamodel_generator.go | 32 ++++++ generator/generator.go | 137 +++------------------- generator/metadata/column_info.go | 116 +++++++++++++++++++ generator/metadata/database_info.go | 29 +++++ generator/metadata/table_info.go | 172 ++++++++++++++++++++++++++++ generator/sqlbuilder_generator.go | 91 +++------------ generator/templates.go | 29 +++-- generator/utils.go | 96 ++++++++++++++++ tests/generator_test.go | 16 ++- 9 files changed, 505 insertions(+), 213 deletions(-) create mode 100644 generator/datamodel_generator.go create mode 100644 generator/metadata/column_info.go create mode 100644 generator/metadata/database_info.go create mode 100644 generator/metadata/table_info.go create mode 100644 generator/utils.go diff --git a/generator/datamodel_generator.go b/generator/datamodel_generator.go new file mode 100644 index 0000000..230841f --- /dev/null +++ b/generator/datamodel_generator.go @@ -0,0 +1,32 @@ +package generator + +import ( + "github.com/sub0Zero/go-sqlbuilder/generator/metadata" + "path/filepath" +) + +func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error { + modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model") + + err := ensureDirPath(modelDirPath) + + if err != nil { + return err + } + + for _, tableInfo := range databaseInfo.TableInfos { + text, err := generateTemplate(DataModelTemplate, tableInfo) + + if err != nil { + return err + } + + err = saveGoFile(modelDirPath, tableInfo.Name, text) + + if err != nil { + return err + } + } + + return nil +} diff --git a/generator/generator.go b/generator/generator.go index 624a77b..53b7450 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -3,8 +3,8 @@ package generator import ( "database/sql" _ "github.com/lib/pq" - "github.com/serenize/snaker" - "os" + "github.com/sub0Zero/go-sqlbuilder/generator/metadata" + "path" ) type DbConnectInfo struct { @@ -16,12 +16,10 @@ type DbConnectInfo struct { } func Generate(folderPath string, connectString string, databaseName, schemaName string) error { - if _, err := os.Stat(folderPath); os.IsNotExist(err) { - err := os.Mkdir(folderPath, os.ModePerm) + err := cleanUpGeneratedFiles(path.Join(folderPath, databaseName, schemaName)) - if err != nil { - return err - } + if err != nil { + return err } db, err := sql.Open("postgres", connectString) @@ -36,128 +34,23 @@ func Generate(folderPath string, connectString string, databaseName, schemaName return err } - tables, err := getTablesInfo(db, schemaName) + databaseInfo, err := metadata.GetDatabaseInfo(db, databaseName, schemaName) if err != nil { return err } - for _, table := range tables { - err = generateSqlBuilderModel(databaseName, schemaName, table, folderPath) + err = generateSqlBuilderModel(databaseInfo, folderPath) - if err != nil { - return err - } + if err != nil { + return err + } + + err = generateDataModel(databaseInfo, folderPath) + + if err != nil { + return err } return nil } - -type TableInfo struct { - Name string - Columns []ColumnInfo -} - -func getTablesInfo(db *sql.DB, schemaName string) ([]TableInfo, error) { - tableNames, err := getListOfTables(db, schemaName) - - if err != nil { - return nil, err - } - - tables := []TableInfo{} - for _, tableName := range tableNames { - columns, err := getColumnInfos(db, tableName) - - if err != nil { - return nil, err - } - - tables = append(tables, TableInfo{tableName, columns}) - } - - return tables, nil -} - -func getListOfTables(db *sql.DB, schemaName string) ([]string, error) { - - rows, err := db.Query(` -SELECT table_name FROM information_schema.tables -where table_schema = $1 and table_type = 'BASE TABLE';`, schemaName) - - if err != nil { - return nil, err - } - defer rows.Close() - - tables := []string{} - for rows.Next() { - var table string - err = rows.Scan(&table) - if err != nil { - return nil, err - } - - tables = append(tables, table) - } - - err = rows.Err() - - if err != nil { - return nil, err - } - - return tables, nil -} - -type ColumnInfo struct { - Name string - IsNullable bool - DataType string -} - -func (c *ColumnInfo) CamelCaseName() string { - return snaker.SnakeToCamel(c.Name) -} - -func getColumnInfos(db *sql.DB, tableName string) ([]ColumnInfo, error) { - - query := ` -SELECT column_name, is_nullable, data_type -FROM information_schema.columns -where table_name = $1 -order by ordinal_position;` - - //fmt.Println(query) - - rows, err := db.Query(query, &tableName) - - if err != nil { - return nil, err - } - defer rows.Close() - - ret := []ColumnInfo{} - - for rows.Next() { - columnInfo := ColumnInfo{} - var isNullable string - err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType) - - columnInfo.IsNullable = isNullable == "YES" - - if err != nil { - return nil, err - } - - ret = append(ret, columnInfo) - } - - err = rows.Err() - - if err != nil { - return nil, err - } - - return ret, nil -} diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go new file mode 100644 index 0000000..84dd754 --- /dev/null +++ b/generator/metadata/column_info.go @@ -0,0 +1,116 @@ +package metadata + +import ( + "database/sql" + "github.com/serenize/snaker" +) + +type ColumnInfo struct { + Name string + IsNullable bool + DataType string + TableInfo *TableInfo +} + +func (c ColumnInfo) IsUnique() bool { + for _, uniqueColumn := range c.TableInfo.PrimaryKeys { + if uniqueColumn == c.Name { + return true + } + } + + return false +} + +func (c ColumnInfo) ToGoVarName() string { + return snaker.SnakeToCamelLower(c.TableInfo.Name) + snaker.SnakeToCamel(c.Name) + "Column" +} + +func (c ColumnInfo) ToGoType() string { + typeStr := c.GoBaseType() + if c.IsNullable { + return "*" + typeStr + } + + return typeStr +} + +func (c ColumnInfo) GoBaseType() string { + if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok { + return snaker.SnakeToCamel(forignKeyTable) + } else { + switch c.DataType { + case "boolean": + return "bool" + case "smallint": + return "int16" + case "integer": + return "int" + case "bigint": + return "int64" + //case "date" : return "time.Time" + case "bytea": + return "[]byte" + case "text": + return "string" + default: + return "string" + } + } +} + +func (c ColumnInfo) ToGoDMFieldName() string { + if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok { + return snaker.SnakeToCamel(forignKeyTable) + } else { + return snaker.SnakeToCamel(c.Name) + } +} + +func (c ColumnInfo) ToGoFieldName() string { + return snaker.SnakeToCamel(c.Name) +} + +func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) { + + query := ` +SELECT column_name, is_nullable, data_type +FROM information_schema.columns +where table_schema = $1 and table_name = $2 +order by ordinal_position;` + + //fmt.Println(query) + + rows, err := db.Query(query, tableInfo.DatabaseInfo.SchemaName, &tableInfo.Name) + + if err != nil { + return nil, err + } + defer rows.Close() + + ret := []ColumnInfo{} + + for rows.Next() { + columnInfo := ColumnInfo{} + var isNullable string + err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType) + + columnInfo.IsNullable = isNullable == "YES" + + if err != nil { + return nil, err + } + + columnInfo.TableInfo = tableInfo + + ret = append(ret, columnInfo) + } + + err = rows.Err() + + if err != nil { + return nil, err + } + + return ret, nil +} diff --git a/generator/metadata/database_info.go b/generator/metadata/database_info.go new file mode 100644 index 0000000..658df3e --- /dev/null +++ b/generator/metadata/database_info.go @@ -0,0 +1,29 @@ +package metadata + +import ( + "database/sql" +) + +type DatabaseInfo struct { + DatabaseName string + SchemaName string + TableInfos []TableInfo +} + +func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) { + + databaseInfo := &DatabaseInfo{ + databaseName, + schemaName, + []TableInfo{}, + } + + var err error + databaseInfo.TableInfos, err = fetchTableInfos(db, databaseInfo) + + if err != nil { + return nil, err + } + + return databaseInfo, nil +} diff --git a/generator/metadata/table_info.go b/generator/metadata/table_info.go new file mode 100644 index 0000000..b667948 --- /dev/null +++ b/generator/metadata/table_info.go @@ -0,0 +1,172 @@ +package metadata + +import ( + "database/sql" + "fmt" + "github.com/serenize/snaker" + "strings" +) + +type TableInfo struct { + Name string + PrimaryKeys []string + ForeignTableMap map[string]string + Columns []ColumnInfo + DatabaseInfo *DatabaseInfo +} + +func (t TableInfo) IsForeignKey(columnName string) bool { + _, exist := t.ForeignTableMap[columnName] + + return exist +} + +func (t TableInfo) ToGoModelStructName() string { + return snaker.SnakeToCamel(t.Name) +} + +func (t TableInfo) ToGoVarName() string { + return snaker.SnakeToCamel(t.Name) +} + +func (t TableInfo) ToGoStructName() string { + return snaker.SnakeToCamel(t.Name) + "Table" +} + +func (t TableInfo) ToGoColumnFieldList(sep string) string { + columnNames := []string{} + for _, columnInfo := range t.Columns { + columnNames = append(columnNames, columnInfo.ToGoVarName()) + } + return strings.Join(columnNames, sep) +} + +func fetchTableInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]TableInfo, error) { + query := ` +SELECT table_name +FROM information_schema.tables +where table_schema = $1 and table_type = 'BASE TABLE';` + + //fmt.Println(query, schemaName) + + rows, err := db.Query(query, &databaseInfo.SchemaName) + + if err != nil { + return nil, err + } + defer rows.Close() + + tableInfos := []TableInfo{} + for rows.Next() { + var tableName string + err = rows.Scan(&tableName) + if err != nil { + return nil, err + } + + tableInfo := &TableInfo{} + tableInfo.Name = tableName + tableInfo.PrimaryKeys, err = getPrimaryKeys(db, databaseInfo.SchemaName, tableName) + if err != nil { + return nil, err + } + tableInfo.DatabaseInfo = databaseInfo + tableInfo.Columns, err = fetchColumnInfos(db, tableInfo) + + if err != nil { + return nil, err + } + + tableInfo.ForeignTableMap, err = getForignKeyMap(db, databaseInfo.SchemaName, tableName) + + if err != nil { + return nil, err + } + + tableInfos = append(tableInfos, *tableInfo) + } + + fmt.Println("FOUND", len(tableInfos), "tables") + + err = rows.Err() + + if err != nil { + return nil, err + } + + return tableInfos, nil +} + +func getPrimaryKeys(db *sql.DB, schemaName, tableName string) ([]string, error) { + query := ` +SELECT c.column_name +FROM information_schema.key_column_usage AS c +LEFT JOIN information_schema.table_constraints AS t +ON t.constraint_name = c.constraint_name +WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'; +` + rows, err := db.Query(query, schemaName, tableName) + + if err != nil { + return nil, err + } + + primaryKeys := []string{} + + for rows.Next() { + primaryKey := "" + err := rows.Scan(&primaryKey) + + if err != nil { + return nil, err + } + + primaryKeys = append(primaryKeys, primaryKey) + } + + return primaryKeys, nil +} + +func getForignKeyMap(db *sql.DB, schemaName, tableName string) (map[string]string, error) { + query := ` +SELECT + kcu.column_name, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name +FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema +WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = $1 AND tc.table_name=$2; +` + rows, err := db.Query(query, schemaName, tableName) + + if err != nil { + return nil, err + } + defer rows.Close() + + ret := map[string]string{} + for rows.Next() { + var columnName, foreignTableName, foreignColumnName string + err := rows.Scan(&columnName, &foreignTableName, &foreignColumnName) + + if err != nil { + return nil, err + } + + ret[columnName] = foreignTableName + } + + err = rows.Err() + + if err != nil { + return nil, err + } + + return ret, nil +} diff --git a/generator/sqlbuilder_generator.go b/generator/sqlbuilder_generator.go index e35ae17..051de05 100644 --- a/generator/sqlbuilder_generator.go +++ b/generator/sqlbuilder_generator.go @@ -1,91 +1,32 @@ package generator import ( - "bytes" - "github.com/serenize/snaker" - "go/format" - "os" + "github.com/sub0Zero/go-sqlbuilder/generator/metadata" "path/filepath" - "strings" - "text/template" ) -func generateSqlBuilderModel(databaseName, schemaName string, tableInfo TableInfo, dirPath string) error { +func generateSqlBuilderModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error { + modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "table") - schemaDirPath := filepath.Join(dirPath, databaseName, schemaName, "table") + err := ensureDirPath(modelDirPath) - if _, err := os.Stat(schemaDirPath); os.IsNotExist(err) { - err := os.MkdirAll(schemaDirPath, os.ModePerm) + if err != nil { + return err + } + + for _, tableInfo := range databaseInfo.TableInfos { + text, err := generateTemplate(SqlBuilderTableTemplate, tableInfo) + + if err != nil { + return err + } + + err = saveGoFile(modelDirPath, tableInfo.Name+"_table", text) if err != nil { return err } } - t, err := template.New("TableTemplate").Funcs(template.FuncMap{ - "camelize": func(txt string) string { - return snaker.SnakeToCamel(txt) - }, - "columnName": columnName, - }).Parse(TableTemplate) - - if err != nil { - return err - } - - newGoFilePath := filepath.Join(schemaDirPath, tableInfo.Name) + ".go" - - file, err := os.Create(newGoFilePath) - - if err != nil { - return err - } - - defer file.Close() - - tableTemplate := TableTemplateData{ - databaseName, - tableInfo, - } - - //err = t.Execute(file, &tableTemplate) - // - //if err != nil { - // return err - //} - - var buf bytes.Buffer - if err := t.Execute(&buf, &tableTemplate); err != nil { - return err - } - p, err := format.Source(buf.Bytes()) - if err != nil { - return err - } - - _, err = file.Write(p) - - if err != nil { - return err - } - return nil } - -type TableTemplateData struct { - PackageName string - TableInfo TableInfo -} - -func columnName(table, column string) string { - return snaker.SnakeToCamelLower(table) + snaker.SnakeToCamel(column) + "Column" -} - -func (t *TableTemplateData) ColumnNameList(sep string) string { - columnNames := []string{} - for _, columnInfo := range t.TableInfo.Columns { - columnInfoName := columnInfo.Name - columnNames = append(columnNames, columnName(t.TableInfo.Name, columnInfoName)) - } - return strings.Join(columnNames, sep) -} diff --git a/generator/templates.go b/generator/templates.go index cec76b4..626f1d8 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -1,30 +1,39 @@ package generator -var TableTemplate = `package table +var SqlBuilderTableTemplate = `package table import "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" -type {{camelize .TableInfo.Name}}Table struct { +type {{.ToGoStructName}} struct { sqlbuilder.Table //Columns -{{- range .TableInfo.Columns}} - {{camelize .Name}} sqlbuilder.NonAliasColumn +{{- range .Columns}} + {{.ToGoFieldName}} sqlbuilder.NonAliasColumn {{- end}} } -var {{camelize .TableInfo.Name}} = &{{camelize .TableInfo.Name}}Table{ - Table: *sqlbuilder.NewTable("{{.TableInfo.Name}}", {{.ColumnNameList ", "}}), +var {{.ToGoVarName}} = &{{.ToGoStructName}}{ + Table: *sqlbuilder.NewTable("{{.Name}}", {{.ToGoColumnFieldList ", "}}), //Columns -{{- range .TableInfo.Columns}} - {{camelize .Name}}: {{columnName $.TableInfo.Name .Name}}, +{{- range .Columns}} + {{.ToGoFieldName}}: {{.ToGoVarName}}, {{- end}} } var ( -{{- range .TableInfo.Columns}} - {{columnName $.TableInfo.Name .Name}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}}) +{{- range .Columns}} + {{.ToGoVarName}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}}) {{- end}} ) ` + +var DataModelTemplate = `package model + +type {{.ToGoModelStructName}} struct { +{{- range .Columns}} + {{.ToGoDMFieldName}} {{.ToGoType}} {{if .IsUnique}}` + "`sql:\"unique\"`" + ` {{end}} +{{- end}} +} +` diff --git a/generator/utils.go b/generator/utils.go new file mode 100644 index 0000000..95742a6 --- /dev/null +++ b/generator/utils.go @@ -0,0 +1,96 @@ +package generator + +import ( + "bytes" + "github.com/serenize/snaker" + "go/format" + "os" + "path/filepath" + "text/template" +) + +func saveGoFile(dirPath, fileName string, text []byte) error { + newGoFilePath := filepath.Join(dirPath, fileName) + ".go" + + file, err := os.Create(newGoFilePath) + + if err != nil { + return err + } + + defer file.Close() + + p, err := format.Source(text) + if err != nil { + return err + } + + _, err = file.Write(p) + + if err != nil { + return err + } + + return nil +} + +func ensureDirPath(dirPath string) error { + if _, err := os.Stat(dirPath); os.IsNotExist(err) { + err := os.MkdirAll(dirPath, os.ModePerm) + + if err != nil { + return err + } + } + + return nil +} + +func generateTemplate(templateText string, templateData interface{}) ([]byte, error) { + + t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{ + "camelize": func(txt string) string { + return snaker.SnakeToCamel(txt) + }, + }).Parse(templateText) + + if err != nil { + return nil, err + } + + var buf bytes.Buffer + if err := t.Execute(&buf, templateData); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func cleanUpGeneratedFiles(dir string) error { + exist, err := dirExists(dir) + + if err != nil { + return err + } + + if exist { + err := os.RemoveAll(dir) + + if err != nil { + return err + } + } + + return nil +} + +func dirExists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return true, err +} diff --git a/tests/generator_test.go b/tests/generator_test.go index 4ff1cb1..3bf31a7 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -2,9 +2,9 @@ package tests import ( "fmt" + . "github.com/sub0Zero/.test_files/dvd_rental/dvds/table" "github.com/sub0Zero/go-sqlbuilder/generator" . "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" - "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/public/table" "gotest.tools/assert" "testing" ) @@ -16,10 +16,10 @@ var ( user = "postgres" password = "postgres" dbname = "dvd_rental" - schemaName = "public" + schemaName = "dvds" ) -//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema public -path .test_files +//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files func TestGenerateModel(t *testing.T) { connectString := fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", @@ -28,13 +28,17 @@ func TestGenerateModel(t *testing.T) { err := generator.Generate(folderPath, connectString, dbname, schemaName) assert.NilError(t, err) + + err = generator.Generate(folderPath, connectString, dbname, "sport") + + assert.NilError(t, err) } func TestSelectQuery(t *testing.T) { - query, err := table.Actor.InnerJoinOn(table.Store, Eq(table.Actor.ActorID, table.Store.StoreID)). - Select(table.Store.StoreID, table.Store.AddressID, table.Actor.ActorID).String(schemaName) + query, err := Actor.InnerJoinOn(Store, Eq(Actor.ActorID, Store.StoreID)). + Select(Store.StoreID, Store.AddressID, Actor.ActorID).String(schemaName) assert.NilError(t, err) - assert.Equal(t, query, "SELECT store.store_id,store.address_id,actor.actor_id FROM public.actor JOIN public.store ON actor.actor_id=store.store_id") + assert.Equal(t, query, "SELECT store.store_id,store.address_id,actor.actor_id FROM dvds.actor JOIN dvds.store ON actor.actor_id=store.store_id") }