Generator clean up refactoring.
This commit is contained in:
parent
7de8c1c45e
commit
b3a52ceb31
16 changed files with 372 additions and 476 deletions
|
|
@ -7,25 +7,52 @@ import (
|
|||
"os"
|
||||
)
|
||||
|
||||
var genDirPath string
|
||||
var dbConnectionString string
|
||||
var dbName string
|
||||
var schemaName string
|
||||
var (
|
||||
host string
|
||||
port string
|
||||
user string
|
||||
password string
|
||||
sslmode string
|
||||
params string
|
||||
dbName string
|
||||
schemaName string
|
||||
|
||||
destDir string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&genDirPath, "path", "", "Destination for generated files.")
|
||||
flag.StringVar(&dbConnectionString, "db", "", "Connection string to database server")
|
||||
flag.StringVar(&dbName, "dbName", "", "Name of the database")
|
||||
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
|
||||
flag.StringVar(&port, "port", "", "Database port")
|
||||
flag.StringVar(&user, "user", "", "Database user")
|
||||
flag.StringVar(&password, "password", "", "The user’s password")
|
||||
flag.StringVar(&sslmode, "sslmode", "disable", "Whether or not to use SSL")
|
||||
flag.StringVar(¶ms, "params", "", "Additional connection string parameters.")
|
||||
|
||||
flag.StringVar(&dbName, "dbname", "", "name of the database")
|
||||
flag.StringVar(&schemaName, "schema", "public", "Database schema name.")
|
||||
|
||||
flag.StringVar(&destDir, "path", "", "Destination dir for generated files.")
|
||||
|
||||
flag.Parse()
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
fmt.Println(genDirPath, dbConnectionString, dbName, schemaName)
|
||||
genData := generator.GeneratorData{
|
||||
Host: host,
|
||||
Port: port,
|
||||
User: user,
|
||||
Password: password,
|
||||
SslMode: sslmode,
|
||||
Params: params,
|
||||
|
||||
err := generator.Generate(genDirPath, dbConnectionString, dbName, schemaName)
|
||||
DbName: dbName,
|
||||
SchemaName: schemaName,
|
||||
}
|
||||
|
||||
fmt.Println(destDir, genData)
|
||||
|
||||
err := generator.Generate(destDir, genData)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
|
|
|
|||
|
|
@ -1,58 +0,0 @@
|
|||
package generator
|
||||
|
||||
import (
|
||||
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
|
||||
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
|
||||
|
||||
err := ensureDirPath(modelDirPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, tableInfo := range databaseInfo.TableInfos {
|
||||
text, err := generateTemplate(DataModelTemplate, tableInfo)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = saveGoFile(modelDirPath, tableInfo.Name, text)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateEnumTypes(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
|
||||
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
|
||||
|
||||
err := ensureDirPath(modelDirPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, enumInfo := range databaseInfo.EnumInfos {
|
||||
text, err := generateTemplate(EnumModelTemplate, enumInfo)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = saveGoFile(modelDirPath, enumInfo.Name, text)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -2,28 +2,32 @@ package generator
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
|
||||
"github.com/sub0zero/go-sqlbuilder/generator/postgres-metadata"
|
||||
"path"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type DbConnectInfo struct {
|
||||
host string
|
||||
port int
|
||||
user string
|
||||
password string
|
||||
dbname string
|
||||
type GeneratorData struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
SslMode string
|
||||
Params string
|
||||
|
||||
DbName string
|
||||
SchemaName string
|
||||
}
|
||||
|
||||
func Generate(folderPath string, connectString string, databaseName, schemaName string) error {
|
||||
func Generate(destDir string, genData GeneratorData) error {
|
||||
|
||||
err := cleanUpGeneratedFiles(path.Join(folderPath, databaseName, schemaName))
|
||||
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
|
||||
genData.Host, genData.Port, genData.User, genData.Password, genData.DbName, genData.SslMode, genData.Params)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", connectString)
|
||||
db, err := sql.Open("postgres", connectionString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -35,25 +39,32 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
|
|||
return err
|
||||
}
|
||||
|
||||
databaseInfo, err := metadata.GetDatabaseInfo(db, databaseName, schemaName)
|
||||
err = cleanUpGeneratedFiles(path.Join(destDir, genData.DbName, genData.SchemaName))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = generateSqlBuilderModel(databaseInfo, folderPath)
|
||||
schemaInfo, err := postgres_metadata.GetSchemaInfo(db, genData.DbName, genData.SchemaName)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = generateDataModel(databaseInfo, folderPath)
|
||||
err = generate(schemaInfo, destDir, "table", sqlBuilderTableTemplate, schemaInfo.TableInfos)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = generateEnumTypes(databaseInfo, folderPath)
|
||||
//err = generateDataModel(schemaInfo, destDir)
|
||||
err = generate(schemaInfo, destDir, "model", dataModelTemplate, schemaInfo.TableInfos)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = generate(schemaInfo, destDir, "model", enumModelTemplate, schemaInfo.EnumInfos)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -61,3 +72,35 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generate(schemaInfo postgres_metadata.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
|
||||
modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName)
|
||||
|
||||
err := ensureDirPath(modelDirPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
autoGenWarning, err := generateTemplate(autoGenWarningTemplate, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, metaData := range metaDataList {
|
||||
text, err := generateTemplate(template, metaData)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = saveGoFile(modelDirPath, metaData.Name(), append(autoGenWarning, text...))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type DatabaseInfo struct {
|
||||
DatabaseName string
|
||||
SchemaName string
|
||||
TableInfos []TableInfo
|
||||
EnumInfos []EnumInfo
|
||||
}
|
||||
|
||||
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
|
||||
|
||||
databaseInfo := &DatabaseInfo{
|
||||
databaseName,
|
||||
schemaName,
|
||||
[]TableInfo{},
|
||||
[]EnumInfo{},
|
||||
}
|
||||
|
||||
var err error
|
||||
databaseInfo.TableInfos, err = fetchTableInfos(db, databaseInfo)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
databaseInfo.EnumInfos, err = fetchEnumInfos(db, databaseInfo)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return databaseInfo, nil
|
||||
}
|
||||
5
generator/metadata/meta_data.go
Normal file
5
generator/metadata/meta_data.go
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
package metadata
|
||||
|
||||
type MetaData interface {
|
||||
Name() string
|
||||
}
|
||||
|
|
@ -1,197 +0,0 @@
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/serenize/snaker"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TableInfo struct {
|
||||
Name string
|
||||
PrimaryKeys []string
|
||||
ForeignTableMap map[string]string
|
||||
Columns []ColumnInfo
|
||||
DatabaseInfo *DatabaseInfo
|
||||
}
|
||||
|
||||
func (t TableInfo) GetImports() []string {
|
||||
imports := map[string]string{}
|
||||
|
||||
for _, column := range t.Columns {
|
||||
columnType := column.GoBaseType()
|
||||
|
||||
switch columnType {
|
||||
case "time.Time":
|
||||
imports["time.Time"] = "time"
|
||||
case "uuid.UUID":
|
||||
imports["uuid.UUID"] = "github.com/google/uuid"
|
||||
case "types.JSONText":
|
||||
imports["types.JSONText"] = "github.com/sub0zero/go-sqlbuilder/types"
|
||||
}
|
||||
}
|
||||
|
||||
ret := []string{}
|
||||
|
||||
for _, packageImport := range imports {
|
||||
ret = append(ret, packageImport)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (t TableInfo) IsForeignKey(columnName string) bool {
|
||||
_, exist := t.ForeignTableMap[columnName]
|
||||
|
||||
return exist
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoModelStructName() string {
|
||||
return snaker.SnakeToCamel(t.Name)
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoVarName() string {
|
||||
return snaker.SnakeToCamel(t.Name)
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoStructName() string {
|
||||
return snaker.SnakeToCamel(t.Name) + "Table"
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoColumnFieldList(sep string) string {
|
||||
columnNames := []string{}
|
||||
for _, columnInfo := range t.Columns {
|
||||
columnNames = append(columnNames, columnInfo.ToGoVarName())
|
||||
}
|
||||
return strings.Join(columnNames, sep)
|
||||
}
|
||||
|
||||
func fetchTableInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]TableInfo, error) {
|
||||
query := `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
where table_schema = $1 and table_type = 'BASE TABLE';`
|
||||
|
||||
//fmt.Println(query, schemaName)
|
||||
|
||||
rows, err := db.Query(query, &databaseInfo.SchemaName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tableInfos := []TableInfo{}
|
||||
for rows.Next() {
|
||||
var tableName string
|
||||
err = rows.Scan(&tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableInfo := &TableInfo{}
|
||||
tableInfo.Name = tableName
|
||||
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, databaseInfo.SchemaName, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tableInfo.DatabaseInfo = databaseInfo
|
||||
tableInfo.Columns, err = fetchColumnInfos(db, tableInfo)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableInfo.ForeignTableMap, err = getForignKeyMap(db, databaseInfo.SchemaName, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableInfos = append(tableInfos, *tableInfo)
|
||||
}
|
||||
|
||||
fmt.Println("FOUND", len(tableInfos), "tables")
|
||||
|
||||
err = rows.Err()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tableInfos, nil
|
||||
}
|
||||
|
||||
func getPrimaryKeys(db *sql.DB, schemaName, tableName string) ([]string, error) {
|
||||
query := `
|
||||
SELECT c.column_name
|
||||
FROM information_schema.key_column_usage AS c
|
||||
LEFT JOIN information_schema.table_constraints AS t
|
||||
ON t.constraint_name = c.constraint_name
|
||||
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY';
|
||||
`
|
||||
rows, err := db.Query(query, schemaName, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
primaryKeys := []string{}
|
||||
|
||||
for rows.Next() {
|
||||
primaryKey := ""
|
||||
err := rows.Scan(&primaryKey)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
primaryKeys = append(primaryKeys, primaryKey)
|
||||
}
|
||||
|
||||
return primaryKeys, nil
|
||||
}
|
||||
|
||||
func getForignKeyMap(db *sql.DB, schemaName, tableName string) (map[string]string, error) {
|
||||
query := `
|
||||
SELECT
|
||||
kcu.column_name,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name
|
||||
FROM
|
||||
information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = $1 AND tc.table_name=$2;
|
||||
`
|
||||
rows, err := db.Query(query, schemaName, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
ret := map[string]string{}
|
||||
for rows.Next() {
|
||||
var columnName, foreignTableName, foreignColumnName string
|
||||
err := rows.Scan(&columnName, &foreignTableName, &foreignColumnName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret[columnName] = foreignTableName
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package metadata
|
||||
package postgres_metadata
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
|
@ -11,21 +11,6 @@ type ColumnInfo struct {
|
|||
IsNullable bool
|
||||
DataType string
|
||||
EnumName string
|
||||
TableInfo *TableInfo
|
||||
}
|
||||
|
||||
func (c ColumnInfo) IsUnique() bool {
|
||||
for _, uniqueColumn := range c.TableInfo.PrimaryKeys {
|
||||
if uniqueColumn == c.Name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c ColumnInfo) ToGoVarName() string {
|
||||
return snaker.SnakeToCamel(c.Name) + "Column"
|
||||
}
|
||||
|
||||
func (c ColumnInfo) ToSqlBuilderColumnType() string {
|
||||
|
|
@ -52,15 +37,6 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (c ColumnInfo) ToGoType() string {
|
||||
typeStr := c.GoBaseType()
|
||||
if c.IsNullable {
|
||||
return "*" + typeStr
|
||||
}
|
||||
|
||||
return typeStr
|
||||
}
|
||||
|
||||
func (c ColumnInfo) GoBaseType() string {
|
||||
switch c.DataType {
|
||||
case "USER-DEFINED":
|
||||
|
|
@ -93,26 +69,32 @@ func (c ColumnInfo) GoBaseType() string {
|
|||
}
|
||||
}
|
||||
|
||||
func (c ColumnInfo) ToGoDMFieldName() string {
|
||||
return snaker.SnakeToCamel(c.Name)
|
||||
func (c ColumnInfo) ToGoType() string {
|
||||
typeStr := c.GoBaseType()
|
||||
if c.IsNullable {
|
||||
return "*" + typeStr
|
||||
}
|
||||
|
||||
return typeStr
|
||||
}
|
||||
|
||||
func (c ColumnInfo) ToGoFieldName() string {
|
||||
return snaker.SnakeToCamel(c.Name)
|
||||
}
|
||||
|
||||
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) {
|
||||
func (c ColumnInfo) ToGoVarName() string {
|
||||
return snaker.SnakeToCamel(c.Name) + "Column"
|
||||
}
|
||||
|
||||
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
|
||||
|
||||
query := `
|
||||
SELECT column_name, is_nullable, data_type, udt_name
|
||||
FROM information_schema.columns
|
||||
where table_schema = $1 and table_name = $2
|
||||
where table_catalog = $1 and table_schema = $2 and table_name = $3
|
||||
order by ordinal_position;`
|
||||
|
||||
//fmt.Println(query)
|
||||
|
||||
rows, err := db.Query(query, tableInfo.DatabaseInfo.SchemaName, &tableInfo.Name)
|
||||
rows, err := db.Query(query, dbName, schemaName, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -132,8 +114,6 @@ order by ordinal_position;`
|
|||
return nil, err
|
||||
}
|
||||
|
||||
columnInfo.TableInfo = tableInfo
|
||||
|
||||
ret = append(ret, columnInfo)
|
||||
}
|
||||
|
||||
|
|
@ -1,20 +1,21 @@
|
|||
package metadata
|
||||
package postgres_metadata
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
|
||||
)
|
||||
|
||||
type EnumInfo struct {
|
||||
Name string
|
||||
name string
|
||||
Values []string
|
||||
}
|
||||
|
||||
func (e *EnumInfo) goValueName(index int) {
|
||||
return
|
||||
func (e EnumInfo) Name() string {
|
||||
return e.name
|
||||
}
|
||||
|
||||
func fetchEnumInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]EnumInfo, error) {
|
||||
func getEnumInfos(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
|
||||
query := `
|
||||
SELECT t.typname,
|
||||
e.enumlabel
|
||||
|
|
@ -24,9 +25,7 @@ FROM pg_catalog.pg_type t
|
|||
WHERE n.nspname = $1
|
||||
ORDER BY n.nspname, t.typname, e.enumsortorder;`
|
||||
|
||||
//fmt.Println(query, schemaName)
|
||||
|
||||
rows, err := db.Query(query, &databaseInfo.SchemaName)
|
||||
rows, err := db.Query(query, schemaName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -55,7 +54,7 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ret := []EnumInfo{}
|
||||
ret := []metadata.MetaData{}
|
||||
|
||||
for enumName, enumValues := range enumsInfosMap {
|
||||
ret = append(ret, EnumInfo{
|
||||
76
generator/postgres-metadata/schema_info.go
Normal file
76
generator/postgres-metadata/schema_info.go
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
package postgres_metadata
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
|
||||
)
|
||||
|
||||
type SchemaInfo struct {
|
||||
DatabaseName string
|
||||
Name string
|
||||
TableInfos []metadata.MetaData
|
||||
EnumInfos []metadata.MetaData
|
||||
}
|
||||
|
||||
func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) {
|
||||
|
||||
schemaInfo.DatabaseName = databaseName
|
||||
schemaInfo.Name = schemaName
|
||||
|
||||
schemaInfo.TableInfos, err = getTableInfos(db, databaseName, schemaName)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
schemaInfo.EnumInfos, err = getEnumInfos(db, schemaName)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getTableInfos(db *sql.DB, dbName, schemaName string) ([]metadata.MetaData, error) {
|
||||
|
||||
query := `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
where table_catalog = $1 and table_schema = $2 and table_type = 'BASE TABLE';`
|
||||
|
||||
rows, err := db.Query(query, dbName, schemaName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
ret := []metadata.MetaData{}
|
||||
for rows.Next() {
|
||||
var tableName string
|
||||
err = rows.Scan(&tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableInfo, err := GetTableInfo(db, dbName, schemaName, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret = append(ret, tableInfo)
|
||||
}
|
||||
|
||||
fmt.Println("FOUND", len(ret), "tables")
|
||||
|
||||
err = rows.Err()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
116
generator/postgres-metadata/table_info.go
Normal file
116
generator/postgres-metadata/table_info.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package postgres_metadata
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/serenize/snaker"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TableInfo struct {
|
||||
SchemaName string
|
||||
name string
|
||||
PrimaryKeys map[string]bool
|
||||
Columns []ColumnInfo
|
||||
}
|
||||
|
||||
func (t TableInfo) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t TableInfo) IsUnique(columnName string) bool {
|
||||
return t.PrimaryKeys[columnName]
|
||||
}
|
||||
|
||||
func (t TableInfo) GetImports() []string {
|
||||
imports := map[string]string{}
|
||||
|
||||
for _, column := range t.Columns {
|
||||
columnType := column.GoBaseType()
|
||||
|
||||
switch columnType {
|
||||
case "time.Time":
|
||||
imports["time.Time"] = "time"
|
||||
case "uuid.UUID":
|
||||
imports["uuid.UUID"] = "github.com/google/uuid"
|
||||
case "types.JSONText":
|
||||
imports["types.JSONText"] = "github.com/sub0zero/go-sqlbuilder/types"
|
||||
}
|
||||
}
|
||||
|
||||
ret := []string{}
|
||||
|
||||
for _, packageImport := range imports {
|
||||
ret = append(ret, packageImport)
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoModelStructName() string {
|
||||
return snaker.SnakeToCamel(t.name)
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoVarName() string {
|
||||
return snaker.SnakeToCamel(t.name)
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoStructName() string {
|
||||
return snaker.SnakeToCamel(t.name) + "Table"
|
||||
}
|
||||
|
||||
func (t TableInfo) ToGoColumnFieldList(sep string) string {
|
||||
columnNames := []string{}
|
||||
for _, columnInfo := range t.Columns {
|
||||
columnNames = append(columnNames, columnInfo.ToGoVarName())
|
||||
}
|
||||
return strings.Join(columnNames, sep)
|
||||
}
|
||||
|
||||
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) {
|
||||
|
||||
tableInfo.SchemaName = schemaName
|
||||
tableInfo.name = tableName
|
||||
|
||||
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, dbName, schemaName, tableName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tableInfo.Columns, err = getColumnInfos(db, dbName, schemaName, tableName)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getPrimaryKeys(db *sql.DB, dbName, schemaName, tableName string) (map[string]bool, error) {
|
||||
query := `
|
||||
SELECT c.column_name
|
||||
FROM information_schema.key_column_usage AS c
|
||||
LEFT JOIN information_schema.table_constraints AS t
|
||||
ON t.constraint_name = c.constraint_name
|
||||
WHERE t.table_catalog = $1 AND t.table_schema = $2 AND t.table_name = $3 AND t.constraint_type = 'PRIMARY KEY';
|
||||
`
|
||||
rows, err := db.Query(query, dbName, schemaName, tableName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
primaryKeyMap := map[string]bool{}
|
||||
|
||||
for rows.Next() {
|
||||
primaryKey := ""
|
||||
err := rows.Scan(&primaryKey)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
primaryKeyMap[primaryKey] = true
|
||||
}
|
||||
|
||||
return primaryKeyMap, nil
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
package generator
|
||||
|
||||
import (
|
||||
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func generateSqlBuilderModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
|
||||
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "table")
|
||||
|
||||
err := ensureDirPath(modelDirPath)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, tableInfo := range databaseInfo.TableInfos {
|
||||
text, err := generateTemplate(SqlBuilderTableTemplate, tableInfo)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = saveGoFile(modelDirPath, tableInfo.Name+"_table", text)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
package generator
|
||||
|
||||
//import (
|
||||
// "gotest.tools/assert"
|
||||
// "testing"
|
||||
//)
|
||||
//
|
||||
//func TestGenerateSqlBuilderModel(t *testing.T) {
|
||||
// table := TableInfo{
|
||||
// "actor",
|
||||
// []ColumnInfo{
|
||||
// {"actor_id", false, "integer"},
|
||||
// {"first_name", true, "character varying"},
|
||||
// {"last_name", false, "timestamp without time zone"},
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// err := generateSqlBuilderModel("dvd_rental", table, "../../sqlbuildertest")
|
||||
//
|
||||
// assert.NilError(t, err)
|
||||
//}
|
||||
|
|
@ -1,6 +1,19 @@
|
|||
package generator
|
||||
|
||||
var SqlBuilderTableTemplate = `package table
|
||||
var autoGenWarningTemplate = `
|
||||
//
|
||||
// Code generated by sqlbuilder DO NOT EDIT.
|
||||
// Generated at {{now}}
|
||||
//
|
||||
// WARNING: Changes to this file may cause incorrect behavior and will be lost
|
||||
// if the code is regenerated
|
||||
//
|
||||
// Licence under ...
|
||||
//
|
||||
|
||||
`
|
||||
|
||||
var sqlBuilderTableTemplate = `package table
|
||||
|
||||
import (
|
||||
"github.com/sub0zero/go-sqlbuilder/sqlbuilder"
|
||||
|
|
@ -27,7 +40,7 @@ func new{{.ToGoStructName}}() *{{.ToGoStructName}} {
|
|||
)
|
||||
|
||||
return &{{.ToGoStructName}}{
|
||||
Table: *sqlbuilder.NewTable("{{.DatabaseInfo.SchemaName}}", "{{.Name}}", {{.ToGoColumnFieldList ", "}}),
|
||||
Table: *sqlbuilder.NewTable("{{.SchemaName}}", "{{.Name}}", {{.ToGoColumnFieldList ", "}}),
|
||||
|
||||
//Columns
|
||||
{{- range .Columns}}
|
||||
|
|
@ -49,7 +62,7 @@ func (a *{{.ToGoStructName}}) AS(alias string) *{{.ToGoStructName}} {
|
|||
|
||||
`
|
||||
|
||||
var DataModelTemplate = `package model
|
||||
var dataModelTemplate = `package model
|
||||
|
||||
{{ if .GetImports }}
|
||||
import (
|
||||
|
|
@ -62,12 +75,12 @@ import (
|
|||
|
||||
type {{.ToGoModelStructName}} struct {
|
||||
{{- range .Columns}}
|
||||
{{.ToGoDMFieldName}} {{.ToGoType}} {{if .IsUnique}}` + "`sql:\"unique\"`" + ` {{end}}
|
||||
{{.ToGoFieldName}} {{.ToGoType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}}
|
||||
{{- end}}
|
||||
}
|
||||
`
|
||||
|
||||
var EnumModelTemplate = `package model
|
||||
var enumModelTemplate = `package model
|
||||
|
||||
import "errors"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
)
|
||||
|
||||
func saveGoFile(dirPath, fileName string, text []byte) error {
|
||||
|
|
@ -49,10 +50,13 @@ func ensureDirPath(dirPath string) error {
|
|||
|
||||
func generateTemplate(templateText string, templateData interface{}) ([]byte, error) {
|
||||
|
||||
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{
|
||||
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
|
||||
"camelize": func(txt string) string {
|
||||
return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1))
|
||||
},
|
||||
"now": func() string {
|
||||
return time.Now().Format(time.RFC850)
|
||||
},
|
||||
}).Parse(templateText)
|
||||
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -1,25 +1,3 @@
|
|||
// A library for generating sql programmatically.
|
||||
//
|
||||
// SQL COMPATIBILITY NOTE: sqlbuilder is designed to generate valid MySQL sql
|
||||
// statements. The generated statements may not work for other sql variants.
|
||||
// For instances, the generated statements does not currently work for
|
||||
// PostgreSQL since column identifiers are escaped with backquotes.
|
||||
// Patches to support other sql flavors are welcome! (see
|
||||
// https://godropbox/issues/33 for additional details).
|
||||
//
|
||||
// Known limitations for SELECT queries:
|
||||
// - does not support subqueries (since mysql is bad at it)
|
||||
// - does not currently support join tableName alias (and hence self join)
|
||||
// - does not support NATURAL joins and join USING
|
||||
//
|
||||
// Known limitation for INSERT statements:
|
||||
// - does not support "INSERT INTO SELECT"
|
||||
//
|
||||
// Known limitation for UPDATE statements:
|
||||
// - does not support update without a WHERE clause (since it is dangerous)
|
||||
// - does not support multi-tableName update
|
||||
//
|
||||
// Known limitation for DELETE statements:
|
||||
// - does not support delete without a WHERE clause (since it is dangerous)
|
||||
// - does not support multi-tableName delete
|
||||
|
||||
package sqlbuilder
|
||||
|
|
|
|||
|
|
@ -5,20 +5,19 @@ import (
|
|||
"fmt"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/pkg/profile"
|
||||
"github.com/sub0zero/go-sqlbuilder/generator"
|
||||
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
|
||||
"gotest.tools/assert"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
folderPath = ".test_files/"
|
||||
host = "localhost"
|
||||
port = 5432
|
||||
user = "postgres"
|
||||
password = "postgres"
|
||||
dbname = "dvd_rental"
|
||||
schemaName = "dvds"
|
||||
)
|
||||
|
||||
var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)
|
||||
|
|
@ -26,8 +25,8 @@ var db *sql.DB
|
|||
|
||||
//var tx *sql.Tx
|
||||
|
||||
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files
|
||||
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files
|
||||
//go:generate generator -host=localhost -port=5432 -user=postgres -password=postgres -dbname=dvd_rental -schema dvds -path .test_files
|
||||
//go:generate generator -host=localhost -port=5432 -user=postgres -password=postgres -dbname=dvd_rental -sslmode=disable -schema test_sample -path .test_files
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
fmt.Println("Begin")
|
||||
|
|
@ -110,31 +109,32 @@ VALUES
|
|||
|
||||
}
|
||||
|
||||
func queryAll(t *testing.T, query string, args []interface{}) {
|
||||
rows, err := db.Query(query, args...)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
//err := rows.Scan(scanContext.row...)
|
||||
//
|
||||
//assert.NilError(t, err)
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
|
||||
assert.NilError(t, err)
|
||||
}
|
||||
|
||||
func TestGenerateModel(t *testing.T) {
|
||||
|
||||
err := generator.Generate(folderPath, connectString, dbname, schemaName)
|
||||
actor := model.Actor{}
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
|
||||
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, actorIDField.Tag.Get("sql"), "unique")
|
||||
assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
|
||||
assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
|
||||
assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time")
|
||||
|
||||
//err = generator.Generate(folderPath, connectString, dbname, "sport")
|
||||
//
|
||||
//assert.NilError(t, err)
|
||||
filmActor := model.FilmActor{}
|
||||
|
||||
assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16")
|
||||
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
|
||||
|
||||
assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16")
|
||||
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
|
||||
|
||||
staff := model.Staff{}
|
||||
|
||||
assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string")
|
||||
assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8")
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue