Generic dialect support. (MySQL and Postgres)
This commit is contained in:
parent
043a0dc4c0
commit
5dda5e1e11
27 changed files with 440 additions and 92 deletions
2
alias.go
2
alias.go
|
|
@ -28,7 +28,7 @@ func (a *alias) serializeForProjection(statement statementType, out *sqlBuilder)
|
||||||
}
|
}
|
||||||
|
|
||||||
out.writeString("AS")
|
out.writeString("AS")
|
||||||
out.writeQuotedString(a.alias)
|
out.writeAlias(a.alias)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -93,13 +93,13 @@ type binaryBoolExpression struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression {
|
func newBinaryBoolOperator(lhs, rhs Expression, operator string) BoolExpression {
|
||||||
boolExpression := binaryBoolExpression{}
|
binaryBoolExpression := binaryBoolExpression{}
|
||||||
|
|
||||||
boolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
|
binaryBoolExpression.binaryOpExpression = newBinaryExpression(lhs, rhs, operator)
|
||||||
boolExpression.expressionInterfaceImpl.parent = &boolExpression
|
binaryBoolExpression.expressionInterfaceImpl.parent = &binaryBoolExpression
|
||||||
boolExpression.boolInterfaceImpl.parent = &boolExpression
|
binaryBoolExpression.boolInterfaceImpl.parent = &binaryBoolExpression
|
||||||
|
|
||||||
return &boolExpression
|
return &binaryBoolExpression
|
||||||
}
|
}
|
||||||
|
|
||||||
//---------------------------------------------------//
|
//---------------------------------------------------//
|
||||||
|
|
|
||||||
22
cast.go
22
cast.go
|
|
@ -44,9 +44,27 @@ func CAST(expression Expression) cast {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *castImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(b)
|
||||||
|
|
||||||
|
b.Expression.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *castImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (b *castImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
err := b.Expression.serialize(statement, out, options...)
|
|
||||||
out.writeString("::" + b.castType)
|
if castOverride := out.dialect.CastOverride; castOverride != nil {
|
||||||
|
return castOverride(b.Expression, b.castType)(statement, out, options...)
|
||||||
|
}
|
||||||
|
|
||||||
|
out.writeString("CAST")
|
||||||
|
err := WRAP(b.Expression).serialize(statement, out, options...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
out.writeString("AS")
|
||||||
|
out.writeString(b.castType)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
15
clause.go
15
clause.go
|
|
@ -30,8 +30,9 @@ func contains(options []serializeOption, option serializeOption) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlBuilder struct {
|
type sqlBuilder struct {
|
||||||
buff bytes.Buffer
|
dialect Dialect
|
||||||
args []interface{}
|
buff bytes.Buffer
|
||||||
|
args []interface{}
|
||||||
|
|
||||||
lastChar byte
|
lastChar byte
|
||||||
ident int
|
ident int
|
||||||
|
|
@ -162,8 +163,9 @@ func isPostSeparator(b byte) bool {
|
||||||
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
|
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *sqlBuilder) writeQuotedString(str string) {
|
func (q *sqlBuilder) writeAlias(str string) {
|
||||||
q.writeString(`"` + str + `"`)
|
aliasQuoteChar := string(q.dialect.AliasQuoteChar)
|
||||||
|
q.writeString(aliasQuoteChar + str + aliasQuoteChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *sqlBuilder) writeString(str string) {
|
func (q *sqlBuilder) writeString(str string) {
|
||||||
|
|
@ -174,7 +176,8 @@ func (q *sqlBuilder) writeIdentifier(name string) {
|
||||||
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
|
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
|
||||||
|
|
||||||
if quoteWrap {
|
if quoteWrap {
|
||||||
q.writeString(`"` + name + `"`)
|
identQuoteChar := string(q.dialect.IdentifierQuoteChar)
|
||||||
|
q.writeString(identQuoteChar + name + identQuoteChar)
|
||||||
} else {
|
} else {
|
||||||
q.writeString(name)
|
q.writeString(name)
|
||||||
}
|
}
|
||||||
|
|
@ -194,7 +197,7 @@ func (q *sqlBuilder) insertConstantArgument(arg interface{}) {
|
||||||
|
|
||||||
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
|
func (q *sqlBuilder) insertParametrizedArgument(arg interface{}) {
|
||||||
q.args = append(q.args, arg)
|
q.args = append(q.args, arg)
|
||||||
argPlaceholder := "$" + strconv.Itoa(len(q.args))
|
argPlaceholder := q.dialect.ArgumentPlaceholder(len(q.args))
|
||||||
|
|
||||||
q.writeString(argPlaceholder)
|
q.writeString(argPlaceholder)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
flag.StringVar(&source, "source", jet.PostgreSQL, "Database name")
|
flag.StringVar(&source, "source", string(jet.PostgreSQL.Name), "Database name")
|
||||||
|
|
||||||
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
|
flag.StringVar(&host, "host", "", "Database host path (Example: localhost)")
|
||||||
flag.IntVar(&port, "port", 0, "Database port")
|
flag.IntVar(&port, "port", 0, "Database port")
|
||||||
|
|
@ -72,7 +72,7 @@ Usage of jet:
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
switch source {
|
switch source {
|
||||||
case jet.PostgreSQL:
|
case jet.PostgreSQL.Name:
|
||||||
if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" {
|
if host == "" || port == 0 || user == "" || dbName == "" || schemaName == "" {
|
||||||
fmt.Println("\njet: required flag missing")
|
fmt.Println("\njet: required flag missing")
|
||||||
flag.Usage()
|
flag.Usage()
|
||||||
|
|
@ -93,7 +93,7 @@ Usage of jet:
|
||||||
|
|
||||||
err = postgres.Generate(destDir, genData)
|
err = postgres.Generate(destDir, genData)
|
||||||
|
|
||||||
case jet.MySql:
|
case jet.MySql.Name:
|
||||||
if host == "" || port == 0 || user == "" || dbName == "" {
|
if host == "" || port == 0 || user == "" || dbName == "" {
|
||||||
fmt.Println("\njet: required flag missing")
|
fmt.Println("\njet: required flag missing")
|
||||||
flag.Usage()
|
flag.Usage()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ type Column interface {
|
||||||
// The base type for real materialized columns.
|
// The base type for real materialized columns.
|
||||||
type columnImpl struct {
|
type columnImpl struct {
|
||||||
expressionInterfaceImpl
|
expressionInterfaceImpl
|
||||||
|
noOpVisitorImpl
|
||||||
|
|
||||||
name string
|
name string
|
||||||
tableName string
|
tableName string
|
||||||
|
|
@ -65,7 +66,7 @@ func (c *columnImpl) defaultAlias() string {
|
||||||
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
|
func (c *columnImpl) serializeForOrderBy(statement statementType, out *sqlBuilder) error {
|
||||||
if statement == setStatement {
|
if statement == setStatement {
|
||||||
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
|
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
|
||||||
out.writeString(`"` + c.defaultAlias() + `"`) //always quote
|
out.writeAlias(c.defaultAlias()) //always quote
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -80,7 +81,8 @@ func (c columnImpl) serializeForProjection(statement statementType, out *sqlBuil
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
out.writeString(`AS "` + c.defaultAlias() + `"`)
|
out.writeString("AS")
|
||||||
|
out.writeAlias(c.defaultAlias())
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +92,7 @@ func (c columnImpl) serialize(statement statementType, out *sqlBuilder, options
|
||||||
if c.subQuery != nil {
|
if c.subQuery != nil {
|
||||||
out.writeIdentifier(c.subQuery.Alias())
|
out.writeIdentifier(c.subQuery.Alias())
|
||||||
out.writeByte('.')
|
out.writeByte('.')
|
||||||
out.writeQuotedString(c.defaultAlias())
|
out.writeAlias(c.defaultAlias())
|
||||||
} else {
|
} else {
|
||||||
if c.tableName != "" {
|
if c.tableName != "" {
|
||||||
out.writeIdentifier(c.tableName)
|
out.writeIdentifier(c.tableName)
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,12 @@ func (d *deleteStatementImpl) RETURNING(projections ...projection) DeleteStateme
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *deleteStatementImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(d)
|
||||||
|
|
||||||
|
d.table.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
|
func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
|
||||||
if d == nil {
|
if d == nil {
|
||||||
return errors.New("jet: delete statement is nil")
|
return errors.New("jet: delete statement is nil")
|
||||||
|
|
@ -68,8 +74,10 @@ func (d *deleteStatementImpl) serializeImpl(out *sqlBuilder) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
|
func (d *deleteStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||||
queryData := &sqlBuilder{}
|
queryData := &sqlBuilder{
|
||||||
|
dialect: detectDialect(d, dialect...),
|
||||||
|
}
|
||||||
|
|
||||||
err = d.serializeImpl(queryData)
|
err = d.serializeImpl(queryData)
|
||||||
|
|
||||||
|
|
@ -81,8 +89,8 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *deleteStatementImpl) DebugSql() (query string, err error) {
|
func (d *deleteStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||||
return debugSql(d)
|
return debugSql(d, dialect...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
|
func (d *deleteStatementImpl) Query(db execution.DB, destination interface{}) error {
|
||||||
|
|
|
||||||
97
dialects.go
Normal file
97
dialects.go
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
package jet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
PostgreSQL = newPostgresDialect()
|
||||||
|
MySql = newMySQLDialect()
|
||||||
|
)
|
||||||
|
|
||||||
|
func newPostgresDialect() Dialect {
|
||||||
|
postgresDialect := newDialect("PostgreSQL", "postgres")
|
||||||
|
|
||||||
|
postgresDialect.OperatorOverrides["IS DISTINCT FROM"] = postgresIS_DISTINCT_FROM
|
||||||
|
postgresDialect.CastOverride = postgresCAST
|
||||||
|
postgresDialect.AliasQuoteChar = '"'
|
||||||
|
postgresDialect.IdentifierQuoteChar = '"'
|
||||||
|
postgresDialect.ArgumentPlaceholder = func(ord int) string {
|
||||||
|
return "$" + strconv.Itoa(ord)
|
||||||
|
}
|
||||||
|
|
||||||
|
return postgresDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMySQLDialect() Dialect {
|
||||||
|
mySQLDialect := newDialect("MySQL", "mysql")
|
||||||
|
|
||||||
|
mySQLDialect.AliasQuoteChar = '"'
|
||||||
|
mySQLDialect.IdentifierQuoteChar = '"'
|
||||||
|
mySQLDialect.ArgumentPlaceholder = func(int) string {
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
|
||||||
|
return mySQLDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
type Dialect struct {
|
||||||
|
Name string
|
||||||
|
PackageName string
|
||||||
|
OperatorOverrides map[string]serializeOverride
|
||||||
|
CastOverride castOverride
|
||||||
|
AliasQuoteChar byte
|
||||||
|
IdentifierQuoteChar byte
|
||||||
|
ArgumentPlaceholder queryPlaceholderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialect) serializeOverride(operator string) serializeOverride {
|
||||||
|
return d.OperatorOverrides[operator]
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryPlaceholderFunc func(ord int) string
|
||||||
|
|
||||||
|
func newDialect(name, packageName string) Dialect {
|
||||||
|
newDialect := Dialect{
|
||||||
|
Name: name,
|
||||||
|
PackageName: packageName,
|
||||||
|
}
|
||||||
|
newDialect.OperatorOverrides = make(map[string]serializeOverride)
|
||||||
|
|
||||||
|
return newDialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func postgresIS_DISTINCT_FROM(expressions ...Expression) serializeFunc {
|
||||||
|
return func(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
|
if len(expressions) != 2 {
|
||||||
|
return errors.New("Invalid number of expressions for operator")
|
||||||
|
}
|
||||||
|
if err := expressions[0].serialize(statement, out); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
out.writeString("IS DISTINCT FROM")
|
||||||
|
|
||||||
|
if err := expressions[1].serialize(statement, out); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func postgresCAST(expression Expression, castType string) serializeFunc {
|
||||||
|
return func(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
|
if err := expression.serialize(statement, out, options...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
out.writeString("::" + castType)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type serializeFunc func(statement statementType, out *sqlBuilder, options ...serializeOption) error
|
||||||
|
type serializeOverride func(expressions ...Expression) serializeFunc
|
||||||
|
|
||||||
|
type castOverride func(expression Expression, castType string) serializeFunc
|
||||||
|
|
@ -3,6 +3,8 @@ package jet
|
||||||
type enumValue struct {
|
type enumValue struct {
|
||||||
expressionInterfaceImpl
|
expressionInterfaceImpl
|
||||||
stringInterfaceImpl
|
stringInterfaceImpl
|
||||||
|
noOpVisitorImpl
|
||||||
|
|
||||||
name string
|
name string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,12 @@ import (
|
||||||
// Expression is common interface for all expressions.
|
// Expression is common interface for all expressions.
|
||||||
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
|
// Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions.
|
||||||
type Expression interface {
|
type Expression interface {
|
||||||
|
acceptsVisitor
|
||||||
|
|
||||||
|
expression
|
||||||
|
}
|
||||||
|
|
||||||
|
type expression interface {
|
||||||
clause
|
clause
|
||||||
projection
|
projection
|
||||||
groupByClause
|
groupByClause
|
||||||
|
|
@ -95,7 +101,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpressio
|
||||||
return binaryExpression
|
return binaryExpression
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (c *binaryOpExpression) accept(visitor visitor) {
|
||||||
|
c.lhs.accept(visitor)
|
||||||
|
c.rhs.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return errors.New("jet: binary Expression is nil")
|
return errors.New("jet: binary Expression is nil")
|
||||||
}
|
}
|
||||||
|
|
@ -112,21 +123,25 @@ func (c *binaryOpExpression) serialize(statement statementType, out *sqlBuilder,
|
||||||
out.writeString("(")
|
out.writeString("(")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.lhs.serialize(statement, out); err != nil {
|
if dialectOveride := out.dialect.serializeOverride(c.operator); dialectOveride != nil {
|
||||||
return err
|
err = dialectOveride(c.lhs, c.rhs)(statement, out, options...)
|
||||||
}
|
} else {
|
||||||
|
if err := c.lhs.serialize(statement, out); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
out.writeString(c.operator)
|
out.writeString(c.operator)
|
||||||
|
|
||||||
if err := c.rhs.serialize(statement, out); err != nil {
|
if err := c.rhs.serialize(statement, out); err != nil {
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if wrap {
|
if wrap {
|
||||||
out.writeString(")")
|
out.writeString(")")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// A prefix operator Expression
|
// A prefix operator Expression
|
||||||
|
|
@ -144,6 +159,10 @@ func newPrefixExpression(expression Expression, operator string) prefixOpExpress
|
||||||
return prefixExpression
|
return prefixExpression
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *prefixOpExpression) accept(visitor visitor) {
|
||||||
|
p.expression.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (p *prefixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return errors.New("jet: Prefix Expression is nil")
|
return errors.New("jet: Prefix Expression is nil")
|
||||||
|
|
@ -176,6 +195,10 @@ func newPostfixOpExpression(expression Expression, operator string) postfixOpExp
|
||||||
return postfixOpExpression
|
return postfixOpExpression
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *postfixOpExpression) accept(visitor visitor) {
|
||||||
|
p.expression.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (p *postfixOpExpression) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return errors.New("jet: Postifx operator Expression is nil")
|
return errors.New("jet: Postifx operator Expression is nil")
|
||||||
|
|
|
||||||
|
|
@ -485,6 +485,14 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
|
||||||
return funcExp
|
return funcExp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *funcExpressionImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(f)
|
||||||
|
|
||||||
|
for _, exp := range f.expressions {
|
||||||
|
exp.accept(visitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (f *funcExpressionImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
if f == nil {
|
if f == nil {
|
||||||
return errors.New("jet: Function expressions is nil. ")
|
return errors.New("jet: Function expressions is nil. ")
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package template
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/go-jet/jet"
|
||||||
"github.com/go-jet/jet/generator/internal/metadata"
|
"github.com/go-jet/jet/generator/internal/metadata"
|
||||||
"github.com/go-jet/jet/internal/utils"
|
"github.com/go-jet/jet/internal/utils"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
@ -10,7 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect string) error {
|
func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect jet.Dialect) error {
|
||||||
if len(tables) == 0 && len(enums) == 0 {
|
if len(tables) == 0 && len(enums) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -59,7 +60,7 @@ func GenerateFiles(destDir string, tables, enums []metadata.MetaData, dialect st
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect string) error {
|
func generate(dirPath, packageName string, template string, metaDataList []metadata.MetaData, dialect jet.Dialect) error {
|
||||||
modelDirPath := filepath.Join(dirPath, packageName)
|
modelDirPath := filepath.Join(dirPath, packageName)
|
||||||
|
|
||||||
err := utils.EnsureDirPath(modelDirPath)
|
err := utils.EnsureDirPath(modelDirPath)
|
||||||
|
|
@ -92,14 +93,14 @@ func generate(dirPath, packageName string, template string, metaDataList []metad
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateTemplate generates template with template text and template data.
|
// GenerateTemplate generates template with template text and template data.
|
||||||
func GenerateTemplate(templateText string, templateData interface{}, dialect string) ([]byte, error) {
|
func GenerateTemplate(templateText string, templateData interface{}, dialect jet.Dialect) ([]byte, error) {
|
||||||
|
|
||||||
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
|
t, err := template.New("sqlBuilderTableTemplate").Funcs(template.FuncMap{
|
||||||
"ToGoIdentifier": utils.ToGoIdentifier,
|
"ToGoIdentifier": utils.ToGoIdentifier,
|
||||||
"now": func() string {
|
"now": func() string {
|
||||||
return time.Now().Format(time.RFC850)
|
return time.Now().Format(time.RFC850)
|
||||||
},
|
},
|
||||||
"dialect": func() string {
|
"dialect": func() jet.Dialect {
|
||||||
return dialect
|
return dialect
|
||||||
},
|
},
|
||||||
}).Parse(templateText)
|
}).Parse(templateText)
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ package table
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/go-jet/jet"
|
"github.com/go-jet/jet"
|
||||||
"github.com/go-jet/jet/{{dialect}}"
|
"github.com/go-jet/jet/{{dialect.PackageName}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
|
var {{ToGoIdentifier .Name}} = new{{.GoStructName}}()
|
||||||
|
|
@ -32,7 +32,7 @@ type {{.GoStructName}} struct {
|
||||||
|
|
||||||
//Columns
|
//Columns
|
||||||
{{- range .Columns}}
|
{{- range .Columns}}
|
||||||
{{ToGoIdentifier .Name}} {{dialect}}.Column{{.SqlBuilderColumnType}}
|
{{ToGoIdentifier .Name}} {{dialect.PackageName}}.Column{{.SqlBuilderColumnType}}
|
||||||
{{- end}}
|
{{- end}}
|
||||||
|
|
||||||
AllColumns jet.ColumnList
|
AllColumns jet.ColumnList
|
||||||
|
|
@ -51,12 +51,12 @@ func (a *{{.GoStructName}}) AS(alias string) *{{.GoStructName}} {
|
||||||
func new{{.GoStructName}}() *{{.GoStructName}} {
|
func new{{.GoStructName}}() *{{.GoStructName}} {
|
||||||
var (
|
var (
|
||||||
{{- range .Columns}}
|
{{- range .Columns}}
|
||||||
{{ToGoIdentifier .Name}}Column = {{dialect}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
|
{{ToGoIdentifier .Name}}Column = {{dialect.PackageName}}.{{.SqlBuilderColumnType}}Column("{{.Name}}")
|
||||||
{{- end}}
|
{{- end}}
|
||||||
)
|
)
|
||||||
|
|
||||||
return &{{.GoStructName}}{
|
return &{{.GoStructName}}{
|
||||||
Table: jet.NewTable("{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
|
Table: jet.NewTable(jet.{{dialect.Name}}, "{{.SchemaName}}", "{{.Name}}", {{template "column-list" .Columns}}),
|
||||||
|
|
||||||
//Columns
|
//Columns
|
||||||
{{- range .Columns}}
|
{{- range .Columns}}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package mysql
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/go-jet/jet"
|
||||||
"github.com/go-jet/jet/generator/internal/metadata"
|
"github.com/go-jet/jet/generator/internal/metadata"
|
||||||
"github.com/go-jet/jet/generator/internal/template"
|
"github.com/go-jet/jet/generator/internal/template"
|
||||||
"path"
|
"path"
|
||||||
|
|
@ -37,7 +38,7 @@ func Generate(destDir string, dbConn DBConnection) error {
|
||||||
|
|
||||||
genPath := path.Join(destDir, dbConn.DBName)
|
genPath := path.Join(destDir, dbConn.DBName)
|
||||||
|
|
||||||
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, "mysql")
|
err = template.GenerateFiles(genPath, dbInfo.TableInfos, dbInfo.EnumInfos, jet.MySql)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package postgres
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/go-jet/jet"
|
||||||
"github.com/go-jet/jet/generator/internal/metadata"
|
"github.com/go-jet/jet/generator/internal/metadata"
|
||||||
"github.com/go-jet/jet/generator/internal/template"
|
"github.com/go-jet/jet/generator/internal/template"
|
||||||
"path"
|
"path"
|
||||||
|
|
@ -41,7 +42,7 @@ func Generate(destDir string, dbConn DBConnection) error {
|
||||||
|
|
||||||
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
|
genPath := path.Join(destDir, dbConn.DBName, dbConn.SchemaName)
|
||||||
|
|
||||||
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, "postgres")
|
err = template.GenerateFiles(genPath, schemaInfo.TableInfos, schemaInfo.EnumInfos, jet.PostgreSQL)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -73,36 +73,44 @@ func (i *insertStatementImpl) getColumns() []column {
|
||||||
return i.table.columns()
|
return i.table.columns()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *insertStatementImpl) DebugSql() (query string, err error) {
|
func (i *insertStatementImpl) accept(visitor visitor) {
|
||||||
return debugSql(i)
|
visitor.visit(i)
|
||||||
|
|
||||||
|
i.table.accept(visitor)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
|
func (i *insertStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||||
queryData := &sqlBuilder{}
|
return debugSql(i, dialect...)
|
||||||
|
}
|
||||||
|
|
||||||
queryData.newLine()
|
func (i *insertStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||||
queryData.writeString("INSERT INTO")
|
out := &sqlBuilder{
|
||||||
|
dialect: detectDialect(i, dialect...),
|
||||||
|
}
|
||||||
|
|
||||||
|
out.newLine()
|
||||||
|
out.writeString("INSERT INTO")
|
||||||
|
|
||||||
if utils.IsNil(i.table) {
|
if utils.IsNil(i.table) {
|
||||||
return "", nil, errors.New("jet: table is nil")
|
return "", nil, errors.New("jet: table is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = i.table.serialize(insertStatement, queryData)
|
err = i.table.serialize(insertStatement, out)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(i.columns) > 0 {
|
if len(i.columns) > 0 {
|
||||||
queryData.writeString("(")
|
out.writeString("(")
|
||||||
|
|
||||||
err = serializeColumnNames(i.columns, queryData)
|
err = serializeColumnNames(i.columns, out)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
queryData.writeString(")")
|
out.writeString(")")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(i.rows) == 0 && i.query == nil {
|
if len(i.rows) == 0 && i.query == nil {
|
||||||
|
|
@ -114,41 +122,41 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(i.rows) > 0 {
|
if len(i.rows) > 0 {
|
||||||
queryData.writeString("VALUES")
|
out.writeString("VALUES")
|
||||||
|
|
||||||
for rowIndex, row := range i.rows {
|
for rowIndex, row := range i.rows {
|
||||||
if rowIndex > 0 {
|
if rowIndex > 0 {
|
||||||
queryData.writeString(",")
|
out.writeString(",")
|
||||||
}
|
}
|
||||||
|
|
||||||
queryData.increaseIdent()
|
out.increaseIdent()
|
||||||
queryData.newLine()
|
out.newLine()
|
||||||
queryData.writeString("(")
|
out.writeString("(")
|
||||||
|
|
||||||
err = serializeClauseList(insertStatement, row, queryData)
|
err = serializeClauseList(insertStatement, row, out)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
queryData.writeByte(')')
|
out.writeByte(')')
|
||||||
queryData.decreaseIdent()
|
out.decreaseIdent()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if i.query != nil {
|
if i.query != nil {
|
||||||
err = i.query.serialize(insertStatement, queryData)
|
err = i.query.serialize(insertStatement, out)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = queryData.writeReturning(insertStatement, i.returning); err != nil {
|
if err = out.writeReturning(insertStatement, i.returning); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sql, args = queryData.finalize()
|
query, args = out.finalize()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import "fmt"
|
||||||
// Representation of an escaped literal
|
// Representation of an escaped literal
|
||||||
type literalExpression struct {
|
type literalExpression struct {
|
||||||
expressionInterfaceImpl
|
expressionInterfaceImpl
|
||||||
|
noOpVisitorImpl
|
||||||
|
|
||||||
value interface{}
|
value interface{}
|
||||||
constant bool
|
constant bool
|
||||||
}
|
}
|
||||||
|
|
@ -188,6 +190,7 @@ func Date(year, month, day int) DateExpression {
|
||||||
//--------------------------------------------------//
|
//--------------------------------------------------//
|
||||||
type nullLiteral struct {
|
type nullLiteral struct {
|
||||||
expressionInterfaceImpl
|
expressionInterfaceImpl
|
||||||
|
noOpVisitorImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
func newNullLiteral() Expression {
|
func newNullLiteral() Expression {
|
||||||
|
|
@ -206,6 +209,7 @@ func (n *nullLiteral) serialize(statement statementType, out *sqlBuilder, option
|
||||||
//--------------------------------------------------//
|
//--------------------------------------------------//
|
||||||
type starLiteral struct {
|
type starLiteral struct {
|
||||||
expressionInterfaceImpl
|
expressionInterfaceImpl
|
||||||
|
noOpVisitorImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
func newStarLiteral() Expression {
|
func newStarLiteral() Expression {
|
||||||
|
|
@ -228,6 +232,12 @@ type wrap struct {
|
||||||
expressions []Expression
|
expressions []Expression
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *wrap) accept(visitor visitor) {
|
||||||
|
for _, exp := range n.expressions {
|
||||||
|
exp.accept(visitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (n *wrap) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
out.writeString("(")
|
out.writeString("(")
|
||||||
err := serializeExpressionList(statement, n.expressions, ", ", out)
|
err := serializeExpressionList(statement, n.expressions, ", ", out)
|
||||||
|
|
@ -247,6 +257,8 @@ func WRAP(expression ...Expression) Expression {
|
||||||
|
|
||||||
type rawExpression struct {
|
type rawExpression struct {
|
||||||
expressionInterfaceImpl
|
expressionInterfaceImpl
|
||||||
|
noOpVisitorImpl
|
||||||
|
|
||||||
raw string
|
raw string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,11 +53,19 @@ func (l *lockStatementImpl) NOWAIT() LockStatement {
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lockStatementImpl) DebugSql() (query string, err error) {
|
func (l *lockStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||||
return debugSql(l)
|
return debugSql(l, dialect...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) {
|
func (l *lockStatementImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(l)
|
||||||
|
|
||||||
|
for _, table := range l.tables {
|
||||||
|
table.accept(visitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *lockStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||||
if l == nil {
|
if l == nil {
|
||||||
return "", nil, errors.New("jet: nil Statement")
|
return "", nil, errors.New("jet: nil Statement")
|
||||||
}
|
}
|
||||||
|
|
@ -66,7 +74,9 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
|
||||||
return "", nil, errors.New("jet: There is no table selected to be locked")
|
return "", nil, errors.New("jet: There is no table selected to be locked")
|
||||||
}
|
}
|
||||||
|
|
||||||
out := &sqlBuilder{}
|
out := &sqlBuilder{
|
||||||
|
dialect: detectDialect(l, dialect...),
|
||||||
|
}
|
||||||
|
|
||||||
out.newLine()
|
out.newLine()
|
||||||
out.writeString("LOCK TABLE")
|
out.writeString("LOCK TABLE")
|
||||||
|
|
|
||||||
18
operators.go
18
operators.go
|
|
@ -108,6 +108,24 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *caseOperatorImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(c)
|
||||||
|
|
||||||
|
c.expression.accept(visitor)
|
||||||
|
|
||||||
|
for _, when := range c.when {
|
||||||
|
when.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, then := range c.then {
|
||||||
|
then.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.els != nil {
|
||||||
|
c.els.accept(visitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (c *caseOperatorImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return errors.New("jet: Case Expression is nil. ")
|
return errors.New("jet: Case Expression is nil. ")
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ var (
|
||||||
// SelectStatement is interface for SQL SELECT statements
|
// SelectStatement is interface for SQL SELECT statements
|
||||||
type SelectStatement interface {
|
type SelectStatement interface {
|
||||||
Statement
|
Statement
|
||||||
Expression
|
expression
|
||||||
|
|
||||||
DISTINCT() SelectStatement
|
DISTINCT() SelectStatement
|
||||||
FROM(table ReadableTable) SelectStatement
|
FROM(table ReadableTable) SelectStatement
|
||||||
|
|
@ -261,10 +261,29 @@ func (s *selectStatementImpl) serializeImpl(out *sqlBuilder) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) {
|
func (s *selectStatementImpl) accept(visitor visitor) {
|
||||||
queryData := sqlBuilder{}
|
visitor.visit(s)
|
||||||
|
|
||||||
err = s.serializeImpl(&queryData)
|
if s.table != nil {
|
||||||
|
s.table.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.where != nil {
|
||||||
|
s.where.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.having != nil {
|
||||||
|
s.having.accept(visitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *selectStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||||
|
|
||||||
|
queryData := &sqlBuilder{
|
||||||
|
dialect: detectDialect(s, dialect...),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.serializeImpl(queryData)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
|
|
@ -275,8 +294,8 @@ func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *selectStatementImpl) DebugSql() (query string, err error) {
|
func (s *selectStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||||
return debugSql(s.parent)
|
return debugSql(s.parent, dialect...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {
|
func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error {
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,15 @@ func (s *selectTableImpl) columns() []column {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *selectTableImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(s)
|
||||||
|
s.selectStmt.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *selectTableImpl) dialect() Dialect {
|
||||||
|
return detectDialect(s.selectStmt)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *selectTableImpl) AllColumns() ProjectionList {
|
func (s *selectTableImpl) AllColumns() ProjectionList {
|
||||||
return s.projections
|
return s.projections
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,14 @@ func newSetStatementImpl(operator string, all bool, selects []SelectStatement) S
|
||||||
return setStatement
|
return setStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *setStatementImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(s)
|
||||||
|
|
||||||
|
for _, selects := range s.selects {
|
||||||
|
selects.accept(visitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *setStatementImpl) projections() []projection {
|
func (s *setStatementImpl) projections() []projection {
|
||||||
if len(s.selects) > 0 {
|
if len(s.selects) > 0 {
|
||||||
return s.selects[0].projections()
|
return s.selects[0].projections()
|
||||||
|
|
@ -169,8 +177,10 @@ func (s *setStatementImpl) serializeImpl(out *sqlBuilder) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) {
|
func (s *setStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||||
queryData := &sqlBuilder{}
|
queryData := &sqlBuilder{
|
||||||
|
dialect: detectDialect(s, dialect...),
|
||||||
|
}
|
||||||
|
|
||||||
err = s.serializeImpl(queryData)
|
err = s.serializeImpl(queryData)
|
||||||
|
|
||||||
|
|
|
||||||
11
statement.go
11
statement.go
|
|
@ -4,19 +4,19 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"github.com/go-jet/jet/execution"
|
"github.com/go-jet/jet/execution"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
|
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
|
||||||
type Statement interface {
|
type Statement interface {
|
||||||
|
acceptsVisitor
|
||||||
// Sql returns parametrized sql query with list of arguments.
|
// Sql returns parametrized sql query with list of arguments.
|
||||||
// err is returned if statement is not composed correctly
|
// err is returned if statement is not composed correctly
|
||||||
Sql() (query string, args []interface{}, err error)
|
Sql(dialect ...Dialect) (query string, args []interface{}, err error)
|
||||||
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
|
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
|
||||||
// Do not use it in production. Use it only for debug purposes.
|
// Do not use it in production. Use it only for debug purposes.
|
||||||
// err is returned if statement is not composed correctly
|
// err is returned if statement is not composed correctly
|
||||||
DebugSql() (query string, err error)
|
DebugSql(dialect ...Dialect) (query string, err error)
|
||||||
|
|
||||||
// Query executes statement over database connection db and stores row result in destination.
|
// Query executes statement over database connection db and stores row result in destination.
|
||||||
// Destination can be arbitrary structure
|
// Destination can be arbitrary structure
|
||||||
|
|
@ -31,7 +31,8 @@ type Statement interface {
|
||||||
ExecContext(context context.Context, db execution.DB) (sql.Result, error)
|
ExecContext(context context.Context, db execution.DB) (sql.Result, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func debugSql(statement Statement) (string, error) {
|
func debugSql(statement Statement, overrideDialect ...Dialect) (string, error) {
|
||||||
|
dialect := detectDialect(statement, overrideDialect...)
|
||||||
sqlQuery, args, err := statement.Sql()
|
sqlQuery, args, err := statement.Sql()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -41,7 +42,7 @@ func debugSql(statement Statement) (string, error) {
|
||||||
debugSQLQuery := sqlQuery
|
debugSQLQuery := sqlQuery
|
||||||
|
|
||||||
for i, arg := range args {
|
for i, arg := range args {
|
||||||
argPlaceholder := "$" + strconv.Itoa(i+1)
|
argPlaceholder := dialect.ArgumentPlaceholder(i + 1)
|
||||||
debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
|
debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
28
table.go
28
table.go
|
|
@ -6,6 +6,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type table interface {
|
type table interface {
|
||||||
|
dialect() Dialect
|
||||||
columns() []column
|
columns() []column
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -42,6 +43,7 @@ type ReadableTable interface {
|
||||||
table
|
table
|
||||||
readableTable
|
readableTable
|
||||||
clause
|
clause
|
||||||
|
acceptsVisitor
|
||||||
}
|
}
|
||||||
|
|
||||||
// WritableTable interface
|
// WritableTable interface
|
||||||
|
|
@ -49,6 +51,7 @@ type WritableTable interface {
|
||||||
table
|
table
|
||||||
writableTable
|
writableTable
|
||||||
clause
|
clause
|
||||||
|
acceptsVisitor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Table interface
|
// Table interface
|
||||||
|
|
@ -57,6 +60,8 @@ type Table interface {
|
||||||
readableTable
|
readableTable
|
||||||
writableTable
|
writableTable
|
||||||
clause
|
clause
|
||||||
|
acceptsVisitor
|
||||||
|
|
||||||
SchemaName() string
|
SchemaName() string
|
||||||
TableName() string
|
TableName() string
|
||||||
AS(alias string)
|
AS(alias string)
|
||||||
|
|
@ -114,10 +119,11 @@ func (w *writableTableInterfaceImpl) LOCK() LockStatement {
|
||||||
return LOCK(w.parent)
|
return LOCK(w.parent)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTable creates new table with schema name, table name and list of columns
|
// NewTable creates new table with schema Name, table Name and list of columns
|
||||||
func NewTable(schemaName, name string, columns ...Column) Table {
|
func NewTable(Dialect Dialect, schemaName, name string, columns ...Column) Table {
|
||||||
|
|
||||||
t := &tableImpl{
|
t := &tableImpl{
|
||||||
|
Dialect: Dialect,
|
||||||
schemaName: schemaName,
|
schemaName: schemaName,
|
||||||
name: name,
|
name: name,
|
||||||
columnList: columns,
|
columnList: columns,
|
||||||
|
|
@ -136,6 +142,7 @@ type tableImpl struct {
|
||||||
readableTableInterfaceImpl
|
readableTableInterfaceImpl
|
||||||
writableTableInterfaceImpl
|
writableTableInterfaceImpl
|
||||||
|
|
||||||
|
Dialect Dialect
|
||||||
schemaName string
|
schemaName string
|
||||||
name string
|
name string
|
||||||
alias string
|
alias string
|
||||||
|
|
@ -168,6 +175,14 @@ func (t *tableImpl) columns() []column {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *tableImpl) dialect() Dialect {
|
||||||
|
return t.Dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tableImpl) accept(visitor visitor) {
|
||||||
|
visitor.visit(t)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *tableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
func (t *tableImpl) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) error {
|
||||||
if t == nil {
|
if t == nil {
|
||||||
return errors.New("jet: tableImpl is nil. ")
|
return errors.New("jet: tableImpl is nil. ")
|
||||||
|
|
@ -235,6 +250,15 @@ func (t *joinTable) columns() []column {
|
||||||
return append(t.lhs.columns(), t.rhs.columns()...)
|
return append(t.lhs.columns(), t.rhs.columns()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *joinTable) accept(visitor visitor) {
|
||||||
|
t.lhs.accept(visitor)
|
||||||
|
t.rhs.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *joinTable) dialect() Dialect {
|
||||||
|
return detectDialect(t)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) {
|
func (t *joinTable) serialize(statement statementType, out *sqlBuilder, options ...serializeOption) (err error) {
|
||||||
if t == nil {
|
if t == nil {
|
||||||
return errors.New("jet: Join table is nil. ")
|
return errors.New("jet: Join table is nil. ")
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,15 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStateme
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
|
func (u *updateStatementImpl) accept(visitor visitor) {
|
||||||
out := &sqlBuilder{}
|
visitor.visit(u)
|
||||||
|
u.table.accept(visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) {
|
||||||
|
out := &sqlBuilder{
|
||||||
|
dialect: detectDialect(u, dialect...),
|
||||||
|
}
|
||||||
|
|
||||||
out.newLine()
|
out.newLine()
|
||||||
out.writeString("UPDATE")
|
out.writeString("UPDATE")
|
||||||
|
|
@ -124,12 +131,12 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sql, args = out.finalize()
|
query, args = out.finalize()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *updateStatementImpl) DebugSql() (query string, err error) {
|
func (u *updateStatementImpl) DebugSql(dialect ...Dialect) (query string, err error) {
|
||||||
return debugSql(u)
|
return debugSql(u, dialect...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error {
|
func (u *updateStatementImpl) Query(db execution.DB, destination interface{}) error {
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ var table1ColBool = BoolColumn("col_bool")
|
||||||
var table1ColDate = DateColumn("col_date")
|
var table1ColDate = DateColumn("col_date")
|
||||||
|
|
||||||
var table1 = NewTable(
|
var table1 = NewTable(
|
||||||
|
PostgreSQL,
|
||||||
"db",
|
"db",
|
||||||
"table1",
|
"table1",
|
||||||
table1Col1,
|
table1Col1,
|
||||||
|
|
@ -44,6 +45,7 @@ var table2ColTimestampz = TimestampzColumn("col_timestampz")
|
||||||
var table2ColDate = DateColumn("col_date")
|
var table2ColDate = DateColumn("col_date")
|
||||||
|
|
||||||
var table2 = NewTable(
|
var table2 = NewTable(
|
||||||
|
PostgreSQL,
|
||||||
"db",
|
"db",
|
||||||
"table2",
|
"table2",
|
||||||
table2Col3,
|
table2Col3,
|
||||||
|
|
@ -63,6 +65,7 @@ var table3Col1 = IntegerColumn("col1")
|
||||||
var table3ColInt = IntegerColumn("col_int")
|
var table3ColInt = IntegerColumn("col_int")
|
||||||
var table3StrCol = StringColumn("col2")
|
var table3StrCol = StringColumn("col2")
|
||||||
var table3 = NewTable(
|
var table3 = NewTable(
|
||||||
|
PostgreSQL,
|
||||||
"db",
|
"db",
|
||||||
"table3",
|
"table3",
|
||||||
table3Col1,
|
table3Col1,
|
||||||
|
|
@ -70,7 +73,7 @@ var table3 = NewTable(
|
||||||
table3StrCol)
|
table3StrCol)
|
||||||
|
|
||||||
func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) {
|
func assertClauseSerialize(t *testing.T, clause clause, query string, args ...interface{}) {
|
||||||
out := sqlBuilder{}
|
out := sqlBuilder{dialect: PostgreSQL}
|
||||||
err := clause.serialize(selectStatement, &out)
|
err := clause.serialize(selectStatement, &out)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
@ -80,7 +83,7 @@ func assertClauseSerialize(t *testing.T, clause clause, query string, args ...in
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
|
func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
|
||||||
out := sqlBuilder{}
|
out := sqlBuilder{dialect: PostgreSQL}
|
||||||
err := clause.serialize(selectStatement, &out)
|
err := clause.serialize(selectStatement, &out)
|
||||||
|
|
||||||
//fmt.Println(out.buff.String())
|
//fmt.Println(out.buff.String())
|
||||||
|
|
@ -89,7 +92,7 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) {
|
func assertProjectionSerialize(t *testing.T, projection projection, query string, args ...interface{}) {
|
||||||
out := sqlBuilder{}
|
out := sqlBuilder{dialect: PostgreSQL}
|
||||||
err := projection.serializeForProjection(selectStatement, &out)
|
err := projection.serializeForProjection(selectStatement, &out)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
@ -99,7 +102,7 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
|
func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
|
||||||
queryStr, args, err := query.Sql()
|
queryStr, args, err := query.Sql(PostgreSQL)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
||||||
//fmt.Println(queryStr)
|
//fmt.Println(queryStr)
|
||||||
|
|
@ -108,7 +111,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
|
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
|
||||||
_, _, err := stmt.Sql()
|
_, _, err := stmt.Sql(PostgreSQL)
|
||||||
|
|
||||||
assert.Assert(t, err != nil)
|
assert.Assert(t, err != nil)
|
||||||
assert.Error(t, err, errorStr)
|
assert.Error(t, err, errorStr)
|
||||||
|
|
|
||||||
63
visitor.go
Normal file
63
visitor.go
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
package jet
|
||||||
|
|
||||||
|
type visitor interface {
|
||||||
|
visit(element acceptsVisitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
type acceptsVisitor interface {
|
||||||
|
accept(visitor visitor)
|
||||||
|
}
|
||||||
|
|
||||||
|
type noOpVisitorImpl struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noOpVisitorImpl) accept(visitor visitor) {
|
||||||
|
// NO OP
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------- dialect finder -----------------//
|
||||||
|
|
||||||
|
type DialectFinder struct {
|
||||||
|
dialects map[string]Dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDialectFinder() *DialectFinder {
|
||||||
|
return &DialectFinder{
|
||||||
|
dialects: make(map[string]Dialect),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DialectFinder) dialect() Dialect {
|
||||||
|
if len(f.dialects) == 0 {
|
||||||
|
panic("jet: can't detect dialect")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(f.dialects) > 1 {
|
||||||
|
panic("jet: more than one dialect detected")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dialect := range f.dialects {
|
||||||
|
return dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("jet: internal error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DialectFinder) visit(element acceptsVisitor) {
|
||||||
|
|
||||||
|
if table, ok := element.(table); ok {
|
||||||
|
dialect := table.dialect()
|
||||||
|
f.dialects[dialect.Name] = dialect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectDialect(element acceptsVisitor, dialectOverride ...Dialect) Dialect {
|
||||||
|
if len(dialectOverride) > 0 {
|
||||||
|
return dialectOverride[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
dialectFinder := newDialectFinder()
|
||||||
|
element.accept(dialectFinder)
|
||||||
|
|
||||||
|
return dialectFinder.dialect()
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue