Improve generator error handling

This commit is contained in:
go-jet 2023-07-21 13:20:44 +02:00
parent b38b63d804
commit 06ecd73f67
12 changed files with 386 additions and 176 deletions

View file

@ -1,10 +1,11 @@
package main package main
//go:generate sh -c "printf 'package main\n\nconst version = \"'%s'\"' $(git describe --tags --abbrev=0) > version.go" //go:generate sh -c "printf 'package main\n\nconst version = \"'%s'\"\n' $(git describe --tags --abbrev=0) > version.go"
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/errfmt"
"os" "os"
"strings" "strings"
@ -155,8 +156,8 @@ func main() {
} }
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(errfmt.Trace(err))
os.Exit(-5) os.Exit(2)
} }
} }
@ -192,7 +193,7 @@ func printErrorAndExit(error string) {
fmt.Println("\n", error) fmt.Println("\n", error)
fmt.Println() fmt.Println()
flag.Usage() flag.Usage()
os.Exit(-2) os.Exit(1)
} }
func getSource() string { func getSource() string {

View file

@ -16,21 +16,36 @@ const (
// DialectQuerySet is set of methods necessary to retrieve dialect metadata information // DialectQuerySet is set of methods necessary to retrieve dialect metadata information
type DialectQuerySet interface { type DialectQuerySet interface {
GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) []Table GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) ([]Table, error)
GetEnumsMetaData(db *sql.DB, schemaName string) []Enum GetEnumsMetaData(db *sql.DB, schemaName string) ([]Enum, error)
} }
// GetSchema retrieves Schema information from database // GetSchema retrieves Schema information from database
func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema { func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) (Schema, error) {
tablesMetaData, err := querySet.GetTablesMetaData(db, schemaName, BaseTable)
if err != nil {
return Schema{}, fmt.Errorf("failed to get %s tables metadata: %w", schemaName, err)
}
viewMetaData, err := querySet.GetTablesMetaData(db, schemaName, ViewTable)
if err != nil {
return Schema{}, fmt.Errorf("failed to get %s view metadata: %w", schemaName, err)
}
enumsMetaData, err := querySet.GetEnumsMetaData(db, schemaName)
if err != nil {
return Schema{}, fmt.Errorf("failed to get %s enum metadata: %w", schemaName, err)
}
ret := Schema{ ret := Schema{
Name: schemaName, Name: schemaName,
TablesMetaData: querySet.GetTablesMetaData(db, schemaName, BaseTable), TablesMetaData: tablesMetaData,
ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, ViewTable), ViewsMetaData: viewMetaData,
EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName), EnumsMetaData: enumsMetaData,
} }
fmt.Println(" FOUND", len(ret.TablesMetaData), "table(s),", len(ret.ViewsMetaData), "view(s),", fmt.Println(" FOUND", len(ret.TablesMetaData), "table(s),", len(ret.ViewsMetaData), "view(s),",
len(ret.EnumsMetaData), "enum(s)") len(ret.EnumsMetaData), "enum(s)")
return ret return ret, nil
} }

View file

