Generic dialect support. (MySQL and Postgres)
This commit is contained in:
parent
043a0dc4c0
commit
5dda5e1e11
27 changed files with 440 additions and 92 deletions
2
alias.go
2
alias.go
|
|
@ -28,7 +28,7 @@ func (a *alias) serializeForProjection(statement statementType, out *sqlBuilder)
|
|||
}
|
||||
|
||||
out.writeString("AS")
|
||||
out.writeQuotedString(a.alias)
|
||||
out.writeAlias(a.alias)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,13 +93,13 @@ type binaryBoolExpression struct {
|
|||
}
|
||||
|
||||
func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression {
|
||||
boolExpression := binaryBoolExpression{}
|
||||
binaryBoolExpression := binaryBoolExpression{}
|
||||
|
||||
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
|
||||
boolExpression.expressionInterfaceImpl.parent = &boolExpression
|
||||
boolExpression.boolInterfaceImpl.parent = &boolExpression
|
||||
binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
|
||||
binaryBoolExpression.expressionInterfaceImpl.parent = &binaryBoolExpression
|
||||
binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression
|
||||
|
||||
return &boolExpression
|
||||
return &binaryBoolExpression
|
||||
}
|
||||
|
||||
//---------------------------------------------------//
|
||||
|
|
|
|||
22
cast.go
22
cast.go
|
|
@ -44,9 +44,27 @@ func CAST(expression Expression) cast {
|
|||
}
|
||||
}
|
||||
|
||||
func (b *castImpl) accept(visitor visitor) {
|
||||
visitor.visit(b)
|
||||
|
||||
b.Expression.accept(visitor)
|
||||
}
|
||||
|
||||
func (b *castImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
err := b.Expression.serialize(statement, out, options...)
|
||||
out.writeString("::" + b.castType)
|
||||
|
||||
if castOverride := out.dialect.CastOverride; castOverride != nil {
|
||||
return castOverride(b.Expression, b.castType)(statement, out, options...)
|
||||
}
|
||||
|
||||
out.writeString("CAST")
|
||||
err := WRAP(b.Expression).serialize(statement, out, options...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out.writeString("AS")
|
||||
out.writeString(b.castType)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
11
clause.go
11
clause.go
|
|
@ -30,6 +30,7 @@ func contains(options []serializeOption, option serializeOption) bool {
|
|||
}
|
||||
|
||||
type sqlBuilder struct {
|
||||
dialect Dialect
|
||||
buff bytes.Buffer
|
||||
args []interface{}
|
||||
|
||||
|
|
@ -162,8 +163,9 @@ func isPostSeparator(b byte) bool {
|
|||
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
|
||||
}
|
||||
|
||||
func (q *sqlBuilder) writeQuotedString(str string) {
|
||||
q.writeString(`"` + str + `"`)
|
||||
func (q *sqlBuilder) writeAlias(str string) {
|
||||
aliasQuoteChar := string(q.dialect.AliasQuoteChar)
|
||||
q.writeString(aliasQuoteChar + str + aliasQuoteChar)
|
||||
}
|
||||
|
||||
func (q *sqlBuilder) writeString(str string) {
|
||||
|
|
@ -174,7 +176,8 @@ func (q *sqlBuilder) writeIdentifier(name string) {
|
|||
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
|
||||
|
||||
if quoteWrap {
|
||||
q.writeString(`"` + name + `"`)
|
||||
identQuoteChar := string(q.dialect.IdentifierQuoteChar)
|
||||
q.writeString(identQuoteChar + name + identQuoteChar)
|
||||
} else {
|
||||
q.writeString(name)
|
||||
}
|
||||
|
|
@ -194,7 +197,7 @@ func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
|
|||
|
||||
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
|
||||
q.args = append(q.args, arg)
|
||||
argPlaceholder := "$" + strconv.Itoa(len(q.args))
|
||||
argPlaceholder := q.dialect.ArgumentPlaceholder(len(q.args))
|
||||
|
||||
q.writeString(argPlaceholder)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ var (
|
|||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&source, "source", jet.PostgreSQL, "Database name")
|
||||
flag.StringVar(&source, "source", string(jet.PostgreSQL.Name), "Database name")
|
||||
|
||||
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
|
||||
flag.IntVar(&port, "port", 0, "Database port")
|
||||
|
|
@ -72,7 +72,7 @@ Usage of jet:
|
|||
var err error
|
||||
|
||||
switch source {
|
||||
case jet.PostgreSQL:
|
||||
case jet.PostgreSQL.Name:
|
||||
if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" {
|
||||
fmt.Println("\njet: required flag missing")
|
||||
flag.Usage()
|
||||
|
|
@ -93,7 +93,7 @@ Usage of jet:
|
|||
|
||||
err = postgres.Generate(destDir, genData)
|
||||
|
||||
case jet.MySql:
|
||||
case jet.MySql.Name:
|
||||
if host == "" || port == 0 || user == "" || dbName == "" {
|
||||
fmt.Println("\njet: required flag missing")
|
||||
flag.Usage()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ type Column interface {
|
|||
// The base type for real materialized columns.
|
||||
type columnImpl struct {
|
||||
expressionInterfaceImpl
|
||||
noOpVisitorImpl
|
||||
|
||||
name string
|
||||
tableName string
|
||||
|
|
@ -65,7 +66,7 @@ func (c *columnImpl) defaultAlias() string {
|
|||
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
|
||||
if statement == setStatement {
|
||||
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
|
||||
out.writeString(`"` + c.defaultAlias() + `"`) //always quote
|
||||
out.writeAlias(c.defaultAlias()) //always quote
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -80,7 +81,8 @@ func (c columnImpl) serializeForProjection(statement statementType, out *sqlBuil
|
|||
return err
|
||||
}
|
||||
|
||||
out.writeString(`AS "` + c.defaultAlias() + `"`)
|
||||
out.writeString("AS")
|
||||
out.writeAlias(c.defaultAlias())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -90,7 +92,7 @@ func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options
|
|||
if c.subQuery != nil {
|
||||
out.writeIdentifier(c.subQuery.Alias())
|
||||
out.writeByte('.')
|
||||
out.writeQuotedString(c.defaultAlias())
|
||||
out.writeAlias(c.defaultAlias())
|
||||
} else {
|
||||
if c.tableName != "" {
|
||||
out.writeIdentifier(c.tableName)
|
||||
|
|
|
|||
|
|
@ -38,6 +38,12 @@ func (d *deleteStatementImpl) RETURNING(projections ...projection) DeleteStateme
|
|||
return d
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) accept(visitor visitor) {
|
||||
visitor.visit(d)
|
||||
|
||||
d.table.accept(visitor)
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
|
||||
if d == nil {
|
||||
return errors.New("jet: delete statement is nil")
|
||||
|
|
@ -68,8 +74,10 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
queryData := &sqlBuilder{}
|
||||
func (d *deleteStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||
queryData := &sqlBuilder{
|
||||
dialect: detectDialect(d, dialect...),
|
||||
}
|
||||
|
||||
err = d.serializeImpl(queryData)
|
||||
|
||||
|
|
@ -81,8 +89,8 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error
|
|||
return
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) DebugSql() (query string, err error) {
|
||||
return debugSql(d)
|
||||
func (d *deleteStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||
return debugSql(d, dialect...)
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
|
||||
|
|
|
|||
97
dialects.go
Normal file
97
dialects.go
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
package jet
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
PostgreSQL = newPostgresDialect()
|
||||
MySql = newMySQLDialect()
|
||||
)
|
||||
|
||||
func newPostgresDialect() Dialect {
|
||||
postgresDialect := newDialect("PostgreSQL", "postgres")
|
||||
|
||||
postgresDialect.OperatorOverrides["IS DISTINCT FROM"] = postgresIS_DISTINCT_FROM
|
||||
postgresDialect.CastOverride = postgresCAST
|
||||
postgresDialect.AliasQuoteChar = '"'
|
||||
postgresDialect.IdentifierQuoteChar = '"'
|
||||
postgresDialect.ArgumentPlaceholder = func(ord int) string {
|
||||
return "$" + strconv.Itoa(ord)
|
||||
}
|
||||
|
||||
return postgresDialect
|
||||
}
|
||||
|
||||
func newMySQLDialect() Dialect {
|
||||
mySQLDialect := newDialect("MySQL", "mysql")
|
||||
|
||||
mySQLDialect.AliasQuoteChar = '"'
|
||||
mySQLDialect.IdentifierQuoteChar = '"'
|
||||
mySQLDialect.ArgumentPlaceholder = func(int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
return mySQLDialect
|
||||
}
|
||||
|
||||
type Dialect struct {
|
||||
Name string
|
||||
PackageName string
|
||||
OperatorOverrides map[string]serializeOverride
|
||||
CastOverride castOverride
|
||||
AliasQuoteChar byte
|
||||
IdentifierQuoteChar byte
|
||||
ArgumentPlaceholder queryPlaceholderFunc
|
||||
}
|
||||
|
||||
func (d *Dialect) serializeOverride(operator string) serializeOverride {
|
||||
return d.OperatorOverrides[operator]
|
||||
}
|
||||
|
||||
type queryPlaceholderFunc func(ord int) string
|
||||
|
||||
func newDialect(name, packageName string) Dialect {
|
||||
newDialect := Dialect{
|
||||
Name: name,
|
||||
PackageName: packageName,
|
||||
}
|
||||
newDialect.OperatorOverrides = make(map[string]serializeOverride)
|
||||
|
||||
return newDialect
|
||||
}
|
||||
|
||||
func postgresIS_DISTINCT_FROM(expressions ...Expression) serializeFunc {
|
||||
return func(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
if len(expressions) != 2 {
|
||||
return errors.New("Invalid number of expressions for operator")
|
||||
}
|
||||
if err := expressions[0].serialize(statement, out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out.writeString("IS DISTINCT FROM")
|
||||
|
||||
if err := expressions[1].serialize(statement, out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func postgresCAST(expression Expression, castType string) serializeFunc {
|
||||
return func(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
if err := expression.serialize(statement, out, options...); err != nil {
|
||||
return err
|
||||
}
|
||||
out.writeString("::" + castType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type serializeFunc func(statement statementType, out *sqlBuilder, options ...serializeOption) error
|
||||
type serializeOverride func(expressions ...Expression) serializeFunc
|
||||
|
||||
type castOverride func(expression Expression, castType string) serializeFunc
|
||||
|
|
@ -3,6 +3,8 @@ package jet
|
|||
type enumValue struct {
|
||||
expressionInterfaceImpl
|
||||
stringInterfaceImpl
|
||||
noOpVisitorImpl
|
||||
|
||||
name string
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,12 @@ import (
|
|||
// Expression is common interface for all expressions.
|
||||
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
|
||||
type Expression interface {
|
||||
acceptsVisitor
|
||||
|
||||
expression
|
||||
}
|
||||
|
||||
type expression interface {
|
||||
clause
|
||||
projection
|
||||
groupByClause
|
||||
|
|
@ -95,7 +101,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpressio
|
|||
return binaryExpression
|
||||
}
|
||||
|
||||
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
func (c *binaryOpExpression) accept(visitor visitor) {
|
||||
c.lhs.accept(visitor)
|
||||
c.rhs.accept(visitor)
|
||||
}
|
||||
|
||||
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) {
|
||||
if c == nil {
|
||||
return errors.New("jet: binary Expression is nil")
|
||||
}
|
||||
|
|
@ -112,6 +123,9 @@ func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder,
|
|||
out.writeString("(")
|
||||
}
|
||||
|
||||
if dialectOveride := out.dialect.serializeOverride(c.operator); dialectOveride != nil {
|
||||
err = dialectOveride(c.lhs, c.rhs)(statement, out, options...)
|
||||
} else {
|
||||
if err := c.lhs.serialize(statement, out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -121,12 +135,13 @@ func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder,
|
|||
if err := c.rhs.serialize(statement, out); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if wrap {
|
||||
out.writeString(")")
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// A prefix operator Expression
|
||||
|
|
@ -144,6 +159,10 @@ func newPrefixExpression(expression Expression, operator string) prefixOpExpress
|
|||
return prefixExpression
|
||||
}
|
||||
|
||||
func (p *prefixOpExpression) accept(visitor visitor) {
|
||||
p.expression.accept(visitor)
|
||||
}
|
||||
|
||||
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
if p == nil {
|
||||
return errors.New("jet: Prefix Expression is nil")
|
||||
|
|
@ -176,6 +195,10 @@ func newPostfixOpExpression(expression Expression, operator string) postfixOpExp
|
|||
return postfixOpExpression
|
||||
}
|
||||
|
||||
func (p *postfixOpExpression) accept(visitor visitor) {
|
||||
p.expression.accept(visitor)
|
||||
}
|
||||
|
||||
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
if p == nil {
|
||||
return errors.New("jet: Postifx operator Expression is nil")
|
||||
|
|
|
|||
|
|
@ -485,6 +485,14 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
|
|||
return funcExp
|
||||
}
|
||||
|
||||
func (f *funcExpressionImpl) accept(visitor visitor) {
|
||||
visitor.visit(f)
|
||||
|
||||
for _, exp := range f.expressions {
|
||||
exp.accept(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
if f == nil {
|
||||
return errors.New("jet: Function expressions is nil. ")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package template
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet"
|
||||
"github.com/go-jet/jet/generator/internal/metadata"
|
||||
"github.com/go-jet/jet/internal/utils"
|
||||
"path/filepath"
|
||||
|
|
@ -10,7 +11,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect string) error {
|
||||
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error {
|
||||
if len(tables) == 0 && len(enums) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
|
@ -59,7 +60,7 @@ func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect st
|
|||
|
||||
}
|
||||
|
||||
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect string) error {
|
||||
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error {
|
||||
modelDirPath := filepath.Join(dirPath, packageName)
|
||||
|
||||
err := utils.EnsureDirPath(modelDirPath)
|
||||
|
|
@ -92,14 +93,14 @@ func generate(dirPath, packageName string, template string, metaDataList []metad
|
|||
}
|
||||
|
||||
// GenerateTemplate generates template with template text and template data.
|
||||
func GenerateTemplate(templateText string, templateData interface{}, dialect string) ([]byte, error) {
|
||||
func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect) ([]byte, error) {
|
||||
|
||||
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
|
||||
"ToGoIdentifier": utils.ToGoIdentifier,
|
||||
"now": func() string {
|
||||
return time.Now().Format(time.RFC850)
|
||||
},
|
||||
"dialect": func() string {
|
||||
"dialect": func() jet.Dialect {
|
||||
return dialect
|
||||
},
|
||||
}).Parse(templateText)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ package table
|
|||
|
||||
import (
|
||||
"github.com/go-jet/jet"
|
||||
"github.com/go-jet/jet/{{dialect}}"
|
||||
"github.com/go-jet/jet/{{dialect.PackageName}}"
|
||||
)
|
||||
|
||||
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
|
||||
|
|
@ -32,7 +32,7 @@ type {{.GoStructName}} struct {
|
|||
|
||||
//Columns
|
||||
{{- range .Columns}}
|
||||
{{ToGoIdentifier .Name}} {{dialect}}.Column{{.SqlBuilderColumnType}}
|
||||
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
|
||||
{{- end}}
|
||||
|
||||
AllColumns jet.ColumnList
|
||||
|
|
@ -51,12 +51,12 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
|
|||
func new{{.GoStructName}}() *{{.GoStructName}} {
|
||||
var (
|
||||
{{- range .Columns}}
|
||||
{{ToGoIdentifier .Name}}Column = {{dialect}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
|
||||
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
|
||||
{{- end}}
|
||||
)
|
||||
|
||||
return &{{.GoStructName}}{
|
||||
Table: jet.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
|
||||
Table: jet.NewTable(jet.{{dialect.Name}}, "{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
|
||||
|
||||
//Columns
|
||||
{{- range .Columns}}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package mysql
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet"
|
||||
"github.com/go-jet/jet/generator/internal/metadata"
|
||||
"github.com/go-jet/jet/generator/internal/template"
|
||||
"path"
|
||||
|
|
@ -37,7 +38,7 @@ func Generate(destDir string, dbConn DBConnection) error {
|
|||
|
||||
genPath := path.Join(destDir, dbConn.DBName)
|
||||
|
||||
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, "mysql")
|
||||
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, jet.MySql)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package postgres
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet"
|
||||
"github.com/go-jet/jet/generator/internal/metadata"
|
||||
"github.com/go-jet/jet/generator/internal/template"
|
||||
"path"
|
||||
|
|
@ -41,7 +42,7 @@ func Generate(destDir string, dbConn DBConnection) error {
|
|||
|
||||
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
|
||||
|
||||
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, "postgres")
|
||||
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, jet.PostgreSQL)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -73,36 +73,44 @@ func (i *insertStatementImpl) getColumns() []column {
|
|||
return i.table.columns()
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) DebugSql() (query string, err error) {
|
||||
return debugSql(i)
|
||||
func (i *insertStatementImpl) accept(visitor visitor) {
|
||||
visitor.visit(i)
|
||||
|
||||
i.table.accept(visitor)
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
|
||||
queryData := &sqlBuilder{}
|
||||
func (i *insertStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||
return debugSql(i, dialect...)
|
||||
}
|
||||
|
||||
queryData.newLine()
|
||||
queryData.writeString("INSERT INTO")
|
||||
func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||
out := &sqlBuilder{
|
||||
dialect: detectDialect(i, dialect...),
|
||||
}
|
||||
|
||||
out.newLine()
|
||||
out.writeString("INSERT INTO")
|
||||
|
||||
if utils.IsNil(i.table) {
|
||||
return "", nil, errors.New("jet: table is nil")
|
||||
}
|
||||
|
||||
err = i.table.serialize(insertStatement, queryData)
|
||||
err = i.table.serialize(insertStatement, out)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(i.columns) > 0 {
|
||||
queryData.writeString("(")
|
||||
out.writeString("(")
|
||||
|
||||
err = serializeColumnNames(i.columns, queryData)
|
||||
err = serializeColumnNames(i.columns, out)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
queryData.writeString(")")
|
||||
out.writeString(")")
|
||||
}
|
||||
|
||||
if len(i.rows) == 0 && i.query == nil {
|
||||
|
|
@ -114,41 +122,41 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
if len(i.rows) > 0 {
|
||||
queryData.writeString("VALUES")
|
||||
out.writeString("VALUES")
|
||||
|
||||
for rowIndex, row := range i.rows {
|
||||
if rowIndex > 0 {
|
||||
queryData.writeString(",")
|
||||
out.writeString(",")
|
||||
}
|
||||
|
||||
queryData.increaseIdent()
|
||||
queryData.newLine()
|
||||
queryData.writeString("(")
|
||||
out.increaseIdent()
|
||||
out.newLine()
|
||||
out.writeString("(")
|
||||
|
||||
err = serializeClauseList(insertStatement, row, queryData)
|
||||
err = serializeClauseList(insertStatement, row, out)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
queryData.writeByte(')')
|
||||
queryData.decreaseIdent()
|
||||
out.writeByte(')')
|
||||
out.decreaseIdent()
|
||||
}
|
||||
}
|
||||
|
||||
if i.query != nil {
|
||||
err = i.query.serialize(insertStatement, queryData)
|
||||
err = i.query.serialize(insertStatement, out)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err = queryData.writeReturning(insertStatement, i.returning); err != nil {
|
||||
if err = out.writeReturning(insertStatement, i.returning); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
sql, args = queryData.finalize()
|
||||
query, args = out.finalize()
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import "fmt"
|
|||
// Representation of an escaped literal
|
||||
type literalExpression struct {
|
||||
expressionInterfaceImpl
|
||||
noOpVisitorImpl
|
||||
|
||||
value interface{}
|
||||
constant bool
|
||||
}
|
||||
|
|
@ -188,6 +190,7 @@ func Date(year, month, day int) DateExpression {
|
|||
//--------------------------------------------------//
|
||||
type nullLiteral struct {
|
||||
expressionInterfaceImpl
|
||||
noOpVisitorImpl
|
||||
}
|
||||
|
||||
func newNullLiteral() Expression {
|
||||
|
|
@ -206,6 +209,7 @@ func (n *nullLiteral) serialize(statement statementType, out *sqlBuilder, option
|
|||
//--------------------------------------------------//
|
||||
type starLiteral struct {
|
||||
expressionInterfaceImpl
|
||||
noOpVisitorImpl
|
||||
}
|
||||
|
||||
func newStarLiteral() Expression {
|
||||
|
|
@ -228,6 +232,12 @@ type wrap struct {
|
|||
expressions []Expression
|
||||
}
|
||||
|
||||
func (n *wrap) accept(visitor visitor) {
|
||||
for _, exp := range n.expressions {
|
||||
exp.accept(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
out.writeString("(")
|
||||
err := serializeExpressionList(statement, n.expressions, ", ", out)
|
||||
|
|
@ -247,6 +257,8 @@ func WRAP(expression ...Expression) Expression {
|
|||
|
||||
type rawExpression struct {
|
||||
expressionInterfaceImpl
|
||||
noOpVisitorImpl
|
||||
|
||||
raw string
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -53,11 +53,19 @@ func (l *lockStatementImpl) NOWAIT() LockStatement {
|
|||
return l
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) DebugSql() (query string, err error) {
|
||||
return debugSql(l)
|
||||
func (l *lockStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||
return debugSql(l, dialect...)
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
func (l *lockStatementImpl) accept(visitor visitor) {
|
||||
visitor.visit(l)
|
||||
|
||||
for _, table := range l.tables {
|
||||
table.accept(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||
if l == nil {
|
||||
return "", nil, errors.New("jet: nil Statement")
|
||||
}
|
||||
|
|
@ -66,7 +74,9 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
|
|||
return "", nil, errors.New("jet: There is no table selected to be locked")
|
||||
}
|
||||
|
||||
out := &sqlBuilder{}
|
||||
out := &sqlBuilder{
|
||||
dialect: detectDialect(l, dialect...),
|
||||
}
|
||||
|
||||
out.newLine()
|
||||
out.writeString("LOCK TABLE")
|
||||
|
|
|
|||
18
operators.go
18
operators.go
|
|
@ -108,6 +108,24 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
|
|||
return c
|
||||
}
|
||||
|
||||
func (c *caseOperatorImpl) accept(visitor visitor) {
|
||||
visitor.visit(c)
|
||||
|
||||
c.expression.accept(visitor)
|
||||
|
||||
for _, when := range c.when {
|
||||
when.accept(visitor)
|
||||
}
|
||||
|
||||
for _, then := range c.then {
|
||||
then.accept(visitor)
|
||||
}
|
||||
|
||||
if c.els != nil {
|
||||
c.els.accept(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
if c == nil {
|
||||
return errors.New("jet: Case Expression is nil. ")
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ var (
|
|||
// SelectStatement is interface for SQL SELECT statements
|
||||
type SelectStatement interface {
|
||||
Statement
|
||||
Expression
|
||||
expression
|
||||
|
||||
DISTINCT() SelectStatement
|
||||
FROM(table ReadableTable) SelectStatement
|
||||
|
|
@ -261,10 +261,29 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
queryData := sqlBuilder{}
|
||||
func (s *selectStatementImpl) accept(visitor visitor) {
|
||||
visitor.visit(s)
|
||||
|
||||
err = s.serializeImpl(&queryData)
|
||||
if s.table != nil {
|
||||
s.table.accept(visitor)
|
||||
}
|
||||
|
||||
if s.where != nil {
|
||||
s.where.accept(visitor)
|
||||
}
|
||||
|
||||
if s.having != nil {
|
||||
s.having.accept(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||
|
||||
queryData := &sqlBuilder{
|
||||
dialect: detectDialect(s, dialect...),
|
||||
}
|
||||
|
||||
err = s.serializeImpl(queryData)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
|
|
@ -275,8 +294,8 @@ func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error
|
|||
return
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) DebugSql() (query string, err error) {
|
||||
return debugSql(s.parent)
|
||||
func (s *selectStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||
return debugSql(s.parent, dialect...)
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {
|
||||
|
|
|
|||
|
|
@ -41,6 +41,15 @@ func (s *selectTableImpl) columns() []column {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *selectTableImpl) accept(visitor visitor) {
|
||||
visitor.visit(s)
|
||||
s.selectStmt.accept(visitor)
|
||||
}
|
||||
|
||||
func (s *selectTableImpl) dialect() Dialect {
|
||||
return detectDialect(s.selectStmt)
|
||||
}
|
||||
|
||||
func (s *selectTableImpl) AllColumns() ProjectionList {
|
||||
return s.projections
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,6 +74,14 @@ func newSetStatementImpl(operator string, all bool, selects []SelectStatement) S
|
|||
return setStatement
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) accept(visitor visitor) {
|
||||
visitor.visit(s)
|
||||
|
||||
for _, selects := range s.selects {
|
||||
selects.accept(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) projections() []projection {
|
||||
if len(s.selects) > 0 {
|
||||
return s.selects[0].projections()
|
||||
|
|
@ -169,8 +177,10 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
queryData := &sqlBuilder{}
|
||||
func (s *setStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||
queryData := &sqlBuilder{
|
||||
dialect: detectDialect(s, dialect...),
|
||||
}
|
||||
|
||||
err = s.serializeImpl(queryData)
|
||||
|
||||
|
|
|
|||
11
statement.go
11
statement.go
|
|
@ -4,19 +4,19 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"github.com/go-jet/jet/execution"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
|
||||
type Statement interface {
|
||||
acceptsVisitor
|
||||
// Sql returns parametrized sql query with list of arguments.
|
||||
// err is returned if statement is not composed correctly
|
||||
Sql() (query string, args []interface{}, err error)
|
||||
Sql(dialect ...Dialect) (query string, args []interface{}, err error)
|
||||
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
|
||||
// Do not use it in production. Use it only for debug purposes.
|
||||
// err is returned if statement is not composed correctly
|
||||
DebugSql() (query string, err error)
|
||||
DebugSql(dialect ...Dialect) (query string, err error)
|
||||
|
||||
// Query executes statement over database connection db and stores row result in destination.
|
||||
// Destination can be arbitrary structure
|
||||
|
|
@ -31,7 +31,8 @@ type Statement interface {
|
|||
ExecContext(context context.Context, db execution.DB) (sql.Result, error)
|
||||
}
|
||||
|
||||
func debugSql(statement Statement) (string, error) {
|
||||
func debugSql(statement Statement, overrideDialect ...Dialect) (string, error) {
|
||||
dialect := detectDialect(statement, overrideDialect...)
|
||||
sqlQuery, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -41,7 +42,7 @@ func debugSql(statement Statement) (string, error) {
|
|||
debugSQLQuery := sqlQuery
|
||||
|
||||
for i, arg := range args {
|
||||
argPlaceholder := "$" + strconv.Itoa(i+1)
|
||||
argPlaceholder := dialect.ArgumentPlaceholder(i + 1)
|
||||
debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
|
||||
}
|
||||
|
||||
|
|
|
|||
28
table.go
28
table.go
|
|
@ -6,6 +6,7 @@ import (
|
|||
)
|
||||
|
||||
type table interface {
|
||||
dialect() Dialect
|
||||
columns() []column
|
||||
}
|
||||
|
||||
|
|
@ -42,6 +43,7 @@ type ReadableTable interface {
|
|||
table
|
||||
readableTable
|
||||
clause
|
||||
acceptsVisitor
|
||||
}
|
||||
|
||||
// WritableTable interface
|
||||
|
|
@ -49,6 +51,7 @@ type WritableTable interface {
|
|||
table
|
||||
writableTable
|
||||
clause
|
||||
acceptsVisitor
|
||||
}
|
||||
|
||||
// Table interface
|
||||
|
|
@ -57,6 +60,8 @@ type Table interface {
|
|||
readableTable
|
||||
writableTable
|
||||
clause
|
||||
acceptsVisitor
|
||||
|
||||
SchemaName() string
|
||||
TableName() string
|
||||
AS(alias string)
|
||||
|
|
@ -114,10 +119,11 @@ func (w *writableTableInterfaceImpl) LOCK() LockStatement {
|
|||
return LOCK(w.parent)
|
||||
}
|
||||
|
||||
// NewTable creates new table with schema name, table name and list of columns
|
||||
func NewTable(schemaName, name string, columns ...Column) Table {
|
||||
// NewTable creates new table with schema Name, table Name and list of columns
|
||||
func NewTable(Dialect Dialect, schemaName, name string, columns ...Column) Table {
|
||||
|
||||
t := &tableImpl{
|
||||
Dialect: Dialect,
|
||||
schemaName: schemaName,
|
||||
name: name,
|
||||
columnList: columns,
|
||||
|
|
@ -136,6 +142,7 @@ type tableImpl struct {
|
|||
readableTableInterfaceImpl
|
||||
writableTableInterfaceImpl
|
||||
|
||||
Dialect Dialect
|
||||
schemaName string
|
||||
name string
|
||||
alias string
|
||||
|
|
@ -168,6 +175,14 @@ func (t *tableImpl) columns() []column {
|
|||
return ret
|
||||
}
|
||||
|
||||
func (t *tableImpl) dialect() Dialect {
|
||||
return t.Dialect
|
||||
}
|
||||
|
||||
func (t *tableImpl) accept(visitor visitor) {
|
||||
visitor.visit(t)
|
||||
}
|
||||
|
||||
func (t *tableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||
if t == nil {
|
||||
return errors.New("jet: tableImpl is nil. ")
|
||||
|
|
@ -235,6 +250,15 @@ func (t *joinTable) columns() []column {
|
|||
return append(t.lhs.columns(), t.rhs.columns()...)
|
||||
}
|
||||
|
||||
func (t *joinTable) accept(visitor visitor) {
|
||||
t.lhs.accept(visitor)
|
||||
t.rhs.accept(visitor)
|
||||
}
|
||||
|
||||
func (t *joinTable) dialect() Dialect {
|
||||
return detectDialect(t)
|
||||
}
|
||||
|
||||
func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) {
|
||||
if t == nil {
|
||||
return errors.New("jet: Join table is nil. ")
|
||||
|
|
|
|||
|
|
@ -57,8 +57,15 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStateme
|
|||
return u
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
|
||||
out := &sqlBuilder{}
|
||||
func (u *updateStatementImpl) accept(visitor visitor) {
|
||||
visitor.visit(u)
|
||||
u.table.accept(visitor)
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||
out := &sqlBuilder{
|
||||
dialect: detectDialect(u, dialect...),
|
||||
}
|
||||
|
||||
out.newLine()
|
||||
out.writeString("UPDATE")
|
||||
|
|
@ -124,12 +131,12 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
return
|
||||
}
|
||||
|
||||
sql, args = out.finalize()
|
||||
query, args = out.finalize()
|
||||
return
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) DebugSql() (query string, err error) {
|
||||
return debugSql(u)
|
||||
func (u *updateStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||
return debugSql(u, dialect...)
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ var table1ColBool = BoolColumn("col_bool")
|
|||
var table1ColDate = DateColumn("col_date")
|
||||
|
||||
var table1 = NewTable(
|
||||
PostgreSQL,
|
||||
"db",
|
||||
"table1",
|
||||
table1Col1,
|
||||
|
|
@ -44,6 +45,7 @@ var table2ColTimestampz = TimestampzColumn("col_timestampz")
|
|||
var table2ColDate = DateColumn("col_date")
|
||||
|
||||
var table2 = NewTable(
|
||||
PostgreSQL,
|
||||
"db",
|
||||
"table2",
|
||||
table2Col3,
|
||||
|
|
@ -63,6 +65,7 @@ var table3Col1 = IntegerColumn("col1")
|
|||
var table3ColInt = IntegerColumn("col_int")
|
||||
var table3StrCol = StringColumn("col2")
|
||||
var table3 = NewTable(
|
||||
PostgreSQL,
|
||||
"db",
|
||||
"table3",
|
||||
table3Col1,
|
||||
|
|
@ -70,7 +73,7 @@ var table3 = NewTable(
|
|||
table3StrCol)
|
||||
|
||||
func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) {
|
||||
out := sqlBuilder{}
|
||||
out := sqlBuilder{dialect: PostgreSQL}
|
||||
err := clause.serialize(selectStatement, &out)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
|
@ -80,7 +83,7 @@ func assertClauseSerialize(t *testing.T, clause clause, query string, args ...in
|
|||
}
|
||||
|
||||
func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
|
||||
out := sqlBuilder{}
|
||||
out := sqlBuilder{dialect: PostgreSQL}
|
||||
err := clause.serialize(selectStatement, &out)
|
||||
|
||||
//fmt.Println(out.buff.String())
|
||||
|
|
@ -89,7 +92,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
|
|||
}
|
||||
|
||||
func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) {
|
||||
out := sqlBuilder{}
|
||||
out := sqlBuilder{dialect: PostgreSQL}
|
||||
err := projection.serializeForProjection(selectStatement, &out)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
|
@ -99,7 +102,7 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string
|
|||
}
|
||||
|
||||
func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
|
||||
queryStr, args, err := query.Sql()
|
||||
queryStr, args, err := query.Sql(PostgreSQL)
|
||||
assert.NilError(t, err)
|
||||
|
||||
//fmt.Println(queryStr)
|
||||
|
|
@ -108,7 +111,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect
|
|||
}
|
||||
|
||||
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
|
||||
_, _, err := stmt.Sql()
|
||||
_, _, err := stmt.Sql(PostgreSQL)
|
||||
|
||||
assert.Assert(t, err != nil)
|
||||
assert.Error(t, err, errorStr)
|
||||
|
|
|
|||
63
visitor.go
Normal file
63
visitor.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package jet
|
||||
|
||||
type visitor interface {
|
||||
visit(element acceptsVisitor)
|
||||
}
|
||||
|
||||
type acceptsVisitor interface {
|
||||
accept(visitor visitor)
|
||||
}
|
||||
|
||||
type noOpVisitorImpl struct {
|
||||
}
|
||||
|
||||
func (n *noOpVisitorImpl) accept(visitor visitor) {
|
||||
// NO OP
|
||||
}
|
||||
|
||||
// --------------- dialect finder -----------------//
|
||||
|
||||
type DialectFinder struct {
|
||||
dialects map[string]Dialect
|
||||
}
|
||||
|
||||
func newDialectFinder() *DialectFinder {
|
||||
return &DialectFinder{
|
||||
dialects: make(map[string]Dialect),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DialectFinder) dialect() Dialect {
|
||||
if len(f.dialects) == 0 {
|
||||
panic("jet: can't detect dialect")
|
||||
}
|
||||
|
||||
if len(f.dialects) > 1 {
|
||||
panic("jet: more than one dialect detected")
|
||||
}
|
||||
|
||||
for _, dialect := range f.dialects {
|
||||
return dialect
|
||||
}
|
||||
|
||||
panic("jet: internal error")
|
||||
}
|
||||
|
||||
func (f *DialectFinder) visit(element acceptsVisitor) {
|
||||
|
||||
if table, ok := element.(table); ok {
|
||||
dialect := table.dialect()
|
||||
f.dialects[dialect.Name] = dialect
|
||||
}
|
||||
}
|
||||
|
||||
func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect {
|
||||
if len(dialectOverride) > 0 {
|
||||
return dialectOverride[0]
|
||||
}
|
||||
|
||||
dialectFinder := newDialectFinder()
|
||||
element.accept(dialectFinder)
|
||||
|
||||
return dialectFinder.dialect()
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue