From 02123005c1a243bc6c787160030e96114908eec2 Mon Sep 17 00:00:00 2001 From: go-jet Date: Tue, 7 Dec 2021 17:16:10 +0100 Subject: [PATCH] [QRM] Prevent recursive scan if destination contains circular dependency. --- qrm/qrm.go | 54 +++++-- qrm/scan_context.go | 21 ++- qrm/type_stack.go | 40 +++++ tests/postgres/select_test.go | 291 ++++++++++++++++++++++++++++++++++ 4 files changed, 389 insertions(+), 17 deletions(-) create mode 100644 qrm/type_stack.go diff --git a/qrm/qrm.go b/qrm/qrm.go index 4502402..473f8f3 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -87,7 +87,7 @@ func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() - _, err = mapRowToSlice(scanContext, "", tempSlicePtrValue, nil) + _, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil) if err != nil { return fmt.Errorf("failed to map a row, %w", err) @@ -141,7 +141,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, scanContext.rowNum++ - _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil) + _, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil) if err != nil { return @@ -164,7 +164,12 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, return } -func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { +func mapRowToSlice( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, + slicePtrValue reflect.Value, + field *reflect.StructField) (updated bool, err error) { sliceElemType := getSliceElemType(slicePtrValue) @@ -184,12 +189,12 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl if ok { structPtrValue := getSliceElemPtrAt(slicePtrValue, index) - return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true) + return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true) } destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) - updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field) + updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field) if err != nil { return @@ -228,10 +233,25 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value return } -func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { +func mapRowToStruct( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, // to prevent circular dependency scan + structPtrValue reflect.Value, + parentField *reflect.StructField, + onlySlices ...bool, // small optimization, not to assign to already assigned struct fields +) (updated bool, err error) { + mapOnlySlices := len(onlySlices) > 0 structType := structPtrValue.Type().Elem() + if typesVisited.contains(&structType) { + return false, nil + } + + typesVisited.push(&structType) + defer typesVisited.pop() + typeInf := scanContext.getTypeInfo(structType, parentField) structValue := structPtrValue.Elem() @@ -248,7 +268,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re if fieldMap.complexType { var changed bool - changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) + changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field) if err != nil { return @@ -295,7 +315,12 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re return } -func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { +func mapRowToDestinationValue( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, + dest reflect.Value, + structField *reflect.StructField) (updated bool, err error) { var destPtrValue reflect.Value @@ -309,7 +334,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re } } - updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) + updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField) if err != nil { return @@ -322,16 +347,21 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } -func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { +func mapRowToDestinationPtr( + scanContext *scanContext, + groupKey string, + typesVisited *typeStack, + destPtrValue reflect.Value, + structField *reflect.StructField) (updated bool, err error) { utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.") destValueKind := destPtrValue.Elem().Kind() if destValueKind == reflect.Struct { - return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) + return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField) } else if destValueKind == reflect.Slice { - return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) + return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField) } else { panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) } diff --git a/qrm/scan_context.go b/qrm/scan_context.go index dbc4b87..61feb75 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -132,7 +132,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect. return s.constructGroupKey(groupKeyInfo) } - groupKeyInfo := s.getGroupKeyInfo(structType, structField) + groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack()) s.groupKeyInfoCache[mapKey] = groupKeyInfo @@ -144,7 +144,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { return fmt.Sprintf("|ROW:%d|", s.rowNum) } - groupKeys := []string{} + var groupKeys []string for _, index := range groupKeyInfo.indexes { cellValue := s.rowElem(index) @@ -153,7 +153,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { groupKeys = append(groupKeys, subKey) } - subTypesGroupKeys := []string{} + var subTypesGroupKeys []string for _, subType := range groupKeyInfo.subTypes { subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) } @@ -161,9 +161,20 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" } -func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { +func (s *scanContext) getGroupKeyInfo( + structType reflect.Type, + parentField *reflect.StructField, + typeVisited *typeStack) groupKeyInfo { + ret := groupKeyInfo{typeName: structType.Name()} + if typeVisited.contains(&structType) { + return ret + } + + typeVisited.push(&structType) + defer typeVisited.pop() + typeName := getTypeName(structType, parentField) primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField) @@ -176,7 +187,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl continue } - subType := s.getGroupKeyInfo(fieldType, &field) + subType := s.getGroupKeyInfo(fieldType, &field, typeVisited) if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { ret.subTypes = append(ret.subTypes, subType) diff --git a/qrm/type_stack.go b/qrm/type_stack.go new file mode 100644 index 0000000..235c06e --- /dev/null +++ b/qrm/type_stack.go @@ -0,0 +1,40 @@ +package qrm + +import "reflect" + +type typeStack []*reflect.Type + +func newTypeStack() *typeStack { + stack := make(typeStack, 0, 20) + return &stack +} + +func (s *typeStack) isEmpty() bool { + return len(*s) == 0 +} + +func (s *typeStack) push(t *reflect.Type) { + *s = append(*s, t) +} + +func (s *typeStack) pop() bool { + if s.isEmpty() { + return false + } + *s = (*s)[:len(*s)-1] + return true +} + +func (s *typeStack) contains(t *reflect.Type) bool { + if s.isEmpty() { + return false + } + + for _, typ := range *s { + if *typ == *t { + return true + } + } + + return false +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index b75db2c..988499a 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -2112,3 +2112,294 @@ FROM dvds.address; require.Len(t, dest, 603) }) } + +type FilmWrap struct { + model.Film + + Actors []ActorWrap +} + +type ActorWrap struct { + model.Actor + + Films []FilmWrap +} + +func TestRecursionScanNxM(t *testing.T) { + + stmt := SELECT( + Actor.AllColumns, + Film.AllColumns, + ).FROM( + Actor. + INNER_JOIN(FilmActor, Actor.ActorID.EQ(FilmActor.ActorID)). + INNER_JOIN(Film, Film.FilmID.EQ(FilmActor.FilmID)), + ).ORDER_BY( + Actor.ActorID, + Film.FilmID, + ).LIMIT(100) + + t.Run("film->actors", func(t *testing.T) { + var films []FilmWrap + err := stmt.Query(db, &films) + + require.NoError(t, err) + require.Len(t, films, 95) + testutils.AssertJSON(t, films[:2], ` +[ + { + "FilmID": 1, + "Title": "Academy Dinosaur", + "Description": "A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 6, + "RentalRate": 0.99, + "Length": 86, + "ReplacementCost": 20.99, + "Rating": "PG", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", + "Actors": [ + { + "ActorID": 1, + "FirstName": "Penelope", + "LastName": "Guiness", + "LastUpdate": "2013-05-26T14:47:57.62Z", + "Films": null + } + ] + }, + { + "FilmID": 23, + "Title": "Anaconda Confessions", + "Description": "A Lacklusture Display of a Dentist And a Dentist who must Fight a Girl in Australia", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 3, + "RentalRate": 0.99, + "Length": 92, + "ReplacementCost": 9.99, + "Rating": "R", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", + "Actors": [ + { + "ActorID": 1, + "FirstName": "Penelope", + "LastName": "Guiness", + "LastUpdate": "2013-05-26T14:47:57.62Z", + "Films": null + }, + { + "ActorID": 4, + "FirstName": "Jennifer", + "LastName": "Davis", + "LastUpdate": "2013-05-26T14:47:57.62Z", + "Films": null + } + ] + } +] +`) + + }) + + t.Run("actors->films", func(t *testing.T) { + var actors []ActorWrap + + err := stmt.Query(db, &actors) + + require.NoError(t, err) + require.Equal(t, len(actors), 5) + require.Equal(t, actors[0].ActorID, int32(1)) + require.Equal(t, actors[0].FirstName, "Penelope") + require.Len(t, actors[0].Films, 19) + testutils.AssertJSON(t, actors[0].Films[:2], ` +[ + { + "FilmID": 1, + "Title": "Academy Dinosaur", + "Description": "A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 6, + "RentalRate": 0.99, + "Length": 86, + "ReplacementCost": 20.99, + "Rating": "PG", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{\"Deleted Scenes\",\"Behind the Scenes\"}", + "Fulltext": "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", + "Actors": null + }, + { + "FilmID": 23, + "Title": "Anaconda Confessions", + "Description": "A Lacklusture Display of a Dentist And a Dentist who must Fight a Girl in Australia", + "ReleaseYear": 2006, + "LanguageID": 1, + "RentalDuration": 3, + "RentalRate": 0.99, + "Length": 92, + "ReplacementCost": 9.99, + "Rating": "R", + "LastUpdate": "2013-05-26T14:50:58.951Z", + "SpecialFeatures": "{Trailers,\"Deleted Scenes\"}", + "Fulltext": "'anaconda':1 'australia':18 'confess':2 'dentist':8,11 'display':5 'fight':14 'girl':16 'lacklustur':4 'must':13", + "Actors": null + } +] +`) + }) +} + +type StoreWrap struct { + model.Store + + Staffs []StaffWrap +} + +type StaffWrap struct { + model.Staff + + Store StoreWrap +} + +func TestRecursionScanNx1(t *testing.T) { + stmt := SELECT( + Store.AllColumns, + Staff.AllColumns, + ).FROM( + Store. + INNER_JOIN(Staff, Staff.StoreID.EQ(Store.StoreID)), + ).ORDER_BY( + Store.StoreID, + Staff.StaffID, + ) + + t.Run("store->staff", func(t *testing.T) { + var stores []StoreWrap + + err := stmt.Query(db, &stores) + + require.NoError(t, err) + require.Len(t, stores, 2) + + testutils.AssertJSON(t, stores, ` +[ + { + "StoreID": 1, + "ManagerStaffID": 1, + "AddressID": 1, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": [ + { + "StaffID": 1, + "FirstName": "Mike", + "LastName": "Hillyer", + "AddressID": 3, + "Email": "Mike.Hillyer@sakilastaff.com", + "StoreID": 1, + "Active": true, + "Username": "Mike", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": "iVBORw0KWgo=", + "Store": { + "StoreID": 0, + "ManagerStaffID": 0, + "AddressID": 0, + "LastUpdate": "0001-01-01T00:00:00Z", + "Staffs": null + } + } + ] + }, + { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": [ + { + "StaffID": 2, + "FirstName": "Jon", + "LastName": "Stephens", + "AddressID": 4, + "Email": "Jon.Stephens@sakilastaff.com", + "StoreID": 2, + "Active": true, + "Username": "Jon", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": null, + "Store": { + "StoreID": 0, + "ManagerStaffID": 0, + "AddressID": 0, + "LastUpdate": "0001-01-01T00:00:00Z", + "Staffs": null + } + } + ] + } +] +`) + }) + + t.Run("staff->store", func(t *testing.T) { + + var staffs []StaffWrap + + err := stmt.Query(db, &staffs) + require.NoError(t, err) + + testutils.AssertJSON(t, staffs, ` +[ + { + "StaffID": 1, + "FirstName": "Mike", + "LastName": "Hillyer", + "AddressID": 3, + "Email": "Mike.Hillyer@sakilastaff.com", + "StoreID": 1, + "Active": true, + "Username": "Mike", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": "iVBORw0KWgo=", + "Store": { + "StoreID": 1, + "ManagerStaffID": 1, + "AddressID": 1, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": null + } + }, + { + "StaffID": 2, + "FirstName": "Jon", + "LastName": "Stephens", + "AddressID": 4, + "Email": "Jon.Stephens@sakilastaff.com", + "StoreID": 2, + "Active": true, + "Username": "Jon", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": null, + "Store": { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": null + } + } +] +`) + }) +}