Add the ability to fully customize jet generated files.

This commit is contained in:
go-jet 2021-07-27 17:39:21 +02:00
parent caa81930dc
commit 8864667f47
40 changed files with 2274 additions and 882 deletions

View file

@ -1,168 +0,0 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/internal/utils"
"strings"
)
// ColumnMetaData struct
type ColumnMetaData struct {
Name string
IsNullable bool
DataType string
EnumName string
IsUnsigned bool
SqlBuilderColumnType string
GoBaseType string
GoModelType string
}
// NewColumnMetaData create new column meta data that describes one column in SQL database
func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData {
columnMetaData := ColumnMetaData{
Name: name,
IsNullable: isNullable,
DataType: dataType,
EnumName: enumName,
IsUnsigned: isUnsigned,
}
columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType()
columnMetaData.GoBaseType = columnMetaData.getGoBaseType()
columnMetaData.GoModelType = columnMetaData.getGoModelType()
return columnMetaData
}
// getSqlBuilderColumnType returns type of jet sql builder column
func (c ColumnMetaData) getSqlBuilderColumnType() string {
switch c.DataType {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint",
"tinyint", "mediumint", "int", "year": //MySQL
return "Integer"
case "date":
return "Date"
case "timestamp without time zone",
"timestamp", "datetime": //MySQL:
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone",
"time": //MySQL
return "Time"
case "time with time zone":
return "Timez"
case "interval":
return "Interval"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "real", "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "Float"
default:
fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.")
return "String"
}
}
// getGoBaseType returns model type for column info.
func (c ColumnMetaData) getGoBaseType() string {
switch c.DataType {
case "USER-DEFINED", "enum":
return utils.ToGoIdentifier(c.EnumName)
case "boolean":
return "bool"
case "tinyint":
return "int8"
case "smallint",
"year":
return "int16"
case "integer",
"mediumint", "int": //MySQL
return "int32"
case "bigint":
return "int64"
case "date", "timestamp without time zone", "timestamp with time zone", "time with time zone", "time without time zone",
"timestamp", "datetime", "time": // MySQL
return "time.Time"
case "bytea",
"binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL
return "[]byte"
case "text", "character", "character varying", "tsvector", "bit", "bit varying", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY",
"char", "varchar", "tinytext", "mediumtext", "longtext": // MySQL
return "string"
case "real":
return "float32"
case "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.")
return "string"
}
}
// GoModelType returns model type for column info with optional pointer if
// column can be NULL.
func (c ColumnMetaData) getGoModelType() string {
typeStr := c.GoBaseType
if strings.Contains(typeStr, "int") && c.IsUnsigned {
typeStr = "u" + typeStr
}
if c.IsNullable {
return "*" + typeStr
}
return typeStr
}
// GoModelTag returns model field tag for column
func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) []ColumnMetaData {
rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName)
utils.PanicOnError(err)
defer rows.Close()
ret := []ColumnMetaData{}
for rows.Next() {
var name, isNullable, dataType, enumName string
var isUnsigned bool
err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned)
utils.PanicOnError(err)
ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned))
}
err = rows.Err()
utils.PanicOnError(err)
return ret
}

View file

@ -1,15 +0,0 @@
package metadata
import (
"database/sql"
)
// DialectQuerySet is set of methods necessary to retrieve dialect meta data information
type DialectQuerySet interface {
ListOfTablesQuery() string
PrimaryKeysQuery() string
ListOfColumnsQuery() string
ListOfEnumsQuery() string
GetEnumsMetaData(db *sql.DB, schemaName string) []MetaData
}

View file

@ -1,12 +0,0 @@
package metadata
// EnumMetaData struct
type EnumMetaData struct {
EnumName string
Values []string
}
// Name returns enum name
func (e EnumMetaData) Name() string {
return e.EnumName
}

View file

@ -1,6 +0,0 @@
package metadata
// MetaData interface
type MetaData interface {
Name() string
}

View file

@ -1,61 +0,0 @@
package metadata
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/internal/utils"
)
// SchemaMetaData struct
type SchemaMetaData struct {
TablesMetaData []MetaData
ViewsMetaData []MetaData
EnumsMetaData []MetaData
}
// IsEmpty returns true if schema info does not contain any table, views or enums metadata
func (s SchemaMetaData) IsEmpty() bool {
return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0
}
const (
baseTable = "BASE TABLE"
view = "VIEW"
)
// GetSchemaMetaData returns schema information from db connection.
func GetSchemaMetaData(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData) {
schemaInfo.TablesMetaData = getTablesMetaData(db, querySet, schemaName, baseTable)
schemaInfo.ViewsMetaData = getTablesMetaData(db, querySet, schemaName, view)
schemaInfo.EnumsMetaData = querySet.GetEnumsMetaData(db, schemaName)
fmt.Println(" FOUND", len(schemaInfo.TablesMetaData), "table(s),", len(schemaInfo.ViewsMetaData), "view(s),",
len(schemaInfo.EnumsMetaData), "enum(s)")
return
}
func getTablesMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableType string) []MetaData {
rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName, tableType)
utils.PanicOnError(err)
defer rows.Close()
ret := []MetaData{}
for rows.Next() {
var tableName string
err = rows.Scan(&tableName)
utils.PanicOnError(err)
tableInfo := GetTableMetaData(db, querySet, schemaName, tableName)
ret = append(ret, tableInfo)
}
err = rows.Err()
utils.PanicOnError(err)
return ret
}

View file

@ -1,103 +0,0 @@
package metadata
import (
"database/sql"
"github.com/go-jet/jet/v2/internal/utils"
"strings"
)
// TableMetaData metadata struct
type TableMetaData struct {
SchemaName string
name string
PrimaryKeys map[string]bool
Columns []ColumnMetaData
}
// Name returns table info name
func (t TableMetaData) Name() string {
return t.name
}
// IsPrimaryKey returns if column is a part of primary key
func (t TableMetaData) IsPrimaryKey(column string) bool {
return t.PrimaryKeys[column]
}
// MutableColumns returns list of mutable columns for table
func (t TableMetaData) MutableColumns() []ColumnMetaData {
ret := []ColumnMetaData{}
for _, column := range t.Columns {
if t.IsPrimaryKey(column.Name) {
continue
}
ret = append(ret, column)
}
return ret
}
// GetImports returns model imports for table.
func (t TableMetaData) GetImports() []string {
imports := map[string]string{}
for _, column := range t.Columns {
columnType := column.GoBaseType
switch columnType {
case "time.Time":
imports["time.Time"] = "time"
case "uuid.UUID":
imports["uuid.UUID"] = "github.com/google/uuid"
}
}
ret := []string{}
for _, packageImport := range imports {
ret = append(ret, packageImport)
}
return ret
}
// GoStructName returns go struct name for sql builder
func (t TableMetaData) GoStructName() string {
return utils.ToGoIdentifier(t.name) + "Table"
}
// GoStructImplName returns go struct impl name for sql builder
func (t TableMetaData) GoStructImplName() string {
name := utils.ToGoIdentifier(t.name) + "Table"
return string(strings.ToLower(name)[0]) + name[1:]
}
// GetTableMetaData returns table info metadata
func GetTableMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData) {
tableInfo.SchemaName = schemaName
tableInfo.name = tableName
tableInfo.PrimaryKeys = getPrimaryKeys(db, querySet, schemaName, tableName)
tableInfo.Columns = getColumnsMetaData(db, querySet, schemaName, tableName)
return
}
func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) map[string]bool {
rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName)
utils.PanicOnError(err)
primaryKeyMap := map[string]bool{}
for rows.Next() {
primaryKey := ""
err := rows.Scan(&primaryKey)
utils.PanicOnError(err)
primaryKeyMap[primaryKey] = true
}
return primaryKeyMap
}

View file

@ -1,107 +0,0 @@
package template
import (
"bytes"
"fmt"
"github.com/go-jet/jet/v2/generator/internal/metadata"
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils"
"path/filepath"
"text/template"
)
// GenerateFiles generates Go files from tables and enums metadata
func GenerateFiles(destDir string, schemaInfo metadata.SchemaMetaData, dialect jet.Dialect) {
if schemaInfo.IsEmpty() {
return
}
fmt.Println("Destination directory:", destDir)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(destDir)
utils.PanicOnError(err)
tableSQLBuilderTemplate := getTableSQLBuilderTemplate(dialect)
generateSQLBuilderFiles(destDir, "table", tableSQLBuilderTemplate, schemaInfo.TablesMetaData, dialect)
generateSQLBuilderFiles(destDir, "view", tableSQLBuilderTemplate, schemaInfo.ViewsMetaData, dialect)
generateSQLBuilderFiles(destDir, "enum", enumSQLBuilderTemplate, schemaInfo.EnumsMetaData, dialect)
generateModelFiles(destDir, "table", tableModelTemplate, schemaInfo.TablesMetaData, dialect)
generateModelFiles(destDir, "view", tableModelTemplate, schemaInfo.ViewsMetaData, dialect)
generateModelFiles(destDir, "enum", enumModelTemplate, schemaInfo.EnumsMetaData, dialect)
fmt.Println("Done")
}
func getTableSQLBuilderTemplate(dialect jet.Dialect) string {
if dialect.Name() == "PostgreSQL" {
return tablePostgreSQLBuilderTemplate
}
return tableSQLBuilderTemplate
}
func generateSQLBuilderFiles(destDir, fileTypes, sqlBuilderTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s sql builder files...\n", fileTypes)
generateGoFiles(destDir, fileTypes, sqlBuilderTemplate, metaData, dialect)
}
func generateModelFiles(destDir, fileTypes, modelTemplate string, metaData []metadata.MetaData, dialect jet.Dialect) {
if len(metaData) == 0 {
return
}
fmt.Printf("Generating %s model files...\n", fileTypes)
generateGoFiles(destDir, "model", modelTemplate, metaData, dialect)
}
func generateGoFiles(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) {
modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath)
utils.PanicOnError(err)
autoGenWarning, err := GenerateTemplate(autoGenWarningTemplate, nil, dialect)
utils.PanicOnError(err)
for _, metaData := range metaDataList {
text, err := GenerateTemplate(template, metaData, dialect, map[string]interface{}{"package": packageName})
utils.PanicOnError(err)
err = utils.SaveGoFile(modelDirPath, utils.ToGoFileName(metaData.Name()), append(autoGenWarning, text...))
utils.PanicOnError(err)
}
return
}
// GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect, params ...map[string]interface{}) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier,
"ToGoEnumValueIdentifier": utils.ToGoEnumValueIdentifier,
"dialect": func() jet.Dialect {
return dialect
},
"param": func(name string) interface{} {
if len(params) > 0 {
return params[0][name]
}
return ""
},
}).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -1,213 +0,0 @@
package template
var autoGenWarningTemplate = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
`
var tableSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
{{- end}}
{{- end}}
package {{param "package"}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "")
type {{.GoStructName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
// AS creates new {{.GoStructName}} with assigned alias
func (a {{.GoStructName}}) AS(alias string) {{.GoStructName}} {
return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{.GoStructName}} with assigned schema name
func (a {{.GoStructName}}) FromSchema(schemaName string) {{.GoStructName}} {
return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias())
}
func new{{.GoStructName}}(schemaName, tableName, alias string) {{.GoStructName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{.GoStructName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tablePostgreSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- if gt $i 0 }}, {{end}}{{ToGoIdentifier $c.Name}}Column
{{- end}}
{{- end}}
package {{param "package"}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}("{{.SchemaName}}", "{{.Name}}", "")
type {{.GoStructImplName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
type {{.GoStructName}} struct {
{{.GoStructImplName}}
EXCLUDED {{.GoStructImplName}}
}
// AS creates new {{.GoStructName}} with assigned alias
func (a {{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
return new{{.GoStructName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{.GoStructName}} with assigned schema name
func (a {{.GoStructName}}) FromSchema(schemaName string) *{{.GoStructName}} {
return new{{.GoStructName}}(schemaName, a.TableName(), a.Alias())
}
func new{{.GoStructName}}(schemaName, tableName, alias string) *{{.GoStructName}} {
return &{{.GoStructName}}{
{{.GoStructImplName}}: new{{.GoStructName}}Impl(schemaName, tableName, alias),
EXCLUDED: new{{.GoStructName}}Impl("", "excluded", ""),
}
}
func new{{.GoStructName}}Impl(schemaName, tableName, alias string) {{.GoStructImplName}} {
var (
{{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{.GoStructImplName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range .Columns}}
{{ToGoIdentifier .Name}}: {{ToGoIdentifier .Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tableModelTemplate = `package model
{{ if .GetImports }}
import (
{{- range .GetImports}}
"{{.}}"
{{- end}}
)
{{end}}
type {{ToGoIdentifier .Name}} struct {
{{- range .Columns}}
{{ToGoIdentifier .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsPrimaryKey .Name)}}" + `
{{- end}}
}
`
var enumSQLBuilderTemplate = `package enum
import "github.com/go-jet/jet/v2/{{dialect.PackageName}}"
var {{ToGoIdentifier $.Name}} = &struct {
{{- range $index, $element := .Values}}
{{ToGoEnumValueIdentifier $.Name $element}} {{dialect.PackageName}}.StringExpression
{{- end}}
} {
{{- range $index, $element := .Values}}
{{ToGoEnumValueIdentifier $.Name $element}}: {{dialect.PackageName}}.NewEnumValue("{{$element}}"),
{{- end}}
}
`
var enumModelTemplate = `package model
import "errors"
type {{ToGoIdentifier $.Name}} string
const (
{{- range $index, $element := .Values}}
{{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}} {{ToGoIdentifier $.Name}} = "{{$element}}"
{{- end}}
)
func (e *{{ToGoIdentifier $.Name}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("jet: Invalid data for {{ToGoIdentifier $.Name}} enum")
} else {
switch string(v) {
{{- range $index, $element := .Values}}
case "{{$element}}":
*e = {{ToGoIdentifier $.Name}}_{{ToGoIdentifier $element}}
{{- end}}
default:
return errors.New("jet: Inavlid data " + string(v) + "for {{ToGoIdentifier $.Name}} enum")
}
return nil
}
}
func (e {{ToGoIdentifier $.Name}}) String() string {
return string(e)
}
`

View file

@ -0,0 +1,27 @@
package metadata
// Column struct
type Column struct {
Name string
IsPrimaryKey bool
IsNullable bool
DataType DataType
}
// DataTypeKind is database type kind(base, enum, user-defined, array)
type DataTypeKind string
// DataTypeKind possible values
const (
BaseType DataTypeKind = "base"
EnumType DataTypeKind = "enum"
UserDefinedType DataTypeKind = "user-defined"
ArrayType DataTypeKind = "array"
)
// DataType contains information about column data type
type DataType struct {
Name string
Kind DataTypeKind
IsUnsigned bool
}

View file

@ -0,0 +1,35 @@
package metadata
import (
"database/sql"
"fmt"
)
// TableType is type of database table(view or base)
type TableType string
const (
baseTable TableType = "BASE TABLE"
viewTable TableType = "VIEW"
)
// DialectQuerySet is set of methods necessary to retrieve dialect meta data information
type DialectQuerySet interface {
GetTablesMetaData(db *sql.DB, schemaName string, tableType TableType) []Table
GetEnumsMetaData(db *sql.DB, schemaName string) []Enum
}
// GetSchema retrieves Schema information from database
func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema {
ret := Schema{
Name: schemaName,
TablesMetaData: querySet.GetTablesMetaData(db, schemaName, baseTable),
ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, viewTable),
EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName),
}
fmt.Println(" FOUND", len(ret.TablesMetaData), "table(s),", len(ret.ViewsMetaData), "view(s),",
len(ret.EnumsMetaData), "enum(s)")
return ret
}

View file

@ -0,0 +1,7 @@
package metadata
// Enum metadata struct
type Enum struct {
Name string `sql:"primary_key"`
Values []string
}

View file

@ -0,0 +1,14 @@
package metadata
// Schema struct
type Schema struct {
Name string
TablesMetaData []Table
ViewsMetaData []Table
EnumsMetaData []Enum
}
// IsEmpty returns true if schema info does not contain any table, views or enums metadata
func (s Schema) IsEmpty() bool {
return len(s.TablesMetaData) == 0 && len(s.ViewsMetaData) == 0 && len(s.EnumsMetaData) == 0
}

View file

@ -0,0 +1,22 @@
package metadata
// Table metadata struct
type Table struct {
Name string
Columns []Column
}
// MutableColumns returns list of mutable columns for table
func (t Table) MutableColumns() []Column {
var ret []Column
for _, column := range t.Columns {
if column.IsPrimaryKey {
continue
}
ret = append(ret, column)
}
return ret
}

View file

@ -3,11 +3,11 @@ package mysql
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/internal/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/internal/template" "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/mysql"
"path"
) )
// DBConnection contains MySQL connection details // DBConnection contains MySQL connection details
@ -22,7 +22,7 @@ type DBConnection struct {
} }
// Generate generates jet files at destination dir from database connection details // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) (err error) { func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) (err error) {
defer utils.ErrorCatch(&err) defer utils.ErrorCatch(&err)
db := openConnection(dbConn) db := openConnection(dbConn)
@ -30,11 +30,14 @@ func Generate(destDir string, dbConn DBConnection) (err error) {
fmt.Println("Retrieving database information...") fmt.Println("Retrieving database information...")
// No schemas in MySQL // No schemas in MySQL
dbInfo := metadata.GetSchemaMetaData(db, dbConn.DBName, &mySqlQuerySet{}) schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbConn.DBName)
genPath := path.Join(destDir, dbConn.DBName) genTemplate := template.Default(mysql.Dialect)
if len(generatorTemplate) > 0 {
genTemplate = generatorTemplate[0]
}
template.GenerateFiles(genPath, dbInfo, mysql.Dialect) template.ProcessSchema(destDir, schemaMetaData, genTemplate)
return nil return nil
} }
@ -46,10 +49,10 @@ func openConnection(dbConn DBConnection) *sql.DB {
} }
fmt.Println("Connecting to MySQL database: " + connectionString) fmt.Println("Connecting to MySQL database: " + connectionString)
db, err := sql.Open("mysql", connectionString) db, err := sql.Open("mysql", connectionString)
utils.PanicOnError(err) throw.OnError(err)
err = db.Ping() err = db.Ping()
utils.PanicOnError(err) throw.OnError(err)
return db return db
} }

View file

@ -1,81 +1,91 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"github.com/go-jet/jet/v2/generator/internal/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
"strings" "strings"
) )
// mySqlQuerySet is dialect query set for MySQL // mySqlQuerySet is dialect query set for MySQL
type mySqlQuerySet struct{} type mySqlQuerySet struct{}
func (m *mySqlQuerySet) ListOfTablesQuery() string { func (m mySqlQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
return ` query := `
SELECT table_name SELECT table_name as "table.name"
FROM INFORMATION_SCHEMA.tables FROM INFORMATION_SCHEMA.tables
WHERE table_schema = ? and table_type = ?; WHERE table_schema = ? and table_type = ?;
` `
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err)
for i := range tables {
tables[i].Columns = m.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
}
return tables
} }
func (m *mySqlQuerySet) PrimaryKeysQuery() string { func (m mySqlQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
return ` query := `
SELECT k.column_name WITH primaryKeys AS (
FROM information_schema.table_constraints t SELECT k.column_name
JOIN information_schema.key_column_usage k FROM information_schema.table_constraints t
USING(constraint_name,table_schema,table_name) JOIN information_schema.key_column_usage k USING(constraint_name,table_schema,table_name)
WHERE t.constraint_type='PRIMARY KEY' WHERE table_schema = ? AND table_name = ? AND t.constraint_type='PRIMARY KEY'
AND t.table_schema= ? )
AND t.table_name= ?; SELECT COLUMN_NAME AS "column.Name",
` IS_NULLABLE = "YES" AS "column.IsNullable",
} (EXISTS(SELECT 1 FROM primaryKeys AS pk WHERE pk.column_name = columns.column_name)) AS "column.IsPrimaryKey",
IF (COLUMN_TYPE = 'tinyint(1)',
func (m *mySqlQuerySet) ListOfColumnsQuery() string { 'boolean',
return ` IF (DATA_TYPE='enum',
SELECT COLUMN_NAME, CONCAT(TABLE_NAME, '_', COLUMN_NAME),
IS_NULLABLE, IF(COLUMN_TYPE = 'tinyint(1)', 'boolean', DATA_TYPE), DATA_TYPE)
IF(DATA_TYPE = 'enum', CONCAT(TABLE_NAME, '_', COLUMN_NAME), ''), ) AS "dataType.Name",
COLUMN_TYPE LIKE '%unsigned%' IF (DATA_TYPE = 'enum', 'enum', 'base') AS "dataType.Kind",
COLUMN_TYPE LIKE '%unsigned%' AS "dataType.IsUnsigned"
FROM information_schema.columns FROM information_schema.columns
WHERE table_schema = ? and table_name = ? WHERE table_schema = ? AND table_name = ?
ORDER BY ordinal_position; ORDER BY ordinal_position;
` `
var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns)
throw.OnError(err)
return columns
} }
func (m *mySqlQuerySet) ListOfEnumsQuery() string { func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
return ` query := `
SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ), SUBSTRING(c.COLUMN_TYPE,5) SELECT (CASE c.DATA_TYPE WHEN 'enum' then CONCAT(c.TABLE_NAME, '_', c.COLUMN_NAME) ELSE '' END ) as "name",
SUBSTRING(c.COLUMN_TYPE,5) as "values"
FROM information_schema.columns as c FROM information_schema.columns as c
INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name) INNER JOIN information_schema.tables as t on (t.table_schema = c.table_schema AND t.table_name = c.table_name)
WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
` `
} var queryResult []struct {
Name string
Values string
}
func (m *mySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData { err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult)
throw.OnError(err)
rows, err := db.Query(m.ListOfEnumsQuery(), schemaName) var ret []metadata.Enum
utils.PanicOnError(err)
defer rows.Close()
ret := []metadata.MetaData{} for _, result := range queryResult {
enumValues := strings.Replace(result.Values[1:len(result.Values)-1], "'", "", -1)
for rows.Next() { ret = append(ret, metadata.Enum{
var enumName string Name: result.Name,
var enumValues string Values: strings.Split(enumValues, ","),
err = rows.Scan(&enumName, &enumValues)
utils.PanicOnError(err)
enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1)
ret = append(ret, metadata.EnumMetaData{
EnumName: enumName,
Values: strings.Split(enumValues, ","),
}) })
} }
err = rows.Err()
utils.PanicOnError(err)
return ret return ret
} }

View file

@ -3,9 +3,10 @@ package postgres
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/internal/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/internal/template" "github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/postgres"
"path" "path"
"strconv" "strconv"
@ -25,38 +26,39 @@ type DBConnection struct {
} }
// Generate generates jet files at destination dir from database connection details // Generate generates jet files at destination dir from database connection details
func Generate(destDir string, dbConn DBConnection) (err error) { func Generate(destDir string, dbConn DBConnection, genTemplate ...template.Template) (err error) {
defer utils.ErrorCatch(&err) defer utils.ErrorCatch(&err)
db, err := openConnection(dbConn) db := openConnection(dbConn)
utils.PanicOnError(err)
defer utils.DBClose(db) defer utils.DBClose(db)
fmt.Println("Retrieving schema information...") fmt.Println("Retrieving schema information...")
schemaInfo := metadata.GetSchemaMetaData(db, dbConn.SchemaName, &postgresQuerySet{})
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) generatorTemplate := template.Default(postgres.Dialect)
template.GenerateFiles(genPath, schemaInfo, postgres.Dialect) if len(genTemplate) > 0 {
generatorTemplate = genTemplate[0]
}
schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, dbConn.SchemaName)
dirPath := path.Join(destDir, dbConn.DBName)
template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate)
return return
} }
func openConnection(dbConn DBConnection) (*sql.DB, error) { func openConnection(dbConn DBConnection) *sql.DB {
connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s",
dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params) dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params)
fmt.Println("Connecting to postgres database: " + connectionString) fmt.Println("Connecting to postgres database: " + connectionString)
db, err := sql.Open("postgres", connectionString) db, err := sql.Open("postgres", connectionString)
if err != nil { throw.OnError(err)
return nil, err
}
err = db.Ping() err = db.Ping()
throw.OnError(err)
if err != nil { return db
return nil, err
}
return db, nil
} }

View file

@ -1,81 +1,83 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"github.com/go-jet/jet/v2/generator/internal/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
) )
// postgresQuerySet is dialect query set for PostgreSQL // postgresQuerySet is dialect query set for PostgreSQL
type postgresQuerySet struct{} type postgresQuerySet struct{}
func (p *postgresQuerySet) ListOfTablesQuery() string { func (p postgresQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
return ` query := `
SELECT table_name SELECT table_name as "table.name"
FROM information_schema.tables FROM information_schema.tables
where table_schema = $1 and table_type = $2; WHERE table_schema = $1 and table_type = $2;
` `
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err)
for i := range tables {
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
}
return tables
} }
func (p *postgresQuerySet) PrimaryKeysQuery() string { func (p postgresQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
return ` query := `
SELECT c.column_name WITH primaryKeys AS (
FROM information_schema.key_column_usage AS c SELECT column_name
LEFT JOIN information_schema.table_constraints AS t FROM information_schema.key_column_usage AS c
ON t.constraint_name = c.constraint_name LEFT JOIN information_schema.table_constraints AS t
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'; ON t.constraint_name = c.constraint_name
` WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'
} )
SELECT column_name as "column.Name",
func (p *postgresQuerySet) ListOfColumnsQuery() string { is_nullable = 'YES' as "column.isNullable",
return ` (EXISTS(SELECT 1 from primaryKeys as pk where pk.column_name = columns.column_name)) as "column.IsPrimaryKey",
SELECT column_name, is_nullable, data_type, udt_name, FALSE dataType.kind as "dataType.Kind",
FROM information_schema.columns (case dataType.Kind when 'base' then data_type else LTRIM(udt_name, '_') end) as "dataType.Name",
FALSE as "dataType.isUnsigned"
FROM information_schema.columns,
LATERAL (select (case data_type
when 'ARRAY' then 'array'
when 'USER-DEFINED' then
case (select typtype from pg_type where typname = columns.udt_name)
when 'e' then 'enum'
else 'user-defined'
end
else 'base'
end) as Kind) as dataType
where table_schema = $1 and table_name = $2 where table_schema = $1 and table_name = $2
order by ordinal_position;` order by ordinal_position;
`
var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns)
throw.OnError(err)
return columns
} }
func (p *postgresQuerySet) ListOfEnumsQuery() string { func (p postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
return ` query := `
SELECT t.typname, SELECT t.typname as "enum.name",
e.enumlabel e.enumlabel as "values"
FROM pg_catalog.pg_type t FROM pg_catalog.pg_type t
JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid JOIN pg_catalog.pg_enum e on t.oid = e.enumtypid
JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE n.nspname = $1 WHERE n.nspname = $1
ORDER BY n.nspname, t.typname, e.enumsortorder;` ORDER BY n.nspname, t.typname, e.enumsortorder;`
}
var result []metadata.Enum
func (p *postgresQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.MetaData {
rows, err := db.Query(p.ListOfEnumsQuery(), schemaName) err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result)
utils.PanicOnError(err) throw.OnError(err)
defer rows.Close()
return result
enumsInfosMap := map[string][]string{}
for rows.Next() {
var enumName string
var enumValue string
err = rows.Scan(&enumName, &enumValue)
utils.PanicOnError(err)
enumValues := enumsInfosMap[enumName]
enumValues = append(enumValues, enumValue)
enumsInfosMap[enumName] = enumValues
}
err = rows.Err()
utils.PanicOnError(err)
ret := []metadata.MetaData{}
for enumName, enumValues := range enumsInfosMap {
ret = append(ret, metadata.EnumMetaData{
EnumName: enumName,
Values: enumValues,
})
}
return ret
} }

View file

@ -0,0 +1,223 @@
package template
var autoGenWarningTemplate = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
`
var tableSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- $field := columnField $c}}
{{- if gt $i 0 }}, {{end}}{{$field.Name}}Column
{{- end}}
{{- end}}
package {{package}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "")
type {{tableTemplate.TypeName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
// AS creates new {{tableTemplate.TypeName}} with assigned alias
func (a {{tableTemplate.TypeName}}) AS(alias string) {{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{tableTemplate.TypeName}} with assigned schema name
func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) {{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias())
}
func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) {{tableTemplate.TypeName}} {
var (
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{tableTemplate.TypeName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}: {{$field.Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tablePostgreSQLBuilderTemplate = `
{{define "column-list" -}}
{{- range $i, $c := . }}
{{- $field := columnField $c}}
{{- if gt $i 0 }}, {{end}}{{$field.Name}}Column
{{- end}}
{{- end}}
package {{package}}
import (
"github.com/go-jet/jet/v2/{{dialect.PackageName}}"
)
var {{tableTemplate.InstanceName}} = new{{tableTemplate.TypeName}}("{{schemaName}}", "{{.Name}}", "")
type {{structImplName}} struct {
{{dialect.PackageName}}.Table
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}} {{dialect.PackageName}}.Column{{$field.Type}}
{{- end}}
AllColumns {{dialect.PackageName}}.ColumnList
MutableColumns {{dialect.PackageName}}.ColumnList
}
type {{tableTemplate.TypeName}} struct {
{{structImplName}}
EXCLUDED {{structImplName}}
}
// AS creates new {{tableTemplate.TypeName}} with assigned alias
func (a {{tableTemplate.TypeName}}) AS(alias string) *{{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new {{tableTemplate.TypeName}} with assigned schema name
func (a {{tableTemplate.TypeName}}) FromSchema(schemaName string) *{{tableTemplate.TypeName}} {
return new{{tableTemplate.TypeName}}(schemaName, a.TableName(), a.Alias())
}
func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) *{{tableTemplate.TypeName}} {
return &{{tableTemplate.TypeName}}{
{{structImplName}}: new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias),
EXCLUDED: new{{tableTemplate.TypeName}}Impl("", "excluded", ""),
}
}
func new{{tableTemplate.TypeName}}Impl(schemaName, tableName, alias string) {{structImplName}} {
var (
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}Column = {{dialect.PackageName}}.{{$field.Type}}Column("{{$c.Name}}")
{{- end}}
allColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .Columns}} }
mutableColumns = {{dialect.PackageName}}.ColumnList{ {{template "column-list" .MutableColumns}} }
)
return {{structImplName}}{
Table: {{dialect.PackageName}}.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
{{- range $i, $c := .Columns}}
{{- $field := columnField $c}}
{{$field.Name}}: {{$field.Name}}Column,
{{- end}}
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
var tableModelFileTemplate = `package {{package}}
{{ with modelImports }}
import (
{{- range .}}
"{{.}}"
{{- end}}
)
{{end}}
{{$modelTableTemplate := tableTemplate}}
type {{$modelTableTemplate.TypeName}} struct {
{{- range .Columns}}
{{- $field := structField .}}
{{$field.Name}} {{$field.Type.Name}} ` + "{{$field.TagsString}}" + `
{{- end}}
}
`
var enumSQLBuilderTemplate = `package {{package}}
import "github.com/go-jet/jet/v2/{{dialect.PackageName}}"
var {{enumTemplate.InstanceName}} = &struct {
{{- range $index, $value := .Values}}
{{enumValueName $value}} {{dialect.PackageName}}.StringExpression
{{- end}}
} {
{{- range $index, $value := .Values}}
{{enumValueName $value}}: {{dialect.PackageName}}.NewEnumValue("{{$value}}"),
{{- end}}
}
`
var enumModelTemplate = `package {{package}}
{{- $enumTemplate := enumTemplate}}
import "errors"
type {{$enumTemplate.TypeName}} string
const (
{{- range $_, $value := .Values}}
{{valueName $value}} {{$enumTemplate.TypeName}} = "{{$value}}"
{{- end}}
)
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("jet: Invalid scan value for {{$enumTemplate.TypeName}} enum. Enum value has to be of type string")
} else {
switch string(v) {
{{- range $_, $value := .Values}}
case "{{$value}}":
*e = {{valueName $value}}
{{- end}}
default:
return errors.New("jet: Invalid scan value '" + string(v) + "' for {{$enumTemplate.TypeName}} enum")
}
return nil
}
}
func (e {{$enumTemplate.TypeName}}) String() string {
return string(e)
}
`

View file

@ -0,0 +1,60 @@
package template
import (
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/jet"
)
// Template is generator template used for file generation
type Template struct {
Dialect jet.Dialect
Schema func(schemaMetaData metadata.Schema) Schema
}
// Default is default generator template implementation
func Default(dialect jet.Dialect) Template {
return Template{
Dialect: dialect,
Schema: DefaultSchema,
}
}
// UseSchema replaces current schema generate function with a new implementation and returns new generator template
func (t Template) UseSchema(schemaFunc func(schemaMetaData metadata.Schema) Schema) Template {
t.Schema = schemaFunc
return t
}
// Schema is schema generator template used to generate schema(model and sql builder) files
type Schema struct {
Path string
Model Model
SQLBuilder SQLBuilder
}
// UsePath replaces path and returns new schema template
func (s Schema) UsePath(path string) Schema {
s.Path = path
return s
}
// UseModel returns new schema template with replaced template for model files generation
func (s Schema) UseModel(model Model) Schema {
s.Model = model
return s
}
// UseSQLBuilder returns new schema with replaced template for sql builder files generation
func (s Schema) UseSQLBuilder(sqlBuilder SQLBuilder) Schema {
s.SQLBuilder = sqlBuilder
return s
}
// DefaultSchema returns default schema template implementation
func DefaultSchema(schemaMetaData metadata.Schema) Schema {
return Schema{
Path: schemaMetaData.Name,
Model: DefaultModel(),
SQLBuilder: DefaultSQLBuilder(),
}
}

View file

@ -0,0 +1,327 @@
package template
import (
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/google/uuid"
"path"
"reflect"
"strings"
"time"
)
// Model is template for model files generation
type Model struct {
Skip bool
Path string
Table func(table metadata.Table) TableModel
View func(table metadata.Table) ViewModel
Enum func(enum metadata.Enum) EnumModel
}
// PackageName returns package name of model types
func (m Model) PackageName() string {
return path.Base(m.Path)
}
// UsePath returns new Model template with replaced file path
func (m Model) UsePath(path string) Model {
m.Path = path
return m
}
// UseTable returns new Model template with replaced template for table model files generation
func (m Model) UseTable(tableModelFunc func(table metadata.Table) TableModel) Model {
m.Table = tableModelFunc
return m
}
// UseView returns new Model template with replaced template for view model files generation
func (m Model) UseView(tableModelFunc func(table metadata.Table) TableModel) Model {
m.View = tableModelFunc
return m
}
// UseEnum returns new Model template with replaced template for enum model files generation
func (m Model) UseEnum(enumFunc func(enumMetaData metadata.Enum) EnumModel) Model {
m.Enum = enumFunc
return m
}
// DefaultModel returns default Model template implementation
func DefaultModel() Model {
return Model{
Skip: false,
Path: "/model",
Table: DefaultTableModel,
View: DefaultViewModel,
Enum: DefaultEnumModel,
}
}
// TableModel is template for table model files generation
type TableModel struct {
Skip bool
FileName string
TypeName string
Field func(columnMetaData metadata.Column) TableModelField
}
// ViewModel is template for view model files generation
type ViewModel = TableModel
// DefaultViewModel is default view template implementation
var DefaultViewModel = DefaultTableModel
// DefaultTableModel is default table template implementation
func DefaultTableModel(tableMetaData metadata.Table) TableModel {
return TableModel{
FileName: utils.ToGoFileName(tableMetaData.Name),
TypeName: utils.ToGoIdentifier(tableMetaData.Name),
Field: DefaultTableModelField,
}
}
// UseFileName returns new TableModel with new file name set
func (t TableModel) UseFileName(fileName string) TableModel {
t.FileName = fileName
return t
}
// UseTypeName returns new TableModel with new type name set
func (t TableModel) UseTypeName(typeName string) TableModel {
t.TypeName = typeName
return t
}
// UseField returns new TableModel with new TableModelField template function
func (t TableModel) UseField(structFieldFunc func(columnMetaData metadata.Column) TableModelField) TableModel {
t.Field = structFieldFunc
return t
}
func getTableModelImports(modelType TableModel, tableMetaData metadata.Table) []string {
importPaths := map[string]bool{}
for _, columnMetaData := range tableMetaData.Columns {
field := modelType.Field(columnMetaData)
importPath := field.Type.ImportPath
if importPath != "" {
importPaths[importPath] = true
}
}
var ret []string
for importPath := range importPaths {
ret = append(ret, importPath)
}
return ret
}
// EnumModel is template for enum model files generation
type EnumModel struct {
Skip bool
FileName string
TypeName string
ValueName func(value string) string
}
// UseFileName returns new EnumModel with new file name set
func (em EnumModel) UseFileName(fileName string) EnumModel {
em.FileName = fileName
return em
}
// UseTypeName returns new EnumModel with new type name set
func (em EnumModel) UseTypeName(typeName string) EnumModel {
em.TypeName = typeName
return em
}
// DefaultEnumModel returns default implementation for EnumModel
func DefaultEnumModel(enumMetaData metadata.Enum) EnumModel {
typeName := utils.ToGoIdentifier(enumMetaData.Name)
return EnumModel{
FileName: utils.ToGoFileName(enumMetaData.Name),
TypeName: typeName,
ValueName: func(value string) string {
return typeName + "_" + utils.ToGoIdentifier(value)
},
}
}
// TableModelField is template for table model field generation
type TableModelField struct {
Name string
Type Type
Tags []string
}
// DefaultTableModelField returns default TableModelField implementation
func DefaultTableModelField(columnMetaData metadata.Column) TableModelField {
var tags []string
if columnMetaData.IsPrimaryKey {
tags = append(tags, `sql:"primary_key"`)
}
return TableModelField{
Name: utils.ToGoIdentifier(columnMetaData.Name),
Type: getType(columnMetaData),
Tags: tags,
}
}
// UseType returns new TypeModelField with a new field type set
func (f TableModelField) UseType(t Type) TableModelField {
f.Type = t
return f
}
// UseName returns new TableModelField implementation with new field name set
func (f TableModelField) UseName(name string) TableModelField {
f.Name = name
return f
}
// UseTags returns new TableModelField implementation with additional tags added.
func (f TableModelField) UseTags(tags ...string) TableModelField {
f.Tags = append(f.Tags, tags...)
return f
}
// TagsString returns tags string representation
func (f TableModelField) TagsString() string {
if len(f.Tags) == 0 {
return ""
}
return fmt.Sprintf("`%s`", strings.Join(f.Tags, " "))
}
// Type represents type of the struct field
type Type struct {
ImportPath string
Name string
}
// NewType creates new type for dummy object
func NewType(dummyObject interface{}) Type {
return Type{
ImportPath: getImportPath(dummyObject),
Name: getTypeName(dummyObject),
}
}
func getTypeName(t interface{}) string {
typeStr := reflect.TypeOf(t).String()
typeStr = strings.Replace(typeStr, "[]uint8", "[]byte", -1)
return typeStr
}
func getImportPath(dummyData interface{}) string {
dataType := reflect.TypeOf(dummyData)
if dataType.Kind() == reflect.Ptr {
return dataType.Elem().PkgPath()
}
return dataType.PkgPath()
}
func getType(columnMetadata metadata.Column) Type {
userDefinedType := getUserDefinedType(columnMetadata)
if userDefinedType != "" {
if columnMetadata.IsNullable {
return Type{Name: "*" + userDefinedType}
}
return Type{Name: userDefinedType}
}
return NewType(getGoType(columnMetadata))
}
func getUserDefinedType(column metadata.Column) string {
switch column.DataType.Kind {
case metadata.EnumType:
return utils.ToGoIdentifier(column.DataType.Name)
case metadata.UserDefinedType, metadata.ArrayType:
return "string"
}
return ""
}
func getGoType(column metadata.Column) interface{} {
defaultGoType := toGoType(column)
if column.IsNullable {
return reflect.New(reflect.TypeOf(defaultGoType)).Interface()
}
return defaultGoType
}
// toGoType returns model type for column info.
func toGoType(column metadata.Column) interface{} {
switch column.DataType.Name {
case "USER-DEFINED", "enum":
return ""
case "boolean", "bool":
return false
case "tinyint":
if column.DataType.IsUnsigned {
return uint8(0)
}
return int8(0)
case "smallint", "int2",
"year":
if column.DataType.IsUnsigned {
return uint16(0)
}
return int16(0)
case "integer", "int4",
"mediumint", "int": //MySQL
if column.DataType.IsUnsigned {
return uint32(0)
}
return int32(0)
case "bigint", "int8":
if column.DataType.IsUnsigned {
return uint64(0)
}
return int64(0)
case "date",
"timestamp without time zone", "timestamp",
"timestamp with time zone", "timestamptz",
"time without time zone", "time",
"time with time zone", "timetz",
"datetime": // MySQL
return time.Time{}
case "bytea",
"binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob": //MySQL
return []byte("")
case "text",
"character", "bpchar",
"character varying", "varchar",
"tsvector", "bit", "bit varying", "varbit",
"money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY",
"char", "tinytext", "mediumtext", "longtext": // MySQL
return ""
case "real", "float4":
return float32(0.0)
case "numeric", "decimal",
"double precision", "float8", "float",
"double": // MySQL
return float64(0.0)
case "uuid":
return uuid.UUID{}
default:
fmt.Println("- [Model ] Unsupported sql column '" + column.Name + " " + column.DataType.Name + "', using string instead.")
return ""
}
}

View file

@ -0,0 +1,45 @@
package template
import (
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/stretchr/testify/require"
"testing"
)
func Test_TableModelField(t *testing.T) {
require.Equal(t, DefaultTableModelField(metadata.Column{
Name: "col_name",
IsPrimaryKey: true,
IsNullable: true,
DataType: metadata.DataType{
Name: "smallint",
Kind: "base",
IsUnsigned: true,
},
}), TableModelField{
Name: "ColName",
Type: Type{
ImportPath: "",
Name: "*uint16",
},
Tags: []string{"sql:\"primary_key\""},
})
require.Equal(t, DefaultTableModelField(metadata.Column{
Name: "time_column_1",
IsPrimaryKey: false,
IsNullable: true,
DataType: metadata.DataType{
Name: "timestamp with time zone",
Kind: "base",
IsUnsigned: false,
},
}), TableModelField{
Name: "TimeColumn1",
Type: Type{
ImportPath: "time",
Name: "*time.Time",
},
Tags: nil,
})
}

View file

@ -0,0 +1,269 @@
package template
import (
"bytes"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"path"
"strings"
"text/template"
)
// ProcessSchema will process schema metadata and constructs go files using generator Template
func ProcessSchema(dirPath string, schemaMetaData metadata.Schema, generatorTemplate Template) {
if schemaMetaData.IsEmpty() {
return
}
schemaTemplate := generatorTemplate.Schema(schemaMetaData)
schemaPath := path.Join(dirPath, schemaTemplate.Path)
fmt.Println("Destination directory:", schemaPath)
fmt.Println("Cleaning up destination directory...")
err := utils.CleanUpGeneratedFiles(schemaPath)
throw.OnError(err)
processModel(schemaPath, schemaMetaData, schemaTemplate)
processSQLBuilder(schemaPath, generatorTemplate.Dialect, schemaMetaData, schemaTemplate)
}
func processModel(dirPath string, schemaMetaData metadata.Schema, schemaTemplate Schema) {
modelTemplate := schemaTemplate.Model
if modelTemplate.Skip {
fmt.Println("Skipping the generation of model types.")
return
}
modelDirPath := path.Join(dirPath, modelTemplate.Path)
err := utils.EnsureDirPath(modelDirPath)
throw.OnError(err)
processTableModels("table", modelDirPath, schemaMetaData.TablesMetaData, modelTemplate)
processTableModels("view", modelDirPath, schemaMetaData.ViewsMetaData, modelTemplate)
processEnumModels(modelDirPath, schemaMetaData.EnumsMetaData, modelTemplate)
}
func processSQLBuilder(dirPath string, dialect jet.Dialect, schemaMetaData metadata.Schema, schemaTemplate Schema) {
sqlBuilderTemplate := schemaTemplate.SQLBuilder
if sqlBuilderTemplate.Skip {
fmt.Println("Skipping the generation of SQL Builder types.")
return
}
sqlBuilderPath := path.Join(dirPath, sqlBuilderTemplate.Path)
processTableSQLBuilder("table", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.TablesMetaData, sqlBuilderTemplate)
processTableSQLBuilder("view", sqlBuilderPath, dialect, schemaMetaData, schemaMetaData.ViewsMetaData, sqlBuilderTemplate)
processEnumSQLBuilder(sqlBuilderPath, dialect, schemaMetaData.EnumsMetaData, sqlBuilderTemplate)
}
func processEnumSQLBuilder(dirPath string, dialect jet.Dialect, enumsMetaData []metadata.Enum, sqlBuilder SQLBuilder) {
if len(enumsMetaData) == 0 {
return
}
fmt.Printf("Generating enum sql builder files\n")
for _, enumMetaData := range enumsMetaData {
enumTemplate := sqlBuilder.Enum(enumMetaData)
if enumTemplate.Skip {
continue
}
enumSQLBuilderPath := path.Join(dirPath, enumTemplate.Path)
err := utils.EnsureDirPath(enumSQLBuilderPath)
throw.OnError(err)
text, err := generateTemplate(
autoGenWarningTemplate+enumSQLBuilderTemplate,
enumMetaData,
template.FuncMap{
"package": func() string {
return enumTemplate.PackageName()
},
"dialect": func() jet.Dialect {
return dialect
},
"enumTemplate": func() EnumSQLBuilder {
return enumTemplate
},
"enumValueName": func(enumValue string) string {
return enumTemplate.ValueName(enumValue)
},
})
throw.OnError(err)
err = utils.SaveGoFile(enumSQLBuilderPath, enumTemplate.FileName, text)
throw.OnError(err)
}
}
func processTableSQLBuilder(fileTypes, dirPath string,
dialect jet.Dialect,
schemaMetaData metadata.Schema,
tablesMetaData []metadata.Table,
sqlBuilderTemplate SQLBuilder) {
if len(tablesMetaData) == 0 {
return
}
fmt.Printf("Generating %s sql builder files\n", fileTypes)
for _, tableMetaData := range tablesMetaData {
var tableSQLBuilderTemplate TableSQLBuilder
if fileTypes == "view" {
tableSQLBuilderTemplate = sqlBuilderTemplate.View(tableMetaData)
} else {
tableSQLBuilderTemplate = sqlBuilderTemplate.Table(tableMetaData)
}
if tableSQLBuilderTemplate.Skip {
continue
}
tableSQLBuilderPath := path.Join(dirPath, tableSQLBuilderTemplate.Path)
err := utils.EnsureDirPath(tableSQLBuilderPath)
throw.OnError(err)
text, err := generateTemplate(
autoGenWarningTemplate+getTableSQLBuilderTemplate(dialect),
tableMetaData,
template.FuncMap{
"package": func() string {
return tableSQLBuilderTemplate.PackageName()
},
"dialect": func() jet.Dialect {
return dialect
},
"schemaName": func() string {
return schemaMetaData.Name
},
"tableTemplate": func() TableSQLBuilder {
return tableSQLBuilderTemplate
},
"structImplName": func() string { // postgres only
structName := tableSQLBuilderTemplate.TypeName
return string(strings.ToLower(structName)[0]) + structName[1:]
},
"columnField": func(columnMetaData metadata.Column) TableSQLBuilderColumn {
return tableSQLBuilderTemplate.Column(columnMetaData)
},
})
throw.OnError(err)
err = utils.SaveGoFile(tableSQLBuilderPath, tableSQLBuilderTemplate.FileName, text)
throw.OnError(err)
}
}
func getTableSQLBuilderTemplate(dialect jet.Dialect) string {
if dialect.Name() == "PostgreSQL" {
return tablePostgreSQLBuilderTemplate
}
return tableSQLBuilderTemplate
}
func processTableModels(fileTypes, modelDirPath string, tablesMetaData []metadata.Table, modelTemplate Model) {
if len(tablesMetaData) == 0 {
return
}
fmt.Printf("Generating %s model files...\n", fileTypes)
for _, tableMetaData := range tablesMetaData {
var tableTemplate TableModel
if fileTypes == "table" {
tableTemplate = modelTemplate.Table(tableMetaData)
} else {
tableTemplate = modelTemplate.View(tableMetaData)
}
if tableTemplate.Skip {
continue
}
text, err := generateTemplate(
autoGenWarningTemplate+tableModelFileTemplate,
tableMetaData,
template.FuncMap{
"package": func() string {
return modelTemplate.PackageName()
},
"modelImports": func() []string {
return getTableModelImports(tableTemplate, tableMetaData)
},
"tableTemplate": func() TableModel {
return tableTemplate
},
"structField": func(columnMetaData metadata.Column) TableModelField {
return tableTemplate.Field(columnMetaData)
},
})
throw.OnError(err)
err = utils.SaveGoFile(modelDirPath, tableTemplate.FileName, text)
throw.OnError(err)
}
}
func processEnumModels(modelDir string, enumsMetaData []metadata.Enum, modelTemplate Model) {
if len(enumsMetaData) == 0 {
return
}
fmt.Print("Generating enum model files...\n")
for _, enumMetaData := range enumsMetaData {
enumTemplate := modelTemplate.Enum(enumMetaData)
if enumTemplate.Skip {
continue
}
text, err := generateTemplate(
autoGenWarningTemplate+enumModelTemplate,
enumMetaData,
template.FuncMap{
"package": func() string {
return modelTemplate.PackageName()
},
"enumTemplate": func() EnumModel {
return enumTemplate
},
"valueName": func(value string) string {
return enumTemplate.ValueName(value)
},
})
throw.OnError(err)
err = utils.SaveGoFile(modelDir, enumTemplate.FileName, text)
throw.OnError(err)
}
}
func generateTemplate(templateText string, templateData interface{}, funcMap template.FuncMap) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(funcMap).Parse(templateText)
if err != nil {
return nil, err
}
var buf bytes.Buffer
if err := t.Execute(&buf, templateData); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View file

@ -0,0 +1,225 @@
package template
import (
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils"
"path"
"unicode"
)
// SQLBuilder is template for generating sql builder files
type SQLBuilder struct {
Skip bool
Path string
Table func(table metadata.Table) TableSQLBuilder
View func(view metadata.Table) TableSQLBuilder
Enum func(enum metadata.Enum) EnumSQLBuilder
}
// DefaultSQLBuilder returns default SQLBuilder implementation
func DefaultSQLBuilder() SQLBuilder {
return SQLBuilder{
Path: "",
Table: DefaultTableSQLBuilder,
View: DefaultViewSQLBuilder,
Enum: DefaultEnumSQLBuilder,
}
}
// UsePath returns new SQLBuilder with new relative path set
func (sb SQLBuilder) UsePath(path string) SQLBuilder {
sb.Path = path
return sb
}
// UseTable returns new SQLBuilder with new TableSQLBuilder template function set
func (sb SQLBuilder) UseTable(tableFunc func(table metadata.Table) TableSQLBuilder) SQLBuilder {
sb.Table = tableFunc
return sb
}
// UseView returns new SQLBuilder with new ViewSQLBuilder template function set
func (sb SQLBuilder) UseView(viewFunc func(table metadata.Table) ViewSQLBuilder) SQLBuilder {
sb.View = viewFunc
return sb
}
// UseEnum returns new SQLBuilder with new EnumSQLBuilder template function set
func (sb SQLBuilder) UseEnum(enumFunc func(enum metadata.Enum) EnumSQLBuilder) SQLBuilder {
sb.Enum = enumFunc
return sb
}
// TableSQLBuilder is template for generating table SQLBuilder files
type TableSQLBuilder struct {
Skip bool
Path string
FileName string
InstanceName string
TypeName string
Column func(columnMetaData metadata.Column) TableSQLBuilderColumn
}
// ViewSQLBuilder is template for generating view SQLBuilder files
type ViewSQLBuilder = TableSQLBuilder
// DefaultTableSQLBuilder returns default implementation for TableSQLBuilder
func DefaultTableSQLBuilder(tableMetaData metadata.Table) TableSQLBuilder {
return TableSQLBuilder{
Path: "/table",
FileName: utils.ToGoFileName(tableMetaData.Name),
InstanceName: utils.ToGoIdentifier(tableMetaData.Name),
TypeName: utils.ToGoIdentifier(tableMetaData.Name) + "Table",
Column: DefaultTableSQLBuilderColumn,
}
}
// DefaultViewSQLBuilder returns default implementation for ViewSQLBuilder
func DefaultViewSQLBuilder(viewMetaData metadata.Table) ViewSQLBuilder {
tableSQLBuilder := DefaultTableSQLBuilder(viewMetaData)
tableSQLBuilder.Path = "/view"
return tableSQLBuilder
}
// PackageName returns package name of table sql builder types
func (tb TableSQLBuilder) PackageName() string {
return path.Base(tb.Path)
}
// UsePath returns new TableSQLBuilder with new relative path set
func (tb TableSQLBuilder) UsePath(path string) TableSQLBuilder {
tb.Path = path
return tb
}
// UseFileName returns new TableSQLBuilder with new file name set
func (tb TableSQLBuilder) UseFileName(name string) TableSQLBuilder {
tb.FileName = name
return tb
}
// UseInstanceName returns new TableSQLBuilder with new instance name set
func (tb TableSQLBuilder) UseInstanceName(name string) TableSQLBuilder {
tb.InstanceName = name
return tb
}
// UseTypeName returns new TableSQLBuilder with new type name set
func (tb TableSQLBuilder) UseTypeName(name string) TableSQLBuilder {
tb.TypeName = name
return tb
}
// UseColumn returns new TableSQLBuilder with new column template function set
func (tb TableSQLBuilder) UseColumn(columnsFunc func(column metadata.Column) TableSQLBuilderColumn) TableSQLBuilder {
tb.Column = columnsFunc
return tb
}
// TableSQLBuilderColumn is template for table sql builder column
type TableSQLBuilderColumn struct {
Name string
Type string
}
// DefaultTableSQLBuilderColumn returns default implementation of TableSQLBuilderColumn
func DefaultTableSQLBuilderColumn(columnMetaData metadata.Column) TableSQLBuilderColumn {
return TableSQLBuilderColumn{
Name: utils.ToGoIdentifier(columnMetaData.Name),
Type: getSqlBuilderColumnType(columnMetaData),
}
}
// getSqlBuilderColumnType returns type of jet sql builder column
func getSqlBuilderColumnType(columnMetaData metadata.Column) string {
if columnMetaData.DataType.Kind != metadata.BaseType {
return "String"
}
switch columnMetaData.DataType.Name {
case "boolean":
return "Bool"
case "smallint", "integer", "bigint",
"tinyint", "mediumint", "int", "year": //MySQL
return "Integer"
case "date":
return "Date"
case "timestamp without time zone",
"timestamp", "datetime": //MySQL:
return "Timestamp"
case "timestamp with time zone":
return "Timestampz"
case "time without time zone",
"time": //MySQL
return "Time"
case "time with time zone":
return "Timez"
case "interval":
return "Interval"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String"
case "real", "numeric", "decimal", "double precision", "float",
"double": // MySQL
return "Float"
default:
fmt.Println("- [SQL Builder] Unsupported sql column '" + columnMetaData.Name + " " + columnMetaData.DataType.Name + "', using StringColumn instead.")
return "String"
}
}
// EnumSQLBuilder is template for generating enum SQLBuilder files
type EnumSQLBuilder struct {
Skip bool
Path string
FileName string
InstanceName string
ValueName func(enumValue string) string
}
// DefaultEnumSQLBuilder returns default implementation of EnumSQLBuilder
func DefaultEnumSQLBuilder(enumMetaData metadata.Enum) EnumSQLBuilder {
return EnumSQLBuilder{
Path: "/enum",
FileName: utils.ToGoFileName(enumMetaData.Name),
InstanceName: utils.ToGoIdentifier(enumMetaData.Name),
ValueName: func(enumValue string) string {
return defaultEnumValueName(enumMetaData.Name, enumValue)
},
}
}
// PackageName returns enum sql builder package name
func (e EnumSQLBuilder) PackageName() string {
return path.Base(e.Path)
}
// UsePath returns new EnumSQLBuilder with new path set
func (e EnumSQLBuilder) UsePath(path string) EnumSQLBuilder {
e.Path = path
return e
}
// UseFileName returns new EnumSQLBuilder with new file name set
func (e EnumSQLBuilder) UseFileName(name string) EnumSQLBuilder {
e.FileName = name
return e
}
// UseInstanceName returns new EnumSQLBuilder with instance name set
func (e EnumSQLBuilder) UseInstanceName(name string) EnumSQLBuilder {
e.InstanceName = name
return e
}
func defaultEnumValueName(enumName, enumValue string) string {
enumValueName := utils.ToGoIdentifier(enumValue)
if !unicode.IsLetter([]rune(enumValueName)[0]) {
return utils.ToGoIdentifier(enumName) + enumValueName
}
return enumValueName
}

View file

@ -0,0 +1,11 @@
package template
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestToGoEnumValueIdentifier(t *testing.T) {
require.Equal(t, defaultEnumValueName("enum_name", "enum_value"), "EnumValue")
require.Equal(t, defaultEnumValueName("NumEnum", "100"), "NumEnum100")
}

View file

@ -9,8 +9,12 @@ import (
) )
// SnakeToCamel returns a string converted from snake case to uppercase // SnakeToCamel returns a string converted from snake case to uppercase
func SnakeToCamel(s string) string { func SnakeToCamel(s string, firstLetterUppercase ...bool) string {
return snakeToCamel(s, true) upperCase := true
if len(firstLetterUppercase) > 0 {
upperCase = firstLetterUppercase[0]
}
return snakeToCamel(s, upperCase)
} }
func snakeToCamel(s string, upperCase bool) string { func snakeToCamel(s string, upperCase bool) string {

View file

@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -66,7 +66,7 @@ func SaveJSONFile(v interface{}, testRelativePath string) {
filePath := getFullPath(testRelativePath) filePath := getFullPath(testRelativePath)
err := ioutil.WriteFile(filePath, jsonText, 0644) err := ioutil.WriteFile(filePath, jsonText, 0644)
utils.PanicOnError(err) throw.OnError(err)
} }
// AssertJSONFile check if data json representation is the same as json at testRelativePath // AssertJSONFile check if data json representation is the same as json at testRelativePath

View file

@ -1,7 +1,7 @@
package testutils package testutils
import ( import (
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"strings" "strings"
"time" "time"
) )
@ -10,7 +10,7 @@ import (
func Date(t string) *time.Time { func Date(t string) *time.Time {
newTime, err := time.Parse("2006-01-02", t) newTime, err := time.Parse("2006-01-02", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }
@ -26,7 +26,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000")
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }
@ -35,7 +35,7 @@ func TimestampWithoutTimeZone(t string, precision int) *time.Time {
func TimeWithoutTimeZone(t string) *time.Time { func TimeWithoutTimeZone(t string) *time.Time {
newTime, err := time.Parse("15:04:05", t) newTime, err := time.Parse("15:04:05", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }
@ -44,7 +44,7 @@ func TimeWithoutTimeZone(t string) *time.Time {
func TimeWithTimeZone(t string) *time.Time { func TimeWithTimeZone(t string) *time.Time {
newTimez, err := time.Parse("15:04:05 -0700", t) newTimez, err := time.Parse("15:04:05 -0700", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTimez return &newTimez
} }
@ -60,7 +60,7 @@ func TimestampWithTimeZone(t string, precision int) *time.Time {
newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t)
utils.PanicOnError(err) throw.OnError(err)
return &newTime return &newTime
} }

View file

@ -0,0 +1,8 @@
package throw
// OnError will panic if err is not nill
func OnError(err error) {
if err != nil {
panic(err)
}
}

View file

@ -10,7 +10,6 @@ import (
"reflect" "reflect"
"strings" "strings"
"time" "time"
"unicode"
) )
// ToGoIdentifier converts database to Go identifier. // ToGoIdentifier converts database to Go identifier.
@ -18,16 +17,6 @@ func ToGoIdentifier(databaseIdentifier string) string {
return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier)) return snaker.SnakeToCamel(replaceInvalidChars(databaseIdentifier))
} }
// ToGoEnumValueIdentifier converts enum value name to Go identifier name.
func ToGoEnumValueIdentifier(enumName, enumValue string) string {
enumValueIdentifier := ToGoIdentifier(enumValue)
if !unicode.IsLetter([]rune(enumValueIdentifier)[0]) {
return ToGoIdentifier(enumName) + enumValueIdentifier
}
return enumValueIdentifier
}
// ToGoFileName converts database identifier to Go file name. // ToGoFileName converts database identifier to Go file name.
func ToGoFileName(databaseIdentifier string) string { func ToGoFileName(databaseIdentifier string) string {
return strings.ToLower(replaceInvalidChars(databaseIdentifier)) return strings.ToLower(replaceInvalidChars(databaseIdentifier))
@ -35,7 +24,11 @@ func ToGoFileName(databaseIdentifier string) string {
// SaveGoFile saves go file at folder dir, with name fileName and contents text. // SaveGoFile saves go file at folder dir, with name fileName and contents text.
func SaveGoFile(dirPath, fileName string, text []byte) error { func SaveGoFile(dirPath, fileName string, text []byte) error {
newGoFilePath := filepath.Join(dirPath, fileName) + ".go" newGoFilePath := filepath.Join(dirPath, fileName)
if !strings.HasSuffix(newGoFilePath, ".go") {
newGoFilePath += ".go"
}
file, err := os.Create(newGoFilePath) file, err := os.Create(newGoFilePath)
@ -160,13 +153,6 @@ func MustBeInitializedPtr(val interface{}, errorStr string) {
} }
} }
// PanicOnError panics if err is not nil
func PanicOnError(err error) {
if err != nil {
panic(err)
}
}
// ErrorCatch is used in defer to recover from panics and to set err // ErrorCatch is used in defer to recover from panics and to set err
func ErrorCatch(err *error) { func ErrorCatch(err *error) {
recovered := recover() recovered := recover()

View file

@ -25,11 +25,6 @@ func TestToGoIdentifier(t *testing.T) {
require.Equal(t, ToGoIdentifier("My-Table"), "MyTable") require.Equal(t, ToGoIdentifier("My-Table"), "MyTable")
} }
func TestToGoEnumValueIdentifier(t *testing.T) {
require.Equal(t, ToGoEnumValueIdentifier("enum_name", "enum_value"), "EnumValue")
require.Equal(t, ToGoEnumValueIdentifier("NumEnum", "100"), "NumEnum100")
}
func TestErrorCatchErr(t *testing.T) { func TestErrorCatchErr(t *testing.T) {
var err error var err error

View file

@ -5,6 +5,7 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"reflect" "reflect"
"strings" "strings"
) )
@ -216,7 +217,7 @@ func (s *scanContext) rowElem(index int) interface{} {
value, err := valuer.Value() value, err := valuer.Value()
utils.PanicOnError(err) throw.OnError(err)
return value return value
} }

View file

@ -4,15 +4,15 @@ import "fmt"
// Postgres test database connection parameters // Postgres test database connection parameters
const ( const (
Host = "localhost" PgHost = "localhost"
Port = 5432 PgPort = 5432
User = "jet" PgUser = "jet"
Password = "jet" PgPassword = "jet"
DBName = "jetdb" PgDBName = "jetdb"
) )
// PostgresConnectString is PostgreSQL test database connection string // PostgresConnectString is PostgreSQL test database connection string
var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", PgHost, PgPort, PgUser, PgPassword, PgDBName)
// MySQL test database connection parameters // MySQL test database connection parameters
const ( const (

View file

@ -6,7 +6,7 @@ import (
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -62,7 +62,7 @@ func initMySQLDB() {
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
err := cmd.Run() err := cmd.Run()
utils.PanicOnError(err) throw.OnError(err)
err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{ err = mysql.Generate("./.gentestdata/mysql", mysql.DBConnection{
Host: dbconfig.MySqLHost, Host: dbconfig.MySqLHost,
@ -72,7 +72,7 @@ func initMySQLDB() {
DBName: dbName, DBName: dbName,
}) })
utils.PanicOnError(err) throw.OnError(err)
} }
} }
@ -99,24 +99,24 @@ func initPostgresDB() {
execFile(db, "./testdata/init/postgres/"+schemaName+".sql") execFile(db, "./testdata/init/postgres/"+schemaName+".sql")
err = postgres.Generate("./.gentestdata", postgres.DBConnection{ err = postgres.Generate("./.gentestdata", postgres.DBConnection{
Host: dbconfig.Host, Host: dbconfig.PgHost,
Port: 5432, Port: 5432,
User: dbconfig.User, User: dbconfig.PgUser,
Password: dbconfig.Password, Password: dbconfig.PgPassword,
DBName: dbconfig.DBName, DBName: dbconfig.PgDBName,
SchemaName: schemaName, SchemaName: schemaName,
SslMode: "disable", SslMode: "disable",
}) })
utils.PanicOnError(err) throw.OnError(err)
} }
} }
func execFile(db *sql.DB, sqlFilePath string) { func execFile(db *sql.DB, sqlFilePath string) {
testSampleSql, err := ioutil.ReadFile(sqlFilePath) testSampleSql, err := ioutil.ReadFile(sqlFilePath)
utils.PanicOnError(err) throw.OnError(err)
_, err = db.Exec(string(testSampleSql)) _, err = db.Exec(string(testSampleSql))
utils.PanicOnError(err) throw.OnError(err)
} }
func printOnError(err error) { func printOnError(err error) {

View file

@ -0,0 +1,25 @@
package file
import (
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"path"
"testing"
)
// Exists expects file to exist on path constructed from pathElems and returns content of the file
func Exists(t *testing.T, pathElems ...string) (fileContent string) {
modelFilePath := path.Join(pathElems...)
file, err := ioutil.ReadFile(modelFilePath)
require.Nil(t, err)
require.NotEmpty(t, file)
return string(file)
}
// NotExists expects file not to exist on path constructed from pathElems
func NotExists(t *testing.T, pathElems ...string) {
modelFilePath := path.Join(pathElems...)
_, err := ioutil.ReadFile(modelFilePath)
require.True(t, os.IsNotExist(err))
}

View file

@ -467,10 +467,10 @@ func TestStringOperators(t *testing.T) {
AllTypes.Text.NOT_LIKE(String("_b_")), AllTypes.Text.NOT_LIKE(String("_b_")),
AllTypes.Text.REGEXP_LIKE(String("aba")), AllTypes.Text.REGEXP_LIKE(String("aba")),
AllTypes.Text.REGEXP_LIKE(String("aba"), false), AllTypes.Text.REGEXP_LIKE(String("aba"), false),
String("ABA").REGEXP_LIKE(String("aba"), true), //String("ABA").REGEXP_LIKE(String("aba"), true),
AllTypes.Text.NOT_REGEXP_LIKE(String("aba")), AllTypes.Text.NOT_REGEXP_LIKE(String("aba")),
AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false), AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false),
String("ABA").NOT_REGEXP_LIKE(String("aba"), true), //String("ABA").NOT_REGEXP_LIKE(String("aba"), true),
BIT_LENGTH(AllTypes.Text), BIT_LENGTH(AllTypes.Text),
CHAR_LENGTH(AllTypes.Char), CHAR_LENGTH(AllTypes.Char),

View file

@ -0,0 +1,389 @@
package mysql
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
mysql2 "github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/3rdparty/snaker"
"github.com/go-jet/jet/v2/internal/utils"
postgres2 "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/dbconfig"
file2 "github.com/go-jet/jet/v2/tests/internal/utils/file"
"github.com/stretchr/testify/require"
"path"
"testing"
)
const tempTestDir = "./.tempTestDir"
var defaultModelPath = path.Join(tempTestDir, "dvds/model")
var defaultActorModelFilePath = path.Join(tempTestDir, "dvds/model", "actor.go")
var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table")
var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "dvds/view")
var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "dvds/enum")
var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "dvds/table", "actor.go")
var dbConnection = mysql2.DBConnection{
Host: dbconfig.MySqLHost,
Port: dbconfig.MySQLPort,
User: dbconfig.MySQLUser,
Password: dbconfig.MySQLPassword,
DBName: "dvds",
}
func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path")
}),
)
require.Nil(t, err)
file2.Exists(t, tempTestDir, "new/schema/path/model/actor.go")
file2.Exists(t, tempTestDir, "new/schema/path/table/actor.go")
file2.Exists(t, tempTestDir, "new/schema/path/view/actor_info.go")
file2.Exists(t, tempTestDir, "new/schema/path/enum/film_rating.go")
}
func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.Model{
Skip: true,
})
}),
)
require.Nil(t, err)
file2.NotExists(t, defaultActorModelFilePath)
file2.Exists(t, defaultTableSQLBuilderFilePath, "actor.go")
file2.Exists(t, defaultViewSQLBuilderFilePath, "actor_info.go")
file2.Exists(t, defaultEnumSQLBuilderFilePath, "film_rating.go")
}
func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.SQLBuilder{
Skip: true,
})
}),
)
require.Nil(t, err)
file2.Exists(t, defaultActorModelFilePath)
file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go")
file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go")
file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go")
}
func TestGeneratorTemplate_Model_ChangePath(t *testing.T) {
const newModelPath = "/new/model/path"
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().UsePath(newModelPath))
}),
)
require.Nil(t, err)
file2.Exists(t, tempTestDir, "dvds", newModelPath, "actor.go")
file2.NotExists(t, defaultActorModelFilePath)
}
func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) {
const newModelPath = "/new/sql-builder/path"
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().UsePath(newModelPath))
}),
)
require.Nil(t, err)
file2.Exists(t, tempTestDir, "dvds", newModelPath, "table", "actor.go")
file2.Exists(t, tempTestDir, "dvds", newModelPath, "view", "actor_info.go")
file2.Exists(t, tempTestDir, "dvds", newModelPath, "enum", "film_rating.go")
file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go")
file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go")
file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go")
}
func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.DefaultTableModel(table).
UseFileName(schemaMetaData.Name + "_" + table.Name).
UseTypeName(utils.ToGoIdentifier(table.Name) + "Table")
}).
UseView(func(table metadata.Table) template.ViewModel {
return template.DefaultViewModel(table).
UseFileName(schemaMetaData.Name + "_" + table.Name + "_view").
UseTypeName(utils.ToGoIdentifier(table.Name) + "View")
}).
UseEnum(func(enumMetaData metadata.Enum) template.EnumModel {
return template.DefaultEnumModel(enumMetaData).
UseFileName(enumMetaData.Name + "_enum").
UseTypeName(utils.ToGoIdentifier(enumMetaData.Name) + "Enum")
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultModelPath, "dvds_actor.go")
require.Contains(t, actor, "type ActorTable struct {")
actorInfo := file2.Exists(t, defaultModelPath, "dvds_actor_info_view.go")
require.Contains(t, actorInfo, "type ActorInfoView struct {")
mpaaRating := file2.Exists(t, defaultModelPath, "film_rating_enum.go")
require.Contains(t, mpaaRating, "type FilmRatingEnum string")
}
func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.TableModel{
Skip: true,
}
}).
UseEnum(func(enumMetaData metadata.Enum) template.EnumModel {
return template.EnumModel{
Skip: true,
}
}),
)
}),
)
require.Nil(t, err)
file2.NotExists(t, defaultModelPath, "actor.go")
file2.Exists(t, defaultModelPath, "actor_info.go")
file2.NotExists(t, defaultModelPath, "film_rating.go")
}
func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().
UseTable(func(table metadata.Table) template.TableSQLBuilder {
return template.TableSQLBuilder{
Skip: true,
}
}).
UseView(func(table metadata.Table) template.TableSQLBuilder {
return template.TableSQLBuilder{
Skip: true,
}
}).
UseEnum(func(enumMetaData metadata.Enum) template.EnumSQLBuilder {
return template.EnumSQLBuilder{
Skip: true,
}
}),
)
}),
)
require.Nil(t, err)
file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go")
file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go")
file2.NotExists(t, defaultEnumSQLBuilderFilePath, "film_rating.go")
}
func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().
UseTable(func(table metadata.Table) template.TableSQLBuilder {
return template.DefaultTableSQLBuilder(table).
UseFileName(schemaMetaData.Name + "_" + table.Name + "_table").
UseTypeName(utils.ToGoIdentifier(table.Name) + "TableSQLBuilder").
UseInstanceName("T_" + utils.ToGoIdentifier(table.Name))
}).
UseView(func(table metadata.Table) template.ViewSQLBuilder {
return template.DefaultViewSQLBuilder(table).
UseFileName(schemaMetaData.Name + "_" + table.Name + "_view").
UseTypeName(utils.ToGoIdentifier(table.Name) + "ViewSQLBuilder").
UseInstanceName("V_" + utils.ToGoIdentifier(table.Name))
}).
UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder {
return template.DefaultEnumSQLBuilder(enum).
UseFileName(schemaMetaData.Name + "_" + enum.Name + "_enum").
UseInstanceName(utils.ToGoIdentifier(enum.Name) + "EnumSQLBuilder")
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "dvds_actor_table.go")
require.Contains(t, actor, "type ActorTableSQLBuilder struct {")
require.Contains(t, actor, "var T_Actor = newActorTableSQLBuilder(\"dvds\", \"actor\", \"\")")
actorInfo := file2.Exists(t, defaultViewSQLBuilderFilePath, "dvds_actor_info_view.go")
require.Contains(t, actorInfo, "type ActorInfoViewSQLBuilder struct {")
require.Contains(t, actorInfo, "var V_ActorInfo = newActorInfoViewSQLBuilder(\"dvds\", \"actor_info\", \"\")")
mpaaRating := file2.Exists(t, defaultEnumSQLBuilderFilePath, "dvds_film_rating_enum.go")
require.Contains(t, mpaaRating, "var FilmRatingEnumSQLBuilder = &struct {")
}
func TestGeneratorTemplate_Model_AddTags(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.DefaultTableModel(table).
UseField(func(columnMetaData metadata.Column) template.TableModelField {
defaultTableModelField := template.DefaultTableModelField(columnMetaData)
return defaultTableModelField.UseTags(
fmt.Sprintf(`json:"%s"`, snaker.SnakeToCamel(columnMetaData.Name, false)),
fmt.Sprintf(`xml:"%s"`, columnMetaData.Name),
)
})
}).
UseView(func(table metadata.Table) template.ViewModel {
return template.DefaultViewModel(table).
UseField(func(columnMetaData metadata.Column) template.TableModelField {
defaultTableModelField := template.DefaultTableModelField(columnMetaData)
if table.Name == "actor_info" && columnMetaData.Name == "actor_id" {
return defaultTableModelField.UseTags(`sql:"primary_key"`)
}
return defaultTableModelField
})
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultActorModelFilePath)
require.Contains(t, actor, "ActorID uint16 `sql:\"primary_key\" json:\"actorID\" xml:\"actor_id\"`")
require.Contains(t, actor, "FirstName string `json:\"firstName\" xml:\"first_name\"`")
actorInfo := file2.Exists(t, defaultModelPath, "actor_info.go")
require.Contains(t, actorInfo, "ActorID uint16 `sql:\"primary_key\"`")
}
func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.DefaultTableModel(table).
UseField(func(columnMetaData metadata.Column) template.TableModelField {
defaultTableModelField := template.DefaultTableModelField(columnMetaData)
switch defaultTableModelField.Type.Name {
case "*string":
defaultTableModelField.Type = template.NewType(sql.NullString{})
case "*int32":
defaultTableModelField.Type = template.NewType(sql.NullInt32{})
case "*int64":
defaultTableModelField.Type = template.NewType(sql.NullInt64{})
case "*bool":
defaultTableModelField.Type = template.NewType(sql.NullBool{})
case "*float64":
defaultTableModelField.Type = template.NewType(sql.NullFloat64{})
case "*time.Time":
defaultTableModelField.Type = template.NewType(sql.NullTime{})
}
return defaultTableModelField
})
}),
)
}),
)
require.Nil(t, err)
data := file2.Exists(t, defaultModelPath, "film.go")
require.Contains(t, data, "\"database/sql\"")
require.Contains(t, data, "Description sql.NullString")
require.Contains(t, data, "ReleaseYear *int16")
require.Contains(t, data, "SpecialFeatures sql.NullString")
}
func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) {
err := mysql2.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().
UseTable(func(table metadata.Table) template.TableSQLBuilder {
return template.DefaultTableSQLBuilder(table).
UseColumn(func(column metadata.Column) template.TableSQLBuilderColumn {
defaultColumn := template.DefaultTableSQLBuilderColumn(column)
if defaultColumn.Name == "ActorID" {
defaultColumn.Type = "String"
}
return defaultColumn
})
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultActorSQLBuilderFilePath)
require.Contains(t, actor, "ActorID postgres.ColumnString")
}

View file

@ -66,9 +66,7 @@ func TestUpdateWithSubQueries(t *testing.T) {
expectedSQL := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET name = ( SET name = ?,
SELECT ?
),
url = ( url = (
SELECT link2.url AS "link2.url" SELECT link2.url AS "link2.url"
FROM test_sample.link2 FROM test_sample.link2
@ -80,7 +78,7 @@ WHERE link.name = ?;
query := Link. query := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET( SET(
SELECT(String("Bong")), String("Bong"),
SELECT(Link2.URL). SELECT(Link2.URL).
FROM(Link2). FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))), WHERE(Link2.Name.EQ(String("Youtube"))),
@ -96,7 +94,7 @@ WHERE link.name = ?;
query := Link. query := Link.
UPDATE(). UPDATE().
SET( SET(
Link.Name.SET(StringExp(SELECT(String("Bong")))), Link.Name.SET(String("Bong")),
Link.URL.SET(StringExp( Link.URL.SET(StringExp(
SELECT(Link2.URL). SELECT(Link2.URL).
FROM(Link2). FROM(Link2).

View file

@ -0,0 +1,387 @@
package postgres
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/3rdparty/snaker"
"github.com/go-jet/jet/v2/internal/utils"
postgres2 "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/dbconfig"
file2 "github.com/go-jet/jet/v2/tests/internal/utils/file"
"github.com/stretchr/testify/require"
"path"
"testing"
)
const tempTestDir = "./.tempTestDir"
var defaultModelPath = path.Join(tempTestDir, "jetdb/dvds/model")
var defaultActorModelFilePath = path.Join(tempTestDir, "jetdb/dvds/model", "actor.go")
var defaultTableSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table")
var defaultViewSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/view")
var defaultEnumSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/enum")
var defaultActorSQLBuilderFilePath = path.Join(tempTestDir, "jetdb/dvds/table", "actor.go")
var dbConnection = postgres.DBConnection{
Host: dbconfig.PgHost,
Port: 5432,
User: dbconfig.PgUser,
Password: dbconfig.PgPassword,
DBName: dbconfig.PgDBName,
SchemaName: "dvds",
SslMode: "disable",
}
func TestGeneratorTemplate_Schema_ChangePath(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).UsePath("new/schema/path")
}),
)
require.Nil(t, err)
file2.Exists(t, tempTestDir, "jetdb/new/schema/path/model/actor.go")
file2.Exists(t, tempTestDir, "jetdb/new/schema/path/table/actor.go")
file2.Exists(t, tempTestDir, "jetdb/new/schema/path/view/actor_info.go")
file2.Exists(t, tempTestDir, "jetdb/new/schema/path/enum/mpaa_rating.go")
}
func TestGeneratorTemplate_Model_SkipGeneration(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.Model{
Skip: true,
})
}),
)
require.Nil(t, err)
file2.NotExists(t, defaultActorModelFilePath)
}
func TestGeneratorTemplate_SQLBuilder_SkipGeneration(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.SQLBuilder{
Skip: true,
})
}),
)
require.Nil(t, err)
file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go")
file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go")
file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go")
}
func TestGeneratorTemplate_Model_ChangePath(t *testing.T) {
const newModelPath = "/new/model/path"
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().UsePath(newModelPath))
}),
)
require.Nil(t, err)
file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "actor.go")
file2.NotExists(t, defaultActorModelFilePath)
}
func TestGeneratorTemplate_SQLBuilder_ChangePath(t *testing.T) {
const newModelPath = "/new/sql-builder/path"
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().UsePath(newModelPath))
}),
)
require.Nil(t, err)
file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "table", "actor.go")
file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "view", "actor_info.go")
file2.Exists(t, tempTestDir, "jetdb", "dvds", newModelPath, "enum", "mpaa_rating.go")
file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go")
file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go")
file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go")
}
func TestGeneratorTemplate_Model_RenameFilesAndTypes(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.DefaultTableModel(table).
UseFileName(schemaMetaData.Name + "_" + table.Name).
UseTypeName(utils.ToGoIdentifier(table.Name) + "Table")
}).
UseView(func(table metadata.Table) template.ViewModel {
return template.DefaultViewModel(table).
UseFileName(schemaMetaData.Name + "_" + table.Name + "_view").
UseTypeName(utils.ToGoIdentifier(table.Name) + "View")
}).
UseEnum(func(enumMetaData metadata.Enum) template.EnumModel {
return template.DefaultEnumModel(enumMetaData).
UseFileName(enumMetaData.Name + "_enum").
UseTypeName(utils.ToGoIdentifier(enumMetaData.Name) + "Enum")
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultModelPath, "dvds_actor.go")
require.Contains(t, actor, "type ActorTable struct {")
actorInfo := file2.Exists(t, defaultModelPath, "dvds_actor_info_view.go")
require.Contains(t, actorInfo, "type ActorInfoView struct {")
mpaaRating := file2.Exists(t, defaultModelPath, "mpaa_rating_enum.go")
require.Contains(t, mpaaRating, "type MpaaRatingEnum string")
}
func TestGeneratorTemplate_Model_SkipTableAndEnum(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.TableModel{
Skip: true,
}
}).
UseEnum(func(enumMetaData metadata.Enum) template.EnumModel {
return template.EnumModel{
Skip: true,
}
}),
)
}),
)
require.Nil(t, err)
file2.NotExists(t, defaultModelPath, "actor.go")
file2.Exists(t, defaultModelPath, "actor_info.go")
file2.NotExists(t, defaultModelPath, "mpaa_rating.go")
}
func TestGeneratorTemplate_SQLBuilder_SkipTableAndEnum(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().
UseTable(func(table metadata.Table) template.TableSQLBuilder {
return template.TableSQLBuilder{
Skip: true,
}
}).
UseView(func(table metadata.Table) template.TableSQLBuilder {
return template.TableSQLBuilder{
Skip: true,
}
}).
UseEnum(func(enumMetaData metadata.Enum) template.EnumSQLBuilder {
return template.EnumSQLBuilder{
Skip: true,
}
}),
)
}),
)
require.Nil(t, err)
file2.NotExists(t, defaultTableSQLBuilderFilePath, "actor.go")
file2.NotExists(t, defaultViewSQLBuilderFilePath, "actor_info.go")
file2.NotExists(t, defaultEnumSQLBuilderFilePath, "mpaa_rating.go")
}
func TestGeneratorTemplate_SQLBuilder_ChangeTypeAndFileName(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().
UseTable(func(table metadata.Table) template.TableSQLBuilder {
return template.DefaultTableSQLBuilder(table).
UseFileName(schemaMetaData.Name + "_" + table.Name + "_table").
UseTypeName(utils.ToGoIdentifier(table.Name) + "TableSQLBuilder").
UseInstanceName("T_" + utils.ToGoIdentifier(table.Name))
}).
UseView(func(table metadata.Table) template.ViewSQLBuilder {
return template.DefaultViewSQLBuilder(table).
UseFileName(schemaMetaData.Name + "_" + table.Name + "_view").
UseTypeName(utils.ToGoIdentifier(table.Name) + "ViewSQLBuilder").
UseInstanceName("V_" + utils.ToGoIdentifier(table.Name))
}).
UseEnum(func(enum metadata.Enum) template.EnumSQLBuilder {
return template.DefaultEnumSQLBuilder(enum).
UseFileName(schemaMetaData.Name + "_" + enum.Name + "_enum").
UseInstanceName(utils.ToGoIdentifier(enum.Name) + "EnumSQLBuilder")
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultTableSQLBuilderFilePath, "dvds_actor_table.go")
require.Contains(t, actor, "type ActorTableSQLBuilder struct {")
require.Contains(t, actor, "var T_Actor = newActorTableSQLBuilder(\"dvds\", \"actor\", \"\")")
actorInfo := file2.Exists(t, defaultViewSQLBuilderFilePath, "dvds_actor_info_view.go")
require.Contains(t, actorInfo, "type ActorInfoViewSQLBuilder struct {")
require.Contains(t, actorInfo, "var V_ActorInfo = newActorInfoViewSQLBuilder(\"dvds\", \"actor_info\", \"\")")
mpaaRating := file2.Exists(t, defaultEnumSQLBuilderFilePath, "dvds_mpaa_rating_enum.go")
require.Contains(t, mpaaRating, "var MpaaRatingEnumSQLBuilder = &struct {")
}
func TestGeneratorTemplate_Model_AddTags(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.DefaultTableModel(table).
UseField(func(columnMetaData metadata.Column) template.TableModelField {
defaultTableModelField := template.DefaultTableModelField(columnMetaData)
return defaultTableModelField.UseTags(
fmt.Sprintf(`json:"%s"`, snaker.SnakeToCamel(columnMetaData.Name, false)),
fmt.Sprintf(`xml:"%s"`, columnMetaData.Name),
)
})
}).
UseView(func(table metadata.Table) template.ViewModel {
return template.DefaultViewModel(table).
UseField(func(columnMetaData metadata.Column) template.TableModelField {
defaultTableModelField := template.DefaultTableModelField(columnMetaData)
if table.Name == "actor_info" && columnMetaData.Name == "actor_id" {
return defaultTableModelField.UseTags(`sql:"primary_key"`)
}
return defaultTableModelField
})
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultActorModelFilePath)
require.Contains(t, actor, "ActorID int32 `sql:\"primary_key\" json:\"actorID\" xml:\"actor_id\"`")
require.Contains(t, actor, "FirstName string `json:\"firstName\" xml:\"first_name\"`")
actorInfo := file2.Exists(t, defaultModelPath, "actor_info.go")
require.Contains(t, actorInfo, "ActorID *int32 `sql:\"primary_key\"`")
}
func TestGeneratorTemplate_Model_ChangeFieldTypes(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseModel(template.DefaultModel().
UseTable(func(table metadata.Table) template.TableModel {
return template.DefaultTableModel(table).
UseField(func(columnMetaData metadata.Column) template.TableModelField {
defaultTableModelField := template.DefaultTableModelField(columnMetaData)
switch defaultTableModelField.Type.Name {
case "*string":
defaultTableModelField.Type = template.NewType(sql.NullString{})
case "*int32":
defaultTableModelField.Type = template.NewType(sql.NullInt32{})
case "*int64":
defaultTableModelField.Type = template.NewType(sql.NullInt64{})
case "*bool":
defaultTableModelField.Type = template.NewType(sql.NullBool{})
case "*float64":
defaultTableModelField.Type = template.NewType(sql.NullFloat64{})
case "*time.Time":
defaultTableModelField.Type = template.NewType(sql.NullTime{})
}
return defaultTableModelField
})
}),
)
}),
)
require.Nil(t, err)
data := file2.Exists(t, defaultModelPath, "film.go")
require.Contains(t, data, "\"database/sql\"")
require.Contains(t, data, "Description sql.NullString")
require.Contains(t, data, "ReleaseYear sql.NullInt32")
require.Contains(t, data, "SpecialFeatures sql.NullString")
}
func TestGeneratorTemplate_SQLBuilder_ChangeColumnTypes(t *testing.T) {
err := postgres.Generate(
tempTestDir,
dbConnection,
template.Default(postgres2.Dialect).
UseSchema(func(schemaMetaData metadata.Schema) template.Schema {
return template.DefaultSchema(schemaMetaData).
UseSQLBuilder(template.DefaultSQLBuilder().
UseTable(func(table metadata.Table) template.TableSQLBuilder {
return template.DefaultTableSQLBuilder(table).
UseColumn(func(column metadata.Column) template.TableSQLBuilderColumn {
defaultColumn := template.DefaultTableSQLBuilderColumn(column)
if defaultColumn.Name == "ActorID" {
defaultColumn.Type = "String"
}
return defaultColumn
})
}),
)
}),
)
require.Nil(t, err)
actor := file2.Exists(t, defaultActorSQLBuilderFilePath)
require.Contains(t, actor, "ActorID postgres.ColumnString")
}

View file

@ -67,14 +67,14 @@ func TestGenerator(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
err := postgres.Generate(genTestDir2, postgres.DBConnection{ err := postgres.Generate(genTestDir2, postgres.DBConnection{
Host: dbconfig.Host, Host: dbconfig.PgHost,
Port: dbconfig.Port, Port: dbconfig.PgPort,
User: dbconfig.User, User: dbconfig.PgUser,
Password: dbconfig.Password, Password: dbconfig.PgPassword,
SslMode: "disable", SslMode: "disable",
Params: "", Params: "",
DBName: dbconfig.DBName, DBName: dbconfig.PgDBName,
SchemaName: "dvds", SchemaName: "dvds",
}) })

View file

@ -32,6 +32,8 @@ func TestMain(m *testing.M) {
setTestRoot() setTestRoot()
for _, driverName := range []string{"postgres", "pgx"} { for _, driverName := range []string{"postgres", "pgx"} {
fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName)
func() { func() {
var err error var err error
db, err = sql.Open(driverName, dbconfig.PostgresConnectString) db, err = sql.Open(driverName, dbconfig.PostgresConnectString)