diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 68f391a..01cc64f 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -139,7 +139,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl return } } - rowElemPtr := scanContext.rowElemPtr(index) + rowElemPtr := scanContext.rowElemValuePtr(index) if !rowElemPtr.IsNil() { updated = true @@ -411,7 +411,17 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re return } updated = true - } else if !isGoBaseType(field.Type) { + } else if isGoBaseType(field.Type) { + cellValue := getCellValue(scanContext, tableName, fieldName) + //spew.Dump(rowElem) + + //spew.Dump(rowColumnValue, fieldValue) + if cellValue != nil { + updated = true + initializeValueIfNil(fieldValue) + setReflectValue(reflect.ValueOf(cellValue), fieldValue) + } + } else { var changed bool changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) @@ -422,16 +432,6 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re if changed { updated = true } - } else { - cellValue := getCellValue(scanContext, tableName, fieldName) - //spew.Dump(rowElem) - - //spew.Dump(rowColumnValue, fieldValue) - if cellValue != nil { - updated = true - initializeValueIfNil(fieldValue) - setReflectValue(reflect.ValueOf(cellValue), fieldValue) - } } } @@ -518,22 +518,36 @@ func isGoBaseType(objType reflect.Type) bool { return false } -func setReflectValue(source, destination reflect.Value) { +func setReflectValue(source, destination reflect.Value) error { + var sourceElem reflect.Value if destination.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr { - destination.Set(source) + sourceElem = source } else { - newDestination := reflect.New(destination.Type().Elem()) - newDestination.Elem().Set(source) - destination.Set(newDestination) + if source.CanAddr() { + sourceElem = source.Addr() + } else { + newDestination := reflect.New(destination.Type().Elem()) + newDestination.Elem().Set(source) + + sourceElem = newDestination + } } } else { if source.Kind() == reflect.Ptr { - destination.Set(source.Elem()) + sourceElem = source.Elem() } else { - destination.Set(source) + sourceElem = source } } + + if !sourceElem.Type().AssignableTo(destination.Type()) { + return errors.New("Can't set " + sourceElem.Type().String() + " to " + destination.Type().String()) + } + + destination.Set(sourceElem) + + return nil } func getIndex(list []string, text string) int { @@ -639,7 +653,7 @@ func (s *scanContext) rowElem(index int) interface{} { return value } -func (s *scanContext) rowElemPtr(index int) reflect.Value { +func (s *scanContext) rowElemValuePtr(index int) reflect.Value { rowElem := s.rowElem(index) rowElemValue := reflect.ValueOf(rowElem) diff --git a/tests/scan_test.go b/tests/scan_test.go index 20b677d..6883908 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -133,6 +133,24 @@ func TestScanToStruct(t *testing.T) { assert.Error(t, err, "Unsupported dest type: Inventory ***model.Inventory") }) + t.Run("custom struct", func(t *testing.T) { + type Inventory struct { + InventoryID *int32 `sql:"unique"` + FilmID int16 + StoreID *int16 + } + + dest := Inventory{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + + assert.Equal(t, *dest.InventoryID, int32(1)) + assert.Equal(t, dest.FilmID, int16(1)) + assert.Equal(t, *dest.StoreID, int16(1)) + }) + } func TestScanToNestedStruct(t *testing.T) {