Improve generator error handling
This commit is contained in:
parent
b38b63d804
commit
06ecd73f67
12 changed files with 386 additions and 176 deletions
|
|
@ -1,10 +1,11 @@
|
|||
package main
|
||||
|
||||
//go:generate sh -c "printf 'package main\n\nconst version = \"'%s'\"' $(git describe --tags --abbrev=0) > version.go"
|
||||
//go:generate sh -c "printf 'package main\n\nconst version = \"'%s'\"\n' $(git describe --tags --abbrev=0) > version.go"
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet/v2/internal/utils/errfmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
|
|
@ -155,8 +156,8 @@ func main() {
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(-5)
|
||||
fmt.Println(errfmt.Trace(err))
|
||||
os.Exit(2)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -192,7 +193,7 @@ func printErrorAndExit(error string) {
|
|||
fmt.Println("\n", error)
|
||||
fmt.Println()
|
||||
flag.Usage()
|
||||
os.Exit(-2)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func getSource() string {
|
||||
|
|
|
|||
|
|
@ -14,23 +14,38 @@ const (
|
|||
ViewTable TableType = "VIEW"
|
||||
)
|
||||
|
||||
// DialectQuerySet is set of methods necessary to retrieve dialect meta data information
|
||||
// DialectQuerySet is set of methods necessary to retrieve dialect metadata information
|
||||
type DialectQuerySet interface {
|
||||
GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) []Table
|
||||
GetEnumsMetaData(db *sql.DB, schemaName string) []Enum
|
||||
GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) ([]Table, error)
|
||||
GetEnumsMetaData(db *sql.DB, schemaName string) ([]Enum, error)
|
||||
}
|
||||
|
||||
// GetSchema retrieves Schema information from database
|
||||
func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema {
|
||||
func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) (Schema, error) {
|
||||
tablesMetaData, err := querySet.GetTablesMetaData(db, schemaName, BaseTable)
|
||||
if err != nil {
|
||||
return Schema{}, fmt.Errorf("failed to get %s tables metadata: %w", schemaName, err)
|
||||
}
|
||||
|
||||
viewMetaData, err := querySet.GetTablesMetaData(db, schemaName, ViewTable)
|
||||
if err != nil {
|
||||
return Schema{}, fmt.Errorf("failed to get %s view metadata: %w", schemaName, err)
|
||||
}
|
||||
|
||||
enumsMetaData, err := querySet.GetEnumsMetaData(db, schemaName)
|
||||
if err != nil {
|
||||
return Schema{}, fmt.Errorf("failed to get %s enum metadata: %w", schemaName, err)
|
||||
}
|
||||
|
||||
ret := Schema{
|
||||
Name: schemaName,
|
||||
TablesMetaData: querySet.GetTablesMetaData(db, schemaName, BaseTable),
|
||||
ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, ViewTable),
|
||||
EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName),
|
||||
TablesMetaData: tablesMetaData,
|
||||
ViewsMetaData: viewMetaData,
|
||||
EnumsMetaData: enumsMetaData,
|
||||
}
|
||||
|
||||
fmt.Println(" FOUND", len(ret.TablesMetaData), "table(s),", len(ret.ViewsMetaData), "view(s),",
|
||||
len(ret.EnumsMetaData), "enum(s)")
|
||||
|
||||
return ret
|
||||
return ret, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ package mysql
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-jet/jet/v2/generator/metadata"
|
||||
"github.com/go-jet/jet/v2/generator/template"
|
||||
"github.com/go-jet/jet/v2/internal/utils"
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
"github.com/go-jet/jet/v2/mysql"
|
||||
mysqldr "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
|
@ -25,26 +25,28 @@ type DBConnection struct {
|
|||
}
|
||||
|
||||
// Generate generates jet files at destination dir from database connection details
|
||||
func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) (err error) {
|
||||
defer utils.ErrorCatch(&err)
|
||||
|
||||
func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) error {
|
||||
connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName)
|
||||
if dbConn.Params != "" {
|
||||
connectionString += "?" + dbConn.Params
|
||||
}
|
||||
|
||||
db := openConnection(connectionString)
|
||||
db, err := openConnection(connectionString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open db connection: %w", err)
|
||||
}
|
||||
defer utils.DBClose(db)
|
||||
|
||||
generate(db, dbConn.DBName, destDir, generatorTemplate...)
|
||||
err = generate(db, dbConn.DBName, destDir, generatorTemplate...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateDSN opens connection via DSN string and does everything what Generate does.
|
||||
func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) {
|
||||
defer utils.ErrorCatch(&err)
|
||||
|
||||
func GenerateDSN(dsn, destDir string, templates ...template.Template) error {
|
||||
// Special case for go mysql driver. It does not understand schema,
|
||||
// so we need to trim it before passing to generator
|
||||
// https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
|
|
@ -54,39 +56,59 @@ func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error
|
|||
}
|
||||
|
||||
cfg, err := mysqldr.ParseDSN(dsn)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse DSN: %w", err)
|
||||
}
|
||||
if cfg.DBName == "" {
|
||||
panic("database name is required")
|
||||
return errors.New("database name is required")
|
||||
}
|
||||
|
||||
db := openConnection(dsn)
|
||||
db, err := openConnection(dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open db connection: %w", err)
|
||||
}
|
||||
defer utils.DBClose(db)
|
||||
|
||||
generate(db, cfg.DBName, destDir, templates...)
|
||||
err = generate(db, cfg.DBName, destDir, templates...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openConnection(connectionString string) *sql.DB {
|
||||
func openConnection(connectionString string) (*sql.DB, error) {
|
||||
fmt.Println("Connecting to MySQL database...")
|
||||
db, err := sql.Open("mysql", connectionString)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open mysql connection: %w", err)
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return db
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) {
|
||||
func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) error {
|
||||
fmt.Println("Retrieving database information...")
|
||||
// No schemas in MySQL
|
||||
schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbName)
|
||||
schemaMetaData, err := metadata.GetSchema(db, &mySqlQuerySet{}, dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get '%s' database metadata: %w", dbName, err)
|
||||
}
|
||||
|
||||
genTemplate := template.Default(mysql.Dialect)
|
||||
if len(templates) > 0 {
|
||||
genTemplate = templates[0]
|
||||
}
|
||||
|
||||
template.ProcessSchema(destDir, schemaMetaData, genTemplate)
|
||||
err = template.ProcessSchema(destDir, schemaMetaData, genTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process '%s' database: %w", schemaMetaData.Name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,17 +3,17 @@ package mysql
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-jet/jet/v2/generator/metadata"
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
)
|
||||
|
||||
// mySqlQuerySet is dialect query set for MySQL
|
||||
type mySqlQuerySet struct{}
|
||||
|
||||
func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
|
||||
func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
|
||||
query := `
|
||||
SELECT table_name as "table.name"
|
||||
FROM INFORMATION_SCHEMA.tables
|
||||
|
|
@ -23,16 +23,21 @@ ORDER BY table_name;
|
|||
var tables []metadata.Table
|
||||
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
|
||||
throw.OnError(err)
|
||||
|
||||
for i := range tables {
|
||||
tables[i].Columns = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query %s metadata result: %w", tableType, err)
|
||||
}
|
||||
|
||||
return tables
|
||||
for i := range tables {
|
||||
tables[i].Columns, err = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get '%s' table columns metadata: %w", tables[i].Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
|
||||
func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) {
|
||||
query := `
|
||||
SELECT COLUMN_NAME AS "column.Name",
|
||||
IS_NULLABLE = "YES" AS "column.IsNullable",
|
||||
|
|
@ -57,12 +62,14 @@ ORDER BY ordinal_position;
|
|||
`
|
||||
var columns []metadata.Column
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query %s column meta data: %w", tableName, err)
|
||||
}
|
||||
|
||||
return columns
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
|
||||
func (m mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
|
||||
query := `
|
||||
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name",
|
||||
SUBSTRING(c.COLUMN_TYPE,5) as "values"
|
||||
|
|
@ -76,7 +83,9 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
|
|||
}
|
||||
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query enums meta data: %w", err)
|
||||
}
|
||||
|
||||
var ret []metadata.Enum
|
||||
|
||||
|
|
@ -89,5 +98,5 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
|
|||
})
|
||||
}
|
||||
|
||||
return ret
|
||||
return ret, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/go-jet/jet/v2/generator/metadata"
|
||||
"github.com/go-jet/jet/v2/generator/template"
|
||||
"github.com/go-jet/jet/v2/internal/utils"
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/jackc/pgconn"
|
||||
)
|
||||
|
|
@ -43,15 +42,18 @@ func Generate(destDir string, dbConn DBConnection, genTemplate ...template.Templ
|
|||
}
|
||||
|
||||
// GenerateDSN generates jet files using dsn connection string
|
||||
func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (err error) {
|
||||
defer utils.ErrorCatch(&err)
|
||||
|
||||
func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) error {
|
||||
cfg, err := pgconn.ParseConfig(dsn)
|
||||
throw.OnError(err)
|
||||
if cfg.Database == "" {
|
||||
panic("database name is required")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse config: %w", err)
|
||||
}
|
||||
if cfg.Database == "" {
|
||||
return fmt.Errorf("database name is required")
|
||||
}
|
||||
db, err := openConnection(dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open db connection: %w", err)
|
||||
}
|
||||
db := openConnection(dsn)
|
||||
defer utils.DBClose(db)
|
||||
|
||||
fmt.Println("Retrieving schema information...")
|
||||
|
|
@ -60,22 +62,33 @@ func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (e
|
|||
generatorTemplate = templates[0]
|
||||
}
|
||||
|
||||
schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, schema)
|
||||
schemaMetadata, err := metadata.GetSchema(db, &postgresQuerySet{}, schema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get '%s' schema metadata: %w", schema, err)
|
||||
}
|
||||
|
||||
dirPath := path.Join(destDir, cfg.Database)
|
||||
|
||||
template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate)
|
||||
return
|
||||
err = template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate schema %s: %d", schemaMetadata.Name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openConnection(dsn string) *sql.DB {
|
||||
func openConnection(dsn string) (*sql.DB, error) {
|
||||
fmt.Println("Connecting to postgres database...")
|
||||
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open db connection: %w", err)
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return db
|
||||
return db, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@ package postgres
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-jet/jet/v2/generator/metadata"
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
)
|
||||
|
||||
// postgresQuerySet is dialect query set for PostgreSQL
|
||||
type postgresQuerySet struct{}
|
||||
|
||||
func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
|
||||
func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
|
||||
query := `
|
||||
SELECT table_name as "table.name"
|
||||
FROM information_schema.tables
|
||||
|
|
@ -22,16 +22,21 @@ ORDER BY table_name;
|
|||
var tables []metadata.Table
|
||||
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
|
||||
throw.OnError(err)
|
||||
|
||||
for i := range tables {
|
||||
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query %s metadata: %w", tableType, err)
|
||||
}
|
||||
|
||||
return tables
|
||||
for i := range tables {
|
||||
tables[i].Columns, err = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query %s columns metadata: %w", tableType, err)
|
||||
}
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
|
||||
func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) {
|
||||
query := `
|
||||
WITH primaryKeys AS (
|
||||
SELECT column_name
|
||||
|
|
@ -67,12 +72,14 @@ order by ordinal_position;
|
|||
`
|
||||
var columns []metadata.Column
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query '%s' columns metadata: %w", tableName, err)
|
||||
}
|
||||
|
||||
return columns
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
|
||||
func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
|
||||
query := `
|
||||
SELECT t.typname as "enum.name",
|
||||
e.enumlabel as "values"
|
||||
|
|
@ -85,7 +92,9 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
|
|||
var result []metadata.Enum
|
||||
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query enums metadata for schema '%s': %w", schemaName, err)
|
||||
}
|
||||
|
||||
return result
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet/v2/generator/metadata"
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"strings"
|
||||
)
|
||||
|
|
@ -13,7 +12,7 @@ import (
|
|||
// sqliteQuerySet is dialect query set for SQLite
|
||||
type sqliteQuerySet struct{}
|
||||
|
||||
func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
|
||||
func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) ([]metadata.Table, error) {
|
||||
query := `
|
||||
SELECT name as "table.name"
|
||||
FROM sqlite_master
|
||||
|
|
@ -29,16 +28,21 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy
|
|||
var tables []metadata.Table
|
||||
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables)
|
||||
throw.OnError(err)
|
||||
|
||||
for i := range tables {
|
||||
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query %s metadata: %w", schemaName, err)
|
||||
}
|
||||
|
||||
return tables
|
||||
for i := range tables {
|
||||
tables[i].Columns, err = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query column metadata: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
|
||||
func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) ([]metadata.Column, error) {
|
||||
query := fmt.Sprintf(`select * from pragma_table_info(?);`)
|
||||
var columnInfos []struct {
|
||||
Name string
|
||||
|
|
@ -48,7 +52,9 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
|
|||
}
|
||||
|
||||
_, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query '%s' column metadata: %w", tableName, err)
|
||||
}
|
||||
|
||||
var columns []metadata.Column
|
||||
|
||||
|
|
@ -67,7 +73,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
|
|||
})
|
||||
}
|
||||
|
||||
return columns
|
||||
return columns, nil
|
||||
}
|
||||
|
||||
// will convert VARCHAR(10) -> VARCHAR, etc...
|
||||
|
|
@ -75,6 +81,6 @@ func getColumnType(columnType string) string {
|
|||
return strings.TrimSpace(strings.Split(columnType, "(")[0])
|
||||
}
|
||||
|
||||
func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
|
||||
return nil
|
||||
func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]metadata.Enum, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,16 +6,15 @@ import (
|
|||
"github.com/go-jet/jet/v2/generator/metadata"
|
||||
"github.com/go-jet/jet/v2/generator/template"
|
||||
"github.com/go-jet/jet/v2/internal/utils"
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
"github.com/go-jet/jet/v2/sqlite"
|
||||
)
|
||||
|
||||
// GenerateDSN generates jet files using dsn connection string
|
||||
func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) {
|
||||
defer utils.ErrorCatch(&err)
|
||||
|
||||
func GenerateDSN(dsn, destDir string, templates ...template.Template) error {
|
||||
db, err := sql.Open("sqlite3", dsn)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open sqlite connection: %w", err)
|
||||
}
|
||||
defer utils.DBClose(db)
|
||||
|
||||
fmt.Println("Retrieving schema information...")
|
||||
|
|
@ -25,8 +24,15 @@ func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error
|
|||
generatorTemplate = templates[0]
|
||||
}
|
||||
|
||||
schemaMetadata := metadata.GetSchema(db, &sqliteQuerySet{}, "")
|
||||
schemaMetadata, err := metadata.GetSchema(db, &sqliteQuerySet{}, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query database metadata: %w", err)
|
||||
}
|
||||
|
||||
template.ProcessSchema(destDir, schemaMetadata, generatorTemplate)
|
||||
return
|
||||
err = template.ProcessSchema(destDir, schemaMetadata, generatorTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process database %s: %w", schemaMetadata.Name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package template
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
|
|
@ -10,13 +11,12 @@ import (
|
|||
"github.com/go-jet/jet/v2/generator/metadata"
|
||||
"github.com/go-jet/jet/v2/internal/jet"
|
||||
"github.com/go-jet/jet/v2/internal/utils"
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
)
|
||||
|
||||
// ProcessSchema will process schema metadata and constructs go files using generator Template
|
||||
func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) {
|
||||
func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) error {
|
||||
if schemaMetaData.IsEmpty() {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
schemaTemplate := generatorTemplate.Schema(schemaMetaData)
|
||||
|
|
@ -25,48 +25,87 @@ func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemp
|
|||
fmt.Println("Destination directory:", schemaPath)
|
||||
fmt.Println("Cleaning up destination directory...")
|
||||
err := utils.CleanUpGeneratedFiles(schemaPath)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return errors.New("failed to cleanup generated files")
|
||||
}
|
||||
|
||||
processModel(schemaPath, schemaMetaData, schemaTemplate)
|
||||
processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate)
|
||||
err = processModel(schemaPath, schemaMetaData, schemaTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate model types: %w", err)
|
||||
}
|
||||
|
||||
err = processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate sql builder types: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) {
|
||||
func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) error {
|
||||
modelTemplate := schemaTemplate.Model
|
||||
|
||||
if modelTemplate.Skip {
|
||||
fmt.Println("Skipping the generation of model types.")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
modelDirPath := path.Join(dirPath, modelTemplate.Path)
|
||||
|
||||
err := utils.EnsureDirPath(modelDirPath)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("destination dir path does not exist: %w", err)
|
||||
}
|
||||
|
||||
processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate)
|
||||
processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate)
|
||||
processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate)
|
||||
err = processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate table model types: %w", err)
|
||||
}
|
||||
|
||||
err = processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate view model types: %w", err)
|
||||
}
|
||||
|
||||
err = processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process enum types: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) {
|
||||
func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) error {
|
||||
sqlBuilderTemplate := schemaTemplate.SQLBuilder
|
||||
|
||||
if sqlBuilderTemplate.Skip {
|
||||
fmt.Println("Skipping the generation of SQL Builder types.")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path)
|
||||
|
||||
processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate)
|
||||
processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate)
|
||||
processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate)
|
||||
err := processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process table sql builder types: %w", err)
|
||||
}
|
||||
|
||||
err = processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process view sql builder types: %w", err)
|
||||
}
|
||||
|
||||
err = processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process enum types: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) {
|
||||
func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) error {
|
||||
if len(enumsMetaData) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Generating enum sql builder files\n")
|
||||
|
|
@ -81,7 +120,9 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []
|
|||
enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path)
|
||||
|
||||
err := utils.EnsureDirPath(enumSQLBuilderPath)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create enum sql builder directory - %s: %w", enumSQLBuilderPath, err)
|
||||
}
|
||||
|
||||
text, err := generateTemplate(
|
||||
autoGenWarningTemplate+enumSQLBuilderTemplate,
|
||||
|
|
@ -100,21 +141,27 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []
|
|||
return enumTemplate.ValueName(enumValue)
|
||||
},
|
||||
})
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generete enum type %s: %w", enumTemplate.FileName, err)
|
||||
}
|
||||
|
||||
err = utils.SaveGoFile(enumSQLBuilderPath, enumTemplate.FileName, text)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to format and save '%s' enum type : %w", enumTemplate.FileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTableSQLBuilder(fileTypes, dirPath string,
|
||||
dialect jet.Dialect,
|
||||
schemaMetaData metadata.Schema,
|
||||
tablesMetaData []metadata.Table,
|
||||
sqlBuilderTemplate SQLBuilder) {
|
||||
sqlBuilderTemplate SQLBuilder) error {
|
||||
|
||||
if len(tablesMetaData) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("Generating %s sql builder files\n", fileTypes)
|
||||
|
|
@ -122,7 +169,6 @@ func processTableSQLBuilder(fileTypes, dirPath string,
|
|||
var generatedBuilders []TableSQLBuilder
|
||||
|
||||
for _, tableMetaData := range tablesMetaData {
|
||||
|
||||
var tableSQLBuilder TableSQLBuilder
|
||||
|
||||
if fileTypes == "view" {
|
||||
|
|
@ -138,7 +184,9 @@ func processTableSQLBuilder(fileTypes, dirPath string,
|
|||
tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilder.Path)
|
||||
|
||||
err := utils.EnsureDirPath(tableSQLBuilderPath)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create table sql builder directory - %s: %w", tableSQLBuilderPath, err)
|
||||
}
|
||||
|
||||
text, err := generateTemplate(
|
||||
autoGenWarningTemplate+tableSQLBuilderTemplate,
|
||||
|
|
@ -168,20 +216,30 @@ func processTableSQLBuilder(fileTypes, dirPath string,
|
|||
return insertedRowAlias(dialect)
|
||||
},
|
||||
})
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate table sql builder type %s: %w", tableSQLBuilder.TypeName, err)
|
||||
}
|
||||
|
||||
err = utils.SaveGoFile(tableSQLBuilderPath, tableSQLBuilder.FileName, text)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to format and save generated sql builder type '%s': %w", tableSQLBuilder.FileName, err)
|
||||
}
|
||||
|
||||
generatedBuilders = append(generatedBuilders, tableSQLBuilder)
|
||||
}
|
||||
|
||||
if len(generatedBuilders) > 0 {
|
||||
generateUseSchemaFunc(dirPath, fileTypes, generatedBuilders)
|
||||
err := generateUseSchemaFunc(dirPath, fileTypes, generatedBuilders)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate UseSchema function")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder) {
|
||||
func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder) error {
|
||||
if len(builders) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
text, err := generateTemplate(
|
||||
autoGenWarningTemplate+tableSqlBuilderSetSchemaTemplate,
|
||||
|
|
@ -191,13 +249,19 @@ func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder
|
|||
"type": func() string { return fileTypes },
|
||||
},
|
||||
)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate use schema template: %w", err)
|
||||
}
|
||||
|
||||
basePath := path.Join(dirPath, builders[0].Path)
|
||||
fileName := fileTypes + "_use_schema"
|
||||
|
||||
err = utils.SaveGoFile(basePath, fileName, text)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save %s file: %w", fileName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertedRowAlias(dialect jet.Dialect) string {
|
||||
|
|
@ -208,9 +272,9 @@ func insertedRowAlias(dialect jet.Dialect) string {
|
|||
return "excluded"
|
||||
}
|
||||
|
||||
func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) {
|
||||
func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) error {
|
||||
if len(tablesMetaData) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
fmt.Printf("Generating %s model files...\n", fileTypes)
|
||||
|
||||
|
|
@ -244,16 +308,22 @@ func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadat
|
|||
return tableTemplate.Field(columnMetaData)
|
||||
},
|
||||
})
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate model type '%s': %w", tableMetaData.Name, err)
|
||||
}
|
||||
|
||||
err = utils.SaveGoFile(modelDirPath, tableTemplate.FileName, text)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save '%s' model type: %w", tableTemplate.FileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) {
|
||||
func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) error {
|
||||
if len(enumsMetaData) == 0 {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
fmt.Print("Generating enum model files...\n")
|
||||
|
||||
|
|
@ -278,23 +348,30 @@ func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemp
|
|||
return enumTemplate.ValueName(value)
|
||||
},
|
||||
})
|
||||
throw.OnError(err)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate enum type '%s': %w", enumMetaData.Name, err)
|
||||
}
|
||||
|
||||
err = utils.SaveGoFile(modelDir, enumTemplate.FileName, text)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save '%s' enum type: %w", enumTemplate.FileName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateTemplate(templateText string, templateData interface{}, funcMap template.FuncMap) ([]byte, error) {
|
||||
t, err := template.New("sqlBuilderTableTemplate").Funcs(funcMap).Parse(templateText)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to parse template: %w", err)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := t.Execute(&buf, templateData); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to generate template: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
|
|
|
|||
10
internal/utils/errfmt/errfmt.go
Normal file
10
internal/utils/errfmt/errfmt.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
package errfmt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Trace returns well formatted wrapped error trace string
|
||||
func Trace(err error) string {
|
||||
return "Error trace:\n" + " - " + strings.Replace(err.Error(), ": ", ":\n - ", -1)
|
||||
}
|
||||
|
|
@ -39,14 +39,16 @@ func SaveGoFile(dirPath, fileName string, text []byte) error {
|
|||
defer file.Close()
|
||||
|
||||
p, err := format.Source(text)
|
||||
|
||||
// if there is a format error we will write unformulated text for debug purposes
|
||||
if err != nil {
|
||||
return err
|
||||
file.Write(text)
|
||||
return fmt.Errorf("failed to format '%s', check '%s' for syntax errors: %w", fileName, newGoFilePath, err)
|
||||
}
|
||||
|
||||
_, err = file.Write(p)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to save '%s' file: %w", newGoFilePath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -58,7 +60,7 @@ func EnsureDirPath(dirPath string) error {
|
|||
err := os.MkdirAll(dirPath, os.ModePerm)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("can't create directory - %s: %w", dirPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ import (
|
|||
"github.com/go-jet/jet/v2/generator/mysql"
|
||||
"github.com/go-jet/jet/v2/generator/postgres"
|
||||
"github.com/go-jet/jet/v2/generator/sqlite"
|
||||
"github.com/go-jet/jet/v2/internal/utils/errfmt"
|
||||
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||
"github.com/go-jet/jet/v2/tests/dbconfig"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
|
|
@ -40,39 +40,65 @@ const (
|
|||
)
|
||||
|
||||
func main() {
|
||||
var err error
|
||||
|
||||
switch strings.ToLower(testSuite) {
|
||||
case Postgres:
|
||||
initPostgresDB(Postgres, dbconfig.PostgresConnectString)
|
||||
err = initPostgresDB(Postgres, dbconfig.PostgresConnectString)
|
||||
case Cockroach:
|
||||
initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
|
||||
err = initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
|
||||
case MySql:
|
||||
initMySQLDB(false)
|
||||
err = initMySQLDB(false)
|
||||
case MariaDB:
|
||||
initMySQLDB(true)
|
||||
err = initMySQLDB(true)
|
||||
case Sqlite:
|
||||
initSQLiteDB()
|
||||
err = initSQLiteDB()
|
||||
case "all":
|
||||
initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
|
||||
initPostgresDB(Postgres, dbconfig.PostgresConnectString)
|
||||
initMySQLDB(false)
|
||||
initMySQLDB(true)
|
||||
initSQLiteDB()
|
||||
err = initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = initPostgresDB(Postgres, dbconfig.PostgresConnectString)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = initMySQLDB(false)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = initMySQLDB(true)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = initSQLiteDB()
|
||||
default:
|
||||
panic("invalid testsuite flag. Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(errfmt.Trace(err))
|
||||
}
|
||||
}
|
||||
|
||||
func initSQLiteDB() {
|
||||
func initSQLiteDB() error {
|
||||
err := sqlite.GenerateDSN(dbconfig.SakilaDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/sakila"))
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate sqlite sakila database types: %w", err)
|
||||
}
|
||||
err = sqlite.GenerateDSN(dbconfig.ChinookDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/chinook"))
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate sqlite chinook database types: %w", err)
|
||||
}
|
||||
err = sqlite.GenerateDSN(dbconfig.TestSampleDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/test_sample"))
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate sqlite test_sample database types: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initMySQLDB(isMariaDB bool) {
|
||||
func initMySQLDB(isMariaDB bool) error {
|
||||
|
||||
mySQLDBs := []string{
|
||||
"dvds",
|
||||
|
|
@ -104,7 +130,9 @@ func initMySQLDB(isMariaDB bool) {
|
|||
cmd.Stdout = os.Stdout
|
||||
|
||||
err := cmd.Run()
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize mysql database %s: %w", dbName, err)
|
||||
}
|
||||
|
||||
err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{
|
||||
Host: host,
|
||||
|
|
@ -114,19 +142,20 @@ func initMySQLDB(isMariaDB bool) {
|
|||
DBName: dbName,
|
||||
})
|
||||
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jet types for '%s' database: %w", dbName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initPostgresDB(dbType string, connectionString string) {
|
||||
func initPostgresDB(dbType string, connectionString string) error {
|
||||
db, err := sql.Open("postgres", connectionString)
|
||||
if err != nil {
|
||||
panic("Failed to connect to test db: " + err.Error())
|
||||
return fmt.Errorf("failed to open '%s' db connection '%s': %w", dbType, connectionString, err)
|
||||
}
|
||||
defer func() {
|
||||
err := db.Close()
|
||||
printOnError(err)
|
||||
}()
|
||||
defer db.Close()
|
||||
|
||||
schemaNames := []string{
|
||||
"northwind",
|
||||
|
|
@ -139,31 +168,43 @@ func initPostgresDB(dbType string, connectionString string) {
|
|||
for _, schemaName := range schemaNames {
|
||||
fmt.Println("\nInitializing", schemaName, "schema...")
|
||||
|
||||
execFile(db, fmt.Sprintf("./testdata/init/%s/%s.sql", dbType, schemaName))
|
||||
err = execFile(db, fmt.Sprintf("./testdata/init/%s/%s.sql", dbType, schemaName))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute sql file: %w", err)
|
||||
}
|
||||
|
||||
err = postgres.GenerateDSN(connectionString, schemaName, "./.gentestdata")
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate jet types: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func execFile(db *sql.DB, sqlFilePath string) {
|
||||
func execFile(db *sql.DB, sqlFilePath string) error {
|
||||
testSampleSql, err := ioutil.ReadFile(sqlFilePath)
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read sql file - %s: %w", sqlFilePath, err)
|
||||
}
|
||||
|
||||
err = execInTx(db, func(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(string(testSampleSql))
|
||||
return err
|
||||
})
|
||||
throw.OnError(err)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute sql file - %s: %w", sqlFilePath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
|
||||
Isolation: sql.LevelReadUncommitted, // to speed up initialization of test database
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to start transaction: %w", err)
|
||||
}
|
||||
|
||||
err = f(tx)
|
||||
|
|
@ -173,11 +214,10 @@ func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func printOnError(err error) {
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return fmt.Errorf("failed to commit transaction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue