diff --git a/internal/jet/statement.go b/internal/jet/statement.go index b205801..183aaae 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -33,11 +33,13 @@ type Statement interface { // Rows wraps sql.Rows type to add query result mapping for Scan method type Rows struct { *sql.Rows + + scanContext *qrm.ScanContext } // Scan will map the Row values into struct destination func (r *Rows) Scan(destination interface{}) error { - return qrm.ScanOneRowToDest(r.Rows, destination) + return qrm.ScanOneRowToDest(r.scanContext, r.Rows, destination) } // SerializerStatement interface @@ -161,7 +163,16 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB) return nil, err } - return &Rows{rows}, nil + scanContext, err := qrm.NewScanContext(rows) + + if err != nil { + return nil, err + } + + return &Rows{ + Rows: rows, + scanContext: scanContext, + }, nil } func duration(f func()) time.Duration { diff --git a/qrm/qrm.go b/qrm/qrm.go index 3731c68..1ffd36f 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -63,48 +63,28 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr } // ScanOneRowToDest will scan one row into struct destination -func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error { +func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interface{}) error { utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") utils.MustBe(destPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct") - scanContext, err := newScanContext(rows) - - if err != nil { - return fmt.Errorf("failed to create scan context, %w", err) - } - if len(scanContext.row) == 0 { return errors.New("empty row slice") } - err = rows.Scan(scanContext.row...) + err := rows.Scan(scanContext.row...) if err != nil { return fmt.Errorf("rows scan error, %w", err) } - destinationPtrType := reflect.TypeOf(destPtr) - tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) - tempSliceValue := tempSlicePtrValue.Elem() + destValue := reflect.ValueOf(destPtr) - _, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil) + _, err = mapRowToStruct(scanContext, "", newTypeStack(), destValue, nil) if err != nil { return fmt.Errorf("failed to map a row, %w", err) } - // edge case when row result set contains only NULLs. - if tempSliceValue.Len() == 0 { - return nil - } - - destValue := reflect.ValueOf(destPtr).Elem() - firstTempSliceValue := tempSliceValue.Index(0).Elem() - - if destValue.Type().AssignableTo(firstTempSliceValue.Type()) { - destValue.Set(tempSliceValue.Index(0).Elem()) - } - return nil } @@ -120,7 +100,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, } defer rows.Close() - scanContext, err := newScanContext(rows) + scanContext, err := NewScanContext(rows) if err != nil { return @@ -157,7 +137,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, } func mapRowToSlice( - scanContext *scanContext, + scanContext *ScanContext, groupKey string, typesVisited *typeStack, slicePtrValue reflect.Value, @@ -204,7 +184,7 @@ func mapRowToSlice( return } -func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { +func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { index := 0 if field != nil { typeName, columnName := getTypeAndFieldName("", *field) @@ -226,7 +206,7 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value } func mapRowToStruct( - scanContext *scanContext, + scanContext *ScanContext, groupKey string, typesVisited *typeStack, // to prevent circular dependency scan structPtrValue reflect.Value, @@ -308,7 +288,7 @@ func mapRowToStruct( } func mapRowToDestinationValue( - scanContext *scanContext, + scanContext *ScanContext, groupKey string, typesVisited *typeStack, dest reflect.Value, @@ -340,7 +320,7 @@ func mapRowToDestinationValue( } func mapRowToDestinationPtr( - scanContext *scanContext, + scanContext *ScanContext, groupKey string, typesVisited *typeStack, destPtrValue reflect.Value, diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 61feb75..01cfe53 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -7,7 +7,9 @@ import ( "strings" ) -type scanContext struct { +// ScanContext contains information about current row processed, mapping from the row to the +// destination types and type grouping information. +type ScanContext struct { rowNum int64 row []interface{} uniqueDestObjectsMap map[string]int @@ -16,7 +18,8 @@ type scanContext struct { typeInfoMap map[string]typeInfo } -func newScanContext(rows *sql.Rows) (*scanContext, error) { +// NewScanContext creates new ScanContext from rows +func NewScanContext(rows *sql.Rows) (*ScanContext, error) { aliases, err := rows.Columns() if err != nil { @@ -42,7 +45,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { commonIdentToColumnIndex[commonIdentifier] = i } - return &scanContext{ + return &ScanContext{ row: createScanSlice(len(columnTypes)), uniqueDestObjectsMap: make(map[string]int), @@ -74,7 +77,7 @@ type fieldMapping struct { implementsScanner bool } -func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { +func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { typeMapKey := structType.String() @@ -120,7 +123,7 @@ type groupKeyInfo struct { subTypes []groupKeyInfo } -func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { +func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { mapKey := structType.Name() @@ -139,7 +142,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect. return s.constructGroupKey(groupKeyInfo) } -func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { +func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { return fmt.Sprintf("|ROW:%d|", s.rowNum) } @@ -161,7 +164,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" } -func (s *scanContext) getGroupKeyInfo( +func (s *ScanContext) getGroupKeyInfo( structType reflect.Type, parentField *reflect.StructField, typeVisited *typeStack) groupKeyInfo { @@ -210,7 +213,7 @@ func (s *scanContext) getGroupKeyInfo( return ret } -func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { +func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int { var key string if typeName != "" { @@ -228,7 +231,7 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { return index } -func (s *scanContext) rowElem(index int) interface{} { +func (s *ScanContext) rowElem(index int) interface{} { cellValue := reflect.ValueOf(s.row[index]) if cellValue.IsValid() && !cellValue.IsNil() { @@ -238,7 +241,7 @@ func (s *scanContext) rowElem(index int) interface{} { return nil } -func (s *scanContext) rowElemValuePtr(index int) reflect.Value { +func (s *ScanContext) rowElemValuePtr(index int) reflect.Value { rowElem := s.rowElem(index) rowElemValue := reflect.ValueOf(rowElem) diff --git a/qrm/utill.go b/qrm/utill.go index 6926c42..ca0db61 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -201,23 +201,38 @@ func isFloatType(value reflect.Type) bool { return false } -func tryAssign(source, destination reflect.Value) error { - - if source.Type() != destination.Type() && - !isFloatType(destination.Type()) && // to preserve precision during conversion - !(isIntegerType(source.Type()) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string - source.Type().ConvertibleTo(destination.Type()) { - - source = source.Convert(destination.Type()) - } - +func assignIfAssignable(source, destination reflect.Value) bool { if source.Type().AssignableTo(destination.Type()) { - switch b := source.Interface().(type) { - case []byte: - destination.SetBytes(cloneBytes(b)) + switch source.Type() { + case byteArrayType: + destination.SetBytes(cloneBytes(source.Interface().([]byte))) default: destination.Set(source) } + return true + } + + return false +} + +func tryAssign(source, destination reflect.Value) error { + + if assignIfAssignable(source, destination) { + return nil + } + + sourceType := source.Type() + destinationType := destination.Type() + + if sourceType != destinationType && + !isFloatType(destinationType) && // to preserve precision during conversion + !(isIntegerType(sourceType) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string + sourceType.ConvertibleTo(destinationType) { + + source = source.Convert(destinationType) + } + + if assignIfAssignable(source, destination) { return nil } @@ -302,38 +317,32 @@ func tryAssign(source, destination reflect.Value) error { return nil } +func setZeroValue(value reflect.Value) { + if !value.IsZero() { + value.Set(reflect.Zero(value.Type())) + } +} + func setReflectValue(source, destination reflect.Value) error { + if source.Kind() == reflect.Ptr { + if source.IsNil() { + // source is nil, destination should be its zero value + setZeroValue(destination) + return nil + } + source = source.Elem() + } + if destination.Kind() == reflect.Ptr { if destination.IsNil() { initializeValueIfNilPtr(destination) } - if source.Kind() == reflect.Ptr { - if source.IsNil() { - return nil // source is nil, destination should keep its zero value - } - source = source.Elem() - } - - if err := tryAssign(source, destination.Elem()); err != nil { - return err - } - - } else { - if source.Kind() == reflect.Ptr { - if source.IsNil() { - return nil // source is nil, destination should keep its zero value - } - source = source.Elem() - } - - if err := tryAssign(source, destination); err != nil { - return err - } + destination = destination.Elem() } - return nil + return tryAssign(source, destination) } func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 39f0e43..e60a7d4 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -951,8 +951,12 @@ func TestRowsScan(t *testing.T) { stmt := SELECT( Inventory.AllColumns, + Film.AllColumns, + Store.AllColumns, ).FROM( - Inventory, + Inventory. + INNER_JOIN(Film, Film.FilmID.EQ(Inventory.FilmID)). + INNER_JOIN(Store, Store.StoreID.EQ(Inventory.StoreID)), ).ORDER_BY( Inventory.InventoryID.ASC(), ) @@ -961,19 +965,42 @@ func TestRowsScan(t *testing.T) { require.NoError(t, err) for rows.Next() { - var inventory model.Inventory + var inventory struct { + model.Inventory + + Film model.Film + Store model.Store + } + err = rows.Scan(&inventory) require.NoError(t, err) - require.NotEqual(t, inventory.InventoryID, uint32(0)) - require.NotEqual(t, inventory.FilmID, uint16(0)) - require.NotEqual(t, inventory.StoreID, uint16(0)) - require.NotEqual(t, inventory.LastUpdate, time.Time{}) + require.NotEmpty(t, inventory.InventoryID) + require.NotEmpty(t, inventory.FilmID) + require.NotEmpty(t, inventory.StoreID) + require.NotEmpty(t, inventory.LastUpdate) + + require.NotEmpty(t, inventory.Film.FilmID) + require.NotEmpty(t, inventory.Film.Title) + require.NotEmpty(t, inventory.Film.Description) + + require.NotEmpty(t, inventory.Store.StoreID) + require.NotEmpty(t, inventory.Store.AddressID) + require.NotEmpty(t, inventory.Store.ManagerStaffID) if inventory.InventoryID == 2103 { require.Equal(t, inventory.FilmID, uint16(456)) require.Equal(t, inventory.StoreID, uint8(2)) require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T05:09:17Z") + + require.Equal(t, inventory.Film.FilmID, uint16(456)) + require.Equal(t, inventory.Film.Title, "INCH JET") + require.Equal(t, *inventory.Film.Description, "A Fateful Saga of a Womanizer And a Student who must Defeat a Butler in A Monastery") + require.Equal(t, *inventory.Film.ReleaseYear, int16(2006)) + + require.Equal(t, inventory.Store.StoreID, uint8(2)) + require.Equal(t, inventory.Store.ManagerStaffID, uint8(2)) + require.Equal(t, inventory.Store.AddressID, uint16(2)) } }