From 0d3ec872d6b7351aa890028c0fdba710c9e2c2f8 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 10 May 2020 11:41:07 +0200 Subject: [PATCH] Add support for automatic query logging. --- internal/jet/logger.go | 19 +++++++++++++++++++ internal/jet/statement.go | 31 +++++++++++++++++++++++-------- mysql/types.go | 6 ++++++ postgres/types.go | 6 ++++++ tests/mysql/alltypes_test.go | 1 + tests/mysql/cast_test.go | 2 ++ tests/mysql/delete_test.go | 3 +++ tests/mysql/insert_test.go | 2 ++ tests/mysql/lock_test.go | 3 +++ tests/mysql/main_test.go | 22 ++++++++++++++++++++++ tests/mysql/select_test.go | 3 +++ tests/mysql/update_test.go | 8 ++++++-- tests/postgres/alltypes_test.go | 3 +++ tests/postgres/chinook_db_test.go | 2 ++ tests/postgres/delete_test.go | 3 +++ tests/postgres/insert_test.go | 3 +++ tests/postgres/lock_test.go | 2 ++ tests/postgres/main_test.go | 21 +++++++++++++++++++++ tests/postgres/northwind_test.go | 1 + tests/postgres/sample_test.go | 3 +++ tests/postgres/select_test.go | 6 ++++++ tests/postgres/update_test.go | 15 ++++++++++++--- 22 files changed, 152 insertions(+), 13 deletions(-) create mode 100644 internal/jet/logger.go diff --git a/internal/jet/logger.go b/internal/jet/logger.go new file mode 100644 index 0000000..90818b0 --- /dev/null +++ b/internal/jet/logger.go @@ -0,0 +1,19 @@ +package jet + +import "context" + +// LoggableStatement is a statement which sql query can be logged +type LoggableStatement interface { + Sql() (query string, args []interface{}) + DebugSql() (query string) +} + +// LoggerFunc is a definition of a function user can implement to support automatic statement logging. +type LoggerFunc func(ctx context.Context, statement LoggableStatement) + +var logger LoggerFunc + +// SetLoggerFunc sets automatic statement logging +func SetLoggerFunc(loggerFunc LoggerFunc) { + logger = loggerFunc +} diff --git a/internal/jet/statement.go b/internal/jet/statement.go index beb52f1..37b2077 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -13,7 +13,6 @@ type Statement interface { // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // Do not use it in production. Use it only for debug purposes. DebugSql() (query string) - // Query executes statement over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. @@ -21,12 +20,12 @@ type Statement interface { // QueryContext executes statement with a context over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. - QueryContext(context context.Context, db qrm.DB, destination interface{}) error + QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error //Exec executes statement over db connection without returning any rows. Exec(db qrm.DB) (sql.Result, error) //Exec executes statement with context over db connection without returning any rows. - ExecContext(context context.Context, db qrm.DB) (sql.Result, error) + ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) } // SerializerStatement interface @@ -75,25 +74,41 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { query, args := s.Sql() + ctx := context.Background() - return qrm.Query(context.Background(), db, query, args, destination) + callLogger(ctx, s) + + return qrm.Query(ctx, db, query, args, destination) } -func (s *serializerStatementInterfaceImpl) QueryContext(context context.Context, db qrm.DB, destination interface{}) error { +func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { query, args := s.Sql() - return qrm.Query(context, db, query, args, destination) + callLogger(ctx, s) + + return qrm.Query(ctx, db, query, args, destination) } func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { query, args := s.Sql() + + callLogger(context.Background(), s) + return db.Exec(query, args...) } -func (s *serializerStatementInterfaceImpl) ExecContext(context context.Context, db qrm.DB) (res sql.Result, err error) { +func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { query, args := s.Sql() - return db.ExecContext(context, query, args...) + callLogger(ctx, s) + + return db.ExecContext(ctx, query, args...) +} + +func callLogger(ctx context.Context, statement Statement) { + if logger != nil { + logger(ctx, statement) + } } // ExpressionStatement interfacess diff --git a/mysql/types.go b/mysql/types.go index 908fce5..7e1424f 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -13,3 +13,9 @@ type ProjectionList = jet.ProjectionList // ColumnAssigment is interface wrapper around column assigment type ColumnAssigment = jet.ColumnAssigment + +// LoggableStatement is a statement which sql query can be logged +type LoggableStatement = jet.LoggableStatement + +// SetLogger sets automatic statement logging +var SetLogger = jet.SetLoggerFunc diff --git a/postgres/types.go b/postgres/types.go index 48de455..cfb52ec 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -13,3 +13,9 @@ type ProjectionList = jet.ProjectionList // ColumnAssigment is interface wrapper around column assigment type ColumnAssigment = jet.ColumnAssigment + +// LoggableStatement is a statement which sql query can be logged +type LoggableStatement = jet.LoggableStatement + +// SetLogger sets automatic statement logging +var SetLogger = jet.SetLoggerFunc diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 7d9e17e..d791d42 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -80,6 +80,7 @@ func TestUUID(t *testing.T) { require.True(t, dest.UUID.String() != uuid.UUID{}.String()) require.True(t, dest.StrUUID.String() != uuid.UUID{}.String()) require.Equal(t, dest.StrUUID.String(), dest.BinUUID.String()) + requireLogged(t, query) } func TestExpressionOperators(t *testing.T) { diff --git a/tests/mysql/cast_test.go b/tests/mysql/cast_test.go index 218665e..fda79e7 100644 --- a/tests/mysql/cast_test.go +++ b/tests/mysql/cast_test.go @@ -68,4 +68,6 @@ FROM test_sample.all_types; Unsigned: 15, Binary: "Some text", }) + + requireLogged(t, query) } diff --git a/tests/mysql/delete_test.go b/tests/mysql/delete_test.go index da91e97..90d15cc 100644 --- a/tests/mysql/delete_test.go +++ b/tests/mysql/delete_test.go @@ -24,6 +24,7 @@ WHERE link.name IN ('Gmail', 'Outlook'); testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") testutils.AssertExec(t, deleteStmt, db, 2) + requireLogged(t, deleteStmt) } func TestDeleteWithWhereOrderByLimit(t *testing.T) { @@ -43,6 +44,7 @@ LIMIT 1; testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1)) testutils.AssertExec(t, deleteStmt, db, 1) + requireLogged(t, deleteStmt) } func TestDeleteQueryContext(t *testing.T) { @@ -61,6 +63,7 @@ func TestDeleteQueryContext(t *testing.T) { err := deleteStmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) } func TestDeleteExecContext(t *testing.T) { diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 613a655..43091b2 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -34,6 +34,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT _, err := insertQuery.Exec(db) require.NoError(t, err) + requireLogged(t, insertQuery) insertedLinks := []model.Link{} @@ -82,6 +83,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT _, err := stmt.Exec(db) require.NoError(t, err) + requireLogged(t, stmt) insertedLinks := []model.Link{} diff --git a/tests/mysql/lock_test.go b/tests/mysql/lock_test.go index 8aed571..c44c436 100644 --- a/tests/mysql/lock_test.go +++ b/tests/mysql/lock_test.go @@ -17,6 +17,7 @@ LOCK TABLES dvds.customer READ; _, err := query.Exec(db) require.NoError(t, err) + requireLogged(t, query) } func TestLockWrite(t *testing.T) { @@ -28,6 +29,7 @@ LOCK TABLES dvds.customer WRITE; _, err := query.Exec(db) require.NoError(t, err) + requireLogged(t, query) } func TestUnlockTables(t *testing.T) { @@ -39,4 +41,5 @@ UNLOCK TABLES; _, err := query.Exec(db) require.NoError(t, err) + requireLogged(t, query) } diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index c7db884..0f51875 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -1,9 +1,13 @@ package mysql import ( + "context" "database/sql" "flag" + jetmysql "github.com/go-jet/jet/mysql" + "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/dbconfig" + "github.com/stretchr/testify/require" "math/rand" "time" @@ -44,3 +48,21 @@ func TestMain(m *testing.M) { os.Exit(ret) } + +var loggedSQL string +var loggedSQLArgs []interface{} +var loggedDebugSQL string + +func init() { + jetmysql.SetLogger(func(ctx context.Context, statement jetmysql.LoggableStatement) { + loggedSQL, loggedSQLArgs = statement.Sql() + loggedDebugSQL = statement.DebugSql() + }) +} + +func requireLogged(t *testing.T, statement postgres.Statement) { + query, args := statement.Sql() + require.Equal(t, loggedSQL, query) + require.Equal(t, loggedSQLArgs, args) + require.Equal(t, loggedDebugSQL, statement.DebugSql()) +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index e5b748e..5fd8fdc 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -33,6 +33,7 @@ WHERE actor.actor_id = ?; require.NoError(t, err) testutils.AssertDeepEqual(t, actor, actor2) + requireLogged(t, query) } var actor2 = model.Actor{ @@ -67,6 +68,7 @@ ORDER BY actor.actor_id; //testutils.PrintJson(dest) //testutils.SaveJsonFile(dest, "mysql/testdata/all_actors.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/all_actors.json") + requireLogged(t, query) } func TestSelectGroupByHaving(t *testing.T) { @@ -144,6 +146,7 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC; //testutils.SaveJsonFile(dest, "mysql/testdata/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/mysql/customer_payment_sum.json") + requireLogged(t, query) } func TestSubQuery(t *testing.T) { diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index a689584..94a6716 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -30,6 +30,7 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertExec(t, query, db) + requireLogged(t, query) }) t.Run("new version", func(t *testing.T) { @@ -42,6 +43,7 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertExec(t, stmt, db) + requireLogged(t, stmt) }) links := []model.Link{} @@ -88,6 +90,7 @@ WHERE link.name = ?; testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") testutils.AssertExec(t, query, db) + requireLogged(t, query) }) t.Run("new version", func(t *testing.T) { @@ -105,6 +108,7 @@ WHERE link.name = ?; testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") testutils.AssertExec(t, query, db) + requireLogged(t, query) }) } @@ -130,10 +134,10 @@ SET id = ?, description = ? WHERE link.id = ?; ` - fmt.Println(stmt.Sql()) testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) testutils.AssertExec(t, stmt, db) + requireLogged(t, stmt) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { @@ -165,10 +169,10 @@ WHERE link.id = 201; testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) testutils.AssertExec(t, stmt, db) + requireLogged(t, stmt) } func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { - setupLinkTableForUpdateTest(t) link := model.Link{ diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index d4e6ab4..3fa5543 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -834,6 +834,7 @@ func TestInterval(t *testing.T) { err := stmt.Query(db, &struct{}{}) require.NoError(t, err) + requireLogged(t, stmt) } func TestSubQueryColumnReference(t *testing.T) { @@ -1009,6 +1010,7 @@ FROM` require.NoError(t, err) testutils.AssertDeepEqual(t, dest1, dest2) + requireLogged(t, stmt2) } } @@ -1062,6 +1064,7 @@ LIMIT $6; "Timestamp": "2009-11-17T20:34:58.651387Z" } `) + requireLogged(t, query) } var allTypesRow0 = model.AllTypes{ diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 2695981..5c12010 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -35,6 +35,7 @@ ORDER BY "Album"."AlbumId" ASC; testutils.AssertDeepEqual(t, dest[0], album1) testutils.AssertDeepEqual(t, dest[1], album2) testutils.AssertDeepEqual(t, dest[len(dest)-1], album347) + requireLogged(t, stmt) } func TestJoinEverything(t *testing.T) { @@ -106,6 +107,7 @@ func TestJoinEverything(t *testing.T) { require.NoError(t, err) require.Equal(t, len(dest), 275) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/joined_everything.json") + requireLogged(t, stmt) } func TestSelfJoin(t *testing.T) { diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index 18080fc..01104d4 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -53,6 +53,7 @@ RETURNING link.id AS "link.id", require.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") + requireLogged(t, deleteStmt) } func initForDeleteTest(t *testing.T) { @@ -80,6 +81,7 @@ func TestDeleteQueryContext(t *testing.T) { err := deleteStmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) } func TestDeleteExecContext(t *testing.T) { @@ -99,4 +101,5 @@ func TestDeleteExecContext(t *testing.T) { _, err := deleteStmt.ExecContext(ctx, db) require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) } diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index a7facea..e7dac15 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -89,6 +89,7 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") AssertExec(t, stmt, 1) + requireLogged(t, stmt) } func TestInsertOnConflict(t *testing.T) { @@ -108,6 +109,7 @@ VALUES ($1, $2, $3, $4, $5), ON CONFLICT (employee_id) DO NOTHING; `) AssertExec(t, stmt, 1) + requireLogged(t, stmt) }) t.Run("on constraint do nothing", func(t *testing.T) { @@ -125,6 +127,7 @@ VALUES ($1, $2, $3, $4, $5), ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) AssertExec(t, stmt, 1) + requireLogged(t, stmt) }) t.Run("do update", func(t *testing.T) { diff --git a/tests/postgres/lock_test.go b/tests/postgres/lock_test.go index ce55874..c27adf3 100644 --- a/tests/postgres/lock_test.go +++ b/tests/postgres/lock_test.go @@ -40,6 +40,7 @@ LOCK TABLE dvds.address IN` err = tx.Rollback() require.NoError(t, err) + requireLogged(t, query) } for _, lockMode := range testData { @@ -56,6 +57,7 @@ LOCK TABLE dvds.address IN` err = tx.Rollback() require.NoError(t, err) + requireLogged(t, query) } } diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index fd538d6..5fa23d5 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -1,10 +1,13 @@ package postgres import ( + "context" "database/sql" + "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/dbconfig" _ "github.com/lib/pq" "github.com/pkg/profile" + "github.com/stretchr/testify/require" "math/rand" "os" "os/exec" @@ -43,3 +46,21 @@ func setTestRoot() { testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" } + +var loggedSQL string +var loggedSQLArgs []interface{} +var loggedDebugSQL string + +func init() { + postgres.SetLogger(func(ctx context.Context, statement postgres.LoggableStatement) { + loggedSQL, loggedSQLArgs = statement.Sql() + loggedDebugSQL = statement.DebugSql() + }) +} + +func requireLogged(t *testing.T, statement postgres.Statement) { + query, args := statement.Sql() + require.Equal(t, loggedSQL, query) + require.Equal(t, loggedSQLArgs, args) + require.Equal(t, loggedDebugSQL, statement.DebugSql()) +} diff --git a/tests/postgres/northwind_test.go b/tests/postgres/northwind_test.go index 80ab589..e45661a 100644 --- a/tests/postgres/northwind_test.go +++ b/tests/postgres/northwind_test.go @@ -63,4 +63,5 @@ func TestNorthwindJoinEverything(t *testing.T) { //jsonSave("./testdata/northwind-all.json", dest) testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/northwind-all.json") + requireLogged(t, stmt) } diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 698e648..b3429fc 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -28,6 +28,7 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; require.NoError(t, err) require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + requireLogged(t, query) } func TestUUIDComplex(t *testing.T) { @@ -118,6 +119,7 @@ func TestUUIDComplex(t *testing.T) { ] } `) + requireLogged(t, query) }) t.Run("slice of structs left join", func(t *testing.T) { @@ -175,6 +177,7 @@ func TestUUIDComplex(t *testing.T) { } ] `) + requireLogged(t, leftQuery) }) } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 5c52e71..35f9803 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -43,6 +43,8 @@ WHERE actor.actor_id = 2; } testutils.AssertDeepEqual(t, actor, expectedActor) + + requireLogged(t, query) } func TestClassicSelect(t *testing.T) { @@ -86,6 +88,8 @@ LIMIT 30; require.NoError(t, err) require.Equal(t, len(dest), 30) + + requireLogged(t, query) } func TestSelect_ScanToSlice(t *testing.T) { @@ -117,6 +121,8 @@ ORDER BY customer.customer_id ASC; testutils.AssertDeepEqual(t, customer0, customers[0]) testutils.AssertDeepEqual(t, customer1, customers[1]) testutils.AssertDeepEqual(t, lastCustomer, customers[598]) + + requireLogged(t, query) } func TestSelectAndUnionInProjection(t *testing.T) { diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 0862fb3..5a7cdcd 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -27,13 +27,15 @@ WHERE link.name = 'Bing'; `, "Bong", "http://bong.com", "Bing") testutils.AssertExec(t, query, db, 1) + requireLogged(t, query) links := []model.Link{} - err := Link. + selQuery := Link. SELECT(Link.AllColumns). - WHERE(Link.Name.IN(String("Bong"))). - Query(db, &links) + WHERE(Link.Name.IN(String("Bong"))) + + err := selQuery.Query(db, &links) require.NoError(t, err) require.Equal(t, len(links), 1) @@ -42,6 +44,7 @@ WHERE link.name = 'Bing'; URL: "http://bong.com", Name: "Bong", }) + requireLogged(t, selQuery) }) t.Run("new version", func(t *testing.T) { @@ -59,6 +62,7 @@ SET name = 'DuckDuckGo', WHERE link.name = 'Yahoo'; `) testutils.AssertExec(t, stmt, db, 1) + requireLogged(t, stmt) }) } @@ -90,6 +94,7 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") AssertExec(t, query, 1) + requireLogged(t, query) }) t.Run("new version", func(t *testing.T) { @@ -114,6 +119,9 @@ SET name = $1, ) WHERE link.name = $3; `, "Bong", "Bing", "Bing") + _, err := query.Exec(db) + require.NoError(t, err) + requireLogged(t, query) }) } @@ -146,6 +154,7 @@ RETURNING link.id AS "link.id", require.Equal(t, len(links), 2) require.Equal(t, links[0].Name, "DuckDuckGo") require.Equal(t, links[1].Name, "DuckDuckGo") + requireLogged(t, stmt) } func TestUpdateWithSelect(t *testing.T) {