From f9efee77ff5c0220ef21227ad1829aba4870c159 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 12 Oct 2019 18:45:09 +0200 Subject: [PATCH] QRM returns sql.ErrNoRows when scanning into struct destination and query result set is empty. --- internal/jet/statement.go | 6 +- qrm/qrm.go | 98 +++++++++++++++++---------------- qrm/scan_context.go | 37 ++++++------- qrm/utill.go | 9 +-- tests/postgres/alltypes_test.go | 2 +- tests/postgres/scan_test.go | 30 ++++++++++ tests/postgres/select_test.go | 11 ---- 7 files changed, 103 insertions(+), 90 deletions(-) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 2bd0929..ab0c655 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -15,10 +15,12 @@ type Statement interface { DebugSql() (query string) // Query executes statement over database connection db and stores row result in destination. - // Destination can be arbitrary structure + // Destination can be either pointer to struct or pointer to a slice. + // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. Query(db qrm.DB, destination interface{}) error // QueryContext executes statement with a context over database connection db and stores row result in destination. - // Destination can be of arbitrary structure + // Destination can be either pointer to struct or pointer to a slice. + // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. QueryContext(context context.Context, db qrm.DB, destination interface{}) error //Exec executes statement over db connection without returning any rows. diff --git a/qrm/qrm.go b/qrm/qrm.go index 2b25ecc..8f79c57 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -10,6 +10,7 @@ import ( // Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. +// If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { utils.MustBeInitializedPtr(db, "jet: db is nil") @@ -19,21 +20,27 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr destinationPtrType := reflect.TypeOf(destPtr) if destinationPtrType.Elem().Kind() == reflect.Slice { - return queryToSlice(ctx, db, query, args, destPtr) + _, err := queryToSlice(ctx, db, query, args, destPtr) + return err } else if destinationPtrType.Elem().Kind() == reflect.Struct { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() - err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) + rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) if err != nil { return err } - if tempSliceValue.Len() == 0 { + if rowsProcessed == 0 { return sql.ErrNoRows } + // edge case when row result set contains only NULLs. + if tempSliceValue.Len() == 0 { + return nil + } + structValue := reflect.ValueOf(destPtr).Elem() firstTempStruct := tempSliceValue.Index(0).Elem() @@ -46,7 +53,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr } } -func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error { +func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { if ctx == nil { ctx = context.Background() } @@ -54,27 +61,27 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return err + return } defer rows.Close() scanContext, err := newScanContext(rows) if err != nil { - return err + return } if len(scanContext.row) == 0 { - return nil + return } slicePtrValue := reflect.ValueOf(slicePtr) for rows.Next() { - err := rows.Scan(scanContext.row...) + err = rows.Scan(scanContext.row...) if err != nil { - return err + return } scanContext.rowNum++ @@ -82,22 +89,24 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil) if err != nil { - return err + return } } err = rows.Close() if err != nil { - return err + return } err = rows.Err() if err != nil { - return err + return } - return nil + rowsProcessed = scanContext.rowNum + + return } func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { @@ -165,6 +174,7 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value } func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { + mapOnlySlices := len(onlySlices) > 0 structType := structPtrValue.Type().Elem() typeInf := scanContext.getTypeInfo(structType, parentField) @@ -193,22 +203,21 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re updated = true } - } else if len(onlySlices) == 0 { - - if fieldMap.columnIndex == -1 { + } else { + if mapOnlySlices || fieldMap.columnIndex == -1 { continue } + cellValue := scanContext.rowElem(fieldMap.columnIndex) + + if cellValue == nil { + continue + } + + initializeValueIfNilPtr(fieldValue) + updated = true + if fieldMap.implementsScanner { - - cellValue := scanContext.rowElem(fieldMap.columnIndex) - - if cellValue == nil { - continue - } - - initializeValueIfNilPtr(fieldValue) - scanner := getScanner(fieldValue) err = scanner.Scan(cellValue) @@ -216,15 +225,8 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re if err != nil { panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String()) } - updated = true } else { - cellValue := scanContext.rowElem(fieldMap.columnIndex) - - if cellValue != nil { - updated = true - initializeValueIfNilPtr(fieldValue) - setReflectValue(reflect.ValueOf(cellValue), fieldValue) - } + setReflectValue(reflect.ValueOf(cellValue), fieldValue) } } } @@ -232,21 +234,6 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re return } -func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { - - utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.") - - destValueKind := destPtrValue.Elem().Kind() - - if destValueKind == reflect.Struct { - return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) - } else if destValueKind == reflect.Slice { - return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) - } else { - panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) - } -} - func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { var destPtrValue reflect.Value @@ -273,3 +260,18 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } + +func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { + + utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.") + + destValueKind := destPtrValue.Elem().Kind() + + if destValueKind == reflect.Struct { + return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) + } else if destValueKind == reflect.Slice { + return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) + } else { + panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) + } +} diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 7db2c95..8141ed7 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -3,23 +3,19 @@ package qrm import ( "database/sql" "database/sql/driver" + "fmt" "github.com/go-jet/jet/internal/utils" "reflect" - "strconv" "strings" ) type scanContext struct { - rowNum int - - row []interface{} - uniqueDestObjectsMap map[string]int - - typeToColumnIndexMap map[string]int - - groupKeyInfoCache map[string]groupKeyInfo - - typeInfoMap map[string]typeInfo + rowNum int64 + row []interface{} + uniqueDestObjectsMap map[string]int + commonIdentToColumnIndex map[string]int + groupKeyInfoCache map[string]groupKeyInfo + typeInfoMap map[string]typeInfo } func newScanContext(rows *sql.Rows) (*scanContext, error) { @@ -35,26 +31,25 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { return nil, err } - typeToIndexMap := map[string]int{} + commonIdentToColumnIndex := map[string]int{} for i, alias := range aliases { names := strings.SplitN(alias, ".", 2) - - goName := toCommonIdentifier(names[0]) + commonIdentifier := toCommonIdentifier(names[0]) if len(names) > 1 { - goName += "." + toCommonIdentifier(names[1]) + commonIdentifier += "." + toCommonIdentifier(names[1]) } - typeToIndexMap[strings.ToLower(goName)] = i + commonIdentToColumnIndex[commonIdentifier] = i } return &scanContext{ row: createScanValue(columnTypes), uniqueDestObjectsMap: make(map[string]int), - groupKeyInfoCache: make(map[string]groupKeyInfo), - typeToColumnIndexMap: typeToIndexMap, + groupKeyInfoCache: make(map[string]groupKeyInfo), + commonIdentToColumnIndex: commonIdentToColumnIndex, typeInfoMap: make(map[string]typeInfo), }, nil @@ -65,7 +60,7 @@ type typeInfo struct { } type fieldMapping struct { - complexType bool + complexType bool // slice or struct columnIndex int implementsScanner bool } @@ -137,7 +132,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect. func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { - return "|ROW: " + strconv.Itoa(s.rowNum) + "|" + return fmt.Sprintf("|ROW:%d|", s.rowNum) } groupKeys := []string{} @@ -204,7 +199,7 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { key = strings.ToLower(fieldName) } - index, ok := s.typeToColumnIndexMap[key] + index, ok := s.commonIdentToColumnIndex[key] if !ok { return -1 diff --git a/qrm/utill.go b/qrm/utill.go index e23dfac..7b6edaa 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -33,14 +33,9 @@ func getScanner(value reflect.Value) sql.Scanner { func getSliceElemType(slicePtrValue reflect.Value) reflect.Type { sliceTypePtr := slicePtrValue.Type() + elemType := indirectType(sliceTypePtr).Elem() - elemType := sliceTypePtr.Elem().Elem() - - if elemType.Kind() == reflect.Ptr { - return elemType.Elem() - } - - return elemType + return indirectType(elemType) } func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value { diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 9af0394..f2e8fbe 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -134,7 +134,7 @@ LIMIT $5; func TestExpressionCast(t *testing.T) { query := AllTypes.SELECT( - postgres.CAST(Int(150)).AS_CHAR(12), + postgres.CAST(Int(150)).AS_CHAR(12).AS("char12"), postgres.CAST(String("TRUE")).AS_BOOL(), postgres.CAST(String("111")).AS_SMALLINT(), postgres.CAST(String("111")).AS_INTEGER(), diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index d3ec061..f42dbd5 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -1,6 +1,7 @@ package postgres import ( + "database/sql" "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -694,6 +695,35 @@ func TestScanToSlice(t *testing.T) { }) } +func TestStructScanErrNoRows(t *testing.T) { + query := SELECT(Customer.AllColumns). + FROM(Customer). + WHERE(Customer.CustomerID.EQ(Int(-1))) + + customer := model.Customer{} + + err := query.Query(db, &customer) + + assert.Error(t, err, sql.ErrNoRows.Error()) +} + +func TestStructScanAllNull(t *testing.T) { + query := SELECT(NULL.AS("null1"), NULL.AS("null2")) + + dest := struct { + Null1 *int + Null2 *int + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest, struct { + Null1 *int + Null2 *int + }{}) +} + var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index afb7832..3cd2fba 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1,7 +1,6 @@ package postgres import ( - "database/sql" "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -1798,13 +1797,3 @@ func TestJoinViewWithTable(t *testing.T) { assert.Equal(t, len(dest[0].Rentals), 32) assert.Equal(t, len(dest[1].Rentals), 27) } - -func TestErrNoRows(t *testing.T) { - query := SELECT(Customer.AllColumns).FROM(Customer).WHERE(Customer.CustomerID.EQ(Int(-1))) - - customer := model.Customer{} - - err := query.Query(db, &customer) - - assert.Error(t, err, sql.ErrNoRows.Error()) -}