Improve Rows scan performance

ScanContext reused between rows.Scan calls.
Simplified assign value logic.
Use complex destination for Rows test.
This commit is contained in:
go-jet 2022-02-04 12:31:08 +01:00
parent 4f29960378
commit c10244aeab
5 changed files with 114 additions and 84 deletions

View file

@ -33,11 +33,13 @@ type Statement interface {
// Rows wraps sql.Rows type to add query result mapping for Scan method // Rows wraps sql.Rows type to add query result mapping for Scan method
type Rows struct { type Rows struct {
*sql.Rows *sql.Rows
scanContext *qrm.ScanContext
} }
// Scan will map the Row values into struct destination // Scan will map the Row values into struct destination
func (r *Rows) Scan(destination interface{}) error { func (r *Rows) Scan(destination interface{}) error {
return qrm.ScanOneRowToDest(r.Rows, destination) return qrm.ScanOneRowToDest(r.scanContext, r.Rows, destination)
} }
// SerializerStatement interface // SerializerStatement interface
@ -161,7 +163,16 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB)
return nil, err return nil, err
} }
return &Rows{rows}, nil scanContext, err := qrm.NewScanContext(rows)
if err != nil {
return nil, err
}
return &Rows{
Rows: rows,
scanContext: scanContext,
}, nil
} }
func duration(f func()) time.Duration { func duration(f func()) time.Duration {

View file

@ -63,48 +63,28 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
} }
// ScanOneRowToDest will scan one row into struct destination // ScanOneRowToDest will scan one row into struct destination
func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error { func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interface{}) error {
utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil")
utils.MustBe(destPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct") utils.MustBe(destPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct")
scanContext, err := newScanContext(rows)
if err != nil {
return fmt.Errorf("failed to create scan context, %w", err)
}
if len(scanContext.row) == 0 { if len(scanContext.row) == 0 {
return errors.New("empty row slice") return errors.New("empty row slice")
} }
err = rows.Scan(scanContext.row...) err := rows.Scan(scanContext.row...)
if err != nil { if err != nil {
return fmt.Errorf("rows scan error, %w", err) return fmt.Errorf("rows scan error, %w", err)
} }
destinationPtrType := reflect.TypeOf(destPtr) destValue := reflect.ValueOf(destPtr)
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
_, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil) _, err = mapRowToStruct(scanContext, "", newTypeStack(), destValue, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to map a row, %w", err) return fmt.Errorf("failed to map a row, %w", err)
} }
// edge case when row result set contains only NULLs.
if tempSliceValue.Len() == 0 {
return nil
}
destValue := reflect.ValueOf(destPtr).Elem()
firstTempSliceValue := tempSliceValue.Index(0).Elem()
if destValue.Type().AssignableTo(firstTempSliceValue.Type()) {
destValue.Set(tempSliceValue.Index(0).Elem())
}
return nil return nil
} }
@ -120,7 +100,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
} }
defer rows.Close() defer rows.Close()
scanContext, err := newScanContext(rows) scanContext, err := NewScanContext(rows)
if err != nil { if err != nil {
return return
@ -157,7 +137,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
} }
func mapRowToSlice( func mapRowToSlice(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack, typesVisited *typeStack,
slicePtrValue reflect.Value, slicePtrValue reflect.Value,
@ -204,7 +184,7 @@ func mapRowToSlice(
return return
} }
func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
index := 0 index := 0
if field != nil { if field != nil {
typeName, columnName := getTypeAndFieldName("", *field) typeName, columnName := getTypeAndFieldName("", *field)
@ -226,7 +206,7 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
} }
func mapRowToStruct( func mapRowToStruct(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack, // to prevent circular dependency scan typesVisited *typeStack, // to prevent circular dependency scan
structPtrValue reflect.Value, structPtrValue reflect.Value,
@ -308,7 +288,7 @@ func mapRowToStruct(
} }
func mapRowToDestinationValue( func mapRowToDestinationValue(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack, typesVisited *typeStack,
dest reflect.Value, dest reflect.Value,
@ -340,7 +320,7 @@ func mapRowToDestinationValue(
} }
func mapRowToDestinationPtr( func mapRowToDestinationPtr(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack, typesVisited *typeStack,
destPtrValue reflect.Value, destPtrValue reflect.Value,

View file

@ -7,7 +7,9 @@ import (
"strings" "strings"
) )
type scanContext struct { // ScanContext contains information about current row processed, mapping from the row to the
// destination types and type grouping information.
type ScanContext struct {
rowNum int64 rowNum int64
row []interface{} row []interface{}
uniqueDestObjectsMap map[string]int uniqueDestObjectsMap map[string]int
@ -16,7 +18,8 @@ type scanContext struct {
typeInfoMap map[string]typeInfo typeInfoMap map[string]typeInfo
} }
func newScanContext(rows *sql.Rows) (*scanContext, error) { // NewScanContext creates new ScanContext from rows
func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
aliases, err := rows.Columns() aliases, err := rows.Columns()
if err != nil { if err != nil {
@ -42,7 +45,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
commonIdentToColumnIndex[commonIdentifier] = i commonIdentToColumnIndex[commonIdentifier] = i
} }
return &scanContext{ return &ScanContext{
row: createScanSlice(len(columnTypes)), row: createScanSlice(len(columnTypes)),
uniqueDestObjectsMap: make(map[string]int), uniqueDestObjectsMap: make(map[string]int),
@ -74,7 +77,7 @@ type fieldMapping struct {
implementsScanner bool implementsScanner bool
} }
func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo {
typeMapKey := structType.String() typeMapKey := structType.String()
@ -120,7 +123,7 @@ type groupKeyInfo struct {
subTypes []groupKeyInfo subTypes []groupKeyInfo
} }
func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string {
mapKey := structType.Name() mapKey := structType.Name()
@ -139,7 +142,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.
return s.constructGroupKey(groupKeyInfo) return s.constructGroupKey(groupKeyInfo)
} }
func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 {
return fmt.Sprintf("|ROW:%d|", s.rowNum) return fmt.Sprintf("|ROW:%d|", s.rowNum)
} }
@ -161,7 +164,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")"
} }
func (s *scanContext) getGroupKeyInfo( func (s *ScanContext) getGroupKeyInfo(
structType reflect.Type, structType reflect.Type,
parentField *reflect.StructField, parentField *reflect.StructField,
typeVisited *typeStack) groupKeyInfo { typeVisited *typeStack) groupKeyInfo {
@ -210,7 +213,7 @@ func (s *scanContext) getGroupKeyInfo(
return ret return ret
} }
func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
var key string var key string
if typeName != "" { if typeName != "" {
@ -228,7 +231,7 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
return index return index
} }
func (s *scanContext) rowElem(index int) interface{} { func (s *ScanContext) rowElem(index int) interface{} {
cellValue := reflect.ValueOf(s.row[index]) cellValue := reflect.ValueOf(s.row[index])
if cellValue.IsValid() && !cellValue.IsNil() { if cellValue.IsValid() && !cellValue.IsNil() {
@ -238,7 +241,7 @@ func (s *scanContext) rowElem(index int) interface{} {
return nil return nil
} }
func (s *scanContext) rowElemValuePtr(index int) reflect.Value { func (s *ScanContext) rowElemValuePtr(index int) reflect.Value {
rowElem := s.rowElem(index) rowElem := s.rowElem(index)
rowElemValue := reflect.ValueOf(rowElem) rowElemValue := reflect.ValueOf(rowElem)

View file

@ -201,23 +201,38 @@ func isFloatType(value reflect.Type) bool {
return false return false
} }
func tryAssign(source, destination reflect.Value) error { func assignIfAssignable(source, destination reflect.Value) bool {
if source.Type() != destination.Type() &&
!isFloatType(destination.Type()) && // to preserve precision during conversion
!(isIntegerType(source.Type()) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
if source.Type().AssignableTo(destination.Type()) { if source.Type().AssignableTo(destination.Type()) {
switch b := source.Interface().(type) { switch source.Type() {
case []byte: case byteArrayType:
destination.SetBytes(cloneBytes(b)) destination.SetBytes(cloneBytes(source.Interface().([]byte)))
default: default:
destination.Set(source) destination.Set(source)
} }
return true
}
return false
}
func tryAssign(source, destination reflect.Value) error {
if assignIfAssignable(source, destination) {
return nil
}
sourceType := source.Type()
destinationType := destination.Type()
if sourceType != destinationType &&
!isFloatType(destinationType) && // to preserve precision during conversion
!(isIntegerType(sourceType) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
sourceType.ConvertibleTo(destinationType) {
source = source.Convert(destinationType)
}
if assignIfAssignable(source, destination) {
return nil return nil
} }
@ -302,38 +317,32 @@ func tryAssign(source, destination reflect.Value) error {
return nil return nil
} }
func setZeroValue(value reflect.Value) {
if !value.IsZero() {
value.Set(reflect.Zero(value.Type()))
}
}
func setReflectValue(source, destination reflect.Value) error { func setReflectValue(source, destination reflect.Value) error {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
// source is nil, destination should be its zero value
setZeroValue(destination)
return nil
}
source = source.Elem()
}
if destination.Kind() == reflect.Ptr { if destination.Kind() == reflect.Ptr {
if destination.IsNil() { if destination.IsNil() {
initializeValueIfNilPtr(destination) initializeValueIfNilPtr(destination)
} }
if source.Kind() == reflect.Ptr { destination = destination.Elem()
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination.Elem()); err != nil {
return err
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination); err != nil {
return err
}
} }
return nil return tryAssign(source, destination)
} }
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {

View file

@ -951,8 +951,12 @@ func TestRowsScan(t *testing.T) {
stmt := SELECT( stmt := SELECT(
Inventory.AllColumns, Inventory.AllColumns,
Film.AllColumns,
Store.AllColumns,
).FROM( ).FROM(
Inventory, Inventory.
INNER_JOIN(Film, Film.FilmID.EQ(Inventory.FilmID)).
INNER_JOIN(Store, Store.StoreID.EQ(Inventory.StoreID)),
).ORDER_BY( ).ORDER_BY(
Inventory.InventoryID.ASC(), Inventory.InventoryID.ASC(),
) )
@ -961,19 +965,42 @@ func TestRowsScan(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
for rows.Next() { for rows.Next() {
var inventory model.Inventory var inventory struct {
model.Inventory
Film model.Film
Store model.Store
}
err = rows.Scan(&inventory) err = rows.Scan(&inventory)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, inventory.InventoryID, uint32(0)) require.NotEmpty(t, inventory.InventoryID)
require.NotEqual(t, inventory.FilmID, uint16(0)) require.NotEmpty(t, inventory.FilmID)
require.NotEqual(t, inventory.StoreID, uint16(0)) require.NotEmpty(t, inventory.StoreID)
require.NotEqual(t, inventory.LastUpdate, time.Time{}) require.NotEmpty(t, inventory.LastUpdate)
require.NotEmpty(t, inventory.Film.FilmID)
require.NotEmpty(t, inventory.Film.Title)
require.NotEmpty(t, inventory.Film.Description)
require.NotEmpty(t, inventory.Store.StoreID)
require.NotEmpty(t, inventory.Store.AddressID)
require.NotEmpty(t, inventory.Store.ManagerStaffID)
if inventory.InventoryID == 2103 { if inventory.InventoryID == 2103 {
require.Equal(t, inventory.FilmID, uint16(456)) require.Equal(t, inventory.FilmID, uint16(456))
require.Equal(t, inventory.StoreID, uint8(2)) require.Equal(t, inventory.StoreID, uint8(2))
require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T05:09:17Z") require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T05:09:17Z")
require.Equal(t, inventory.Film.FilmID, uint16(456))
require.Equal(t, inventory.Film.Title, "INCH JET")
require.Equal(t, *inventory.Film.Description, "A Fateful Saga of a Womanizer And a Student who must Defeat a Butler in A Monastery")
require.Equal(t, *inventory.Film.ReleaseYear, int16(2006))
require.Equal(t, inventory.Store.StoreID, uint8(2))
require.Equal(t, inventory.Store.ManagerStaffID, uint8(2))
require.Equal(t, inventory.Store.AddressID, uint16(2))
} }
} }