From 51cad22809bb50f7181452501de8be07ae6de90e Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 21 Oct 2021 13:21:01 +0200 Subject: [PATCH] Add jet generator support for SQLite --- cmd/jet/main.go | 43 +++++++----- generator/metadata/dialect_query_set.go | 9 +-- generator/sqlite/query_set.go | 80 ++++++++++++++++++++++ generator/sqlite/sqlite_generator.go | 32 +++++++++ generator/template/file_templates.go | 2 +- generator/template/model_template.go | 8 +-- generator/template/process.go | 4 +- generator/template/sql_builder_template.go | 7 +- 8 files changed, 154 insertions(+), 31 deletions(-) create mode 100644 generator/sqlite/query_set.go create mode 100644 generator/sqlite/sqlite_generator.go diff --git a/cmd/jet/main.go b/cmd/jet/main.go index b75f58b..9409f45 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + sqlitegen "github.com/go-jet/jet/v2/generator/sqlite" "os" "strings" @@ -12,6 +13,7 @@ import ( "github.com/go-jet/jet/v2/postgres" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" ) var ( @@ -31,7 +33,7 @@ var ( ) func init() { - flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL or MariaDB)") + flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL, MariaDB or SQLite)") flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") @@ -50,7 +52,7 @@ func main() { flag.Usage = func() { _, _ = fmt.Fprint(os.Stdout, ` -Jet generator 2.5.0 +Jet generator 2.6.0 Usage: -dsn string @@ -61,8 +63,11 @@ Usage: MySQL: https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html Example: mysql://jet:jet@tcp(localhost:3306)/dvds + SQLite: https://www.sqlite.org/c3ref/open.html#urifilenameexamples + Example: + file://path/to/database/file -source string - Database system name (PostgreSQL, MySQL or MariaDB) + Database system name (PostgreSQL, MySQL, MariaDB or SQLite) -host string Database host path (Example: localhost) -port int @@ -76,17 +81,18 @@ Usage: -params string Additional connection string parameters(optional) -schema string - Database schema name. (default "public") (ignored for MySQL and MariaDB) + Database schema name. (default "public") (ignored for MySQL, MariaDB and SQLite) -sslmode string - Whether or not to use SSL(optional) (default "disable") (ignored for MySQL and MariaDB) + Whether or not to use SSL(optional) (default "disable") (ignored for MySQL, MariaDB and SQLite) -path string Destination dir for files generated. Example commands: - $ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=dvds - $ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=dvds - $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds + $ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=./dvds + $ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=./dvds + $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=./dvds + $ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -schema=./dvds `) } @@ -95,7 +101,7 @@ Example commands: if dsn == "" { // validations for separated connection flags. if source == "" || host == "" || port == 0 || user == "" || dbName == "" { - printErrorAndExit("\nERROR: required flag(s) missing") + printErrorAndExit("ERROR: required flag(s) missing") } } else { if source == "" { @@ -105,15 +111,14 @@ Example commands: // validations when dsn != "" if source == "" { - printErrorAndExit("\nERROR: required -source flag missing.") + printErrorAndExit("ERROR: required -source flag missing.") } } var err error switch strings.ToLower(strings.TrimSpace(source)) { - case strings.ToLower(postgres.Dialect.Name()), - strings.ToLower(postgres.Dialect.PackageName()): + case "postgresql", "postgres": if dsn != "" { err = postgresgen.GenerateDSN(dsn, schemaName, destDir) break @@ -132,7 +137,7 @@ Example commands: err = postgresgen.Generate(destDir, genData) - case strings.ToLower(mysql.Dialect.Name()), "mysqlx", "mariadb": + case "mysql", "mysqlx", "mariadb": if dsn != "" { err = mysqlgen.GenerateDSN(dsn, destDir) break @@ -147,9 +152,13 @@ Example commands: } err = mysqlgen.Generate(destDir, dbConn) + case "sqlite": + if dsn == "" { + printErrorAndExit("ERROR: required -dsn flag missing.") + } + err = sqlitegen.GenerateDSN(dsn, destDir) default: - fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.") - os.Exit(-4) + printErrorAndExit("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.") } if err != nil { @@ -159,12 +168,12 @@ Example commands: } func printErrorAndExit(error string) { - fmt.Println(error) + fmt.Println("\n", error) flag.Usage() os.Exit(-2) } -func detectSchema(dsn string) (source string) { +func detectSchema(dsn string) string { match := strings.SplitN(dsn, "://", 2) if len(match) < 2 { // not found return "" diff --git a/generator/metadata/dialect_query_set.go b/generator/metadata/dialect_query_set.go index 036e4d5..66a32a6 100644 --- a/generator/metadata/dialect_query_set.go +++ b/generator/metadata/dialect_query_set.go @@ -8,9 +8,10 @@ import ( // TableType is type of database table(view or base) type TableType string +// SQL table types const ( - baseTable TableType = "BASE TABLE" - viewTable TableType = "VIEW" + BaseTable TableType = "BASE TABLE" + ViewTable TableType = "VIEW" ) // DialectQuerySet is set of methods necessary to retrieve dialect meta data information @@ -23,8 +24,8 @@ type DialectQuerySet interface { func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema { ret := Schema{ Name: schemaName, - TablesMetaData: querySet.GetTablesMetaData(db, schemaName, baseTable), - ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, viewTable), + TablesMetaData: querySet.GetTablesMetaData(db, schemaName, BaseTable), + ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, ViewTable), EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName), } diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go new file mode 100644 index 0000000..e1d5e4d --- /dev/null +++ b/generator/sqlite/query_set.go @@ -0,0 +1,80 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/qrm" + "strings" +) + +// sqliteQuerySet is dialect query set for SQLite +type sqliteQuerySet struct{} + +func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table { + query := ` + SELECT name as "table.name" + FROM sqlite_master + WHERE type=? AND name != 'sqlite_sequence' + ORDER BY name; +` + sqlTableType := "table" + + if tableType == metadata.ViewTable { + sqlTableType = "view" + } + + var tables []metadata.Table + + err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) + throw.OnError(err) + + for i := range tables { + tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name) + } + + return tables +} + +func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column { + query := fmt.Sprintf(`select * from pragma_table_info(?);`) + var columnInfos []struct { + Name string + Type string + NotNull int32 + Pk int32 + } + + err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) + throw.OnError(err) + + var columns []metadata.Column + + for _, columnInfo := range columnInfos { + columnType := getColumnType(columnInfo.Type) + + columns = append(columns, metadata.Column{ + Name: columnInfo.Name, + IsPrimaryKey: columnInfo.Pk != 0, + IsNullable: columnInfo.NotNull != 1, + DataType: metadata.DataType{ + Name: columnType, + Kind: metadata.BaseType, + IsUnsigned: false, + }, + }) + } + + return columns +} + +// will convert VARCHAR(10) -> VARCHAR, etc... +func getColumnType(columnType string) string { + return strings.TrimSpace(strings.Split(columnType, "(")[0]) +} + +func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum { + return nil +} diff --git a/generator/sqlite/sqlite_generator.go b/generator/sqlite/sqlite_generator.go new file mode 100644 index 0000000..7887394 --- /dev/null +++ b/generator/sqlite/sqlite_generator.go @@ -0,0 +1,32 @@ +package sqlite + +import ( + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/generator/metadata" + "github.com/go-jet/jet/v2/generator/template" + "github.com/go-jet/jet/v2/internal/utils" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/sqlite" +) + +// GenerateDSN generates jet files using dsn connection string +func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) { + defer utils.ErrorCatch(&err) + + db, err := sql.Open("sqlite3", dsn) + throw.OnError(err) + defer utils.DBClose(db) + + fmt.Println("Retrieving schema information...") + + generatorTemplate := template.Default(sqlite.Dialect) + if len(templates) > 0 { + generatorTemplate = templates[0] + } + + schemaMetadata := metadata.GetSchema(db, &sqliteQuerySet{}, "") + + template.ProcessSchema(destDir, schemaMetadata, generatorTemplate) + return +} diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index 8b5c5b0..e3020ce 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -74,7 +74,7 @@ func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) {{tableT } ` -var tablePostgreSQLBuilderTemplate = ` +var tableSQLBuilderTemplateWithEXCLUDED = ` {{define "column-list" -}} {{- range $i, $c := . }} {{- $field := columnField $c}} diff --git a/generator/template/model_template.go b/generator/template/model_template.go index 732cc2f..032afc8 100644 --- a/generator/template/model_template.go +++ b/generator/template/model_template.go @@ -267,8 +267,8 @@ func getGoType(column metadata.Column) interface{} { // toGoType returns model type for column info. func toGoType(column metadata.Column) interface{} { - switch column.DataType.Name { - case "USER-DEFINED", "enum": + switch strings.ToLower(column.DataType.Name) { + case "user-defined", "enum": return "" case "boolean", "bool": return false @@ -306,10 +306,10 @@ func toGoType(column metadata.Column) interface{} { return []byte("") case "text", "character", "bpchar", - "character varying", "varchar", + "character varying", "varchar", "nvarchar", "tsvector", "bit", "bit varying", "varbit", "money", "json", "jsonb", - "xml", "point", "interval", "line", "ARRAY", + "xml", "point", "interval", "line", "array", "char", "tinytext", "mediumtext", "longtext": // MySQL return "" case "real", "float4": diff --git a/generator/template/process.go b/generator/template/process.go index ff3775e..46a598d 100644 --- a/generator/template/process.go +++ b/generator/template/process.go @@ -169,8 +169,8 @@ func processTableSQLBuilder(fileTypes, dirPath string, } func getTableSQLBuilderTemplate(dialect jet.Dialect) string { - if dialect.Name() == "PostgreSQL" { - return tablePostgreSQLBuilderTemplate + if dialect.Name() == "PostgreSQL" || dialect.Name() == "SQLite" { + return tableSQLBuilderTemplateWithEXCLUDED } return tableSQLBuilderTemplate diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index 8d2d18c..099c0e3 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -5,6 +5,7 @@ import ( "github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/internal/utils" "path" + "strings" "unicode" ) @@ -137,7 +138,7 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string { return "String" } - switch columnMetaData.DataType.Name { + switch strings.ToLower(columnMetaData.DataType.Name) { case "boolean": return "Bool" case "smallint", "integer", "bigint", @@ -157,9 +158,9 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string { return "Timez" case "interval": return "Interval" - case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid", + case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", - "char", "varchar", "binary", "varbinary", + "char", "varchar", "nvarchar", "binary", "varbinary", "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL return "String" case "real", "numeric", "decimal", "double precision", "float",