[Bug] Statement Query and Exec methods can not be used with sql.Conn

This commit is contained in:
go-jet 2022-05-13 14:04:11 +02:00
parent 84dbda5948
commit 3b0285cc4b
4 changed files with 71 additions and 18 deletions

View file

@ -11,23 +11,23 @@ import (
type Statement interface { type Statement interface {
// Sql returns parametrized sql query with list of arguments. // Sql returns parametrized sql query with list of arguments.
Sql() (query string, args []interface{}) Sql() (query string, args []interface{})
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // DebugSql returns debug query where every parametrized placeholder is replaced with its argument string representation.
// Do not use it in production. Use it only for debug purposes. // Do not use it in production. Use it only for debug purposes.
DebugSql() (query string) DebugSql() (query string)
// Query executes statement over database connection/transaction db and stores row result in destination. // Query executes statement over database connection/transaction db and stores row results in destination.
// Destination can be either pointer to struct or pointer to a slice. // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
Query(db qrm.DB, destination interface{}) error Query(db qrm.Queryable, destination interface{}) error
// QueryContext executes statement with a context over database connection/transaction db and stores row result in destination. // QueryContext executes statement with a context over database connection/transaction db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice. // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error
// Exec executes statement over db connection/transaction without returning any rows. // Exec executes statement over db connection/transaction without returning any rows.
Exec(db qrm.DB) (sql.Result, error) Exec(db qrm.Executable) (sql.Result, error)
// ExecContext executes statement with context over db connection/transaction without returning any rows. // ExecContext executes statement with context over db connection/transaction without returning any rows.
ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error)
// Rows executes statements over db connection/transaction and returns rows // Rows executes statements over db connection/transaction and returns rows
Rows(ctx context.Context, db qrm.DB) (*Rows, error) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error)
} }
// Rows wraps sql.Rows type to add query result mapping for Scan method // Rows wraps sql.Rows type to add query result mapping for Scan method
@ -86,11 +86,11 @@ func (s *serializerStatementInterfaceImpl) DebugSql() (query string) {
return return
} }
func (s *serializerStatementInterfaceImpl) Query(db qrm.DB, destination interface{}) error { func (s *serializerStatementInterfaceImpl) Query(db qrm.Queryable, destination interface{}) error {
return s.QueryContext(context.Background(), db, destination) return s.QueryContext(context.Background(), db, destination)
} }
func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error { func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db qrm.Queryable, destination interface{}) error {
query, args := s.Sql() query, args := s.Sql()
callLogger(ctx, s) callLogger(ctx, s)
@ -112,11 +112,11 @@ func (s *serializerStatementInterfaceImpl) QueryContext(ctx context.Context, db
return err return err
} }
func (s *serializerStatementInterfaceImpl) Exec(db qrm.DB) (res sql.Result, err error) { func (s *serializerStatementInterfaceImpl) Exec(db qrm.Executable) (res sql.Result, err error) {
return s.ExecContext(context.Background(), db) return s.ExecContext(context.Background(), db)
} }
func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.DB) (res sql.Result, err error) { func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db qrm.Executable) (res sql.Result, err error) {
query, args := s.Sql() query, args := s.Sql()
callLogger(ctx, s) callLogger(ctx, s)
@ -141,7 +141,7 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q
return res, err return res, err
} }
func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) { func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.Queryable) (*Rows, error) {
query, args := s.Sql() query, args := s.Sql()
callLogger(ctx, s) callLogger(ctx, s)

View file

@ -13,3 +13,13 @@ type DB interface {
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
} }
// Queryable interface for sql QueryContext method
type Queryable interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// Executable interface for sql ExecContext method
type Executable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}

View file

@ -17,7 +17,7 @@ var ErrNoRows = errors.New("qrm: no rows in result set")
// using context `ctx` into destination `destPtr`. // using context `ctx` into destination `destPtr`.
// Destination can be either pointer to struct or pointer to slice of structs. // Destination can be either pointer to struct or pointer to slice of structs.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) { func Query(ctx context.Context, db Queryable, query string, args []interface{}, destPtr interface{}) (rowsProcessed int64, err error) {
utils.MustBeInitializedPtr(db, "jet: db is nil") utils.MustBeInitializedPtr(db, "jet: db is nil")
utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil")
@ -88,7 +88,7 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
return nil return nil
} }
func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { func queryToSlice(ctx context.Context, db Queryable, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }

View file

@ -1,6 +1,8 @@
package postgres package postgres
import ( import (
"context"
"github.com/go-jet/jet/v2/qrm"
"testing" "testing"
"time" "time"
@ -24,9 +26,9 @@ FROM dvds.actor
WHERE actor.actor_id = 2; WHERE actor.actor_id = 2;
` `
query := Actor. query := SELECT(Actor.AllColumns).
SELECT(Actor.AllColumns).
DISTINCT(). DISTINCT().
FROM(Actor).
WHERE(Actor.ActorID.EQ(Int(2))) WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2)) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(2))
@ -44,7 +46,6 @@ WHERE actor.actor_id = 2;
} }
testutils.AssertDeepEqual(t, actor, expectedActor) testutils.AssertDeepEqual(t, actor, expectedActor)
requireLogged(t, query) requireLogged(t, query)
} }
@ -166,7 +167,7 @@ SELECT customer.customer_id AS "customer.customer_id",
FROM dvds.customer FROM dvds.customer
ORDER BY customer.customer_id ASC; ORDER BY customer.customer_id ASC;
` `
customers := []model.Customer{} var customers []model.Customer
query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC()) query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC())
@ -2680,6 +2681,48 @@ SELECT dvds.get_film_count(100, 120) AS "film_count";
require.Equal(t, dest.FilmCount, 165) require.Equal(t, dest.FilmCount, 165)
} }
func TestScanUsingConn(t *testing.T) {
conn, err := db.Conn(context.Background())
require.NoError(t, err)
defer conn.Close()
stmt := SELECT(Actor.AllColumns).
FROM(Actor).
DISTINCT().
WHERE(Actor.ActorID.EQ(Int(2)))
var actor model.Actor
err = stmt.Query(conn, &actor)
require.NoError(t, err)
err = stmt.QueryContext(context.Background(), conn, &actor)
require.NoError(t, err)
testutils.AssertDeepEqual(t, actor, model.Actor{
ActorID: 2,
FirstName: "Nick",
LastName: "Wahlberg",
LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:47:57.62", 2),
})
_, err = stmt.Exec(conn)
require.NoError(t, err)
_, err = stmt.ExecContext(context.Background(), conn)
require.NoError(t, err)
t.Run("ensure qrm.DB still works", func(t *testing.T) {
var qrmDB qrm.DB = db
err = stmt.Query(qrmDB, &actor)
require.NoError(t, err)
err = stmt.QueryContext(context.Background(), qrmDB, &actor)
require.NoError(t, err)
_, err = stmt.Exec(qrmDB)
require.NoError(t, err)
_, err = stmt.ExecContext(context.Background(), qrmDB)
require.NoError(t, err)
})
}
var customer0 = model.Customer{ var customer0 = model.Customer{
CustomerID: 1, CustomerID: 1,
StoreID: 1, StoreID: 1,