Add support to retrieve Rows from statement

Rows statement method executes statements over db connection/transaction and returns Rows.
This commit is contained in:
go-jet 2021-05-16 18:46:50 +02:00
parent a5b7769589
commit 3021a6a0fd
7 changed files with 248 additions and 10 deletions

View file

@ -13,19 +13,30 @@ type Statement interface {
// DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // DebugSql returns debug query where every parametrized placeholder is replaced with its argument.
// Do not use it in production. Use it only for debug purposes. // Do not use it in production. Use it only for debug purposes.
DebugSql() (query string) DebugSql() (query string)
// Query executes statement over database connection db and stores row result in destination. // Query executes statement over database connection/transaction db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice. // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
Query(db qrm.DB, destination interface{}) error Query(db qrm.DB, destination interface{}) error
// QueryContext executes statement with a context over database connection db and stores row result in destination. // QueryContext executes statement with a context over database connection/transaction db and stores row result in destination.
// Destination can be either pointer to struct or pointer to a slice. // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows.
QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error QueryContext(ctx context.Context, db qrm.DB, destination interface{}) error
//Exec executes statement over db connection/transaction without returning any rows.
//Exec executes statement over db connection without returning any rows.
Exec(db qrm.DB) (sql.Result, error) Exec(db qrm.DB) (sql.Result, error)
//Exec executes statement with context over db connection without returning any rows. //Exec executes statement with context over db connection/transaction without returning any rows.
ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error) ExecContext(ctx context.Context, db qrm.DB) (sql.Result, error)
// Rows executes statements over db connection/transaction and returns rows
Rows(ctx context.Context, db qrm.DB) (*Rows, error)
}
// Rows wraps sql.Rows type to add query result mapping for Scan method
type Rows struct {
*sql.Rows
}
// Scan will map the Row values into struct destination
func (r *Rows) Scan(destination interface{}) error {
return qrm.ScanOneRowToDest(r.Rows, destination)
} }
// SerializerStatement interface // SerializerStatement interface
@ -99,6 +110,20 @@ func (s *serializerStatementInterfaceImpl) ExecContext(ctx context.Context, db q
return db.ExecContext(ctx, query, args...) return db.ExecContext(ctx, query, args...)
} }
func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) (*Rows, error) {
query, args := s.Sql()
callLogger(ctx, s)
rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &Rows{rows}, nil
}
func callLogger(ctx context.Context, statement Statement) { func callLogger(ctx context.Context, statement Statement) {
if logger != nil { if logger != nil {
logger(ctx, statement) logger(ctx, statement)

View file

@ -5,7 +5,8 @@ import (
"database/sql" "database/sql"
) )
// DB is common database interface used by jet execution // DB is common database interface used by query result mapping
// Both *sql.DB and *sql.Tx implements DB interface
type DB interface { type DB interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)

View file

@ -2,9 +2,12 @@ package qrm
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"github.com/go-jet/jet/v2/internal/utils" "fmt"
"reflect" "reflect"
"github.com/go-jet/jet/v2/internal/utils"
) )
// ErrNoRows is returned by Query when query result set is empty // ErrNoRows is returned by Query when query result set is empty
@ -56,6 +59,51 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
} }
} }
func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error {
utils.MustBeInitializedPtr(destPtr, "jet: destination is nil")
utils.MustBe(destPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct")
scanContext, err := newScanContext(rows)
if err != nil {
return fmt.Errorf("failed to create scan context, %w", err)
}
if len(scanContext.row) == 0 {
return errors.New("empty row slice")
}
err = rows.Scan(scanContext.row...)
if err != nil {
return fmt.Errorf("rows scan error, %w", err)
}
destinationPtrType := reflect.TypeOf(destPtr)
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
_, err = mapRowToSlice(scanContext, "", tempSlicePtrValue, nil)
if err != nil {
return fmt.Errorf("failed to map a row, %w", err)
}
// edge case when row result set contains only NULLs.
if tempSliceValue.Len() == 0 {
return nil
}
destValue := reflect.ValueOf(destPtr).Elem()
firstTempSliceValue := tempSliceValue.Index(0).Elem()
if destValue.Type().AssignableTo(firstTempSliceValue.Type()) {
destValue.Set(tempSliceValue.Index(0).Elem())
}
return nil
}
func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()

View file

@ -1,7 +1,9 @@
package mysql package mysql
import ( import (
"context"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -80,3 +82,42 @@ func TestRawStatementSelectWithArguments(t *testing.T) {
LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2), LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 04:34:33", 2),
}) })
} }
func TestRawStatementRows(t *testing.T) {
stmt := RawStatement(`
SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM dvds.actor
ORDER BY actor.actor_id`)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
for rows.Next() {
var actor model.Actor
err := rows.Scan(&actor)
require.NoError(t, err)
require.NotEqual(t, actor.ActorID, int16(0))
require.NotEqual(t, actor.FirstName, "")
require.NotEqual(t, actor.LastName, "")
require.NotEqual(t, actor.LastUpdate, time.Time{})
if actor.ActorID == 54 {
require.Equal(t, actor.ActorID, uint16(54))
require.Equal(t, actor.FirstName, "PENELOPE")
require.Equal(t, actor.LastName, "PINKETT")
require.Equal(t, actor.LastUpdate.Format(time.RFC3339), "2006-02-15T04:34:33Z")
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
requireLogged(t, stmt)
}

View file

@ -1,8 +1,10 @@
package mysql package mysql
import ( import (
"context"
"strings" "strings"
"testing" "testing"
"time"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql" . "github.com/go-jet/jet/v2/mysql"
@ -887,3 +889,41 @@ LIMIT 1;
require.Equal(t, dest, dest2) require.Equal(t, dest, dest2)
}) })
} }
func TestRowsScan(t *testing.T) {
stmt := SELECT(
Inventory.AllColumns,
).FROM(
Inventory,
).ORDER_BY(
Inventory.InventoryID.ASC(),
)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
for rows.Next() {
var inventory model.Inventory
err = rows.Scan(&inventory)
require.NoError(t, err)
require.NotEqual(t, inventory.InventoryID, uint32(0))
require.NotEqual(t, inventory.FilmID, uint16(0))
require.NotEqual(t, inventory.StoreID, uint16(0))
require.NotEqual(t, inventory.LastUpdate, time.Time{})
if inventory.InventoryID == 2103 {
require.Equal(t, inventory.FilmID, uint16(456))
require.Equal(t, inventory.StoreID, uint8(2))
require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T05:09:17Z")
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
requireLogged(t, stmt)
}

View file

@ -1,7 +1,9 @@
package postgres package postgres
import ( import (
"context"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -136,3 +138,42 @@ RETURNING link.id AS "link.id",
require.Equal(t, links[2].Name, "Google") require.Equal(t, links[2].Name, "Google")
require.Nil(t, links[2].Description) require.Nil(t, links[2].Description)
} }
func TestRawStatementRows(t *testing.T) {
stmt := RawStatement(`
SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM dvds.actor
ORDER BY actor.actor_id`)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
for rows.Next() {
var actor model.Actor
err := rows.Scan(&actor)
require.NoError(t, err)
require.NotEqual(t, actor.ActorID, int32(0))
require.NotEqual(t, actor.FirstName, "")
require.NotEqual(t, actor.LastName, "")
require.NotEqual(t, actor.LastUpdate, time.Time{})
if actor.ActorID == 54 {
require.Equal(t, actor.ActorID, int32(54))
require.Equal(t, actor.FirstName, "Penelope")
require.Equal(t, actor.LastName, "Pinkett")
require.Equal(t, actor.LastUpdate.Format(time.RFC3339), "2013-05-26T14:47:57Z")
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
requireLogged(t, stmt)
}

View file

@ -1,14 +1,18 @@
package postgres package postgres
import ( import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres" . "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"testing"
) )
var oneInventoryQuery = Inventory. var oneInventoryQuery = Inventory.
@ -722,6 +726,44 @@ func TestStructScanAllNull(t *testing.T) {
}{}) }{})
} }
func TestRowsScan(t *testing.T) {
stmt := SELECT(
Inventory.AllColumns,
).FROM(
Inventory,
).ORDER_BY(
Inventory.InventoryID.ASC(),
)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
for rows.Next() {
var inventory model.Inventory
err = rows.Scan(&inventory)
require.NoError(t, err)
require.NotEqual(t, inventory.InventoryID, int32(0))
require.NotEqual(t, inventory.FilmID, int16(0))
require.NotEqual(t, inventory.StoreID, int16(0))
require.NotEqual(t, inventory.LastUpdate, time.Time{})
if inventory.InventoryID == 2103 {
require.Equal(t, inventory.FilmID, int16(456))
require.Equal(t, inventory.StoreID, int16(2))
require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T10:09:17Z")
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
requireLogged(t, stmt)
}
var address256 = model.Address{ var address256 = model.Address{
AddressID: 256, AddressID: 256,
Address: "1497 Yuzhou Drive", Address: "1497 Yuzhou Drive",