From 4955bfc4b5ce0ff96aa6a74a103c94defa407cc3 Mon Sep 17 00:00:00 2001 From: go-jet Date: Wed, 12 Jan 2022 19:03:50 +0100 Subject: [PATCH] Add automatic query logger function with additional execution details. --- generator/mysql/query_set.go | 6 +-- generator/postgres/query_set.go | 6 +-- generator/sqlite/query_set.go | 4 +- internal/jet/logger.go | 66 ++++++++++++++++++++++++++- internal/jet/statement.go | 76 +++++++++++++++++++++++-------- mysql/types.go | 7 +++ postgres/types.go | 9 +++- qrm/qrm.go | 32 +++++-------- sqlite/types.go | 9 +++- tests/mysql/main_test.go | 26 +++++++++++ tests/mysql/select_test.go | 6 ++- tests/postgres/chinook_db_test.go | 4 +- tests/postgres/delete_test.go | 9 +++- tests/postgres/main_test.go | 26 +++++++++++ tests/postgres/scan_test.go | 4 +- tests/postgres/update_test.go | 4 +- tests/sqlite/main_test.go | 27 +++++++++++ tests/sqlite/select_test.go | 4 +- 18 files changed, 266 insertions(+), 59 deletions(-) diff --git a/generator/mysql/query_set.go b/generator/mysql/query_set.go index 85c4278..5847be4 100644 --- a/generator/mysql/query_set.go +++ b/generator/mysql/query_set.go @@ -20,7 +20,7 @@ WHERE table_schema = ? and table_type = ?; ` var tables []metadata.Table - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) throw.OnError(err) for i := range tables { @@ -53,7 +53,7 @@ WHERE table_schema = ? AND table_name = ? ORDER BY ordinal_position; ` var columns []metadata.Column - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName, schemaName, tableName}, &columns) throw.OnError(err) return columns @@ -72,7 +72,7 @@ WHERE c.table_schema = ? AND DATA_TYPE = 'enum'; Values string } - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &queryResult) throw.OnError(err) var ret []metadata.Enum diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index e2fb969..93e6ffb 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -19,7 +19,7 @@ WHERE table_schema = $1 and table_type = $2; ` var tables []metadata.Table - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableType}, &tables) throw.OnError(err) for i := range tables { @@ -58,7 +58,7 @@ where table_schema = $1 and table_name = $2 order by ordinal_position; ` var columns []metadata.Column - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName, tableName}, &columns) throw.OnError(err) return columns @@ -76,7 +76,7 @@ ORDER BY n.nspname, t.typname, e.enumsortorder;` var result []metadata.Enum - err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) + _, err := qrm.Query(context.Background(), db, query, []interface{}{schemaName}, &result) throw.OnError(err) return result diff --git a/generator/sqlite/query_set.go b/generator/sqlite/query_set.go index e1d5e4d..c11f210 100644 --- a/generator/sqlite/query_set.go +++ b/generator/sqlite/query_set.go @@ -28,7 +28,7 @@ func (p sqliteQuerySet) GetTablesMetaData(db *sql.DB, schemaName string, tableTy var tables []metadata.Table - err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) + _, err := qrm.Query(context.Background(), db, query, []interface{}{sqlTableType}, &tables) throw.OnError(err) for i := range tables { @@ -47,7 +47,7 @@ func (p sqliteQuerySet) GetTableColumnsMetaData(db *sql.DB, schemaName string, t Pk int32 } - err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) + _, err := qrm.Query(context.Background(), db, query, []interface{}{tableName}, &columnInfos) throw.OnError(err) var columns []metadata.Column diff --git a/internal/jet/logger.go b/internal/jet/logger.go index c900fc0..b883fee 100644 --- a/internal/jet/logger.go +++ b/internal/jet/logger.go @@ -1,6 +1,11 @@ package jet -import "context" +import ( + "context" + "runtime" + "strings" + "time" +) // PrintableStatement is a statement which sql query can be logged type PrintableStatement interface { @@ -8,7 +13,7 @@ type PrintableStatement interface { DebugSql() (query string) } -// LoggerFunc is a definition of a function user can implement to support automatic statement logging. +// LoggerFunc is a function user can implement to support automatic statement logging. type LoggerFunc func(ctx context.Context, statement PrintableStatement) var logger LoggerFunc @@ -17,3 +22,60 @@ var logger LoggerFunc func SetLoggerFunc(loggerFunc LoggerFunc) { logger = loggerFunc } + +func callLogger(ctx context.Context, statement Statement) { + if logger != nil { + logger(ctx, statement) + } +} + +// QueryInfo contains information about executed query +type QueryInfo struct { + Statement PrintableStatement + // Depending of statement execution method RowsProcessed is: + // - Number of rows returned for Query() and QueryContext() methods + // - RowsAffected() for Exec() and ExecContext() methods + // - Always 0 for Rows() method. + RowsProcessed int64 + Duration time.Duration + Err error +} + +// QueryLoggerFunc is a function user can implement to retrieve more information about statement executed. +type QueryLoggerFunc func(ctx context.Context, info QueryInfo) + +var queryLoggerFunc QueryLoggerFunc + +// SetQueryLoggerFunc sets automatic query logging function. +func SetQueryLoggerFunc(loggerFunc QueryLoggerFunc) { + queryLoggerFunc = loggerFunc +} + +func callQueryLoggerFunc(ctx context.Context, info QueryInfo) { + if queryLoggerFunc != nil { + queryLoggerFunc(ctx, info) + } +} + +// Caller returns information about statement caller +func (q QueryInfo) Caller() (file string, line int, function string) { + skip := 4 + // depending on execution type (Query, QueryContext, Exec, ...) looped once or twice + for { + var pc uintptr + var ok bool + + pc, file, line, ok = runtime.Caller(skip) + if !ok { + return + } + + funcDetails := runtime.FuncForPC(pc) + if !strings.Contains(funcDetails.Name(), "github.com/go-jet/jet/v2/internal") { + function = funcDetails.Name() + return + } + + skip++ + } +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index a5ae83b..b205801 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "github.com/go-jet/jet/v2/qrm" + "time" ) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) @@ -21,9 +22,9 @@ type Statement interface { // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error - //Exec executes statement over db connection/transaction without returning any rows. + // Exec executes statement over db connection/transaction without returning any rows. Exec(db qrm.DB) (sql.Result, error) - //Exec executes statement with context over db connection/transaction without returning any rows. + // ExecContext executes statement with context over db connection/transaction without returning any rows. ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) // Rows executes statements over db connection/transaction and returns rows Rows(ctx context.Context, db qrm.DB) (*Rows, error) @@ -84,12 +85,7 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { } func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { - query, args := s.Sql() - ctx := context.Background() - - callLogger(ctx, s) - - return qrm.Query(ctx, db, query, args, destination) + return s.QueryContext(context.Background(), db, destination) } func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { @@ -97,15 +93,25 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db callLogger(ctx, s) - return qrm.Query(ctx, db, query, args, destination) + var rowsProcessed int64 + var err error + + duration := duration(func() { + rowsProcessed, err = qrm.Query(ctx, db, query, args, destination) + }) + + callQueryLoggerFunc(ctx, QueryInfo{ + Statement: s, + RowsProcessed: rowsProcessed, + Duration: duration, + Err: err, + }) + + return err } func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { - query, args := s.Sql() - - callLogger(context.Background(), s) - - return db.Exec(query, args...) + return s.ExecContext(context.Background(), db) } func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { @@ -113,7 +119,24 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q callLogger(ctx, s) - return db.ExecContext(ctx, query, args...) + duration := duration(func() { + res, err = db.ExecContext(ctx, query, args...) + }) + + var rowsAffected int64 + + if err == nil { + rowsAffected, _ = res.RowsAffected() + } + + callQueryLoggerFunc(ctx, QueryInfo{ + Statement: s, + RowsProcessed: rowsAffected, + Duration: duration, + Err: err, + }) + + return res, err } func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) { @@ -121,7 +144,18 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) callLogger(ctx, s) - rows, err := db.QueryContext(ctx, query, args...) + var rows *sql.Rows + var err error + + duration := duration(func() { + rows, err = db.QueryContext(ctx, query, args...) + }) + + callQueryLoggerFunc(ctx, QueryInfo{ + Statement: s, + Duration: duration, + Err: err, + }) if err != nil { return nil, err @@ -130,10 +164,12 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) return &Rows{rows}, nil } -func callLogger(ctx context.Context, statement Statement) { - if logger != nil { - logger(ctx, statement) - } +func duration(f func()) time.Duration { + start := time.Now() + + f() + + return time.Now().Sub(start) } // ExpressionStatement interfacess diff --git a/mysql/types.go b/mysql/types.go index c82962f..3e15043 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -24,4 +24,11 @@ type OrderByClause = jet.OrderByClause type GroupByClause = jet.GroupByClause // SetLogger sets automatic statement logging +// Deprecated: use SetQueryLoggerFunc instead. var SetLogger = jet.SetLoggerFunc + +// SetQueryLoggerFunc sets automatic query logging function. +var SetQueryLoggerFunc = jet.SetQueryLoggerFunc + +// QueryInfo contains information about executed query +type QueryInfo = jet.QueryInfo diff --git a/postgres/types.go b/postgres/types.go index 6fed21b..1e0d7ea 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause // GroupByClause interface to use as input for GROUP_BY type GroupByClause = jet.GroupByClause -// SetLogger sets automatic statement logging +// SetLogger sets automatic statement logging function +// Deprecated: use SetQueryLoggerFunc instead. var SetLogger = jet.SetLoggerFunc + +// SetQueryLoggerFunc sets automatic query logging function. +var SetQueryLoggerFunc = jet.SetQueryLoggerFunc + +// QueryInfo contains information about executed query +type QueryInfo = jet.QueryInfo diff --git a/qrm/qrm.go b/qrm/qrm.go index 473f8f3..3731c68 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -17,7 +17,7 @@ var ErrNoRows = errors.New("qrm: no rows in result set") // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. -func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { +func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { utils.MustBeInitializedPtr(db, "jet: db is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") @@ -26,11 +26,11 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr destinationPtrType := reflect.TypeOf(destPtr) if destinationPtrType.Elem().Kind() == reflect.Slice { - _, err := queryToSlice(ctx, db, query, args, destPtr) + rowsProcessed, err := queryToSlice(ctx, db, query, args, destPtr) if err != nil { - return fmt.Errorf("jet: %w", err) + return rowsProcessed, fmt.Errorf("jet: %w", err) } - return nil + return rowsProcessed, nil } else if destinationPtrType.Elem().Kind() == reflect.Struct { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() @@ -38,16 +38,16 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) if err != nil { - return fmt.Errorf("jet: %w", err) + return rowsProcessed, fmt.Errorf("jet: %w", err) } if rowsProcessed == 0 { - return ErrNoRows + return 0, ErrNoRows } // edge case when row result set contains only NULLs. if tempSliceValue.Len() == 0 { - return nil + return rowsProcessed, nil } structValue := reflect.ValueOf(destPtr).Elem() @@ -56,7 +56,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr if structValue.Type().AssignableTo(firstTempStruct.Type()) { structValue.Set(tempSliceValue.Index(0).Elem()) } - return nil + return rowsProcessed, nil } else { panic("jet: destination has to be a pointer to slice or pointer to struct") } @@ -136,7 +136,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, err = rows.Scan(scanContext.row...) if err != nil { - return + return scanContext.rowNum, err } scanContext.rowNum++ @@ -144,24 +144,16 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, _, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil) if err != nil { - return + return scanContext.rowNum, err } } err = rows.Close() if err != nil { - return + return scanContext.rowNum, err } - err = rows.Err() - - if err != nil { - return - } - - rowsProcessed = scanContext.rowNum - - return + return scanContext.rowNum, rows.Err() } func mapRowToSlice( diff --git a/sqlite/types.go b/sqlite/types.go index 755be1d..91002c2 100644 --- a/sqlite/types.go +++ b/sqlite/types.go @@ -23,5 +23,12 @@ type OrderByClause = jet.OrderByClause // GroupByClause interface to use as input for GROUP_BY type GroupByClause = jet.GroupByClause -// SetLogger sets automatic statement logging +// SetLogger sets automatic statement logging. +// Deprecated: use SetQueryLoggerFunc instead. var SetLogger = jet.SetLoggerFunc + +// SetQueryLoggerFunc sets automatic query logging function. +var SetQueryLoggerFunc = jet.SetQueryLoggerFunc + +// QueryInfo contains information about executed query +type QueryInfo = jet.QueryInfo diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index 6a35cb8..75d4ab5 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -8,6 +8,7 @@ import ( "github.com/go-jet/jet/v2/tests/dbconfig" "github.com/stretchr/testify/require" "math/rand" + "runtime" "time" _ "github.com/go-sql-driver/mysql" @@ -51,11 +52,21 @@ var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string +var queryInfo jetmysql.QueryInfo +var callerFile string +var callerLine int +var callerFunction string + func init() { jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) + + jetmysql.SetQueryLoggerFunc(func(ctx context.Context, info jetmysql.QueryInfo) { + queryInfo = info + callerFile, callerLine, callerFunction = info.Caller() + }) } func requireLogged(t *testing.T, statement postgres.Statement) { @@ -65,6 +76,21 @@ func requireLogged(t *testing.T, statement postgres.Statement) { require.Equal(t, loggedDebugSQL, statement.DebugSql()) } +func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { + query, args := statement.Sql() + queryLogged, argsLogged := queryInfo.Statement.Sql() + + require.Equal(t, query, queryLogged) + require.Equal(t, args, argsLogged) + require.Equal(t, queryInfo.RowsProcessed, rowsProcessed) + + pc, file, _, _ := runtime.Caller(1) + funcDetails := runtime.FuncForPC(pc) + require.Equal(t, file, callerFile) + require.NotEmpty(t, callerLine) + require.Equal(t, funcDetails.Name(), callerFunction) +} + func skipForMariaDB(t *testing.T) { if sourceIsMariaDB() { t.SkipNow() diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 5ae7280..39f0e43 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -38,6 +38,7 @@ WHERE actor.actor_id = ?; testutils.AssertDeepEqual(t, actor, actor2) requireLogged(t, query) + requireQueryLogged(t, query, 1) } var actor2 = model.Actor{ @@ -60,9 +61,9 @@ SELECT actor.actor_id AS "actor.actor_id", FROM dvds.actor ORDER BY actor.actor_id; `) - dest := []model.Actor{} + var dest []model.Actor - err := query.Query(db, &dest) + err := query.QueryContext(context.Background(), db, &dest) require.NoError(t, err) @@ -73,6 +74,7 @@ ORDER BY actor.actor_id; //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json") requireLogged(t, query) + requireQueryLogged(t, query, 200) } func TestSelectGroupByHaving(t *testing.T) { diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index f5a1701..1e58a12 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -35,6 +35,7 @@ ORDER BY "Album"."AlbumId" ASC; testutils.AssertDeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[len(dest)-1], album347) requireLogged(t, stmt) + requireQueryLogged(t, stmt, 347) } func TestJoinEverything(t *testing.T) { @@ -191,12 +192,13 @@ FROM chinook."Artist" ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId", "Genre"."GenreId", "MediaType"."MediaTypeId", "Playlist"."PlaylistId", "Invoice"."InvoiceId", "Customer"."CustomerId"; `) - err := stmt.Query(db, &dest) + err := stmt.QueryContext(context.Background(), db, &dest) require.NoError(t, err) require.Equal(t, len(dest), 275) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") requireLogged(t, stmt) + requireQueryLogged(t, stmt, 9423) } // default column aliases from sub-CTEs are bubbled up to the main query, diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index c1b9c6a..9ae9ec9 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -25,7 +25,14 @@ WHERE link.name IN ('Gmail', 'Outlook'); WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") - AssertExec(t, deleteStmt, 2) + + res, err := deleteStmt.ExecContext(context.Background(), db) + + require.NoError(t, err) + rows, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, rows, int64(2)) + requireQueryLogged(t, deleteStmt, int64(2)) } func TestDeleteWithWhereAndReturning(t *testing.T) { diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index cc20646..b4b011b 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/v2/tests/internal/utils/repo" "math/rand" "os" + "runtime" "testing" "time" @@ -59,11 +60,21 @@ var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string +var queryInfo postgres.QueryInfo +var callerFile string +var callerLine int +var callerFunction string + func init() { postgres.SetLogger(func(ctx context.Context, statement postgres.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) + + postgres.SetQueryLoggerFunc(func(ctx context.Context, info postgres.QueryInfo) { + queryInfo = info + callerFile, callerLine, callerFunction = info.Caller() + }) } func requireLogged(t *testing.T, statement postgres.Statement) { @@ -73,6 +84,21 @@ func requireLogged(t *testing.T, statement postgres.Statement) { require.Equal(t, loggedDebugSQL, statement.DebugSql()) } +func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { + query, args := statement.Sql() + queryLogged, argsLogged := queryInfo.Statement.Sql() + + require.Equal(t, query, queryLogged) + require.Equal(t, args, argsLogged) + require.Equal(t, queryInfo.RowsProcessed, rowsProcessed) + + pc, file, _, _ := runtime.Caller(1) + funcDetails := runtime.FuncForPC(pc) + require.Equal(t, file, callerFile) + require.NotEmpty(t, callerLine) + require.Equal(t, funcDetails.Name(), callerFunction) +} + func skipForPgxDriver(t *testing.T) { if isPgxDriver() { t.SkipNow() diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index ce3cc46..61b7bec 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -62,7 +62,8 @@ func TestScanToValidDestination(t *testing.T) { t.Run("global query function scan", func(t *testing.T) { queryStr, args := oneInventoryQuery.Sql() dest := []struct{}{} - err := qrm.Query(nil, db, queryStr, args, &dest) + rowProcessed, err := qrm.Query(nil, db, queryStr, args, &dest) + require.Equal(t, rowProcessed, int64(1)) require.NoError(t, err) }) @@ -782,6 +783,7 @@ func TestRowsScan(t *testing.T) { require.NoError(t, err) requireLogged(t, stmt) + requireQueryLogged(t, stmt, 0) } func TestScanNumericToFloat(t *testing.T) { diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 476333c..cf8a3b8 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -270,7 +270,9 @@ WHERE link.id = 201::integer; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) - AssertExec(t, stmt, 1) + _, err := stmt.Exec(db) + require.NoError(t, err) + requireQueryLogged(t, stmt, 1) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 710f7ad..76d1d58 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -5,12 +5,14 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/dbconfig" "github.com/stretchr/testify/require" "math/rand" "os" "os/exec" + "runtime" "strings" "testing" "time" @@ -63,11 +65,36 @@ var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string +var queryInfo sqlite.QueryInfo +var callerFile string +var callerLine int +var callerFunction string + func init() { sqlite.SetLogger(func(ctx context.Context, statement sqlite.PrintableStatement) { loggedSQL, loggedSQLArgs = statement.Sql() loggedDebugSQL = statement.DebugSql() }) + + sqlite.SetQueryLoggerFunc(func(ctx context.Context, info sqlite.QueryInfo) { + queryInfo = info + callerFile, callerLine, callerFunction = info.Caller() + }) +} + +func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { + query, args := statement.Sql() + queryLogged, argsLogged := queryInfo.Statement.Sql() + + require.Equal(t, query, queryLogged) + require.Equal(t, args, argsLogged) + require.Equal(t, queryInfo.RowsProcessed, rowsProcessed) + + pc, file, _, _ := runtime.Caller(1) + funcDetails := runtime.FuncForPC(pc) + require.Equal(t, file, callerFile) + require.NotEmpty(t, callerLine) + require.Equal(t, funcDetails.Name(), callerFunction) } func requireLogged(t *testing.T, statement sqlite.Statement) { diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index ce31c76..657fb3a 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -39,6 +39,7 @@ WHERE actor.actor_id = ?; testutils.AssertDeepEqual(t, actor, actor2) requireLogged(t, query) + requireQueryLogged(t, query, 1) } var actor2 = model.Actor{ @@ -63,7 +64,7 @@ ORDER BY actor.actor_id; `) dest := []model.Actor{} - err := query.Query(db, &dest) + err := query.QueryContext(context.Background(), db, &dest) require.NoError(t, err) @@ -73,6 +74,7 @@ ORDER BY actor.actor_id; //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/all_actors.json") requireLogged(t, query) + requireQueryLogged(t, query, 200) } func TestSelectGroupByHaving(t *testing.T) {