Add automatic query logger function with additional execution details.

This commit is contained in:
go-jet 2022-01-12 19:03:50 +01:00
parent 7377e078cd
commit 4955bfc4b5
18 changed files with 266 additions and 59 deletions

View file

@ -20,7 +20,7 @@ WHERE table_schema = ? and table_type = ?;
` `
var tables []metadata.Table var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err) throw.OnError(err)
for i := range tables { for i := range tables {
@ -53,7 +53,7 @@ WHERE table_schema = ? AND table_name = ?
ORDER BY ordinal_position; ORDER BY ordinal_position;
` `
var columns []metadata.Column var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns)
throw.OnError(err) throw.OnError(err)
return columns return columns
@ -72,7 +72,7 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum';
Values string Values string
} }
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult)
throw.OnError(err) throw.OnError(err)
var ret []metadata.Enum var ret []metadata.Enum

View file

@ -19,7 +19,7 @@ WHERE table_schema = $1 and table_type = $2;
` `
var tables []metadata.Table var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables)
throw.OnError(err) throw.OnError(err)
for i := range tables { for i := range tables {
@ -58,7 +58,7 @@ where table_schema = $1 and table_name = $2
order by ordinal_position; order by ordinal_position;
` `
var columns []metadata.Column var columns []metadata.Column
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns)
throw.OnError(err) throw.OnError(err)
return columns return columns
@ -76,7 +76,7 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;`
var result []metadata.Enum var result []metadata.Enum
err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result)
throw.OnError(err) throw.OnError(err)
return result return result

View file

@ -28,7 +28,7 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy
var tables []metadata.Table var tables []metadata.Table
err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) _, err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables)
throw.OnError(err) throw.OnError(err)
for i := range tables { for i := range tables {
@ -47,7 +47,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t
Pk int32 Pk int32
} }
err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) _, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos)
throw.OnError(err) throw.OnError(err)
var columns []metadata.Column var columns []metadata.Column

View file

@ -1,6 +1,11 @@
package jet package jet
import "context" import (
"context"
"runtime"
"strings"
"time"
)
// PrintableStatement is a statement which sql query can be logged // PrintableStatement is a statement which sql query can be logged
type PrintableStatement interface { type PrintableStatement interface {
@ -8,7 +13,7 @@ type PrintableStatement interface {
DebugSql() (query string) DebugSql() (query string)
} }
// LoggerFunc is a definition of a function user can implement to support automatic statement logging. // LoggerFunc is a function user can implement to support automatic statement logging.
type LoggerFunc func(ctx context.Context, statement PrintableStatement) type LoggerFunc func(ctx context.Context, statement PrintableStatement)
var logger LoggerFunc var logger LoggerFunc
@ -17,3 +22,60 @@ var logger LoggerFunc
func SetLoggerFunc(loggerFunc LoggerFunc) { func SetLoggerFunc(loggerFunc LoggerFunc) {
logger = loggerFunc logger = loggerFunc
} }
func callLogger(ctx context.Context, statement Statement) {
if logger != nil {
logger(ctx, statement)
}
}
// QueryInfo contains information about executed query
type QueryInfo struct {
Statement PrintableStatement
// Depending of statement execution method RowsProcessed is:
// - Number of rows returned for Query() and QueryContext() methods
// - RowsAffected() for Exec() and ExecContext() methods
// - Always 0 for Rows() method.
RowsProcessed int64
Duration time.Duration
Err error
}
// QueryLoggerFunc is a function user can implement to retrieve more information about statement executed.
type QueryLoggerFunc func(ctx context.Context, info QueryInfo)
var queryLoggerFunc QueryLoggerFunc
// SetQueryLoggerFunc sets automatic query logging function.
func SetQueryLoggerFunc(loggerFunc QueryLoggerFunc) {
queryLoggerFunc = loggerFunc
}
func callQueryLoggerFunc(ctx context.Context, info QueryInfo) {
if queryLoggerFunc != nil {
queryLoggerFunc(ctx, info)
}
}
// Caller returns information about statement caller
func (q QueryInfo) Caller() (file string, line int, function string) {
skip := 4
// depending on execution type (Query, QueryContext, Exec, ...) looped once or twice
for {
var pc uintptr
var ok bool
pc, file, line, ok = runtime.Caller(skip)
if !ok {
return
}
funcDetails := runtime.FuncForPC(pc)
if !strings.Contains(funcDetails.Name(), "github.com/go-jet/jet/v2/internal") {
function = funcDetails.Name()
return
}
skip++
}
}

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
"time"
) )
//Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
@ -21,9 +22,9 @@ type Statement interface {
// Destination can be either pointer to struct or pointer to a slice. // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error
//Exec executes statement over db connection/transaction without returning any rows. // Exec executes statement over db connection/transaction without returning any rows.
Exec(db qrm.DB) (sql.Result, error) Exec(db qrm.DB) (sql.Result, error)
//Exec executes statement with context over db connection/transaction without returning any rows. // ExecContext executes statement with context over db connection/transaction without returning any rows.
ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error)
// Rows executes statements over db connection/transaction and returns rows // Rows executes statements over db connection/transaction and returns rows
Rows(ctx context.Context, db qrm.DB) (*Rows, error) Rows(ctx context.Context, db qrm.DB) (*Rows, error)
@ -84,12 +85,7 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
} }
func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error {
query, args := s.Sql() return s.QueryContext(context.Background(), db, destination)
ctx := context.Background()
callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination)
} }
func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error {
@ -97,15 +93,25 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db
callLogger(ctx, s) callLogger(ctx, s)
return qrm.Query(ctx, db, query, args, destination) var rowsProcessed int64
var err error
duration := duration(func() {
rowsProcessed, err = qrm.Query(ctx, db, query, args, destination)
})
callQueryLoggerFunc(ctx, QueryInfo{
Statement: s,
RowsProcessed: rowsProcessed,
Duration: duration,
Err: err,
})
return err
} }
func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) {
query, args := s.Sql() return s.ExecContext(context.Background(), db)
callLogger(context.Background(), s)
return db.Exec(query, args...)
} }
func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) {
@ -113,7 +119,24 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q
callLogger(ctx, s) callLogger(ctx, s)
return db.ExecContext(ctx, query, args...) duration := duration(func() {
res, err = db.ExecContext(ctx, query, args...)
})
var rowsAffected int64
if err == nil {
rowsAffected, _ = res.RowsAffected()
}
callQueryLoggerFunc(ctx, QueryInfo{
Statement: s,
RowsProcessed: rowsAffected,
Duration: duration,
Err: err,
})
return res, err
} }
func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) { func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) {
@ -121,7 +144,18 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB)
callLogger(ctx, s) callLogger(ctx, s)
rows, err := db.QueryContext(ctx, query, args...) var rows *sql.Rows
var err error
duration := duration(func() {
rows, err = db.QueryContext(ctx, query, args...)
})
callQueryLoggerFunc(ctx, QueryInfo{
Statement: s,
Duration: duration,
Err: err,
})
if err != nil { if err != nil {
return nil, err return nil, err
@ -130,10 +164,12 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB)
return &Rows{rows}, nil return &Rows{rows}, nil
} }
func callLogger(ctx context.Context, statement Statement) { func duration(f func()) time.Duration {
if logger != nil { start := time.Now()
logger(ctx, statement)
} f()
return time.Now().Sub(start)
} }
// ExpressionStatement interfacess // ExpressionStatement interfacess

View file

@ -24,4 +24,11 @@ type OrderByClause = jet.OrderByClause
type GroupByClause = jet.GroupByClause type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging // SetLogger sets automatic statement logging
// Deprecated: use SetQueryLoggerFunc instead.
var SetLogger = jet.SetLoggerFunc var SetLogger = jet.SetLoggerFunc
// SetQueryLoggerFunc sets automatic query logging function.
var SetQueryLoggerFunc = jet.SetQueryLoggerFunc
// QueryInfo contains information about executed query
type QueryInfo = jet.QueryInfo

View file

@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY // GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging // SetLogger sets automatic statement logging function
// Deprecated: use SetQueryLoggerFunc instead.
var SetLogger = jet.SetLoggerFunc var SetLogger = jet.SetLoggerFunc
// SetQueryLoggerFunc sets automatic query logging function.
var SetQueryLoggerFunc = jet.SetQueryLoggerFunc
// QueryInfo contains information about executed query
type QueryInfo = jet.QueryInfo

View file

@ -17,7 +17,7 @@ var ErrNoRows = errors.New("qrm: no rows in result set")
// using context `ctx` into destination `destPtr`. // using context `ctx` into destination `destPtr`.
// Destination can be either pointer to struct or pointer to slice of structs. // Destination can be either pointer to struct or pointer to slice of structs.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
utils.MustBeInitializedPtr(db, "jet: db is nil") utils.MustBeInitializedPtr(db, "jet: db is nil")
utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil")
@ -26,11 +26,11 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
destinationPtrType := reflect.TypeOf(destPtr) destinationPtrType := reflect.TypeOf(destPtr)
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
_, err := queryToSlice(ctx, db, query, args, destPtr) rowsProcessed, err := queryToSlice(ctx, db, query, args, destPtr)
if err != nil { if err != nil {
return fmt.Errorf("jet: %w", err) return rowsProcessed, fmt.Errorf("jet: %w", err)
} }
return nil return rowsProcessed, nil
} else if destinationPtrType.Elem().Kind() == reflect.Struct { } else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem() tempSliceValue := tempSlicePtrValue.Elem()
@ -38,16 +38,16 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
if err != nil { if err != nil {
return fmt.Errorf("jet: %w", err) return rowsProcessed, fmt.Errorf("jet: %w", err)
} }
if rowsProcessed == 0 { if rowsProcessed == 0 {
return ErrNoRows return 0, ErrNoRows
} }
// edge case when row result set contains only NULLs. // edge case when row result set contains only NULLs.
if tempSliceValue.Len() == 0 { if tempSliceValue.Len() == 0 {
return nil return rowsProcessed, nil
} }
structValue := reflect.ValueOf(destPtr).Elem() structValue := reflect.ValueOf(destPtr).Elem()
@ -56,7 +56,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
if structValue.Type().AssignableTo(firstTempStruct.Type()) { if structValue.Type().AssignableTo(firstTempStruct.Type()) {
structValue.Set(tempSliceValue.Index(0).Elem()) structValue.Set(tempSliceValue.Index(0).Elem())
} }
return nil return rowsProcessed, nil
} else { } else {
panic("jet: destination has to be a pointer to slice or pointer to struct") panic("jet: destination has to be a pointer to slice or pointer to struct")
} }
@ -136,7 +136,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
err = rows.Scan(scanContext.row...) err = rows.Scan(scanContext.row...)
if err != nil { if err != nil {
return return scanContext.rowNum, err
} }
scanContext.rowNum++ scanContext.rowNum++
@ -144,24 +144,16 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil) _, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil)
if err != nil { if err != nil {
return return scanContext.rowNum, err
} }
} }
err = rows.Close() err = rows.Close()
if err != nil { if err != nil {
return return scanContext.rowNum, err
} }
err = rows.Err() return scanContext.rowNum, rows.Err()
if err != nil {
return
}
rowsProcessed = scanContext.rowNum
return
} }
func mapRowToSlice( func mapRowToSlice(

View file

@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY // GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging // SetLogger sets automatic statement logging.
// Deprecated: use SetQueryLoggerFunc instead.
var SetLogger = jet.SetLoggerFunc var SetLogger = jet.SetLoggerFunc
// SetQueryLoggerFunc sets automatic query logging function.
var SetQueryLoggerFunc = jet.SetQueryLoggerFunc
// QueryInfo contains information about executed query
type QueryInfo = jet.QueryInfo

View file

@ -8,6 +8,7 @@ import (
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"math/rand" "math/rand"
"runtime"
"time" "time"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@ -51,11 +52,21 @@ var loggedSQL string
var loggedSQLArgs []interface{} var loggedSQLArgs []interface{}
var loggedDebugSQL string var loggedDebugSQL string
var queryInfo jetmysql.QueryInfo
var callerFile string
var callerLine int
var callerFunction string
func init() { func init() {
jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) { jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql() loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql() loggedDebugSQL = statement.DebugSql()
}) })
jetmysql.SetQueryLoggerFunc(func(ctx context.Context, info jetmysql.QueryInfo) {
queryInfo = info
callerFile, callerLine, callerFunction = info.Caller()
})
} }
func requireLogged(t *testing.T, statement postgres.Statement) { func requireLogged(t *testing.T, statement postgres.Statement) {
@ -65,6 +76,21 @@ func requireLogged(t *testing.T, statement postgres.Statement) {
require.Equal(t, loggedDebugSQL, statement.DebugSql()) require.Equal(t, loggedDebugSQL, statement.DebugSql())
} }
func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) {
query, args := statement.Sql()
queryLogged, argsLogged := queryInfo.Statement.Sql()
require.Equal(t, query, queryLogged)
require.Equal(t, args, argsLogged)
require.Equal(t, queryInfo.RowsProcessed, rowsProcessed)
pc, file, _, _ := runtime.Caller(1)
funcDetails := runtime.FuncForPC(pc)
require.Equal(t, file, callerFile)
require.NotEmpty(t, callerLine)
require.Equal(t, funcDetails.Name(), callerFunction)
}
func skipForMariaDB(t *testing.T) { func skipForMariaDB(t *testing.T) {
if sourceIsMariaDB() { if sourceIsMariaDB() {
t.SkipNow() t.SkipNow()

View file

@ -38,6 +38,7 @@ WHERE actor.actor_id = ?;
testutils.AssertDeepEqual(t, actor, actor2) testutils.AssertDeepEqual(t, actor, actor2)
requireLogged(t, query) requireLogged(t, query)
requireQueryLogged(t, query, 1)
} }
var actor2 = model.Actor{ var actor2 = model.Actor{
@ -60,9 +61,9 @@ SELECT actor.actor_id AS "actor.actor_id",
FROM dvds.actor FROM dvds.actor
ORDER BY actor.actor_id; ORDER BY actor.actor_id;
`) `)
dest := []model.Actor{} var dest []model.Actor
err := query.Query(db, &dest) err := query.QueryContext(context.Background(), db, &dest)
require.NoError(t, err) require.NoError(t, err)
@ -73,6 +74,7 @@ ORDER BY actor.actor_id;
//testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json")
requireLogged(t, query) requireLogged(t, query)
requireQueryLogged(t, query, 200)
} }
func TestSelectGroupByHaving(t *testing.T) { func TestSelectGroupByHaving(t *testing.T) {

View file

@ -35,6 +35,7 @@ ORDER BY "Album"."AlbumId" ASC;
testutils.AssertDeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[1], album2)
testutils.AssertDeepEqual(t, dest[len(dest)-1], album347) testutils.AssertDeepEqual(t, dest[len(dest)-1], album347)
requireLogged(t, stmt) requireLogged(t, stmt)
requireQueryLogged(t, stmt, 347)
} }
func TestJoinEverything(t *testing.T) { func TestJoinEverything(t *testing.T) {
@ -191,12 +192,13 @@ FROM chinook."Artist"
ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId"; ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId";
`) `)
err := stmt.Query(db, &dest) err := stmt.QueryContext(context.Background(), db, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(dest), 275) require.Equal(t, len(dest), 275)
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json")
requireLogged(t, stmt) requireLogged(t, stmt)
requireQueryLogged(t, stmt, 9423)
} }
// default column aliases from sub-CTEs are bubbled up to the main query, // default column aliases from sub-CTEs are bubbled up to the main query,

View file

@ -25,7 +25,14 @@ WHERE link.name IN ('Gmail', 'Outlook');
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
AssertExec(t, deleteStmt, 2)
res, err := deleteStmt.ExecContext(context.Background(), db)
require.NoError(t, err)
rows, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, rows, int64(2))
requireQueryLogged(t, deleteStmt, int64(2))
} }
func TestDeleteWithWhereAndReturning(t *testing.T) { func TestDeleteWithWhereAndReturning(t *testing.T) {

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/v2/tests/internal/utils/repo" "github.com/go-jet/jet/v2/tests/internal/utils/repo"
"math/rand" "math/rand"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
@ -59,11 +60,21 @@ var loggedSQL string
var loggedSQLArgs []interface{} var loggedSQLArgs []interface{}
var loggedDebugSQL string var loggedDebugSQL string
var queryInfo postgres.QueryInfo
var callerFile string
var callerLine int
var callerFunction string
func init() { func init() {
postgres.SetLogger(func(ctx context.Context, statement postgres.PrintableStatement) { postgres.SetLogger(func(ctx context.Context, statement postgres.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql() loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql() loggedDebugSQL = statement.DebugSql()
}) })
postgres.SetQueryLoggerFunc(func(ctx context.Context, info postgres.QueryInfo) {
queryInfo = info
callerFile, callerLine, callerFunction = info.Caller()
})
} }
func requireLogged(t *testing.T, statement postgres.Statement) { func requireLogged(t *testing.T, statement postgres.Statement) {
@ -73,6 +84,21 @@ func requireLogged(t *testing.T, statement postgres.Statement) {
require.Equal(t, loggedDebugSQL, statement.DebugSql()) require.Equal(t, loggedDebugSQL, statement.DebugSql())
} }
func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) {
query, args := statement.Sql()
queryLogged, argsLogged := queryInfo.Statement.Sql()
require.Equal(t, query, queryLogged)
require.Equal(t, args, argsLogged)
require.Equal(t, queryInfo.RowsProcessed, rowsProcessed)
pc, file, _, _ := runtime.Caller(1)
funcDetails := runtime.FuncForPC(pc)
require.Equal(t, file, callerFile)
require.NotEmpty(t, callerLine)
require.Equal(t, funcDetails.Name(), callerFunction)
}
func skipForPgxDriver(t *testing.T) { func skipForPgxDriver(t *testing.T) {
if isPgxDriver() { if isPgxDriver() {
t.SkipNow() t.SkipNow()

View file

@ -62,7 +62,8 @@ func TestScanToValidDestination(t *testing.T) {
t.Run("global query function scan", func(t *testing.T) { t.Run("global query function scan", func(t *testing.T) {
queryStr, args := oneInventoryQuery.Sql() queryStr, args := oneInventoryQuery.Sql()
dest := []struct{}{} dest := []struct{}{}
err := qrm.Query(nil, db, queryStr, args, &dest) rowProcessed, err := qrm.Query(nil, db, queryStr, args, &dest)
require.Equal(t, rowProcessed, int64(1))
require.NoError(t, err) require.NoError(t, err)
}) })
@ -782,6 +783,7 @@ func TestRowsScan(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
requireLogged(t, stmt) requireLogged(t, stmt)
requireQueryLogged(t, stmt, 0)
} }
func TestScanNumericToFloat(t *testing.T) { func TestScanNumericToFloat(t *testing.T) {

View file

@ -270,7 +270,9 @@ WHERE link.id = 201::integer;
` `
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201))
AssertExec(t, stmt, 1) _, err := stmt.Exec(db)
require.NoError(t, err)
requireQueryLogged(t, stmt, 1)
} }
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {

View file

@ -5,12 +5,14 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"math/rand" "math/rand"
"os" "os"
"os/exec" "os/exec"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -63,11 +65,36 @@ var loggedSQL string
var loggedSQLArgs []interface{} var loggedSQLArgs []interface{}
var loggedDebugSQL string var loggedDebugSQL string
var queryInfo sqlite.QueryInfo
var callerFile string
var callerLine int
var callerFunction string
func init() { func init() {
sqlite.SetLogger(func(ctx context.Context, statement sqlite.PrintableStatement) { sqlite.SetLogger(func(ctx context.Context, statement sqlite.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql() loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql() loggedDebugSQL = statement.DebugSql()
}) })
sqlite.SetQueryLoggerFunc(func(ctx context.Context, info sqlite.QueryInfo) {
queryInfo = info
callerFile, callerLine, callerFunction = info.Caller()
})
}
func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) {
query, args := statement.Sql()
queryLogged, argsLogged := queryInfo.Statement.Sql()
require.Equal(t, query, queryLogged)
require.Equal(t, args, argsLogged)
require.Equal(t, queryInfo.RowsProcessed, rowsProcessed)
pc, file, _, _ := runtime.Caller(1)
funcDetails := runtime.FuncForPC(pc)
require.Equal(t, file, callerFile)
require.NotEmpty(t, callerLine)
require.Equal(t, funcDetails.Name(), callerFunction)
} }
func requireLogged(t *testing.T, statement sqlite.Statement) { func requireLogged(t *testing.T, statement sqlite.Statement) {

View file

@ -39,6 +39,7 @@ WHERE actor.actor_id = ?;
testutils.AssertDeepEqual(t, actor, actor2) testutils.AssertDeepEqual(t, actor, actor2)
requireLogged(t, query) requireLogged(t, query)
requireQueryLogged(t, query, 1)
} }
var actor2 = model.Actor{ var actor2 = model.Actor{
@ -63,7 +64,7 @@ ORDER BY actor.actor_id;
`) `)
dest := []model.Actor{} dest := []model.Actor{}
err := query.Query(db, &dest) err := query.QueryContext(context.Background(), db, &dest)
require.NoError(t, err) require.NoError(t, err)
@ -73,6 +74,7 @@ ORDER BY actor.actor_id;
//testutils.SaveJSONFile(dest, "./testdata/results/sqlite/all_actors.json") //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/all_actors.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/all_actors.json")
requireLogged(t, query) requireLogged(t, query)
requireQueryLogged(t, query, 200)
} }
func TestSelectGroupByHaving(t *testing.T) { func TestSelectGroupByHaving(t *testing.T) {