Merge pull request #424 from jamius19/pkg-customization

Added CLI Flags for Package name customization for Model, Table, View and Enum
This commit is contained in:
go-jet 2024-11-26 11:46:52 +01:00 committed by GitHub
commit 072b58cd6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 209 additions and 41 deletions

View file

@ -42,7 +42,11 @@ var (
ignoreViews string ignoreViews string
ignoreEnums string ignoreEnums string
destDir string destDir string
modelPkg string
tablePkg string
viewPkg string
enumPkg string
) )
func init() { func init() {
@ -66,11 +70,15 @@ func init() {
flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public")(PostgreSQL only)`) flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public")(PostgreSQL only)`)
flag.StringVar(&params, "params", "", "Additional connection string parameters(optional). Used only if dsn is not set.") flag.StringVar(&params, "params", "", "Additional connection string parameters(optional). Used only if dsn is not set.")
flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL. Used only if dsn is not set. (optional)(default "disable")(PostgreSQL only)`) flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL. Used only if dsn is not set. (optional)(default "disable")(PostgreSQL only)`)
flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore`) flag.StringVar(&ignoreTables, "ignore-tables", "", `Comma-separated list of tables to ignore.`)
flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore`) flag.StringVar(&ignoreViews, "ignore-views", "", `Comma-separated list of views to ignore.`)
flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore`) flag.StringVar(&ignoreEnums, "ignore-enums", "", `Comma-separated list of enums to ignore.`)
flag.StringVar(&destDir, "path", "", "Destination dir for files generated.") flag.StringVar(&destDir, "path", "", "Destination directory for files generated.")
flag.StringVar(&modelPkg, "rel-model-path", "model", "Relative path for the Model files package from the destination directory.")
flag.StringVar(&tablePkg, "rel-table-path", "table", "Relative path for the Table files package from the destination directory.")
flag.StringVar(&viewPkg, "rel-view-path", "view", "Relative path for the View files package from the destination directory.")
flag.StringVar(&enumPkg, "rel-enum-path", "enum", "Relative path for the Enum files package from the destination directory.")
} }
func main() { func main() {
@ -170,6 +178,7 @@ func usage() {
"source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode", "source", "dsn", "host", "port", "user", "password", "dbname", "schema", "params", "sslmode",
"path", "path",
"ignore-tables", "ignore-views", "ignore-enums", "ignore-tables", "ignore-views", "ignore-enums",
"rel-model-path", "rel-table-path", "rel-view-path", "rel-enum-path",
} }
for _, name := range order { for _, name := range order {
@ -186,6 +195,7 @@ func usage() {
$ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds -path=./gen
$ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen $ jet -source=mysql -host=localhost -port=3306 -user=jet -password=jet -dbname=jetdb -path=./gen
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -path=./gen -rel-model-path=./entity
`) `)
} }
@ -246,7 +256,7 @@ func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []strin
return template.Default(dialect). return template.Default(dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema { UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData). return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel(). UseModel(template.DefaultModel().UsePath(modelPkg).
UseTable(func(table metadata.Table) template.TableModel { UseTable(func(table metadata.Table) template.TableModel {
if shouldSkipTable(table) { if shouldSkipTable(table) {
return template.TableModel{Skip: true} return template.TableModel{Skip: true}
@ -271,19 +281,22 @@ func genTemplate(dialect jet.Dialect, ignoreTables []string, ignoreViews []strin
if shouldSkipTable(table) { if shouldSkipTable(table) {
return template.TableSQLBuilder{Skip: true} return template.TableSQLBuilder{Skip: true}
} }
return template.DefaultTableSQLBuilder(table)
return template.DefaultTableSQLBuilder(table).UsePath(tablePkg)
}). }).
UseView(func(table metadata.Table) template.ViewSQLBuilder { UseView(func(table metadata.Table) template.ViewSQLBuilder {
if shouldSkipView(table) { if shouldSkipView(table) {
return template.ViewSQLBuilder{Skip: true} return template.ViewSQLBuilder{Skip: true}
} }
return template.DefaultViewSQLBuilder(table)
return template.DefaultViewSQLBuilder(table).UsePath(viewPkg)
}). }).
UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder { UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder {
if shouldSkipEnum(enum) { if shouldSkipEnum(enum) {
return template.EnumSQLBuilder{Skip: true} return template.EnumSQLBuilder{Skip: true}
} }
return template.DefaultEnumSQLBuilder(enum)
return template.DefaultEnumSQLBuilder(enum).UsePath(enumPkg)
}), }),
) )
}) })

