Sql builder generator for postgres database.

This commit is contained in:
sub0Zero 2019-03-03 17:54:43 +01:00 committed by zer0sub
parent 3190d6f933
commit 92edc96c9a
10 changed files with 403 additions and 9 deletions

18
.gitignore vendored Normal file
View file

@ -0,0 +1,18 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Idea
.idea
# Test files
.test_files

36
cmd/generator/main.go Normal file
View file

@ -0,0 +1,36 @@
package main
import (
"flag"
"fmt"
"github.com/sub0Zero/go-sqlbuilder/generator"
"os"
)
var genDirPath string
var dbConnectionString string
var dbName string
var schemaName string
func init() {
flag.StringVar(&genDirPath, "path", "", "Destination for generated files.")
flag.StringVar(&dbConnectionString, "db", "", "Connection string to database server")
flag.StringVar(&dbName, "dbName", "", "Name of the database")
flag.StringVar(&schemaName, "schema", "public", "Database schema name.")
flag.Parse()
}
func main() {
fmt.Println(genDirPath, dbConnectionString, dbName, schemaName)
err := generator.Generate(genDirPath, dbConnectionString, dbName, schemaName)
if err != nil {
fmt.Println(err.Error())
os.Exit(-1)
}
fmt.Println("SUCCESS")
}

163
generator/generator.go Normal file
View file

@ -0,0 +1,163 @@
package generator
import (
"database/sql"
_ "github.com/lib/pq"
"github.com/serenize/snaker"
"os"
)
type DbConnectInfo struct {
host string
port int
user string
password string
dbname string
}
func Generate(folderPath string, connectString string, databaseName, schemaName string) error {
if _, err := os.Stat(folderPath); os.IsNotExist(err) {
err := os.Mkdir(folderPath, os.ModePerm)
if err != nil {
return err
}
}
db, err := sql.Open("postgres", connectString)
if err != nil {
return err
}
defer db.Close()
err = db.Ping()
if err != nil {
return err
}
tables, err := getTablesInfo(db, schemaName)
if err != nil {
return err
}
for _, table := range tables {
err = generateSqlBuilderModel(databaseName, schemaName, table, folderPath)
if err != nil {
return err
}
}
return nil
}
type TableInfo struct {
Name string
Columns []ColumnInfo
}
func getTablesInfo(db *sql.DB, schemaName string) ([]TableInfo, error) {
tableNames, err := getListOfTables(db, schemaName)
if err != nil {
return nil, err
}
tables := []TableInfo{}
for _, tableName := range tableNames {
columns, err := getColumnInfos(db, tableName)
if err != nil {
return nil, err
}
tables = append(tables, TableInfo{tableName, columns})
}
return tables, nil
}
func getListOfTables(db *sql.DB, schemaName string) ([]string, error) {
rows, err := db.Query(`
SELECT table_name FROM information_schema.tables
where table_schema = $1 and table_type = 'BASE TABLE';`, schemaName)
if err != nil {
return nil, err
}
defer rows.Close()
tables := []string{}
for rows.Next() {
var table string
err = rows.Scan(&table)
if err != nil {
return nil, err
}
tables = append(tables, table)
}
err = rows.Err()
if err != nil {
return nil, err
}
return tables, nil
}
type ColumnInfo struct {
Name string
IsNullable bool
DataType string
}
func (c *ColumnInfo) CamelCaseName() string {
return snaker.SnakeToCamel(c.Name)
}
func getColumnInfos(db *sql.DB, tableName string) ([]ColumnInfo, error) {
query := `
SELECT column_name, is_nullable, data_type
FROM information_schema.columns
where table_name = $1
order by ordinal_position;`
//fmt.Println(query)
rows, err := db.Query(query, &tableName)
if err != nil {
return nil, err
}
defer rows.Close()
ret := []ColumnInfo{}
for rows.Next() {
columnInfo := ColumnInfo{}
var isNullable string
err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType)
columnInfo.IsNullable = isNullable == "YES"
if err != nil {
return nil, err
}
ret = append(ret, columnInfo)
}
err = rows.Err()
if err != nil {
return nil, err
}
return ret, nil
}

View file

@ -0,0 +1,91 @@
package generator
import (
"bytes"
"github.com/serenize/snaker"
"go/format"
"os"
"path/filepath"
"strings"
"text/template"
)
func generateSqlBuilderModel(databaseName, schemaName string, tableInfo TableInfo, dirPath string) error {
schemaDirPath := filepath.Join(dirPath, databaseName, schemaName, "table")
if _, err := os.Stat(schemaDirPath); os.IsNotExist(err) {
err := os.MkdirAll(schemaDirPath, os.ModePerm)
if err != nil {
return err
}
}
t, err := template.New("TableTemplate").Funcs(template.FuncMap{
"camelize": func(txt string) string {
return snaker.SnakeToCamel(txt)
},
"columnName": columnName,
}).Parse(TableTemplate)
if err != nil {
return err
}
newGoFilePath := filepath.Join(schemaDirPath, tableInfo.Name) + ".go"
file, err := os.Create(newGoFilePath)
if err != nil {
return err
}
defer file.Close()
tableTemplate := TableTemplateData{
databaseName,
tableInfo,
}
//err = t.Execute(file, &tableTemplate)
//
//if err != nil {
// return err
//}
var buf bytes.Buffer
if err := t.Execute(&buf, &tableTemplate); err != nil {
return err
}
p, err := format.Source(buf.Bytes())
if err != nil {
return err
}
_, err = file.Write(p)
if err != nil {
return err
}
return nil
}
type TableTemplateData struct {
PackageName string
TableInfo TableInfo
}
func columnName(table, column string) string {
return snaker.SnakeToCamelLower(table) + snaker.SnakeToCamel(column) + "Column"
}
func (t *TableTemplateData) ColumnNameList(sep string) string {
columnNames := []string{}
for _, columnInfo := range t.TableInfo.Columns {
columnInfoName := columnInfo.Name
columnNames = append(columnNames, columnName(t.TableInfo.Name, columnInfoName))
}
return strings.Join(columnNames, sep)
}

