Additional scan performance improvements

Move typeStack to ScanContext, so it is shared between rows.Scan calls.
Use string.Builder for string concatenations.
Simplify value assign logic.
Move convert value to the last assign step (needs for type conversions are rare).
This commit is contained in:
go-jet 2022-02-09 12:34:10 +01:00
parent c10244aeab
commit c86903fd1d
8 changed files with 428 additions and 174 deletions

View file

@ -74,15 +74,15 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
err := rows.Scan(scanContext.row...)
if err != nil {
return fmt.Errorf("rows scan error, %w", err)
return fmt.Errorf("jet: rows scan error, %w", err)
}
destValue := reflect.ValueOf(destPtr)
destValuePtr := reflect.ValueOf(destPtr)
_, err = mapRowToStruct(scanContext, "", newTypeStack(), destValue, nil)
_, err = mapRowToStruct(scanContext, "", destValuePtr, nil)
if err != nil {
return fmt.Errorf("failed to map a row, %w", err)
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
}
return nil
@ -121,7 +121,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
scanContext.rowNum++
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil)
_, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
if err != nil {
return scanContext.rowNum, err
@ -139,7 +139,6 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
func mapRowToSlice(
scanContext *ScanContext,
groupKey string,
typesVisited *typeStack,
slicePtrValue reflect.Value,
field *reflect.StructField) (updated bool, err error) {
@ -154,19 +153,19 @@ func mapRowToSlice(
structGroupKey := scanContext.getGroupKey(sliceElemType, field)
groupKey = groupKey + "," + structGroupKey
groupKey = concat(groupKey, ",", structGroupKey)
index, ok := scanContext.uniqueDestObjectsMap[groupKey]
if ok {
structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true)
return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
}
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field)
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
if err != nil {
return
@ -192,7 +191,7 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value
return
}
}
rowElemPtr := scanContext.rowElemValuePtr(index)
rowElemPtr := scanContext.rowElemValueClonePtr(index)
if rowElemPtr.IsValid() && !rowElemPtr.IsNil() {
updated = true
@ -208,7 +207,6 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value
func mapRowToStruct(
scanContext *ScanContext,
groupKey string,
typesVisited *typeStack, // to prevent circular dependency scan
structPtrValue reflect.Value,
parentField *reflect.StructField,
onlySlices ...bool, // small optimization, not to assign to already assigned struct fields
@ -217,12 +215,12 @@ func mapRowToStruct(
mapOnlySlices := len(onlySlices) > 0
structType := structPtrValue.Type().Elem()
if typesVisited.contains(&structType) {
if scanContext.typesVisited.contains(&structType) {
return false, nil
}
typesVisited.push(&structType)
defer typesVisited.pop()
scanContext.typesVisited.push(&structType)
defer scanContext.typesVisited.pop()
typeInf := scanContext.getTypeInfo(structType, parentField)
@ -240,7 +238,7 @@ func mapRowToStruct(
if fieldMap.complexType {
var changed bool
changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field)
changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
if err != nil {
return
@ -251,34 +249,36 @@ func mapRowToStruct(
}
} else {
if mapOnlySlices || fieldMap.columnIndex == -1 {
if mapOnlySlices || fieldMap.rowIndex == -1 {
continue
}
cellValue := scanContext.rowElem(fieldMap.columnIndex)
scannedValue := scanContext.rowElemValue(fieldMap.rowIndex)
if cellValue == nil {
if !scannedValue.IsValid() {
setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value
continue
}
initializeValueIfNilPtr(fieldValue)
updated = true
if fieldMap.implementsScanner {
scanner := getScanner(fieldValue)
initializeValueIfNilPtr(fieldValue)
fieldScanner := getScanner(fieldValue)
err = scanner.Scan(cellValue)
value := scannedValue.Interface()
err := fieldScanner.Scan(value)
if err != nil {
err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err)
}
} else {
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue)
err := assign(scannedValue, fieldValue)
if err != nil {
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(),
field.Name, field.Type.String(), err)
}
}
}
@ -290,7 +290,6 @@ func mapRowToStruct(
func mapRowToDestinationValue(
scanContext *ScanContext,
groupKey string,
typesVisited *typeStack,
dest reflect.Value,
structField *reflect.StructField) (updated bool, err error) {
@ -306,7 +305,7 @@ func mapRowToDestinationValue(
}
}
updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField)
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
if err != nil {
return
@ -322,7 +321,6 @@ func mapRowToDestinationValue(
func mapRowToDestinationPtr(
scanContext *ScanContext,
groupKey string,
typesVisited *typeStack,
destPtrValue reflect.Value,
structField *reflect.StructField) (updated bool, err error) {
@ -331,9 +329,9 @@ func mapRowToDestinationPtr(
destValueKind := destPtrValue.Elem().Kind()
if destValueKind == reflect.Struct {
return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField)
return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
} else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField)
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else {
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
}

View file

@ -16,6 +16,8 @@ type ScanContext struct {
commonIdentToColumnIndex map[string]int
groupKeyInfoCache map[string]groupKeyInfo
typeInfoMap map[string]typeInfo
typesVisited typeStack // to prevent circular dependency scan
}
// NewScanContext creates new ScanContext from rows
@ -39,7 +41,7 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
commonIdentifier := toCommonIdentifier(names[0])
if len(names) > 1 {
commonIdentifier += "." + toCommonIdentifier(names[1])
commonIdentifier = concat(commonIdentifier, ".", toCommonIdentifier(names[1]))
}
commonIdentToColumnIndex[commonIdentifier] = i
@ -53,15 +55,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
commonIdentToColumnIndex: commonIdentToColumnIndex,
typeInfoMap: make(map[string]typeInfo),
typesVisited: newTypeStack(),
}, nil
}
func createScanSlice(columnCount int) []interface{} {
scanSlice := make([]interface{}, columnCount)
scanPtrSlice := make([]interface{}, columnCount)
for i := range scanPtrSlice {
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value
var a interface{}
scanPtrSlice[i] = &a // if destination is pointer to interface sql.Scan will just forward driver value
}
return scanPtrSlice
@ -72,8 +76,8 @@ type typeInfo struct {
}
type fieldMapping struct {
complexType bool // slice or struct
columnIndex int
complexType bool // slice and struct are complex types
rowIndex int // index in ScanContext.row
implementsScanner bool
}
@ -82,7 +86,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
typeMapKey := structType.String()
if parentField != nil {
typeMapKey += string(parentField.Tag)
typeMapKey = concat(typeMapKey, string(parentField.Tag))
}
if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok {
@ -100,7 +104,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{
columnIndex: columnIndex,
rowIndex: columnIndex,
}
if implementsScannerType(field.Type) {
@ -128,14 +132,15 @@ func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect.
mapKey := structType.Name()
if structField != nil {
mapKey += structField.Type.String()
mapKey = concat(mapKey, structField.Type.String())
}
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo)
}
groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack())
tempTypeStack := newTypeStack()
groupKeyInfo := s.getGroupKeyInfo(structType, structField, &tempTypeStack)
s.groupKeyInfoCache[mapKey] = groupKeyInfo
@ -150,10 +155,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
var groupKeys []string
for _, index := range groupKeyInfo.indexes {
cellValue := s.rowElem(index)
subKey := valueToString(reflect.ValueOf(cellValue))
groupKeys = append(groupKeys, subKey)
groupKeys = append(groupKeys, s.rowElemToString(index))
}
var subTypesGroupKeys []string
@ -161,7 +163,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
}
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")"
return concat(groupKeyInfo.typeName, "(", strings.Join(groupKeys, ","), strings.Join(subTypesGroupKeys, ","), ")")
}
func (s *ScanContext) getGroupKeyInfo(
@ -231,32 +233,36 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
return index
}
func (s *ScanContext) rowElem(index int) interface{} {
cellValue := reflect.ValueOf(s.row[index])
if cellValue.IsValid() && !cellValue.IsNil() {
return cellValue.Elem().Interface()
}
return nil
// rowElemValue always returns non-ptr value,
// invalid value is nil
func (s *ScanContext) rowElemValue(index int) reflect.Value {
scannedValue := reflect.ValueOf(s.row[index])
return scannedValue.Elem().Elem() // no need to check validity of Elem, because s.row[index] always contains interface in interface
}
func (s *ScanContext) rowElemValuePtr(index int) reflect.Value {
rowElem := s.rowElem(index)
rowElemValue := reflect.ValueOf(rowElem)
func (s *ScanContext) rowElemToString(index int) string {
value := s.rowElemValue(index)
if !value.IsValid() {
return "nil"
}
valueInterface := value.Interface()
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
func (s *ScanContext) rowElemValueClonePtr(index int) reflect.Value {
rowElemValue := s.rowElemValue(index)
if !rowElemValue.IsValid() {
return reflect.Value{}
}
if rowElemValue.Kind() == reflect.Ptr {
return rowElemValue
}
if rowElemValue.CanAddr() {
return rowElemValue.Addr()
}
newElem := reflect.New(rowElemValue.Type())
newElem.Elem().Set(rowElemValue)
return newElem

View file

@ -4,9 +4,9 @@ import "reflect"
type typeStack []*reflect.Type
func newTypeStack() *typeStack {
func newTypeStack() typeStack {
stack := make(typeStack, 0, 20)
return &stack
return stack
}
func (s *typeStack) isEmpty() bool {

View file

@ -18,9 +18,9 @@ func implementsScannerType(fieldType reflect.Type) bool {
return true
}
typePtr := reflect.New(fieldType).Type()
fieldTypePtr := reflect.New(fieldType).Type()
return typePtr.Implements(scannerInterfaceType)
return fieldTypePtr.Implements(scannerInterfaceType)
}
func getScanner(value reflect.Value) sql.Scanner {
@ -68,9 +68,9 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
err = assign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
err = assign(objPtrValue.Elem(), newSliceElemValue)
}
if err != nil {
@ -138,29 +138,6 @@ func initializeValueIfNilPtr(value reflect.Value) {
}
}
func valueToString(value reflect.Value) string {
if !value.IsValid() {
return "nil"
}
var valueInterface interface{}
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return "nil"
}
valueInterface = value.Elem().Interface()
} else {
valueInterface = value.Interface()
}
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
var timeType = reflect.TypeOf(time.Now())
var uuidType = reflect.TypeOf(uuid.New())
var byteArrayType = reflect.TypeOf([]byte(""))
@ -180,30 +157,35 @@ func isSimpleModelType(objType reflect.Type) bool {
return objType == timeType || objType == uuidType || objType == byteArrayType
}
func isIntegerType(objType reflect.Type) bool {
objType = indirectType(objType)
// source can't be pointer
// destination can be pointer
func assign(source, destination reflect.Value) error {
if destination.Kind() == reflect.Ptr {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
destination = destination.Elem()
}
return false
}
err := tryAssign(source, destination)
func isFloatType(value reflect.Type) bool {
switch value.Kind() {
case reflect.Float32, reflect.Float64:
return true
if err != nil {
// needs for the type conversions are rare, so we leave conversion as a last assign step if everything else fails
if tryConvert(source, destination) {
return nil
}
return err
}
return false
return nil
}
func assignIfAssignable(source, destination reflect.Value) bool {
if source.Type().AssignableTo(destination.Type()) {
switch source.Type() {
sourceType := source.Type()
if sourceType.AssignableTo(destination.Type()) {
switch sourceType {
case byteArrayType:
destination.SetBytes(cloneBytes(source.Interface().([]byte)))
default:
@ -215,31 +197,17 @@ func assignIfAssignable(source, destination reflect.Value) bool {
return false
}
// source and destination are non-ptr values
func tryAssign(source, destination reflect.Value) error {
if assignIfAssignable(source, destination) {
return nil
}
sourceType := source.Type()
destinationType := destination.Type()
if sourceType != destinationType &&
!isFloatType(destinationType) && // to preserve precision during conversion
!(isIntegerType(sourceType) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
sourceType.ConvertibleTo(destinationType) {
source = source.Convert(destinationType)
}
if assignIfAssignable(source, destination) {
return nil
}
sourceInterface := source.Interface()
switch destination.Interface().(type) {
case bool:
switch destination.Type().Kind() {
case reflect.Bool:
var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface)
@ -250,7 +218,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetBool(nullBool.Bool)
case float32, float64:
case reflect.Float32, reflect.Float64:
var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface)
@ -261,7 +229,7 @@ func tryAssign(source, destination reflect.Value) error {
if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64)
}
case int, int8, int16, int32, int64:
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var integer sql.NullInt64
err := integer.Scan(sourceInterface)
@ -273,7 +241,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetInt(integer.Int64)
}
case uint, uint8, uint16, uint32, uint64:
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var uInt internal.NullUInt64
err := uInt.Scan(sourceInterface)
@ -286,7 +254,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetUint(uInt.UInt64)
}
case string:
case reflect.String:
var str sql.NullString
err := str.Scan(sourceInterface)
@ -298,53 +266,44 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetString(str.String)
}
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
switch destination.Interface().(type) {
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
}
}
return nil
}
func tryConvert(source, destination reflect.Value) bool {
destinationType := destination.Type()
if source.Type().ConvertibleTo(destinationType) {
source = source.Convert(destinationType)
return assignIfAssignable(source, destination)
}
return false
}
func setZeroValue(value reflect.Value) {
if !value.IsZero() {
value.Set(reflect.Zero(value.Type()))
}
}
func setReflectValue(source, destination reflect.Value) error {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
// source is nil, destination should be its zero value
setZeroValue(destination)
return nil
}
source = source.Elem()
}
if destination.Kind() == reflect.Ptr {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
destination = destination.Elem()
}
return tryAssign(source, destination)
}
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
if len(primaryKeyOverwrites) > 0 {
return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
@ -398,3 +357,11 @@ func cloneBytes(b []byte) []byte {
copy(c, b)
return c
}
func concat(stringList ...string) string {
var b strings.Builder
for _, str := range stringList {
b.WriteString(str)
}
return b.String()
}