@ -2,13 +2,13 @@ package mysql
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/mysql"
mysqldr "github.com/go-sql-driver/mysql" mysqldr "github.com/go-sql-driver/mysql"
) )
@ -25,26 +25,28 @@ type DBConnection struct {
} }
// Generate generates jet files at destination dir from database connection details // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) (err error) { func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) error {
defer utils.ErrorCatch(&err)
connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
if dbConn.Params != "" { if dbConn.Params != "" {
connectionString += "?" + dbConn.Params connectionString += "?" + dbConn.Params
} }
db := openConnection(connectionString) db, err := openConnection(connectionString)
if err != nil {
return fmt.Errorf("failed to open db connection: %w", err)
}
defer utils.DBClose(db) defer utils.DBClose(db)
generate(db, dbConn.DBName, destDir, generatorTemplate...) err = generate(db, dbConn.DBName, destDir, generatorTemplate...)
if err != nil {
return err
}
return nil return nil
} }
// GenerateDSN opens connection via DSN string and does everything what Generate does. // GenerateDSN opens connection via DSN string and does everything what Generate does.
func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) { func GenerateDSN(dsn, destDir string, templates ...template.Template) error {
defer utils.ErrorCatch(&err)
// Special case for go mysql driver. It does not understand schema, // Special case for go mysql driver. It does not understand schema,
// so we need to trim it before passing to generator // so we need to trim it before passing to generator
// https://github.com/go-sql-driver/mysql#dsn-data-source-name // https://github.com/go-sql-driver/mysql#dsn-data-source-name
@ -54,39 +56,59 @@ func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error
} }
cfg, err := mysqldr.ParseDSN(dsn) cfg, err := mysqldr.ParseDSN(dsn)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to parse DSN: %w", err)
}
if cfg.DBName == "" { if cfg.DBName == "" {
panic("database name is required") return errors.New("database name is required")
} }
db := openConnection(dsn) db, err := openConnection(dsn)
if err != nil {
return fmt.Errorf("failed to open db connection: %w", err)
}
defer utils.DBClose(db) defer utils.DBClose(db)
generate(db, cfg.DBName, destDir, templates...) err = generate(db, cfg.DBName, destDir, templates...)
if err != nil {
return fmt.Errorf("failed to generate: %w", err)
}
return nil return nil
} }
func openConnection(connectionString string) *sql.DB { func openConnection(connectionString string) (*sql.DB, error) {
fmt.Println("Connecting to MySQL database...") fmt.Println("Connecting to MySQL database...")
db, err := sql.Open("mysql", connectionString) db, err := sql.Open("mysql", connectionString)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to open mysql connection: %w", err)
err = db.Ping()
throw.OnError(err)
return db
} }
func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) { err = db.Ping()
if err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return db, nil
}
func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) error {
fmt.Println("Retrieving database information...") fmt.Println("Retrieving database information...")
// No schemas in MySQL // No schemas in MySQL
schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbName) schemaMetaData, err := metadata.GetSchema(db, &mySqlQuerySet{}, dbName)
if err != nil {
return fmt.Errorf("failed to get '%s' database metadata: %w", dbName, err)
}
genTemplate := template.Default(mysql.Dialect) genTemplate := template.Default(mysql.Dialect)
if len(templates) > 0 { if len(templates) > 0 {
genTemplate = templates[0] genTemplate = templates[0]
} }
template.ProcessSchema(destDir, schemaMetaData, genTemplate) err = template.ProcessSchema(destDir, schemaMetaData, genTemplate)
if err != nil {
return fmt.Errorf("failed to process '%s' database: %w", schemaMetaData.Name, err)
}
return nil
} }

View file

@ -3,17 +3,17 @@ package mysql
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
) )
// mySqlQuerySet is dialect query set for MySQL // mySqlQuerySet is dialect query set for MySQL
type mySqlQuerySet struct{} type mySqlQuerySet struct{}
func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
query := ` query := `
SELECT table_name as "table.name" SELECT table_name as "table.name"
FROM INFORMATION_SCHEMA.tables FROM INFORMATION_SCHEMA.tables
@ -23,16 +23,21 @@ ORDER BY table_name;
var tables []metadata.Table var tables []metadata.Table
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query %s metadata result: %w", tableType, err)
}
for i := range tables { for i := range tables {
tables[i].Columns = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name) tables[i].Columns, err = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
if err != nil {
return nil, fmt.Errorf("failed to get '%s' table columns metadata: %w", tables[i].Name, err)
}
} }
return tables return tables, nil
} }
func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) {
query := ` query := `
SELECT COLUMN_NAME AS "column.Name", SELECT COLUMN_NAME AS "column.Name",
IS_NULLABLE = "YES" AS "column.IsNullable", IS_NULLABLE = "YES" AS "column.IsNullable",
@ -57,12 +62,14 @@ ORDER BY ordinal_position;
` `
var columns []metadata.Column var columns []metadata.Column
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query %s column meta data: %w", tableName, err)
return columns
} }
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { return columns, nil
}
func (m mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
query := ` query := `
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name", SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name",
SUBSTRING(c.COLUMN_TYPE,5) as "values" SUBSTRING(c.COLUMN_TYPE,5) as "values"
@ -76,7 +83,9 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
} }
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query enums meta data: %w", err)
}
var ret []metadata.Enum var ret []metadata.Enum
@ -89,5 +98,5 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
}) })
} }
return ret return ret, nil
} }

View file

@ -10,7 +10,6 @@ import (
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/postgres"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
) )
@ -43,15 +42,18 @@ func Generate(destDir string, dbConn DBConnection, genTemplate ...template.Templ
} }
// GenerateDSN generates jet files using dsn connection string // GenerateDSN generates jet files using dsn connection string
func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (err error) { func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) error {
defer utils.ErrorCatch(&err)
cfg, err := pgconn.ParseConfig(dsn) cfg, err := pgconn.ParseConfig(dsn)
throw.OnError(err) if err != nil {
if cfg.Database == "" { return fmt.Errorf("failed to parse config: %w", err)
panic("database name is required") }
if cfg.Database == "" {
return fmt.Errorf("database name is required")
}
db, err := openConnection(dsn)
if err != nil {
return fmt.Errorf("failed to open db connection: %w", err)
} }
db := openConnection(dsn)
defer utils.DBClose(db) defer utils.DBClose(db)
fmt.Println("Retrieving schema information...") fmt.Println("Retrieving schema information...")
@ -60,22 +62,33 @@ func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (e
generatorTemplate = templates[0] generatorTemplate = templates[0]
} }
schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, schema) schemaMetadata, err := metadata.GetSchema(db, &postgresQuerySet{}, schema)
if err != nil {
return fmt.Errorf("failed to get '%s' schema metadata: %w", schema, err)
}
dirPath := path.Join(destDir, cfg.Database) dirPath := path.Join(destDir, cfg.Database)
template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) err = template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate)
return if err != nil {
return fmt.Errorf("failed to generate schema %s: %d", schemaMetadata.Name, err)
} }
func openConnection(dsn string) *sql.DB { return nil
}
func openConnection(dsn string) (*sql.DB, error) {
fmt.Println("Connecting to postgres database...") fmt.Println("Connecting to postgres database...")
db, err := sql.Open("postgres", dsn) db, err := sql.Open("postgres", dsn)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to open db connection: %w", err)
}
err = db.Ping() err = db.Ping()
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
return db }
return db, nil
} }

View file

@ -3,16 +3,16 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
) )
// postgresQuerySet is dialect query set for PostgreSQL // postgresQuerySet is dialect query set for PostgreSQL
type postgresQuerySet struct{} type postgresQuerySet struct{}
func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
query := ` query := `
SELECT table_name as "table.name" SELECT table_name as "table.name"
FROM information_schema.tables FROM information_schema.tables
@ -22,16 +22,21 @@ ORDER BY table_name;
var tables []metadata.Table var tables []metadata.Table
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query %s metadata: %w", tableType, err)
}
for i := range tables { for i := range tables {
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) tables[i].Columns, err = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
if err != nil {
return nil, fmt.Errorf("failed to query %s columns metadata: %w", tableType, err)
}
} }
return tables return tables, nil
} }
func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) {
query := ` query := `
WITH primaryKeys AS ( WITH primaryKeys AS (
SELECT column_name SELECT column_name
@ -67,12 +72,14 @@ order by ordinal_position;
` `
var columns []metadata.Column var columns []metadata.Column
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query '%s' columns metadata: %w", tableName, err)
return columns
} }
func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { return columns, nil
}
func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
query := ` query := `
SELECT t.typname as "enum.name", SELECT t.typname as "enum.name",
e.enumlabel as "values" e.enumlabel as "values"
@ -85,7 +92,9 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
var result []metadata.Enum var result []metadata.Enum
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query enums metadata for schema '%s': %w", schemaName, err)
return result }
return result, nil
} }

View file

@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
"strings" "strings"
) )
@ -13,7 +12,7 @@ import (
// sqliteQuerySet is dialect query set for SQLite // sqliteQuerySet is dialect query set for SQLite
type sqliteQuerySet struct{} type sqliteQuerySet struct{}
func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
query := ` query := `
SELECT name as "table.name" SELECT name as "table.name"
FROM sqlite_master FROM sqlite_master
@ -29,16 +28,21 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy
var tables []metadata.Table var tables []metadata.Table
_, err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) _, err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query %s metadata: %w", schemaName, err)
}
for i := range tables { for i := range tables {
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) tables[i].Columns, err = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
if err != nil {
return nil, fmt.Errorf("failed to query column metadata: %w", err)
}
} }
return tables return tables, nil
} }
func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) {
query := fmt.Sprintf(`select * from pragma_table_info(?);`) query := fmt.Sprintf(`select * from pragma_table_info(?);`)
var columnInfos []struct { var columnInfos []struct {
Name string Name string
@ -48,7 +52,9 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
} }
_, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) _, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos)
throw.OnError(err) if err != nil {
return nil, fmt.Errorf("failed to query '%s' column metadata: %w", tableName, err)
}
var columns []metadata.Column var columns []metadata.Column
@ -67,7 +73,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
}) })
} }
return columns return columns, nil
} }
// will convert VARCHAR(10) -> VARCHAR, etc... // will convert VARCHAR(10) -> VARCHAR, etc...
@ -75,6 +81,6 @@ func getColumnType(columnType string) string {
return strings.TrimSpace(strings.Split(columnType, "(")[0]) return strings.TrimSpace(strings.Split(columnType, "(")[0])
} }
func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
return nil return nil, nil
} }

View file

@ -6,16 +6,15 @@ import (
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/sqlite"
) )
// GenerateDSN generates jet files using dsn connection string // GenerateDSN generates jet files using dsn connection string
func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) { func GenerateDSN(dsn, destDir string, templates ...template.Template) error {
defer utils.ErrorCatch(&err)
db, err := sql.Open("sqlite3", dsn) db, err := sql.Open("sqlite3", dsn)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to open sqlite connection: %w", err)
}
defer utils.DBClose(db) defer utils.DBClose(db)
fmt.Println("Retrieving schema information...") fmt.Println("Retrieving schema information...")
@ -25,8 +24,15 @@ func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error
generatorTemplate = templates[0] generatorTemplate = templates[0]
} }
schemaMetadata := metadata.GetSchema(db, &sqliteQuerySet{}, "") schemaMetadata, err := metadata.GetSchema(db, &sqliteQuerySet{}, "")
if err != nil {
template.ProcessSchema(destDir, schemaMetadata, generatorTemplate) return fmt.Errorf("failed to query database metadata: %w", err)
return }
err = template.ProcessSchema(destDir, schemaMetadata, generatorTemplate)
if err != nil {
return fmt.Errorf("failed to process database %s: %w", schemaMetadata.Name, err)
}
return nil
} }

View file

