[QRM] Prevent recursive scan if destination contains circular dependency.
This commit is contained in:
parent
7f54036b1a
commit
02123005c1
4 changed files with 389 additions and 17 deletions
54
qrm/qrm.go
54
qrm/qrm.go
|
|
@ -87,7 +87,7 @@ func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error {
|
|||
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
|
||||
tempSliceValue := tempSlicePtrValue.Elem()
|
||||
|
||||
_, err = mapRowToSlice(scanContext, "", tempSlicePtrValue, nil)
|
||||
_, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to map a row, %w", err)
|
||||
|
|
@ -141,7 +141,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
|
|||
|
||||
scanContext.rowNum++
|
||||
|
||||
_, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
|
||||
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -164,7 +164,12 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
|
|||
return
|
||||
}
|
||||
|
||||
func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
|
||||
func mapRowToSlice(
|
||||
scanContext *scanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack,
|
||||
slicePtrValue reflect.Value,
|
||||
field *reflect.StructField) (updated bool, err error) {
|
||||
|
||||
sliceElemType := getSliceElemType(slicePtrValue)
|
||||
|
||||
|
|
@ -184,12 +189,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
|
|||
if ok {
|
||||
structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
|
||||
|
||||
return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
|
||||
return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true)
|
||||
}
|
||||
|
||||
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
|
||||
|
||||
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
|
||||
updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -228,10 +233,25 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
|
|||
return
|
||||
}
|
||||
|
||||
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
|
||||
func mapRowToStruct(
|
||||
scanContext *scanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack, // to prevent circular dependency scan
|
||||
structPtrValue reflect.Value,
|
||||
parentField *reflect.StructField,
|
||||
onlySlices ...bool, // small optimization, not to assign to already assigned struct fields
|
||||
) (updated bool, err error) {
|
||||
|
||||
mapOnlySlices := len(onlySlices) > 0
|
||||
structType := structPtrValue.Type().Elem()
|
||||
|
||||
if typesVisited.contains(&structType) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
typesVisited.push(&structType)
|
||||
defer typesVisited.pop()
|
||||
|
||||
typeInf := scanContext.getTypeInfo(structType, parentField)
|
||||
|
||||
structValue := structPtrValue.Elem()
|
||||
|
|
@ -248,7 +268,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
|
|||
|
||||
if fieldMap.complexType {
|
||||
var changed bool
|
||||
changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
|
||||
changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -295,7 +315,12 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
|
|||
return
|
||||
}
|
||||
|
||||
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) {
|
||||
func mapRowToDestinationValue(
|
||||
scanContext *scanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack,
|
||||
dest reflect.Value,
|
||||
structField *reflect.StructField) (updated bool, err error) {
|
||||
|
||||
var destPtrValue reflect.Value
|
||||
|
||||
|
|
@ -309,7 +334,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
|
|||
}
|
||||
}
|
||||
|
||||
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
|
||||
updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -322,16 +347,21 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
|
|||
return
|
||||
}
|
||||
|
||||
func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
|
||||
func mapRowToDestinationPtr(
|
||||
scanContext *scanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack,
|
||||
destPtrValue reflect.Value,
|
||||
structField *reflect.StructField) (updated bool, err error) {
|
||||
|
||||
utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
|
||||
|
||||
destValueKind := destPtrValue.Elem().Kind()
|
||||
|
||||
if destValueKind == reflect.Struct {
|
||||
return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
|
||||
return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField)
|
||||
} else if destValueKind == reflect.Slice {
|
||||
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
|
||||
return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField)
|
||||
} else {
|
||||
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.
|
|||
return s.constructGroupKey(groupKeyInfo)
|
||||
}
|
||||
|
||||
groupKeyInfo := s.getGroupKeyInfo(structType, structField)
|
||||
groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack())
|
||||
|
||||
s.groupKeyInfoCache[mapKey] = groupKeyInfo
|
||||
|
||||
|
|
@ -144,7 +144,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
|
|||
return fmt.Sprintf("|ROW:%d|", s.rowNum)
|
||||
}
|
||||
|
||||
groupKeys := []string{}
|
||||
var groupKeys []string
|
||||
|
||||
for _, index := range groupKeyInfo.indexes {
|
||||
cellValue := s.rowElem(index)
|
||||
|
|
@ -153,7 +153,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
|
|||
groupKeys = append(groupKeys, subKey)
|
||||
}
|
||||
|
||||
subTypesGroupKeys := []string{}
|
||||
var subTypesGroupKeys []string
|
||||
for _, subType := range groupKeyInfo.subTypes {
|
||||
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
|
||||
}
|
||||
|
|
@ -161,9 +161,20 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
|
|||
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")"
|
||||
}
|
||||
|
||||
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
|
||||
func (s *scanContext) getGroupKeyInfo(
|
||||
structType reflect.Type,
|
||||
parentField *reflect.StructField,
|
||||
typeVisited *typeStack) groupKeyInfo {
|
||||
|
||||
ret := groupKeyInfo{typeName: structType.Name()}
|
||||
|
||||
if typeVisited.contains(&structType) {
|
||||
return ret
|
||||
}
|
||||
|
||||
typeVisited.push(&structType)
|
||||
defer typeVisited.pop()
|
||||
|
||||
typeName := getTypeName(structType, parentField)
|
||||
primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField)
|
||||
|
||||
|
|
@ -176,7 +187,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl
|
|||
continue
|
||||
}
|
||||
|
||||
subType := s.getGroupKeyInfo(fieldType, &field)
|
||||
subType := s.getGroupKeyInfo(fieldType, &field, typeVisited)
|
||||
|
||||
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
|
||||
ret.subTypes = append(ret.subTypes, subType)
|
||||
|
|
|
|||
40
qrm/type_stack.go
Normal file
40
qrm/type_stack.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package qrm
|
||||
|
||||
import "reflect"
|
||||
|
||||
type typeStack []*reflect.Type
|
||||
|
||||
func newTypeStack() *typeStack {
|
||||
stack := make(typeStack, 0, 20)
|
||||
return &stack
|
||||
}
|
||||
|
||||
func (s *typeStack) isEmpty() bool {
|
||||
return len(*s) == 0
|
||||
}
|
||||
|
||||
func (s *typeStack) push(t *reflect.Type) {
|
||||
*s = append(*s, t)
|
||||
}
|
||||
|
||||
func (s *typeStack) pop() bool {
|
||||
if s.isEmpty() {
|
||||
return false
|
||||
}
|
||||
*s = (*s)[:len(*s)-1]
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *typeStack) contains(t *reflect.Type) bool {
|
||||
if s.isEmpty() {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, typ := range *s {
|
||||
if *typ == *t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue