diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 6e87da8..da3650d 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -13,19 +13,30 @@ type Statement interface { // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // Do not use it in production. Use it only for debug purposes. DebugSql() (query string) - // Query executes statement over database connection db and stores row result in destination. + // Query executes statement over database connection/transaction db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. Query(db qrm.DB, destination interface{}) error - // QueryContext executes statement with a context over database connection db and stores row result in destination. + // QueryContext executes statement with a context over database connection/transaction db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error - - //Exec executes statement over db connection without returning any rows. + //Exec executes statement over db connection/transaction without returning any rows. Exec(db qrm.DB) (sql.Result, error) - //Exec executes statement with context over db connection without returning any rows. + //Exec executes statement with context over db connection/transaction without returning any rows. ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) + // Rows executes statements over db connection/transaction and returns rows + Rows(ctx context.Context, db qrm.DB) (*Rows, error) +} + +// Rows wraps sql.Rows type to add query result mapping for Scan method +type Rows struct { + *sql.Rows +} + +// Scan will map the Row values into struct destination +func (r *Rows) Scan(destination interface{}) error { + return qrm.ScanOneRowToDest(r.Rows, destination) } // SerializerStatement interface @@ -99,6 +110,20 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q return db.ExecContext(ctx, query, args...) } +func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) { + query, args := s.Sql() + + callLogger(ctx, s) + + rows, err := db.QueryContext(ctx, query, args...) + + if err != nil { + return nil, err + } + + return &Rows{rows}, nil +} + func callLogger(ctx context.Context, statement Statement) { if logger != nil { logger(ctx, statement) diff --git a/qrm/db.go b/qrm/db.go index 564819a..6b319eb 100644 --- a/qrm/db.go +++ b/qrm/db.go @@ -5,7 +5,8 @@ import ( "database/sql" ) -// DB is common database interface used by jet execution +// DB is common database interface used by query result mapping +// Both *sql.DB and *sql.Tx implements DB interface type DB interface { Exec(query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) diff --git a/qrm/qrm.go b/qrm/qrm.go index 7477569..52c1a28 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -2,9 +2,12 @@ package qrm import ( "context" + "database/sql" "errors" - "github.com/go-jet/jet/v2/internal/utils" + "fmt" "reflect" + + "github.com/go-jet/jet/v2/internal/utils" ) // ErrNoRows is returned by Query when query result set is empty @@ -56,6 +59,51 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr } } +func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error { + utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") + utils.MustBe(destPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct") + + scanContext, err := newScanContext(rows) + + if err != nil { + return fmt.Errorf("failed to create scan context, %w", err) + } + + if len(scanContext.row) == 0 { + return errors.New("empty row slice") + } + + err = rows.Scan(scanContext.row...) + + if err != nil { + return fmt.Errorf("rows scan error, %w", err) + } + + destinationPtrType := reflect.TypeOf(destPtr) + tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) + tempSliceValue := tempSlicePtrValue.Elem() + + _, err = mapRowToSlice(scanContext, "", tempSlicePtrValue, nil) + + if err != nil { + return fmt.Errorf("failed to map a row, %w", err) + } + + // edge case when row result set contains only NULLs. + if tempSliceValue.Len() == 0 { + return nil + } + + destValue := reflect.ValueOf(destPtr).Elem() + firstTempSliceValue := tempSliceValue.Index(0).Elem() + + if destValue.Type().AssignableTo(firstTempSliceValue.Type()) { + destValue.Set(tempSliceValue.Index(0).Elem()) + } + + return nil +} + func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { if ctx == nil { ctx = context.Background() diff --git a/tests/mysql/raw_statement_test.go b/tests/mysql/raw_statement_test.go index 9af4c46..fd4531f 100644 --- a/tests/mysql/raw_statement_test.go +++ b/tests/mysql/raw_statement_test.go @@ -1,7 +1,9 @@ package mysql import ( + "context" "testing" + "time" "github.com/stretchr/testify/require" @@ -80,3 +82,42 @@ func TestRawStatementSelectWithArguments(t *testing.T) { LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), }) } + +func TestRawStatementRows(t *testing.T) { + stmt := RawStatement(` + SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + ORDER BY actor.actor_id`) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var actor model.Actor + err := rows.Scan(&actor) + require.NoError(t, err) + + require.NotEqual(t, actor.ActorID, int16(0)) + require.NotEqual(t, actor.FirstName, "") + require.NotEqual(t, actor.LastName, "") + require.NotEqual(t, actor.LastUpdate, time.Time{}) + + if actor.ActorID == 54 { + require.Equal(t, actor.ActorID, uint16(54)) + require.Equal(t, actor.FirstName, "PENELOPE") + require.Equal(t, actor.LastName, "PINKETT") + require.Equal(t, actor.LastUpdate.Format(time.RFC3339), "2006-02-15T04:34:33Z") + } + } + + err = rows.Close() + require.NoError(t, err) + + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 1a60a42..6bbc211 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -1,8 +1,10 @@ package mysql import ( + "context" "strings" "testing" + "time" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" @@ -887,3 +889,41 @@ LIMIT 1; require.Equal(t, dest, dest2) }) } + +func TestRowsScan(t *testing.T) { + + stmt := SELECT( + Inventory.AllColumns, + ).FROM( + Inventory, + ).ORDER_BY( + Inventory.InventoryID.ASC(), + ) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var inventory model.Inventory + err = rows.Scan(&inventory) + require.NoError(t, err) + + require.NotEqual(t, inventory.InventoryID, uint32(0)) + require.NotEqual(t, inventory.FilmID, uint16(0)) + require.NotEqual(t, inventory.StoreID, uint16(0)) + require.NotEqual(t, inventory.LastUpdate, time.Time{}) + + if inventory.InventoryID == 2103 { + require.Equal(t, inventory.FilmID, uint16(456)) + require.Equal(t, inventory.StoreID, uint8(2)) + require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T05:09:17Z") + } + } + + err = rows.Close() + require.NoError(t, err) + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} diff --git a/tests/postgres/raw_statements_test.go b/tests/postgres/raw_statements_test.go index 61c3228..a193258 100644 --- a/tests/postgres/raw_statements_test.go +++ b/tests/postgres/raw_statements_test.go @@ -1,7 +1,9 @@ package postgres import ( + "context" "testing" + "time" "github.com/stretchr/testify/require" @@ -136,3 +138,42 @@ RETURNING link.id AS "link.id", require.Equal(t, links[2].Name, "Google") require.Nil(t, links[2].Description) } + +func TestRawStatementRows(t *testing.T) { + stmt := RawStatement(` + SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM dvds.actor + ORDER BY actor.actor_id`) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var actor model.Actor + err := rows.Scan(&actor) + require.NoError(t, err) + + require.NotEqual(t, actor.ActorID, int32(0)) + require.NotEqual(t, actor.FirstName, "") + require.NotEqual(t, actor.LastName, "") + require.NotEqual(t, actor.LastUpdate, time.Time{}) + + if actor.ActorID == 54 { + require.Equal(t, actor.ActorID, int32(54)) + require.Equal(t, actor.FirstName, "Penelope") + require.Equal(t, actor.LastName, "Pinkett") + require.Equal(t, actor.LastUpdate.Format(time.RFC3339), "2013-05-26T14:47:57Z") + } + } + + err = rows.Close() + require.NoError(t, err) + + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index def3097..dacdf88 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -1,14 +1,18 @@ package postgres import ( + "context" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" - "github.com/google/uuid" - "github.com/stretchr/testify/require" - "testing" ) var oneInventoryQuery = Inventory. @@ -722,6 +726,44 @@ func TestStructScanAllNull(t *testing.T) { }{}) } +func TestRowsScan(t *testing.T) { + + stmt := SELECT( + Inventory.AllColumns, + ).FROM( + Inventory, + ).ORDER_BY( + Inventory.InventoryID.ASC(), + ) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var inventory model.Inventory + err = rows.Scan(&inventory) + require.NoError(t, err) + + require.NotEqual(t, inventory.InventoryID, int32(0)) + require.NotEqual(t, inventory.FilmID, int16(0)) + require.NotEqual(t, inventory.StoreID, int16(0)) + require.NotEqual(t, inventory.LastUpdate, time.Time{}) + + if inventory.InventoryID == 2103 { + require.Equal(t, inventory.FilmID, int16(456)) + require.Equal(t, inventory.StoreID, int16(2)) + require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T10:09:17Z") + } + } + + err = rows.Close() + require.NoError(t, err) + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} + var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive",