Generator clean up refactoring.

This commit is contained in:
zer0sub 2019-05-24 13:13:13 +02:00
parent 7de8c1c45e
commit b3a52ceb31
16 changed files with 372 additions and 476 deletions

View file

@ -7,25 +7,52 @@ import (
"os" "os"
) )
var genDirPath string var (
var dbConnectionString string host string
var dbName string port string
var schemaName string user string
password string
sslmode string
params string
dbName string
schemaName string
destDir string
)
func init() { func init() {
flag.StringVar(&genDirPath, "path", "", "Destination for generated files.") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
flag.StringVar(&dbConnectionString, "db", "", "Connection string to database server") flag.StringVar(&port, "port", "", "Database port")
flag.StringVar(&dbName, "dbName", "", "Name of the database") flag.StringVar(&user, "user", "", "Database user")
flag.StringVar(&password, "password", "", "The users password")
flag.StringVar(&sslmode, "sslmode", "disable", "Whether or not to use SSL")
flag.StringVar(&params, "params", "", "Additional connection string parameters.")
flag.StringVar(&dbName, "dbname", "", "name of the database")
flag.StringVar(&schemaName, "schema", "public", "Database schema name.") flag.StringVar(&schemaName, "schema", "public", "Database schema name.")
flag.StringVar(&destDir, "path", "", "Destination dir for generated files.")
flag.Parse() flag.Parse()
} }
func main() { func main() {
fmt.Println(genDirPath, dbConnectionString, dbName, schemaName) genData := generator.GeneratorData{
Host: host,
Port: port,
User: user,
Password: password,
SslMode: sslmode,
Params: params,
err := generator.Generate(genDirPath, dbConnectionString, dbName, schemaName) DbName: dbName,
SchemaName: schemaName,
}
fmt.Println(destDir, genData)
err := generator.Generate(destDir, genData)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())

View file

@ -1,58 +0,0 @@
package generator
import (
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
"path/filepath"
)
func generateDataModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
err := ensureDirPath(modelDirPath)
if err != nil {
return err
}
for _, tableInfo := range databaseInfo.TableInfos {
text, err := generateTemplate(DataModelTemplate, tableInfo)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, tableInfo.Name, text)
if err != nil {
return err
}
}
return nil
}
func generateEnumTypes(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "model")
err := ensureDirPath(modelDirPath)
if err != nil {
return err
}
for _, enumInfo := range databaseInfo.EnumInfos {
text, err := generateTemplate(EnumModelTemplate, enumInfo)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, enumInfo.Name, text)
if err != nil {
return err
}
}
return nil
}

View file

@ -2,28 +2,32 @@ package generator
import ( import (
"database/sql" "database/sql"
"fmt"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/sub0zero/go-sqlbuilder/generator/metadata" "github.com/sub0zero/go-sqlbuilder/generator/metadata"
"github.com/sub0zero/go-sqlbuilder/generator/postgres-metadata"
"path" "path"
"path/filepath"
) )
type DbConnectInfo struct { type GeneratorData struct {
host string Host string
port int Port string
user string User string
password string Password string
dbname string SslMode string
Params string
DbName string
SchemaName string
} }
func Generate(folderPath string, connectString string, databaseName, schemaName string) error { func Generate(destDir string, genData GeneratorData) error {
err := cleanUpGeneratedFiles(path.Join(folderPath, databaseName, schemaName)) connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
genData.Host, genData.Port, genData.User, genData.Password, genData.DbName, genData.SslMode, genData.Params)
if err != nil { db, err := sql.Open("postgres", connectionString)
return err
}
db, err := sql.Open("postgres", connectString)
if err != nil { if err != nil {
return err return err
} }
@ -35,25 +39,32 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
return err return err
} }
databaseInfo, err := metadata.GetDatabaseInfo(db, databaseName, schemaName) err = cleanUpGeneratedFiles(path.Join(destDir, genData.DbName, genData.SchemaName))
if err != nil { if err != nil {
return err return err
} }
err = generateSqlBuilderModel(databaseInfo, folderPath) schemaInfo, err := postgres_metadata.GetSchemaInfo(db, genData.DbName, genData.SchemaName)
if err != nil { if err != nil {
return err return err
} }
err = generateDataModel(databaseInfo, folderPath) err = generate(schemaInfo, destDir, "table", sqlBuilderTableTemplate, schemaInfo.TableInfos)
if err != nil { if err != nil {
return err return err
} }
err = generateEnumTypes(databaseInfo, folderPath) //err = generateDataModel(schemaInfo, destDir)
err = generate(schemaInfo, destDir, "model", dataModelTemplate, schemaInfo.TableInfos)
if err != nil {
return err
}
err = generate(schemaInfo, destDir, "model", enumModelTemplate, schemaInfo.EnumInfos)
if err != nil { if err != nil {
return err return err
@ -61,3 +72,35 @@ func Generate(folderPath string, connectString string, databaseName, schemaName
return nil return nil
} }
func generate(schemaInfo postgres_metadata.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error {
modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName)
err := ensureDirPath(modelDirPath)
if err != nil {
return err
}
autoGenWarning, err := generateTemplate(autoGenWarningTemplate, nil)
if err != nil {
return err
}
for _, metaData := range metaDataList {
text, err := generateTemplate(template, metaData)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, metaData.Name(), append(autoGenWarning, text...))
if err != nil {
return err
}
}
return nil
}

