diff --git a/generator/internal/metadata/column_meta_data.go b/generator/internal/metadata/column_meta_data.go index c1fdd10..69a16f7 100644 --- a/generator/internal/metadata/column_meta_data.go +++ b/generator/internal/metadata/column_meta_data.go @@ -142,13 +142,10 @@ func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string { return "" } -func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) { +func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData { rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) defer rows.Close() ret := []ColumnMetaData{} @@ -157,19 +154,13 @@ func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableN var name, isNullable, dataType, enumName string var isUnsigned bool err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned)) } err = rows.Err() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return ret, nil + return ret } diff --git a/generator/internal/metadata/dialect_query_set.go b/generator/internal/metadata/dialect_query_set.go index 6cc9834..6c91825 100644 --- a/generator/internal/metadata/dialect_query_set.go +++ b/generator/internal/metadata/dialect_query_set.go @@ -11,5 +11,5 @@ type DialectQuerySet interface { ListOfColumnsQuery() string ListOfEnumsQuery() string - GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaData, error) + GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData } diff --git a/generator/internal/metadata/schema_meta_data.go b/generator/internal/metadata/schema_meta_data.go index 0c54d7d..836745b 100644 --- a/generator/internal/metadata/schema_meta_data.go +++ b/generator/internal/metadata/schema_meta_data.go @@ -3,6 +3,7 @@ package metadata import ( "database/sql" "fmt" + "github.com/go-jet/jet/internal/utils" ) // SchemaMetaData struct @@ -23,25 +24,11 @@ const ( ) // GetSchemaMetaData returns schema information from db connection. -func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) { +func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) { - schemaInfo.TablesMetaData, err = getTablesMetaData(db, querySet, schemaName, baseTable) - - if err != nil { - return - } - - schemaInfo.ViewsMetaData, err = getTablesMetaData(db, querySet, schemaName, view) - - if err != nil { - return - } - - schemaInfo.EnumsMetaData, err = querySet.GetEnumsMetaData(db, schemaName) - - if err != nil { - return - } + schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable) + schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view) + schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName) fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),", len(schemaInfo.EnumsMetaData), "enum(s)") @@ -49,13 +36,10 @@ func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) return } -func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) ([]MetaData, error) { +func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData { rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) defer rows.Close() ret := []MetaData{} @@ -63,24 +47,15 @@ func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableTy var tableName string err = rows.Scan(&tableName) - if err != nil { - return nil, err - } + utils.PanicOnError(err) - tableInfo, err := GetTableMetaData(db, querySet, schemaName, tableName) - - if err != nil { - return nil, err - } + tableInfo := GetTableMetaData(db, querySet, schemaName, tableName) ret = append(ret, tableInfo) } err = rows.Err() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return ret, nil + return ret } diff --git a/generator/internal/metadata/table_meta_data.go b/generator/internal/metadata/table_meta_data.go index eea0604..cb738fa 100644 --- a/generator/internal/metadata/table_meta_data.go +++ b/generator/internal/metadata/table_meta_data.go @@ -68,45 +68,31 @@ func (t TableMetaData) GoStructName() string { } // GetTableMetaData returns table info metadata -func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) { +func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) { tableInfo.SchemaName = schemaName tableInfo.name = tableName - tableInfo.PrimaryKeys, err = getPrimaryKeys(db, querySet, schemaName, tableName) - if err != nil { - return - } - - tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName) - - if err != nil { - return - } + tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName) + tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName) return } -func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) { +func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool { rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) primaryKeyMap := map[string]bool{} for rows.Next() { primaryKey := "" err := rows.Scan(&primaryKey) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) primaryKeyMap[primaryKey] = true } - return primaryKeyMap, nil + return primaryKeyMap } diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go index c545335..a076bb1 100644 --- a/generator/internal/template/generate.go +++ b/generator/internal/template/generate.go @@ -12,106 +12,61 @@ import ( ) // GenerateFiles generates Go files from tables and enums metadata -func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) error { +func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) { if schemaInfo.IsEmpty() { - return nil + return } fmt.Println("Destination directory:", destDir) fmt.Println("Cleaning up destination directory...") err := utils.CleanUpGeneratedFiles(destDir) + utils.PanicOnError(err) - if err != nil { - return err - } + generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) + generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) + generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) - err = generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) - - if err != nil { - return err - } - - err = generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) - - if err != nil { - return err - } - - err = generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) - - if err != nil { - return err - } - - err = generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect) - - if err != nil { - return err - } - - err = generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect) - - if err != nil { - return err - } - - err = generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect) - - if err != nil { - return err - } + generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect) + generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect) + generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect) fmt.Println("Done") - - return nil } -func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) error { +func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { if len(metaData) == 0 { - return nil + return } fmt.Printf("Generating %s sql builder files...\n", fileTypes) - return generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect) + generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect) } -func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) error { +func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { if len(metaData) == 0 { - return nil + return } fmt.Printf("Generating %s model files...\n", fileTypes) - return generateGoFiles(destDir, "model", modelTemplate, metaData, dialect) + generateGoFiles(destDir, "model", modelTemplate, metaData, dialect) } -func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error { +func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) { modelDirPath := filepath.Join(dirPath, packageName) err := utils.EnsureDirPath(modelDirPath) - - if err != nil { - return err - } + utils.PanicOnError(err) autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect) - - if err != nil { - return err - } + utils.PanicOnError(err) for _, metaData := range metaDataList { text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName}) - - if err != nil { - return err - } + utils.PanicOnError(err) err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...)) - - if err != nil { - return err - } + utils.PanicOnError(err) } - return nil + return } // GenerateTemplate generates template with template text and template data. diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index 4a4b4ca..75405ea 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -22,50 +22,34 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) error { - db, err := openConnection(dbConn) - if err != nil { - return err - } +func Generate(destDir string, dbConn DBConnection) (err error) { + defer utils.ErrorCatch(&err) + + db := openConnection(dbConn) defer utils.DBClose(db) fmt.Println("Retrieving database information...") // No schemas in MySQL - dbInfo, err := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{}) - - if err != nil { - return err - } + dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{}) genPath := path.Join(destDir, dbConn.DBName) - err = template.GenerateFiles(genPath, dbInfo, mysql.Dialect) - - if err != nil { - return err - } + template.GenerateFiles(genPath, dbInfo, mysql.Dialect) return nil } -func openConnection(dbConn DBConnection) (*sql.DB, error) { +func openConnection(dbConn DBConnection) *sql.DB { var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) if dbConn.Params != "" { connectionString += "?" + dbConn.Params } - db, err := sql.Open("mysql", connectionString) - fmt.Println("Connecting to MySQL database: " + connectionString) - - if err != nil { - return nil, err - } + db, err := sql.Open("mysql", connectionString) + utils.PanicOnError(err) err = db.Ping() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return db, nil + return db } diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 3d146d4..a1ad8ec 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -3,6 +3,7 @@ package mysql import ( "database/sql" "github.com/go-jet/jet/generator/internal/metadata" + "github.com/go-jet/jet/internal/utils" "strings" ) @@ -50,13 +51,10 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; ` } -func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) { +func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { rows, err := db.Query(m.ListOfEnumsQuery(), schemaName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) defer rows.Close() ret := []metadata.MetaData{} @@ -65,9 +63,7 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad var enumName string var enumValues string err = rows.Scan(&enumName, &enumValues) - if err != nil { - return nil, err - } + utils.PanicOnError(err) enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1) @@ -78,11 +74,8 @@ func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metad } err = rows.Err() + utils.PanicOnError(err) - if err != nil { - return nil, err - } - - return ret, nil + return ret } diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index d46d6f0..392a00b 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -25,31 +25,20 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) error { +func Generate(destDir string, dbConn DBConnection) (err error) { + defer utils.ErrorCatch(&err) db, err := openConnection(dbConn) + utils.PanicOnError(err) defer utils.DBClose(db) - if err != nil { - return err - } - fmt.Println("Retrieving schema information...") - schemaInfo, err := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{}) - - if err != nil { - return err - } + schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{}) genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) + template.GenerateFiles(genPath, schemaInfo, postgres.Dialect) - err = template.GenerateFiles(genPath, schemaInfo, postgres.Dialect) - - if err != nil { - return err - } - - return nil + return } func openConnection(dbConn DBConnection) (*sql.DB, error) { diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index f1c7457..ce4a083 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -3,6 +3,7 @@ package postgres import ( "database/sql" "github.com/go-jet/jet/generator/internal/metadata" + "github.com/go-jet/jet/internal/utils" ) // postgresQuerySet is dialect query set for PostgreSQL @@ -45,12 +46,9 @@ WHERE n.nspname = $1 ORDER BY n.nspname, t.typname, e.enumsortorder;` } -func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.MetaData, error) { +func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { rows, err := db.Query(p.ListOfEnumsQuery(), schemaName) - - if err != nil { - return nil, err - } + utils.PanicOnError(err) defer rows.Close() enumsInfosMap := map[string][]string{} @@ -58,9 +56,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me var enumName string var enumValue string err = rows.Scan(&enumName, &enumValue) - if err != nil { - return nil, err - } + utils.PanicOnError(err) enumValues := enumsInfosMap[enumName] @@ -70,10 +66,7 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me } err = rows.Err() - - if err != nil { - return nil, err - } + utils.PanicOnError(err) ret := []metadata.MetaData{} @@ -84,5 +77,5 @@ func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]me }) } - return ret, nil + return ret } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index ecc6471..a091973 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -2,6 +2,7 @@ package utils import ( "database/sql" + "fmt" "github.com/go-jet/jet/internal/3rdparty/snaker" "go/format" "os" @@ -146,3 +147,20 @@ func PanicOnError(err error) { panic(err) } } + +// ErrorCatch is used in defer to recover from panics and to set err +func ErrorCatch(err *error) { + recovered := recover() + + if recovered == nil { + return + } + + recoveredErr, isError := recovered.(error) + + if isError { + *err = recoveredErr + } else { + *err = fmt.Errorf("%v", recovered) + } +}