From 5dda5e1e11c8f9b021a69d3e61c326d90f213f2f Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 28 Jul 2019 14:57:02 +0200 Subject: [PATCH] Generic dialect support. (MySQL and Postgres) --- alias.go | 2 +- bool_expression.go | 10 +-- cast.go | 22 +++++- clause.go | 15 ++-- cmd/jet/main.go | 6 +- column.go | 8 +- delete_statement.go | 16 +++- dialects.go | 97 ++++++++++++++++++++++++ enum_value.go | 2 + expression.go | 39 ++++++++-- func_expression.go | 8 ++ generator/internal/template/generate.go | 9 ++- generator/internal/template/templates.go | 8 +- generator/mysql/mysql_generator.go | 3 +- generator/postgres/postgres_generator.go | 3 +- insert_statement.go | 50 +++++++----- literal_expression.go | 12 +++ lock_statement.go | 18 ++++- operators.go | 18 +++++ select_statement.go | 31 ++++++-- select_table.go | 9 +++ set_statement.go | 14 +++- statement.go | 11 +-- table.go | 28 ++++++- update_statement.go | 17 +++-- utils_test.go | 13 ++-- visitor.go | 63 +++++++++++++++ 27 files changed, 440 insertions(+), 92 deletions(-) create mode 100644 dialects.go create mode 100644 visitor.go diff --git a/alias.go b/alias.go index 29c021c..d19198b 100644 --- a/alias.go +++ b/alias.go @@ -28,7 +28,7 @@ func (a *alias) serializeForProjection(statement statementType, out *sqlBuilder) } out.writeString("AS") - out.writeQuotedString(a.alias) + out.writeAlias(a.alias) return nil } diff --git a/bool_expression.go b/bool_expression.go index d6b359b..efbae3c 100644 --- a/bool_expression.go +++ b/bool_expression.go @@ -93,13 +93,13 @@ type binaryBoolExpression struct { } func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression { - boolExpression := binaryBoolExpression{} + binaryBoolExpression := binaryBoolExpression{} - boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) - boolExpression.expressionInterfaceImpl.parent = &boolExpression - boolExpression.boolInterfaceImpl.parent = &boolExpression + binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) + binaryBoolExpression.expressionInterfaceImpl.parent = &binaryBoolExpression + binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression - return &boolExpression + return &binaryBoolExpression } //---------------------------------------------------// diff --git a/cast.go b/cast.go index b98af0f..2ff6ccd 100644 --- a/cast.go +++ b/cast.go @@ -44,9 +44,27 @@ func CAST(expression Expression) cast { } } +func (b *castImpl) accept(visitor visitor) { + visitor.visit(b) + + b.Expression.accept(visitor) +} + func (b *castImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { - err := b.Expression.serialize(statement, out, options...) - out.writeString("::" + b.castType) + + if castOverride := out.dialect.CastOverride; castOverride != nil { + return castOverride(b.Expression, b.castType)(statement, out, options...) + } + + out.writeString("CAST") + err := WRAP(b.Expression).serialize(statement, out, options...) + if err != nil { + return err + } + + out.writeString("AS") + out.writeString(b.castType) + return err } diff --git a/clause.go b/clause.go index e007470..2e5cdcf 100644 --- a/clause.go +++ b/clause.go @@ -30,8 +30,9 @@ func contains(options []serializeOption, option serializeOption) bool { } type sqlBuilder struct { - buff bytes.Buffer - args []interface{} + dialect Dialect + buff bytes.Buffer + args []interface{} lastChar byte ident int @@ -162,8 +163,9 @@ func isPostSeparator(b byte) bool { return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' } -func (q *sqlBuilder) writeQuotedString(str string) { - q.writeString(`"` + str + `"`) +func (q *sqlBuilder) writeAlias(str string) { + aliasQuoteChar := string(q.dialect.AliasQuoteChar) + q.writeString(aliasQuoteChar + str + aliasQuoteChar) } func (q *sqlBuilder) writeString(str string) { @@ -174,7 +176,8 @@ func (q *sqlBuilder) writeIdentifier(name string) { quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -") if quoteWrap { - q.writeString(`"` + name + `"`) + identQuoteChar := string(q.dialect.IdentifierQuoteChar) + q.writeString(identQuoteChar + name + identQuoteChar) } else { q.writeString(name) } @@ -194,7 +197,7 @@ func (q *sqlBuilder) insertConstantArgument(arg interface{}) { func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) { q.args = append(q.args, arg) - argPlaceholder := "$" + strconv.Itoa(len(q.args)) + argPlaceholder := q.dialect.ArgumentPlaceholder(len(q.args)) q.writeString(argPlaceholder) } diff --git a/cmd/jet/main.go b/cmd/jet/main.go index 8c06c49..1554b2a 100644 --- a/cmd/jet/main.go +++ b/cmd/jet/main.go @@ -27,7 +27,7 @@ var ( ) func init() { - flag.StringVar(&source, "source", jet.PostgreSQL, "Database name") + flag.StringVar(&source, "source", string(jet.PostgreSQL.Name), "Database name") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.IntVar(&port, "port", 0, "Database port") @@ -72,7 +72,7 @@ Usage of jet: var err error switch source { - case jet.PostgreSQL: + case jet.PostgreSQL.Name: if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" { fmt.Println("\njet: required flag missing") flag.Usage() @@ -93,7 +93,7 @@ Usage of jet: err = postgres.Generate(destDir, genData) - case jet.MySql: + case jet.MySql.Name: if host == "" || port == 0 || user == "" || dbName == "" { fmt.Println("\njet: required flag missing") flag.Usage() diff --git a/column.go b/column.go index f2a035c..0c24c63 100644 --- a/column.go +++ b/column.go @@ -20,6 +20,7 @@ type Column interface { // The base type for real materialized columns. type columnImpl struct { expressionInterfaceImpl + noOpVisitorImpl name string tableName string @@ -65,7 +66,7 @@ func (c *columnImpl) defaultAlias() string { func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error { if statement == setStatement { // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause - out.writeString(`"` + c.defaultAlias() + `"`) //always quote + out.writeAlias(c.defaultAlias()) //always quote return nil } @@ -80,7 +81,8 @@ func (c columnImpl) serializeForProjection(statement statementType, out *sqlBuil return err } - out.writeString(`AS "` + c.defaultAlias() + `"`) + out.writeString("AS") + out.writeAlias(c.defaultAlias()) return nil } @@ -90,7 +92,7 @@ func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options if c.subQuery != nil { out.writeIdentifier(c.subQuery.Alias()) out.writeByte('.') - out.writeQuotedString(c.defaultAlias()) + out.writeAlias(c.defaultAlias()) } else { if c.tableName != "" { out.writeIdentifier(c.tableName) diff --git a/delete_statement.go b/delete_statement.go index 98bf1dd..ec0c09c 100644 --- a/delete_statement.go +++ b/delete_statement.go @@ -38,6 +38,12 @@ func (d *deleteStatementImpl) RETURNING(projections ...projection) DeleteStateme return d } +func (d *deleteStatementImpl) accept(visitor visitor) { + visitor.visit(d) + + d.table.accept(visitor) +} + func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error { if d == nil { return errors.New("jet: delete statement is nil") @@ -68,8 +74,10 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error { return nil } -func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) { - queryData := &sqlBuilder{} +func (d *deleteStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { + queryData := &sqlBuilder{ + dialect: detectDialect(d, dialect...), + } err = d.serializeImpl(queryData) @@ -81,8 +89,8 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error return } -func (d *deleteStatementImpl) DebugSql() (query string, err error) { - return debugSql(d) +func (d *deleteStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { + return debugSql(d, dialect...) } func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error { diff --git a/dialects.go b/dialects.go new file mode 100644 index 0000000..8bb7449 --- /dev/null +++ b/dialects.go @@ -0,0 +1,97 @@ +package jet + +import ( + "github.com/pkg/errors" + "strconv" +) + +var ( + PostgreSQL = newPostgresDialect() + MySql = newMySQLDialect() +) + +func newPostgresDialect() Dialect { + postgresDialect := newDialect("PostgreSQL", "postgres") + + postgresDialect.OperatorOverrides["IS DISTINCT FROM"] = postgresIS_DISTINCT_FROM + postgresDialect.CastOverride = postgresCAST + postgresDialect.AliasQuoteChar = '"' + postgresDialect.IdentifierQuoteChar = '"' + postgresDialect.ArgumentPlaceholder = func(ord int) string { + return "$" + strconv.Itoa(ord) + } + + return postgresDialect +} + +func newMySQLDialect() Dialect { + mySQLDialect := newDialect("MySQL", "mysql") + + mySQLDialect.AliasQuoteChar = '"' + mySQLDialect.IdentifierQuoteChar = '"' + mySQLDialect.ArgumentPlaceholder = func(int) string { + return "?" + } + + return mySQLDialect +} + +type Dialect struct { + Name string + PackageName string + OperatorOverrides map[string]serializeOverride + CastOverride castOverride + AliasQuoteChar byte + IdentifierQuoteChar byte + ArgumentPlaceholder queryPlaceholderFunc +} + +func (d *Dialect) serializeOverride(operator string) serializeOverride { + return d.OperatorOverrides[operator] +} + +type queryPlaceholderFunc func(ord int) string + +func newDialect(name, packageName string) Dialect { + newDialect := Dialect{ + Name: name, + PackageName: packageName, + } + newDialect.OperatorOverrides = make(map[string]serializeOverride) + + return newDialect +} + +func postgresIS_DISTINCT_FROM(expressions ...Expression) serializeFunc { + return func(statement statementType, out *sqlBuilder, options ...serializeOption) error { + if len(expressions) != 2 { + return errors.New("Invalid number of expressions for operator") + } + if err := expressions[0].serialize(statement, out); err != nil { + return err + } + + out.writeString("IS DISTINCT FROM") + + if err := expressions[1].serialize(statement, out); err != nil { + return err + } + + return nil + } +} + +func postgresCAST(expression Expression, castType string) serializeFunc { + return func(statement statementType, out *sqlBuilder, options ...serializeOption) error { + if err := expression.serialize(statement, out, options...); err != nil { + return err + } + out.writeString("::" + castType) + return nil + } +} + +type serializeFunc func(statement statementType, out *sqlBuilder, options ...serializeOption) error +type serializeOverride func(expressions ...Expression) serializeFunc + +type castOverride func(expression Expression, castType string) serializeFunc diff --git a/enum_value.go b/enum_value.go index 7691f96..0828517 100644 --- a/enum_value.go +++ b/enum_value.go @@ -3,6 +3,8 @@ package jet type enumValue struct { expressionInterfaceImpl stringInterfaceImpl + noOpVisitorImpl + name string } diff --git a/expression.go b/expression.go index 81fc67e..d35b2e3 100644 --- a/expression.go +++ b/expression.go @@ -7,6 +7,12 @@ import ( // Expression is common interface for all expressions. // Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions. type Expression interface { + acceptsVisitor + + expression +} + +type expression interface { clause projection groupByClause @@ -95,7 +101,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpressio return binaryExpression } -func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { +func (c *binaryOpExpression) accept(visitor visitor) { + c.lhs.accept(visitor) + c.rhs.accept(visitor) +} + +func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) { if c == nil { return errors.New("jet: binary Expression is nil") } @@ -112,21 +123,25 @@ func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, out.writeString("(") } - if err := c.lhs.serialize(statement, out); err != nil { - return err - } + if dialectOveride := out.dialect.serializeOverride(c.operator); dialectOveride != nil { + err = dialectOveride(c.lhs, c.rhs)(statement, out, options...) + } else { + if err := c.lhs.serialize(statement, out); err != nil { + return err + } - out.writeString(c.operator) + out.writeString(c.operator) - if err := c.rhs.serialize(statement, out); err != nil { - return err + if err := c.rhs.serialize(statement, out); err != nil { + return err + } } if wrap { out.writeString(")") } - return nil + return err } // A prefix operator Expression @@ -144,6 +159,10 @@ func newPrefixExpression(expression Expression, operator string) prefixOpExpress return prefixExpression } +func (p *prefixOpExpression) accept(visitor visitor) { + p.expression.accept(visitor) +} + func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { if p == nil { return errors.New("jet: Prefix Expression is nil") @@ -176,6 +195,10 @@ func newPostfixOpExpression(expression Expression, operator string) postfixOpExp return postfixOpExpression } +func (p *postfixOpExpression) accept(visitor visitor) { + p.expression.accept(visitor) +} + func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { if p == nil { return errors.New("jet: Postifx operator Expression is nil") diff --git a/func_expression.go b/func_expression.go index 91fa61a..cbd1427 100644 --- a/func_expression.go +++ b/func_expression.go @@ -485,6 +485,14 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr return funcExp } +func (f *funcExpressionImpl) accept(visitor visitor) { + visitor.visit(f) + + for _, exp := range f.expressions { + exp.accept(visitor) + } +} + func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { if f == nil { return errors.New("jet: Function expressions is nil. ") diff --git a/generator/internal/template/generate.go b/generator/internal/template/generate.go index 628ea36..366e907 100644 --- a/generator/internal/template/generate.go +++ b/generator/internal/template/generate.go @@ -3,6 +3,7 @@ package template import ( "bytes" "fmt" + "github.com/go-jet/jet" "github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/internal/utils" "path/filepath" @@ -10,7 +11,7 @@ import ( "time" ) -func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect string) error { +func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error { if len(tables) == 0 && len(enums) == 0 { return nil } @@ -59,7 +60,7 @@ func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect st } -func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect string) error { +func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error { modelDirPath := filepath.Join(dirPath, packageName) err := utils.EnsureDirPath(modelDirPath) @@ -92,14 +93,14 @@ func generate(dirPath, packageName string, template string, metaDataList []metad } // GenerateTemplate generates template with template text and template data. -func GenerateTemplate(templateText string, templateData interface{}, dialect string) ([]byte, error) { +func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect) ([]byte, error) { t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ "ToGoIdentifier": utils.ToGoIdentifier, "now": func() string { return time.Now().Format(time.RFC850) }, - "dialect": func() string { + "dialect": func() jet.Dialect { return dialect }, }).Parse(templateText) diff --git a/generator/internal/template/templates.go b/generator/internal/template/templates.go index c54c9b3..c066e6b 100644 --- a/generator/internal/template/templates.go +++ b/generator/internal/template/templates.go @@ -22,7 +22,7 @@ package table import ( "github.com/go-jet/jet" - "github.com/go-jet/jet/{{dialect}}" + "github.com/go-jet/jet/{{dialect.PackageName}}" ) var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() @@ -32,7 +32,7 @@ type {{.GoStructName}} struct { //Columns {{- range .Columns}} - {{ToGoIdentifier .Name}} {{dialect}}.Column{{.SqlBuilderColumnType}} + {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}} {{- end}} AllColumns jet.ColumnList @@ -51,12 +51,12 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} { func new{{.GoStructName}}() *{{.GoStructName}} { var ( {{- range .Columns}} - {{ToGoIdentifier .Name}}Column = {{dialect}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") + {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{- end}} ) return &{{.GoStructName}}{ - Table: jet.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), + Table: jet.NewTable(jet.{{dialect.Name}}, "{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), //Columns {{- range .Columns}} diff --git a/generator/mysql/mysql_generator.go b/generator/mysql/mysql_generator.go index d3a8c65..1765efc 100644 --- a/generator/mysql/mysql_generator.go +++ b/generator/mysql/mysql_generator.go @@ -3,6 +3,7 @@ package mysql import ( "database/sql" "fmt" + "github.com/go-jet/jet" "github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/template" "path" @@ -37,7 +38,7 @@ func Generate(destDir string, dbConn DBConnection) error { genPath := path.Join(destDir, dbConn.DBName) - err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, "mysql") + err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, jet.MySql) if err != nil { return err diff --git a/generator/postgres/postgres_generator.go b/generator/postgres/postgres_generator.go index 18f3601..727ca94 100644 --- a/generator/postgres/postgres_generator.go +++ b/generator/postgres/postgres_generator.go @@ -3,6 +3,7 @@ package postgres import ( "database/sql" "fmt" + "github.com/go-jet/jet" "github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/template" "path" @@ -41,7 +42,7 @@ func Generate(destDir string, dbConn DBConnection) error { genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) - err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, "postgres") + err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, jet.PostgreSQL) if err != nil { return err diff --git a/insert_statement.go b/insert_statement.go index de48064..2a49543 100644 --- a/insert_statement.go +++ b/insert_statement.go @@ -73,36 +73,44 @@ func (i *insertStatementImpl) getColumns() []column { return i.table.columns() } -func (i *insertStatementImpl) DebugSql() (query string, err error) { - return debugSql(i) +func (i *insertStatementImpl) accept(visitor visitor) { + visitor.visit(i) + + i.table.accept(visitor) } -func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) { - queryData := &sqlBuilder{} +func (i *insertStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { + return debugSql(i, dialect...) +} - queryData.newLine() - queryData.writeString("INSERT INTO") +func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { + out := &sqlBuilder{ + dialect: detectDialect(i, dialect...), + } + + out.newLine() + out.writeString("INSERT INTO") if utils.IsNil(i.table) { return "", nil, errors.New("jet: table is nil") } - err = i.table.serialize(insertStatement, queryData) + err = i.table.serialize(insertStatement, out) if err != nil { return } if len(i.columns) > 0 { - queryData.writeString("(") + out.writeString("(") - err = serializeColumnNames(i.columns, queryData) + err = serializeColumnNames(i.columns, out) if err != nil { return } - queryData.writeString(")") + out.writeString(")") } if len(i.rows) == 0 && i.query == nil { @@ -114,41 +122,41 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(i.rows) > 0 { - queryData.writeString("VALUES") + out.writeString("VALUES") for rowIndex, row := range i.rows { if rowIndex > 0 { - queryData.writeString(",") + out.writeString(",") } - queryData.increaseIdent() - queryData.newLine() - queryData.writeString("(") + out.increaseIdent() + out.newLine() + out.writeString("(") - err = serializeClauseList(insertStatement, row, queryData) + err = serializeClauseList(insertStatement, row, out) if err != nil { return "", nil, err } - queryData.writeByte(')') - queryData.decreaseIdent() + out.writeByte(')') + out.decreaseIdent() } } if i.query != nil { - err = i.query.serialize(insertStatement, queryData) + err = i.query.serialize(insertStatement, out) if err != nil { return } } - if err = queryData.writeReturning(insertStatement, i.returning); err != nil { + if err = out.writeReturning(insertStatement, i.returning); err != nil { return } - sql, args = queryData.finalize() + query, args = out.finalize() return } diff --git a/literal_expression.go b/literal_expression.go index 27dca03..3195767 100644 --- a/literal_expression.go +++ b/literal_expression.go @@ -5,6 +5,8 @@ import "fmt" // Representation of an escaped literal type literalExpression struct { expressionInterfaceImpl + noOpVisitorImpl + value interface{} constant bool } @@ -188,6 +190,7 @@ func Date(year, month, day int) DateExpression { //--------------------------------------------------// type nullLiteral struct { expressionInterfaceImpl + noOpVisitorImpl } func newNullLiteral() Expression { @@ -206,6 +209,7 @@ func (n *nullLiteral) serialize(statement statementType, out *sqlBuilder, option //--------------------------------------------------// type starLiteral struct { expressionInterfaceImpl + noOpVisitorImpl } func newStarLiteral() Expression { @@ -228,6 +232,12 @@ type wrap struct { expressions []Expression } +func (n *wrap) accept(visitor visitor) { + for _, exp := range n.expressions { + exp.accept(visitor) + } +} + func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { out.writeString("(") err := serializeExpressionList(statement, n.expressions, ", ", out) @@ -247,6 +257,8 @@ func WRAP(expression ...Expression) Expression { type rawExpression struct { expressionInterfaceImpl + noOpVisitorImpl + raw string } diff --git a/lock_statement.go b/lock_statement.go index da249ec..ba7105e 100644 --- a/lock_statement.go +++ b/lock_statement.go @@ -53,11 +53,19 @@ func (l *lockStatementImpl) NOWAIT() LockStatement { return l } -func (l *lockStatementImpl) DebugSql() (query string, err error) { - return debugSql(l) +func (l *lockStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { + return debugSql(l, dialect...) } -func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) { +func (l *lockStatementImpl) accept(visitor visitor) { + visitor.visit(l) + + for _, table := range l.tables { + table.accept(visitor) + } +} + +func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { if l == nil { return "", nil, errors.New("jet: nil Statement") } @@ -66,7 +74,9 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) return "", nil, errors.New("jet: There is no table selected to be locked") } - out := &sqlBuilder{} + out := &sqlBuilder{ + dialect: detectDialect(l, dialect...), + } out.newLine() out.writeString("LOCK TABLE") diff --git a/operators.go b/operators.go index 9e5187f..7ac6247 100644 --- a/operators.go +++ b/operators.go @@ -108,6 +108,24 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator { return c } +func (c *caseOperatorImpl) accept(visitor visitor) { + visitor.visit(c) + + c.expression.accept(visitor) + + for _, when := range c.when { + when.accept(visitor) + } + + for _, then := range c.then { + then.accept(visitor) + } + + if c.els != nil { + c.els.accept(visitor) + } +} + func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { if c == nil { return errors.New("jet: Case Expression is nil. ") diff --git a/select_statement.go b/select_statement.go index 25083ac..a1151c9 100644 --- a/select_statement.go +++ b/select_statement.go @@ -18,7 +18,7 @@ var ( // SelectStatement is interface for SQL SELECT statements type SelectStatement interface { Statement - Expression + expression DISTINCT() SelectStatement FROM(table ReadableTable) SelectStatement @@ -261,10 +261,29 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error { return nil } -func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) { - queryData := sqlBuilder{} +func (s *selectStatementImpl) accept(visitor visitor) { + visitor.visit(s) - err = s.serializeImpl(&queryData) + if s.table != nil { + s.table.accept(visitor) + } + + if s.where != nil { + s.where.accept(visitor) + } + + if s.having != nil { + s.having.accept(visitor) + } +} + +func (s *selectStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { + + queryData := &sqlBuilder{ + dialect: detectDialect(s, dialect...), + } + + err = s.serializeImpl(queryData) if err != nil { return "", nil, err @@ -275,8 +294,8 @@ func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error return } -func (s *selectStatementImpl) DebugSql() (query string, err error) { - return debugSql(s.parent) +func (s *selectStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { + return debugSql(s.parent, dialect...) } func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error { diff --git a/select_table.go b/select_table.go index 2d7c901..00e5fdc 100644 --- a/select_table.go +++ b/select_table.go @@ -41,6 +41,15 @@ func (s *selectTableImpl) columns() []column { return nil } +func (s *selectTableImpl) accept(visitor visitor) { + visitor.visit(s) + s.selectStmt.accept(visitor) +} + +func (s *selectTableImpl) dialect() Dialect { + return detectDialect(s.selectStmt) +} + func (s *selectTableImpl) AllColumns() ProjectionList { return s.projections } diff --git a/set_statement.go b/set_statement.go index 54e5e77..c81fca5 100644 --- a/set_statement.go +++ b/set_statement.go @@ -74,6 +74,14 @@ func newSetStatementImpl(operator string, all bool, selects []SelectStatement) S return setStatement } +func (s *setStatementImpl) accept(visitor visitor) { + visitor.visit(s) + + for _, selects := range s.selects { + selects.accept(visitor) + } +} + func (s *setStatementImpl) projections() []projection { if len(s.selects) > 0 { return s.selects[0].projections() @@ -169,8 +177,10 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error { return nil } -func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) { - queryData := &sqlBuilder{} +func (s *setStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { + queryData := &sqlBuilder{ + dialect: detectDialect(s, dialect...), + } err = s.serializeImpl(queryData) diff --git a/statement.go b/statement.go index e7df331..a5ece84 100644 --- a/statement.go +++ b/statement.go @@ -4,19 +4,19 @@ import ( "context" "database/sql" "github.com/go-jet/jet/execution" - "strconv" "strings" ) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) type Statement interface { + acceptsVisitor // Sql returns parametrized sql query with list of arguments. // err is returned if statement is not composed correctly - Sql() (query string, args []interface{}, err error) + Sql(dialect ...Dialect) (query string, args []interface{}, err error) // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // Do not use it in production. Use it only for debug purposes. // err is returned if statement is not composed correctly - DebugSql() (query string, err error) + DebugSql(dialect ...Dialect) (query string, err error) // Query executes statement over database connection db and stores row result in destination. // Destination can be arbitrary structure @@ -31,7 +31,8 @@ type Statement interface { ExecContext(context context.Context, db execution.DB) (sql.Result, error) } -func debugSql(statement Statement) (string, error) { +func debugSql(statement Statement, overrideDialect ...Dialect) (string, error) { + dialect := detectDialect(statement, overrideDialect...) sqlQuery, args, err := statement.Sql() if err != nil { @@ -41,7 +42,7 @@ func debugSql(statement Statement) (string, error) { debugSQLQuery := sqlQuery for i, arg := range args { - argPlaceholder := "$" + strconv.Itoa(i+1) + argPlaceholder := dialect.ArgumentPlaceholder(i + 1) debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1) } diff --git a/table.go b/table.go index 1e12370..413087a 100644 --- a/table.go +++ b/table.go @@ -6,6 +6,7 @@ import ( ) type table interface { + dialect() Dialect columns() []column } @@ -42,6 +43,7 @@ type ReadableTable interface { table readableTable clause + acceptsVisitor } // WritableTable interface @@ -49,6 +51,7 @@ type WritableTable interface { table writableTable clause + acceptsVisitor } // Table interface @@ -57,6 +60,8 @@ type Table interface { readableTable writableTable clause + acceptsVisitor + SchemaName() string TableName() string AS(alias string) @@ -114,10 +119,11 @@ func (w *writableTableInterfaceImpl) LOCK() LockStatement { return LOCK(w.parent) } -// NewTable creates new table with schema name, table name and list of columns -func NewTable(schemaName, name string, columns ...Column) Table { +// NewTable creates new table with schema Name, table Name and list of columns +func NewTable(Dialect Dialect, schemaName, name string, columns ...Column) Table { t := &tableImpl{ + Dialect: Dialect, schemaName: schemaName, name: name, columnList: columns, @@ -136,6 +142,7 @@ type tableImpl struct { readableTableInterfaceImpl writableTableInterfaceImpl + Dialect Dialect schemaName string name string alias string @@ -168,6 +175,14 @@ func (t *tableImpl) columns() []column { return ret } +func (t *tableImpl) dialect() Dialect { + return t.Dialect +} + +func (t *tableImpl) accept(visitor visitor) { + visitor.visit(t) +} + func (t *tableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { if t == nil { return errors.New("jet: tableImpl is nil. ") @@ -235,6 +250,15 @@ func (t *joinTable) columns() []column { return append(t.lhs.columns(), t.rhs.columns()...) } +func (t *joinTable) accept(visitor visitor) { + t.lhs.accept(visitor) + t.rhs.accept(visitor) +} + +func (t *joinTable) dialect() Dialect { + return detectDialect(t) +} + func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) { if t == nil { return errors.New("jet: Join table is nil. ") diff --git a/update_statement.go b/update_statement.go index b124a42..ee3bfae 100644 --- a/update_statement.go +++ b/update_statement.go @@ -57,8 +57,15 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStateme return u } -func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) { - out := &sqlBuilder{} +func (u *updateStatementImpl) accept(visitor visitor) { + visitor.visit(u) + u.table.accept(visitor) +} + +func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { + out := &sqlBuilder{ + dialect: detectDialect(u, dialect...), + } out.newLine() out.writeString("UPDATE") @@ -124,12 +131,12 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) return } - sql, args = out.finalize() + query, args = out.finalize() return } -func (u *updateStatementImpl) DebugSql() (query string, err error) { - return debugSql(u) +func (u *updateStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) { + return debugSql(u, dialect...) } func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error { diff --git a/utils_test.go b/utils_test.go index 2f0ea75..f37b833 100644 --- a/utils_test.go +++ b/utils_test.go @@ -17,6 +17,7 @@ var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") var table1 = NewTable( + PostgreSQL, "db", "table1", table1Col1, @@ -44,6 +45,7 @@ var table2ColTimestampz = TimestampzColumn("col_timestampz") var table2ColDate = DateColumn("col_date") var table2 = NewTable( + PostgreSQL, "db", "table2", table2Col3, @@ -63,6 +65,7 @@ var table3Col1 = IntegerColumn("col1") var table3ColInt = IntegerColumn("col_int") var table3StrCol = StringColumn("col2") var table3 = NewTable( + PostgreSQL, "db", "table3", table3Col1, @@ -70,7 +73,7 @@ var table3 = NewTable( table3StrCol) func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) { - out := sqlBuilder{} + out := sqlBuilder{dialect: PostgreSQL} err := clause.serialize(selectStatement, &out) assert.NilError(t, err) @@ -80,7 +83,7 @@ func assertClauseSerialize(t *testing.T, clause clause, query string, args ...in } func assertClauseSerializeErr(t *testing.T, clause clause, errString string) { - out := sqlBuilder{} + out := sqlBuilder{dialect: PostgreSQL} err := clause.serialize(selectStatement, &out) //fmt.Println(out.buff.String()) @@ -89,7 +92,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) { } func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) { - out := sqlBuilder{} + out := sqlBuilder{dialect: PostgreSQL} err := projection.serializeForProjection(selectStatement, &out) assert.NilError(t, err) @@ -99,7 +102,7 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string } func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) { - queryStr, args, err := query.Sql() + queryStr, args, err := query.Sql(PostgreSQL) assert.NilError(t, err) //fmt.Println(queryStr) @@ -108,7 +111,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect } func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { - _, _, err := stmt.Sql() + _, _, err := stmt.Sql(PostgreSQL) assert.Assert(t, err != nil) assert.Error(t, err, errorStr) diff --git a/visitor.go b/visitor.go new file mode 100644 index 0000000..6c9dc84 --- /dev/null +++ b/visitor.go @@ -0,0 +1,63 @@ +package jet + +type visitor interface { + visit(element acceptsVisitor) +} + +type acceptsVisitor interface { + accept(visitor visitor) +} + +type noOpVisitorImpl struct { +} + +func (n *noOpVisitorImpl) accept(visitor visitor) { + // NO OP +} + +// --------------- dialect finder -----------------// + +type DialectFinder struct { + dialects map[string]Dialect +} + +func newDialectFinder() *DialectFinder { + return &DialectFinder{ + dialects: make(map[string]Dialect), + } +} + +func (f *DialectFinder) dialect() Dialect { + if len(f.dialects) == 0 { + panic("jet: can't detect dialect") + } + + if len(f.dialects) > 1 { + panic("jet: more than one dialect detected") + } + + for _, dialect := range f.dialects { + return dialect + } + + panic("jet: internal error") +} + +func (f *DialectFinder) visit(element acceptsVisitor) { + + if table, ok := element.(table); ok { + dialect := table.dialect() + f.dialects[dialect.Name] = dialect + } +} + +func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect { + if len(dialectOverride) > 0 { + return dialectOverride[0] + } + + dialectFinder := newDialectFinder() + element.accept(dialectFinder) + + return dialectFinder.dialect() +}