Generic dialect support. (MySQL and Postgres)

This commit is contained in:
go-jet 2019-07-28 14:57:02 +02:00
parent 043a0dc4c0
commit 5dda5e1e11
27 changed files with 440 additions and 92 deletions

View file

@ -28,7 +28,7 @@ func (a *alias) serializeForProjection(statement statementType, out *sqlBuilder)
} }
out.writeString("AS") out.writeString("AS")
out.writeQuotedString(a.alias) out.writeAlias(a.alias)
return nil return nil
} }

View file

@ -93,13 +93,13 @@ type binaryBoolExpression struct {
} }
func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression { func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression {
boolExpression := binaryBoolExpression{} binaryBoolExpression := binaryBoolExpression{}
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator) binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
boolExpression.expressionInterfaceImpl.parent = &boolExpression binaryBoolExpression.expressionInterfaceImpl.parent = &binaryBoolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression
return &boolExpression return &binaryBoolExpression
} }
//---------------------------------------------------// //---------------------------------------------------//

22
cast.go
View file

@ -44,9 +44,27 @@ func CAST(expression Expression) cast {
} }
} }
func (b *castImpl) accept(visitor visitor) {
visitor.visit(b)
b.Expression.accept(visitor)
}
func (b *castImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (b *castImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
err := b.Expression.serialize(statement, out, options...)
out.writeString("::" + b.castType) if castOverride := out.dialect.CastOverride; castOverride != nil {
return castOverride(b.Expression, b.castType)(statement, out, options...)
}
out.writeString("CAST")
err := WRAP(b.Expression).serialize(statement, out, options...)
if err != nil {
return err
}
out.writeString("AS")
out.writeString(b.castType)
return err return err
} }

View file

@ -30,8 +30,9 @@ func contains(options []serializeOption, option serializeOption) bool {
} }
type sqlBuilder struct { type sqlBuilder struct {
buff bytes.Buffer dialect Dialect
args []interface{} buff bytes.Buffer
args []interface{}
lastChar byte lastChar byte
ident int ident int
@ -162,8 +163,9 @@ func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
} }
func (q *sqlBuilder) writeQuotedString(str string) { func (q *sqlBuilder) writeAlias(str string) {
q.writeString(`"` + str + `"`) aliasQuoteChar := string(q.dialect.AliasQuoteChar)
q.writeString(aliasQuoteChar + str + aliasQuoteChar)
} }
func (q *sqlBuilder) writeString(str string) { func (q *sqlBuilder) writeString(str string) {
@ -174,7 +176,8 @@ func (q *sqlBuilder) writeIdentifier(name string) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -") quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap { if quoteWrap {
q.writeString(`"` + name + `"`) identQuoteChar := string(q.dialect.IdentifierQuoteChar)
q.writeString(identQuoteChar + name + identQuoteChar)
} else { } else {
q.writeString(name) q.writeString(name)
} }
@ -194,7 +197,7 @@ func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) { func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
q.args = append(q.args, arg) q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args)) argPlaceholder := q.dialect.ArgumentPlaceholder(len(q.args))
q.writeString(argPlaceholder) q.writeString(argPlaceholder)
} }

View file

