From 3b0285cc4b481991ba2f9520846a8bb35a0019f6 Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 13 May 2022 14:04:11 +0200 Subject: [PATCH] [Bug] Statement Query and Exec methods can not be used with sql.Conn --- internal/jet/statement.go | 24 ++++++++--------- qrm/db.go | 10 +++++++ qrm/qrm.go | 4 +-- tests/postgres/select_test.go | 51 ++++++++++++++++++++++++++++++++--- 4 files changed, 71 insertions(+), 18 deletions(-) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 183aaae..11d8c95 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -11,23 +11,23 @@ import ( type Statement interface { // Sql returns parametrized sql query with list of arguments. Sql() (query string, args []interface{}) - // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. + // DebugSql returns debug query where every parametrized placeholder is replaced with its argument string representation. // Do not use it in production. Use it only for debug purposes. DebugSql() (query string) - // Query executes statement over database connection/transaction db and stores row result in destination. + // Query executes statement over database connection/transaction db and stores row results in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. - Query(db qrm.DB, destination interface{}) error + Query(db qrm.Queryable, destination interface{}) error // QueryContext executes statement with a context over database connection/transaction db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. - QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error + QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error // Exec executes statement over db connection/transaction without returning any rows. - Exec(db qrm.DB) (sql.Result, error) + Exec(db qrm.Executable) (sql.Result, error) // ExecContext executes statement with context over db connection/transaction without returning any rows. - ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) + ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) // Rows executes statements over db connection/transaction and returns rows - Rows(ctx context.Context, db qrm.DB) (*Rows, error) + Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) } // Rows wraps sql.Rows type to add query result mapping for Scan method @@ -86,11 +86,11 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) { return } -func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { +func (s *serializerStatementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error { return s.QueryContext(context.Background(), db, destination) } -func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { +func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error { query, args := s.Sql() callLogger(ctx, s) @@ -112,11 +112,11 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db return err } -func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { +func (s *serializerStatementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) { return s.ExecContext(context.Background(), db) } -func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { +func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) { query, args := s.Sql() callLogger(ctx, s) @@ -141,7 +141,7 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q return res, err } -func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) { +func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) { query, args := s.Sql() callLogger(ctx, s) diff --git a/qrm/db.go b/qrm/db.go index 6b319eb..1efefb1 100644 --- a/qrm/db.go +++ b/qrm/db.go @@ -13,3 +13,13 @@ type DB interface { Query(query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } + +// Queryable interface for sql QueryContext method +type Queryable interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +// Executable interface for sql ExecContext method +type Executable interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} diff --git a/qrm/qrm.go b/qrm/qrm.go index 50597cd..1c559f6 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -17,7 +17,7 @@ var ErrNoRows = errors.New("qrm: no rows in result set") // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. -func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { +func Query(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { utils.MustBeInitializedPtr(db, "jet: db is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") @@ -88,7 +88,7 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac return nil } -func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { +func queryToSlice(ctx context.Context, db Queryable, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { if ctx == nil { ctx = context.Background() } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index b52508f..dce9b87 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1,6 +1,8 @@ package postgres import ( + "context" + "github.com/go-jet/jet/v2/qrm" "testing" "time" @@ -24,9 +26,9 @@ FROM dvds.actor WHERE actor.actor_id = 2; ` - query := Actor. - SELECT(Actor.AllColumns). + query := SELECT(Actor.AllColumns). DISTINCT(). + FROM(Actor). WHERE(Actor.ActorID.EQ(Int(2))) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2)) @@ -44,7 +46,6 @@ WHERE actor.actor_id = 2; } testutils.AssertDeepEqual(t, actor, expectedActor) - requireLogged(t, query) } @@ -166,7 +167,7 @@ SELECT customer.customer_id AS "customer.customer_id", FROM dvds.customer ORDER BY customer.customer_id ASC; ` - customers := []model.Customer{} + var customers []model.Customer query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC()) @@ -2680,6 +2681,48 @@ SELECT dvds.get_film_count(100, 120) AS "film_count"; require.Equal(t, dest.FilmCount, 165) } +func TestScanUsingConn(t *testing.T) { + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + DISTINCT(). + WHERE(Actor.ActorID.EQ(Int(2))) + + var actor model.Actor + err = stmt.Query(conn, &actor) + require.NoError(t, err) + err = stmt.QueryContext(context.Background(), conn, &actor) + require.NoError(t, err) + testutils.AssertDeepEqual(t, actor, model.Actor{ + ActorID: 2, + FirstName: "Nick", + LastName: "Wahlberg", + LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2), + }) + + _, err = stmt.Exec(conn) + require.NoError(t, err) + _, err = stmt.ExecContext(context.Background(), conn) + require.NoError(t, err) + + t.Run("ensure qrm.DB still works", func(t *testing.T) { + var qrmDB qrm.DB = db + + err = stmt.Query(qrmDB, &actor) + require.NoError(t, err) + err = stmt.QueryContext(context.Background(), qrmDB, &actor) + require.NoError(t, err) + + _, err = stmt.Exec(qrmDB) + require.NoError(t, err) + _, err = stmt.ExecContext(context.Background(), qrmDB) + require.NoError(t, err) + }) +} + var customer0 = model.Customer{ CustomerID: 1, StoreID: 1,