diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e9df959 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Idea +.idea + +# Test files +.test_files \ No newline at end of file diff --git a/cmd/generator/main.go b/cmd/generator/main.go new file mode 100644 index 0000000..74cb087 --- /dev/null +++ b/cmd/generator/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "flag" + "fmt" + "github.com/sub0Zero/go-sqlbuilder/generator" + "os" +) + +var genDirPath string +var dbConnectionString string +var dbName string +var schemaName string + +func init() { + flag.StringVar(&genDirPath, "path", "", "Destination for generated files.") + flag.StringVar(&dbConnectionString, "db", "", "Connection string to database server") + flag.StringVar(&dbName, "dbName", "", "Name of the database") + flag.StringVar(&schemaName, "schema", "public", "Database schema name.") + + flag.Parse() +} + +func main() { + + fmt.Println(genDirPath, dbConnectionString, dbName, schemaName) + + err := generator.Generate(genDirPath, dbConnectionString, dbName, schemaName) + + if err != nil { + fmt.Println(err.Error()) + os.Exit(-1) + } + + fmt.Println("SUCCESS") +} diff --git a/generator/generator.go b/generator/generator.go new file mode 100644 index 0000000..624a77b --- /dev/null +++ b/generator/generator.go @@ -0,0 +1,163 @@ +package generator + +import ( + "database/sql" + _ "github.com/lib/pq" + "github.com/serenize/snaker" + "os" +) + +type DbConnectInfo struct { + host string + port int + user string + password string + dbname string +} + +func Generate(folderPath string, connectString string, databaseName, schemaName string) error { + if _, err := os.Stat(folderPath); os.IsNotExist(err) { + err := os.Mkdir(folderPath, os.ModePerm) + + if err != nil { + return err + } + } + + db, err := sql.Open("postgres", connectString) + if err != nil { + return err + } + defer db.Close() + + err = db.Ping() + + if err != nil { + return err + } + + tables, err := getTablesInfo(db, schemaName) + + if err != nil { + return err + } + + for _, table := range tables { + err = generateSqlBuilderModel(databaseName, schemaName, table, folderPath) + + if err != nil { + return err + } + } + + return nil +} + +type TableInfo struct { + Name string + Columns []ColumnInfo +} + +func getTablesInfo(db *sql.DB, schemaName string) ([]TableInfo, error) { + tableNames, err := getListOfTables(db, schemaName) + + if err != nil { + return nil, err + } + + tables := []TableInfo{} + for _, tableName := range tableNames { + columns, err := getColumnInfos(db, tableName) + + if err != nil { + return nil, err + } + + tables = append(tables, TableInfo{tableName, columns}) + } + + return tables, nil +} + +func getListOfTables(db *sql.DB, schemaName string) ([]string, error) { + + rows, err := db.Query(` +SELECT table_name FROM information_schema.tables +where table_schema = $1 and table_type = 'BASE TABLE';`, schemaName) + + if err != nil { + return nil, err + } + defer rows.Close() + + tables := []string{} + for rows.Next() { + var table string + err = rows.Scan(&table) + if err != nil { + return nil, err + } + + tables = append(tables, table) + } + + err = rows.Err() + + if err != nil { + return nil, err + } + + return tables, nil +} + +type ColumnInfo struct { + Name string + IsNullable bool + DataType string +} + +func (c *ColumnInfo) CamelCaseName() string { + return snaker.SnakeToCamel(c.Name) +} + +func getColumnInfos(db *sql.DB, tableName string) ([]ColumnInfo, error) { + + query := ` +SELECT column_name, is_nullable, data_type +FROM information_schema.columns +where table_name = $1 +order by ordinal_position;` + + //fmt.Println(query) + + rows, err := db.Query(query, &tableName) + + if err != nil { + return nil, err + } + defer rows.Close() + + ret := []ColumnInfo{} + + for rows.Next() { + columnInfo := ColumnInfo{} + var isNullable string + err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType) + + columnInfo.IsNullable = isNullable == "YES" + + if err != nil { + return nil, err + } + + ret = append(ret, columnInfo) + } + + err = rows.Err() + + if err != nil { + return nil, err + } + + return ret, nil +} diff --git a/generator/sqlbuilder_generator.go b/generator/sqlbuilder_generator.go new file mode 100644 index 0000000..e35ae17 --- /dev/null +++ b/generator/sqlbuilder_generator.go @@ -0,0 +1,91 @@ +package generator + +import ( + "bytes" + "github.com/serenize/snaker" + "go/format" + "os" + "path/filepath" + "strings" + "text/template" +) + +func generateSqlBuilderModel(databaseName, schemaName string, tableInfo TableInfo, dirPath string) error { + + schemaDirPath := filepath.Join(dirPath, databaseName, schemaName, "table") + + if _, err := os.Stat(schemaDirPath); os.IsNotExist(err) { + err := os.MkdirAll(schemaDirPath, os.ModePerm) + + if err != nil { + return err + } + } + + t, err := template.New("TableTemplate").Funcs(template.FuncMap{ + "camelize": func(txt string) string { + return snaker.SnakeToCamel(txt) + }, + "columnName": columnName, + }).Parse(TableTemplate) + + if err != nil { + return err + } + + newGoFilePath := filepath.Join(schemaDirPath, tableInfo.Name) + ".go" + + file, err := os.Create(newGoFilePath) + + if err != nil { + return err + } + + defer file.Close() + + tableTemplate := TableTemplateData{ + databaseName, + tableInfo, + } + + //err = t.Execute(file, &tableTemplate) + // + //if err != nil { + // return err + //} + + var buf bytes.Buffer + if err := t.Execute(&buf, &tableTemplate); err != nil { + return err + } + p, err := format.Source(buf.Bytes()) + if err != nil { + return err + } + + _, err = file.Write(p) + + if err != nil { + return err + } + + return nil +} + +type TableTemplateData struct { + PackageName string + TableInfo TableInfo +} + +func columnName(table, column string) string { + return snaker.SnakeToCamelLower(table) + snaker.SnakeToCamel(column) + "Column" +} + +func (t *TableTemplateData) ColumnNameList(sep string) string { + columnNames := []string{} + for _, columnInfo := range t.TableInfo.Columns { + columnInfoName := columnInfo.Name + columnNames = append(columnNames, columnName(t.TableInfo.Name, columnInfoName)) + } + return strings.Join(columnNames, sep) +} diff --git a/generator/sqlbuilder_generator_test.go b/generator/sqlbuilder_generator_test.go new file mode 100644 index 0000000..ab38b1b --- /dev/null +++ b/generator/sqlbuilder_generator_test.go @@ -0,0 +1,21 @@ +package generator + +//import ( +// "gotest.tools/assert" +// "testing" +//) +// +//func TestGenerateSqlBuilderModel(t *testing.T) { +// table := TableInfo{ +// "actor", +// []ColumnInfo{ +// {"actor_id", false, "integer"}, +// {"first_name", true, "character varying"}, +// {"last_name", false, "timestamp without time zone"}, +// }, +// } +// +// err := generateSqlBuilderModel("dvd_rental", table, "../../sqlbuildertest") +// +// assert.NilError(t, err) +//} diff --git a/generator/templates.go b/generator/templates.go new file mode 100644 index 0000000..cec76b4 --- /dev/null +++ b/generator/templates.go @@ -0,0 +1,30 @@ +package generator + +var TableTemplate = `package table + +import "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" + +type {{camelize .TableInfo.Name}}Table struct { + sqlbuilder.Table + + //Columns +{{- range .TableInfo.Columns}} + {{camelize .Name}} sqlbuilder.NonAliasColumn +{{- end}} +} + +var {{camelize .TableInfo.Name}} = &{{camelize .TableInfo.Name}}Table{ + Table: *sqlbuilder.NewTable("{{.TableInfo.Name}}", {{.ColumnNameList ", "}}), + + //Columns +{{- range .TableInfo.Columns}} + {{camelize .Name}}: {{columnName $.TableInfo.Name .Name}}, +{{- end}} +} + +var ( +{{- range .TableInfo.Columns}} + {{columnName $.TableInfo.Name .Name}} = sqlbuilder.IntColumn("{{.Name}}", {{if .IsNullable}}sqlbuilder.Nullable{{else}}sqlbuilder.NotNullable{{end}}) +{{- end}} +) +` diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 8aae3ac..dc511a9 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -75,13 +75,10 @@ func (c *baseColumn) setTableName(table string) error { func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { if c.table != "" { - _ = out.WriteByte('`') _, _ = out.WriteString(c.table) - _, _ = out.WriteString("`.") + _, _ = out.WriteString(".") } - _, _ = out.WriteString("`") _, _ = out.WriteString(c.name) - _ = out.WriteByte('`') return nil } @@ -147,7 +144,7 @@ func DateTimeColumn(name string, nullable NullableColumn) NonAliasColumn { return dc } -type integerColumn struct { +type IntegerColumn struct { baseColumn isExpression } @@ -158,7 +155,7 @@ func IntColumn(name string, nullable NullableColumn) NonAliasColumn { if !validIdentifierName(name) { panic("Invalid column name in int column") } - ic := &integerColumn{} + ic := &IntegerColumn{} ic.name = name ic.nullable = nullable return ic diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index dfcc4a9..ba01b36 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -131,11 +131,9 @@ func (t *Table) ForceIndex(index string) *Table { // Generates the sql string for the current table expression. Note: the // generated string may not be a valid/executable sql statement. func (t *Table) SerializeSql(database string, out *bytes.Buffer) error { - _, _ = out.WriteString("`") _, _ = out.WriteString(database) - _, _ = out.WriteString("`.`") + _, _ = out.WriteString(".") _, _ = out.WriteString(t.Name()) - _, _ = out.WriteString("`") if t.forcedIndex != "" { if !validIdentifierName(t.forcedIndex) { diff --git a/tests/dvdrental.tar b/tests/dvdrental.tar new file mode 100644 index 0000000..ebcef4f Binary files /dev/null and b/tests/dvdrental.tar differ diff --git a/tests/generator_test.go b/tests/generator_test.go new file mode 100644 index 0000000..4ff1cb1 --- /dev/null +++ b/tests/generator_test.go @@ -0,0 +1,40 @@ +package tests + +import ( + "fmt" + "github.com/sub0Zero/go-sqlbuilder/generator" + . "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" + "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/public/table" + "gotest.tools/assert" + "testing" +) + +var ( + folderPath = ".test_files/" + host = "localhost" + port = 5432 + user = "postgres" + password = "postgres" + dbname = "dvd_rental" + schemaName = "public" +) + +//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema public -path .test_files + +func TestGenerateModel(t *testing.T) { + connectString := fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", + host, port, user, password, dbname) + + err := generator.Generate(folderPath, connectString, dbname, schemaName) + + assert.NilError(t, err) +} + +func TestSelectQuery(t *testing.T) { + query, err := table.Actor.InnerJoinOn(table.Store, Eq(table.Actor.ActorID, table.Store.StoreID)). + Select(table.Store.StoreID, table.Store.AddressID, table.Actor.ActorID).String(schemaName) + + assert.NilError(t, err) + + assert.Equal(t, query, "SELECT store.store_id,store.address_id,actor.actor_id FROM public.actor JOIN public.store ON actor.actor_id=store.store_id") +}