diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index bebb9a7..6609902 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -157,20 +157,16 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl return false, errors.New("Unsupported dest type: " + structField.Name + " " + structField.Type.String()) } - structGroupKey := getGroupKey(scanContext, sliceElemType, structField) - - if structGroupKey == "" { - structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|" - } + structGroupKey := scanContext.getGroupKey(sliceElemType, structField) groupKey = groupKey + ":" + structGroupKey - index, ok := scanContext.uniqueObjectsMap[groupKey] + index, ok := scanContext.uniqueDestObjectsMap[groupKey] if ok { structPtrValue := getSliceElemPtrAt(slicePtrValue, index) - return mapRowToStruct(scanContext, groupKey, structPtrValue, structField) + return mapRowToStruct(scanContext, groupKey, structPtrValue, structField, true) } else { destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) @@ -181,7 +177,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl } if updated { - scanContext.uniqueObjectsMap[groupKey] = slicePtrValue.Elem().Len() + scanContext.uniqueDestObjectsMap[groupKey] = slicePtrValue.Elem().Len() err = appendElemToSlice(slicePtrValue, destinationStructPtr) if err != nil { @@ -193,54 +189,6 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl return } -func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string { - tableName, _ := getRefTableNameFrom(structField) - - if tableName == "" { - tableName = structType.Name() - } - - groupKeys := []string{} - - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - - if !isGoBaseType(field.Type) { - var structType reflect.Type - if field.Type.Kind() == reflect.Struct { - structType = field.Type - } else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct { - structType = field.Type.Elem() - } else { - continue - } - - structGroupKey := getGroupKey(scanContext, structType, &field) - - if structGroupKey != "" { - groupKeys = append(groupKeys, structGroupKey) - } - } else if tagInfo(field.Tag.Get("sql")).isPrimaryKey { - fieldName := field.Name - - cellValue := scanContext.getCellValue(tableName, fieldName) - subKey := valueToString(cellValue) - - if subKey != "" { - groupKeys = append(groupKeys, subKey) - } - } - } - - if len(groupKeys) == 0 { - return "" - } - - groupKey := "{" + structType.Name() + "(" + strings.Join(groupKeys, ",") + ")}" - - return groupKey -} - func getSliceElemType(slicePtrValue reflect.Value) reflect.Type { sliceTypePtr := slicePtrValue.Type() @@ -339,6 +287,78 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } +func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { + structType := structPtrValue.Type().Elem() + structValue := structPtrValue.Elem() + + tableName, _ := getRefTableNameFrom(structField) + + if tableName == "" { + tableName = structType.Name() + } + + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + + fieldValue := structValue.Field(i) + columnName := field.Name + + if scannerValue, ok := implementsScanner(fieldValue); ok { + if len(onlySlices) > 0 { + continue + } + + cellValue := scanContext.getCellValue(tableName, columnName) + + if cellValue == nil { + continue + } + + initializeValueIfNilPtr(fieldValue) + + scanner := scannerValue.Interface().(sql.Scanner) + + err = scanner.Scan(cellValue) + + if err != nil { + err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) + return + } + updated = true + } else if isGoBaseType(field.Type) { + if len(onlySlices) > 0 { + continue + } + + cellValue := scanContext.getCellValue(tableName, columnName) + + if cellValue != nil { + updated = true + initializeValueIfNilPtr(fieldValue) + err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) + + if err != nil { + err = fmt.Errorf("Scan: %s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) + return + } + } + } else { + var changed bool + changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) + + if err != nil { + return + } + + if changed { + updated = true + } + } + } + + return +} + func getRefTableNameFrom(structField *reflect.StructField) (table, column string) { if structField == nil { return @@ -365,71 +385,7 @@ func getRefTableNameFrom(structField *reflect.StructField) (table, column string return fieldType.Name(), "" } -func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { - structType := structPtrValue.Type().Elem() - structValue := structPtrValue.Elem() - - tableName, _ := getRefTableNameFrom(structField) - - if tableName == "" { - tableName = structType.Name() - } - - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - - fieldValue := structValue.Field(i) - fieldName := field.Name - - if scannerValue, ok := implementsScanner(fieldValue); ok { - cellValue := scanContext.getCellValue(tableName, fieldName) - - if cellValue == nil { - continue - } - - initializeValueIfNil(fieldValue) - - scanner := scannerValue.Interface().(sql.Scanner) - - err = scanner.Scan(cellValue) - - if err != nil { - err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) - return - } - updated = true - } else if isGoBaseType(field.Type) { - cellValue := scanContext.getCellValue(tableName, fieldName) - - if cellValue != nil { - updated = true - initializeValueIfNil(fieldValue) - err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) - - if err != nil { - err = fmt.Errorf("Scan: %s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) - return - } - } - } else { - var changed bool - changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) - - if err != nil { - return - } - - if changed { - updated = true - } - } - } - - return -} - -func initializeValueIfNil(value reflect.Value) { +func initializeValueIfNilPtr(value reflect.Value) { if !value.IsValid() || !value.CanSet() { return } @@ -451,16 +407,19 @@ func implementsScanner(value reflect.Value) (reflect.Value, bool) { return value, false } -func valueToString(val interface{}) string { - if val == nil { - return "" - } +func valueToString(value reflect.Value) string { - value := reflect.ValueOf(val) + if !value.IsValid() { + return "nil" + } var valueInterface interface{} if value.Kind() == reflect.Ptr { - valueInterface = value.Elem().Interface() + if value.IsNil() { + return "nil" + } else { + valueInterface = value.Elem().Interface() + } } else { valueInterface = value.Interface() } @@ -572,13 +531,13 @@ func newScanType(columnType *sql.ColumnType) reflect.Type { } type scanContext struct { - rowNum int - columnNames []string + rowNum int + + row []interface{} + uniqueDestObjectsMap map[string]int - row []interface{} - uniqueObjectsMap map[string]int - groupKeyMap map[string]string columnNameIndexMap map[string]int + groupKeyInfoCache map[string]groupKeyInfo } func newScanContext(rows *sql.Rows) (*scanContext, error) { @@ -601,42 +560,131 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { } return &scanContext{ - row: createScanValue(columnTypes), - columnNames: columnNames, - uniqueObjectsMap: make(map[string]int), - groupKeyMap: make(map[string]string), + row: createScanValue(columnTypes), + uniqueDestObjectsMap: make(map[string]int), + + groupKeyInfoCache: make(map[string]groupKeyInfo), + columnNameIndexMap: columnNameIndexMap, }, nil } -func (s *scanContext) columnIndex(structName, fieldName string) int { - if structName == "" { - name := strings.ToLower(fieldName) +func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { + + mapKey := structType.Name() + + if structField != nil { + mapKey += structField.Type.String() + } + + if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok { + return s.constructGroupKey(groupKeyInfo) + } else { + groupKeyInfo := s.getGroupKeyInfo(structType, structField) + + s.groupKeyInfoCache[mapKey] = groupKeyInfo + + return s.constructGroupKey(groupKeyInfo) + } +} + +func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { + if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { + return "|ROW: " + strconv.Itoa(s.rowNum) + "|" + } + + groupKeys := []string{} + + for _, index := range groupKeyInfo.indexes { + cellValue := s.rowElem(index) + subKey := valueToString(reflect.ValueOf(cellValue)) + + groupKeys = append(groupKeys, subKey) + } + + subTypesGroupKeys := []string{} + for _, subType := range groupKeyInfo.subTypes { + subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) + } + + return "{" + groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")}" +} + +func (s *scanContext) getGroupKeyInfo(structType reflect.Type, structField *reflect.StructField) groupKeyInfo { + tableName, _ := getRefTableNameFrom(structField) + + if tableName == "" { + tableName = structType.Name() + } + + ret := groupKeyInfo{typeName: structType.Name()} + + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + + if !isGoBaseType(field.Type) { + var structType reflect.Type + if field.Type.Kind() == reflect.Struct { + structType = field.Type + } else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct { + structType = field.Type.Elem() + } else { + continue + } + + subType := s.getGroupKeyInfo(structType, &field) + + if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { + ret.subTypes = append(ret.subTypes, subType) + } + } else if tagInfo(field.Tag.Get("sql")).isPrimaryKey { + index := s.columnIndex(tableName, field.Name) + + if index < 0 { + continue + } + + ret.indexes = append(ret.indexes, index) + } + } + + return ret +} + +type groupKeyInfo struct { + typeName string + indexes []int + subTypes []groupKeyInfo +} + +func (s *scanContext) columnIndex(tableName, columnName string) int { + if tableName == "" { + name := strings.ToLower(columnName) if i, ok := s.columnNameIndexMap[name]; ok { return i } - name = strings.ToLower(snaker.CamelToSnake(fieldName)) + name = strings.ToLower(snaker.CamelToSnake(columnName)) if i, ok := s.columnNameIndexMap[name]; ok { return i } } else { - name := strings.ToLower(structName + "." + fieldName) + name := strings.ToLower(tableName + "." + columnName) if i, ok := s.columnNameIndexMap[name]; ok { return i } - name = strings.ToLower(structName + "." + snaker.CamelToSnake(fieldName)) + name = strings.ToLower(snaker.CamelToSnake(tableName) + "." + snaker.CamelToSnake(columnName)) if i, ok := s.columnNameIndexMap[name]; ok { return i } - name = strings.ToLower(snaker.CamelToSnake(structName) + "." + fieldName) + name = strings.ToLower(tableName + "." + snaker.CamelToSnake(columnName)) if i, ok := s.columnNameIndexMap[name]; ok { return i } - name = strings.ToLower(snaker.CamelToSnake(structName) + "." + snaker.CamelToSnake(fieldName)) + name = strings.ToLower(snaker.CamelToSnake(tableName) + "." + columnName) if i, ok := s.columnNameIndexMap[name]; ok { return i } diff --git a/tests/scan_test.go b/tests/scan_test.go index a57018f..59d7a28 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -2,6 +2,7 @@ package tests import ( "fmt" + "github.com/davecgh/go-spew/spew" . "github.com/go-jet/jet/sqlbuilder" "github.com/go-jet/jet/tests/.test_files/dvd_rental/dvds/model" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/dvds/table" @@ -450,6 +451,10 @@ func TestScanToSlice(t *testing.T) { err := query.Query(db, &dest) + fmt.Println(query.DebugSql()) + + spew.Dump(dest) + assert.NilError(t, err) assert.DeepEqual(t, dest.Film, film1) assert.DeepEqual(t, dest.IDs, []int32{1, 2, 3, 4, 5, 6, 7, 8}) diff --git a/tests/select_test.go b/tests/select_test.go index b0aa3b8..3b529c9 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -204,9 +204,9 @@ FROM dvds.film_actor INNER JOIN dvds.inventory ON (inventory.film_id = film.film_id) INNER JOIN dvds.rental ON (rental.inventory_id = inventory.inventory_id) ORDER BY film.film_id ASC -LIMIT 500; +LIMIT 1000; ` - for i := 0; i < 1; i++ { + for i := 0; i < 2; i++ { query := FilmActor. INNER_JOIN(Actor, FilmActor.ActorID.EQ(Actor.ActorID)). INNER_JOIN(Film, FilmActor.FilmID.EQ(Film.FilmID)). @@ -223,9 +223,9 @@ LIMIT 500; ). //WHERE(FilmActor.ActorID.GtEqL(1).AND(FilmActor.ActorID.LtEqL(2))). ORDER_BY(Film.FilmID.ASC()). - LIMIT(500) + LIMIT(1000) - assertStatementSql(t, query, expectedSql, int64(500)) + assertStatementSql(t, query, expectedSql, int64(1000)) var languageActorFilm []struct { model.Language @@ -248,7 +248,7 @@ LIMIT 500; assert.NilError(t, err) assert.Equal(t, len(languageActorFilm), 1) - assert.Equal(t, len(languageActorFilm[0].Films), 6) + assert.Equal(t, len(languageActorFilm[0].Films), 10) assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10) }