@ -27,7 +27,7 @@ var (
) )
func init() { func init() {
flag.StringVar(&source, "source", jet.PostgreSQL, "Database name") flag.StringVar(&source, "source", string(jet.PostgreSQL.Name), "Database name")
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)") flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
flag.IntVar(&port, "port", 0, "Database port") flag.IntVar(&port, "port", 0, "Database port")
@ -72,7 +72,7 @@ Usage of jet:
var err error var err error
switch source { switch source {
case jet.PostgreSQL: case jet.PostgreSQL.Name:
if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" { if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" {
fmt.Println("\njet: required flag missing") fmt.Println("\njet: required flag missing")
flag.Usage() flag.Usage()
@ -93,7 +93,7 @@ Usage of jet:
err = postgres.Generate(destDir, genData) err = postgres.Generate(destDir, genData)
case jet.MySql: case jet.MySql.Name:
if host == "" || port == 0 || user == "" || dbName == "" { if host == "" || port == 0 || user == "" || dbName == "" {
fmt.Println("\njet: required flag missing") fmt.Println("\njet: required flag missing")
flag.Usage() flag.Usage()

View file

@ -20,6 +20,7 @@ type Column interface {
// The base type for real materialized columns. // The base type for real materialized columns.
type columnImpl struct { type columnImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
noOpVisitorImpl
name string name string
tableName string tableName string
@ -65,7 +66,7 @@ func (c *columnImpl) defaultAlias() string {
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error { func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
if statement == setStatement { if statement == setStatement {
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.writeString(`"` + c.defaultAlias() + `"`) //always quote out.writeAlias(c.defaultAlias()) //always quote
return nil return nil
} }
@ -80,7 +81,8 @@ func (c columnImpl) serializeForProjection(statement statementType, out *sqlBuil
return err return err
} }
out.writeString(`AS "` + c.defaultAlias() + `"`) out.writeString("AS")
out.writeAlias(c.defaultAlias())
return nil return nil
} }
@ -90,7 +92,7 @@ func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options
if c.subQuery != nil { if c.subQuery != nil {
out.writeIdentifier(c.subQuery.Alias()) out.writeIdentifier(c.subQuery.Alias())
out.writeByte('.') out.writeByte('.')
out.writeQuotedString(c.defaultAlias()) out.writeAlias(c.defaultAlias())
} else { } else {
if c.tableName != "" { if c.tableName != "" {
out.writeIdentifier(c.tableName) out.writeIdentifier(c.tableName)

View file

@ -38,6 +38,12 @@ func (d *deleteStatementImpl) RETURNING(projections ...projection) DeleteStateme
return d return d
} }
func (d *deleteStatementImpl) accept(visitor visitor) {
visitor.visit(d)
d.table.accept(visitor)
}
func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error { func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
if d == nil { if d == nil {
return errors.New("jet: delete statement is nil") return errors.New("jet: delete statement is nil")
@ -68,8 +74,10 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
return nil return nil
} }
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) { func (d *deleteStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &sqlBuilder{} queryData := &sqlBuilder{
dialect: detectDialect(d, dialect...),
}
err = d.serializeImpl(queryData) err = d.serializeImpl(queryData)
@ -81,8 +89,8 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error
return return
} }
func (d *deleteStatementImpl) DebugSql() (query string, err error) { func (d *deleteStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(d) return debugSql(d, dialect...)
} }
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error { func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {

97
dialects.go Normal file
View file

@ -0,0 +1,97 @@
package jet
import (
"github.com/pkg/errors"
"strconv"
)
var (
PostgreSQL = newPostgresDialect()
MySql = newMySQLDialect()
)
func newPostgresDialect() Dialect {
postgresDialect := newDialect("PostgreSQL", "postgres")
postgresDialect.OperatorOverrides["IS DISTINCT FROM"] = postgresIS_DISTINCT_FROM
postgresDialect.CastOverride = postgresCAST
postgresDialect.AliasQuoteChar = '"'
postgresDialect.IdentifierQuoteChar = '"'
postgresDialect.ArgumentPlaceholder = func(ord int) string {
return "$" + strconv.Itoa(ord)
}
return postgresDialect
}
func newMySQLDialect() Dialect {
mySQLDialect := newDialect("MySQL", "mysql")
mySQLDialect.AliasQuoteChar = '"'
mySQLDialect.IdentifierQuoteChar = '"'
mySQLDialect.ArgumentPlaceholder = func(int) string {
return "?"
}
return mySQLDialect
}
type Dialect struct {
Name string
PackageName string
OperatorOverrides map[string]serializeOverride
CastOverride castOverride
AliasQuoteChar byte
IdentifierQuoteChar byte
ArgumentPlaceholder queryPlaceholderFunc
}
func (d *Dialect) serializeOverride(operator string) serializeOverride {
return d.OperatorOverrides[operator]
}
type queryPlaceholderFunc func(ord int) string
func newDialect(name, packageName string) Dialect {
newDialect := Dialect{
Name: name,
PackageName: packageName,
}
newDialect.OperatorOverrides = make(map[string]serializeOverride)
return newDialect
}
func postgresIS_DISTINCT_FROM(expressions ...Expression) serializeFunc {
return func(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if len(expressions) != 2 {
return errors.New("Invalid number of expressions for operator")
}
if err := expressions[0].serialize(statement, out); err != nil {
return err
}
out.writeString("IS DISTINCT FROM")
if err := expressions[1].serialize(statement, out); err != nil {
return err
}
return nil
}
}
func postgresCAST(expression Expression, castType string) serializeFunc {
return func(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if err := expression.serialize(statement, out, options...); err != nil {
return err
}
out.writeString("::" + castType)
return nil
}
}
type serializeFunc func(statement statementType, out *sqlBuilder, options ...serializeOption) error
type serializeOverride func(expressions ...Expression) serializeFunc
type castOverride func(expression Expression, castType string) serializeFunc

View file

@ -3,6 +3,8 @@ package jet
type enumValue struct { type enumValue struct {
expressionInterfaceImpl expressionInterfaceImpl
stringInterfaceImpl stringInterfaceImpl
noOpVisitorImpl
name string name string
} }

View file

@ -7,6 +7,12 @@ import (
// Expression is common interface for all expressions. // Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions. // Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
type Expression interface { type Expression interface {
acceptsVisitor
expression
}
type expression interface {
clause clause
projection projection
groupByClause groupByClause
@ -95,7 +101,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpressio
return binaryExpression return binaryExpression
} }
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (c *binaryOpExpression) accept(visitor visitor) {
c.lhs.accept(visitor)
c.rhs.accept(visitor)
}
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) {
if c == nil { if c == nil {
return errors.New("jet: binary Expression is nil") return errors.New("jet: binary Expression is nil")
} }
@ -112,21 +123,25 @@ func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder,
out.writeString("(") out.writeString("(")
} }
if err := c.lhs.serialize(statement, out); err != nil { if dialectOveride := out.dialect.serializeOverride(c.operator); dialectOveride != nil {
return err err = dialectOveride(c.lhs, c.rhs)(statement, out, options...)
} } else {
if err := c.lhs.serialize(statement, out); err != nil {
return err
}
out.writeString(c.operator) out.writeString(c.operator)
if err := c.rhs.serialize(statement, out); err != nil { if err := c.rhs.serialize(statement, out); err != nil {
return err return err
}
} }
if wrap { if wrap {
out.writeString(")") out.writeString(")")
} }
return nil return err
} }
// A prefix operator Expression // A prefix operator Expression
@ -144,6 +159,10 @@ func newPrefixExpression(expression Expression, operator string) prefixOpExpress
return prefixExpression return prefixExpression
} }
func (p *prefixOpExpression) accept(visitor visitor) {
p.expression.accept(visitor)
}
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil { if p == nil {
return errors.New("jet: Prefix Expression is nil") return errors.New("jet: Prefix Expression is nil")
@ -176,6 +195,10 @@ func newPostfixOpExpression(expression Expression, operator string) postfixOpExp
return postfixOpExpression return postfixOpExpression
} }
func (p *postfixOpExpression) accept(visitor visitor) {
p.expression.accept(visitor)
}
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if p == nil { if p == nil {
return errors.New("jet: Postifx operator Expression is nil") return errors.New("jet: Postifx operator Expression is nil")

View file

@ -485,6 +485,14 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
return funcExp return funcExp
} }
func (f *funcExpressionImpl) accept(visitor visitor) {
visitor.visit(f)
for _, exp := range f.expressions {
exp.accept(visitor)
}
}
func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if f == nil { if f == nil {
return errors.New("jet: Function expressions is nil. ") return errors.New("jet: Function expressions is nil. ")

View file

@ -3,6 +3,7 @@ package template
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"github.com/go-jet/jet"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"path/filepath" "path/filepath"
@ -10,7 +11,7 @@ import (
"time" "time"
) )
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect string) error { func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error {
if len(tables) == 0 && len(enums) == 0 { if len(tables) == 0 && len(enums) == 0 {
return nil return nil
} }
@ -59,7 +60,7 @@ func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect st
} }
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect string) error { func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error {
modelDirPath := filepath.Join(dirPath, packageName) modelDirPath := filepath.Join(dirPath, packageName)
err := utils.EnsureDirPath(modelDirPath) err := utils.EnsureDirPath(modelDirPath)
@ -92,14 +93,14 @@ func generate(dirPath, packageName string, template string, metaDataList []metad
} }
// GenerateTemplate generates template with template text and template data. // GenerateTemplate generates template with template text and template data.
func GenerateTemplate(templateText string, templateData interface{}, dialect string) ([]byte, error) { func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect) ([]byte, error) {
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{ t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
"ToGoIdentifier": utils.ToGoIdentifier, "ToGoIdentifier": utils.ToGoIdentifier,
"now": func() string { "now": func() string {
return time.Now().Format(time.RFC850) return time.Now().Format(time.RFC850)
}, },
"dialect": func() string { "dialect": func() jet.Dialect {
return dialect return dialect
}, },
}).Parse(templateText) }).Parse(templateText)

View file

@ -22,7 +22,7 @@ package table
import ( import (
"github.com/go-jet/jet" "github.com/go-jet/jet"
"github.com/go-jet/jet/{{dialect}}" "github.com/go-jet/jet/{{dialect.PackageName}}"
) )
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}() var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
@ -32,7 +32,7 @@ type {{.GoStructName}} struct {
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}} {{dialect}}.Column{{.SqlBuilderColumnType}} {{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
{{- end}} {{- end}}
AllColumns jet.ColumnList AllColumns jet.ColumnList
@ -51,12 +51,12 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
func new{{.GoStructName}}() *{{.GoStructName}} { func new{{.GoStructName}}() *{{.GoStructName}} {
var ( var (
{{- range .Columns}} {{- range .Columns}}
{{ToGoIdentifier .Name}}Column = {{dialect}}.{{.SqlBuilderColumnType}}Column("{{.Name}}") {{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
{{- end}} {{- end}}
) )
return &{{.GoStructName}}{ return &{{.GoStructName}}{
Table: jet.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}), Table: jet.NewTable(jet.{{dialect.Name}}, "{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
//Columns //Columns
{{- range .Columns}} {{- range .Columns}}

View file

@ -3,6 +3,7 @@ package mysql
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template" "github.com/go-jet/jet/generator/internal/template"
"path" "path"
@ -37,7 +38,7 @@ func Generate(destDir string, dbConn DBConnection) error {
genPath := path.Join(destDir, dbConn.DBName) genPath := path.Join(destDir, dbConn.DBName)
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, "mysql") err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, jet.MySql)
if err != nil { if err != nil {
return err return err

View file

@ -3,6 +3,7 @@ package postgres
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet"
"github.com/go-jet/jet/generator/internal/metadata" "github.com/go-jet/jet/generator/internal/metadata"
"github.com/go-jet/jet/generator/internal/template" "github.com/go-jet/jet/generator/internal/template"
"path" "path"
@ -41,7 +42,7 @@ func Generate(destDir string, dbConn DBConnection) error {
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName) genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, "postgres") err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, jet.PostgreSQL)
if err != nil { if err != nil {
return err return err

View file

@ -73,36 +73,44 @@ func (i *insertStatementImpl) getColumns() []column {
return i.table.columns() return i.table.columns()
} }
func (i *insertStatementImpl) DebugSql() (query string, err error) { func (i *insertStatementImpl) accept(visitor visitor) {
return debugSql(i) visitor.visit(i)
i.table.accept(visitor)
} }
func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) { func (i *insertStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
queryData := &sqlBuilder{} return debugSql(i, dialect...)
}
queryData.newLine() func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData.writeString("INSERT INTO") out := &sqlBuilder{
dialect: detectDialect(i, dialect...),
}
out.newLine()
out.writeString("INSERT INTO")
if utils.IsNil(i.table) { if utils.IsNil(i.table) {
return "", nil, errors.New("jet: table is nil") return "", nil, errors.New("jet: table is nil")
} }
err = i.table.serialize(insertStatement, queryData) err = i.table.serialize(insertStatement, out)
if err != nil { if err != nil {
return return
} }
if len(i.columns) > 0 { if len(i.columns) > 0 {
queryData.writeString("(") out.writeString("(")
err = serializeColumnNames(i.columns, queryData) err = serializeColumnNames(i.columns, out)
if err != nil { if err != nil {
return return
} }
queryData.writeString(")") out.writeString(")")
} }
if len(i.rows) == 0 && i.query == nil { if len(i.rows) == 0 && i.query == nil {
@ -114,41 +122,41 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
if len(i.rows) > 0 { if len(i.rows) > 0 {
queryData.writeString("VALUES") out.writeString("VALUES")
for rowIndex, row := range i.rows { for rowIndex, row := range i.rows {
if rowIndex > 0 { if rowIndex > 0 {
queryData.writeString(",") out.writeString(",")
} }
queryData.increaseIdent() out.increaseIdent()
queryData.newLine() out.newLine()
queryData.writeString("(") out.writeString("(")
err = serializeClauseList(insertStatement, row, queryData) err = serializeClauseList(insertStatement, row, out)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
queryData.writeByte(')') out.writeByte(')')
queryData.decreaseIdent() out.decreaseIdent()
} }
} }
if i.query != nil { if i.query != nil {
err = i.query.serialize(insertStatement, queryData) err = i.query.serialize(insertStatement, out)
if err != nil { if err != nil {
return return
} }
} }
if err = queryData.writeReturning(insertStatement, i.returning); err != nil { if err = out.writeReturning(insertStatement, i.returning); err != nil {
return return
} }
sql, args = queryData.finalize() query, args = out.finalize()
return return
} }

View file

@ -5,6 +5,8 @@ import "fmt"
// Representation of an escaped literal // Representation of an escaped literal
type literalExpression struct { type literalExpression struct {
expressionInterfaceImpl expressionInterfaceImpl
noOpVisitorImpl
value interface{} value interface{}
constant bool constant bool
} }
@ -188,6 +190,7 @@ func Date(year, month, day int) DateExpression {
//--------------------------------------------------// //--------------------------------------------------//
type nullLiteral struct { type nullLiteral struct {
expressionInterfaceImpl expressionInterfaceImpl
noOpVisitorImpl
} }
func newNullLiteral() Expression { func newNullLiteral() Expression {
@ -206,6 +209,7 @@ func (n *nullLiteral) serialize(statement statementType, out *sqlBuilder, option
//--------------------------------------------------// //--------------------------------------------------//
type starLiteral struct { type starLiteral struct {
expressionInterfaceImpl expressionInterfaceImpl
noOpVisitorImpl
} }
func newStarLiteral() Expression { func newStarLiteral() Expression {
@ -228,6 +232,12 @@ type wrap struct {
expressions []Expression expressions []Expression
} }
func (n *wrap) accept(visitor visitor) {
for _, exp := range n.expressions {
exp.accept(visitor)
}
}
func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
out.writeString("(") out.writeString("(")
err := serializeExpressionList(statement, n.expressions, ", ", out) err := serializeExpressionList(statement, n.expressions, ", ", out)
@ -247,6 +257,8 @@ func WRAP(expression ...Expression) Expression {
type rawExpression struct { type rawExpression struct {
expressionInterfaceImpl expressionInterfaceImpl
noOpVisitorImpl
raw string raw string
} }

View file

@ -53,11 +53,19 @@ func (l *lockStatementImpl) NOWAIT() LockStatement {
return l return l
} }
func (l *lockStatementImpl) DebugSql() (query string, err error) { func (l *lockStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(l) return debugSql(l, dialect...)
} }
func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) { func (l *lockStatementImpl) accept(visitor visitor) {
visitor.visit(l)
for _, table := range l.tables {
table.accept(visitor)
}
}
func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
if l == nil { if l == nil {
return "", nil, errors.New("jet: nil Statement") return "", nil, errors.New("jet: nil Statement")
} }
@ -66,7 +74,9 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
return "", nil, errors.New("jet: There is no table selected to be locked") return "", nil, errors.New("jet: There is no table selected to be locked")
} }
out := &sqlBuilder{} out := &sqlBuilder{
dialect: detectDialect(l, dialect...),
}
out.newLine() out.newLine()
out.writeString("LOCK TABLE") out.writeString("LOCK TABLE")

View file

@ -108,6 +108,24 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
return c return c
} }
func (c *caseOperatorImpl) accept(visitor visitor) {
visitor.visit(c)
c.expression.accept(visitor)
for _, when := range c.when {
when.accept(visitor)
}
for _, then := range c.then {
then.accept(visitor)
}
if c.els != nil {
c.els.accept(visitor)
}
}
func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if c == nil { if c == nil {
return errors.New("jet: Case Expression is nil. ") return errors.New("jet: Case Expression is nil. ")

View file

@ -18,7 +18,7 @@ var (
// SelectStatement is interface for SQL SELECT statements // SelectStatement is interface for SQL SELECT statements
type SelectStatement interface { type SelectStatement interface {
Statement Statement
Expression expression
DISTINCT() SelectStatement DISTINCT() SelectStatement
FROM(table ReadableTable) SelectStatement FROM(table ReadableTable) SelectStatement
@ -261,10 +261,29 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
return nil return nil
} }
func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) { func (s *selectStatementImpl) accept(visitor visitor) {
queryData := sqlBuilder{} visitor.visit(s)
err = s.serializeImpl(&queryData) if s.table != nil {
s.table.accept(visitor)
}
if s.where != nil {
s.where.accept(visitor)
}
if s.having != nil {
s.having.accept(visitor)
}
}
func (s *selectStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &sqlBuilder{
dialect: detectDialect(s, dialect...),
}
err = s.serializeImpl(queryData)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -275,8 +294,8 @@ func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error
return return
} }
func (s *selectStatementImpl) DebugSql() (query string, err error) { func (s *selectStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(s.parent) return debugSql(s.parent, dialect...)
} }
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error { func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {

View file

@ -41,6 +41,15 @@ func (s *selectTableImpl) columns() []column {
return nil return nil
} }
func (s *selectTableImpl) accept(visitor visitor) {
visitor.visit(s)
s.selectStmt.accept(visitor)
}
func (s *selectTableImpl) dialect() Dialect {
return detectDialect(s.selectStmt)
}
func (s *selectTableImpl) AllColumns() ProjectionList { func (s *selectTableImpl) AllColumns() ProjectionList {
return s.projections return s.projections
} }

View file

@ -74,6 +74,14 @@ func newSetStatementImpl(operator string, all bool, selects []SelectStatement) S
return setStatement return setStatement
} }
func (s *setStatementImpl) accept(visitor visitor) {
visitor.visit(s)
for _, selects := range s.selects {
selects.accept(visitor)
}
}
func (s *setStatementImpl) projections() []projection { func (s *setStatementImpl) projections() []projection {
if len(s.selects) > 0 { if len(s.selects) > 0 {
return s.selects[0].projections() return s.selects[0].projections()
@ -169,8 +177,10 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
return nil return nil
} }
func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) { func (s *setStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
queryData := &sqlBuilder{} queryData := &sqlBuilder{
dialect: detectDialect(s, dialect...),
}
err = s.serializeImpl(queryData) err = s.serializeImpl(queryData)

View file

@ -4,19 +4,19 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
"strconv"
"strings" "strings"
) )
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement interface { type Statement interface {
acceptsVisitor
// Sql returns parametrized sql query with list of arguments. // Sql returns parametrized sql query with list of arguments.
// err is returned if statement is not composed correctly // err is returned if statement is not composed correctly
Sql() (query string, args []interface{}, err error) Sql(dialect ...Dialect) (query string, args []interface{}, err error)
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes. // Do not use it in production. Use it only for debug purposes.
// err is returned if statement is not composed correctly // err is returned if statement is not composed correctly
DebugSql() (query string, err error) DebugSql(dialect ...Dialect) (query string, err error)
// Query executes statement over database connection db and stores row result in destination. // Query executes statement over database connection db and stores row result in destination.
// Destination can be arbitrary structure // Destination can be arbitrary structure
@ -31,7 +31,8 @@ type Statement interface {
ExecContext(context context.Context, db execution.DB) (sql.Result, error) ExecContext(context context.Context, db execution.DB) (sql.Result, error)
} }
func debugSql(statement Statement) (string, error) { func debugSql(statement Statement, overrideDialect ...Dialect) (string, error) {
dialect := detectDialect(statement, overrideDialect...)
sqlQuery, args, err := statement.Sql() sqlQuery, args, err := statement.Sql()
if err != nil { if err != nil {
@ -41,7 +42,7 @@ func debugSql(statement Statement) (string, error) {
debugSQLQuery := sqlQuery debugSQLQuery := sqlQuery
for i, arg := range args { for i, arg := range args {
argPlaceholder := "$" + strconv.Itoa(i+1) argPlaceholder := dialect.ArgumentPlaceholder(i + 1)
debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1) debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
} }

View file

@ -6,6 +6,7 @@ import (
) )
type table interface { type table interface {
dialect() Dialect
columns() []column columns() []column
} }
@ -42,6 +43,7 @@ type ReadableTable interface {
table table
readableTable readableTable
clause clause
acceptsVisitor
} }
// WritableTable interface // WritableTable interface
@ -49,6 +51,7 @@ type WritableTable interface {
table table
writableTable writableTable
clause clause
acceptsVisitor
} }
// Table interface // Table interface
@ -57,6 +60,8 @@ type Table interface {
readableTable readableTable
writableTable writableTable
clause clause
acceptsVisitor
SchemaName() string SchemaName() string
TableName() string TableName() string
AS(alias string) AS(alias string)
@ -114,10 +119,11 @@ func (w *writableTableInterfaceImpl) LOCK() LockStatement {
return LOCK(w.parent) return LOCK(w.parent)
} }
// NewTable creates new table with schema name, table name and list of columns // NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, columns ...Column) Table { func NewTable(Dialect Dialect, schemaName, name string, columns ...Column) Table {
t := &tableImpl{ t := &tableImpl{
Dialect: Dialect,
schemaName: schemaName, schemaName: schemaName,
name: name, name: name,
columnList: columns, columnList: columns,
@ -136,6 +142,7 @@ type tableImpl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
writableTableInterfaceImpl writableTableInterfaceImpl
Dialect Dialect
schemaName string schemaName string
name string name string
alias string alias string
@ -168,6 +175,14 @@ func (t *tableImpl) columns() []column {
return ret return ret
} }
func (t *tableImpl) dialect() Dialect {
return t.Dialect
}
func (t *tableImpl) accept(visitor visitor) {
visitor.visit(t)
}
func (t *tableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error { func (t *tableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
if t == nil { if t == nil {
return errors.New("jet: tableImpl is nil. ") return errors.New("jet: tableImpl is nil. ")
@ -235,6 +250,15 @@ func (t *joinTable) columns() []column {
return append(t.lhs.columns(), t.rhs.columns()...) return append(t.lhs.columns(), t.rhs.columns()...)
} }
func (t *joinTable) accept(visitor visitor) {
t.lhs.accept(visitor)
t.rhs.accept(visitor)
}
func (t *joinTable) dialect() Dialect {
return detectDialect(t)
}
func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) { func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) {
if t == nil { if t == nil {
return errors.New("jet: Join table is nil. ") return errors.New("jet: Join table is nil. ")

View file

@ -57,8 +57,15 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStateme
return u return u
} }
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) { func (u *updateStatementImpl) accept(visitor visitor) {
out := &sqlBuilder{} visitor.visit(u)
u.table.accept(visitor)
}
func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
out := &sqlBuilder{
dialect: detectDialect(u, dialect...),
}
out.newLine() out.newLine()
out.writeString("UPDATE") out.writeString("UPDATE")
@ -124,12 +131,12 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
return return
} }
sql, args = out.finalize() query, args = out.finalize()
return return
} }
func (u *updateStatementImpl) DebugSql() (query string, err error) { func (u *updateStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
return debugSql(u) return debugSql(u, dialect...)
} }
func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error { func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error {

View file

@ -17,6 +17,7 @@ var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date") var table1ColDate = DateColumn("col_date")
var table1 = NewTable( var table1 = NewTable(
PostgreSQL,
"db", "db",
"table1", "table1",
table1Col1, table1Col1,
@ -44,6 +45,7 @@ var table2ColTimestampz = TimestampzColumn("col_timestampz")
var table2ColDate = DateColumn("col_date") var table2ColDate = DateColumn("col_date")
var table2 = NewTable( var table2 = NewTable(
PostgreSQL,
"db", "db",
"table2", "table2",
table2Col3, table2Col3,
@ -63,6 +65,7 @@ var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int") var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2") var table3StrCol = StringColumn("col2")
var table3 = NewTable( var table3 = NewTable(
PostgreSQL,
"db", "db",
"table3", "table3",
table3Col1, table3Col1,
@ -70,7 +73,7 @@ var table3 = NewTable(
table3StrCol) table3StrCol)
func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) { func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) {
out := sqlBuilder{} out := sqlBuilder{dialect: PostgreSQL}
err := clause.serialize(selectStatement, &out) err := clause.serialize(selectStatement, &out)
assert.NilError(t, err) assert.NilError(t, err)
@ -80,7 +83,7 @@ func assertClauseSerialize(t *testing.T, clause clause, query string, args ...in
} }
func assertClauseSerializeErr(t *testing.T, clause clause, errString string) { func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
out := sqlBuilder{} out := sqlBuilder{dialect: PostgreSQL}
err := clause.serialize(selectStatement, &out) err := clause.serialize(selectStatement, &out)
//fmt.Println(out.buff.String()) //fmt.Println(out.buff.String())
@ -89,7 +92,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
} }
func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) { func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) {
out := sqlBuilder{} out := sqlBuilder{dialect: PostgreSQL}
err := projection.serializeForProjection(selectStatement, &out) err := projection.serializeForProjection(selectStatement, &out)
assert.NilError(t, err) assert.NilError(t, err)
@ -99,7 +102,7 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string
} }
func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) { func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args, err := query.Sql() queryStr, args, err := query.Sql(PostgreSQL)
assert.NilError(t, err) assert.NilError(t, err)
//fmt.Println(queryStr) //fmt.Println(queryStr)
@ -108,7 +111,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect
} }
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
_, _, err := stmt.Sql() _, _, err := stmt.Sql(PostgreSQL)
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
assert.Error(t, err, errorStr) assert.Error(t, err, errorStr)

63
visitor.go Normal file
View file

@ -0,0 +1,63 @@
package jet
type visitor interface {
visit(element acceptsVisitor)
}
type acceptsVisitor interface {
accept(visitor visitor)
}
type noOpVisitorImpl struct {
}
func (n *noOpVisitorImpl) accept(visitor visitor) {
// NO OP
}
// --------------- dialect finder -----------------//
type DialectFinder struct {
dialects map[string]Dialect
}
func newDialectFinder() *DialectFinder {
return &DialectFinder{
dialects: make(map[string]Dialect),
}
}
func (f *DialectFinder) dialect() Dialect {
if len(f.dialects) == 0 {
panic("jet: can't detect dialect")
}
if len(f.dialects) > 1 {
panic("jet: more than one dialect detected")
}
for _, dialect := range f.dialects {
return dialect
}
panic("jet: internal error")
}
func (f *DialectFinder) visit(element acceptsVisitor) {
if table, ok := element.(table); ok {
dialect := table.dialect()
f.dialects[dialect.Name] = dialect
}
}
func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect {
if len(dialectOverride) > 0 {
return dialectOverride[0]
}
dialectFinder := newDialectFinder()
element.accept(dialectFinder)
return dialectFinder.dialect()
}