Data model generator for postgres database.

This commit is contained in:
sub0Zero 2019-03-04 19:35:49 +01:00 committed by zer0sub
parent 92edc96c9a
commit 319c9f757d
9 changed files with 505 additions and 213 deletions

View file

@ -0,0 +1,32 @@
package generator
import (
"github.com/sub0Zero/go-sqlbuilder/generator/metadata"
"path/filepath"
)
func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
err := ensureDirPath(modelDirPath)
if err != nil {
return err
}
for _, tableInfo := range databaseInfo.TableInfos {
text, err := generateTemplate(DataModelTemplate, tableInfo)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, tableInfo.Name, text)
if err != nil {
return err
}
}
return nil
}

View file

@ -3,8 +3,8 @@ package generator
import (
"database/sql"
_ "github.com/lib/pq"
"github.com/serenize/snaker"
"os"
"github.com/sub0Zero/go-sqlbuilder/generator/metadata"
"path"
)
type DbConnectInfo struct {
@ -16,12 +16,10 @@ type DbConnectInfo struct {
}
func Generate(folderPath string, connectString string, databaseName, schemaName string) error {
if _, err := os.Stat(folderPath); os.IsNotExist(err) {
err := os.Mkdir(folderPath, os.ModePerm)
err := cleanUpGeneratedFiles(path.Join(folderPath, databaseName, schemaName))
if err != nil {
return err
}
if err != nil {
return err
}
db, err := sql.Open("postgres", connectString)
@ -36,128 +34,23 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
return err
}
tables, err := getTablesInfo(db, schemaName)
databaseInfo, err := metadata.GetDatabaseInfo(db, databaseName, schemaName)
if err != nil {
return err
}
for _, table := range tables {
err = generateSqlBuilderModel(databaseName, schemaName, table, folderPath)
err = generateSqlBuilderModel(databaseInfo, folderPath)
if err != nil {
return err
}
if err != nil {
return err
}
err = generateDataModel(databaseInfo, folderPath)
if err != nil {
return err
}
return nil
}
type TableInfo struct {
Name string
Columns []ColumnInfo
}
func getTablesInfo(db *sql.DB, schemaName string) ([]TableInfo, error) {
tableNames, err := getListOfTables(db, schemaName)
if err != nil {
return nil, err
}
tables := []TableInfo{}
for _, tableName := range tableNames {
columns, err := getColumnInfos(db, tableName)
if err != nil {
return nil, err
}
tables = append(tables, TableInfo{tableName, columns})
}
return tables, nil
}
func getListOfTables(db *sql.DB, schemaName string) ([]string, error) {
rows, err := db.Query(`
SELECT table_name FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';`, schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
tables := []string{}
for rows.Next() {
var table string
err = rows.Scan(&table)
if err != nil {
return nil, err
}
tables = append(tables, table)
}
err = rows.Err()
if err != nil {
return nil, err
}
return tables, nil
}
type ColumnInfo struct {
Name string
IsNullable bool
DataType string
}
func (c *ColumnInfo) CamelCaseName() string {
return snaker.SnakeToCamel(c.Name)
}
func getColumnInfos(db *sql.DB, tableName string) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type
FROM information_schema.columns
where table_name = $1
order by ordinal_position;`
//fmt.Println(query)
rows, err := db.Query(query, &tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnInfo{}
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType)
columnInfo.IsNullable = isNullable == "YES"
if err != nil {
return nil, err
}
ret = append(ret, columnInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,116 @@
package metadata
import (
"database/sql"
"github.com/serenize/snaker"
)
type ColumnInfo struct {
Name string
IsNullable bool
DataType string
TableInfo *TableInfo
}
func (c ColumnInfo) IsUnique() bool {
for _, uniqueColumn := range c.TableInfo.PrimaryKeys {
if uniqueColumn == c.Name {
return true
}
}
return false
}
func (c ColumnInfo) ToGoVarName() string {
return snaker.SnakeToCamelLower(c.TableInfo.Name) + snaker.SnakeToCamel(c.Name) + "Column"
}
func (c ColumnInfo) ToGoType() string {
typeStr := c.GoBaseType()
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
func (c ColumnInfo) GoBaseType() string {
if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok {
return snaker.SnakeToCamel(forignKeyTable)
} else {
switch c.DataType {
case "boolean":
return "bool"
case "smallint":
return "int16"
case "integer":
return "int"
case "bigint":
return "int64"
//case "date" : return "time.Time"
case "bytea":
return "[]byte"
case "text":
return "string"
default:
return "string"
}
}
}
func (c ColumnInfo) ToGoDMFieldName() string {
if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok {
return snaker.SnakeToCamel(forignKeyTable)
} else {
return snaker.SnakeToCamel(c.Name)
}
}
func (c ColumnInfo) ToGoFieldName() string {
return snaker.SnakeToCamel(c.Name)
}
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type
FROM information_schema.columns
where table_schema = $1 and table_name = $2
order by ordinal_position;`
//fmt.Println(query)
rows, err := db.Query(query, tableInfo.DatabaseInfo.SchemaName, &tableInfo.Name)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnInfo{}
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType)
columnInfo.IsNullable = isNullable == "YES"
if err != nil {
return nil, err
}
columnInfo.TableInfo = tableInfo
ret = append(ret, columnInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,29 @@
package metadata
import (
"database/sql"
)
type DatabaseInfo struct {
DatabaseName string
SchemaName string
TableInfos []TableInfo
}
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
databaseInfo := &DatabaseInfo{
databaseName,
schemaName,
[]TableInfo{},
}
var err error
databaseInfo.TableInfos, err = fetchTableInfos(db, databaseInfo)
if err != nil {
return nil, err
}
return databaseInfo, nil
}

View file

@ -0,0 +1,172 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/serenize/snaker"
"strings"
)
type TableInfo struct {
Name string
PrimaryKeys []string
ForeignTableMap map[string]string
Columns []ColumnInfo
DatabaseInfo *DatabaseInfo
}
func (t TableInfo) IsForeignKey(columnName string) bool {
_, exist := t.ForeignTableMap[columnName]
return exist
}
func (t TableInfo) ToGoModelStructName() string {
return snaker.SnakeToCamel(t.Name)
}
func (t TableInfo) ToGoVarName() string {
return snaker.SnakeToCamel(t.Name)
}
func (t TableInfo) ToGoStructName() string {
return snaker.SnakeToCamel(t.Name) + "Table"
}
func (t TableInfo) ToGoColumnFieldList(sep string) string {
columnNames := []string{}
for _, columnInfo := range t.Columns {
columnNames = append(columnNames, columnInfo.ToGoVarName())
}
return strings.Join(columnNames, sep)
}
func fetchTableInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]TableInfo, error) {
query := `
SELECT table_name
FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';`
//fmt.Println(query, schemaName)
rows, err := db.Query(query, &databaseInfo.SchemaName)
if err != nil {
return nil, err
}
defer rows.Close()
tableInfos := []TableInfo{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo := &TableInfo{}
tableInfo.Name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, databaseInfo.SchemaName, tableName)
if err != nil {
return nil, err
}
tableInfo.DatabaseInfo = databaseInfo
tableInfo.Columns, err = fetchColumnInfos(db, tableInfo)
if err != nil {
return nil, err
}
tableInfo.ForeignTableMap, err = getForignKeyMap(db, databaseInfo.SchemaName, tableName)
if err != nil {
return nil, err
}
tableInfos = append(tableInfos, *tableInfo)
}
fmt.Println("FOUND", len(tableInfos), "tables")
err = rows.Err()
if err != nil {
return nil, err
}
return tableInfos, nil
}
func getPrimaryKeys(db *sql.DB, schemaName, tableName string) ([]string, error) {
query := `
SELECT c.column_name
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY';
`
rows, err := db.Query(query, schemaName, tableName)
if err != nil {
return nil, err
}
primaryKeys := []string{}
for rows.Next() {
primaryKey := ""
err := rows.Scan(&primaryKey)
if err != nil {
return nil, err
}
primaryKeys = append(primaryKeys, primaryKey)
}
return primaryKeys, nil
}
func getForignKeyMap(db *sql.DB, schemaName, tableName string) (map[string]string, error) {
query := `
SELECT
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = $1 AND tc.table_name=$2;
`
rows, err := db.Query(query, schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := map[string]string{}
for rows.Next() {
var columnName, foreignTableName, foreignColumnName string
err := rows.Scan(&columnName, &foreignTableName, &foreignColumnName)
if err != nil {
return nil, err
}
ret[columnName] = foreignTableName
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,91 +1,32 @@
package generator
import (
"bytes"
"github.com/serenize/snaker"
"go/format"
"os"
"github.com/sub0Zero/go-sqlbuilder/generator/metadata"
"path/filepath"
"strings"
"text/template"
)
func generateSqlBuilderModel(databaseName, schemaName string, tableInfo TableInfo, dirPath string) error {
func generateSqlBuilderModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "table")
schemaDirPath := filepath.Join(dirPath, databaseName, schemaName, "table")
err := ensureDirPath(modelDirPath)
if _, err := os.Stat(schemaDirPath); os.IsNotExist(err) {
err := os.MkdirAll(schemaDirPath, os.ModePerm)
if err != nil {
return err
}
for _, tableInfo := range databaseInfo.TableInfos {
text, err := generateTemplate(SqlBuilderTableTemplate, tableInfo)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, tableInfo.Name+"_table", text)
if err != nil {
return err
}
}
t, err := template.New("TableTemplate").Funcs(template.FuncMap{
"camelize": func(txt string) string {
return snaker.SnakeToCamel(txt)
},
"columnName": columnName,
}).Parse(TableTemplate)
if err != nil {
return err
}
newGoFilePath := filepath.Join(schemaDirPath, tableInfo.Name) + ".go"
file, err := os.Create(newGoFilePath)
if err != nil {
return err
}
defer file.Close()
tableTemplate := TableTemplateData{
databaseName,
tableInfo,
}
//err = t.Execute(file, &tableTemplate)
//
//if err != nil {
// return err
//}
var buf bytes.Buffer
if err := t.Execute(&buf, &tableTemplate); err != nil {
return err
}
p, err := format.Source(buf.Bytes())
if err != nil {
return err
}
_, err = file.Write(p)
if err != nil {
return err
}
return nil
}
type TableTemplateData struct {
PackageName string
TableInfo TableInfo
}
func columnName(table, column string) string {
return snaker.SnakeToCamelLower(table) + snaker.SnakeToCamel(column) + "Column"
}
func (t *TableTemplateData) ColumnNameList(sep string) string {
columnNames := []string{}
for _, columnInfo := range t.TableInfo.Columns {
columnInfoName := columnInfo.Name
columnNames = append(columnNames, columnName(t.TableInfo.Name, columnInfoName))
}
return strings.Join(columnNames, sep)
}

View file

@ -1,30 +1,39 @@
package generator
var TableTemplate = `package table
var SqlBuilderTableTemplate = `package table
import "github.com/sub0Zero/go-sqlbuilder/sqlbuilder"
type {{camelize .TableInfo.Name}}Table struct {
type {{.ToGoStructName}} struct {
sqlbuilder.Table
//Columns
{{- range .TableInfo.Columns}}
{{camelize .Name}} sqlbuilder.NonAliasColumn
{{- range .Columns}}
{{.ToGoFieldName}} sqlbuilder.NonAliasColumn
{{- end}}
}
var {{camelize .TableInfo.Name}} = &{{camelize .TableInfo.Name}}Table{
Table: *sqlbuilder.NewTable("{{.TableInfo.Name}}", {{.ColumnNameList ", "}}),
var {{.ToGoVarName}} = &{{.ToGoStructName}}{
Table: *sqlbuilder.NewTable("{{.Name}}", {{.ToGoColumnFieldList ", "}}),
//Columns
{{- range .TableInfo.Columns}}
{{camelize .Name}}: {{columnName $.TableInfo.Name .Name}},
{{- range .Columns}}
{{.ToGoFieldName}}: {{.ToGoVarName}},
{{- end}}
}
var (
{{- range .TableInfo.Columns}}
{{columnName $.TableInfo.Name .Name}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}})
{{- range .Columns}}
{{.ToGoVarName}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}})
{{- end}}
)
`
var DataModelTemplate = `package model
type {{.ToGoModelStructName}} struct {
{{- range .Columns}}
{{.ToGoDMFieldName}} {{.ToGoType}} {{if .IsUnique}}` + "`sql:\"unique\"`" + ` {{end}}
{{- end}}
}
`

96
generator/utils.go Normal file
View file

@ -0,0 +1,96 @@
package generator
import (
"bytes"
"github.com/serenize/snaker"
"go/format"
"os"
"path/filepath"
"text/template"
)
func saveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath := filepath.Join(dirPath, fileName) + ".go"
file, err := os.Create(newGoFilePath)
if err != nil {
return err
}
defer file.Close()
p, err := format.Source(text)
if err != nil {
return err
}
_, err = file.Write(p)
if err != nil {
return err
}
return nil
}
func ensureDirPath(dirPath string) error {
if _, err := os.Stat(dirPath); os.IsNotExist(err) {
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
return err
}
}
return nil
}
func generateTemplate(templateText string, templateData interface{}) ([]byte, error) {
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{
"camelize": func(txt string) string {
return snaker.SnakeToCamel(txt)
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func cleanUpGeneratedFiles(dir string) error {
exist, err := dirExists(dir)
if err != nil {
return err
}
if exist {
err := os.RemoveAll(dir)
if err != nil {
return err
}
}
return nil
}
func dirExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return true, err
}