View file

@ -1,37 +0,0 @@
package metadata
import (
"database/sql"
)
type DatabaseInfo struct {
DatabaseName string
SchemaName string
TableInfos []TableInfo
EnumInfos []EnumInfo
}
func GetDatabaseInfo(db *sql.DB, databaseName, schemaName string) (*DatabaseInfo, error) {
databaseInfo := &DatabaseInfo{
databaseName,
schemaName,
[]TableInfo{},
[]EnumInfo{},
}
var err error
databaseInfo.TableInfos, err = fetchTableInfos(db, databaseInfo)
if err != nil {
return nil, err
}
databaseInfo.EnumInfos, err = fetchEnumInfos(db, databaseInfo)
if err != nil {
return nil, err
}
return databaseInfo, nil
}

View file

@ -0,0 +1,5 @@
package metadata
type MetaData interface {
Name() string
}

View file

@ -1,197 +0,0 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/serenize/snaker"
"strings"
)
type TableInfo struct {
Name string
PrimaryKeys []string
ForeignTableMap map[string]string
Columns []ColumnInfo
DatabaseInfo *DatabaseInfo
}
func (t TableInfo) GetImports() []string {
imports := map[string]string{}
for _, column := range t.Columns {
columnType := column.GoBaseType()
switch columnType {
case "time.Time":
imports["time.Time"] = "time"
case "uuid.UUID":
imports["uuid.UUID"] = "github.com/google/uuid"
case "types.JSONText":
imports["types.JSONText"] = "github.com/sub0zero/go-sqlbuilder/types"
}
}
ret := []string{}
for _, packageImport := range imports {
ret = append(ret, packageImport)
}
return ret
}
func (t TableInfo) IsForeignKey(columnName string) bool {
_, exist := t.ForeignTableMap[columnName]
return exist
}
func (t TableInfo) ToGoModelStructName() string {
return snaker.SnakeToCamel(t.Name)
}
func (t TableInfo) ToGoVarName() string {
return snaker.SnakeToCamel(t.Name)
}
func (t TableInfo) ToGoStructName() string {
return snaker.SnakeToCamel(t.Name) + "Table"
}
func (t TableInfo) ToGoColumnFieldList(sep string) string {
columnNames := []string{}
for _, columnInfo := range t.Columns {
columnNames = append(columnNames, columnInfo.ToGoVarName())
}
return strings.Join(columnNames, sep)
}
func fetchTableInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]TableInfo, error) {
query := `
SELECT table_name
FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';`
//fmt.Println(query, schemaName)
rows, err := db.Query(query, &databaseInfo.SchemaName)
if err != nil {
return nil, err
}
defer rows.Close()
tableInfos := []TableInfo{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo := &TableInfo{}
tableInfo.Name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, databaseInfo.SchemaName, tableName)
if err != nil {
return nil, err
}
tableInfo.DatabaseInfo = databaseInfo
tableInfo.Columns, err = fetchColumnInfos(db, tableInfo)
if err != nil {
return nil, err
}
tableInfo.ForeignTableMap, err = getForignKeyMap(db, databaseInfo.SchemaName, tableName)
if err != nil {
return nil, err
}
tableInfos = append(tableInfos, *tableInfo)
}
fmt.Println("FOUND", len(tableInfos), "tables")
err = rows.Err()
if err != nil {
return nil, err
}
return tableInfos, nil
}
func getPrimaryKeys(db *sql.DB, schemaName, tableName string) ([]string, error) {
query := `
SELECT c.column_name
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY';
`
rows, err := db.Query(query, schemaName, tableName)
if err != nil {
return nil, err
}
primaryKeys := []string{}
for rows.Next() {
primaryKey := ""
err := rows.Scan(&primaryKey)
if err != nil {
return nil, err
}
primaryKeys = append(primaryKeys, primaryKey)
}
return primaryKeys, nil
}
func getForignKeyMap(db *sql.DB, schemaName, tableName string) (map[string]string, error) {
query := `
SELECT
kcu.column_name,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name
FROM
information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = $1 AND tc.table_name=$2;
`
rows, err := db.Query(query, schemaName, tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := map[string]string{}
for rows.Next() {
var columnName, foreignTableName, foreignColumnName string
err := rows.Scan(&columnName, &foreignTableName, &foreignColumnName)
if err != nil {
return nil, err
}
ret[columnName] = foreignTableName
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -1,4 +1,4 @@
package metadata package postgres_metadata
import ( import (
"database/sql" "database/sql"
@ -11,21 +11,6 @@ type ColumnInfo struct {
IsNullable bool IsNullable bool
DataType string DataType string
EnumName string EnumName string
TableInfo *TableInfo
}
func (c ColumnInfo) IsUnique() bool {
for _, uniqueColumn := range c.TableInfo.PrimaryKeys {
if uniqueColumn == c.Name {
return true
}
}
return false
}
func (c ColumnInfo) ToGoVarName() string {
return snaker.SnakeToCamel(c.Name) + "Column"
} }
func (c ColumnInfo) ToSqlBuilderColumnType() string { func (c ColumnInfo) ToSqlBuilderColumnType() string {
@ -52,15 +37,6 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string {
} }
} }
func (c ColumnInfo) ToGoType() string {
typeStr := c.GoBaseType()
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
func (c ColumnInfo) GoBaseType() string { func (c ColumnInfo) GoBaseType() string {
switch c.DataType { switch c.DataType {
case "USER-DEFINED": case "USER-DEFINED":
@ -93,26 +69,32 @@ func (c ColumnInfo) GoBaseType() string {
} }
} }
func (c ColumnInfo) ToGoDMFieldName() string { func (c ColumnInfo) ToGoType() string {
return snaker.SnakeToCamel(c.Name) typeStr := c.GoBaseType()
if c.IsNullable {
return "*" + typeStr
}
return typeStr
} }
func (c ColumnInfo) ToGoFieldName() string { func (c ColumnInfo) ToGoFieldName() string {
return snaker.SnakeToCamel(c.Name) return snaker.SnakeToCamel(c.Name)
} }
func fetchColumnInfos(db *sql.DB, tableInfo *TableInfo) ([]ColumnInfo, error) { func (c ColumnInfo) ToGoVarName() string {
return snaker.SnakeToCamel(c.Name) + "Column"
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := ` query := `
SELECT column_name, is_nullable, data_type, udt_name SELECT column_name, is_nullable, data_type, udt_name
FROM information_schema.columns FROM information_schema.columns
where table_schema = $1 and table_name = $2 where table_catalog = $1 and table_schema = $2 and table_name = $3
order by ordinal_position;` order by ordinal_position;`
//fmt.Println(query) rows, err := db.Query(query, dbName, schemaName, tableName)
rows, err := db.Query(query, tableInfo.DatabaseInfo.SchemaName, &tableInfo.Name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -132,8 +114,6 @@ order by ordinal_position;`
return nil, err return nil, err
} }
columnInfo.TableInfo = tableInfo
ret = append(ret, columnInfo) ret = append(ret, columnInfo)
} }

View file

@ -1,20 +1,21 @@
package metadata package postgres_metadata
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
) )
type EnumInfo struct { type EnumInfo struct {
Name string name string
Values []string Values []string
} }
func (e *EnumInfo) goValueName(index int) { func (e EnumInfo) Name() string {
return return e.name
} }
func fetchEnumInfos(db *sql.DB, databaseInfo *DatabaseInfo) ([]EnumInfo, error) { func getEnumInfos(db *sql.DB, schemaName string) ([]metadata.MetaData, error) {
query := ` query := `
SELECT t.typname, SELECT t.typname,
e.enumlabel e.enumlabel
@ -24,9 +25,7 @@ FROM pg_catalog.pg_type t
WHERE n.nspname = $1 WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;` ORDER BY n.nspname, t.typname, e.enumsortorder;`
//fmt.Println(query, schemaName) rows, err := db.Query(query, schemaName)
rows, err := db.Query(query, &databaseInfo.SchemaName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -55,7 +54,7 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
return nil, err return nil, err
} }
ret := []EnumInfo{} ret := []metadata.MetaData{}
for enumName, enumValues := range enumsInfosMap { for enumName, enumValues := range enumsInfosMap {
ret = append(ret, EnumInfo{ ret = append(ret, EnumInfo{

View file

@ -0,0 +1,76 @@
package postgres_metadata
import (
"database/sql"
"fmt"
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
)
type SchemaInfo struct {
DatabaseName string
Name string
TableInfos []metadata.MetaData
EnumInfos []metadata.MetaData
}
func GetSchemaInfo(db *sql.DB, databaseName, schemaName string) (schemaInfo SchemaInfo, err error) {
schemaInfo.DatabaseName = databaseName
schemaInfo.Name = schemaName
schemaInfo.TableInfos, err = getTableInfos(db, databaseName, schemaName)
if err != nil {
return
}
schemaInfo.EnumInfos, err = getEnumInfos(db, schemaName)
if err != nil {
return
}
return
}
func getTableInfos(db *sql.DB, dbName, schemaName string) ([]metadata.MetaData, error) {
query := `
SELECT table_name
FROM information_schema.tables
where table_catalog = $1 and table_schema = $2 and table_type = 'BASE TABLE';`
rows, err := db.Query(query, dbName, schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []metadata.MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
if err != nil {
return nil, err
}
tableInfo, err := GetTableInfo(db, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
ret = append(ret, tableInfo)
}
fmt.Println("FOUND", len(ret), "tables")
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,116 @@
package postgres_metadata
import (
"database/sql"
"github.com/serenize/snaker"
"strings"
)
type TableInfo struct {
SchemaName string
name string
PrimaryKeys map[string]bool
Columns []ColumnInfo
}
func (t TableInfo) Name() string {
return t.name
}
func (t TableInfo) IsUnique(columnName string) bool {
return t.PrimaryKeys[columnName]
}
func (t TableInfo) GetImports() []string {
imports := map[string]string{}
for _, column := range t.Columns {
columnType := column.GoBaseType()
switch columnType {
case "time.Time":
imports["time.Time"] = "time"
case "uuid.UUID":
imports["uuid.UUID"] = "github.com/google/uuid"
case "types.JSONText":
imports["types.JSONText"] = "github.com/sub0zero/go-sqlbuilder/types"
}
}
ret := []string{}
for _, packageImport := range imports {
ret = append(ret, packageImport)
}
return ret
}
func (t TableInfo) ToGoModelStructName() string {
return snaker.SnakeToCamel(t.name)
}
func (t TableInfo) ToGoVarName() string {
return snaker.SnakeToCamel(t.name)
}
func (t TableInfo) ToGoStructName() string {
return snaker.SnakeToCamel(t.name) + "Table"
}
func (t TableInfo) ToGoColumnFieldList(sep string) string {
columnNames := []string{}
for _, columnInfo := range t.Columns {
columnNames = append(columnNames, columnInfo.ToGoVarName())
}
return strings.Join(columnNames, sep)
}
func GetTableInfo(db *sql.DB, dbName, schemaName, tableName string) (tableInfo TableInfo, err error) {
tableInfo.SchemaName = schemaName
tableInfo.name = tableName
tableInfo.PrimaryKeys, err = getPrimaryKeys(db, dbName, schemaName, tableName)
if err != nil {
return
}
tableInfo.Columns, err = getColumnInfos(db, dbName, schemaName, tableName)
if err != nil {
return
}
return
}
func getPrimaryKeys(db *sql.DB, dbName, schemaName, tableName string) (map[string]bool, error) {
query := `
SELECT c.column_name
FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name
WHERE t.table_catalog = $1 AND t.table_schema = $2 AND t.table_name = $3 AND t.constraint_type = 'PRIMARY KEY';
`
rows, err := db.Query(query, dbName, schemaName, tableName)
if err != nil {
return nil, err
}
primaryKeyMap := map[string]bool{}
for rows.Next() {
primaryKey := ""
err := rows.Scan(&primaryKey)
if err != nil {
return nil, err
}
primaryKeyMap[primaryKey] = true
}
return primaryKeyMap, nil
}

View file

@ -1,32 +0,0 @@
package generator
import (
"github.com/sub0zero/go-sqlbuilder/generator/metadata"
"path/filepath"
)
func generateSqlBuilderModel(databaseInfo *metadata.DatabaseInfo, dirPath string) error {
modelDirPath := filepath.Join(dirPath, databaseInfo.DatabaseName, databaseInfo.SchemaName, "table")
err := ensureDirPath(modelDirPath)
if err != nil {
return err
}
for _, tableInfo := range databaseInfo.TableInfos {
text, err := generateTemplate(SqlBuilderTableTemplate, tableInfo)
if err != nil {
return err
}
err = saveGoFile(modelDirPath, tableInfo.Name+"_table", text)
if err != nil {
return err
}
}
return nil
}

View file

@ -1,21 +0,0 @@
package generator
//import (
// "gotest.tools/assert"
// "testing"
//)
//
//func TestGenerateSqlBuilderModel(t *testing.T) {
// table := TableInfo{
// "actor",
// []ColumnInfo{
// {"actor_id", false, "integer"},
// {"first_name", true, "character varying"},
// {"last_name", false, "timestamp without time zone"},
// },
// }
//
// err := generateSqlBuilderModel("dvd_rental", table, "../../sqlbuildertest")
//
// assert.NilError(t, err)
//}

View file

@ -1,6 +1,19 @@
package generator package generator
var SqlBuilderTableTemplate = `package table var autoGenWarningTemplate = `
//
// Code generated by sqlbuilder DO NOT EDIT.
// Generated at {{now}}
//
// WARNING: Changes to this file may cause incorrect behavior and will be lost
// if the code is regenerated
//
// Licence under ...
//
`
var sqlBuilderTableTemplate = `package table
import ( import (
"github.com/sub0zero/go-sqlbuilder/sqlbuilder" "github.com/sub0zero/go-sqlbuilder/sqlbuilder"
@ -27,7 +40,7 @@ func new{{.ToGoStructName}}() *{{.ToGoStructName}} {
) )
return &{{.ToGoStructName}}{ return &{{.ToGoStructName}}{
Table: *sqlbuilder.NewTable("{{.DatabaseInfo.SchemaName}}", "{{.Name}}", {{.ToGoColumnFieldList ", "}}), Table: *sqlbuilder.NewTable("{{.SchemaName}}", "{{.Name}}", {{.ToGoColumnFieldList ", "}}),
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}
@ -49,7 +62,7 @@ func (a *{{.ToGoStructName}}) AS(alias string) *{{.ToGoStructName}} {
` `
var DataModelTemplate = `package model var dataModelTemplate = `package model
{{ if .GetImports }} {{ if .GetImports }}
import ( import (
@ -62,12 +75,12 @@ import (
type {{.ToGoModelStructName}} struct { type {{.ToGoModelStructName}} struct {
{{- range .Columns}} {{- range .Columns}}
{{.ToGoDMFieldName}} {{.ToGoType}} {{if .IsUnique}}` + "`sql:\"unique\"`" + ` {{end}} {{.ToGoFieldName}} {{.ToGoType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}}
{{- end}} {{- end}}
} }
` `
var EnumModelTemplate = `package model var enumModelTemplate = `package model
import "errors" import "errors"

View file

@ -8,6 +8,7 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"text/template" "text/template"
"time"
) )
func saveGoFile(dirPath, fileName string, text []byte) error { func saveGoFile(dirPath, fileName string, text []byte) error {
@ -49,10 +50,13 @@ func ensureDirPath(dirPath string) error {
func generateTemplate(templateText string, templateData interface{}) ([]byte, error) { func generateTemplate(templateText string, templateData interface{}) ([]byte, error) {
t, err := template.New("SqlBuilderTableTemplate").Funcs(template.FuncMap{ t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"camelize": func(txt string) string { "camelize": func(txt string) string {
return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1)) return snaker.SnakeToCamel(strings.Replace(txt, "-", "_", -1))
}, },
"now": func() string {
return time.Now().Format(time.RFC850)
},
}).Parse(templateText) }).Parse(templateText)
if err != nil { if err != nil {

View file

@ -1,25 +1,3 @@
// A library for generating sql programmatically.
// //
// SQL COMPATIBILITY NOTE: sqlbuilder is designed to generate valid MySQL sql
// statements. The generated statements may not work for other sql variants.
// For instances, the generated statements does not currently work for
// PostgreSQL since column identifiers are escaped with backquotes.
// Patches to support other sql flavors are welcome! (see
// https://godropbox/issues/33 for additional details).
//
// Known limitations for SELECT queries:
// - does not support subqueries (since mysql is bad at it)
// - does not currently support join tableName alias (and hence self join)
// - does not support NATURAL joins and join USING
//
// Known limitation for INSERT statements:
// - does not support "INSERT INTO SELECT"
//
// Known limitation for UPDATE statements:
// - does not support update without a WHERE clause (since it is dangerous)
// - does not support multi-tableName update
//
// Known limitation for DELETE statements:
// - does not support delete without a WHERE clause (since it is dangerous)
// - does not support multi-tableName delete
package sqlbuilder package sqlbuilder

View file

@ -5,20 +5,19 @@ import (
"fmt" "fmt"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/pkg/profile" "github.com/pkg/profile"
"github.com/sub0zero/go-sqlbuilder/generator" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
"gotest.tools/assert" "gotest.tools/assert"
"os" "os"
"reflect"
"testing" "testing"
) )
const ( const (
folderPath = ".test_files/" host = "localhost"
host = "localhost" port = 5432
port = 5432 user = "postgres"
user = "postgres" password = "postgres"
password = "postgres" dbname = "dvd_rental"
dbname = "dvd_rental"
schemaName = "dvds"
) )
var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname) var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)
@ -26,8 +25,8 @@ var db *sql.DB
//var tx *sql.Tx //var tx *sql.Tx
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files //go:generate generator -host=localhost -port=5432 -user=postgres -password=postgres -dbname=dvd_rental -schema dvds -path .test_files
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files //go:generate generator -host=localhost -port=5432 -user=postgres -password=postgres -dbname=dvd_rental -sslmode=disable -schema test_sample -path .test_files
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
fmt.Println("Begin") fmt.Println("Begin")
@ -110,31 +109,32 @@ VALUES
} }
func queryAll(t *testing.T, query string, args []interface{}) {
rows, err := db.Query(query, args...)
assert.NilError(t, err)
defer rows.Close()
for rows.Next() {
//err := rows.Scan(scanContext.row...)
//
//assert.NilError(t, err)
}
err = rows.Err()
assert.NilError(t, err)
}
func TestGenerateModel(t *testing.T) { func TestGenerateModel(t *testing.T) {
err := generator.Generate(folderPath, connectString, dbname, schemaName) actor := model.Actor{}
assert.NilError(t, err) assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
assert.Assert(t, ok)
assert.Equal(t, actorIDField.Tag.Get("sql"), "unique")
assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time")
//err = generator.Generate(folderPath, connectString, dbname, "sport") filmActor := model.FilmActor{}
//
//assert.NilError(t, err) assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16")
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16")
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
staff := model.Staff{}
assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string")
assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8")
} }