diff --git a/qrm/qrm.go b/qrm/qrm.go index 1ffd36f..50597cd 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -74,15 +74,15 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac err := rows.Scan(scanContext.row...) if err != nil { - return fmt.Errorf("rows scan error, %w", err) + return fmt.Errorf("jet: rows scan error, %w", err) } - destValue := reflect.ValueOf(destPtr) + destValuePtr := reflect.ValueOf(destPtr) - _, err = mapRowToStruct(scanContext, "", newTypeStack(), destValue, nil) + _, err = mapRowToStruct(scanContext, "", destValuePtr, nil) if err != nil { - return fmt.Errorf("failed to map a row, %w", err) + return fmt.Errorf("jet: failed to scan a row into destination, %w", err) } return nil @@ -121,7 +121,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, scanContext.rowNum++ - _, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil) + _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil) if err != nil { return scanContext.rowNum, err @@ -139,7 +139,6 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, func mapRowToSlice( scanContext *ScanContext, groupKey string, - typesVisited *typeStack, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { @@ -154,19 +153,19 @@ func mapRowToSlice( structGroupKey := scanContext.getGroupKey(sliceElemType, field) - groupKey = groupKey + "," + structGroupKey + groupKey = concat(groupKey, ",", structGroupKey) index, ok := scanContext.uniqueDestObjectsMap[groupKey] if ok { structPtrValue := getSliceElemPtrAt(slicePtrValue, index) - return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true) + return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true) } destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) - updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field) + updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field) if err != nil { return @@ -192,7 +191,7 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value return } } - rowElemPtr := scanContext.rowElemValuePtr(index) + rowElemPtr := scanContext.rowElemValueClonePtr(index) if rowElemPtr.IsValid() && !rowElemPtr.IsNil() { updated = true @@ -208,7 +207,6 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value func mapRowToStruct( scanContext *ScanContext, groupKey string, - typesVisited *typeStack, // to prevent circular dependency scan structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool, // small optimization, not to assign to already assigned struct fields @@ -217,12 +215,12 @@ func mapRowToStruct( mapOnlySlices := len(onlySlices) > 0 structType := structPtrValue.Type().Elem() - if typesVisited.contains(&structType) { + if scanContext.typesVisited.contains(&structType) { return false, nil } - typesVisited.push(&structType) - defer typesVisited.pop() + scanContext.typesVisited.push(&structType) + defer scanContext.typesVisited.pop() typeInf := scanContext.getTypeInfo(structType, parentField) @@ -240,7 +238,7 @@ func mapRowToStruct( if fieldMap.complexType { var changed bool - changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field) + changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) if err != nil { return @@ -251,34 +249,36 @@ func mapRowToStruct( } } else { - if mapOnlySlices || fieldMap.columnIndex == -1 { + if mapOnlySlices || fieldMap.rowIndex == -1 { continue } - cellValue := scanContext.rowElem(fieldMap.columnIndex) + scannedValue := scanContext.rowElemValue(fieldMap.rowIndex) - if cellValue == nil { + if !scannedValue.IsValid() { + setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value continue } - initializeValueIfNilPtr(fieldValue) updated = true if fieldMap.implementsScanner { - scanner := getScanner(fieldValue) + initializeValueIfNilPtr(fieldValue) + fieldScanner := getScanner(fieldValue) - err = scanner.Scan(cellValue) + value := scannedValue.Interface() + + err := fieldScanner.Scan(value) if err != nil { - err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) - return + return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err) } } else { - err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) + err := assign(scannedValue, fieldValue) if err != nil { - err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) - return + return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(), + field.Name, field.Type.String(), err) } } } @@ -290,7 +290,6 @@ func mapRowToStruct( func mapRowToDestinationValue( scanContext *ScanContext, groupKey string, - typesVisited *typeStack, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { @@ -306,7 +305,7 @@ func mapRowToDestinationValue( } } - updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField) + updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) if err != nil { return @@ -322,7 +321,6 @@ func mapRowToDestinationValue( func mapRowToDestinationPtr( scanContext *ScanContext, groupKey string, - typesVisited *typeStack, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { @@ -331,9 +329,9 @@ func mapRowToDestinationPtr( destValueKind := destPtrValue.Elem().Kind() if destValueKind == reflect.Struct { - return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField) + return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) } else if destValueKind == reflect.Slice { - return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField) + return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) } else { panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) } diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 01cfe53..fa99b5a 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -16,6 +16,8 @@ type ScanContext struct { commonIdentToColumnIndex map[string]int groupKeyInfoCache map[string]groupKeyInfo typeInfoMap map[string]typeInfo + + typesVisited typeStack // to prevent circular dependency scan } // NewScanContext creates new ScanContext from rows @@ -39,7 +41,7 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) { commonIdentifier := toCommonIdentifier(names[0]) if len(names) > 1 { - commonIdentifier += "." + toCommonIdentifier(names[1]) + commonIdentifier = concat(commonIdentifier, ".", toCommonIdentifier(names[1])) } commonIdentToColumnIndex[commonIdentifier] = i @@ -53,15 +55,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) { commonIdentToColumnIndex: commonIdentToColumnIndex, typeInfoMap: make(map[string]typeInfo), + + typesVisited: newTypeStack(), }, nil } func createScanSlice(columnCount int) []interface{} { - scanSlice := make([]interface{}, columnCount) scanPtrSlice := make([]interface{}, columnCount) for i := range scanPtrSlice { - scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value + var a interface{} + scanPtrSlice[i] = &a // if destination is pointer to interface sql.Scan will just forward driver value } return scanPtrSlice @@ -72,8 +76,8 @@ type typeInfo struct { } type fieldMapping struct { - complexType bool // slice or struct - columnIndex int + complexType bool // slice and struct are complex types + rowIndex int // index in ScanContext.row implementsScanner bool } @@ -82,7 +86,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect. typeMapKey := structType.String() if parentField != nil { - typeMapKey += string(parentField.Tag) + typeMapKey = concat(typeMapKey, string(parentField.Tag)) } if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok { @@ -100,7 +104,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect. columnIndex := s.typeToColumnIndex(newTypeName, fieldName) fieldMap := fieldMapping{ - columnIndex: columnIndex, + rowIndex: columnIndex, } if implementsScannerType(field.Type) { @@ -128,14 +132,15 @@ func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect. mapKey := structType.Name() if structField != nil { - mapKey += structField.Type.String() + mapKey = concat(mapKey, structField.Type.String()) } if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok { return s.constructGroupKey(groupKeyInfo) } - groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack()) + tempTypeStack := newTypeStack() + groupKeyInfo := s.getGroupKeyInfo(structType, structField, &tempTypeStack) s.groupKeyInfoCache[mapKey] = groupKeyInfo @@ -150,10 +155,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { var groupKeys []string for _, index := range groupKeyInfo.indexes { - cellValue := s.rowElem(index) - subKey := valueToString(reflect.ValueOf(cellValue)) - - groupKeys = append(groupKeys, subKey) + groupKeys = append(groupKeys, s.rowElemToString(index)) } var subTypesGroupKeys []string @@ -161,7 +163,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) } - return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" + return concat(groupKeyInfo.typeName, "(", strings.Join(groupKeys, ","), strings.Join(subTypesGroupKeys, ","), ")") } func (s *ScanContext) getGroupKeyInfo( @@ -231,32 +233,36 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int { return index } -func (s *ScanContext) rowElem(index int) interface{} { - cellValue := reflect.ValueOf(s.row[index]) - - if cellValue.IsValid() && !cellValue.IsNil() { - return cellValue.Elem().Interface() - } - - return nil +// rowElemValue always returns non-ptr value, +// invalid value is nil +func (s *ScanContext) rowElemValue(index int) reflect.Value { + scannedValue := reflect.ValueOf(s.row[index]) + return scannedValue.Elem().Elem() // no need to check validity of Elem, because s.row[index] always contains interface in interface } -func (s *ScanContext) rowElemValuePtr(index int) reflect.Value { - rowElem := s.rowElem(index) - rowElemValue := reflect.ValueOf(rowElem) +func (s *ScanContext) rowElemToString(index int) string { + value := s.rowElemValue(index) + + if !value.IsValid() { + return "nil" + } + + valueInterface := value.Interface() + + if t, ok := valueInterface.(fmt.Stringer); ok { + return t.String() + } + + return fmt.Sprintf("%#v", valueInterface) +} + +func (s *ScanContext) rowElemValueClonePtr(index int) reflect.Value { + rowElemValue := s.rowElemValue(index) if !rowElemValue.IsValid() { return reflect.Value{} } - if rowElemValue.Kind() == reflect.Ptr { - return rowElemValue - } - - if rowElemValue.CanAddr() { - return rowElemValue.Addr() - } - newElem := reflect.New(rowElemValue.Type()) newElem.Elem().Set(rowElemValue) return newElem diff --git a/qrm/type_stack.go b/qrm/type_stack.go index 235c06e..2bdf799 100644 --- a/qrm/type_stack.go +++ b/qrm/type_stack.go @@ -4,9 +4,9 @@ import "reflect" type typeStack []*reflect.Type -func newTypeStack() *typeStack { +func newTypeStack() typeStack { stack := make(typeStack, 0, 20) - return &stack + return stack } func (s *typeStack) isEmpty() bool { diff --git a/qrm/utill.go b/qrm/utill.go index ca0db61..dfb9a69 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -18,9 +18,9 @@ func implementsScannerType(fieldType reflect.Type) bool { return true } - typePtr := reflect.New(fieldType).Type() + fieldTypePtr := reflect.New(fieldType).Type() - return typePtr.Implements(scannerInterfaceType) + return fieldTypePtr.Implements(scannerInterfaceType) } func getScanner(value reflect.Value) sql.Scanner { @@ -68,9 +68,9 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e if newSliceElemValue.Kind() == reflect.Ptr { newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem())) - err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem()) + err = assign(objPtrValue.Elem(), newSliceElemValue.Elem()) } else { - err = tryAssign(objPtrValue.Elem(), newSliceElemValue) + err = assign(objPtrValue.Elem(), newSliceElemValue) } if err != nil { @@ -138,29 +138,6 @@ func initializeValueIfNilPtr(value reflect.Value) { } } -func valueToString(value reflect.Value) string { - - if !value.IsValid() { - return "nil" - } - - var valueInterface interface{} - if value.Kind() == reflect.Ptr { - if value.IsNil() { - return "nil" - } - valueInterface = value.Elem().Interface() - } else { - valueInterface = value.Interface() - } - - if t, ok := valueInterface.(fmt.Stringer); ok { - return t.String() - } - - return fmt.Sprintf("%#v", valueInterface) -} - var timeType = reflect.TypeOf(time.Now()) var uuidType = reflect.TypeOf(uuid.New()) var byteArrayType = reflect.TypeOf([]byte("")) @@ -180,30 +157,35 @@ func isSimpleModelType(objType reflect.Type) bool { return objType == timeType || objType == uuidType || objType == byteArrayType } -func isIntegerType(objType reflect.Type) bool { - objType = indirectType(objType) +// source can't be pointer +// destination can be pointer +func assign(source, destination reflect.Value) error { + if destination.Kind() == reflect.Ptr { + if destination.IsNil() { + initializeValueIfNilPtr(destination) + } - switch objType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return true + destination = destination.Elem() } - return false -} + err := tryAssign(source, destination) -func isFloatType(value reflect.Type) bool { - switch value.Kind() { - case reflect.Float32, reflect.Float64: - return true + if err != nil { + // needs for the type conversions are rare, so we leave conversion as a last assign step if everything else fails + if tryConvert(source, destination) { + return nil + } + + return err } - return false + return nil } func assignIfAssignable(source, destination reflect.Value) bool { - if source.Type().AssignableTo(destination.Type()) { - switch source.Type() { + sourceType := source.Type() + if sourceType.AssignableTo(destination.Type()) { + switch sourceType { case byteArrayType: destination.SetBytes(cloneBytes(source.Interface().([]byte))) default: @@ -215,31 +197,17 @@ func assignIfAssignable(source, destination reflect.Value) bool { return false } +// source and destination are non-ptr values func tryAssign(source, destination reflect.Value) error { if assignIfAssignable(source, destination) { return nil } - sourceType := source.Type() - destinationType := destination.Type() - - if sourceType != destinationType && - !isFloatType(destinationType) && // to preserve precision during conversion - !(isIntegerType(sourceType) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string - sourceType.ConvertibleTo(destinationType) { - - source = source.Convert(destinationType) - } - - if assignIfAssignable(source, destination) { - return nil - } - sourceInterface := source.Interface() - switch destination.Interface().(type) { - case bool: + switch destination.Type().Kind() { + case reflect.Bool: var nullBool internal.NullBool err := nullBool.Scan(sourceInterface) @@ -250,7 +218,7 @@ func tryAssign(source, destination reflect.Value) error { destination.SetBool(nullBool.Bool) - case float32, float64: + case reflect.Float32, reflect.Float64: var nullFloat sql.NullFloat64 err := nullFloat.Scan(sourceInterface) @@ -261,7 +229,7 @@ func tryAssign(source, destination reflect.Value) error { if nullFloat.Valid { destination.SetFloat(nullFloat.Float64) } - case int, int8, int16, int32, int64: + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: var integer sql.NullInt64 err := integer.Scan(sourceInterface) @@ -273,7 +241,7 @@ func tryAssign(source, destination reflect.Value) error { destination.SetInt(integer.Int64) } - case uint, uint8, uint16, uint32, uint64: + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: var uInt internal.NullUInt64 err := uInt.Scan(sourceInterface) @@ -286,7 +254,7 @@ func tryAssign(source, destination reflect.Value) error { destination.SetUint(uInt.UInt64) } - case string: + case reflect.String: var str sql.NullString err := str.Scan(sourceInterface) @@ -298,53 +266,44 @@ func tryAssign(source, destination reflect.Value) error { destination.SetString(str.String) } - case time.Time: - var nullTime internal.NullTime - - err := nullTime.Scan(sourceInterface) - if err != nil { - return err - } - - if nullTime.Valid { - destination.Set(reflect.ValueOf(nullTime.Time)) - } - default: - return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface()) + switch destination.Interface().(type) { + case time.Time: + var nullTime internal.NullTime + + err := nullTime.Scan(sourceInterface) + if err != nil { + return err + } + + if nullTime.Valid { + destination.Set(reflect.ValueOf(nullTime.Time)) + } + default: + return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface()) + } } return nil } +func tryConvert(source, destination reflect.Value) bool { + destinationType := destination.Type() + + if source.Type().ConvertibleTo(destinationType) { + source = source.Convert(destinationType) + return assignIfAssignable(source, destination) + } + + return false +} + func setZeroValue(value reflect.Value) { if !value.IsZero() { value.Set(reflect.Zero(value.Type())) } } -func setReflectValue(source, destination reflect.Value) error { - - if source.Kind() == reflect.Ptr { - if source.IsNil() { - // source is nil, destination should be its zero value - setZeroValue(destination) - return nil - } - source = source.Elem() - } - - if destination.Kind() == reflect.Ptr { - if destination.IsNil() { - initializeValueIfNilPtr(destination) - } - - destination = destination.Elem() - } - - return tryAssign(source, destination) -} - func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { if len(primaryKeyOverwrites) > 0 { return utils.StringSliceContains(primaryKeyOverwrites, field.Name) @@ -398,3 +357,11 @@ func cloneBytes(b []byte) []byte { copy(c, b) return c } + +func concat(stringList ...string) string { + var b strings.Builder + for _, str := range stringList { + b.WriteString(str) + } + return b.String() +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index e60a7d4..8bd028a 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -206,7 +206,7 @@ GROUP BY payment.customer_id; "RentalID": null, "Amount": 0, "PaymentDate": "0001-01-01T00:00:00Z", - "LastUpdate": "0001-01-01T00:00:00Z", + "LastUpdate": null, "Count": 8, "Sum": 38.92, "Avg": 4.865, @@ -964,14 +964,14 @@ func TestRowsScan(t *testing.T) { rows, err := stmt.Rows(context.Background(), db) require.NoError(t, err) + var inventory struct { + model.Inventory + + Film model.Film + Store model.Store + } + for rows.Next() { - var inventory struct { - model.Inventory - - Film model.Film - Store model.Store - } - err = rows.Scan(&inventory) require.NoError(t, err) @@ -1056,3 +1056,50 @@ func TestScanNumericToNumber(t *testing.T) { require.Equal(t, number.Float32, float32(1.234568e+09)) require.Equal(t, number.Float64, float64(1.23456789e+09)) } + +// scan into custom base types should be equivalent to the scan into base go types +func TestScanIntoCustomBaseTypes(t *testing.T) { + + type MyUint8 uint8 + type MyUint16 uint16 + type MyUint32 uint32 + type MyInt16 int16 + type MyFloat32 float32 + type MyFloat64 float64 + type MyString string + type MyTime = time.Time + + type film struct { + FilmID MyUint16 `sql:"primary_key"` + Title MyString + Description *MyString + ReleaseYear *MyInt16 + LanguageID MyUint8 + OriginalLanguageID *MyUint8 + RentalDuration MyUint8 + RentalRate MyFloat32 + Length *MyUint32 + ReplacementCost MyFloat64 + Rating *model.FilmRating + SpecialFeatures *MyString + LastUpdate MyTime + } + + stmt := SELECT( + Film.AllColumns, + ).FROM( + Film, + ).ORDER_BY( + Film.FilmID.ASC(), + ).LIMIT(3) + + var films []model.Film + err := stmt.Query(db, &films) + require.NoError(t, err) + + var myFilms []film + err = stmt.Query(db, &myFilms) + require.NoError(t, err) + + require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms)) +} diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 61b7bec..3078709 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -786,6 +786,123 @@ func TestRowsScan(t *testing.T) { requireQueryLogged(t, stmt, 0) } +func TestScanNullColumn(t *testing.T) { + stmt := SELECT( + Address.AllColumns, + ).FROM( + Address, + ).WHERE( + Address.Address2.IS_NULL(), + ) + + var dest []model.Address + + err := stmt.Query(db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "AddressID": 1, + "Address": "47 MySakila Drive", + "Address2": null, + "District": "Alberta", + "CityID": 300, + "PostalCode": "", + "Phone": "", + "LastUpdate": "2006-02-15T09:45:30Z" + }, + { + "AddressID": 2, + "Address": "28 MySQL Boulevard", + "Address2": null, + "District": "QLD", + "CityID": 576, + "PostalCode": "", + "Phone": "", + "LastUpdate": "2006-02-15T09:45:30Z" + }, + { + "AddressID": 3, + "Address": "23 Workhaven Lane", + "Address2": null, + "District": "Alberta", + "CityID": 300, + "PostalCode": "", + "Phone": "14033335568", + "LastUpdate": "2006-02-15T09:45:30Z" + }, + { + "AddressID": 4, + "Address": "1411 Lillydale Drive", + "Address2": null, + "District": "QLD", + "CityID": 576, + "PostalCode": "", + "Phone": "6172235589", + "LastUpdate": "2006-02-15T09:45:30Z" + } +] +`) +} + +func TestRowsScanSetZeroValue(t *testing.T) { + stmt := SELECT( + Rental.AllColumns, + ).FROM( + Rental, + ).WHERE( + Rental.RentalID.IN(Int(16049), Int(15966)), + ).ORDER_BY( + Rental.RentalID.DESC(), + ) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + defer rows.Close() + + // destination object is used as destination for all rows scan. + // this tests checks that ReturnedDate is set to nil with the second call + // check qrm.setZeroValue + var dest model.Rental + + for rows.Next() { + err := rows.Scan(&dest) + require.NoError(t, err) + + if dest.RentalID == 16049 { + testutils.AssertJSON(t, dest, ` +{ + "RentalID": 16049, + "RentalDate": "2005-08-23T22:50:12Z", + "InventoryID": 2666, + "CustomerID": 393, + "ReturnDate": "2005-08-30T01:01:12Z", + "StaffID": 2, + "LastUpdate": "2006-02-16T02:30:53Z" +} +`) + } else { + testutils.AssertJSON(t, dest, ` +{ + "RentalID": 15966, + "RentalDate": "2006-02-14T15:16:03Z", + "InventoryID": 4472, + "CustomerID": 374, + "ReturnDate": null, + "StaffID": 1, + "LastUpdate": "2006-02-16T02:30:53Z" +} +`) + } + } + + err = rows.Close() + require.NoError(t, err) + err = rows.Err() + require.NoError(t, err) +} + func TestScanNumericToFloat(t *testing.T) { type Number struct { Float32 float32 @@ -826,6 +943,54 @@ func TestScanNumericToIntegerError(t *testing.T) { } +func TestScanIntoCustomBaseTypes(t *testing.T) { + + type MyUint8 uint8 + type MyUint16 uint16 + type MyUint32 uint32 + type MyInt16 int16 + type MyFloat32 float32 + type MyFloat64 float64 + type MyString string + type MyTime = time.Time + + type film struct { + FilmID MyUint16 `sql:"primary_key"` + Title MyString + Description *MyString + ReleaseYear *MyInt16 + LanguageID MyUint8 + RentalDuration MyUint8 + RentalRate MyFloat32 + Length *MyUint32 + ReplacementCost MyFloat64 + Rating *model.MpaaRating + LastUpdate MyTime + SpecialFeatures *MyString + Fulltext MyString + } + + stmt := SELECT( + Film.AllColumns, + ).FROM( + Film, + ).ORDER_BY( + Film.FilmID.ASC(), + ).LIMIT(3) + + var films []model.Film + + err := stmt.Query(db, &films) + require.NoError(t, err) + + var myFilms []film + + err = stmt.Query(db, &myFilms) + require.NoError(t, err) + + require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms)) +} + // QueryContext panic when the scanned value is nil and the destination is a slice of primitive // https://github.com/go-jet/jet/issues/91 func TestScanToPrimitiveElementsSlice(t *testing.T) { diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index b3d3e63..e95d92a 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -2521,6 +2521,79 @@ func TestRecursionScanNx1(t *testing.T) { }) } +type StoreInfo struct { + model.Store + + Staffs ManagerInfo +} + +type ManagerInfo struct { + model.Staff + Store *StoreInfo +} + +func TestRecursionScan1x1(t *testing.T) { + + stmt := SELECT( + Store.AllColumns, + Staff.AllColumns, + ).FROM( + Store. + INNER_JOIN(Staff, Staff.StaffID.EQ(Store.ManagerStaffID)), + ).ORDER_BY( + Store.StoreID, + ) + + var dest []StoreInfo + + err := stmt.Query(db, &dest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "StoreID": 1, + "ManagerStaffID": 1, + "AddressID": 1, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": { + "StaffID": 1, + "FirstName": "Mike", + "LastName": "Hillyer", + "AddressID": 3, + "Email": "Mike.Hillyer@sakilastaff.com", + "StoreID": 1, + "Active": true, + "Username": "Mike", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": "iVBORw0KWgo=", + "Store": null + } + }, + { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "2006-02-15T09:57:12Z", + "Staffs": { + "StaffID": 2, + "FirstName": "Jon", + "LastName": "Stephens", + "AddressID": 4, + "Email": "Jon.Stephens@sakilastaff.com", + "StoreID": 2, + "Active": true, + "Username": "Jon", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": null, + "Store": null + } + } +] +`) +} + // In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions, // postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error. // Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible. diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 0b47e9a..e6d23ee 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -2,7 +2,6 @@ package postgres import ( "context" - "fmt" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model" @@ -864,5 +863,4 @@ WHERE orders1."orders.order_id" < $1; err := stmt.Query(db, &dest) require.NoError(t, err) require.Len(t, dest, 72) - fmt.Println(len(dest)) }