diff --git a/cmd/jet/main.go b/cmd/jet/main.go index 56b0322..e3335cf 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -10,6 +10,7 @@ import ( _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" "os" + "strings" ) var ( @@ -28,16 +29,16 @@ var ( ) func init() { - flag.StringVar(&source, "source", postgres.Dialect.Name(), "Database name") + flag.StringVar(&source, "source", "", "Database system name (PostgreSQL or MySQL)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.IntVar(&port, "port", 0, "Database port") flag.StringVar(&user, "user", "", "Database user") flag.StringVar(&password, "password", "", "The user’s password") - flag.StringVar(&sslmode, "sslmode", "disable", "Whether or not to use SSL(optional)") flag.StringVar(¶ms, "params", "", "Additional connection string parameters(optional)") - flag.StringVar(&dbName, "dbname", "", "name of the database") - flag.StringVar(&schemaName, "schema", "public", "Database schema name.") + flag.StringVar(&dbName, "dbname", "", "Database name") + flag.StringVar(&schemaName, "schema", "public", `Database schema name. (default "public") (ignored for MySQL)`) + flag.StringVar(&sslmode, "sslmode", "disable", `Whether or not to use SSL(optional)(default "disable") (ignored for MySQL)`) flag.StringVar(&destDir, "path", "", "Destination dir for files generated.") } @@ -46,7 +47,11 @@ func main() { flag.Usage = func() { _, _ = fmt.Fprint(os.Stdout, ` -Usage of jet: +Jet generator 2.0.0 + +Usage: + -source string + Database system name (PostgreSQL or MySQL) -host string Database host path (Example: localhost) -port int @@ -56,13 +61,13 @@ Usage of jet: -password string The user’s password -dbname string - name of the database + Database name -params string Additional connection string parameters(optional) -schema string - Database schema name. (default "public") + Database schema name. (default "public") (ignored for MySQL) -sslmode string - Whether or not to use SSL(optional) (default "disable") + Whether or not to use SSL(optional) (default "disable") (ignored for MySQL) -path string Destination dir for files generated. `) @@ -70,16 +75,14 @@ Usage of jet: flag.Parse() + if source == "" || host == "" || port == 0 || user == "" || dbName == "" { + printErrorAndExit("\nERROR: required flag(s) missing") + } + var err error - switch source { - case postgres.Dialect.Name(): - if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" { - fmt.Println("\njet: required flag missing") - flag.Usage() - os.Exit(-2) - } - + switch strings.ToLower(strings.TrimSpace(source)) { + case strings.ToLower(postgres.Dialect.Name()): genData := postgresgen.DBConnection{ Host: host, Port: port, @@ -94,12 +97,7 @@ Usage of jet: err = postgresgen.Generate(destDir, genData) - case mysql.Dialect.Name(): - if host == "" || port == 0 || user == "" || dbName == "" { - fmt.Println("\njet: required flag missing") - flag.Usage() - os.Exit(-2) - } + case strings.ToLower(mysql.Dialect.Name()): dbConn := mysqlgen.DBConnection{ Host: host, @@ -112,10 +110,19 @@ Usage of jet: } err = mysqlgen.Generate(destDir, dbConn) + default: + fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.") + os.Exit(-4) } if err != nil { fmt.Println(err.Error()) - os.Exit(-1) + os.Exit(-5) } } + +func printErrorAndExit(error string) { + fmt.Println(error) + flag.Usage() + os.Exit(-2) +} diff --git a/generator/internal/metadata/column_info.go b/generator/internal/metadata/column_meta_data.go similarity index 64% rename from generator/internal/metadata/column_info.go rename to generator/internal/metadata/column_meta_data.go index 988355c..7bc169e 100644 --- a/generator/internal/metadata/column_info.go +++ b/generator/internal/metadata/column_meta_data.go @@ -7,17 +7,37 @@ import ( "strings" ) -// ColumnInfo metadata struct -type ColumnInfo struct { +// ColumnMetaData struct +type ColumnMetaData struct { Name string IsNullable bool DataType string - IsUnsigned bool EnumName string + IsUnsigned bool + + SqlBuilderColumnType string + GoBaseType string + GoModelType string } -// SqlBuilderColumnType returns type of jet sql builder column -func (c ColumnInfo) SqlBuilderColumnType() string { +func NewColumnMetaData(name string, isNullable bool, dataType string, enumName string, isUnsigned bool) ColumnMetaData { + columnMetaData := ColumnMetaData{ + Name: name, + IsNullable: isNullable, + DataType: dataType, + EnumName: enumName, + IsUnsigned: isUnsigned, + } + + columnMetaData.SqlBuilderColumnType = columnMetaData.getSqlBuilderColumnType() + columnMetaData.GoBaseType = columnMetaData.getGoBaseType() + columnMetaData.GoModelType = columnMetaData.getGoModelType() + + return columnMetaData +} + +// getSqlBuilderColumnType returns type of jet sql builder column +func (c ColumnMetaData) getSqlBuilderColumnType() string { switch c.DataType { case "boolean": return "Bool" @@ -45,13 +65,13 @@ func (c ColumnInfo) SqlBuilderColumnType() string { "double": // MySQL return "Float" default: - fmt.Println("Unsupported sql type: " + c.DataType + ", using string column instead for sql builder.") + fmt.Println("- [SQL Builder] Unsupported sql column '" + c.Name + " " + c.DataType + "', using StringColumn instead.") return "String" } } -// GoBaseType returns model type for column info. -func (c ColumnInfo) GoBaseType() string { +// getGoBaseType returns model type for column info. +func (c ColumnMetaData) getGoBaseType() string { switch c.DataType { case "USER-DEFINED", "enum": return utils.ToGoIdentifier(c.EnumName) @@ -85,15 +105,15 @@ func (c ColumnInfo) GoBaseType() string { case "uuid": return "uuid.UUID" default: - fmt.Println("Unsupported sql type: " + c.DataType + ", using string instead for model type.") + fmt.Println("- [Model ] Unsupported sql column '" + c.Name + " " + c.DataType + "', using string instead.") return "string" } } // GoModelType returns model type for column info with optional pointer if // column can be NULL. -func (c ColumnInfo) GoModelType() string { - typeStr := c.GoBaseType() +func (c ColumnMetaData) getGoModelType() string { + typeStr := c.GoBaseType if strings.Contains(typeStr, "int") && c.IsUnsigned { typeStr = "u" + typeStr @@ -107,7 +127,7 @@ func (c ColumnInfo) GoModelType() string { } // GoModelTag returns model field tag for column -func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string { +func (c ColumnMetaData) GoModelTag(isPrimaryKey bool) string { tags := []string{} if isPrimaryKey { @@ -121,7 +141,7 @@ func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string { return "" } -func getColumnInfos(db *sql.DB, querySet MetaDataQuerySet, schemaName, tableName string) ([]ColumnInfo, error) { +func getColumnsMetaData(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) ([]ColumnMetaData, error) { rows, err := db.Query(querySet.ListOfColumnsQuery(), schemaName, tableName) @@ -130,20 +150,18 @@ func getColumnInfos(db *sql.DB, querySet MetaDataQuerySet, schemaName, tableName } defer rows.Close() - ret := []ColumnInfo{} + ret := []ColumnMetaData{} for rows.Next() { - columnInfo := ColumnInfo{} - var isNullable string - err := rows.Scan(&columnInfo.Name, &isNullable, &columnInfo.DataType, &columnInfo.EnumName, &columnInfo.IsUnsigned) - - columnInfo.IsNullable = isNullable == "YES" + var name, isNullable, dataType, enumName string + var isUnsigned bool + err := rows.Scan(&name, &isNullable, &dataType, &enumName, &isUnsigned) if err != nil { return nil, err } - ret = append(ret, columnInfo) + ret = append(ret, NewColumnMetaData(name, isNullable == "YES", dataType, enumName, isUnsigned)) } err = rows.Err() diff --git a/generator/internal/metadata/query.go b/generator/internal/metadata/dialect_query_set.go similarity index 97% rename from generator/internal/metadata/query.go rename to generator/internal/metadata/dialect_query_set.go index 0c97296..96867be 100644 --- a/generator/internal/metadata/query.go +++ b/generator/internal/metadata/dialect_query_set.go @@ -5,7 +5,7 @@ import ( "strings" ) -type MetaDataQuerySet interface { +type DialectQuerySet interface { ListOfTablesQuery() string PrimaryKeysQuery() string ListOfColumnsQuery() string @@ -123,7 +123,7 @@ func (m *MySqlQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) ([]MetaD enumValues = strings.Replace(enumValues[1:len(enumValues)-1], "'", "", -1) - ret = append(ret, EnumInfo{ + ret = append(ret, EnumMetaData{ name: enumName, Values: strings.Split(enumValues, ","), }) diff --git a/generator/internal/metadata/enum_info.go b/generator/internal/metadata/enum_meta_data.go similarity index 77% rename from generator/internal/metadata/enum_info.go rename to generator/internal/metadata/enum_meta_data.go index 21881be..806325f 100644 --- a/generator/internal/metadata/enum_info.go +++ b/generator/internal/metadata/enum_meta_data.go @@ -4,18 +4,18 @@ import ( "database/sql" ) -// EnumInfo struct -type EnumInfo struct { +// EnumMetaData struct +type EnumMetaData struct { name string Values []string } // Name returns enum name -func (e EnumInfo) Name() string { +func (e EnumMetaData) Name() string { return e.name } -func getEnumInfos(db *sql.DB, querySet MetaDataQuerySet, schemaName string) ([]MetaData, error) { +func getEnumInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) { rows, err := db.Query(querySet.ListOfEnumsQuery(), schemaName) @@ -49,7 +49,7 @@ func getEnumInfos(db *sql.DB, querySet MetaDataQuerySet, schemaName string) ([]M ret := []MetaData{} for enumName, enumValues := range enumsInfosMap { - ret = append(ret, EnumInfo{ + ret = append(ret, EnumMetaData{ enumName, enumValues, }) diff --git a/generator/internal/metadata/schema_info.go b/generator/internal/metadata/schema_meta_data.go similarity index 78% rename from generator/internal/metadata/schema_info.go rename to generator/internal/metadata/schema_meta_data.go index 6bec5ec..0a05d6a 100644 --- a/generator/internal/metadata/schema_info.go +++ b/generator/internal/metadata/schema_meta_data.go @@ -5,14 +5,14 @@ import ( "fmt" ) -// SchemaInfo metadata struct -type SchemaInfo struct { +// SchemaMetaData struct +type SchemaMetaData struct { TableInfos []MetaData EnumInfos []MetaData } // GetSchemaInfo returns schema information from db connection. -func GetSchemaInfo(db *sql.DB, schemaName string, querySet MetaDataQuerySet) (schemaInfo SchemaInfo, err error) { +func GetSchemaInfo(db *sql.DB, schemaName string, querySet DialectQuerySet) (schemaInfo SchemaMetaData, err error) { schemaInfo.TableInfos, err = getTableInfos(db, querySet, schemaName) @@ -31,7 +31,7 @@ func GetSchemaInfo(db *sql.DB, schemaName string, querySet MetaDataQuerySet) (sc return } -func getTableInfos(db *sql.DB, querySet MetaDataQuerySet, schemaName string) ([]MetaData, error) { +func getTableInfos(db *sql.DB, querySet DialectQuerySet, schemaName string) ([]MetaData, error) { rows, err := db.Query(querySet.ListOfTablesQuery(), schemaName) diff --git a/generator/internal/metadata/table_info.go b/generator/internal/metadata/table_meta_data.go similarity index 68% rename from generator/internal/metadata/table_info.go rename to generator/internal/metadata/table_meta_data.go index 9d43d4a..c2f4b23 100644 --- a/generator/internal/metadata/table_info.go +++ b/generator/internal/metadata/table_meta_data.go @@ -5,27 +5,27 @@ import ( "github.com/go-jet/jet/internal/utils" ) -// TableInfo metadata struct -type TableInfo struct { +// TableMetaData metadata struct +type TableMetaData struct { SchemaName string name string PrimaryKeys map[string]bool - Columns []ColumnInfo + Columns []ColumnMetaData } // Name returns table info name -func (t TableInfo) Name() string { +func (t TableMetaData) Name() string { return t.name } // IsPrimaryKey returns if column is a part of primary key -func (t TableInfo) IsPrimaryKey(column string) bool { +func (t TableMetaData) IsPrimaryKey(column string) bool { return t.PrimaryKeys[column] } // MutableColumns returns list of mutable columns for table -func (t TableInfo) MutableColumns() []ColumnInfo { - ret := []ColumnInfo{} +func (t TableMetaData) MutableColumns() []ColumnMetaData { + ret := []ColumnMetaData{} for _, column := range t.Columns { if t.IsPrimaryKey(column.Name) { @@ -39,11 +39,11 @@ func (t TableInfo) MutableColumns() []ColumnInfo { } // GetImports returns model imports for table. -func (t TableInfo) GetImports() []string { +func (t TableMetaData) GetImports() []string { imports := map[string]string{} for _, column := range t.Columns { - columnType := column.GoBaseType() + columnType := column.GoBaseType switch columnType { case "time.Time": @@ -63,12 +63,12 @@ func (t TableInfo) GetImports() []string { } // GoStructName returns go struct name for sql builder -func (t TableInfo) GoStructName() string { +func (t TableMetaData) GoStructName() string { return utils.ToGoIdentifier(t.name) + "Table" } // GetTableInfo returns table info metadata -func GetTableInfo(db *sql.DB, querySet MetaDataQuerySet, schemaName, tableName string) (tableInfo TableInfo, err error) { +func GetTableInfo(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (tableInfo TableMetaData, err error) { tableInfo.SchemaName = schemaName tableInfo.name = tableName @@ -78,7 +78,7 @@ func GetTableInfo(db *sql.DB, querySet MetaDataQuerySet, schemaName, tableName s return } - tableInfo.Columns, err = getColumnInfos(db, querySet, schemaName, tableName) + tableInfo.Columns, err = getColumnsMetaData(db, querySet, schemaName, tableName) if err != nil { return @@ -87,7 +87,7 @@ func GetTableInfo(db *sql.DB, querySet MetaDataQuerySet, schemaName, tableName s return } -func getPrimaryKeys(db *sql.DB, querySet MetaDataQuerySet, schemaName, tableName string) (map[string]bool, error) { +func getPrimaryKeys(db *sql.DB, querySet DialectQuerySet, schemaName, tableName string) (map[string]bool, error) { rows, err := db.Query(querySet.PrimaryKeysQuery(), schemaName, tableName) diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index a88b7da..a893e3c 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -50,6 +50,9 @@ func Generate(destDir string, dbConn DBConnection) error { func openConnection(dbConn DBConnection) (*sql.DB, error) { var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) + if dbConn.Params != "" { + connectionString += "?" + dbConn.Params + } db, err := sql.Open("mysql", connectionString) fmt.Println("Connecting to MySQL database: " + connectionString) diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index f09c324..d73a850 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -55,8 +55,10 @@ func TestCmdGenerator(t *testing.T) { err = os.RemoveAll(genTestDir2) assert.NilError(t, err) - cmd := exec.Command("jet", "-dbname=jetdb", "-host=localhost", "-port=5432", + cmd := exec.Command("jet", "-source=PostgreSQL", "-dbname=jetdb", "-host=localhost", "-port=5432", "-user=jet", "-password=jet", "-schema=dvds", "-path="+genTestDir2) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout err = cmd.Run() assert.NilError(t, err)