From 736a6502414ddf6f7e1a965613cb2a3a6d6aea9c Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 11 Aug 2019 18:44:58 +0200 Subject: [PATCH] DebugSQL refactor. --- internal/jet/sql_builder.go | 155 +++++++++++++++--------------- internal/jet/statement.go | 87 +++++------------ internal/jet/testutils.go | 2 +- internal/jet/time_expression.go | 4 +- internal/jet/timez_expression.go | 4 +- tests/postgres/chinook_db_test.go | 3 + 6 files changed, 112 insertions(+), 143 deletions(-) diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index ec0f449..f5ec496 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -16,126 +16,124 @@ type SqlBuilder struct { lastChar byte ident int -} -func (s *SqlBuilder) DebugSQL() string { - return queryStringToDebugString(s.Buff.String(), s.Args, s.Dialect) + debug bool } const defaultIdent = 5 -func (q *SqlBuilder) IncreaseIdent(ident ...int) { +func (s *SqlBuilder) IncreaseIdent(ident ...int) { if len(ident) > 0 { - q.ident += ident[0] + s.ident += ident[0] } else { - q.ident += defaultIdent + s.ident += defaultIdent } } -func (q *SqlBuilder) DecreaseIdent(ident ...int) { +func (s *SqlBuilder) DecreaseIdent(ident ...int) { toDecrease := defaultIdent if len(ident) > 0 { toDecrease = ident[0] } - if q.ident < toDecrease { - q.ident = 0 + if s.ident < toDecrease { + s.ident = 0 } - q.ident -= toDecrease + s.ident -= toDecrease } -func (q *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error { - q.IncreaseIdent() - err := SerializeProjectionList(statement, projections, q) - q.DecreaseIdent() +func (s *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error { + s.IncreaseIdent() + err := SerializeProjectionList(statement, projections, s) + s.DecreaseIdent() return err } -func (q *SqlBuilder) writeFrom(statement StatementType, table Serializer) error { - q.NewLine() - q.WriteString("FROM") +func (s *SqlBuilder) writeFrom(statement StatementType, table Serializer) error { + s.NewLine() + s.WriteString("FROM") - q.IncreaseIdent() - err := table.serialize(statement, q) - q.DecreaseIdent() + s.IncreaseIdent() + err := table.serialize(statement, s) + s.DecreaseIdent() return err } -func (q *SqlBuilder) writeWhere(statement StatementType, where Expression) error { - q.NewLine() - q.WriteString("WHERE") +func (s *SqlBuilder) writeWhere(statement StatementType, where Expression) error { + s.NewLine() + s.WriteString("WHERE") - q.IncreaseIdent() - err := where.serialize(statement, q, noWrap) - q.DecreaseIdent() + s.IncreaseIdent() + err := where.serialize(statement, s, noWrap) + s.DecreaseIdent() return err } -func (q *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClause) error { - q.NewLine() - q.WriteString("GROUP BY") +func (s *SqlBuilder) writeGroupBy(statement StatementType, groupBy []GroupByClause) error { + s.NewLine() + s.WriteString("GROUP BY") - q.IncreaseIdent() - err := serializeGroupByClauseList(statement, groupBy, q) - q.DecreaseIdent() + s.IncreaseIdent() + err := serializeGroupByClauseList(statement, groupBy, s) + s.DecreaseIdent() return err } -func (q *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClause) error { - q.NewLine() - q.WriteString("ORDER BY") +func (s *SqlBuilder) writeOrderBy(statement StatementType, orderBy []OrderByClause) error { + s.NewLine() + s.WriteString("ORDER BY") - q.IncreaseIdent() - err := serializeOrderByClauseList(statement, orderBy, q) - q.DecreaseIdent() + s.IncreaseIdent() + err := serializeOrderByClauseList(statement, orderBy, s) + s.DecreaseIdent() return err } -func (q *SqlBuilder) writeHaving(statement StatementType, having Expression) error { - q.NewLine() - q.WriteString("HAVING") +func (s *SqlBuilder) writeHaving(statement StatementType, having Expression) error { + s.NewLine() + s.WriteString("HAVING") - q.IncreaseIdent() - err := having.serialize(statement, q, noWrap) - q.DecreaseIdent() + s.IncreaseIdent() + err := having.serialize(statement, s, noWrap) + s.DecreaseIdent() return err } -func (q *SqlBuilder) WriteReturning(statement StatementType, returning []Projection) error { +func (s *SqlBuilder) WriteReturning(statement StatementType, returning []Projection) error { if len(returning) == 0 { return nil } - q.NewLine() - q.WriteString("RETURNING") - q.IncreaseIdent() + s.NewLine() + s.WriteString("RETURNING") + s.IncreaseIdent() - return q.WriteProjections(statement, returning) + return s.WriteProjections(statement, returning) } -func (q *SqlBuilder) NewLine() { - q.write([]byte{'\n'}) - q.write(bytes.Repeat([]byte{' '}, q.ident)) +func (s *SqlBuilder) NewLine() { + s.write([]byte{'\n'}) + s.write(bytes.Repeat([]byte{' '}, s.ident)) } -func (q *SqlBuilder) write(data []byte) { +func (s *SqlBuilder) write(data []byte) { if len(data) == 0 { return } - if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.Buff.Len() > 0 { - q.Buff.WriteByte(' ') + if !isPreSeparator(s.lastChar) && !isPostSeparator(data[0]) && s.Buff.Len() > 0 { + s.Buff.WriteByte(' ') } - q.Buff.Write(data) - q.lastChar = data[len(data)-1] + s.Buff.Write(data) + s.lastChar = data[len(data)-1] } func isPreSeparator(b byte) bool { @@ -146,43 +144,48 @@ func isPostSeparator(b byte) bool { return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' } -func (q *SqlBuilder) writeAlias(str string) { - aliasQuoteChar := string(q.Dialect.AliasQuoteChar()) - q.WriteString(aliasQuoteChar + str + aliasQuoteChar) +func (s *SqlBuilder) writeAlias(str string) { + aliasQuoteChar := string(s.Dialect.AliasQuoteChar()) + s.WriteString(aliasQuoteChar + str + aliasQuoteChar) } -func (q *SqlBuilder) WriteString(str string) { - q.write([]byte(str)) +func (s *SqlBuilder) WriteString(str string) { + s.write([]byte(str)) } -func (q *SqlBuilder) writeIdentifier(name string, alwaysQuote ...bool) { +func (s *SqlBuilder) writeIdentifier(name string, alwaysQuote ...bool) { quoteWrap := name != strings.ToLower(name) || strings.ContainsAny(name, ". -") if quoteWrap || len(alwaysQuote) > 0 { - identQuoteChar := string(q.Dialect.IdentifierQuoteChar()) - q.WriteString(identQuoteChar + name + identQuoteChar) + identQuoteChar := string(s.Dialect.IdentifierQuoteChar()) + s.WriteString(identQuoteChar + name + identQuoteChar) } else { - q.WriteString(name) + s.WriteString(name) } } -func (q *SqlBuilder) writeByte(b byte) { - q.write([]byte{b}) +func (s *SqlBuilder) writeByte(b byte) { + s.write([]byte{b}) } -func (q *SqlBuilder) finalize() (string, []interface{}) { - return q.Buff.String() + ";\n", q.Args +func (s *SqlBuilder) finalize() (string, []interface{}) { + return s.Buff.String() + ";\n", s.Args } -func (q *SqlBuilder) insertConstantArgument(arg interface{}) { - q.WriteString(argToString(arg)) +func (s *SqlBuilder) insertConstantArgument(arg interface{}) { + s.WriteString(argToString(arg)) } -func (q *SqlBuilder) insertParametrizedArgument(arg interface{}) { - q.Args = append(q.Args, arg) - argPlaceholder := q.Dialect.ArgumentPlaceholder()(len(q.Args)) +func (s *SqlBuilder) insertParametrizedArgument(arg interface{}) { + if s.debug { + s.insertConstantArgument(arg) + return + } - q.WriteString(argPlaceholder) + s.Args = append(s.Args, arg) + argPlaceholder := s.Dialect.ArgumentPlaceholder()(len(s.Args)) + + s.WriteString(argPlaceholder) } func argToString(value interface{}) string { diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 4a27b44..c424e1f 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "github.com/go-jet/jet/execution" - "strings" ) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) @@ -13,11 +12,11 @@ type Statement interface { acceptsVisitor // Sql returns parametrized sql query with list of arguments. // err is returned if statement is not composed correctly - Sql(dialect ...Dialect) (query string, args []interface{}, err error) + Sql() (query string, args []interface{}, err error) // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // Do not use it in production. Use it only for debug purposes. // err is returned if statement is not composed correctly - DebugSql(dialect ...Dialect) (query string, err error) + DebugSql() (query string, err error) // Query executes statement over database connection db and stores row result in destination. // Destination can be arbitrary structure @@ -49,16 +48,16 @@ type HasProjections interface { type SerializerStatementInterfaceImpl struct { noOpVisitorImpl - Parent SerializerStatement - Dialect Dialect - StatementType StatementType + dialect Dialect + statementType StatementType + parent SerializerStatement } -func (s *SerializerStatementInterfaceImpl) Sql(dialect ...Dialect) (query string, args []interface{}, err error) { +func (s *SerializerStatementInterfaceImpl) Sql() (query string, args []interface{}, err error) { - queryData := &SqlBuilder{Dialect: s.Dialect} + queryData := &SqlBuilder{Dialect: s.dialect} - err = s.Parent.serialize(s.StatementType, queryData, noWrap) + err = s.parent.serialize(s.statementType, queryData, noWrap) if err != nil { return "", nil, err @@ -69,58 +68,22 @@ func (s *SerializerStatementInterfaceImpl) Sql(dialect ...Dialect) (query string return } -func (s *SerializerStatementInterfaceImpl) DebugSql(dialect ...Dialect) (query string, err error) { - return debugSql(s.Parent, s.Dialect) -} +func (s *SerializerStatementInterfaceImpl) DebugSql() (query string, err error) { + sqlBuilder := &SqlBuilder{Dialect: s.dialect, debug: true} -func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error { - return query(s.Parent, db, destination) -} - -func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { - return queryContext(context, s.Parent, db, destination) -} - -func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) { - return exec(s.Parent, db) -} - -func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { - return execContext(context, s.Parent, db) -} - -func debugSql(statement Statement, overrideDialect ...Dialect) (string, error) { - dialect := detectDialect(statement, overrideDialect...) - sqlQuery, args, err := statement.Sql(dialect) + err = s.parent.serialize(s.statementType, sqlBuilder, noWrap) if err != nil { return "", err } - //debugSQLQuery := sqlQuery - // - //for i, arg := range args { - // argPlaceholder := dialect.ArgumentPlaceholder()(i + 1) - // debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1) - //} - // - //return debugSQLQuery, nil - return queryStringToDebugString(sqlQuery, args, dialect), nil + query, _ = sqlBuilder.finalize() + + return } -func queryStringToDebugString(sqlQuery string, args []interface{}, dialect Dialect) string { - debugSQLQuery := sqlQuery - - for i, arg := range args { - argPlaceholder := dialect.ArgumentPlaceholder()(i + 1) - debugSQLQuery = strings.Replace(debugSQLQuery, argPlaceholder, argToString(arg), 1) - } - - return debugSQLQuery -} - -func query(statement Statement, db execution.DB, destination interface{}) error { - query, args, err := statement.Sql() +func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error { + query, args, err := s.Sql() if err != nil { return err @@ -129,8 +92,8 @@ func query(statement Statement, db execution.DB, destination interface{}) error return execution.Query(context.Background(), db, query, args, destination) } -func queryContext(context context.Context, statement Statement, db execution.DB, destination interface{}) error { - query, args, err := statement.Sql() +func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { + query, args, err := s.Sql() if err != nil { return err @@ -139,8 +102,8 @@ func queryContext(context context.Context, statement Statement, db execution.DB, return execution.Query(context, db, query, args, destination) } -func exec(statement Statement, db execution.DB) (res sql.Result, err error) { - query, args, err := statement.Sql() +func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) { + query, args, err := s.Sql() if err != nil { return @@ -149,8 +112,8 @@ func exec(statement Statement, db execution.DB) (res sql.Result, err error) { return db.Exec(query, args...) } -func execContext(context context.Context, statement Statement, db execution.DB) (res sql.Result, err error) { - query, args, err := statement.Sql() +func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { + query, args, err := s.Sql() if err != nil { return @@ -171,9 +134,9 @@ func (s *ExpressionStatementImpl) serializeForProjection(statement StatementType func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) StatementImpl { return StatementImpl{ SerializerStatementInterfaceImpl: SerializerStatementInterfaceImpl{ - Parent: parent, - Dialect: Dialect, - StatementType: statementType, + parent: parent, + dialect: Dialect, + statementType: statementType, }, Clauses: clauses, } diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index 8d71556..9680e95 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -86,7 +86,7 @@ func assertStatement(t *testing.T, query Statement, expectedQuery string, expect } func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { - _, _, err := stmt.Sql(DefaultDialect) + _, _, err := stmt.Sql() assert.Assert(t, err != nil) assert.Error(t, err, errorStr) diff --git a/internal/jet/time_expression.go b/internal/jet/time_expression.go index 03ce827..0fadfc5 100644 --- a/internal/jet/time_expression.go +++ b/internal/jet/time_expression.go @@ -63,8 +63,8 @@ type prefixTimeExpression struct { // timeExpr := prefixTimeExpression{} // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // -// timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr -// timeExpr.timeInterfaceImpl.Parent = &timeExpr +// timeExpr.ExpressionInterfaceImpl.parent = &timeExpr +// timeExpr.timeInterfaceImpl.parent = &timeExpr // // return &timeExpr //} diff --git a/internal/jet/timez_expression.go b/internal/jet/timez_expression.go index 529c92e..20c3d6d 100644 --- a/internal/jet/timez_expression.go +++ b/internal/jet/timez_expression.go @@ -71,8 +71,8 @@ type prefixTimezExpression struct { // timeExpr := prefixTimezExpression{} // timeExpr.prefixOpExpression = newPrefixExpression(expression, operator) // -// timeExpr.ExpressionInterfaceImpl.Parent = &timeExpr -// timeExpr.timezInterfaceImpl.Parent = &timeExpr +// timeExpr.ExpressionInterfaceImpl.parent = &timeExpr +// timeExpr.timezInterfaceImpl.parent = &timeExpr // // return &timeExpr //} diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 1d18ebd..5b7cf01 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model" @@ -16,6 +17,8 @@ func TestSelect(t *testing.T) { SELECT(Album.AllColumns). ORDER_BY(Album.AlbumId.ASC()) + fmt.Println(stmt.DebugSql()) + testutils.AssertDebugStatementSql(t, stmt, ` SELECT "Album"."AlbumId" AS "Album.AlbumId", "Album"."Title" AS "Album.Title",