@ -2,6 +2,7 @@ package template
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"path" "path"
"strings" "strings"
@ -10,13 +11,12 @@ import (
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
) )
// ProcessSchema will process schema metadata and constructs go files using generator Template // ProcessSchema will process schema metadata and constructs go files using generator Template
func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) { func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) error {
if schemaMetaData.IsEmpty() { if schemaMetaData.IsEmpty() {
return return nil
} }
schemaTemplate := generatorTemplate.Schema(schemaMetaData) schemaTemplate := generatorTemplate.Schema(schemaMetaData)
@ -25,48 +25,87 @@ func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemp
fmt.Println("Destination directory:", schemaPath) fmt.Println("Destination directory:", schemaPath)
fmt.Println("Cleaning up destination directory...") fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(schemaPath) err := utils.CleanUpGeneratedFiles(schemaPath)
throw.OnError(err) if err != nil {
return errors.New("failed to cleanup generated files")
processModel(schemaPath, schemaMetaData, schemaTemplate)
processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate)
} }
func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) { err = processModel(schemaPath, schemaMetaData, schemaTemplate)
if err != nil {
return fmt.Errorf("failed to generate model types: %w", err)
}
err = processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate)
if err != nil {
return fmt.Errorf("failed to generate sql builder types: %w", err)
}
return nil
}
func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) error {
modelTemplate := schemaTemplate.Model modelTemplate := schemaTemplate.Model
if modelTemplate.Skip { if modelTemplate.Skip {
fmt.Println("Skipping the generation of model types.") fmt.Println("Skipping the generation of model types.")
return return nil
} }
modelDirPath := path.Join(dirPath, modelTemplate.Path) modelDirPath := path.Join(dirPath, modelTemplate.Path)
err := utils.EnsureDirPath(modelDirPath) err := utils.EnsureDirPath(modelDirPath)
throw.OnError(err) if err != nil {
return fmt.Errorf("destination dir path does not exist: %w", err)
processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate)
processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate)
processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate)
} }
func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) { err = processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate)
if err != nil {
return fmt.Errorf("failed to generate table model types: %w", err)
}
err = processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate)
if err != nil {
return fmt.Errorf("failed to generate view model types: %w", err)
}
err = processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate)
if err != nil {
return fmt.Errorf("failed to process enum types: %w", err)
}
return nil
}
func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) error {
sqlBuilderTemplate := schemaTemplate.SQLBuilder sqlBuilderTemplate := schemaTemplate.SQLBuilder
if sqlBuilderTemplate.Skip { if sqlBuilderTemplate.Skip {
fmt.Println("Skipping the generation of SQL Builder types.") fmt.Println("Skipping the generation of SQL Builder types.")
return return nil
} }
sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path) sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path)
processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate) err := processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate)
processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate) if err != nil {
processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate) return fmt.Errorf("failed to process table sql builder types: %w", err)
} }
func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) { err = processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate)
if err != nil {
return fmt.Errorf("failed to process view sql builder types: %w", err)
}
err = processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate)
if err != nil {
return fmt.Errorf("failed to process enum types: %w", err)
}
return nil
}
func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) error {
if len(enumsMetaData) == 0 { if len(enumsMetaData) == 0 {
return return nil
} }
fmt.Printf("Generating enum sql builder files\n") fmt.Printf("Generating enum sql builder files\n")
@ -81,7 +120,9 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []
enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path) enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path)
err := utils.EnsureDirPath(enumSQLBuilderPath) err := utils.EnsureDirPath(enumSQLBuilderPath)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to create enum sql builder directory - %s: %w", enumSQLBuilderPath, err)
}
text, err := generateTemplate( text, err := generateTemplate(
autoGenWarningTemplate+enumSQLBuilderTemplate, autoGenWarningTemplate+enumSQLBuilderTemplate,
@ -100,21 +141,27 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []
return enumTemplate.ValueName(enumValue) return enumTemplate.ValueName(enumValue)
}, },
}) })
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generete enum type %s: %w", enumTemplate.FileName, err)
}
err = utils.SaveGoFile(enumSQLBuilderPath, enumTemplate.FileName, text) err = utils.SaveGoFile(enumSQLBuilderPath, enumTemplate.FileName, text)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to format and save '%s' enum type : %w", enumTemplate.FileName, err)
} }
} }
return nil
}
func processTableSQLBuilder(fileTypes, dirPath string, func processTableSQLBuilder(fileTypes, dirPath string,
dialect jet.Dialect, dialect jet.Dialect,
schemaMetaData metadata.Schema, schemaMetaData metadata.Schema,
tablesMetaData []metadata.Table, tablesMetaData []metadata.Table,
sqlBuilderTemplate SQLBuilder) { sqlBuilderTemplate SQLBuilder) error {
if len(tablesMetaData) == 0 { if len(tablesMetaData) == 0 {
return return nil
} }
fmt.Printf("Generating %s sql builder files\n", fileTypes) fmt.Printf("Generating %s sql builder files\n", fileTypes)
@ -122,7 +169,6 @@ func processTableSQLBuilder(fileTypes, dirPath string,
var generatedBuilders []TableSQLBuilder var generatedBuilders []TableSQLBuilder
for _, tableMetaData := range tablesMetaData { for _, tableMetaData := range tablesMetaData {
var tableSQLBuilder TableSQLBuilder var tableSQLBuilder TableSQLBuilder
if fileTypes == "view" { if fileTypes == "view" {
@ -138,7 +184,9 @@ func processTableSQLBuilder(fileTypes, dirPath string,
tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilder.Path) tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilder.Path)
err := utils.EnsureDirPath(tableSQLBuilderPath) err := utils.EnsureDirPath(tableSQLBuilderPath)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to create table sql builder directory - %s: %w", tableSQLBuilderPath, err)
}
text, err := generateTemplate( text, err := generateTemplate(
autoGenWarningTemplate+tableSQLBuilderTemplate, autoGenWarningTemplate+tableSQLBuilderTemplate,
@ -168,20 +216,30 @@ func processTableSQLBuilder(fileTypes, dirPath string,
return insertedRowAlias(dialect) return insertedRowAlias(dialect)
}, },
}) })
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate table sql builder type %s: %w", tableSQLBuilder.TypeName, err)
}
err = utils.SaveGoFile(tableSQLBuilderPath, tableSQLBuilder.FileName, text) err = utils.SaveGoFile(tableSQLBuilderPath, tableSQLBuilder.FileName, text)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to format and save generated sql builder type '%s': %w", tableSQLBuilder.FileName, err)
}
generatedBuilders = append(generatedBuilders, tableSQLBuilder) generatedBuilders = append(generatedBuilders, tableSQLBuilder)
} }
if len(generatedBuilders) > 0 { err := generateUseSchemaFunc(dirPath, fileTypes, generatedBuilders)
generateUseSchemaFunc(dirPath, fileTypes, generatedBuilders) if err != nil {
} return fmt.Errorf("failed to generate UseSchema function")
} }
func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder) { return nil
}
func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder) error {
if len(builders) == 0 {
return nil
}
text, err := generateTemplate( text, err := generateTemplate(
autoGenWarningTemplate+tableSqlBuilderSetSchemaTemplate, autoGenWarningTemplate+tableSqlBuilderSetSchemaTemplate,
@ -191,13 +249,19 @@ func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder
"type": func() string { return fileTypes }, "type": func() string { return fileTypes },
}, },
) )
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate use schema template: %w", err)
}
basePath := path.Join(dirPath, builders[0].Path) basePath := path.Join(dirPath, builders[0].Path)
fileName := fileTypes + "_use_schema" fileName := fileTypes + "_use_schema"
err = utils.SaveGoFile(basePath, fileName, text) err = utils.SaveGoFile(basePath, fileName, text)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to save %s file: %w", fileName, err)
}
return nil
} }
func insertedRowAlias(dialect jet.Dialect) string { func insertedRowAlias(dialect jet.Dialect) string {
@ -208,9 +272,9 @@ func insertedRowAlias(dialect jet.Dialect) string {
return "excluded" return "excluded"
} }
func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) { func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) error {
if len(tablesMetaData) == 0 { if len(tablesMetaData) == 0 {
return return nil
} }
fmt.Printf("Generating %s model files...\n", fileTypes) fmt.Printf("Generating %s model files...\n", fileTypes)
@ -244,16 +308,22 @@ func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadat
return tableTemplate.Field(columnMetaData) return tableTemplate.Field(columnMetaData)
}, },
}) })
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate model type '%s': %w", tableMetaData.Name, err)
}
err = utils.SaveGoFile(modelDirPath, tableTemplate.FileName, text) err = utils.SaveGoFile(modelDirPath, tableTemplate.FileName, text)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to save '%s' model type: %w", tableTemplate.FileName, err)
} }
} }
func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) { return nil
}
func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) error {
if len(enumsMetaData) == 0 { if len(enumsMetaData) == 0 {
return return nil
} }
fmt.Print("Generating enum model files...\n") fmt.Print("Generating enum model files...\n")
@ -278,23 +348,30 @@ func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemp
return enumTemplate.ValueName(value) return enumTemplate.ValueName(value)
}, },
}) })
throw.OnError(err)
if err != nil {
return fmt.Errorf("failed to generate enum type '%s': %w", enumMetaData.Name, err)
}
err = utils.SaveGoFile(modelDir, enumTemplate.FileName, text) err = utils.SaveGoFile(modelDir, enumTemplate.FileName, text)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to save '%s' enum type: %w", enumTemplate.FileName, err)
} }
} }
return nil
}
func generateTemplate(templateText string, templateData interface{}, funcMap template.FuncMap) ([]byte, error) { func generateTemplate(templateText string, templateData interface{}, funcMap template.FuncMap) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(funcMap).Parse(templateText) t, err := template.New("sqlBuilderTableTemplate").Funcs(funcMap).Parse(templateText)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to parse template: %w", err)
} }
var buf bytes.Buffer var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil { if err := t.Execute(&buf, templateData); err != nil {
return nil, err return nil, fmt.Errorf("failed to generate template: %w", err)
} }
return buf.Bytes(), nil return buf.Bytes(), nil

View file

@ -0,0 +1,10 @@
package errfmt
import (
"strings"
)
// Trace returns well formatted wrapped error trace string
func Trace(err error) string {
return "Error trace:\n" + " - " + strings.Replace(err.Error(), ": ", ":\n - ", -1)
}

View file

@ -39,14 +39,16 @@ func SaveGoFile(dirPath, fileName string, text []byte) error {
defer file.Close() defer file.Close()
p, err := format.Source(text) p, err := format.Source(text)
// if there is a format error we will write unformulated text for debug purposes
if err != nil { if err != nil {
return err file.Write(text)
return fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err)
} }
_, err = file.Write(p) _, err = file.Write(p)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to save '%s' file: %w", newGoFilePath, err)
} }
return nil return nil
@ -58,7 +60,7 @@ func EnsureDirPath(dirPath string) error {
err := os.MkdirAll(dirPath, os.ModePerm) err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil { if err != nil {
return err return fmt.Errorf("can't create directory - %s: %w", dirPath, err)
} }
} }

View file

@ -8,13 +8,13 @@ import (
"github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/internal/utils/errfmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo" "github.com/go-jet/jet/v2/tests/internal/utils/repo"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/jackc/pgx/v4/stdlib" _ "github.com/jackc/pgx/v4/stdlib"
@ -40,39 +40,65 @@ const (
) )
func main() { func main() {
var err error
switch strings.ToLower(testSuite) { switch strings.ToLower(testSuite) {
case Postgres: case Postgres:
initPostgresDB(Postgres, dbconfig.PostgresConnectString) err = initPostgresDB(Postgres, dbconfig.PostgresConnectString)
case Cockroach: case Cockroach:
initPostgresDB(Cockroach, dbconfig.CockroachConnectString) err = initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
case MySql: case MySql:
initMySQLDB(false) err = initMySQLDB(false)
case MariaDB: case MariaDB:
initMySQLDB(true) err = initMySQLDB(true)
case Sqlite: case Sqlite:
initSQLiteDB() err = initSQLiteDB()
case "all": case "all":
initPostgresDB(Cockroach, dbconfig.CockroachConnectString) err = initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
initPostgresDB(Postgres, dbconfig.PostgresConnectString) if err != nil {
initMySQLDB(false) break
initMySQLDB(true) }
initSQLiteDB() err = initPostgresDB(Postgres, dbconfig.PostgresConnectString)
if err != nil {
break
}
err = initMySQLDB(false)
if err != nil {
break
}
err = initMySQLDB(true)
if err != nil {
break
}
err = initSQLiteDB()
default: default:
panic("invalid testsuite flag. Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)") panic("invalid testsuite flag. Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)")
} }
if err != nil {
fmt.Println(errfmt.Trace(err))
}
} }
func initSQLiteDB() { func initSQLiteDB() error {
err := sqlite.GenerateDSN(dbconfig.SakilaDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/sakila")) err := sqlite.GenerateDSN(dbconfig.SakilaDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/sakila"))
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate sqlite sakila database types: %w", err)
}
err = sqlite.GenerateDSN(dbconfig.ChinookDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/chinook")) err = sqlite.GenerateDSN(dbconfig.ChinookDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/chinook"))
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate sqlite chinook database types: %w", err)
}
err = sqlite.GenerateDSN(dbconfig.TestSampleDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/test_sample")) err = sqlite.GenerateDSN(dbconfig.TestSampleDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/test_sample"))
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate sqlite test_sample database types: %w", err)
} }
func initMySQLDB(isMariaDB bool) { return nil
}
func initMySQLDB(isMariaDB bool) error {
mySQLDBs := []string{ mySQLDBs := []string{
"dvds", "dvds",
@ -104,7 +130,9 @@ func initMySQLDB(isMariaDB bool) {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
err := cmd.Run() err := cmd.Run()
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to initialize mysql database %s: %w", dbName, err)
}
err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{
Host: host, Host: host,
@ -114,19 +142,20 @@ func initMySQLDB(isMariaDB bool) {
DBName: dbName, DBName: dbName,
}) })
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate jet types for '%s' database: %w", dbName, err)
} }
} }
func initPostgresDB(dbType string, connectionString string) { return nil
}
func initPostgresDB(dbType string, connectionString string) error {
db, err := sql.Open("postgres", connectionString) db, err := sql.Open("postgres", connectionString)
if err != nil { if err != nil {
panic("Failed to connect to test db: " + err.Error()) return fmt.Errorf("failed to open '%s' db connection '%s': %w", dbType, connectionString, err)
} }
defer func() { defer db.Close()
err := db.Close()
printOnError(err)
}()
schemaNames := []string{ schemaNames := []string{
"northwind", "northwind",
@ -139,31 +168,43 @@ func initPostgresDB(dbType string, connectionString string) {
for _, schemaName := range schemaNames { for _, schemaName := range schemaNames {
fmt.Println("\nInitializing", schemaName, "schema...") fmt.Println("\nInitializing", schemaName, "schema...")
execFile(db, fmt.Sprintf("./testdata/init/%s/%s.sql", dbType, schemaName)) err = execFile(db, fmt.Sprintf("./testdata/init/%s/%s.sql", dbType, schemaName))
if err != nil {
return fmt.Errorf("failed to execute sql file: %w", err)
}
err = postgres.GenerateDSN(connectionString, schemaName, "./.gentestdata") err = postgres.GenerateDSN(connectionString, schemaName, "./.gentestdata")
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to generate jet types: %w", err)
} }
} }
func execFile(db *sql.DB, sqlFilePath string) { return nil
}
func execFile(db *sql.DB, sqlFilePath string) error {
testSampleSql, err := ioutil.ReadFile(sqlFilePath) testSampleSql, err := ioutil.ReadFile(sqlFilePath)
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err)
}
err = execInTx(db, func(tx *sql.Tx) error { err = execInTx(db, func(tx *sql.Tx) error {
_, err := tx.Exec(string(testSampleSql)) _, err := tx.Exec(string(testSampleSql))
return err return err
}) })
throw.OnError(err) if err != nil {
return fmt.Errorf("failed to execute sql file - %s: %w", sqlFilePath, err)
}
return nil
} }
func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error { func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
Isolation: sql.LevelReadUncommitted, // to speed up initialization of test database Isolation: sql.LevelReadUncommitted, // to speed up initialization of test database
}) })
if err != nil { if err != nil {
return err return fmt.Errorf("failed to start transaction: %w", err)
} }
err = f(tx) err = f(tx)
@ -173,11 +214,10 @@ func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error {
return err return err
} }
return tx.Commit() err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit transaction")
} }
func printOnError(err error) { return nil
if err != nil {
fmt.Println(err.Error())
}
} }