Add jet generator support for SQLite

This commit is contained in:
go-jet 2021-10-21 13:21:01 +02:00
parent 3f7efb33eb
commit 51cad22809
8 changed files with 154 additions and 31 deletions

View file

@ -3,6 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
sqlitegen "github.com/go-jet/jet/v2/generator/sqlite"
"os" "os"
"strings" "strings"
@ -12,6 +13,7 @@ import (
"github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/postgres"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
) )
var ( var (
@ -31,7 +33,7 @@ var (
) )
func init() { func init() {
flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL or MariaDB)") flag.StringVar(&source, "source", "", "Database system name (PostgreSQL, MySQL, MariaDB or SQLite)")
flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)") flag.StringVar(&dsn, "dsn", "", "Data source name connection string (Example: postgresql://user@localhost:5432/otherdb?sslmode=trust)")
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
@ -50,7 +52,7 @@ func main() {
flag.Usage = func() { flag.Usage = func() {
_, _ = fmt.Fprint(os.Stdout, ` _, _ = fmt.Fprint(os.Stdout, `
Jet generator 2.5.0 Jet generator 2.6.0
Usage: Usage:
-dsn string -dsn string
@ -61,8 +63,11 @@ Usage:
MySQL: https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html MySQL: https://dev.mysql.com/doc/refman/8.0/en/connecting-using-uri-or-key-value-pairs.html
Example: Example:
mysql://jet:jet@tcp(localhost:3306)/dvds mysql://jet:jet@tcp(localhost:3306)/dvds
SQLite: https://www.sqlite.org/c3ref/open.html#urifilenameexamples
Example:
file://path/to/database/file
-source string -source string
Database system name (PostgreSQL, MySQL or MariaDB) Database system name (PostgreSQL, MySQL, MariaDB or SQLite)
-host string -host string
Database host path (Example: localhost) Database host path (Example: localhost)
-port int -port int
@ -76,17 +81,18 @@ Usage:
-params string -params string
Additional connection string parameters(optional) Additional connection string parameters(optional)
-schema string -schema string
Database schema name. (default "public") (ignored for MySQL and MariaDB) Database schema name. (default "public") (ignored for MySQL, MariaDB and SQLite)
-sslmode string -sslmode string
Whether or not to use SSL(optional) (default "disable") (ignored for MySQL and MariaDB) Whether or not to use SSL(optional) (default "disable") (ignored for MySQL, MariaDB and SQLite)
-path string -path string
Destination dir for files generated. Destination dir for files generated.
Example commands: Example commands:
$ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=dvds $ jet -source=PostgreSQL -dbname=jetdb -host=localhost -port=5432 -user=jet -password=jet -schema=./dvds
$ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=dvds $ jet -dsn=postgresql://jet:jet@localhost:5432/jetdb -schema=./dvds
$ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=dvds $ jet -source=postgres -dsn="user=jet password=jet host=localhost port=5432 dbname=jetdb" -schema=./dvds
$ jet -source=sqlite -dsn="file://path/to/sqlite/database/file" -schema=./dvds
`) `)
} }
@ -95,7 +101,7 @@ Example commands:
if dsn == "" { if dsn == "" {
// validations for separated connection flags. // validations for separated connection flags.
if source == "" || host == "" || port == 0 || user == "" || dbName == "" { if source == "" || host == "" || port == 0 || user == "" || dbName == "" {
printErrorAndExit("\nERROR: required flag(s) missing") printErrorAndExit("ERROR: required flag(s) missing")
} }
} else { } else {
if source == "" { if source == "" {
@ -105,15 +111,14 @@ Example commands:
// validations when dsn != "" // validations when dsn != ""
if source == "" { if source == "" {
printErrorAndExit("\nERROR: required -source flag missing.") printErrorAndExit("ERROR: required -source flag missing.")
} }
} }
var err error var err error
switch strings.ToLower(strings.TrimSpace(source)) { switch strings.ToLower(strings.TrimSpace(source)) {
case strings.ToLower(postgres.Dialect.Name()), case "postgresql", "postgres":
strings.ToLower(postgres.Dialect.PackageName()):
if dsn != "" { if dsn != "" {
err = postgresgen.GenerateDSN(dsn, schemaName, destDir) err = postgresgen.GenerateDSN(dsn, schemaName, destDir)
break break
@ -132,7 +137,7 @@ Example commands:
err = postgresgen.Generate(destDir, genData) err = postgresgen.Generate(destDir, genData)
case strings.ToLower(mysql.Dialect.Name()), "mysqlx", "mariadb": case "mysql", "mysqlx", "mariadb":
if dsn != "" { if dsn != "" {
err = mysqlgen.GenerateDSN(dsn, destDir) err = mysqlgen.GenerateDSN(dsn, destDir)
break break
@ -147,9 +152,13 @@ Example commands:
} }
err = mysqlgen.Generate(destDir, dbConn) err = mysqlgen.Generate(destDir, dbConn)
case "sqlite":
if dsn == "" {
printErrorAndExit("ERROR: required -dsn flag missing.")
}
err = sqlitegen.GenerateDSN(dsn, destDir)
default: default:
fmt.Println("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.") printErrorAndExit("ERROR: unsupported source " + source + ". " + postgres.Dialect.Name() + " and " + mysql.Dialect.Name() + " are currently supported.")
os.Exit(-4)
} }
if err != nil { if err != nil {
@ -159,12 +168,12 @@ Example commands:
} }
func printErrorAndExit(error string) { func printErrorAndExit(error string) {
fmt.Println(error) fmt.Println("\n", error)
flag.Usage() flag.Usage()
os.Exit(-2) os.Exit(-2)
} }
func detectSchema(dsn string) (source string) { func detectSchema(dsn string) string {
match := strings.SplitN(dsn, "://", 2) match := strings.SplitN(dsn, "://", 2)
if len(match) < 2 { // not found if len(match) < 2 { // not found
return "" return ""

View file

@ -8,9 +8,10 @@ import (
// TableType is type of database table(view or base) // TableType is type of database table(view or base)
type TableType string type TableType string
// SQL table types
const ( const (
baseTable TableType = "BASE TABLE" BaseTable TableType = "BASE TABLE"
viewTable TableType = "VIEW" ViewTable TableType = "VIEW"
) )
// DialectQuerySet is set of methods necessary to retrieve dialect meta data information // DialectQuerySet is set of methods necessary to retrieve dialect meta data information
@ -23,8 +24,8 @@ type DialectQuerySet interface {
func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema { func GetSchema(db *sql.DB, querySet DialectQuerySet, schemaName string) Schema {
ret := Schema{ ret := Schema{
Name: schemaName, Name: schemaName,
TablesMetaData: querySet.GetTablesMetaData(db, schemaName, baseTable), TablesMetaData: querySet.GetTablesMetaData(db, schemaName, BaseTable),
ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, viewTable), ViewsMetaData: querySet.GetTablesMetaData(db, schemaName, ViewTable),
EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName), EnumsMetaData: querySet.GetEnumsMetaData(db, schemaName),
} }

View file

@ -0,0 +1,80 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
"strings"
)
// sqliteQuerySet is dialect query set for SQLite
type sqliteQuerySet struct{}
func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableType metadata.TableType) []metadata.Table {
query := `
SELECT name as "table.name"
FROM sqlite_master
WHERE type=? AND name != 'sqlite_sequence'
ORDER BY name;
`
sqlTableType := "table"
if tableType == metadata.ViewTable {
sqlTableType = "view"
}
var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables)
throw.OnError(err)
for i := range tables {
tables[i].Columns = p.GetTableColumnsMetaData(db, schemaName, tables[i].Name)
}
return tables
}
func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, tableName string) []metadata.Column {
query := fmt.Sprintf(`select * from pragma_table_info(?);`)
var columnInfos []struct {
Name string
Type string
NotNull int32
Pk int32
}
err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos)
throw.OnError(err)
var columns []metadata.Column
for _, columnInfo := range columnInfos {
columnType := getColumnType(columnInfo.Type)
columns = append(columns, metadata.Column{
Name: columnInfo.Name,
IsPrimaryKey: columnInfo.Pk != 0,
IsNullable: columnInfo.NotNull != 1,
DataType: metadata.DataType{
Name: columnType,
Kind: metadata.BaseType,
IsUnsigned: false,
},
})
}
return columns
}
// will convert VARCHAR(10) -> VARCHAR, etc...
func getColumnType(columnType string) string {
return strings.TrimSpace(strings.Split(columnType, "(")[0])
}
func (p sqliteQuerySet) GetEnumsMetaData(db *sql.DB, schemaName string) []metadata.Enum {
return nil
}

View file

@ -0,0 +1,32 @@
package sqlite
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/generator/template"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/sqlite"
)
// GenerateDSN generates jet files using dsn connection string
func GenerateDSN(dsn, destDir string, templates ...template.Template) (err error) {
defer utils.ErrorCatch(&err)
db, err := sql.Open("sqlite3", dsn)
throw.OnError(err)
defer utils.DBClose(db)
fmt.Println("Retrieving schema information...")
generatorTemplate := template.Default(sqlite.Dialect)
if len(templates) > 0 {
generatorTemplate = templates[0]
}
schemaMetadata := metadata.GetSchema(db, &sqliteQuerySet{}, "")
template.ProcessSchema(destDir, schemaMetadata, generatorTemplate)
return
}

View file

@ -74,7 +74,7 @@ func new{{tableTemplate.TypeName}}(schemaName, tableName, alias string) {{tableT
} }
` `
var tablePostgreSQLBuilderTemplate = ` var tableSQLBuilderTemplateWithEXCLUDED = `
{{define "column-list" -}} {{define "column-list" -}}
{{- range $i, $c := . }} {{- range $i, $c := . }}
{{- $field := columnField $c}} {{- $field := columnField $c}}

View file

@ -267,8 +267,8 @@ func getGoType(column metadata.Column) interface{} {
// toGoType returns model type for column info. // toGoType returns model type for column info.
func toGoType(column metadata.Column) interface{} { func toGoType(column metadata.Column) interface{} {
switch column.DataType.Name { switch strings.ToLower(column.DataType.Name) {
case "USER-DEFINED", "enum": case "user-defined", "enum":
return "" return ""
case "boolean", "bool": case "boolean", "bool":
return false return false
@ -306,10 +306,10 @@ func toGoType(column metadata.Column) interface{} {
return []byte("") return []byte("")
case "text", case "text",
"character", "bpchar", "character", "bpchar",
"character varying", "varchar", "character varying", "varchar", "nvarchar",
"tsvector", "bit", "bit varying", "varbit", "tsvector", "bit", "bit varying", "varbit",
"money", "json", "jsonb", "money", "json", "jsonb",
"xml", "point", "interval", "line", "ARRAY", "xml", "point", "interval", "line", "array",
"char", "tinytext", "mediumtext", "longtext": // MySQL "char", "tinytext", "mediumtext", "longtext": // MySQL
return "" return ""
case "real", "float4": case "real", "float4":

View file

@ -169,8 +169,8 @@ func processTableSQLBuilder(fileTypes, dirPath string,
} }
func getTableSQLBuilderTemplate(dialect jet.Dialect) string { func getTableSQLBuilderTemplate(dialect jet.Dialect) string {
if dialect.Name() == "PostgreSQL" { if dialect.Name() == "PostgreSQL" || dialect.Name() == "SQLite" {
return tablePostgreSQLBuilderTemplate return tableSQLBuilderTemplateWithEXCLUDED
} }
return tableSQLBuilderTemplate return tableSQLBuilderTemplate

View file

@ -5,6 +5,7 @@ import (
"github.com/go-jet/jet/v2/generator/metadata" "github.com/go-jet/jet/v2/generator/metadata"
"github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/internal/utils"
"path" "path"
"strings"
"unicode" "unicode"
) )
@ -137,7 +138,7 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string {
return "String" return "String"
} }
switch columnMetaData.DataType.Name { switch strings.ToLower(columnMetaData.DataType.Name) {
case "boolean": case "boolean":
return "Bool" return "Bool"
case "smallint", "integer", "bigint", case "smallint", "integer", "bigint",
@ -157,9 +158,9 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string {
return "Timez" return "Timez"
case "interval": case "interval":
return "Interval" return "Interval"
case "USER-DEFINED", "enum", "text", "character", "character varying", "bytea", "uuid", case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid",
"tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY",
"char", "varchar", "binary", "varbinary", "char", "varchar", "nvarchar", "binary", "varbinary",
"tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL
return "String" return "String"
case "real", "numeric", "decimal", "double precision", "float", case "real", "numeric", "decimal", "double precision", "float",