[QRM] Prevent recursive scan if destination contains circular dependency.

This commit is contained in:
go-jet 2021-12-07 17:16:10 +01:00
parent 7f54036b1a
commit 02123005c1
4 changed files with 389 additions and 17 deletions

View file

@ -87,7 +87,7 @@ func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
_, err = mapRowToSlice(scanContext, "", tempSlicePtrValue, nil)
_, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil)
if err != nil {
return fmt.Errorf("failed to map a row, %w", err)
@ -141,7 +141,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
scanContext.rowNum++
_, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil)
if err != nil {
return
@ -164,7 +164,12 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
return
}
func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
func mapRowToSlice(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack,
slicePtrValue reflect.Value,
field *reflect.StructField) (updated bool, err error) {
sliceElemType := getSliceElemType(slicePtrValue)
@ -184,12 +189,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
if ok {
structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true)
}
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field)
if err != nil {
return
@ -228,10 +233,25 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
return
}
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
func mapRowToStruct(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack, // to prevent circular dependency scan
structPtrValue reflect.Value,
parentField *reflect.StructField,
onlySlices ...bool, // small optimization, not to assign to already assigned struct fields
) (updated bool, err error) {
mapOnlySlices := len(onlySlices) > 0
structType := structPtrValue.Type().Elem()
if typesVisited.contains(&structType) {
return false, nil
}
typesVisited.push(&structType)
defer typesVisited.pop()
typeInf := scanContext.getTypeInfo(structType, parentField)
structValue := structPtrValue.Elem()
@ -248,7 +268,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
if fieldMap.complexType {
var changed bool
changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field)
if err != nil {
return
@ -295,7 +315,12 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
return
}
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) {
func mapRowToDestinationValue(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack,
dest reflect.Value,
structField *reflect.StructField) (updated bool, err error) {
var destPtrValue reflect.Value
@ -309,7 +334,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
}
}
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField)
if err != nil {
return
@ -322,16 +347,21 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
return
}
func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
func mapRowToDestinationPtr(
scanContext *scanContext,
groupKey string,
typesVisited *typeStack,
destPtrValue reflect.Value,
structField *reflect.StructField) (updated bool, err error) {
utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
destValueKind := destPtrValue.Elem().Kind()
if destValueKind == reflect.Struct {
return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField)
} else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField)
} else {
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
}