diff --git a/generator/template/model_template.go b/generator/template/model_template.go index 84f46d7..2b4ec8a 100644 --- a/generator/template/model_template.go +++ b/generator/template/model_template.go @@ -106,10 +106,10 @@ func getTableModelImports(modelType TableModel, tableMetaData metadata.Table) [] importPaths := map[string]bool{} for _, columnMetaData := range tableMetaData.Columns { field := modelType.Field(columnMetaData) - importPath := field.Type.ImportPath - - if importPath != "" { - importPaths[importPath] = true + for _, importPath := range append([]string{field.Type.ImportPath}, field.Type.AdditionalImportPaths...) { + if importPath != "" { + importPaths[importPath] = true + } } } @@ -207,8 +207,9 @@ func (f TableModelField) TagsString() string { // Type represents type of the struct field type Type struct { - ImportPath string - Name string + ImportPath string + AdditionalImportPaths []string + Name string } // NewType creates new type for dummy object diff --git a/tests/mysql/generator_template_test.go b/tests/mysql/generator_template_test.go index 8c539e2..18055c1 100644 --- a/tests/mysql/generator_template_test.go +++ b/tests/mysql/generator_template_test.go @@ -8,7 +8,7 @@ import ( "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/3rdparty/snaker" "github.com/go-jet/jet/v2/internal/utils/dbidentifier" - postgres2 "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/dbconfig" file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" "github.com/stretchr/testify/require" @@ -49,7 +49,7 @@ func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path") }), @@ -67,7 +67,7 @@ func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseModel(template.Model{ @@ -88,7 +88,7 @@ func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseSQLBuilder(template.SQLBuilder{ @@ -111,7 +111,7 @@ func TestGeneratorTemplate_Model_ChangePath(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseModel(template.DefaultModel().UsePath(newModelPath)) @@ -129,7 +129,7 @@ func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseSQLBuilder(template.DefaultSQLBuilder().UsePath(newModelPath)) @@ -150,7 +150,7 @@ func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseModel(template.DefaultModel(). @@ -188,7 +188,7 @@ func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseModel(template.DefaultModel(). @@ -216,7 +216,7 @@ func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseSQLBuilder(template.DefaultSQLBuilder(). @@ -249,7 +249,7 @@ func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseSQLBuilder(template.DefaultSQLBuilder(). @@ -289,7 +289,7 @@ func TestGeneratorTemplate_SQLBuilder_DefaultAlias(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseSQLBuilder(template.DefaultSQLBuilder(). @@ -313,7 +313,7 @@ func TestGeneratorTemplate_Model_AddTags(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseModel(template.DefaultModel(). @@ -354,7 +354,7 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseModel(template.DefaultModel(). @@ -376,6 +376,12 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { defaultTableModelField.Type = template.NewType(sql.NullFloat64{}) case "*time.Time": defaultTableModelField.Type = template.NewType(sql.NullTime{}) + case "time.Time": + defaultTableModelField.Type = template.Type{ + ImportPath: "database/sql", + AdditionalImportPaths: []string{"github.com/google/uuid"}, + Name: "sql.Null[uuid.UUID]", + } } return defaultTableModelField }) @@ -387,17 +393,19 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { require.Nil(t, err) data := file2.Exists(t, defaultModelPath, "film.go") - require.Contains(t, data, "\"database/sql\"") + require.Contains(t, data, `database/sql"`) + require.Contains(t, data, `"github.com/google/uuid"`) require.Contains(t, data, "Description sql.NullString") require.Contains(t, data, "ReleaseYear *int16") require.Contains(t, data, "SpecialFeatures sql.NullString") + require.Contains(t, data, "LastUpdate sql.Null[uuid.UUID]") } func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { err := mysql2.Generate( tempTestDir, dbConnection("dvds"), - template.Default(postgres2.Dialect). + template.Default(mysql.Dialect). UseSchema(func(schemaMetaData metadata.Schema) template.Schema { return template.DefaultSchema(schemaMetaData). UseSQLBuilder(template.DefaultSQLBuilder(). @@ -420,5 +428,5 @@ func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { require.Nil(t, err) actor := file2.Exists(t, defaultActorSQLBuilderFilePath) - require.Contains(t, actor, "ActorID postgres.ColumnString") + require.Contains(t, actor, "ActorID mysql.ColumnString") } diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go index e518db7..a63c3d2 100644 --- a/tests/postgres/generator_template_test.go +++ b/tests/postgres/generator_template_test.go @@ -3,9 +3,6 @@ package postgres import ( "database/sql" "fmt" - "path/filepath" - "testing" - "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/template" @@ -16,6 +13,8 @@ import ( "github.com/go-jet/jet/v2/tests/dbconfig" file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" "github.com/stretchr/testify/require" + "path/filepath" + "testing" ) const tempTestDir = "./.tempTestDir" @@ -432,6 +431,12 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { defaultTableModelField.Type = template.NewType(sql.NullFloat64{}) case "*time.Time": defaultTableModelField.Type = template.NewType(sql.NullTime{}) + case "time.Time": + defaultTableModelField.Type = template.Type{ + ImportPath: "database/sql", + AdditionalImportPaths: []string{"github.com/google/uuid"}, + Name: "sql.Null[uuid.UUID]", + } } return defaultTableModelField }) @@ -443,10 +448,12 @@ func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { require.Nil(t, err) data := file2.Exists(t, defaultModelPath, "film.go") - require.Contains(t, data, "\"database/sql\"") + require.Contains(t, data, `"database/sql"`) + require.Contains(t, data, `"github.com/google/uuid"`) require.Contains(t, data, "Description sql.NullString") require.Contains(t, data, "ReleaseYear sql.NullInt32") require.Contains(t, data, "SpecialFeatures sql.NullString") + require.Contains(t, data, "LastUpdate sql.Null[uuid.UUID]") } func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) {