DebugSQL refactor.

This commit is contained in:
go-jet 2019-08-11 18:44:58 +02:00
parent 0dd976601e
commit 736a650241
6 changed files with 112 additions and 143 deletions

View file

@ -16,126 +16,124 @@ type SqlBuilder struct {
lastChar byte lastChar byte
ident int ident int
}
func (s *SqlBuilder) DebugSQL() string { debug bool
return queryStringToDebugString(s.Buff.String(), s.Args, s.Dialect)
} }
const defaultIdent = 5 const defaultIdent = 5
func (q *SqlBuilder) IncreaseIdent(ident ...int) { func (s *SqlBuilder) IncreaseIdent(ident ...int) {
if len(ident) > 0 { if len(ident) > 0 {
q.ident += ident[0] s.ident += ident[0]
} else { } else {
q.ident += defaultIdent s.ident += defaultIdent
} }
} }
func (q *SqlBuilder) DecreaseIdent(ident ...int) { func (s *SqlBuilder) DecreaseIdent(ident ...int) {
toDecrease := defaultIdent toDecrease := defaultIdent
if len(ident) > 0 { if len(ident) > 0 {
toDecrease = ident[0] toDecrease = ident[0]
} }
if q.ident < toDecrease { if s.ident < toDecrease {
q.ident = 0 s.ident = 0
} }
q.ident -= toDecrease s.ident -= toDecrease
} }
func (q *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error { func (s *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error {
q.IncreaseIdent() s.IncreaseIdent()
err := SerializeProjectionList(statement, projections, q) err := SerializeProjectionList(statement, projections, s)
q.DecreaseIdent() s.DecreaseIdent()
return err return err
} }
func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error { func (s *SqlBuilder) writeFrom(statement StatementType, table Serializer) error {
q.NewLine() s.NewLine()
q.WriteString("FROM") s.WriteString("FROM")
q.IncreaseIdent() s.IncreaseIdent()
err := table.serialize(statement, q) err := table.serialize(statement, s)
q.DecreaseIdent() s.DecreaseIdent()
return err return err
} }
func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error { func (s *SqlBuilder) writeWhere(statement StatementType, where Expression) error {
q.NewLine() s.NewLine()
q.WriteString("WHERE") s.WriteString("WHERE")
q.IncreaseIdent() s.IncreaseIdent()
err := where.serialize(statement, q, noWrap) err := where.serialize(statement, s, noWrap)
q.DecreaseIdent() s.DecreaseIdent()
return err return err
} }
func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClause) error { func (s *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClause) error {
q.NewLine() s.NewLine()
q.WriteString("GROUP BY") s.WriteString("GROUP BY")
q.IncreaseIdent() s.IncreaseIdent()
err := serializeGroupByClauseList(statement, groupBy, q) err := serializeGroupByClauseList(statement, groupBy, s)
q.DecreaseIdent() s.DecreaseIdent()
return err return err
} }
func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClause) error { func (s *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClause) error {
q.NewLine() s.NewLine()
q.WriteString("ORDER BY") s.WriteString("ORDER BY")
q.IncreaseIdent() s.IncreaseIdent()
err := serializeOrderByClauseList(statement, orderBy, q) err := serializeOrderByClauseList(statement, orderBy, s)
q.DecreaseIdent() s.DecreaseIdent()
return err return err
} }
func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) error { func (s *SqlBuilder) writeHaving(statement StatementType, having Expression) error {
q.NewLine() s.NewLine()
q.WriteString("HAVING") s.WriteString("HAVING")
q.IncreaseIdent() s.IncreaseIdent()
err := having.serialize(statement, q, noWrap) err := having.serialize(statement, s, noWrap)
q.DecreaseIdent() s.DecreaseIdent()
return err return err
} }
func (q *SqlBuilder) WriteReturning(statement StatementType, returning []Projection) error { func (s *SqlBuilder) WriteReturning(statement StatementType, returning []Projection) error {
if len(returning) == 0 { if len(returning) == 0 {
return nil return nil
} }
q.NewLine() s.NewLine()
q.WriteString("RETURNING") s.WriteString("RETURNING")
q.IncreaseIdent() s.IncreaseIdent()
return q.WriteProjections(statement, returning) return s.WriteProjections(statement, returning)
} }
func (q *SqlBuilder) NewLine() { func (s *SqlBuilder) NewLine() {
q.write([]byte{'\n'}) s.write([]byte{'\n'})
q.write(bytes.Repeat([]byte{' '}, q.ident)) s.write(bytes.Repeat([]byte{' '}, s.ident))
} }
func (q *SqlBuilder) write(data []byte) { func (s *SqlBuilder) write(data []byte) {
if len(data) == 0 { if len(data) == 0 {
return return
} }
if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.Buff.Len() > 0 { if !isPreSeparator(s.lastChar) && !isPostSeparator(data[0]) && s.Buff.Len() > 0 {
q.Buff.WriteByte(' ') s.Buff.WriteByte(' ')
} }
q.Buff.Write(data) s.Buff.Write(data)
q.lastChar = data[len(data)-1] s.lastChar = data[len(data)-1]
} }
func isPreSeparator(b byte) bool { func isPreSeparator(b byte) bool {
@ -146,43 +144,48 @@ func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
} }
func (q *SqlBuilder) writeAlias(str string) { func (s *SqlBuilder) writeAlias(str string) {
aliasQuoteChar := string(q.Dialect.AliasQuoteChar()) aliasQuoteChar := string(s.Dialect.AliasQuoteChar())
q.WriteString(aliasQuoteChar + str + aliasQuoteChar) s.WriteString(aliasQuoteChar + str + aliasQuoteChar)
} }
func (q *SqlBuilder) WriteString(str string) { func (s *SqlBuilder) WriteString(str string) {
q.write([]byte(str)) s.write([]byte(str))
} }
func (q *SqlBuilder) writeIdentifier(name string, alwaysQuote ...bool) { func (s *SqlBuilder) writeIdentifier(name string, alwaysQuote ...bool) {
quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -") quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -")
if quoteWrap || len(alwaysQuote) > 0 { if quoteWrap || len(alwaysQuote) > 0 {
identQuoteChar := string(q.Dialect.IdentifierQuoteChar()) identQuoteChar := string(s.Dialect.IdentifierQuoteChar())
q.WriteString(identQuoteChar + name + identQuoteChar) s.WriteString(identQuoteChar + name + identQuoteChar)
} else { } else {
q.WriteString(name) s.WriteString(name)
} }
} }
func (q *SqlBuilder) writeByte(b byte) { func (s *SqlBuilder) writeByte(b byte) {
q.write([]byte{b}) s.write([]byte{b})
} }
func (q *SqlBuilder) finalize() (string, []interface{}) { func (s *SqlBuilder) finalize() (string, []interface{}) {
return q.Buff.String() + ";\n", q.Args return s.Buff.String() + ";\n", s.Args
} }
func (q *SqlBuilder) insertConstantArgument(arg interface{}) { func (s *SqlBuilder) insertConstantArgument(arg interface{}) {
q.WriteString(argToString(arg)) s.WriteString(argToString(arg))
} }
func (q *SqlBuilder) insertParametrizedArgument(arg interface{}) { func (s *SqlBuilder) insertParametrizedArgument(arg interface{}) {
q.Args = append(q.Args, arg) if s.debug {
argPlaceholder := q.Dialect.ArgumentPlaceholder()(len(q.Args)) s.insertConstantArgument(arg)
return
}
q.WriteString(argPlaceholder) s.Args = append(s.Args, arg)
argPlaceholder := s.Dialect.ArgumentPlaceholder()(len(s.Args))
s.WriteString(argPlaceholder)
} }
func argToString(value interface{}) string { func argToString(value interface{}) string {

View file

@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/execution" "github.com/go-jet/jet/execution"
"strings"
) )
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
@ -13,11 +12,11 @@ type Statement interface {
acceptsVisitor acceptsVisitor
// Sql returns parametrized sql query with list of arguments. // Sql returns parametrized sql query with list of arguments.
// err is returned if statement is not composed correctly // err is returned if statement is not composed correctly
Sql(dialect ...Dialect) (query string, args []interface{}, err error) Sql() (query string, args []interface{}, err error)
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes. // Do not use it in production. Use it only for debug purposes.
// err is returned if statement is not composed correctly // err is returned if statement is not composed correctly
DebugSql(dialect ...Dialect) (query string, err error) DebugSql() (query string, err error)
// Query executes statement over database connection db and stores row result in destination. // Query executes statement over database connection db and stores row result in destination.
// Destination can be arbitrary structure // Destination can be arbitrary structure
@ -49,16 +48,16 @@ type HasProjections interface {
type SerializerStatementInterfaceImpl struct { type SerializerStatementInterfaceImpl struct {
noOpVisitorImpl noOpVisitorImpl
Parent SerializerStatement dialect Dialect
Dialect Dialect statementType StatementType
StatementType StatementType parent SerializerStatement
} }
func (s *SerializerStatementInterfaceImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { func (s *SerializerStatementInterfaceImpl) Sql() (query string, args []interface{}, err error) {
queryData := &SqlBuilder{Dialect: s.Dialect} queryData := &SqlBuilder{Dialect: s.dialect}
err = s.Parent.serialize(s.StatementType, queryData, noWrap) err = s.parent.serialize(s.statementType, queryData, noWrap)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -69,58 +68,22 @@ func (s *SerializerStatementInterfaceImpl) Sql(dialect ...Dialect) (query string
return return
} }
func (s *SerializerStatementInterfaceImpl) DebugSql(dialect ...Dialect) (query string, err error) { func (s *SerializerStatementInterfaceImpl) DebugSql() (query string, err error) {
return debugSql(s.Parent, s.Dialect) sqlBuilder := &SqlBuilder{Dialect: s.dialect, debug: true}
}
func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error { err = s.parent.serialize(s.statementType, sqlBuilder, noWrap)
return query(s.Parent, db, destination)
}
func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
return queryContext(context, s.Parent, db, destination)
}
func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) {
return exec(s.Parent, db)
}
func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
return execContext(context, s.Parent, db)
}
func debugSql(statement Statement, overrideDialect ...Dialect) (string, error) {
dialect := detectDialect(statement, overrideDialect...)
sqlQuery, args, err := statement.Sql(dialect)
if err != nil { if err != nil {
return "", err return "", err
} }
//debugSQLQuery := sqlQuery query, _ = sqlBuilder.finalize()
//
//for i, arg := range args { return
// argPlaceholder := dialect.ArgumentPlaceholder()(i + 1)
// debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
//}
//
//return debugSQLQuery, nil
return queryStringToDebugString(sqlQuery, args, dialect), nil
} }
func queryStringToDebugString(sqlQuery string, args []interface{}, dialect Dialect) string { func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error {
debugSQLQuery := sqlQuery query, args, err := s.Sql()
for i, arg := range args {
argPlaceholder := dialect.ArgumentPlaceholder()(i + 1)
debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1)
}
return debugSQLQuery
}
func query(statement Statement, db execution.DB, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil { if err != nil {
return err return err
@ -129,8 +92,8 @@ func query(statement Statement, db execution.DB, destination interface{}) error
return execution.Query(context.Background(), db, query, args, destination) return execution.Query(context.Background(), db, query, args, destination)
} }
func queryContext(context context.Context, statement Statement, db execution.DB, destination interface{}) error { func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error {
query, args, err := statement.Sql() query, args, err := s.Sql()
if err != nil { if err != nil {
return err return err
@ -139,8 +102,8 @@ func queryContext(context context.Context, statement Statement, db execution.DB,
return execution.Query(context, db, query, args, destination) return execution.Query(context, db, query, args, destination)
} }
func exec(statement Statement, db execution.DB) (res sql.Result, err error) { func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) {
query, args, err := statement.Sql() query, args, err := s.Sql()
if err != nil { if err != nil {
return return
@ -149,8 +112,8 @@ func exec(statement Statement, db execution.DB) (res sql.Result, err error) {
return db.Exec(query, args...) return db.Exec(query, args...)
} }
func execContext(context context.Context, statement Statement, db execution.DB) (res sql.Result, err error) { func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) {
query, args, err := statement.Sql() query, args, err := s.Sql()
if err != nil { if err != nil {
return return
@ -171,9 +134,9 @@ func (s *ExpressionStatementImpl) serializeForProjection(statement StatementType
func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) StatementImpl { func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) StatementImpl {
return StatementImpl{ return StatementImpl{
SerializerStatementInterfaceImpl: SerializerStatementInterfaceImpl{ SerializerStatementInterfaceImpl: SerializerStatementInterfaceImpl{
Parent: parent, parent: parent,
Dialect: Dialect, dialect: Dialect,
StatementType: statementType, statementType: statementType,
}, },
Clauses: clauses, Clauses: clauses,
} }

View file

@ -86,7 +86,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect
} }
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
_, _, err := stmt.Sql(DefaultDialect) _, _, err := stmt.Sql()
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
assert.Error(t, err, errorStr) assert.Error(t, err, errorStr)

View file

@ -63,8 +63,8 @@ type prefixTimeExpression struct {
// timeExpr := prefixTimeExpression{} // timeExpr := prefixTimeExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
// //
// timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr // timeExpr.ExpressionInterfaceImpl.parent = &timeExpr
// timeExpr.timeInterfaceImpl.Parent = &timeExpr // timeExpr.timeInterfaceImpl.parent = &timeExpr
// //
// return &timeExpr // return &timeExpr
//} //}

View file

@ -71,8 +71,8 @@ type prefixTimezExpression struct {
// timeExpr := prefixTimezExpression{} // timeExpr := prefixTimezExpression{}
// timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator)
// //
// timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr // timeExpr.ExpressionInterfaceImpl.parent = &timeExpr
// timeExpr.timezInterfaceImpl.Parent = &timeExpr // timeExpr.timezInterfaceImpl.parent = &timeExpr
// //
// return &timeExpr // return &timeExpr
//} //}

View file

@ -2,6 +2,7 @@ package postgres
import ( import (
"context" "context"
"fmt"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model"
@ -16,6 +17,8 @@ func TestSelect(t *testing.T) {
SELECT(Album.AllColumns). SELECT(Album.AllColumns).
ORDER_BY(Album.AlbumId.ASC()) ORDER_BY(Album.AlbumId.ASC())
fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, ` testutils.AssertDebugStatementSql(t, stmt, `
SELECT "Album"."AlbumId" AS "Album.AlbumId", SELECT "Album"."AlbumId" AS "Album.AlbumId",
"Album"."Title" AS "Album.Title", "Album"."Title" AS "Album.Title",