View file

@ -4,7 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"net/url" "net/url"
"path" "path/filepath"
"strconv" "strconv"
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
@ -66,7 +66,7 @@ func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) er
return fmt.Errorf("failed to get '%s' schema metadata: %w", schema, err) return fmt.Errorf("failed to get '%s' schema metadata: %w", schema, err)
} }
dirPath := path.Join(destDir, cfg.Database) dirPath := filepath.Join(destDir, cfg.Database)
err = template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) err = template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate)
if err != nil { if err != nil {

View file

@ -6,7 +6,7 @@ import (
"github.com/go-jet/jet/v2/internal/utils/dbidentifier" "github.com/go-jet/jet/v2/internal/utils/dbidentifier"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgtype" "github.com/jackc/pgtype"
"path" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"time" "time"
@ -23,7 +23,7 @@ type Model struct {
// PackageName returns package name of model types // PackageName returns package name of model types
func (m Model) PackageName() string { func (m Model) PackageName() string {
return path.Base(m.Path) return filepath.Base(m.Path)
} }
// UsePath returns new Model template with replaced file path // UsePath returns new Model template with replaced file path

View file

@ -5,7 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/filesys" "github.com/go-jet/jet/v2/internal/utils/filesys"
"path" "path/filepath"
"strings" "strings"
"text/template" "text/template"
@ -20,7 +20,7 @@ func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemp
} }
schemaTemplate := generatorTemplate.Schema(schemaMetaData) schemaTemplate := generatorTemplate.Schema(schemaMetaData)
schemaPath := path.Join(dirPath, schemaTemplate.Path) schemaPath := filepath.Join(dirPath, schemaTemplate.Path)
fmt.Println("Destination directory:", schemaPath) fmt.Println("Destination directory:", schemaPath)
fmt.Println("Cleaning up destination directory...") fmt.Println("Cleaning up destination directory...")
@ -50,7 +50,7 @@ func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate
return nil return nil
} }
modelDirPath := path.Join(dirPath, modelTemplate.Path) modelDirPath := filepath.Join(dirPath, modelTemplate.Path)
err := filesys.EnsureDirPathExist(modelDirPath) err := filesys.EnsureDirPathExist(modelDirPath)
if err != nil { if err != nil {
@ -83,7 +83,7 @@ func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metad
return nil return nil
} }
sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path) sqlBuilderPath := filepath.Join(dirPath, sqlBuilderTemplate.Path)
err := processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate) err := processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate)
if err != nil { if err != nil {
@ -117,7 +117,7 @@ func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []
continue continue
} }
enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path) enumSQLBuilderPath := filepath.Join(dirPath, enumTemplate.Path)
err := filesys.EnsureDirPathExist(enumSQLBuilderPath) err := filesys.EnsureDirPathExist(enumSQLBuilderPath)
if err != nil { if err != nil {
@ -182,7 +182,7 @@ func processTableSQLBuilder(fileTypes, dirPath string,
continue continue
} }
tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilder.Path) tableSQLBuilderPath := filepath.Join(dirPath, tableSQLBuilder.Path)
err := filesys.EnsureDirPathExist(tableSQLBuilderPath) err := filesys.EnsureDirPathExist(tableSQLBuilderPath)
if err != nil { if err != nil {
@ -255,7 +255,7 @@ func generateUseSchemaFunc(dirPath, fileTypes string, builders []TableSQLBuilder
return fmt.Errorf("failed to generate use schema template: %w", err) return fmt.Errorf("failed to generate use schema template: %w", err)
} }
basePath := path.Join(dirPath, builders[0].Path) basePath := filepath.Join(dirPath, builders[0].Path)
fileName := fileTypes + "_use_schema" fileName := fileTypes + "_use_schema"
err = filesys.FormatAndSaveGoFile(basePath, fileName, text) err = filesys.FormatAndSaveGoFile(basePath, fileName, text)

View file

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/dbidentifier" "github.com/go-jet/jet/v2/internal/utils/dbidentifier"
"path" "path/filepath"
"slices" "slices"
"strings" "strings"
"unicode" "unicode"
@ -90,7 +90,7 @@ func DefaultViewSQLBuilder(viewMetaData metadata.Table) ViewSQLBuilder {
// PackageName returns package name of table sql builder types // PackageName returns package name of table sql builder types
func (tb TableSQLBuilder) PackageName() string { func (tb TableSQLBuilder) PackageName() string {
return path.Base(tb.Path) return filepath.Base(tb.Path)
} }
// UsePath returns new TableSQLBuilder with new relative path set // UsePath returns new TableSQLBuilder with new relative path set
@ -228,7 +228,7 @@ func DefaultEnumSQLBuilder(enumMetaData metadata.Enum) EnumSQLBuilder {
// PackageName returns enum sql builder package name // PackageName returns enum sql builder package name
func (e EnumSQLBuilder) PackageName() string { func (e EnumSQLBuilder) PackageName() string {
return path.Base(e.Path) return filepath.Base(e.Path)
} }
// UsePath returns new EnumSQLBuilder with new path set // UsePath returns new EnumSQLBuilder with new path set

View file

@ -3,13 +3,13 @@ package file
import ( import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"os" "os"
"path" "path/filepath"
"testing" "testing"
) )
// Exists expects file to exist on path constructed from pathElems and returns content of the file // Exists expects file to exist on path constructed from pathElems and returns content of the file
func Exists(t *testing.T, pathElems ...string) (fileContent string) { func Exists(t *testing.T, pathElems ...string) (fileContent string) {
modelFilePath := path.Join(pathElems...) modelFilePath := filepath.Join(pathElems...)
file, err := os.ReadFile(modelFilePath) // #nosec G304 file, err := os.ReadFile(modelFilePath) // #nosec G304
require.Nil(t, err) require.Nil(t, err)
require.NotEmpty(t, file) require.NotEmpty(t, file)
@ -18,7 +18,7 @@ func Exists(t *testing.T, pathElems ...string) (fileContent string) {
// NotExists expects file not to exist on path constructed from pathElems // NotExists expects file not to exist on path constructed from pathElems
func NotExists(t *testing.T, pathElems ...string) { func NotExists(t *testing.T, pathElems ...string) {
modelFilePath := path.Join(pathElems...) modelFilePath := filepath.Join(pathElems...)
_, err := os.ReadFile(modelFilePath) // #nosec G304 _, err := os.ReadFile(modelFilePath) // #nosec G304
require.True(t, os.IsNotExist(err)) require.True(t, os.IsNotExist(err))
} }

View file

@ -12,18 +12,18 @@ import (
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" file2 "github.com/go-jet/jet/v2/tests/internal/utils/file"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"path" "path/filepath"
"testing" "testing"
) )
const tempTestDir = "./.tempTestDir" const tempTestDir = "./.tempTestDir"
var defaultModelPath = path.Join(tempTestDir, "dvds/model") var defaultModelPath = filepath.Join(tempTestDir, "dvds/model")
var defaultActorModelFilePath = path.Join(tempTestDir, "dvds/model", "actor.go") var defaultActorModelFilePath = filepath.Join(tempTestDir, "dvds/model", "actor.go")
var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table") var defaultTableSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/table")
var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "dvds/view") var defaultViewSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/view")
var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "dvds/enum") var defaultEnumSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/enum")
var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table", "actor.go") var defaultActorSQLBuilderFilePath = filepath.Join(tempTestDir, "dvds/table", "actor.go")
func dbConnection(dbName string) mysql2.DBConnection { func dbConnection(dbName string) mysql2.DBConnection {
if sourceIsMariaDB() { if sourceIsMariaDB() {

View file

@ -3,7 +3,7 @@ package postgres
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"path" "path/filepath"
"testing" "testing"
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
@ -20,13 +20,13 @@ import (
const tempTestDir = "./.tempTestDir" const tempTestDir = "./.tempTestDir"
var defaultModelPath = path.Join(tempTestDir, "jetdb/dvds/model") var defaultModelPath = filepath.Join(tempTestDir, "jetdb/dvds/model")
var defaultSqlBuilderPath = path.Join(tempTestDir, "jetdb/dvds/table") var defaultSqlBuilderPath = filepath.Join(tempTestDir, "jetdb/dvds/table")
var defaultActorModelFilePath = path.Join(tempTestDir, "jetdb/dvds/model", "actor.go") var defaultActorModelFilePath = filepath.Join(tempTestDir, "jetdb/dvds/model", "actor.go")
var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table") var defaultTableSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/table")
var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/view") var defaultViewSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/view")
var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/enum") var defaultEnumSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/enum")
var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table", "actor.go") var defaultActorSQLBuilderFilePath = filepath.Join(tempTestDir, "jetdb/dvds/table", "actor.go")
var dbConnection = postgres.DBConnection{ var dbConnection = postgres.DBConnection{
Host: dbconfig.PgHost, Host: dbconfig.PgHost,

View file

@ -6,6 +6,7 @@ import (
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"testing" "testing"
@ -108,6 +109,76 @@ func TestCmdGenerator(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestCmdGeneratorWithPkgNames(t *testing.T) {
err := os.RemoveAll(genTestDir2)
require.NoError(t, err)
// Testing with custom package paths
modelPath := "./newmodel"
tablePath := "./newtable"
viewPath := "./newview"
enumPath := "./newenum"
cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost",
"-port="+strconv.Itoa(dbconfig.PgPort),
"-user=jet",
"-password=jet",
"-schema=dvds",
"-path="+genTestDir2,
"-rel-model-path="+modelPath,
"-rel-table-path="+tablePath,
"-rel-view-path="+viewPath,
"-rel-enum-path="+enumPath)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
err = cmd.Run()
require.NoError(t, err)
assertGeneratedFilesWithPkgNames(
t,
modelPath,
tablePath,
viewPath,
enumPath,
)
err = os.RemoveAll(genTestDir2)
require.NoError(t, err)
// Testing with nested paths
modelPath = "./db/newmodel"
tablePath = "./db/newtable"
viewPath = "./db/newview"
enumPath = "./db/newenum"
cmd = exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost",
"-port="+strconv.Itoa(dbconfig.PgPort),
"-user=jet",
"-password=jet",
"-schema=dvds",
"-path="+genTestDir2,
"-rel-model-path="+modelPath,
"-rel-table-path="+tablePath,
"-rel-view-path="+viewPath,
"-rel-enum-path="+enumPath)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
err = cmd.Run()
require.NoError(t, err)
assertGeneratedFilesWithPkgNames(
t,
modelPath,
tablePath,
viewPath,
enumPath,
)
}
func TestGeneratorIgnoreTables(t *testing.T) { func TestGeneratorIgnoreTables(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -311,6 +382,90 @@ func assertGeneratedFiles(t *testing.T) {
testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", actorModelFile) testutils.AssertFileContent(t, "./.gentestdata2/jetdb/dvds/model/actor.go", actorModelFile)
} }
func assertGeneratedFilesWithPkgNames(t *testing.T, modelPkgPath, tablePkgPath, viewPkgPath, enumPkgPath string) {
// We can get the package names from the base of the package paths for
// replacing package names in the default file content strings
modelPkg := filepath.Base(modelPkgPath)
tablePkg := filepath.Base(tablePkgPath)
viewPkg := filepath.Base(viewPkgPath)
enumPkg := filepath.Base(enumPkgPath)
// Table SQL Builder files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath),
"actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "table_use_schema.go",
)
testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath, "actor.go"),
getFileContentWithNewPkg(tablePkg, actorSQLBuilderFile),
)
testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", tablePkgPath, "table_use_schema.go"),
getFileContentWithNewPkg(tablePkg, tableUseSchemaFile),
)
// View SQL Builder files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath),
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go",
"sales_by_film_category.go", "customer_list.go", "sales_by_store.go", "staff_list.go", "view_use_schema.go",
)
testutils.AssertFileContent(t,
filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath, "actor_info.go"),
getFileContentWithNewPkg(viewPkg, actorInfoSQLBuilderFile),
)
testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", viewPkgPath, "view_use_schema.go"),
getFileContentWithNewPkg(viewPkg, viewUseSchemaFile),
)
// Enums SQL Builder files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", enumPkgPath),
"mpaa_rating.go",
)
testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", enumPkgPath, "mpaa_rating.go"),
getFileContentWithNewPkg(enumPkg, mpaaRatingEnumFile),
)
// Model files
testutils.AssertFileNamesEqual(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", modelPkgPath),
"actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go", "mpaa_rating.go",
"actor_info.go", "film_list.go", "nicer_but_slower_film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go",
)
testutils.AssertFileContent(
t,
filepath.Join("./.gentestdata2/jetdb/dvds/", modelPkgPath, "actor.go"),
getFileContentWithNewPkg(modelPkg, actorModelFile),
)
}
func getFileContentWithNewPkg(pkgName, fileContent string) string {
regex := regexp.MustCompile(`package \w+`)
return regex.ReplaceAllString(fileContent, "package "+pkgName)
}
var mpaaRatingEnumFile = ` var mpaaRatingEnumFile = `
// //
// Code generated by go-jet DO NOT EDIT. // Code generated by go-jet DO NOT EDIT.