diff --git a/cmd/jet/main.go b/cmd/jet/main.go index 136e58b..9192eb6 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -3,19 +3,22 @@ package main import ( "flag" "fmt" + "os" + "regexp" + "strings" + mysqlgen "github.com/go-jet/jet/v2/generator/mysql" postgresgen "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/postgres" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" - "os" - "strings" ) var ( source string + dsn string host string port int user string @@ -31,6 +34,7 @@ var ( func init() { flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL or MariaDB)") + flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.IntVar(&port, "port", 0, "Database port") flag.StringVar(&user, "user", "", "Database user") @@ -50,6 +54,12 @@ func main() { Jet generator 2.5.0 Usage: + -dsn string + Data source name. Unified format for connecting to database. + PostgreSQL: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING + Example: postgresql://user:pass@localhost:5432/dbname + MySQL: https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html + Example: mysql://jet:jet@tcp(localhost:3306)/dvds -source string Database system name (PostgreSQL, MySQL or MariaDB) -host string @@ -75,8 +85,21 @@ Usage: flag.Parse() - if source == "" || host == "" || port == 0 || user == "" || dbName == "" { - printErrorAndExit("\nERROR: required flag(s) missing") + if dsn == "" { + // validations for separated connection flags. + if source == "" || host == "" || port == 0 || user == "" || dbName == "" { + printErrorAndExit("\nERROR: required flag(s) missing") + } + } else { + if source == "" { + // try to get source from schema + source = detectSchema(dsn) + } + + // validations when dsn != "" + if source == "" { + printErrorAndExit("\nERROR: required -source flag missing.") + } } var err error @@ -84,6 +107,10 @@ Usage: switch strings.ToLower(strings.TrimSpace(source)) { case strings.ToLower(postgres.Dialect.Name()), strings.ToLower(postgres.Dialect.PackageName()): + if dsn != "" { + err = postgresgen.GenerateDSN(dsn, schemaName, destDir) + break + } genData := postgresgen.DBConnection{ Host: host, Port: port, @@ -98,8 +125,19 @@ Usage: err = postgresgen.Generate(destDir, genData) - case strings.ToLower(mysql.Dialect.Name()), "mariadb": + case strings.ToLower(mysql.Dialect.Name()), "mysqlx", "mariadb": + if dsn != "" { + // Special case for go mysql driver. It does not understand schema, + // so we need to trim it before passing to generator + // https://github.com/go-sql-driver/mysql#dsn-data-source-name + idx := strings.Index(dsn, "://") + if idx != -1 { + dsn = dsn[idx+len("://"):] + } + err = mysqlgen.GenerateDSN(dsn, destDir) + break + } dbConn := mysqlgen.DBConnection{ Host: host, Port: port, @@ -126,3 +164,12 @@ func printErrorAndExit(error string) { flag.Usage() os.Exit(-2) } + +func detectSchema(dsn string) (source string) { + schemeRe := regexp.MustCompile(`^(.+)://.*`) + match := schemeRe.FindStringSubmatch(dsn) + if len(match) < 2 { // not found + return "" + } + return match[1] +} diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index ab00822..6de530b 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -3,11 +3,13 @@ package mysql import ( "database/sql" "fmt" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/mysql" + mysqldr "github.com/go-sql-driver/mysql" ) // DBConnection contains MySQL connection details @@ -25,28 +27,38 @@ type DBConnection struct { func Generate(destDir string, dbConn DBConnection, generatorTemplate ...template.Template) (err error) { defer utils.ErrorCatch(&err) - db := openConnection(dbConn) - defer utils.DBClose(db) - - fmt.Println("Retrieving database information...") - // No schemas in MySQL - schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbConn.DBName) - - genTemplate := template.Default(mysql.Dialect) - if len(generatorTemplate) > 0 { - genTemplate = generatorTemplate[0] + connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) + if dbConn.Params != "" { + connectionString += "?" + dbConn.Params } - template.ProcessSchema(destDir, schemaMetaData, genTemplate) + db := openConnection(connectionString) + defer utils.DBClose(db) + + generate(db, dbConn.DBName, destDir, generatorTemplate...) return nil } -func openConnection(dbConn DBConnection) *sql.DB { - var connectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", dbConn.User, dbConn.Password, dbConn.Host, dbConn.Port, dbConn.DBName) - if dbConn.Params != "" { - connectionString += "?" + dbConn.Params +// GenerateDSN opens connection via DSN string and does everything what Generate does. +func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) { + defer utils.ErrorCatch(&err) + + cfg, err := mysqldr.ParseDSN(dsn) + throw.OnError(err) + if cfg.DBName == "" { + panic("database name is required") } + + db := openConnection(dsn) + defer utils.DBClose(db) + + generate(db, cfg.DBName, destDir, templates...) + + return nil +} + +func openConnection(connectionString string) *sql.DB { fmt.Println("Connecting to MySQL database: " + connectionString) db, err := sql.Open("mysql", connectionString) throw.OnError(err) @@ -56,3 +68,16 @@ func openConnection(dbConn DBConnection) *sql.DB { return db } + +func generate(db *sql.DB, dbName, destDir string, templates ...template.Template) { + fmt.Println("Retrieving database information...") + // No schemas in MySQL + schemaMetaData := metadata.GetSchema(db, &mySqlQuerySet{}, dbName) + + genTemplate := template.Default(mysql.Dialect) + if len(templates) > 0 { + genTemplate = templates[0] + } + + template.ProcessSchema(destDir, schemaMetaData, genTemplate) +} diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index ebb5420..b3dd4a6 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -3,13 +3,15 @@ package postgres import ( "database/sql" "fmt" + "path" + "strconv" + "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/template" "github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/postgres" - "path" - "strconv" + "github.com/jackc/pgconn" ) // DBConnection contains postgres connection details @@ -29,32 +31,37 @@ type DBConnection struct { func Generate(destDir string, dbConn DBConnection, genTemplate ...template.Template) (err error) { defer utils.ErrorCatch(&err) - db := openConnection(dbConn) + connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", + dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params) + + db := openConnection(connectionString) defer utils.DBClose(db) - fmt.Println("Retrieving schema information...") - - generatorTemplate := template.Default(postgres.Dialect) - if len(genTemplate) > 0 { - generatorTemplate = genTemplate[0] - } - - schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, dbConn.SchemaName) - - dirPath := path.Join(destDir, dbConn.DBName) - - template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) + generate(db, dbConn.DBName, dbConn.SchemaName, destDir, genTemplate...) return } -func openConnection(dbConn DBConnection) *sql.DB { - connectionString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s %s", - dbConn.Host, strconv.Itoa(dbConn.Port), dbConn.User, dbConn.Password, dbConn.DBName, dbConn.SslMode, dbConn.Params) +func GenerateDSN(dsn, schema, destDir string, templates ...template.Template) (err error) { + defer utils.ErrorCatch(&err) - fmt.Println("Connecting to postgres database: " + connectionString) + cfg, err := pgconn.ParseConfig(dsn) + throw.OnError(err) + if cfg.Database == "" { + panic("database name is required") + } + db := openConnection(dsn) + defer utils.DBClose(db) - db, err := sql.Open("postgres", connectionString) + generate(db, cfg.Database, schema, destDir, templates...) + + return +} + +func openConnection(dsn string) *sql.DB { + fmt.Println("Connecting to postgres database: " + dsn) + + db, err := sql.Open("postgres", dsn) throw.OnError(err) err = db.Ping() @@ -62,3 +69,17 @@ func openConnection(dbConn DBConnection) *sql.DB { return db } + +func generate(db *sql.DB, dbName, schema, destDir string, templates ...template.Template) { + fmt.Println("Retrieving schema information...") + generatorTemplate := template.Default(postgres.Dialect) + if len(templates) > 0 { + generatorTemplate = templates[0] + } + + schemaMetadata := metadata.GetSchema(db, &postgresQuerySet{}, schema) + + dirPath := path.Join(destDir, dbName) + + template.ProcessSchema(dirPath, schemaMetadata, generatorTemplate) +} diff --git a/go.mod b/go.mod index 9dd3e02..c349db9 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/go-sql-driver/mysql v1.5.0 github.com/google/go-cmp v0.5.0 //tests github.com/google/uuid v1.1.1 + github.com/jackc/pgconn v1.8.1 github.com/jackc/pgx/v4 v4.11.0 //tests github.com/lib/pq v1.7.0 github.com/pkg/profile v1.5.0 //tests diff --git a/go.sum b/go.sum index 47af479..26a2d4a 100644 --- a/go.sum +++ b/go.sum @@ -42,7 +42,6 @@ github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7 github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -457,7 +456,6 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.23.1/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.26.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/tests/init/init.go b/tests/init/init.go index a28ee19..3bd7e64 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -4,16 +4,17 @@ import ( "database/sql" "flag" "fmt" + "io/ioutil" + "os" + "os/exec" + "strings" + "github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/tests/dbconfig" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" - "io/ioutil" - "os" - "os/exec" - "strings" ) var testSuite string diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index c9dcc1a..033f699 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -1,14 +1,16 @@ package mysql import ( - "github.com/go-jet/jet/v2/generator/mysql" - "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" + "fmt" "io/ioutil" "os" "os/exec" "testing" + + "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/stretchr/testify/require" ) const genTestDirRoot = "./.gentestdata3" @@ -30,6 +32,21 @@ func TestGenerator(t *testing.T) { assertGeneratedFiles(t) } + for i := 0; i < 3; i++ { + dsn := fmt.Sprintf("%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s", + dbconfig.MySQLUser, + dbconfig.MySQLPassword, + dbconfig.MySqLHost, + dbconfig.MySQLPort, + "dvds", + ) + err := mysql.GenerateDSN(dsn, genTestDir3) + + require.NoError(t, err) + + assertGeneratedFiles(t) + } + err := os.RemoveAll(genTestDirRoot) require.NoError(t, err) } @@ -51,6 +68,25 @@ func TestCmdGenerator(t *testing.T) { err = os.RemoveAll(genTestDirRoot) require.NoError(t, err) + + // check that generation via DSN works + dsn := fmt.Sprintf("mysql://%[1]s:%[2]s@tcp(%[3]s:%[4]d)/%[5]s", + dbconfig.MySQLUser, + dbconfig.MySQLPassword, + dbconfig.MySqLHost, + dbconfig.MySQLPort, + "dvds", + ) + cmd = exec.Command("jet", "-dsn="+dsn, "-path="+genTestDir3) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err = cmd.Run() + require.NoError(t, err) + + err = os.RemoveAll(genTestDirRoot) + require.NoError(t, err) } func assertGeneratedFiles(t *testing.T) { diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index d1f8a52..b1b733e 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -1,16 +1,18 @@ package postgres import ( - "github.com/go-jet/jet/v2/generator/postgres" - "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/dbconfig" - "github.com/stretchr/testify/require" + "fmt" "io/ioutil" "os" "os/exec" "reflect" "testing" + "github.com/go-jet/jet/v2/generator/postgres" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" ) @@ -61,6 +63,26 @@ func TestCmdGenerator(t *testing.T) { err = os.RemoveAll(genTestDir2) require.NoError(t, err) + + // Check that connection via DSN works + dsn := fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=disable", + dbconfig.PgUser, + dbconfig.PgPassword, + dbconfig.PgHost, + dbconfig.PgPort, + "jetdb", + ) + cmd = exec.Command("jet", "-dsn="+dsn, "-schema=dvds", "-path="+genTestDir2) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err = cmd.Run() + require.NoError(t, err) + + assertGeneratedFiles(t) + + err = os.RemoveAll(genTestDir2) + require.NoError(t, err) } func TestGenerator(t *testing.T) { @@ -83,6 +105,21 @@ func TestGenerator(t *testing.T) { assertGeneratedFiles(t) } + for i := 0; i < 3; i++ { + dsn := fmt.Sprintf("postgresql://%[1]s:%[2]s@%[3]s:%[4]d/%[5]s?sslmode=disable", + dbconfig.PgUser, + dbconfig.PgPassword, + dbconfig.PgHost, + dbconfig.PgPort, + dbconfig.PgDBName, + ) + err := postgres.GenerateDSN(dsn, "dvds", genTestDir2) + + require.NoError(t, err) + + assertGeneratedFiles(t) + } + err := os.RemoveAll(genTestDir2) require.NoError(t, err) }