View file

@ -0,0 +1,21 @@
package generator
//import (
// "gotest.tools/assert"
// "testing"
//)
//
//func TestGenerateSqlBuilderModel(t *testing.T) {
// table := TableInfo{
// "actor",
// []ColumnInfo{
// {"actor_id", false, "integer"},
// {"first_name", true, "character varying"},
// {"last_name", false, "timestamp without time zone"},
// },
// }
//
// err := generateSqlBuilderModel("dvd_rental", table, "../../sqlbuildertest")
//
// assert.NilError(t, err)
//}

30
generator/templates.go Normal file
View file

@ -0,0 +1,30 @@
package generator
var TableTemplate = `package table
import "github.com/sub0Zero/go-sqlbuilder/sqlbuilder"
type {{camelize .TableInfo.Name}}Table struct {
sqlbuilder.Table
//Columns
{{- range .TableInfo.Columns}}
{{camelize .Name}} sqlbuilder.NonAliasColumn
{{- end}}
}
var {{camelize .TableInfo.Name}} = &{{camelize .TableInfo.Name}}Table{
Table: *sqlbuilder.NewTable("{{.TableInfo.Name}}", {{.ColumnNameList ", "}}),
//Columns
{{- range .TableInfo.Columns}}
{{camelize .Name}}: {{columnName $.TableInfo.Name .Name}},
{{- end}}
}
var (
{{- range .TableInfo.Columns}}
{{columnName $.TableInfo.Name .Name}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}})
{{- end}}
)
`

View file

@ -75,13 +75,10 @@ func (c *baseColumn) setTableName(table string) error {
func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error {
if c.table != "" {
_ = out.WriteByte('`')
_, _ = out.WriteString(c.table)
_, _ = out.WriteString("`.")
_, _ = out.WriteString(".")
}
_, _ = out.WriteString("`")
_, _ = out.WriteString(c.name)
_ = out.WriteByte('`')
return nil
}
@ -147,7 +144,7 @@ func DateTimeColumn(name string, nullable NullableColumn) NonAliasColumn {
return dc
}
type integerColumn struct {
type IntegerColumn struct {
baseColumn
isExpression
}
@ -158,7 +155,7 @@ func IntColumn(name string, nullable NullableColumn) NonAliasColumn {
if !validIdentifierName(name) {
panic("Invalid column name in int column")
}
ic := &integerColumn{}
ic := &IntegerColumn{}
ic.name = name
ic.nullable = nullable
return ic

View file

@ -131,11 +131,9 @@ func (t *Table) ForceIndex(index string) *Table {
// Generates the sql string for the current table expression. Note: the
// generated string may not be a valid/executable sql statement.
func (t *Table) SerializeSql(database string, out *bytes.Buffer) error {
_, _ = out.WriteString("`")
_, _ = out.WriteString(database)
_, _ = out.WriteString("`.`")
_, _ = out.WriteString(".")
_, _ = out.WriteString(t.Name())
_, _ = out.WriteString("`")
if t.forcedIndex != "" {
if !validIdentifierName(t.forcedIndex) {

BIN
tests/dvdrental.tar Normal file

Binary file not shown.

40
tests/generator_test.go Normal file
View file

@ -0,0 +1,40 @@
package tests
import (
"fmt"
"github.com/sub0Zero/go-sqlbuilder/generator"
. "github.com/sub0Zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/public/table"
"gotest.tools/assert"
"testing"
)
var (
folderPath = ".test_files/"
host = "localhost"
port = 5432
user = "postgres"
password = "postgres"
dbname = "dvd_rental"
schemaName = "public"
)
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema public -path .test_files
func TestGenerateModel(t *testing.T) {
connectString := fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable",
host, port, user, password, dbname)
err := generator.Generate(folderPath, connectString, dbname, schemaName)
assert.NilError(t, err)
}
func TestSelectQuery(t *testing.T) {
query, err := table.Actor.InnerJoinOn(table.Store, Eq(table.Actor.ActorID, table.Store.StoreID)).
Select(table.Store.StoreID, table.Store.AddressID, table.Actor.ActorID).String(schemaName)
assert.NilError(t, err)
assert.Equal(t, query, "SELECT store.store_id,store.address_id,actor.actor_id FROM public.actor JOIN public.store ON actor.actor_id=store.store_id")
}