diff --git a/README.md b/README.md index 2dc87fb..737b1e2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Jet Sql Builder for Postgresql(soon MySql and OracleSql) +# Jet - Sql Builder for Postgresql [![CircleCI](https://circleci.com/gh/go-jet/jet/tree/develop.svg?style=svg&circle-token=97f255c6a4a3ab6590ea2e9195eb3ebf9f97b4a7)](https://circleci.com/gh/go-jet/jet/tree/develop) diff --git a/cmd/generator/main.go b/cmd/jetgen/main.go similarity index 86% rename from cmd/generator/main.go rename to cmd/jetgen/main.go index 8ded84f..535d3db 100644 --- a/cmd/generator/main.go +++ b/cmd/jetgen/main.go @@ -3,7 +3,7 @@ package main import ( "flag" "fmt" - "github.com/go-jet/jet/generator" + "github.com/go-jet/jet/generator/postgresgen" "os" ) @@ -38,7 +38,7 @@ func init() { func main() { - genData := generator.GeneratorData{ + genData := postgresgen.DBConnection{ Host: host, Port: port, User: user, @@ -50,14 +50,10 @@ func main() { SchemaName: schemaName, } - fmt.Println(destDir, genData) - - err := generator.Generate(destDir, genData) + err := postgresgen.Generate(destDir, genData) if err != nil { fmt.Println(err.Error()) os.Exit(-1) } - - fmt.Println("SUCCESS") } diff --git a/delete_statement.go b/delete_statement.go index e3ed399..8259fec 100644 --- a/delete_statement.go +++ b/delete_statement.go @@ -72,18 +72,18 @@ func (d *deleteStatementImpl) DebugSql() (query string, err error) { return DebugSql(d) } -func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) error { +func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error { return Query(d, db, destination) } -func (d *deleteStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { +func (d *deleteStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { return QueryContext(d, db, context, destination) } -func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { +func (d *deleteStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { return Exec(d, db) } -func (d *deleteStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { +func (d *deleteStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { return ExecContext(d, db, context) } diff --git a/execution/db.go b/execution/db.go index ab55d2d..a18f8c5 100644 --- a/execution/db.go +++ b/execution/db.go @@ -5,7 +5,7 @@ import ( "database/sql" ) -type Db interface { +type DB interface { Exec(query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) diff --git a/execution/execution.go b/execution/execution.go index 2e2d595..0241dc8 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "errors" "fmt" + "github.com/go-jet/jet/execution/internal" "github.com/serenize/snaker" "reflect" "strconv" @@ -13,7 +14,7 @@ import ( "time" ) -func Query(db Db, context context.Context, query string, args []interface{}, destinationPtr interface{}) error { +func Query(db DB, context context.Context, query string, args []interface{}, destinationPtr interface{}) error { if destinationPtr == nil { return errors.New("Destination is nil. ") @@ -54,7 +55,7 @@ func Query(db Db, context context.Context, query string, args []interface{}, des } } -func queryToSlice(db Db, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error { +func queryToSlice(db DB, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error { if db == nil { return errors.New("db is nil") } @@ -494,15 +495,15 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} { return values } -var nullFloatType = reflect.TypeOf(NullFloat32{}) +var nullFloatType = reflect.TypeOf(internal.NullFloat32{}) var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) -var nullInt16Type = reflect.TypeOf(NullInt16{}) -var nullInt32Type = reflect.TypeOf(NullInt32{}) +var nullInt16Type = reflect.TypeOf(internal.NullInt16{}) +var nullInt32Type = reflect.TypeOf(internal.NullInt32{}) var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) var nullStringType = reflect.TypeOf(sql.NullString{}) var nullBoolType = reflect.TypeOf(sql.NullBool{}) -var nullTimeType = reflect.TypeOf(NullTime{}) -var nullByteArrayType = reflect.TypeOf(NullByteArray{}) +var nullTimeType = reflect.TypeOf(internal.NullTime{}) +var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{}) func newScanType(columnType *sql.ColumnType) reflect.Type { switch columnType.DatabaseTypeName() { diff --git a/execution/null_types.go b/execution/internal/null_types.go similarity index 99% rename from execution/null_types.go rename to execution/internal/null_types.go index 0cb57b6..0d18120 100644 --- a/execution/null_types.go +++ b/execution/internal/null_types.go @@ -1,4 +1,4 @@ -package execution +package internal import ( "database/sql/driver" diff --git a/generator/metadata/meta_data.go b/generator/internal/metadata/meta_data.go similarity index 100% rename from generator/metadata/meta_data.go rename to generator/internal/metadata/meta_data.go diff --git a/generator/postgres-metadata/column_info.go b/generator/internal/metadata/postgres-metadata/column_info.go similarity index 100% rename from generator/postgres-metadata/column_info.go rename to generator/internal/metadata/postgres-metadata/column_info.go diff --git a/generator/postgres-metadata/enum_info.go b/generator/internal/metadata/postgres-metadata/enum_info.go similarity index 92% rename from generator/postgres-metadata/enum_info.go rename to generator/internal/metadata/postgres-metadata/enum_info.go index 5de3d78..98d6235 100644 --- a/generator/postgres-metadata/enum_info.go +++ b/generator/internal/metadata/postgres-metadata/enum_info.go @@ -2,8 +2,7 @@ package postgres_metadata import ( "database/sql" - "fmt" - "github.com/go-jet/jet/generator/metadata" + "github.com/go-jet/jet/generator/internal/metadata" ) type EnumInfo struct { @@ -63,7 +62,5 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;` }) } - fmt.Println("FOUND", len(ret), " enums") - return ret, nil } diff --git a/generator/postgres-metadata/schema_info.go b/generator/internal/metadata/postgres-metadata/schema_info.go similarity index 93% rename from generator/postgres-metadata/schema_info.go rename to generator/internal/metadata/postgres-metadata/schema_info.go index 3c67c76..1d63296 100644 --- a/generator/postgres-metadata/schema_info.go +++ b/generator/internal/metadata/postgres-metadata/schema_info.go @@ -2,8 +2,7 @@ package postgres_metadata import ( "database/sql" - "fmt" - "github.com/go-jet/jet/generator/metadata" + "github.com/go-jet/jet/generator/internal/metadata" ) type SchemaInfo struct { @@ -65,8 +64,6 @@ where table_catalog = $1 and table_schema = $2 and table_type = 'BASE TABLE'; ret = append(ret, tableInfo) } - fmt.Println("FOUND", len(ret), "tables") - err = rows.Err() if err != nil { diff --git a/generator/postgres-metadata/table_info.go b/generator/internal/metadata/postgres-metadata/table_info.go similarity index 100% rename from generator/postgres-metadata/table_info.go rename to generator/internal/metadata/postgres-metadata/table_info.go diff --git a/generator/utils.go b/generator/internal/utils/utils.go similarity index 80% rename from generator/utils.go rename to generator/internal/utils/utils.go index 9e1556b..343f4d5 100644 --- a/generator/utils.go +++ b/generator/internal/utils/utils.go @@ -1,4 +1,4 @@ -package generator +package utils import ( "bytes" @@ -11,7 +11,7 @@ import ( "time" ) -func saveGoFile(dirPath, fileName string, text []byte) error { +func SaveGoFile(dirPath, fileName string, text []byte) error { newGoFilePath := filepath.Join(dirPath, fileName) + ".go" file, err := os.Create(newGoFilePath) @@ -36,7 +36,7 @@ func saveGoFile(dirPath, fileName string, text []byte) error { return nil } -func ensureDirPath(dirPath string) error { +func EnsureDirPath(dirPath string) error { if _, err := os.Stat(dirPath); os.IsNotExist(err) { err := os.MkdirAll(dirPath, os.ModePerm) @@ -48,7 +48,7 @@ func ensureDirPath(dirPath string) error { return nil } -func generateTemplate(templateText string, templateData interface{}) ([]byte, error) { +func GenerateTemplate(templateText string, templateData interface{}) ([]byte, error) { t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ "camelize": func(txt string) string { @@ -71,8 +71,8 @@ func generateTemplate(templateText string, templateData interface{}) ([]byte, er return buf.Bytes(), nil } -func cleanUpGeneratedFiles(dir string) error { - exist, err := dirExists(dir) +func CleanUpGeneratedFiles(dir string) error { + exist, err := DirExists(dir) if err != nil { return err @@ -89,7 +89,7 @@ func cleanUpGeneratedFiles(dir string) error { return nil } -func dirExists(path string) (bool, error) { +func DirExists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { return true, nil diff --git a/generator/generator.go b/generator/postgresgen/generator.go similarity index 50% rename from generator/generator.go rename to generator/postgresgen/generator.go index 8b2fdbf..7d48866 100644 --- a/generator/generator.go +++ b/generator/postgresgen/generator.go @@ -1,16 +1,17 @@ -package generator +package postgresgen import ( "database/sql" "fmt" - "github.com/go-jet/jet/generator/metadata" - "github.com/go-jet/jet/generator/postgres-metadata" + "github.com/go-jet/jet/generator/internal/metadata" + "github.com/go-jet/jet/generator/internal/metadata/postgres-metadata" + "github.com/go-jet/jet/generator/internal/utils" _ "github.com/lib/pq" "path" "path/filepath" ) -type GeneratorData struct { +type DBConnection struct { Host string Port string User string @@ -22,11 +23,13 @@ type GeneratorData struct { SchemaName string } -func Generate(destDir string, genData GeneratorData) error { +func Generate(destDir string, genData DBConnection) error { connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", genData.Host, genData.Port, genData.User, genData.Password, genData.DBName, genData.SslMode, genData.Params) + fmt.Println("Connecting to postgres database: " + connectionString) + db, err := sql.Open("postgres", connectionString) if err != nil { return err @@ -39,42 +42,53 @@ func Generate(destDir string, genData GeneratorData) error { return err } - err = cleanUpGeneratedFiles(path.Join(destDir, genData.DBName, genData.SchemaName)) - - if err != nil { - return err - } - + fmt.Println("Retrieving schema information...") schemaInfo, err := postgres_metadata.GetSchemaInfo(db, genData.DBName, genData.SchemaName) if err != nil { return err } + fmt.Println(" FOUND", len(schemaInfo.TableInfos), " table(s), ", len(schemaInfo.EnumInfos), " enum(s)") + + fmt.Println("Cleaning up destination directory...") + err = utils.CleanUpGeneratedFiles(path.Join(destDir, genData.DBName, genData.SchemaName)) + + if err != nil { + return err + } + + fmt.Println("Generating table sql builder files...") err = generate(schemaInfo, destDir, "table", sqlBuilderTableTemplate, schemaInfo.TableInfos) if err != nil { return err } - //err = generateDataModel(schemaInfo, destDir) + fmt.Println("Generating table model files...") err = generate(schemaInfo, destDir, "model", dataModelTemplate, schemaInfo.TableInfos) if err != nil { return err } - err = generate(schemaInfo, destDir, "model", enumModelTemplate, schemaInfo.EnumInfos) + if len(schemaInfo.EnumInfos) > 0 { + fmt.Println("Generating enum sql builder files...") + err = generate(schemaInfo, destDir, "enum", enumTypeTemplate, schemaInfo.EnumInfos) - if err != nil { - return err + if err != nil { + return err + } + + fmt.Println("Generating enum model files...") + err = generate(schemaInfo, destDir, "model", enumModelTemplate, schemaInfo.EnumInfos) + + if err != nil { + return err + } } - err = generate(schemaInfo, destDir, "enum", enumTypeTemplate, schemaInfo.EnumInfos) - - if err != nil { - return err - } + fmt.Println("Done") return nil } @@ -82,26 +96,26 @@ func Generate(destDir string, genData GeneratorData) error { func generate(schemaInfo postgres_metadata.SchemaInfo, dirPath, packageName string, template string, metaDataList []metadata.MetaData) error { modelDirPath := filepath.Join(dirPath, schemaInfo.DatabaseName, schemaInfo.Name, packageName) - err := ensureDirPath(modelDirPath) + err := utils.EnsureDirPath(modelDirPath) if err != nil { return err } - autoGenWarning, err := generateTemplate(autoGenWarningTemplate, nil) + autoGenWarning, err := utils.GenerateTemplate(autoGenWarningTemplate, nil) if err != nil { return err } for _, metaData := range metaDataList { - text, err := generateTemplate(template, metaData) + text, err := utils.GenerateTemplate(template, metaData) if err != nil { return err } - err = saveGoFile(modelDirPath, metaData.Name(), append(autoGenWarning, text...)) + err = utils.SaveGoFile(modelDirPath, metaData.Name(), append(autoGenWarning, text...)) if err != nil { return err diff --git a/generator/templates.go b/generator/postgresgen/templates.go similarity index 99% rename from generator/templates.go rename to generator/postgresgen/templates.go index 5a96d33..177e623 100644 --- a/generator/templates.go +++ b/generator/postgresgen/templates.go @@ -1,4 +1,4 @@ -package generator +package postgresgen var autoGenWarningTemplate = ` // diff --git a/insert_statement.go b/insert_statement.go index c84e478..77d8f6d 100644 --- a/insert_statement.go +++ b/insert_statement.go @@ -146,18 +146,18 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) return } -func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) error { +func (i *insertStatementImpl) Query(db execution.DB, destination interface{}) error { return Query(i, db, destination) } -func (i *insertStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { +func (i *insertStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { return QueryContext(i, db, context, destination) } -func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { +func (i *insertStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { return Exec(i, db) } -func (i *insertStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { +func (i *insertStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { return ExecContext(i, db, context) } diff --git a/lock_statement.go b/lock_statement.go index 98238b6..6f0997a 100644 --- a/lock_statement.go +++ b/lock_statement.go @@ -93,18 +93,18 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) return } -func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) error { +func (l *lockStatementImpl) Query(db execution.DB, destination interface{}) error { return Query(l, db, destination) } -func (l *lockStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { +func (l *lockStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { return QueryContext(l, db, context, destination) } -func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) { +func (l *lockStatementImpl) Exec(db execution.DB) (sql.Result, error) { return Exec(l, db) } -func (l *lockStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { +func (l *lockStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { return ExecContext(l, db, context) } diff --git a/select_statement.go b/select_statement.go index 24fadb2..356a3c7 100644 --- a/select_statement.go +++ b/select_statement.go @@ -289,18 +289,18 @@ func (s *selectLockImpl) serialize(statement statementType, out *queryData, opti return nil } -func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) error { +func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error { return Query(s, db, destination) } -func (s *selectStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { +func (s *selectStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { return QueryContext(s, db, context, destination) } -func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { +func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { return Exec(s, db) } -func (s *selectStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { +func (s *selectStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { return ExecContext(s, db, context) } diff --git a/set_statement.go b/set_statement.go index bbe738a..46997c5 100644 --- a/set_statement.go +++ b/set_statement.go @@ -203,18 +203,18 @@ func (s *setStatementImpl) DebugSql() (query string, err error) { return DebugSql(s) } -func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error { +func (s *setStatementImpl) Query(db execution.DB, destination interface{}) error { return Query(s, db, destination) } -func (s *setStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { +func (s *setStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { return QueryContext(s, db, context, destination) } -func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { +func (s *setStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { return Exec(s, db) } -func (s *setStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { +func (s *setStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { return ExecContext(s, db, context) } diff --git a/statement.go b/statement.go index d8d67e7..339a3d6 100644 --- a/statement.go +++ b/statement.go @@ -14,11 +14,11 @@ type Statement interface { DebugSql() (query string, err error) - Query(db execution.Db, destination interface{}) error - QueryContext(db execution.Db, context context.Context, destination interface{}) error + Query(db execution.DB, destination interface{}) error + QueryContext(db execution.DB, context context.Context, destination interface{}) error - Exec(db execution.Db) (sql.Result, error) - ExecContext(db execution.Db, context context.Context) (sql.Result, error) + Exec(db execution.DB) (sql.Result, error) + ExecContext(db execution.DB, context context.Context) (sql.Result, error) } func DebugSql(statement Statement) (string, error) { @@ -38,7 +38,7 @@ func DebugSql(statement Statement) (string, error) { return debugSqlQuery, nil } -func Query(statement Statement, db execution.Db, destination interface{}) error { +func Query(statement Statement, db execution.DB, destination interface{}) error { query, args, err := statement.Sql() if err != nil { @@ -48,7 +48,7 @@ func Query(statement Statement, db execution.Db, destination interface{}) error return execution.Query(db, context.Background(), query, args, destination) } -func QueryContext(statement Statement, db execution.Db, context context.Context, destination interface{}) error { +func QueryContext(statement Statement, db execution.DB, context context.Context, destination interface{}) error { query, args, err := statement.Sql() if err != nil { @@ -58,7 +58,7 @@ func QueryContext(statement Statement, db execution.Db, context context.Context, return execution.Query(db, context, query, args, destination) } -func Exec(statement Statement, db execution.Db) (res sql.Result, err error) { +func Exec(statement Statement, db execution.DB) (res sql.Result, err error) { query, args, err := statement.Sql() if err != nil { @@ -68,7 +68,7 @@ func Exec(statement Statement, db execution.Db) (res sql.Result, err error) { return db.Exec(query, args...) } -func ExecContext(statement Statement, db execution.Db, context context.Context) (res sql.Result, err error) { +func ExecContext(statement Statement, db execution.DB, context context.Context) (res sql.Result, err error) { query, args, err := statement.Sql() if err != nil { diff --git a/tests/chinook_db_test.go b/tests/chinook_db_test.go index 3dc6b6e..524f20a 100644 --- a/tests/chinook_db_test.go +++ b/tests/chinook_db_test.go @@ -91,7 +91,7 @@ func TestJoinEverything(t *testing.T) { Customer struct { // customer data for invoice model.Customer - Employee *struct { + Employee *struct { // employee data for customer if exists model.Employee Manager *model.Employee diff --git a/tests/init/init.go b/tests/init/init.go index 388b376..6e0fcc6 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -3,7 +3,7 @@ package main import ( "database/sql" "fmt" - "github.com/go-jet/jet/generator" + postgres_generator "github.com/go-jet/jet/generator/postgresgen" "github.com/go-jet/jet/tests/dbconfig" "io/ioutil" ) @@ -33,7 +33,7 @@ func main() { _, err = db.Exec(string(testSampleSql)) - err = generator.Generate("./.test_files", generator.GeneratorData{ + err = postgres_generator.Generate("./.test_files", postgres_generator.DBConnection{ Host: dbconfig.Host, Port: "5432", User: dbconfig.User, diff --git a/update_statement.go b/update_statement.go index c63d154..32cbf67 100644 --- a/update_statement.go +++ b/update_statement.go @@ -139,18 +139,18 @@ func (u *updateStatementImpl) DebugSql() (query string, err error) { return DebugSql(u) } -func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) error { +func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error { return Query(u, db, destination) } -func (u *updateStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error { +func (u *updateStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { return QueryContext(u, db, context, destination) } -func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { +func (u *updateStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { return Exec(u, db) } -func (u *updateStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) { +func (u *updateStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { return ExecContext(u, db, context) }