From 8864667f4772610525a64b961a8a3582cdd7782c Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 27 Jul 2021 17:39:21 +0200 Subject: [PATCH] Add the ability to fully customize jet generated files. --- .../internal/metadata/column_meta_data.go | 168 -------- .../internal/metadata/dialect_query_set.go | 15 - generator/internal/metadata/enum_meta_data.go | 12 - generator/internal/metadata/meta_data.go | 6 - .../internal/metadata/schema_meta_data.go | 61 --- .../internal/metadata/table_meta_data.go | 103 ----- generator/internal/template/generate.go | 107 ----- generator/internal/template/templates.go | 213 ---------- generator/metadata/column_meta_data.go | 27 ++ generator/metadata/dialect_query_set.go | 35 ++ generator/metadata/enum_meta_data.go | 7 + generator/metadata/schema_meta_data.go | 14 + generator/metadata/table_meta_data.go | 22 + generator/mysql/mysql_generator.go | 21 +- generator/mysql/query_set.go | 108 ++--- generator/postgres/postgres_generator.go | 36 +- generator/postgres/query_set.go | 120 +++--- generator/template/file_templates.go | 223 ++++++++++ generator/template/generator_template.go | 60 +++ generator/template/model_template.go | 327 +++++++++++++++ generator/template/model_template_test.go | 45 ++ generator/template/process.go | 269 ++++++++++++ generator/template/sql_builder_template.go | 225 ++++++++++ .../template/sql_builder_template_test.go | 11 + internal/3rdparty/snaker/snaker.go | 8 +- internal/testutils/test_utils.go | 4 +- internal/testutils/time_utils.go | 12 +- internal/utils/throw/throw.go | 8 + internal/utils/utils.go | 24 +- internal/utils/utils_test.go | 5 - qrm/scan_context.go | 3 +- tests/dbconfig/dbconfig.go | 12 +- tests/init/init.go | 20 +- tests/internal/utils/file/file.go | 25 ++ tests/mysql/alltypes_test.go | 4 +- tests/mysql/generator_template_test.go | 389 ++++++++++++++++++ tests/mysql/update_test.go | 8 +- tests/postgres/generator_template_test.go | 387 +++++++++++++++++ tests/postgres/generator_test.go | 10 +- tests/postgres/main_test.go | 2 + 40 files changed, 2274 insertions(+), 882 deletions(-) delete mode 100644 generator/internal/metadata/column_meta_data.go delete mode 100644 generator/internal/metadata/dialect_query_set.go delete mode 100644 generator/internal/metadata/enum_meta_data.go delete mode 100644 generator/internal/metadata/meta_data.go delete mode 100644 generator/internal/metadata/schema_meta_data.go delete mode 100644 generator/internal/metadata/table_meta_data.go delete mode 100644 generator/internal/template/generate.go delete mode 100644 generator/internal/template/templates.go create mode 100644 generator/metadata/column_meta_data.go create mode 100644 generator/metadata/dialect_query_set.go create mode 100644 generator/metadata/enum_meta_data.go create mode 100644 generator/metadata/schema_meta_data.go create mode 100644 generator/metadata/table_meta_data.go create mode 100644 generator/template/file_templates.go create mode 100644 generator/template/generator_template.go create mode 100644 generator/template/model_template.go create mode 100644 generator/template/model_template_test.go create mode 100644 generator/template/process.go create mode 100644 generator/template/sql_builder_template.go create mode 100644 generator/template/sql_builder_template_test.go create mode 100644 internal/utils/throw/throw.go create mode 100644 tests/internal/utils/file/file.go create mode 100644 tests/mysql/generator_template_test.go create mode 100644 tests/postgres/generator_template_test.go diff --git a/generator/internal/metadata/column_meta_data.go b/generator/internal/metadata/column_meta_data.go deleted file mode 100644 index dceb7c0..0000000 --- a/generator/internal/metadata/column_meta_data.go +++ /dev/null @@ -1,168 +0,0 @@ -package metadata - -import ( - "database/sql" - "fmt" - "github.com/go-jet/jet/v2/internal/utils" - "strings" -) - -// ColumnMetaData struct -type ColumnMetaData struct { - Name string - IsNullable bool - DataType string - EnumName string - IsUnsigned bool - - SqlBuilderColumnType string - GoBaseType string - GoModelType string -} - -// NewColumnMetaData create new column meta data that describes one column in SQL database -func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData { - columnMetaData := ColumnMetaData{ - Name: name, - IsNullable: isNullable, - DataType: dataType, - EnumName: enumName, - IsUnsigned: isUnsigned, - } - - columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType() - columnMetaData.GoBaseType = columnMetaData.getGoBaseType() - columnMetaData.GoModelType = columnMetaData.getGoModelType() - - return columnMetaData -} - -// getSqlBuilderColumnType returns type of jet sql builder column -func (c ColumnMetaData) getSqlBuilderColumnType() string { - switch c.DataType { - case "boolean": - return "Bool" - case "smallint", "integer", "bigint", - "tinyint", "mediumint", "int", "year": //MySQL - return "Integer" - case "date": - return "Date" - case "timestamp without time zone", - "timestamp", "datetime": //MySQL: - return "Timestamp" - case "timestamp with time zone": - return "Timestampz" - case "time without time zone", - "time": //MySQL - return "Time" - case "time with time zone": - return "Timez" - case "interval": - return "Interval" - case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid", - "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", - "char", "varchar", "binary", "varbinary", - "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL - return "String" - case "real", "numeric", "decimal", "double precision", "float", - "double": // MySQL - return "Float" - default: - fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.") - return "String" - } -} - -// getGoBaseType returns model type for column info. -func (c ColumnMetaData) getGoBaseType() string { - switch c.DataType { - case "USER-DEFINED", "enum": - return utils.ToGoIdentifier(c.EnumName) - case "boolean": - return "bool" - case "tinyint": - return "int8" - case "smallint", - "year": - return "int16" - case "integer", - "mediumint", "int": //MySQL - return "int32" - case "bigint": - return "int64" - case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone", - "timestamp", "datetime", "time": // MySQL - return "time.Time" - case "bytea", - "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL - return "[]byte" - case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb", - "xml", "point", "interval", "line", "ARRAY", - "char", "varchar", "tinytext", "mediumtext", "longtext": // MySQL - return "string" - case "real": - return "float32" - case "numeric", "decimal", "double precision", "float", - "double": // MySQL - return "float64" - case "uuid": - return "uuid.UUID" - default: - fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.") - return "string" - } -} - -// GoModelType returns model type for column info with optional pointer if -// column can be NULL. -func (c ColumnMetaData) getGoModelType() string { - typeStr := c.GoBaseType - - if strings.Contains(typeStr, "int") && c.IsUnsigned { - typeStr = "u" + typeStr - } - - if c.IsNullable { - return "*" + typeStr - } - - return typeStr -} - -// GoModelTag returns model field tag for column -func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string { - tags := []string{} - - if isPrimaryKey { - tags = append(tags, "primary_key") - } - - if len(tags) > 0 { - return "`sql:\"" + strings.Join(tags, ",") + "\"`" - } - - return "" -} - -func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData { - - rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName) - utils.PanicOnError(err) - defer rows.Close() - - ret := []ColumnMetaData{} - - for rows.Next() { - var name, isNullable, dataType, enumName string - var isUnsigned bool - err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned) - utils.PanicOnError(err) - - ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned)) - } - - err = rows.Err() - utils.PanicOnError(err) - - return ret -} diff --git a/generator/internal/metadata/dialect_query_set.go b/generator/internal/metadata/dialect_query_set.go deleted file mode 100644 index 6c91825..0000000 --- a/generator/internal/metadata/dialect_query_set.go +++ /dev/null @@ -1,15 +0,0 @@ -package metadata - -import ( - "database/sql" -) - -// DialectQuerySet is set of methods necessary to retrieve dialect meta data information -type DialectQuerySet interface { - ListOfTablesQuery() string - PrimaryKeysQuery() string - ListOfColumnsQuery() string - ListOfEnumsQuery() string - - GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData -} diff --git a/generator/internal/metadata/enum_meta_data.go b/generator/internal/metadata/enum_meta_data.go deleted file mode 100644 index 8479c60..0000000 --- a/generator/internal/metadata/enum_meta_data.go +++ /dev/null @@ -1,12 +0,0 @@ -package metadata - -// EnumMetaData struct -type EnumMetaData struct { - EnumName string - Values []string -} - -// Name returns enum name -func (e EnumMetaData) Name() string { - return e.EnumName -} diff --git a/generator/internal/metadata/meta_data.go b/generator/internal/metadata/meta_data.go deleted file mode 100644 index 17d2f5c..0000000 --- a/generator/internal/metadata/meta_data.go +++ /dev/null @@ -1,6 +0,0 @@ -package metadata - -// MetaData interface -type MetaData interface { - Name() string -} diff --git a/generator/internal/metadata/schema_meta_data.go b/generator/internal/metadata/schema_meta_data.go deleted file mode 100644 index bc85511..0000000 --- a/generator/internal/metadata/schema_meta_data.go +++ /dev/null @@ -1,61 +0,0 @@ -package metadata - -import ( - "database/sql" - "fmt" - "github.com/go-jet/jet/v2/internal/utils" -) - -// SchemaMetaData struct -type SchemaMetaData struct { - TablesMetaData []MetaData - ViewsMetaData []MetaData - EnumsMetaData []MetaData -} - -// IsEmpty returns true if schema info does not contain any table, views or enums metadata -func (s SchemaMetaData) IsEmpty() bool { - return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0 -} - -const ( - baseTable = "BASE TABLE" - view = "VIEW" -) - -// GetSchemaMetaData returns schema information from db connection. -func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) { - - schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable) - schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view) - schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName) - - fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),", - len(schemaInfo.EnumsMetaData), "enum(s)") - - return -} - -func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData { - - rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType) - utils.PanicOnError(err) - defer rows.Close() - - ret := []MetaData{} - for rows.Next() { - var tableName string - - err = rows.Scan(&tableName) - utils.PanicOnError(err) - - tableInfo := GetTableMetaData(db, querySet, schemaName, tableName) - - ret = append(ret, tableInfo) - } - - err = rows.Err() - utils.PanicOnError(err) - - return ret -} diff --git a/generator/internal/metadata/table_meta_data.go b/generator/internal/metadata/table_meta_data.go deleted file mode 100644 index c106dd4..0000000 --- a/generator/internal/metadata/table_meta_data.go +++ /dev/null @@ -1,103 +0,0 @@ -package metadata - -import ( - "database/sql" - "github.com/go-jet/jet/v2/internal/utils" - "strings" -) - -// TableMetaData metadata struct -type TableMetaData struct { - SchemaName string - name string - PrimaryKeys map[string]bool - Columns []ColumnMetaData -} - -// Name returns table info name -func (t TableMetaData) Name() string { - return t.name -} - -// IsPrimaryKey returns if column is a part of primary key -func (t TableMetaData) IsPrimaryKey(column string) bool { - return t.PrimaryKeys[column] -} - -// MutableColumns returns list of mutable columns for table -func (t TableMetaData) MutableColumns() []ColumnMetaData { - ret := []ColumnMetaData{} - - for _, column := range t.Columns { - if t.IsPrimaryKey(column.Name) { - continue - } - - ret = append(ret, column) - } - - return ret -} - -// GetImports returns model imports for table. -func (t TableMetaData) GetImports() []string { - imports := map[string]string{} - - for _, column := range t.Columns { - columnType := column.GoBaseType - - switch columnType { - case "time.Time": - imports["time.Time"] = "time" - case "uuid.UUID": - imports["uuid.UUID"] = "github.com/google/uuid" - } - } - - ret := []string{} - - for _, packageImport := range imports { - ret = append(ret, packageImport) - } - - return ret -} - -// GoStructName returns go struct name for sql builder -func (t TableMetaData) GoStructName() string { - return utils.ToGoIdentifier(t.name) + "Table" -} - -// GoStructImplName returns go struct impl name for sql builder -func (t TableMetaData) GoStructImplName() string { - name := utils.ToGoIdentifier(t.name) + "Table" - return string(strings.ToLower(name)[0]) + name[1:] -} - -// GetTableMetaData returns table info metadata -func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) { - tableInfo.SchemaName = schemaName - tableInfo.name = tableName - - tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName) - tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName) - return -} - -func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool { - - rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName) - utils.PanicOnError(err) - - primaryKeyMap := map[string]bool{} - - for rows.Next() { - primaryKey := "" - err := rows.Scan(&primaryKey) - utils.PanicOnError(err) - - primaryKeyMap[primaryKey] = true - } - - return primaryKeyMap -} diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go deleted file mode 100644 index 34e9ca1..0000000 --- a/generator/internal/template/generate.go +++ /dev/null @@ -1,107 +0,0 @@ -package template - -import ( - "bytes" - "fmt" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/utils" - "path/filepath" - "text/template" -) - -// GenerateFiles generates Go files from tables and enums metadata -func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) { - if schemaInfo.IsEmpty() { - return - } - - fmt.Println("Destination directory:", destDir) - fmt.Println("Cleaning up destination directory...") - err := utils.CleanUpGeneratedFiles(destDir) - utils.PanicOnError(err) - - tableSQLBuilderTemplate := getTableSQLBuilderTemplate(dialect) - generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect) - generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect) - generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect) - - generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect) - generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect) - generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect) - - fmt.Println("Done") -} - -func getTableSQLBuilderTemplate(dialect jet.Dialect) string { - if dialect.Name() == "PostgreSQL" { - return tablePostgreSQLBuilderTemplate - } - - return tableSQLBuilderTemplate -} - -func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { - if len(metaData) == 0 { - return - } - fmt.Printf("Generating %s sql builder files...\n", fileTypes) - generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect) -} - -func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) { - if len(metaData) == 0 { - return - } - fmt.Printf("Generating %s model files...\n", fileTypes) - generateGoFiles(destDir, "model", modelTemplate, metaData, dialect) -} - -func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) { - modelDirPath := filepath.Join(dirPath, packageName) - - err := utils.EnsureDirPath(modelDirPath) - utils.PanicOnError(err) - - autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect) - utils.PanicOnError(err) - - for _, metaData := range metaDataList { - text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName}) - utils.PanicOnError(err) - - err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...)) - utils.PanicOnError(err) - } - - return -} - -// GenerateTemplate generates template with template text and template data. -func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) { - - t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ - "ToGoIdentifier": utils.ToGoIdentifier, - "ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier, - "dialect": func() jet.Dialect { - return dialect - }, - "param": func(name string) interface{} { - if len(params) > 0 { - return params[0][name] - } - return "" - }, - }).Parse(templateText) - - if err != nil { - return nil, err - } - - var buf bytes.Buffer - if err := t.Execute(&buf, templateData); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go deleted file mode 100644 index 186ade0..0000000 --- a/generator/internal/template/templates.go +++ /dev/null @@ -1,213 +0,0 @@ -package template - -var autoGenWarningTemplate = ` -// -// Code generated by go-jet DO NOT EDIT. -// -// WARNING: Changes to this file may cause incorrect behavior -// and will be lost if the code is regenerated -// - -` - -var tableSQLBuilderTemplate = ` -{{define "column-list" -}} - {{- range $i, $c := . }} - {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column - {{- end}} -{{- end}} - -package {{param "package"}} - -import ( - "github.com/go-jet/jet/v2/{{dialect.PackageName}}" -) - -var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "") - -type {{.GoStructName}} struct { - {{dialect.PackageName}}.Table - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}} -{{- end}} - - AllColumns {{dialect.PackageName}}.ColumnList - MutableColumns {{dialect.PackageName}}.ColumnList -} - -// AS creates new {{.GoStructName}} with assigned alias -func (a {{.GoStructName}}) AS(alias string) {{.GoStructName}} { - return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias) -} - -// Schema creates new {{.GoStructName}} with assigned schema name -func (a {{.GoStructName}}) FromSchema(schemaName string) {{.GoStructName}} { - return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias()) -} - -func new{{.GoStructName}}(schemaName, tableName, alias string) {{.GoStructName}} { - var ( - {{- range .Columns}} - {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") - {{- end}} - allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } - mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } - ) - - return {{.GoStructName}}{ - Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, -{{- end}} - - AllColumns: allColumns, - MutableColumns: mutableColumns, - } -} -` - -var tablePostgreSQLBuilderTemplate = ` -{{define "column-list" -}} - {{- range $i, $c := . }} - {{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column - {{- end}} -{{- end}} - -package {{param "package"}} - -import ( - "github.com/go-jet/jet/v2/{{dialect.PackageName}}" -) - -var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "") - -type {{.GoStructImplName}} struct { - {{dialect.PackageName}}.Table - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}} -{{- end}} - - AllColumns {{dialect.PackageName}}.ColumnList - MutableColumns {{dialect.PackageName}}.ColumnList -} - -type {{.GoStructName}} struct { - {{.GoStructImplName}} - - EXCLUDED {{.GoStructImplName}} -} - -// AS creates new {{.GoStructName}} with assigned alias -func (a {{.GoStructName}}) AS(alias string) *{{.GoStructName}} { - return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias) -} - -// Schema creates new {{.GoStructName}} with assigned schema name -func (a {{.GoStructName}}) FromSchema(schemaName string) *{{.GoStructName}} { - return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias()) -} - -func new{{.GoStructName}}(schemaName, tableName, alias string) *{{.GoStructName}} { - return &{{.GoStructName}}{ - {{.GoStructImplName}}: new{{.GoStructName}}Impl(schemaName, tableName, alias), - EXCLUDED: new{{.GoStructName}}Impl("", "excluded", ""), - } -} - -func new{{.GoStructName}}Impl(schemaName, tableName, alias string) {{.GoStructImplName}} { - var ( - {{- range .Columns}} - {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") - {{- end}} - allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } - mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } - ) - - return {{.GoStructImplName}}{ - Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), - - //Columns -{{- range .Columns}} - {{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column, -{{- end}} - - AllColumns: allColumns, - MutableColumns: mutableColumns, - } -} -` - -var tableModelTemplate = `package model - -{{ if .GetImports }} -import ( -{{- range .GetImports}} - "{{.}}" -{{- end}} -) -{{end}} - - -type {{ToGoIdentifier .Name}} struct { -{{- range .Columns}} - {{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + ` -{{- end}} -} - - -` -var enumSQLBuilderTemplate = `package enum - -import "github.com/go-jet/jet/v2/{{dialect.PackageName}}" - -var {{ToGoIdentifier $.Name}} = &struct { -{{- range $index, $element := .Values}} - {{ToGoEnumValueIdentifier $.Name $element}} {{dialect.PackageName}}.StringExpression -{{- end}} -} { -{{- range $index, $element := .Values}} - {{ToGoEnumValueIdentifier $.Name $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"), -{{- end}} -} -` - -var enumModelTemplate = `package model - -import "errors" - -type {{ToGoIdentifier $.Name}} string - -const ( -{{- range $index, $element := .Values}} - {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} {{ToGoIdentifier $.Name}} = "{{$element}}" -{{- end}} -) - -func (e *{{ToGoIdentifier $.Name}}) Scan(value interface{}) error { - if v, ok := value.(string); !ok { - return errors.New("jet: Invalid data for {{ToGoIdentifier $.Name}} enum") - } else { - switch string(v) { -{{- range $index, $element := .Values}} - case "{{$element}}": - *e = {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} -{{- end}} - default: - return errors.New("jet: Inavlid data " + string(v) + "for {{ToGoIdentifier $.Name}} enum") - } - - return nil - } -} - -func (e {{ToGoIdentifier $.Name}}) String() string { - return string(e) -} - -` diff --git a/generator/metadata/column_meta_data.go b/generator/metadata/column_meta_data.go new file mode 100644 index 0000000..74184b6 --- /dev/null +++ b/generator/metadata/column_meta_data.go @@ -0,0 +1,27 @@ +package metadata + +// Column struct +type Column struct { + Name string + IsPrimaryKey bool + IsNullable bool + DataType DataType +} + +// DataTypeKind is database type kind(base, enum, user-defined, array) +type DataTypeKind string + +// DataTypeKind possible values +const ( + BaseType DataTypeKind = "base" + EnumType DataTypeKind = "enum" + UserDefinedType DataTypeKind = "user-defined" + ArrayType DataTypeKind = "array" +) + +// DataType contains information about column data type +type DataType struct { + Name string + Kind DataTypeKind + IsUnsigned bool +} diff --git a/generator/metadata/dialect_query_set.go b/generator/metadata/dialect_query_set.go new file mode 100644 index 0000000..036e4d5 --- /dev/null +++ b/generator/metadata/dialect_query_set.go @@ -0,0 +1,35 @@ +package metadata + +import ( + "database/sql" + "fmt" +) + +// TableType is type of database table(view or base) +type TableType string + +const ( + baseTable TableType = "BASE TABLE" + viewTable TableType = "VIEW" +) + +// DialectQuerySet is set of methods necessary to retrieve dialect meta data information +type DialectQuerySet interface { + GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) []Table + GetEnumsMetaData(db *sql.DB, schemaName string) []Enum +} + +// GetSchema retrieves Schema information from database +func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema { + ret := Schema{ + Name: schemaName, + TablesMetaData: querySet.GetTablesMetaData(db, schemaName, baseTable), + ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, viewTable), + EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName), + } + + fmt.Println(" FOUND", len(ret.TablesMetaData), "table(s),", len(ret.ViewsMetaData), "view(s),", + len(ret.EnumsMetaData), "enum(s)") + + return ret +} diff --git a/generator/metadata/enum_meta_data.go b/generator/metadata/enum_meta_data.go new file mode 100644 index 0000000..7aea3d6 --- /dev/null +++ b/generator/metadata/enum_meta_data.go @@ -0,0 +1,7 @@ +package metadata + +// Enum metadata struct +type Enum struct { + Name string `sql:"primary_key"` + Values []string +} diff --git a/generator/metadata/schema_meta_data.go b/generator/metadata/schema_meta_data.go new file mode 100644 index 0000000..c4c505a --- /dev/null +++ b/generator/metadata/schema_meta_data.go @@ -0,0 +1,14 @@ +package metadata + +// Schema struct +type Schema struct { + Name string + TablesMetaData []Table + ViewsMetaData []Table + EnumsMetaData []Enum +} + +// IsEmpty returns true if schema info does not contain any table, views or enums metadata +func (s Schema) IsEmpty() bool { + return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0 +} diff --git a/generator/metadata/table_meta_data.go b/generator/metadata/table_meta_data.go new file mode 100644 index 0000000..6479dc2 --- /dev/null +++ b/generator/metadata/table_meta_data.go @@ -0,0 +1,22 @@ +package metadata + +// Table metadata struct +type Table struct { + Name string + Columns []Column +} + +// MutableColumns returns list of mutable columns for table +func (t Table) MutableColumns() []Column { + var ret []Column + + for _, column := range t.Columns { + if column.IsPrimaryKey { + continue + } + + ret = append(ret, column) + } + + return ret +} diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index 7f5d99a..ab00822 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -3,11 +3,11 @@ package mysql import ( "database/sql" "fmt" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/generator/internal/template" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/mysql" - "path" ) // DBConnection contains MySQL connection details @@ -22,7 +22,7 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) (err error) { +func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) (err error) { defer utils.ErrorCatch(&err) db := openConnection(dbConn) @@ -30,11 +30,14 @@ func Generate(destDir string, dbConn DBConnection) (err error) { fmt.Println("Retrieving database information...") // No schemas in MySQL - dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{}) + schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbConn.DBName) - genPath := path.Join(destDir, dbConn.DBName) + genTemplate := template.Default(mysql.Dialect) + if len(generatorTemplate) > 0 { + genTemplate = generatorTemplate[0] + } - template.GenerateFiles(genPath, dbInfo, mysql.Dialect) + template.ProcessSchema(destDir, schemaMetaData, genTemplate) return nil } @@ -46,10 +49,10 @@ func openConnection(dbConn DBConnection) *sql.DB { } fmt.Println("Connecting to MySQL database: " + connectionString) db, err := sql.Open("mysql", connectionString) - utils.PanicOnError(err) + throw.OnError(err) err = db.Ping() - utils.PanicOnError(err) + throw.OnError(err) return db } diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 1b4e2b2..a409eb7 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -1,81 +1,91 @@ package mysql import ( + "context" "database/sql" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/qrm" "strings" ) // mySqlQuerySet is dialect query set for MySQL type mySqlQuerySet struct{} -func (m *mySqlQuerySet) ListOfTablesQuery() string { - return ` -SELECT table_name +func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { + query := ` +SELECT table_name as "table.name" FROM INFORMATION_SCHEMA.tables WHERE table_schema = ? and table_type = ?; ` + var tables []metadata.Table + + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + throw.OnError(err) + + for i := range tables { + tables[i].Columns = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + } + + return tables } -func (m *mySqlQuerySet) PrimaryKeysQuery() string { - return ` -SELECT k.column_name -FROM information_schema.table_constraints t -JOIN information_schema.key_column_usage k -USING(constraint_name,table_schema,table_name) -WHERE t.constraint_type='PRIMARY KEY' - AND t.table_schema= ? - AND t.table_name= ?; -` -} - -func (m *mySqlQuerySet) ListOfColumnsQuery() string { - return ` -SELECT COLUMN_NAME, - IS_NULLABLE, IF(COLUMN_TYPE = 'tinyint(1)', 'boolean', DATA_TYPE), - IF(DATA_TYPE = 'enum', CONCAT(TABLE_NAME, '_', COLUMN_NAME), ''), - COLUMN_TYPE LIKE '%unsigned%' -FROM information_schema.columns -WHERE table_schema = ? and table_name = ? +func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { + query := ` +WITH primaryKeys AS ( + SELECT k.column_name + FROM information_schema.table_constraints t + JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name) + WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY' +) +SELECT COLUMN_NAME AS "column.Name", + IS_NULLABLE = "YES" AS "column.IsNullable", + (EXISTS(SELECT 1 FROM primaryKeys AS pk WHERE pk.column_name = columns.column_name)) AS "column.IsPrimaryKey", + IF (COLUMN_TYPE = 'tinyint(1)', + 'boolean', + IF (DATA_TYPE='enum', + CONCAT(TABLE_NAME, '_', COLUMN_NAME), + DATA_TYPE) + ) AS "dataType.Name", + IF (DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind", + COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned" +FROM information_schema.columns +WHERE table_schema = ? AND table_name = ? ORDER BY ordinal_position; ` + var columns []metadata.Column + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) + throw.OnError(err) + + return columns } -func (m *mySqlQuerySet) ListOfEnumsQuery() string { - return ` -SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5) +func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { + query := ` +SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name", + SUBSTRING(c.COLUMN_TYPE,5) as "values" FROM information_schema.columns as c INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name) WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; ` -} + var queryResult []struct { + Name string + Values string + } -func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) + throw.OnError(err) - rows, err := db.Query(m.ListOfEnumsQuery(), schemaName) - utils.PanicOnError(err) - defer rows.Close() + var ret []metadata.Enum - ret := []metadata.MetaData{} + for _, result := range queryResult { + enumValues := strings.Replace(result.Values[1:len(result.Values)-1], "'", "", -1) - for rows.Next() { - var enumName string - var enumValues string - err = rows.Scan(&enumName, &enumValues) - utils.PanicOnError(err) - - enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1) - - ret = append(ret, metadata.EnumMetaData{ - EnumName: enumName, - Values: strings.Split(enumValues, ","), + ret = append(ret, metadata.Enum{ + Name: result.Name, + Values: strings.Split(enumValues, ","), }) } - err = rows.Err() - utils.PanicOnError(err) - return ret - } diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index 970fd2d..ebb5420 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -3,9 +3,10 @@ package postgres import ( "database/sql" "fmt" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/generator/internal/template" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/postgres" "path" "strconv" @@ -25,38 +26,39 @@ type DBConnection struct { } // Generate generates jet files at destination dir from database connection details -func Generate(destDir string, dbConn DBConnection) (err error) { +func Generate(destDir string, dbConn DBConnection, genTemplate ...template.Template) (err error) { defer utils.ErrorCatch(&err) - db, err := openConnection(dbConn) - utils.PanicOnError(err) + db := openConnection(dbConn) defer utils.DBClose(db) fmt.Println("Retrieving schema information...") - schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{}) - genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) - template.GenerateFiles(genPath, schemaInfo, postgres.Dialect) + generatorTemplate := template.Default(postgres.Dialect) + if len(genTemplate) > 0 { + generatorTemplate = genTemplate[0] + } + + schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, dbConn.SchemaName) + + dirPath := path.Join(destDir, dbConn.DBName) + + template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) return } -func openConnection(dbConn DBConnection) (*sql.DB, error) { +func openConnection(dbConn DBConnection) *sql.DB { connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params) fmt.Println("Connecting to postgres database: " + connectionString) db, err := sql.Open("postgres", connectionString) - if err != nil { - return nil, err - } + throw.OnError(err) err = db.Ping() + throw.OnError(err) - if err != nil { - return nil, err - } - - return db, nil + return db } diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 0fc8fdc..e2fb969 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -1,81 +1,83 @@ package postgres import ( + "context" "database/sql" - "github.com/go-jet/jet/v2/generator/internal/metadata" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/qrm" ) // postgresQuerySet is dialect query set for PostgreSQL type postgresQuerySet struct{} -func (p *postgresQuerySet) ListOfTablesQuery() string { - return ` -SELECT table_name +func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { + query := ` +SELECT table_name as "table.name" FROM information_schema.tables -where table_schema = $1 and table_type = $2; +WHERE table_schema = $1 and table_type = $2; ` + var tables []metadata.Table + + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + throw.OnError(err) + + for i := range tables { + tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + } + + return tables } -func (p *postgresQuerySet) PrimaryKeysQuery() string { - return ` -SELECT c.column_name -FROM information_schema.key_column_usage AS c -LEFT JOIN information_schema.table_constraints AS t -ON t.constraint_name = c.constraint_name -WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'; -` -} - -func (p *postgresQuerySet) ListOfColumnsQuery() string { - return ` -SELECT column_name, is_nullable, data_type, udt_name, FALSE -FROM information_schema.columns +func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { + query := ` +WITH primaryKeys AS ( + SELECT column_name + FROM information_schema.key_column_usage AS c + LEFT JOIN information_schema.table_constraints AS t + ON t.constraint_name = c.constraint_name + WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY' +) +SELECT column_name as "column.Name", + is_nullable = 'YES' as "column.isNullable", + (EXISTS(SELECT 1 from primaryKeys as pk where pk.column_name = columns.column_name)) as "column.IsPrimaryKey", + dataType.kind as "dataType.Kind", + (case dataType.Kind when 'base' then data_type else LTRIM(udt_name, '_') end) as "dataType.Name", + FALSE as "dataType.isUnsigned" +FROM information_schema.columns, + LATERAL (select (case data_type + when 'ARRAY' then 'array' + when 'USER-DEFINED' then + case (select typtype from pg_type where typname = columns.udt_name) + when 'e' then 'enum' + else 'user-defined' + end + else 'base' + end) as Kind) as dataType where table_schema = $1 and table_name = $2 -order by ordinal_position;` +order by ordinal_position; +` + var columns []metadata.Column + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) + throw.OnError(err) + + return columns } -func (p *postgresQuerySet) ListOfEnumsQuery() string { - return ` -SELECT t.typname, - e.enumlabel +func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { + query := ` +SELECT t.typname as "enum.name", + e.enumlabel as "values" FROM pg_catalog.pg_type t JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace WHERE n.nspname = $1 ORDER BY n.nspname, t.typname, e.enumsortorder;` -} - -func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { - rows, err := db.Query(p.ListOfEnumsQuery(), schemaName) - utils.PanicOnError(err) - defer rows.Close() - - enumsInfosMap := map[string][]string{} - for rows.Next() { - var enumName string - var enumValue string - err = rows.Scan(&enumName, &enumValue) - utils.PanicOnError(err) - - enumValues := enumsInfosMap[enumName] - - enumValues = append(enumValues, enumValue) - - enumsInfosMap[enumName] = enumValues - } - - err = rows.Err() - utils.PanicOnError(err) - - ret := []metadata.MetaData{} - - for enumName, enumValues := range enumsInfosMap { - ret = append(ret, metadata.EnumMetaData{ - EnumName: enumName, - Values: enumValues, - }) - } - - return ret + + var result []metadata.Enum + + err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) + throw.OnError(err) + + return result } diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go new file mode 100644 index 0000000..a738ea6 --- /dev/null +++ b/generator/template/file_templates.go @@ -0,0 +1,223 @@ +package template + +var autoGenWarningTemplate = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +` + +var tableSQLBuilderTemplate = ` +{{define "column-list" -}} + {{- range $i, $c := . }} +{{- $field := columnField $c}} + {{- if gt $i 0 }}, {{end}}{{$field.Name}}Column + {{- end}} +{{- end}} + +package {{package}} + +import ( + "github.com/go-jet/jet/v2/{{dialect.PackageName}}" +) + +var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "") + +type {{tableTemplate.TypeName}} struct { + {{dialect.PackageName}}.Table + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} +{{- end}} + + AllColumns {{dialect.PackageName}}.ColumnList + MutableColumns {{dialect.PackageName}}.ColumnList +} + +// AS creates new {{tableTemplate.TypeName}} with assigned alias +func (a {{tableTemplate.TypeName}}) AS(alias string) {{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new {{tableTemplate.TypeName}} with assigned schema name +func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) {{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias()) +} + +func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) {{tableTemplate.TypeName}} { + var ( +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}") +{{- end}} + allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } + mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } + ) + + return {{tableTemplate.TypeName}}{ + Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}: {{$field.Name}}Column, +{{- end}} + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +var tablePostgreSQLBuilderTemplate = ` +{{define "column-list" -}} + {{- range $i, $c := . }} +{{- $field := columnField $c}} + {{- if gt $i 0 }}, {{end}}{{$field.Name}}Column + {{- end}} +{{- end}} + +package {{package}} + +import ( + "github.com/go-jet/jet/v2/{{dialect.PackageName}}" +) + +var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "") + +type {{structImplName}} struct { + {{dialect.PackageName}}.Table + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}} +{{- end}} + + AllColumns {{dialect.PackageName}}.ColumnList + MutableColumns {{dialect.PackageName}}.ColumnList +} + +type {{tableTemplate.TypeName}} struct { + {{structImplName}} + + EXCLUDED {{structImplName}} +} + +// AS creates new {{tableTemplate.TypeName}} with assigned alias +func (a {{tableTemplate.TypeName}}) AS(alias string) *{{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new {{tableTemplate.TypeName}} with assigned schema name +func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) *{{tableTemplate.TypeName}} { + return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias()) +} + +func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) *{{tableTemplate.TypeName}} { + return &{{tableTemplate.TypeName}}{ + {{structImplName}}: new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias), + EXCLUDED: new{{tableTemplate.TypeName}}Impl("", "excluded", ""), + } +} + +func new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias string) {{structImplName}} { + var ( +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}") +{{- end}} + allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} } + mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} } + ) + + return {{structImplName}}{ + Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns +{{- range $i, $c := .Columns}} +{{- $field := columnField $c}} + {{$field.Name}}: {{$field.Name}}Column, +{{- end}} + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +var tableModelFileTemplate = `package {{package}} + +{{ with modelImports }} +import ( +{{- range .}} + "{{.}}" +{{- end}} +) +{{end}} + +{{$modelTableTemplate := tableTemplate}} +type {{$modelTableTemplate.TypeName}} struct { +{{- range .Columns}} +{{- $field := structField .}} + {{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + ` +{{- end}} +} + +` + +var enumSQLBuilderTemplate = `package {{package}} + +import "github.com/go-jet/jet/v2/{{dialect.PackageName}}" + +var {{enumTemplate.InstanceName}} = &struct { +{{- range $index, $value := .Values}} + {{enumValueName $value}} {{dialect.PackageName}}.StringExpression +{{- end}} +} { +{{- range $index, $value := .Values}} + {{enumValueName $value}}: {{dialect.PackageName}}.NewEnumValue("{{$value}}"), +{{- end}} +} +` + +var enumModelTemplate = `package {{package}} +{{- $enumTemplate := enumTemplate}} + +import "errors" + +type {{$enumTemplate.TypeName}} string + +const ( +{{- range $_, $value := .Values}} + {{valueName $value}} {{$enumTemplate.TypeName}} = "{{$value}}" +{{- end}} +) + +func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error { + if v, ok := value.(string); !ok { + return errors.New("jet: Invalid scan value for {{$enumTemplate.TypeName}} enum. Enum value has to be of type string") + } else { + switch string(v) { +{{- range $_, $value := .Values}} + case "{{$value}}": + *e = {{valueName $value}} +{{- end}} + default: + return errors.New("jet: Invalid scan value '" + string(v) + "' for {{$enumTemplate.TypeName}} enum") + } + + return nil + } +} + +func (e {{$enumTemplate.TypeName}}) String() string { + return string(e) +} + +` diff --git a/generator/template/generator_template.go b/generator/template/generator_template.go new file mode 100644 index 0000000..38e8fbb --- /dev/null +++ b/generator/template/generator_template.go @@ -0,0 +1,60 @@ +package template + +import ( + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/jet" +) + +// Template is generator template used for file generation +type Template struct { + Dialect jet.Dialect + Schema func(schemaMetaData metadata.Schema) Schema +} + +// Default is default generator template implementation +func Default(dialect jet.Dialect) Template { + return Template{ + Dialect: dialect, + Schema: DefaultSchema, + } +} + +// UseSchema replaces current schema generate function with a new implementation and returns new generator template +func (t Template) UseSchema(schemaFunc func(schemaMetaData metadata.Schema) Schema) Template { + t.Schema = schemaFunc + return t +} + +// Schema is schema generator template used to generate schema(model and sql builder) files +type Schema struct { + Path string + Model Model + SQLBuilder SQLBuilder +} + +// UsePath replaces path and returns new schema template +func (s Schema) UsePath(path string) Schema { + s.Path = path + return s +} + +// UseModel returns new schema template with replaced template for model files generation +func (s Schema) UseModel(model Model) Schema { + s.Model = model + return s +} + +// UseSQLBuilder returns new schema with replaced template for sql builder files generation +func (s Schema) UseSQLBuilder(sqlBuilder SQLBuilder) Schema { + s.SQLBuilder = sqlBuilder + return s +} + +// DefaultSchema returns default schema template implementation +func DefaultSchema(schemaMetaData metadata.Schema) Schema { + return Schema{ + Path: schemaMetaData.Name, + Model: DefaultModel(), + SQLBuilder: DefaultSQLBuilder(), + } +} diff --git a/generator/template/model_template.go b/generator/template/model_template.go new file mode 100644 index 0000000..732cc2f --- /dev/null +++ b/generator/template/model_template.go @@ -0,0 +1,327 @@ +package template + +import ( + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/google/uuid" + "path" + "reflect" + "strings" + "time" +) + +// Model is template for model files generation +type Model struct { + Skip bool + Path string + Table func(table metadata.Table) TableModel + View func(table metadata.Table) ViewModel + Enum func(enum metadata.Enum) EnumModel +} + +// PackageName returns package name of model types +func (m Model) PackageName() string { + return path.Base(m.Path) +} + +// UsePath returns new Model template with replaced file path +func (m Model) UsePath(path string) Model { + m.Path = path + return m +} + +// UseTable returns new Model template with replaced template for table model files generation +func (m Model) UseTable(tableModelFunc func(table metadata.Table) TableModel) Model { + m.Table = tableModelFunc + return m +} + +// UseView returns new Model template with replaced template for view model files generation +func (m Model) UseView(tableModelFunc func(table metadata.Table) TableModel) Model { + m.View = tableModelFunc + return m +} + +// UseEnum returns new Model template with replaced template for enum model files generation +func (m Model) UseEnum(enumFunc func(enumMetaData metadata.Enum) EnumModel) Model { + m.Enum = enumFunc + return m +} + +// DefaultModel returns default Model template implementation +func DefaultModel() Model { + return Model{ + Skip: false, + Path: "/model", + Table: DefaultTableModel, + View: DefaultViewModel, + Enum: DefaultEnumModel, + } +} + +// TableModel is template for table model files generation +type TableModel struct { + Skip bool + FileName string + TypeName string + Field func(columnMetaData metadata.Column) TableModelField +} + +// ViewModel is template for view model files generation +type ViewModel = TableModel + +// DefaultViewModel is default view template implementation +var DefaultViewModel = DefaultTableModel + +// DefaultTableModel is default table template implementation +func DefaultTableModel(tableMetaData metadata.Table) TableModel { + return TableModel{ + FileName: utils.ToGoFileName(tableMetaData.Name), + TypeName: utils.ToGoIdentifier(tableMetaData.Name), + Field: DefaultTableModelField, + } +} + +// UseFileName returns new TableModel with new file name set +func (t TableModel) UseFileName(fileName string) TableModel { + t.FileName = fileName + return t +} + +// UseTypeName returns new TableModel with new type name set +func (t TableModel) UseTypeName(typeName string) TableModel { + t.TypeName = typeName + return t +} + +// UseField returns new TableModel with new TableModelField template function +func (t TableModel) UseField(structFieldFunc func(columnMetaData metadata.Column) TableModelField) TableModel { + t.Field = structFieldFunc + return t +} + +func getTableModelImports(modelType TableModel, tableMetaData metadata.Table) []string { + importPaths := map[string]bool{} + for _, columnMetaData := range tableMetaData.Columns { + field := modelType.Field(columnMetaData) + importPath := field.Type.ImportPath + + if importPath != "" { + importPaths[importPath] = true + } + } + + var ret []string + for importPath := range importPaths { + ret = append(ret, importPath) + } + + return ret +} + +// EnumModel is template for enum model files generation +type EnumModel struct { + Skip bool + FileName string + TypeName string + ValueName func(value string) string +} + +// UseFileName returns new EnumModel with new file name set +func (em EnumModel) UseFileName(fileName string) EnumModel { + em.FileName = fileName + return em +} + +// UseTypeName returns new EnumModel with new type name set +func (em EnumModel) UseTypeName(typeName string) EnumModel { + em.TypeName = typeName + return em +} + +// DefaultEnumModel returns default implementation for EnumModel +func DefaultEnumModel(enumMetaData metadata.Enum) EnumModel { + typeName := utils.ToGoIdentifier(enumMetaData.Name) + + return EnumModel{ + FileName: utils.ToGoFileName(enumMetaData.Name), + TypeName: typeName, + ValueName: func(value string) string { + return typeName + "_" + utils.ToGoIdentifier(value) + }, + } +} + +// TableModelField is template for table model field generation +type TableModelField struct { + Name string + Type Type + Tags []string +} + +// DefaultTableModelField returns default TableModelField implementation +func DefaultTableModelField(columnMetaData metadata.Column) TableModelField { + var tags []string + + if columnMetaData.IsPrimaryKey { + tags = append(tags, `sql:"primary_key"`) + } + + return TableModelField{ + Name: utils.ToGoIdentifier(columnMetaData.Name), + Type: getType(columnMetaData), + Tags: tags, + } +} + +// UseType returns new TypeModelField with a new field type set +func (f TableModelField) UseType(t Type) TableModelField { + f.Type = t + return f +} + +// UseName returns new TableModelField implementation with new field name set +func (f TableModelField) UseName(name string) TableModelField { + f.Name = name + return f +} + +// UseTags returns new TableModelField implementation with additional tags added. +func (f TableModelField) UseTags(tags ...string) TableModelField { + f.Tags = append(f.Tags, tags...) + return f +} + +// TagsString returns tags string representation +func (f TableModelField) TagsString() string { + if len(f.Tags) == 0 { + return "" + } + + return fmt.Sprintf("`%s`", strings.Join(f.Tags, " ")) +} + +// Type represents type of the struct field +type Type struct { + ImportPath string + Name string +} + +// NewType creates new type for dummy object +func NewType(dummyObject interface{}) Type { + return Type{ + ImportPath: getImportPath(dummyObject), + Name: getTypeName(dummyObject), + } +} + +func getTypeName(t interface{}) string { + typeStr := reflect.TypeOf(t).String() + typeStr = strings.Replace(typeStr, "[]uint8", "[]byte", -1) + + return typeStr +} + +func getImportPath(dummyData interface{}) string { + dataType := reflect.TypeOf(dummyData) + if dataType.Kind() == reflect.Ptr { + return dataType.Elem().PkgPath() + } + return dataType.PkgPath() +} + +func getType(columnMetadata metadata.Column) Type { + userDefinedType := getUserDefinedType(columnMetadata) + + if userDefinedType != "" { + if columnMetadata.IsNullable { + return Type{Name: "*" + userDefinedType} + } + return Type{Name: userDefinedType} + } + + return NewType(getGoType(columnMetadata)) +} + +func getUserDefinedType(column metadata.Column) string { + switch column.DataType.Kind { + case metadata.EnumType: + return utils.ToGoIdentifier(column.DataType.Name) + case metadata.UserDefinedType, metadata.ArrayType: + return "string" + } + + return "" +} + +func getGoType(column metadata.Column) interface{} { + defaultGoType := toGoType(column) + + if column.IsNullable { + return reflect.New(reflect.TypeOf(defaultGoType)).Interface() + } + + return defaultGoType +} + +// toGoType returns model type for column info. +func toGoType(column metadata.Column) interface{} { + switch column.DataType.Name { + case "USER-DEFINED", "enum": + return "" + case "boolean", "bool": + return false + case "tinyint": + if column.DataType.IsUnsigned { + return uint8(0) + } + return int8(0) + case "smallint", "int2", + "year": + if column.DataType.IsUnsigned { + return uint16(0) + } + return int16(0) + case "integer", "int4", + "mediumint", "int": //MySQL + if column.DataType.IsUnsigned { + return uint32(0) + } + return int32(0) + case "bigint", "int8": + if column.DataType.IsUnsigned { + return uint64(0) + } + return int64(0) + case "date", + "timestamp without time zone", "timestamp", + "timestamp with time zone", "timestamptz", + "time without time zone", "time", + "time with time zone", "timetz", + "datetime": // MySQL + return time.Time{} + case "bytea", + "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL + return []byte("") + case "text", + "character", "bpchar", + "character varying", "varchar", + "tsvector", "bit", "bit varying", "varbit", + "money", "json", "jsonb", + "xml", "point", "interval", "line", "ARRAY", + "char", "tinytext", "mediumtext", "longtext": // MySQL + return "" + case "real", "float4": + return float32(0.0) + case "numeric", "decimal", + "double precision", "float8", "float", + "double": // MySQL + return float64(0.0) + case "uuid": + return uuid.UUID{} + default: + fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.") + return "" + } +} diff --git a/generator/template/model_template_test.go b/generator/template/model_template_test.go new file mode 100644 index 0000000..a7bbe28 --- /dev/null +++ b/generator/template/model_template_test.go @@ -0,0 +1,45 @@ +package template + +import ( + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/stretchr/testify/require" + "testing" +) + +func Test_TableModelField(t *testing.T) { + require.Equal(t, DefaultTableModelField(metadata.Column{ + Name: "col_name", + IsPrimaryKey: true, + IsNullable: true, + DataType: metadata.DataType{ + Name: "smallint", + Kind: "base", + IsUnsigned: true, + }, + }), TableModelField{ + Name: "ColName", + Type: Type{ + ImportPath: "", + Name: "*uint16", + }, + Tags: []string{"sql:\"primary_key\""}, + }) + + require.Equal(t, DefaultTableModelField(metadata.Column{ + Name: "time_column_1", + IsPrimaryKey: false, + IsNullable: true, + DataType: metadata.DataType{ + Name: "timestamp with time zone", + Kind: "base", + IsUnsigned: false, + }, + }), TableModelField{ + Name: "TimeColumn1", + Type: Type{ + ImportPath: "time", + Name: "*time.Time", + }, + Tags: nil, + }) +} diff --git a/generator/template/process.go b/generator/template/process.go new file mode 100644 index 0000000..ff3775e --- /dev/null +++ b/generator/template/process.go @@ -0,0 +1,269 @@ +package template + +import ( + "bytes" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/jet" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" + "path" + "strings" + "text/template" +) + +// ProcessSchema will process schema metadata and constructs go files using generator Template +func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) { + if schemaMetaData.IsEmpty() { + return + } + + schemaTemplate := generatorTemplate.Schema(schemaMetaData) + schemaPath := path.Join(dirPath, schemaTemplate.Path) + + fmt.Println("Destination directory:", schemaPath) + fmt.Println("Cleaning up destination directory...") + err := utils.CleanUpGeneratedFiles(schemaPath) + throw.OnError(err) + + processModel(schemaPath, schemaMetaData, schemaTemplate) + processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate) +} + +func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) { + modelTemplate := schemaTemplate.Model + + if modelTemplate.Skip { + fmt.Println("Skipping the generation of model types.") + return + } + + modelDirPath := path.Join(dirPath, modelTemplate.Path) + + err := utils.EnsureDirPath(modelDirPath) + throw.OnError(err) + + processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate) + processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate) + processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate) +} + +func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) { + sqlBuilderTemplate := schemaTemplate.SQLBuilder + + if sqlBuilderTemplate.Skip { + fmt.Println("Skipping the generation of SQL Builder types.") + return + } + + sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path) + + processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate) + processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate) + processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate) +} + +func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) { + if len(enumsMetaData) == 0 { + return + } + + fmt.Printf("Generating enum sql builder files\n") + + for _, enumMetaData := range enumsMetaData { + enumTemplate := sqlBuilder.Enum(enumMetaData) + + if enumTemplate.Skip { + continue + } + + enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path) + + err := utils.EnsureDirPath(enumSQLBuilderPath) + throw.OnError(err) + + text, err := generateTemplate( + autoGenWarningTemplate+enumSQLBuilderTemplate, + enumMetaData, + template.FuncMap{ + "package": func() string { + return enumTemplate.PackageName() + }, + "dialect": func() jet.Dialect { + return dialect + }, + "enumTemplate": func() EnumSQLBuilder { + return enumTemplate + }, + "enumValueName": func(enumValue string) string { + return enumTemplate.ValueName(enumValue) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(enumSQLBuilderPath, enumTemplate.FileName, text) + throw.OnError(err) + } +} + +func processTableSQLBuilder(fileTypes, dirPath string, + dialect jet.Dialect, + schemaMetaData metadata.Schema, + tablesMetaData []metadata.Table, + sqlBuilderTemplate SQLBuilder) { + + if len(tablesMetaData) == 0 { + return + } + + fmt.Printf("Generating %s sql builder files\n", fileTypes) + + for _, tableMetaData := range tablesMetaData { + + var tableSQLBuilderTemplate TableSQLBuilder + + if fileTypes == "view" { + tableSQLBuilderTemplate = sqlBuilderTemplate.View(tableMetaData) + } else { + tableSQLBuilderTemplate = sqlBuilderTemplate.Table(tableMetaData) + } + + if tableSQLBuilderTemplate.Skip { + continue + } + + tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilderTemplate.Path) + + err := utils.EnsureDirPath(tableSQLBuilderPath) + throw.OnError(err) + + text, err := generateTemplate( + autoGenWarningTemplate+getTableSQLBuilderTemplate(dialect), + tableMetaData, + template.FuncMap{ + "package": func() string { + return tableSQLBuilderTemplate.PackageName() + }, + "dialect": func() jet.Dialect { + return dialect + }, + "schemaName": func() string { + return schemaMetaData.Name + }, + "tableTemplate": func() TableSQLBuilder { + return tableSQLBuilderTemplate + }, + "structImplName": func() string { // postgres only + structName := tableSQLBuilderTemplate.TypeName + return string(strings.ToLower(structName)[0]) + structName[1:] + }, + "columnField": func(columnMetaData metadata.Column) TableSQLBuilderColumn { + return tableSQLBuilderTemplate.Column(columnMetaData) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(tableSQLBuilderPath, tableSQLBuilderTemplate.FileName, text) + throw.OnError(err) + } +} + +func getTableSQLBuilderTemplate(dialect jet.Dialect) string { + if dialect.Name() == "PostgreSQL" { + return tablePostgreSQLBuilderTemplate + } + + return tableSQLBuilderTemplate +} + +func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) { + if len(tablesMetaData) == 0 { + return + } + fmt.Printf("Generating %s model files...\n", fileTypes) + + for _, tableMetaData := range tablesMetaData { + var tableTemplate TableModel + + if fileTypes == "table" { + tableTemplate = modelTemplate.Table(tableMetaData) + } else { + tableTemplate = modelTemplate.View(tableMetaData) + } + + if tableTemplate.Skip { + continue + } + + text, err := generateTemplate( + autoGenWarningTemplate+tableModelFileTemplate, + tableMetaData, + template.FuncMap{ + "package": func() string { + return modelTemplate.PackageName() + }, + "modelImports": func() []string { + return getTableModelImports(tableTemplate, tableMetaData) + }, + "tableTemplate": func() TableModel { + return tableTemplate + }, + "structField": func(columnMetaData metadata.Column) TableModelField { + return tableTemplate.Field(columnMetaData) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(modelDirPath, tableTemplate.FileName, text) + throw.OnError(err) + } +} + +func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) { + if len(enumsMetaData) == 0 { + return + } + fmt.Print("Generating enum model files...\n") + + for _, enumMetaData := range enumsMetaData { + enumTemplate := modelTemplate.Enum(enumMetaData) + + if enumTemplate.Skip { + continue + } + + text, err := generateTemplate( + autoGenWarningTemplate+enumModelTemplate, + enumMetaData, + template.FuncMap{ + "package": func() string { + return modelTemplate.PackageName() + }, + "enumTemplate": func() EnumModel { + return enumTemplate + }, + "valueName": func(value string) string { + return enumTemplate.ValueName(value) + }, + }) + throw.OnError(err) + + err = utils.SaveGoFile(modelDir, enumTemplate.FileName, text) + throw.OnError(err) + } +} + +func generateTemplate(templateText string, templateData interface{}, funcMap template.FuncMap) ([]byte, error) { + t, err := template.New("sqlBuilderTableTemplate").Funcs(funcMap).Parse(templateText) + + if err != nil { + return nil, err + } + + var buf bytes.Buffer + if err := t.Execute(&buf, templateData); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go new file mode 100644 index 0000000..8d2d18c --- /dev/null +++ b/generator/template/sql_builder_template.go @@ -0,0 +1,225 @@ +package template + +import ( + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils" + "path" + "unicode" +) + +// SQLBuilder is template for generating sql builder files +type SQLBuilder struct { + Skip bool + Path string + Table func(table metadata.Table) TableSQLBuilder + View func(view metadata.Table) TableSQLBuilder + Enum func(enum metadata.Enum) EnumSQLBuilder +} + +// DefaultSQLBuilder returns default SQLBuilder implementation +func DefaultSQLBuilder() SQLBuilder { + return SQLBuilder{ + Path: "", + Table: DefaultTableSQLBuilder, + View: DefaultViewSQLBuilder, + Enum: DefaultEnumSQLBuilder, + } +} + +// UsePath returns new SQLBuilder with new relative path set +func (sb SQLBuilder) UsePath(path string) SQLBuilder { + sb.Path = path + return sb +} + +// UseTable returns new SQLBuilder with new TableSQLBuilder template function set +func (sb SQLBuilder) UseTable(tableFunc func(table metadata.Table) TableSQLBuilder) SQLBuilder { + sb.Table = tableFunc + return sb +} + +// UseView returns new SQLBuilder with new ViewSQLBuilder template function set +func (sb SQLBuilder) UseView(viewFunc func(table metadata.Table) ViewSQLBuilder) SQLBuilder { + sb.View = viewFunc + return sb +} + +// UseEnum returns new SQLBuilder with new EnumSQLBuilder template function set +func (sb SQLBuilder) UseEnum(enumFunc func(enum metadata.Enum) EnumSQLBuilder) SQLBuilder { + sb.Enum = enumFunc + return sb +} + +// TableSQLBuilder is template for generating table SQLBuilder files +type TableSQLBuilder struct { + Skip bool + Path string + FileName string + InstanceName string + TypeName string + Column func(columnMetaData metadata.Column) TableSQLBuilderColumn +} + +// ViewSQLBuilder is template for generating view SQLBuilder files +type ViewSQLBuilder = TableSQLBuilder + +// DefaultTableSQLBuilder returns default implementation for TableSQLBuilder +func DefaultTableSQLBuilder(tableMetaData metadata.Table) TableSQLBuilder { + return TableSQLBuilder{ + Path: "/table", + FileName: utils.ToGoFileName(tableMetaData.Name), + InstanceName: utils.ToGoIdentifier(tableMetaData.Name), + TypeName: utils.ToGoIdentifier(tableMetaData.Name) + "Table", + Column: DefaultTableSQLBuilderColumn, + } +} + +// DefaultViewSQLBuilder returns default implementation for ViewSQLBuilder +func DefaultViewSQLBuilder(viewMetaData metadata.Table) ViewSQLBuilder { + tableSQLBuilder := DefaultTableSQLBuilder(viewMetaData) + tableSQLBuilder.Path = "/view" + return tableSQLBuilder +} + +// PackageName returns package name of table sql builder types +func (tb TableSQLBuilder) PackageName() string { + return path.Base(tb.Path) +} + +// UsePath returns new TableSQLBuilder with new relative path set +func (tb TableSQLBuilder) UsePath(path string) TableSQLBuilder { + tb.Path = path + return tb +} + +// UseFileName returns new TableSQLBuilder with new file name set +func (tb TableSQLBuilder) UseFileName(name string) TableSQLBuilder { + tb.FileName = name + return tb +} + +// UseInstanceName returns new TableSQLBuilder with new instance name set +func (tb TableSQLBuilder) UseInstanceName(name string) TableSQLBuilder { + tb.InstanceName = name + return tb +} + +// UseTypeName returns new TableSQLBuilder with new type name set +func (tb TableSQLBuilder) UseTypeName(name string) TableSQLBuilder { + tb.TypeName = name + return tb +} + +// UseColumn returns new TableSQLBuilder with new column template function set +func (tb TableSQLBuilder) UseColumn(columnsFunc func(column metadata.Column) TableSQLBuilderColumn) TableSQLBuilder { + tb.Column = columnsFunc + return tb +} + +// TableSQLBuilderColumn is template for table sql builder column +type TableSQLBuilderColumn struct { + Name string + Type string +} + +// DefaultTableSQLBuilderColumn returns default implementation of TableSQLBuilderColumn +func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilderColumn { + return TableSQLBuilderColumn{ + Name: utils.ToGoIdentifier(columnMetaData.Name), + Type: getSqlBuilderColumnType(columnMetaData), + } +} + +// getSqlBuilderColumnType returns type of jet sql builder column +func getSqlBuilderColumnType(columnMetaData metadata.Column) string { + if columnMetaData.DataType.Kind != metadata.BaseType { + return "String" + } + + switch columnMetaData.DataType.Name { + case "boolean": + return "Bool" + case "smallint", "integer", "bigint", + "tinyint", "mediumint", "int", "year": //MySQL + return "Integer" + case "date": + return "Date" + case "timestamp without time zone", + "timestamp", "datetime": //MySQL: + return "Timestamp" + case "timestamp with time zone": + return "Timestampz" + case "time without time zone", + "time": //MySQL + return "Time" + case "time with time zone": + return "Timez" + case "interval": + return "Interval" + case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid", + "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", + "char", "varchar", "binary", "varbinary", + "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL + return "String" + case "real", "numeric", "decimal", "double precision", "float", + "double": // MySQL + return "Float" + default: + fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.") + return "String" + } +} + +// EnumSQLBuilder is template for generating enum SQLBuilder files +type EnumSQLBuilder struct { + Skip bool + Path string + FileName string + InstanceName string + ValueName func(enumValue string) string +} + +// DefaultEnumSQLBuilder returns default implementation of EnumSQLBuilder +func DefaultEnumSQLBuilder(enumMetaData metadata.Enum) EnumSQLBuilder { + return EnumSQLBuilder{ + Path: "/enum", + FileName: utils.ToGoFileName(enumMetaData.Name), + InstanceName: utils.ToGoIdentifier(enumMetaData.Name), + ValueName: func(enumValue string) string { + return defaultEnumValueName(enumMetaData.Name, enumValue) + }, + } +} + +// PackageName returns enum sql builder package name +func (e EnumSQLBuilder) PackageName() string { + return path.Base(e.Path) +} + +// UsePath returns new EnumSQLBuilder with new path set +func (e EnumSQLBuilder) UsePath(path string) EnumSQLBuilder { + e.Path = path + return e +} + +// UseFileName returns new EnumSQLBuilder with new file name set +func (e EnumSQLBuilder) UseFileName(name string) EnumSQLBuilder { + e.FileName = name + return e +} + +// UseInstanceName returns new EnumSQLBuilder with instance name set +func (e EnumSQLBuilder) UseInstanceName(name string) EnumSQLBuilder { + e.InstanceName = name + return e +} + +func defaultEnumValueName(enumName, enumValue string) string { + enumValueName := utils.ToGoIdentifier(enumValue) + if !unicode.IsLetter([]rune(enumValueName)[0]) { + return utils.ToGoIdentifier(enumName) + enumValueName + } + + return enumValueName +} diff --git a/generator/template/sql_builder_template_test.go b/generator/template/sql_builder_template_test.go new file mode 100644 index 0000000..b3719d7 --- /dev/null +++ b/generator/template/sql_builder_template_test.go @@ -0,0 +1,11 @@ +package template + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestToGoEnumValueIdentifier(t *testing.T) { + require.Equal(t, defaultEnumValueName("enum_name", "enum_value"), "EnumValue") + require.Equal(t, defaultEnumValueName("NumEnum", "100"), "NumEnum100") +} diff --git a/internal/3rdparty/snaker/snaker.go b/internal/3rdparty/snaker/snaker.go index aadd928..32a19e6 100644 --- a/internal/3rdparty/snaker/snaker.go +++ b/internal/3rdparty/snaker/snaker.go @@ -9,8 +9,12 @@ import ( ) // SnakeToCamel returns a string converted from snake case to uppercase -func SnakeToCamel(s string) string { - return snakeToCamel(s, true) +func SnakeToCamel(s string, firstLetterUppercase ...bool) string { + upperCase := true + if len(firstLetterUppercase) > 0 { + upperCase = firstLetterUppercase[0] + } + return snakeToCamel(s, upperCase) } func snakeToCamel(s string, upperCase bool) string { diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index dd5e790..4158977 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -5,7 +5,7 @@ import ( "encoding/json" "fmt" "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -66,7 +66,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) { filePath := getFullPath(testRelativePath) err := ioutil.WriteFile(filePath, jsonText, 0644) - utils.PanicOnError(err) + throw.OnError(err) } // AssertJSONFile check if data json representation is the same as json at testRelativePath diff --git a/internal/testutils/time_utils.go b/internal/testutils/time_utils.go index 2bf6653..b48129c 100644 --- a/internal/testutils/time_utils.go +++ b/internal/testutils/time_utils.go @@ -1,7 +1,7 @@ package testutils import ( - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "strings" "time" ) @@ -10,7 +10,7 @@ import ( func Date(t string) *time.Time { newTime, err := time.Parse("2006-01-02", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTime } @@ -26,7 +26,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time { newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") - utils.PanicOnError(err) + throw.OnError(err) return &newTime } @@ -35,7 +35,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time { func TimeWithoutTimeZone(t string) *time.Time { newTime, err := time.Parse("15:04:05", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTime } @@ -44,7 +44,7 @@ func TimeWithoutTimeZone(t string) *time.Time { func TimeWithTimeZone(t string) *time.Time { newTimez, err := time.Parse("15:04:05 -0700", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTimez } @@ -60,7 +60,7 @@ func TimestampWithTimeZone(t string, precision int) *time.Time { newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) - utils.PanicOnError(err) + throw.OnError(err) return &newTime } diff --git a/internal/utils/throw/throw.go b/internal/utils/throw/throw.go new file mode 100644 index 0000000..9595c8b --- /dev/null +++ b/internal/utils/throw/throw.go @@ -0,0 +1,8 @@ +package throw + +// OnError will panic if err is not nill +func OnError(err error) { + if err != nil { + panic(err) + } +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 55005b4..6f6f178 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -10,7 +10,6 @@ import ( "reflect" "strings" "time" - "unicode" ) // ToGoIdentifier converts database to Go identifier. @@ -18,16 +17,6 @@ func ToGoIdentifier(databaseIdentifier string) string { return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) } -// ToGoEnumValueIdentifier converts enum value name to Go identifier name. -func ToGoEnumValueIdentifier(enumName, enumValue string) string { - enumValueIdentifier := ToGoIdentifier(enumValue) - if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) { - return ToGoIdentifier(enumName) + enumValueIdentifier - } - - return enumValueIdentifier -} - // ToGoFileName converts database identifier to Go file name. func ToGoFileName(databaseIdentifier string) string { return strings.ToLower(replaceInvalidChars(databaseIdentifier)) @@ -35,7 +24,11 @@ func ToGoFileName(databaseIdentifier string) string { // SaveGoFile saves go file at folder dir, with name fileName and contents text. func SaveGoFile(dirPath, fileName string, text []byte) error { - newGoFilePath := filepath.Join(dirPath, fileName) + ".go" + newGoFilePath := filepath.Join(dirPath, fileName) + + if !strings.HasSuffix(newGoFilePath, ".go") { + newGoFilePath += ".go" + } file, err := os.Create(newGoFilePath) @@ -160,13 +153,6 @@ func MustBeInitializedPtr(val interface{}, errorStr string) { } } -// PanicOnError panics if err is not nil -func PanicOnError(err error) { - if err != nil { - panic(err) - } -} - // ErrorCatch is used in defer to recover from panics and to set err func ErrorCatch(err *error) { recovered := recover() diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index f2b4f84..f374929 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -25,11 +25,6 @@ func TestToGoIdentifier(t *testing.T) { require.Equal(t, ToGoIdentifier("My-Table"), "MyTable") } -func TestToGoEnumValueIdentifier(t *testing.T) { - require.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue") - require.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100") -} - func TestErrorCatchErr(t *testing.T) { var err error diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 9d9e059..6e2b4c8 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "fmt" "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "reflect" "strings" ) @@ -216,7 +217,7 @@ func (s *scanContext) rowElem(index int) interface{} { value, err := valuer.Value() - utils.PanicOnError(err) + throw.OnError(err) return value } diff --git a/tests/dbconfig/dbconfig.go b/tests/dbconfig/dbconfig.go index cf48420..0481252 100644 --- a/tests/dbconfig/dbconfig.go +++ b/tests/dbconfig/dbconfig.go @@ -4,15 +4,15 @@ import "fmt" // Postgres test database connection parameters const ( - Host = "localhost" - Port = 5432 - User = "jet" - Password = "jet" - DBName = "jetdb" + PgHost = "localhost" + PgPort = 5432 + PgUser = "jet" + PgPassword = "jet" + PgDBName = "jetdb" ) // PostgresConnectString is PostgreSQL test database connection string -var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) +var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", PgHost, PgPort, PgUser, PgPassword, PgDBName) // MySQL test database connection parameters const ( diff --git a/tests/init/init.go b/tests/init/init.go index a2f6eb3..a304804 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -6,7 +6,7 @@ import ( "fmt" "github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/postgres" - "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/tests/dbconfig" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" @@ -62,7 +62,7 @@ func initMySQLDB() { cmd.Stdout = os.Stdout err := cmd.Run() - utils.PanicOnError(err) + throw.OnError(err) err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ Host: dbconfig.MySqLHost, @@ -72,7 +72,7 @@ func initMySQLDB() { DBName: dbName, }) - utils.PanicOnError(err) + throw.OnError(err) } } @@ -99,24 +99,24 @@ func initPostgresDB() { execFile(db, "./testdata/init/postgres/"+schemaName+".sql") err = postgres.Generate("./.gentestdata", postgres.DBConnection{ - Host: dbconfig.Host, + Host: dbconfig.PgHost, Port: 5432, - User: dbconfig.User, - Password: dbconfig.Password, - DBName: dbconfig.DBName, + User: dbconfig.PgUser, + Password: dbconfig.PgPassword, + DBName: dbconfig.PgDBName, SchemaName: schemaName, SslMode: "disable", }) - utils.PanicOnError(err) + throw.OnError(err) } } func execFile(db *sql.DB, sqlFilePath string) { testSampleSql, err := ioutil.ReadFile(sqlFilePath) - utils.PanicOnError(err) + throw.OnError(err) _, err = db.Exec(string(testSampleSql)) - utils.PanicOnError(err) + throw.OnError(err) } func printOnError(err error) { diff --git a/tests/internal/utils/file/file.go b/tests/internal/utils/file/file.go new file mode 100644 index 0000000..6d08d22 --- /dev/null +++ b/tests/internal/utils/file/file.go @@ -0,0 +1,25 @@ +package file + +import ( + "github.com/stretchr/testify/require" + "io/ioutil" + "os" + "path" + "testing" +) + +// Exists expects file to exist on path constructed from pathElems and returns content of the file +func Exists(t *testing.T, pathElems ...string) (fileContent string) { + modelFilePath := path.Join(pathElems...) + file, err := ioutil.ReadFile(modelFilePath) + require.Nil(t, err) + require.NotEmpty(t, file) + return string(file) +} + +// NotExists expects file not to exist on path constructed from pathElems +func NotExists(t *testing.T, pathElems ...string) { + modelFilePath := path.Join(pathElems...) + _, err := ioutil.ReadFile(modelFilePath) + require.True(t, os.IsNotExist(err)) +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index d96c1d3..b12585e 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -467,10 +467,10 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.NOT_LIKE(String("_b_")), AllTypes.Text.REGEXP_LIKE(String("aba")), AllTypes.Text.REGEXP_LIKE(String("aba"), false), - String("ABA").REGEXP_LIKE(String("aba"), true), + //String("ABA").REGEXP_LIKE(String("aba"), true), AllTypes.Text.NOT_REGEXP_LIKE(String("aba")), AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false), - String("ABA").NOT_REGEXP_LIKE(String("aba"), true), + //String("ABA").NOT_REGEXP_LIKE(String("aba"), true), BIT_LENGTH(AllTypes.Text), CHAR_LENGTH(AllTypes.Char), diff --git a/tests/mysql/generator_template_test.go b/tests/mysql/generator_template_test.go new file mode 100644 index 0000000..e915e0f --- /dev/null +++ b/tests/mysql/generator_template_test.go @@ -0,0 +1,389 @@ +package mysql + +import ( + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + mysql2 "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/3rdparty/snaker" + "github.com/go-jet/jet/v2/internal/utils" + postgres2 "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/dbconfig" + file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" + "github.com/stretchr/testify/require" + "path" + "testing" +) + +const tempTestDir = "./.tempTestDir" + +var defaultModelPath = path.Join(tempTestDir, "dvds/model") +var defaultActorModelFilePath = path.Join(tempTestDir, "dvds/model", "actor.go") +var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table") +var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "dvds/view") +var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "dvds/enum") +var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table", "actor.go") + +var dbConnection = mysql2.DBConnection{ + Host: dbconfig.MySqLHost, + Port: dbconfig.MySQLPort, + User: dbconfig.MySQLUser, + Password: dbconfig.MySQLPassword, + DBName: "dvds", +} + +func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path") + }), + ) + + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "new/schema/path/model/actor.go") + file2.Exists(t, tempTestDir, "new/schema/path/table/actor.go") + file2.Exists(t, tempTestDir, "new/schema/path/view/actor_info.go") + file2.Exists(t, tempTestDir, "new/schema/path/enum/film_rating.go") +} + +func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.Model{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.NotExists(t, defaultActorModelFilePath) + file2.Exists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.Exists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.Exists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.SQLBuilder{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.Exists(t, defaultActorModelFilePath) + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_Model_ChangePath(t *testing.T) { + const newModelPath = "/new/model/path" + + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "dvds", newModelPath, "actor.go") + file2.NotExists(t, defaultActorModelFilePath) +} + +func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) { + const newModelPath = "/new/sql-builder/path" + + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "dvds", newModelPath, "table", "actor.go") + file2.Exists(t, tempTestDir, "dvds", newModelPath, "view", "actor_info.go") + file2.Exists(t, tempTestDir, "dvds", newModelPath, "enum", "film_rating.go") + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name). + UseTypeName(utils.ToGoIdentifier(table.Name) + "Table") + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "View") + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.DefaultEnumModel(enumMetaData). + UseFileName(enumMetaData.Name + "_enum"). + UseTypeName(utils.ToGoIdentifier(enumMetaData.Name) + "Enum") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultModelPath, "dvds_actor.go") + require.Contains(t, actor, "type ActorTable struct {") + + actorInfo := file2.Exists(t, defaultModelPath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoView struct {") + + mpaaRating := file2.Exists(t, defaultModelPath, "film_rating_enum.go") + require.Contains(t, mpaaRating, "type FilmRatingEnum string") +} + +func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.TableModel{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.EnumModel{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultModelPath, "actor.go") + file2.Exists(t, defaultModelPath, "actor_info.go") + file2.NotExists(t, defaultModelPath, "film_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseView(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumSQLBuilder { + return template.EnumSQLBuilder{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_table"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "TableSQLBuilder"). + UseInstanceName("T_" + utils.ToGoIdentifier(table.Name)) + }). + UseView(func(table metadata.Table) template.ViewSQLBuilder { + return template.DefaultViewSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "ViewSQLBuilder"). + UseInstanceName("V_" + utils.ToGoIdentifier(table.Name)) + }). + UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder { + return template.DefaultEnumSQLBuilder(enum). + UseFileName(schemaMetaData.Name + "_" + enum.Name + "_enum"). + UseInstanceName(utils.ToGoIdentifier(enum.Name) + "EnumSQLBuilder") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "dvds_actor_table.go") + require.Contains(t, actor, "type ActorTableSQLBuilder struct {") + require.Contains(t, actor, "var T_Actor = newActorTableSQLBuilder(\"dvds\", \"actor\", \"\")") + actorInfo := file2.Exists(t, defaultViewSQLBuilderFilePath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoViewSQLBuilder struct {") + require.Contains(t, actorInfo, "var V_ActorInfo = newActorInfoViewSQLBuilder(\"dvds\", \"actor_info\", \"\")") + mpaaRating := file2.Exists(t, defaultEnumSQLBuilderFilePath, "dvds_film_rating_enum.go") + require.Contains(t, mpaaRating, "var FilmRatingEnumSQLBuilder = &struct {") +} + +func TestGeneratorTemplate_Model_AddTags(t *testing.T) { + + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + return defaultTableModelField.UseTags( + fmt.Sprintf(`json:"%s"`, snaker.SnakeToCamel(columnMetaData.Name, false)), + fmt.Sprintf(`xml:"%s"`, columnMetaData.Name), + ) + }) + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + if table.Name == "actor_info" && columnMetaData.Name == "actor_id" { + return defaultTableModelField.UseTags(`sql:"primary_key"`) + } + return defaultTableModelField + }) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorModelFilePath) + require.Contains(t, actor, "ActorID uint16 `sql:\"primary_key\" json:\"actorID\" xml:\"actor_id\"`") + require.Contains(t, actor, "FirstName string `json:\"firstName\" xml:\"first_name\"`") + + actorInfo := file2.Exists(t, defaultModelPath, "actor_info.go") + require.Contains(t, actorInfo, "ActorID uint16 `sql:\"primary_key\"`") +} + +func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + + switch defaultTableModelField.Type.Name { + case "*string": + defaultTableModelField.Type = template.NewType(sql.NullString{}) + case "*int32": + defaultTableModelField.Type = template.NewType(sql.NullInt32{}) + case "*int64": + defaultTableModelField.Type = template.NewType(sql.NullInt64{}) + case "*bool": + defaultTableModelField.Type = template.NewType(sql.NullBool{}) + case "*float64": + defaultTableModelField.Type = template.NewType(sql.NullFloat64{}) + case "*time.Time": + defaultTableModelField.Type = template.NewType(sql.NullTime{}) + } + return defaultTableModelField + }) + }), + ) + }), + ) + + require.Nil(t, err) + + data := file2.Exists(t, defaultModelPath, "film.go") + require.Contains(t, data, "\"database/sql\"") + require.Contains(t, data, "Description sql.NullString") + require.Contains(t, data, "ReleaseYear *int16") + require.Contains(t, data, "SpecialFeatures sql.NullString") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { + err := mysql2.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseColumn(func(column metadata.Column) template.TableSQLBuilderColumn { + defaultColumn := template.DefaultTableSQLBuilderColumn(column) + + if defaultColumn.Name == "ActorID" { + defaultColumn.Type = "String" + } + + return defaultColumn + }) + }), + ) + }), + ) + + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorSQLBuilderFilePath) + require.Contains(t, actor, "ActorID postgres.ColumnString") +} diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index 3b6aa75..281e17b 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -66,9 +66,7 @@ func TestUpdateWithSubQueries(t *testing.T) { expectedSQL := ` UPDATE test_sample.link -SET name = ( - SELECT ? - ), +SET name = ?, url = ( SELECT link2.url AS "link2.url" FROM test_sample.link2 @@ -80,7 +78,7 @@ WHERE link.name = ?; query := Link. UPDATE(Link.Name, Link.URL). SET( - SELECT(String("Bong")), + String("Bong"), SELECT(Link2.URL). FROM(Link2). WHERE(Link2.Name.EQ(String("Youtube"))), @@ -96,7 +94,7 @@ WHERE link.name = ?; query := Link. UPDATE(). SET( - Link.Name.SET(StringExp(SELECT(String("Bong")))), + Link.Name.SET(String("Bong")), Link.URL.SET(StringExp( SELECT(Link2.URL). FROM(Link2). diff --git a/tests/postgres/generator_template_test.go b/tests/postgres/generator_template_test.go new file mode 100644 index 0000000..85dd01e --- /dev/null +++ b/tests/postgres/generator_template_test.go @@ -0,0 +1,387 @@ +package postgres + +import ( + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/postgres" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/3rdparty/snaker" + "github.com/go-jet/jet/v2/internal/utils" + postgres2 "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/tests/dbconfig" + file2 "github.com/go-jet/jet/v2/tests/internal/utils/file" + "github.com/stretchr/testify/require" + "path" + "testing" +) + +const tempTestDir = "./.tempTestDir" + +var defaultModelPath = path.Join(tempTestDir, "jetdb/dvds/model") +var defaultActorModelFilePath = path.Join(tempTestDir, "jetdb/dvds/model", "actor.go") +var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table") +var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/view") +var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/enum") +var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table", "actor.go") + +var dbConnection = postgres.DBConnection{ + Host: dbconfig.PgHost, + Port: 5432, + User: dbconfig.PgUser, + Password: dbconfig.PgPassword, + DBName: dbconfig.PgDBName, + SchemaName: "dvds", + SslMode: "disable", +} + +func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path") + }), + ) + + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/model/actor.go") + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/table/actor.go") + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/view/actor_info.go") + file2.Exists(t, tempTestDir, "jetdb/new/schema/path/enum/mpaa_rating.go") +} + +func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.Model{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.NotExists(t, defaultActorModelFilePath) +} + +func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.SQLBuilder{ + Skip: true, + }) + }), + ) + + require.Nil(t, err) + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_Model_ChangePath(t *testing.T) { + const newModelPath = "/new/model/path" + + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "actor.go") + file2.NotExists(t, defaultActorModelFilePath) +} + +func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) { + const newModelPath = "/new/sql-builder/path" + + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder().UsePath(newModelPath)) + }), + ) + require.Nil(t, err) + + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "table", "actor.go") + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "view", "actor_info.go") + file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "enum", "mpaa_rating.go") + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name). + UseTypeName(utils.ToGoIdentifier(table.Name) + "Table") + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "View") + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.DefaultEnumModel(enumMetaData). + UseFileName(enumMetaData.Name + "_enum"). + UseTypeName(utils.ToGoIdentifier(enumMetaData.Name) + "Enum") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultModelPath, "dvds_actor.go") + require.Contains(t, actor, "type ActorTable struct {") + + actorInfo := file2.Exists(t, defaultModelPath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoView struct {") + + mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go") + require.Contains(t, mpaaRating, "type MpaaRatingEnum string") +} + +func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.TableModel{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumModel { + return template.EnumModel{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultModelPath, "actor.go") + file2.Exists(t, defaultModelPath, "actor_info.go") + file2.NotExists(t, defaultModelPath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseView(func(table metadata.Table) template.TableSQLBuilder { + return template.TableSQLBuilder{ + Skip: true, + } + }). + UseEnum(func(enumMetaData metadata.Enum) template.EnumSQLBuilder { + return template.EnumSQLBuilder{ + Skip: true, + } + }), + ) + }), + ) + require.Nil(t, err) + + file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go") + file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go") + file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_table"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "TableSQLBuilder"). + UseInstanceName("T_" + utils.ToGoIdentifier(table.Name)) + }). + UseView(func(table metadata.Table) template.ViewSQLBuilder { + return template.DefaultViewSQLBuilder(table). + UseFileName(schemaMetaData.Name + "_" + table.Name + "_view"). + UseTypeName(utils.ToGoIdentifier(table.Name) + "ViewSQLBuilder"). + UseInstanceName("V_" + utils.ToGoIdentifier(table.Name)) + }). + UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder { + return template.DefaultEnumSQLBuilder(enum). + UseFileName(schemaMetaData.Name + "_" + enum.Name + "_enum"). + UseInstanceName(utils.ToGoIdentifier(enum.Name) + "EnumSQLBuilder") + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "dvds_actor_table.go") + require.Contains(t, actor, "type ActorTableSQLBuilder struct {") + require.Contains(t, actor, "var T_Actor = newActorTableSQLBuilder(\"dvds\", \"actor\", \"\")") + actorInfo := file2.Exists(t, defaultViewSQLBuilderFilePath, "dvds_actor_info_view.go") + require.Contains(t, actorInfo, "type ActorInfoViewSQLBuilder struct {") + require.Contains(t, actorInfo, "var V_ActorInfo = newActorInfoViewSQLBuilder(\"dvds\", \"actor_info\", \"\")") + mpaaRating := file2.Exists(t, defaultEnumSQLBuilderFilePath, "dvds_mpaa_rating_enum.go") + require.Contains(t, mpaaRating, "var MpaaRatingEnumSQLBuilder = &struct {") +} + +func TestGeneratorTemplate_Model_AddTags(t *testing.T) { + + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + return defaultTableModelField.UseTags( + fmt.Sprintf(`json:"%s"`, snaker.SnakeToCamel(columnMetaData.Name, false)), + fmt.Sprintf(`xml:"%s"`, columnMetaData.Name), + ) + }) + }). + UseView(func(table metadata.Table) template.ViewModel { + return template.DefaultViewModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + if table.Name == "actor_info" && columnMetaData.Name == "actor_id" { + return defaultTableModelField.UseTags(`sql:"primary_key"`) + } + return defaultTableModelField + }) + }), + ) + }), + ) + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorModelFilePath) + require.Contains(t, actor, "ActorID int32 `sql:\"primary_key\" json:\"actorID\" xml:\"actor_id\"`") + require.Contains(t, actor, "FirstName string `json:\"firstName\" xml:\"first_name\"`") + + actorInfo := file2.Exists(t, defaultModelPath, "actor_info.go") + require.Contains(t, actorInfo, "ActorID *int32 `sql:\"primary_key\"`") +} + +func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseModel(template.DefaultModel(). + UseTable(func(table metadata.Table) template.TableModel { + return template.DefaultTableModel(table). + UseField(func(columnMetaData metadata.Column) template.TableModelField { + defaultTableModelField := template.DefaultTableModelField(columnMetaData) + + switch defaultTableModelField.Type.Name { + case "*string": + defaultTableModelField.Type = template.NewType(sql.NullString{}) + case "*int32": + defaultTableModelField.Type = template.NewType(sql.NullInt32{}) + case "*int64": + defaultTableModelField.Type = template.NewType(sql.NullInt64{}) + case "*bool": + defaultTableModelField.Type = template.NewType(sql.NullBool{}) + case "*float64": + defaultTableModelField.Type = template.NewType(sql.NullFloat64{}) + case "*time.Time": + defaultTableModelField.Type = template.NewType(sql.NullTime{}) + } + return defaultTableModelField + }) + }), + ) + }), + ) + + require.Nil(t, err) + + data := file2.Exists(t, defaultModelPath, "film.go") + require.Contains(t, data, "\"database/sql\"") + require.Contains(t, data, "Description sql.NullString") + require.Contains(t, data, "ReleaseYear sql.NullInt32") + require.Contains(t, data, "SpecialFeatures sql.NullString") +} + +func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) { + err := postgres.Generate( + tempTestDir, + dbConnection, + template.Default(postgres2.Dialect). + UseSchema(func(schemaMetaData metadata.Schema) template.Schema { + return template.DefaultSchema(schemaMetaData). + UseSQLBuilder(template.DefaultSQLBuilder(). + UseTable(func(table metadata.Table) template.TableSQLBuilder { + return template.DefaultTableSQLBuilder(table). + UseColumn(func(column metadata.Column) template.TableSQLBuilderColumn { + defaultColumn := template.DefaultTableSQLBuilderColumn(column) + + if defaultColumn.Name == "ActorID" { + defaultColumn.Type = "String" + } + + return defaultColumn + }) + }), + ) + }), + ) + + require.Nil(t, err) + + actor := file2.Exists(t, defaultActorSQLBuilderFilePath) + require.Contains(t, actor, "ActorID postgres.ColumnString") +} diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 0571157..d1f8a52 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -67,14 +67,14 @@ func TestGenerator(t *testing.T) { for i := 0; i < 3; i++ { err := postgres.Generate(genTestDir2, postgres.DBConnection{ - Host: dbconfig.Host, - Port: dbconfig.Port, - User: dbconfig.User, - Password: dbconfig.Password, + Host: dbconfig.PgHost, + Port: dbconfig.PgPort, + User: dbconfig.PgUser, + Password: dbconfig.PgPassword, SslMode: "disable", Params: "", - DBName: dbconfig.DBName, + DBName: dbconfig.PgDBName, SchemaName: "dvds", }) diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index ef9337c..e06c985 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -32,6 +32,8 @@ func TestMain(m *testing.M) { setTestRoot() for _, driverName := range []string{"postgres", "pgx"} { + fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) + func() { var err error db, err = sql.Open(driverName, dbconfig.PostgresConnectString)