From 8c9ae77cd8bf9339253070ca9f79b4e19d90512b Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 10 Oct 2019 16:09:44 +0200 Subject: [PATCH 1/7] Add QRM error if result set is empty when scanning into struct destination. --- qrm/qrm.go | 2 +- tests/mysql/alltypes_test.go | 15 ++++++++++----- tests/mysql/select_test.go | 21 ++++++++++++++------- tests/postgres/alltypes_test.go | 9 ++++++--- tests/postgres/scan_test.go | 6 ++++-- tests/postgres/select_test.go | 30 ++++++++++++++++++++++++++---- 6 files changed, 61 insertions(+), 22 deletions(-) diff --git a/qrm/qrm.go b/qrm/qrm.go index ebe9084..694f6a1 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -38,7 +38,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr } if tempSliceValue.Len() == 0 { - return nil + return sql.ErrNoRows } structValue := reflect.ValueOf(destPtr).Elem() diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 952ea90..f0f7406 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -516,7 +516,8 @@ func TestStringOperators(t *testing.T) { fmt.Println(query.DebugSql()) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -586,7 +587,8 @@ FROM test_sample.all_types; `, "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", "19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36") - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -646,7 +648,8 @@ SELECT CAST(? AS DATE), FROM test_sample.all_types; `) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -708,7 +711,8 @@ SELECT all_types.date_time = all_types.date_time, FROM test_sample.all_types; `) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -769,7 +773,8 @@ SELECT all_types.timestamp = all_types.timestamp, CURRENT_TIMESTAMP(2) FROM test_sample.all_types; `) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index c34cfd6..f404e29 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -228,7 +228,8 @@ LIMIT ?; testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -263,7 +264,8 @@ LIMIT ?; testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(1)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -305,7 +307,8 @@ OFFSET ?; testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(4), int64(3)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -510,7 +513,8 @@ SELECT true, 'date'; `) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -530,7 +534,8 @@ LOCK IN SHARE MODE; testutils.AssertDebugStatementSql(t, query, expectedSQL) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -606,7 +611,8 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -641,7 +647,8 @@ ORDER BY payment.customer_id; testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index a613173..9af0394 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -170,7 +170,8 @@ func TestExpressionCast(t *testing.T) { //fmt.Println(query.DebugSql()) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -249,7 +250,8 @@ func TestStringOperators(t *testing.T) { //fmt.Println(query.DebugSql()) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -618,7 +620,8 @@ func TestTimeExpression(t *testing.T) { //fmt.Println(query.DebugSql()) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 4a0eba4..d3ec061 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -50,14 +50,16 @@ func TestScanToInvalidDestination(t *testing.T) { func TestScanToValidDestination(t *testing.T) { t.Run("pointer to struct", func(t *testing.T) { - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) }) t.Run("global query function scan", func(t *testing.T) { queryStr, args := query.Sql() - err := qrm.Query(nil, db, queryStr, args, &struct{}{}) + dest := []struct{}{} + err := qrm.Query(nil, db, queryStr, args, &dest) assert.NilError(t, err) }) diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 401854d..afb7832 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1,6 +1,7 @@ package postgres import ( + "database/sql" "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -162,7 +163,8 @@ LIMIT 12; testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -1617,7 +1619,8 @@ SELECT true, 'date'; `) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -1688,7 +1691,8 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date; testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } @@ -1723,12 +1727,14 @@ ORDER BY payment.customer_id; testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) - err := query.Query(db, &struct{}{}) + dest := []struct{}{} + err := query.Query(db, &dest) assert.NilError(t, err) } func TestSimpleView(t *testing.T) { + query := SELECT( view.ActorInfo.AllColumns, ). @@ -1743,6 +1749,12 @@ func TestSimpleView(t *testing.T) { FilmInfo string } + //sql, args := query.Sql() + // + //row := db.QueryRow(sql, args...) + // + //row.Scan() + var dest []ActorInfo err := query.Query(db, &dest) @@ -1786,3 +1798,13 @@ func TestJoinViewWithTable(t *testing.T) { assert.Equal(t, len(dest[0].Rentals), 32) assert.Equal(t, len(dest[1].Rentals), 27) } + +func TestErrNoRows(t *testing.T) { + query := SELECT(Customer.AllColumns).FROM(Customer).WHERE(Customer.CustomerID.EQ(Int(-1))) + + customer := model.Customer{} + + err := query.Query(db, &customer) + + assert.Error(t, err, sql.ErrNoRows.Error()) +} From 3544977d7f87efc239cfabb37f67909ee385028a Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 11 Oct 2019 10:15:36 +0200 Subject: [PATCH 2/7] QRM code refactor. --- qrm/qrm.go | 607 -------------------------------------------- qrm/scan_context.go | 246 ++++++++++++++++++ qrm/utill.go | 376 +++++++++++++++++++++++++++ 3 files changed, 622 insertions(+), 607 deletions(-) create mode 100644 qrm/scan_context.go create mode 100644 qrm/utill.go diff --git a/qrm/qrm.go b/qrm/qrm.go index 694f6a1..2b25ecc 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -3,15 +3,8 @@ package qrm import ( "context" "database/sql" - "database/sql/driver" - "fmt" "github.com/go-jet/jet/internal/utils" - "github.com/go-jet/jet/qrm/internal" - "github.com/google/uuid" "reflect" - "strconv" - "strings" - "time" ) // Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` @@ -171,56 +164,6 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value return } -type typeInfo struct { - fieldMappings []fieldMapping -} - -type fieldMapping struct { - complexType bool - columnIndex int - implementsScanner bool -} - -func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { - - typeMapKey := structType.String() - - if parentField != nil { - typeMapKey += string(parentField.Tag) - } - - if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok { - return typeInfo - } - - typeName := getTypeName(structType, parentField) - - newTypeInfo := typeInfo{} - - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - - newTypeName, fieldName := getTypeAndFieldName(typeName, field) - columnIndex := s.typeToColumnIndex(newTypeName, fieldName) - - fieldMap := fieldMapping{ - columnIndex: columnIndex, - } - - if implementsScannerType(field.Type) { - fieldMap.implementsScanner = true - } else if !isSimpleModelType(field.Type) { - fieldMap.complexType = true - } - - newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap) - } - - s.typeInfoMap[typeMapKey] = newTypeInfo - - return newTypeInfo -} - func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { structType := structPtrValue.Type().Elem() @@ -330,553 +273,3 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } - -var scannerInterfaceType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() - -func implementsScannerType(fieldType reflect.Type) bool { - if fieldType.Implements(scannerInterfaceType) { - return true - } - - typePtr := reflect.New(fieldType).Type() - - return typePtr.Implements(scannerInterfaceType) -} - -func getScanner(value reflect.Value) sql.Scanner { - if scanner, ok := value.Interface().(sql.Scanner); ok { - return scanner - } - - return value.Addr().Interface().(sql.Scanner) -} - -func getSliceElemType(slicePtrValue reflect.Value) reflect.Type { - sliceTypePtr := slicePtrValue.Type() - - elemType := sliceTypePtr.Elem().Elem() - - if elemType.Kind() == reflect.Ptr { - return elemType.Elem() - } - - return elemType -} - -func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value { - sliceValue := slicePtrValue.Elem() - elem := sliceValue.Index(index) - - if elem.Kind() == reflect.Ptr { - return elem - } - - return elem.Addr() -} - -func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error { - if slicePtrValue.IsNil() { - panic("jet: internal, slice is nil") - } - sliceValue := slicePtrValue.Elem() - sliceElemType := sliceValue.Type().Elem() - - newElemValue := objPtrValue - - if sliceElemType.Kind() != reflect.Ptr { - newElemValue = objPtrValue.Elem() - } - - if newElemValue.Type().ConvertibleTo(sliceElemType) { - newElemValue = newElemValue.Convert(sliceElemType) - } - - if !newElemValue.Type().AssignableTo(sliceElemType) { - panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice") - } - - sliceValue.Set(reflect.Append(sliceValue, newElemValue)) - - return nil -} - -func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value { - destinationSliceType := slicePtrValue.Type().Elem() - elemType := indirectType(destinationSliceType.Elem()) - - return reflect.New(elemType) -} - -func getTypeName(structType reflect.Type, parentField *reflect.StructField) string { - if parentField == nil { - return structType.Name() - } - - aliasTag := parentField.Tag.Get("alias") - - if aliasTag == "" { - return structType.Name() - } - - aliasParts := strings.Split(aliasTag, ".") - - return toCommonIdentifier(aliasParts[0]) -} - -func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) { - aliasTag := field.Tag.Get("alias") - - if aliasTag == "" { - return structType, field.Name - } - - aliasParts := strings.Split(aliasTag, ".") - - if len(aliasParts) == 1 { - return structType, toCommonIdentifier(aliasParts[0]) - } - - return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]) -} - -var replacer = strings.NewReplacer(" ", "", "-", "", "_", "") - -func toCommonIdentifier(name string) string { - return strings.ToLower(replacer.Replace(name)) -} - -func initializeValueIfNilPtr(value reflect.Value) { - - if !value.IsValid() || !value.CanSet() { - return - } - - if value.Kind() == reflect.Ptr && value.IsNil() { - value.Set(reflect.New(value.Type().Elem())) - } -} - -func valueToString(value reflect.Value) string { - - if !value.IsValid() { - return "nil" - } - - var valueInterface interface{} - if value.Kind() == reflect.Ptr { - if value.IsNil() { - return "nil" - } - valueInterface = value.Elem().Interface() - } else { - valueInterface = value.Interface() - } - - if t, ok := valueInterface.(fmt.Stringer); ok { - return t.String() - } - - return fmt.Sprintf("%#v", valueInterface) -} - -var timeType = reflect.TypeOf(time.Now()) -var uuidType = reflect.TypeOf(uuid.New()) - -func isSimpleModelType(objType reflect.Type) bool { - objType = indirectType(objType) - - switch objType.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64, - reflect.String, - reflect.Bool: - return true - case reflect.Slice: - return objType.Elem().Kind() == reflect.Uint8 //[]byte - case reflect.Struct: - return objType == timeType || objType == uuidType // time.Time || uuid.UUID - } - - return false -} - -func isIntegerType(value reflect.Type) bool { - switch value { - case int8Type, unit8Type, int16Type, uint16Type, - int32Type, uint32Type, int64Type, uint64Type: - return true - } - - return false -} - -func tryAssign(source, destination reflect.Value) bool { - if source.Type().ConvertibleTo(destination.Type()) { - source = source.Convert(destination.Type()) - } - - if isIntegerType(source.Type()) && destination.Type() == boolType { - intValue := source.Int() - - if intValue == 1 { - source = reflect.ValueOf(true) - } else if intValue == 0 { - source = reflect.ValueOf(false) - } - } - - if source.Type().AssignableTo(destination.Type()) { - destination.Set(source) - return true - } - - return false -} - -func setReflectValue(source, destination reflect.Value) { - - if tryAssign(source, destination) { - return - } - - if destination.Kind() == reflect.Ptr { - if source.Kind() == reflect.Ptr { - if !source.IsNil() { - if destination.IsNil() { - initializeValueIfNilPtr(destination) - } - - if tryAssign(source.Elem(), destination.Elem()) { - return - } - } else { - return - } - } else { - if source.CanAddr() { - source = source.Addr() - } else { - sourceCopy := reflect.New(source.Type()) - sourceCopy.Elem().Set(source) - - source = sourceCopy - } - - if tryAssign(source, destination) { - return - } - - if tryAssign(source.Elem(), destination.Elem()) { - return - } - } - } else { - if source.Kind() == reflect.Ptr { - if source.IsNil() { - return - } - source = source.Elem() - } - - if tryAssign(source, destination) { - return - } - } - - panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) -} - -func createScanValue(columnTypes []*sql.ColumnType) []interface{} { - values := make([]interface{}, len(columnTypes)) - - for i, sqlColumnType := range columnTypes { - columnType := newScanType(sqlColumnType) - - columnValue := reflect.New(columnType) - - values[i] = columnValue.Interface() - } - - return values -} - -var boolType = reflect.TypeOf(true) -var int8Type = reflect.TypeOf(int8(1)) -var unit8Type = reflect.TypeOf(uint8(1)) -var int16Type = reflect.TypeOf(int16(1)) -var uint16Type = reflect.TypeOf(uint16(1)) -var int32Type = reflect.TypeOf(int32(1)) -var uint32Type = reflect.TypeOf(uint32(1)) -var int64Type = reflect.TypeOf(int64(1)) -var uint64Type = reflect.TypeOf(uint64(1)) - -var nullBoolType = reflect.TypeOf(sql.NullBool{}) -var nullInt8Type = reflect.TypeOf(internal.NullInt8{}) -var nullInt16Type = reflect.TypeOf(internal.NullInt16{}) -var nullInt32Type = reflect.TypeOf(internal.NullInt32{}) -var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) -var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{}) -var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) -var nullStringType = reflect.TypeOf(sql.NullString{}) -var nullTimeType = reflect.TypeOf(internal.NullTime{}) -var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{}) - -func newScanType(columnType *sql.ColumnType) reflect.Type { - - switch columnType.DatabaseTypeName() { - case "TINYINT": - return nullInt8Type - case "INT2", "SMALLINT", "YEAR": - return nullInt16Type - case "INT4", "MEDIUMINT", "INT": - return nullInt32Type - case "INT8", "BIGINT": - return nullInt64Type - case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML": - return nullStringType - case "FLOAT4": - return nullFloat32Type - case "FLOAT8", "NUMERIC", "DECIMAL", "FLOAT", "DOUBLE": - return nullFloat64Type - case "BOOL": - return nullBoolType - case "BYTEA", "BINARY", "VARBINARY", "BLOB": - return nullByteArrayType - case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ": - return nullTimeType - default: - return nullStringType - } -} - -type scanContext struct { - rowNum int - - row []interface{} - uniqueDestObjectsMap map[string]int - - typeToColumnIndexMap map[string]int - - groupKeyInfoCache map[string]groupKeyInfo - - typeInfoMap map[string]typeInfo -} - -func newScanContext(rows *sql.Rows) (*scanContext, error) { - aliases, err := rows.Columns() - - if err != nil { - return nil, err - } - - columnTypes, err := rows.ColumnTypes() - - if err != nil { - return nil, err - } - - typeToIndexMap := map[string]int{} - - for i, alias := range aliases { - names := strings.SplitN(alias, ".", 2) - - goName := toCommonIdentifier(names[0]) - - if len(names) > 1 { - goName += "." + toCommonIdentifier(names[1]) - } - - typeToIndexMap[strings.ToLower(goName)] = i - } - - return &scanContext{ - row: createScanValue(columnTypes), - uniqueDestObjectsMap: make(map[string]int), - - groupKeyInfoCache: make(map[string]groupKeyInfo), - typeToColumnIndexMap: typeToIndexMap, - - typeInfoMap: make(map[string]typeInfo), - }, nil -} - -type groupKeyInfo struct { - typeName string - indexes []int - subTypes []groupKeyInfo -} - -func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { - - mapKey := structType.Name() - - if structField != nil { - mapKey += structField.Type.String() - } - - if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok { - return s.constructGroupKey(groupKeyInfo) - } - - groupKeyInfo := s.getGroupKeyInfo(structType, structField) - - s.groupKeyInfoCache[mapKey] = groupKeyInfo - - return s.constructGroupKey(groupKeyInfo) -} - -func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { - if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { - return "|ROW: " + strconv.Itoa(s.rowNum) + "|" - } - - groupKeys := []string{} - - for _, index := range groupKeyInfo.indexes { - cellValue := s.rowElem(index) - subKey := valueToString(reflect.ValueOf(cellValue)) - - groupKeys = append(groupKeys, subKey) - } - - subTypesGroupKeys := []string{} - for _, subType := range groupKeyInfo.subTypes { - subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) - } - - return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" -} - -func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { - ret := groupKeyInfo{typeName: structType.Name()} - - typeName := getTypeName(structType, parentField) - primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField) - - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - fieldType := indirectType(field.Type) - - if !isSimpleModelType(fieldType) { - if fieldType.Kind() != reflect.Struct { - continue - } - - subType := s.getGroupKeyInfo(fieldType, &field) - - if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { - ret.subTypes = append(ret.subTypes, subType) - } - } else { - if isPrimaryKey(field, primaryKeyOverwrites) { - newTypeName, fieldName := getTypeAndFieldName(typeName, field) - - index := s.typeToColumnIndex(newTypeName, fieldName) - - if index < 0 { - continue - } - - ret.indexes = append(ret.indexes, index) - } - } - } - - return ret -} - -func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { - var key string - - if typeName != "" { - key = strings.ToLower(typeName + "." + fieldName) - } else { - key = strings.ToLower(fieldName) - } - - index, ok := s.typeToColumnIndexMap[key] - - if !ok { - return -1 - } - - return index -} - -func (s *scanContext) rowElem(index int) interface{} { - - valuer, ok := s.row[index].(driver.Valuer) - - if !ok { - panic("jet: internal error, scan value doesn't implement driver.Valuer") - } - - value, err := valuer.Value() - - utils.PanicOnError(err) - - return value -} - -func (s *scanContext) rowElemValuePtr(index int) reflect.Value { - rowElem := s.rowElem(index) - rowElemValue := reflect.ValueOf(rowElem) - - if rowElemValue.Kind() == reflect.Ptr { - return rowElemValue - } - - if rowElemValue.CanAddr() { - return rowElemValue.Addr() - } - - newElem := reflect.New(rowElemValue.Type()) - newElem.Elem().Set(rowElemValue) - return newElem -} - -func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { - if len(primaryKeyOverwrites) > 0 { - return utils.StringSliceContains(primaryKeyOverwrites, field.Name) - } - - sqlTag := field.Tag.Get("sql") - - return sqlTag == "primary_key" -} - -func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string { - if parentField == nil { - return nil - } - - sqlTag := parentField.Tag.Get("sql") - - if !strings.HasPrefix(sqlTag, "primary_key") { - return nil - } - - parts := strings.Split(sqlTag, "=") - - if len(parts) < 2 { - return nil - } - - return strings.Split(parts[1], ",") -} - -func indirectType(reflectType reflect.Type) reflect.Type { - if reflectType.Kind() != reflect.Ptr { - return reflectType - } - return reflectType.Elem() -} - -func fieldToString(field *reflect.StructField) string { - if field == nil { - return "" - } - - return " at '" + field.Name + " " + field.Type.String() + "'" -} diff --git a/qrm/scan_context.go b/qrm/scan_context.go new file mode 100644 index 0000000..7db2c95 --- /dev/null +++ b/qrm/scan_context.go @@ -0,0 +1,246 @@ +package qrm + +import ( + "database/sql" + "database/sql/driver" + "github.com/go-jet/jet/internal/utils" + "reflect" + "strconv" + "strings" +) + +type scanContext struct { + rowNum int + + row []interface{} + uniqueDestObjectsMap map[string]int + + typeToColumnIndexMap map[string]int + + groupKeyInfoCache map[string]groupKeyInfo + + typeInfoMap map[string]typeInfo +} + +func newScanContext(rows *sql.Rows) (*scanContext, error) { + aliases, err := rows.Columns() + + if err != nil { + return nil, err + } + + columnTypes, err := rows.ColumnTypes() + + if err != nil { + return nil, err + } + + typeToIndexMap := map[string]int{} + + for i, alias := range aliases { + names := strings.SplitN(alias, ".", 2) + + goName := toCommonIdentifier(names[0]) + + if len(names) > 1 { + goName += "." + toCommonIdentifier(names[1]) + } + + typeToIndexMap[strings.ToLower(goName)] = i + } + + return &scanContext{ + row: createScanValue(columnTypes), + uniqueDestObjectsMap: make(map[string]int), + + groupKeyInfoCache: make(map[string]groupKeyInfo), + typeToColumnIndexMap: typeToIndexMap, + + typeInfoMap: make(map[string]typeInfo), + }, nil +} + +type typeInfo struct { + fieldMappings []fieldMapping +} + +type fieldMapping struct { + complexType bool + columnIndex int + implementsScanner bool +} + +func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { + + typeMapKey := structType.String() + + if parentField != nil { + typeMapKey += string(parentField.Tag) + } + + if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok { + return typeInfo + } + + typeName := getTypeName(structType, parentField) + + newTypeInfo := typeInfo{} + + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + + newTypeName, fieldName := getTypeAndFieldName(typeName, field) + columnIndex := s.typeToColumnIndex(newTypeName, fieldName) + + fieldMap := fieldMapping{ + columnIndex: columnIndex, + } + + if implementsScannerType(field.Type) { + fieldMap.implementsScanner = true + } else if !isSimpleModelType(field.Type) { + fieldMap.complexType = true + } + + newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap) + } + + s.typeInfoMap[typeMapKey] = newTypeInfo + + return newTypeInfo +} + +type groupKeyInfo struct { + typeName string + indexes []int + subTypes []groupKeyInfo +} + +func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { + + mapKey := structType.Name() + + if structField != nil { + mapKey += structField.Type.String() + } + + if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok { + return s.constructGroupKey(groupKeyInfo) + } + + groupKeyInfo := s.getGroupKeyInfo(structType, structField) + + s.groupKeyInfoCache[mapKey] = groupKeyInfo + + return s.constructGroupKey(groupKeyInfo) +} + +func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { + if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { + return "|ROW: " + strconv.Itoa(s.rowNum) + "|" + } + + groupKeys := []string{} + + for _, index := range groupKeyInfo.indexes { + cellValue := s.rowElem(index) + subKey := valueToString(reflect.ValueOf(cellValue)) + + groupKeys = append(groupKeys, subKey) + } + + subTypesGroupKeys := []string{} + for _, subType := range groupKeyInfo.subTypes { + subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) + } + + return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" +} + +func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { + ret := groupKeyInfo{typeName: structType.Name()} + + typeName := getTypeName(structType, parentField) + primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField) + + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + fieldType := indirectType(field.Type) + + if !isSimpleModelType(fieldType) { + if fieldType.Kind() != reflect.Struct { + continue + } + + subType := s.getGroupKeyInfo(fieldType, &field) + + if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { + ret.subTypes = append(ret.subTypes, subType) + } + } else { + if isPrimaryKey(field, primaryKeyOverwrites) { + newTypeName, fieldName := getTypeAndFieldName(typeName, field) + + index := s.typeToColumnIndex(newTypeName, fieldName) + + if index < 0 { + continue + } + + ret.indexes = append(ret.indexes, index) + } + } + } + + return ret +} + +func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { + var key string + + if typeName != "" { + key = strings.ToLower(typeName + "." + fieldName) + } else { + key = strings.ToLower(fieldName) + } + + index, ok := s.typeToColumnIndexMap[key] + + if !ok { + return -1 + } + + return index +} + +func (s *scanContext) rowElem(index int) interface{} { + + valuer, ok := s.row[index].(driver.Valuer) + + if !ok { + panic("jet: internal error, scan value doesn't implement driver.Valuer") + } + + value, err := valuer.Value() + + utils.PanicOnError(err) + + return value +} + +func (s *scanContext) rowElemValuePtr(index int) reflect.Value { + rowElem := s.rowElem(index) + rowElemValue := reflect.ValueOf(rowElem) + + if rowElemValue.Kind() == reflect.Ptr { + return rowElemValue + } + + if rowElemValue.CanAddr() { + return rowElemValue.Addr() + } + + newElem := reflect.New(rowElemValue.Type()) + newElem.Elem().Set(rowElemValue) + return newElem +} diff --git a/qrm/utill.go b/qrm/utill.go new file mode 100644 index 0000000..e23dfac --- /dev/null +++ b/qrm/utill.go @@ -0,0 +1,376 @@ +package qrm + +import ( + "database/sql" + "fmt" + "github.com/go-jet/jet/internal/utils" + "github.com/go-jet/jet/qrm/internal" + "github.com/google/uuid" + "reflect" + "strings" + "time" +) + +var scannerInterfaceType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func implementsScannerType(fieldType reflect.Type) bool { + if fieldType.Implements(scannerInterfaceType) { + return true + } + + typePtr := reflect.New(fieldType).Type() + + return typePtr.Implements(scannerInterfaceType) +} + +func getScanner(value reflect.Value) sql.Scanner { + if scanner, ok := value.Interface().(sql.Scanner); ok { + return scanner + } + + return value.Addr().Interface().(sql.Scanner) +} + +func getSliceElemType(slicePtrValue reflect.Value) reflect.Type { + sliceTypePtr := slicePtrValue.Type() + + elemType := sliceTypePtr.Elem().Elem() + + if elemType.Kind() == reflect.Ptr { + return elemType.Elem() + } + + return elemType +} + +func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value { + sliceValue := slicePtrValue.Elem() + elem := sliceValue.Index(index) + + if elem.Kind() == reflect.Ptr { + return elem + } + + return elem.Addr() +} + +func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error { + if slicePtrValue.IsNil() { + panic("jet: internal, slice is nil") + } + sliceValue := slicePtrValue.Elem() + sliceElemType := sliceValue.Type().Elem() + + newElemValue := objPtrValue + + if sliceElemType.Kind() != reflect.Ptr { + newElemValue = objPtrValue.Elem() + } + + if newElemValue.Type().ConvertibleTo(sliceElemType) { + newElemValue = newElemValue.Convert(sliceElemType) + } + + if !newElemValue.Type().AssignableTo(sliceElemType) { + panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice") + } + + sliceValue.Set(reflect.Append(sliceValue, newElemValue)) + + return nil +} + +func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value { + destinationSliceType := slicePtrValue.Type().Elem() + elemType := indirectType(destinationSliceType.Elem()) + + return reflect.New(elemType) +} + +func getTypeName(structType reflect.Type, parentField *reflect.StructField) string { + if parentField == nil { + return structType.Name() + } + + aliasTag := parentField.Tag.Get("alias") + + if aliasTag == "" { + return structType.Name() + } + + aliasParts := strings.Split(aliasTag, ".") + + return toCommonIdentifier(aliasParts[0]) +} + +func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) { + aliasTag := field.Tag.Get("alias") + + if aliasTag == "" { + return structType, field.Name + } + + aliasParts := strings.Split(aliasTag, ".") + + if len(aliasParts) == 1 { + return structType, toCommonIdentifier(aliasParts[0]) + } + + return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]) +} + +var replacer = strings.NewReplacer(" ", "", "-", "", "_", "") + +func toCommonIdentifier(name string) string { + return strings.ToLower(replacer.Replace(name)) +} + +func initializeValueIfNilPtr(value reflect.Value) { + + if !value.IsValid() || !value.CanSet() { + return + } + + if value.Kind() == reflect.Ptr && value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } +} + +func valueToString(value reflect.Value) string { + + if !value.IsValid() { + return "nil" + } + + var valueInterface interface{} + if value.Kind() == reflect.Ptr { + if value.IsNil() { + return "nil" + } + valueInterface = value.Elem().Interface() + } else { + valueInterface = value.Interface() + } + + if t, ok := valueInterface.(fmt.Stringer); ok { + return t.String() + } + + return fmt.Sprintf("%#v", valueInterface) +} + +var timeType = reflect.TypeOf(time.Now()) +var uuidType = reflect.TypeOf(uuid.New()) + +func isSimpleModelType(objType reflect.Type) bool { + objType = indirectType(objType) + + switch objType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, + reflect.String, + reflect.Bool: + return true + case reflect.Slice: + return objType.Elem().Kind() == reflect.Uint8 //[]byte + case reflect.Struct: + return objType == timeType || objType == uuidType // time.Time || uuid.UUID + } + + return false +} + +func isIntegerType(value reflect.Type) bool { + switch value { + case int8Type, unit8Type, int16Type, uint16Type, + int32Type, uint32Type, int64Type, uint64Type: + return true + } + + return false +} + +func tryAssign(source, destination reflect.Value) bool { + if source.Type().ConvertibleTo(destination.Type()) { + source = source.Convert(destination.Type()) + } + + if isIntegerType(source.Type()) && destination.Type() == boolType { + intValue := source.Int() + + if intValue == 1 { + source = reflect.ValueOf(true) + } else if intValue == 0 { + source = reflect.ValueOf(false) + } + } + + if source.Type().AssignableTo(destination.Type()) { + destination.Set(source) + return true + } + + return false +} + +func setReflectValue(source, destination reflect.Value) { + + if tryAssign(source, destination) { + return + } + + if destination.Kind() == reflect.Ptr { + if source.Kind() == reflect.Ptr { + if !source.IsNil() { + if destination.IsNil() { + initializeValueIfNilPtr(destination) + } + + if tryAssign(source.Elem(), destination.Elem()) { + return + } + } else { + return + } + } else { + if source.CanAddr() { + source = source.Addr() + } else { + sourceCopy := reflect.New(source.Type()) + sourceCopy.Elem().Set(source) + + source = sourceCopy + } + + if tryAssign(source, destination) { + return + } + + if tryAssign(source.Elem(), destination.Elem()) { + return + } + } + } else { + if source.Kind() == reflect.Ptr { + if source.IsNil() { + return + } + source = source.Elem() + } + + if tryAssign(source, destination) { + return + } + } + + panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) +} + +func createScanValue(columnTypes []*sql.ColumnType) []interface{} { + values := make([]interface{}, len(columnTypes)) + + for i, sqlColumnType := range columnTypes { + columnType := newScanType(sqlColumnType) + + columnValue := reflect.New(columnType) + + values[i] = columnValue.Interface() + } + + return values +} + +var boolType = reflect.TypeOf(true) +var int8Type = reflect.TypeOf(int8(1)) +var unit8Type = reflect.TypeOf(uint8(1)) +var int16Type = reflect.TypeOf(int16(1)) +var uint16Type = reflect.TypeOf(uint16(1)) +var int32Type = reflect.TypeOf(int32(1)) +var uint32Type = reflect.TypeOf(uint32(1)) +var int64Type = reflect.TypeOf(int64(1)) +var uint64Type = reflect.TypeOf(uint64(1)) + +var nullBoolType = reflect.TypeOf(sql.NullBool{}) +var nullInt8Type = reflect.TypeOf(internal.NullInt8{}) +var nullInt16Type = reflect.TypeOf(internal.NullInt16{}) +var nullInt32Type = reflect.TypeOf(internal.NullInt32{}) +var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) +var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{}) +var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) +var nullStringType = reflect.TypeOf(sql.NullString{}) +var nullTimeType = reflect.TypeOf(internal.NullTime{}) +var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{}) + +func newScanType(columnType *sql.ColumnType) reflect.Type { + + switch columnType.DatabaseTypeName() { + case "TINYINT": + return nullInt8Type + case "INT2", "SMALLINT", "YEAR": + return nullInt16Type + case "INT4", "MEDIUMINT", "INT": + return nullInt32Type + case "INT8", "BIGINT": + return nullInt64Type + case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML": + return nullStringType + case "FLOAT4": + return nullFloat32Type + case "FLOAT8", "NUMERIC", "DECIMAL", "FLOAT", "DOUBLE": + return nullFloat64Type + case "BOOL": + return nullBoolType + case "BYTEA", "BINARY", "VARBINARY", "BLOB": + return nullByteArrayType + case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ": + return nullTimeType + default: + return nullStringType + } +} + +func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { + if len(primaryKeyOverwrites) > 0 { + return utils.StringSliceContains(primaryKeyOverwrites, field.Name) + } + + sqlTag := field.Tag.Get("sql") + + return sqlTag == "primary_key" +} + +func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string { + if parentField == nil { + return nil + } + + sqlTag := parentField.Tag.Get("sql") + + if !strings.HasPrefix(sqlTag, "primary_key") { + return nil + } + + parts := strings.Split(sqlTag, "=") + + if len(parts) < 2 { + return nil + } + + return strings.Split(parts[1], ",") +} + +func indirectType(reflectType reflect.Type) reflect.Type { + if reflectType.Kind() != reflect.Ptr { + return reflectType + } + return reflectType.Elem() +} + +func fieldToString(field *reflect.StructField) string { + if field == nil { + return "" + } + + return " at '" + field.Name + " " + field.Type.String() + "'" +} From f9efee77ff5c0220ef21227ad1829aba4870c159 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 12 Oct 2019 18:45:09 +0200 Subject: [PATCH 3/7] QRM returns sql.ErrNoRows when scanning into struct destination and query result set is empty. --- internal/jet/statement.go | 6 +- qrm/qrm.go | 98 +++++++++++++++++---------------- qrm/scan_context.go | 37 ++++++------- qrm/utill.go | 9 +-- tests/postgres/alltypes_test.go | 2 +- tests/postgres/scan_test.go | 30 ++++++++++ tests/postgres/select_test.go | 11 ---- 7 files changed, 103 insertions(+), 90 deletions(-) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 2bd0929..ab0c655 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -15,10 +15,12 @@ type Statement interface { DebugSql() (query string) // Query executes statement over database connection db and stores row result in destination. - // Destination can be arbitrary structure + // Destination can be either pointer to struct or pointer to a slice. + // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. Query(db qrm.DB, destination interface{}) error // QueryContext executes statement with a context over database connection db and stores row result in destination. - // Destination can be of arbitrary structure + // Destination can be either pointer to struct or pointer to a slice. + // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. QueryContext(context context.Context, db qrm.DB, destination interface{}) error //Exec executes statement over db connection without returning any rows. diff --git a/qrm/qrm.go b/qrm/qrm.go index 2b25ecc..8f79c57 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -10,6 +10,7 @@ import ( // Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. +// If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { utils.MustBeInitializedPtr(db, "jet: db is nil") @@ -19,21 +20,27 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr destinationPtrType := reflect.TypeOf(destPtr) if destinationPtrType.Elem().Kind() == reflect.Slice { - return queryToSlice(ctx, db, query, args, destPtr) + _, err := queryToSlice(ctx, db, query, args, destPtr) + return err } else if destinationPtrType.Elem().Kind() == reflect.Struct { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() - err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) + rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) if err != nil { return err } - if tempSliceValue.Len() == 0 { + if rowsProcessed == 0 { return sql.ErrNoRows } + // edge case when row result set contains only NULLs. + if tempSliceValue.Len() == 0 { + return nil + } + structValue := reflect.ValueOf(destPtr).Elem() firstTempStruct := tempSliceValue.Index(0).Elem() @@ -46,7 +53,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr } } -func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error { +func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) { if ctx == nil { ctx = context.Background() } @@ -54,27 +61,27 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return err + return } defer rows.Close() scanContext, err := newScanContext(rows) if err != nil { - return err + return } if len(scanContext.row) == 0 { - return nil + return } slicePtrValue := reflect.ValueOf(slicePtr) for rows.Next() { - err := rows.Scan(scanContext.row...) + err = rows.Scan(scanContext.row...) if err != nil { - return err + return } scanContext.rowNum++ @@ -82,22 +89,24 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil) if err != nil { - return err + return } } err = rows.Close() if err != nil { - return err + return } err = rows.Err() if err != nil { - return err + return } - return nil + rowsProcessed = scanContext.rowNum + + return } func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { @@ -165,6 +174,7 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value } func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { + mapOnlySlices := len(onlySlices) > 0 structType := structPtrValue.Type().Elem() typeInf := scanContext.getTypeInfo(structType, parentField) @@ -193,22 +203,21 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re updated = true } - } else if len(onlySlices) == 0 { - - if fieldMap.columnIndex == -1 { + } else { + if mapOnlySlices || fieldMap.columnIndex == -1 { continue } + cellValue := scanContext.rowElem(fieldMap.columnIndex) + + if cellValue == nil { + continue + } + + initializeValueIfNilPtr(fieldValue) + updated = true + if fieldMap.implementsScanner { - - cellValue := scanContext.rowElem(fieldMap.columnIndex) - - if cellValue == nil { - continue - } - - initializeValueIfNilPtr(fieldValue) - scanner := getScanner(fieldValue) err = scanner.Scan(cellValue) @@ -216,15 +225,8 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re if err != nil { panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String()) } - updated = true } else { - cellValue := scanContext.rowElem(fieldMap.columnIndex) - - if cellValue != nil { - updated = true - initializeValueIfNilPtr(fieldValue) - setReflectValue(reflect.ValueOf(cellValue), fieldValue) - } + setReflectValue(reflect.ValueOf(cellValue), fieldValue) } } } @@ -232,21 +234,6 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re return } -func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { - - utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.") - - destValueKind := destPtrValue.Elem().Kind() - - if destValueKind == reflect.Struct { - return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) - } else if destValueKind == reflect.Slice { - return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) - } else { - panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) - } -} - func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { var destPtrValue reflect.Value @@ -273,3 +260,18 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } + +func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { + + utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.") + + destValueKind := destPtrValue.Elem().Kind() + + if destValueKind == reflect.Struct { + return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) + } else if destValueKind == reflect.Slice { + return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) + } else { + panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) + } +} diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 7db2c95..8141ed7 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -3,23 +3,19 @@ package qrm import ( "database/sql" "database/sql/driver" + "fmt" "github.com/go-jet/jet/internal/utils" "reflect" - "strconv" "strings" ) type scanContext struct { - rowNum int - - row []interface{} - uniqueDestObjectsMap map[string]int - - typeToColumnIndexMap map[string]int - - groupKeyInfoCache map[string]groupKeyInfo - - typeInfoMap map[string]typeInfo + rowNum int64 + row []interface{} + uniqueDestObjectsMap map[string]int + commonIdentToColumnIndex map[string]int + groupKeyInfoCache map[string]groupKeyInfo + typeInfoMap map[string]typeInfo } func newScanContext(rows *sql.Rows) (*scanContext, error) { @@ -35,26 +31,25 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { return nil, err } - typeToIndexMap := map[string]int{} + commonIdentToColumnIndex := map[string]int{} for i, alias := range aliases { names := strings.SplitN(alias, ".", 2) - - goName := toCommonIdentifier(names[0]) + commonIdentifier := toCommonIdentifier(names[0]) if len(names) > 1 { - goName += "." + toCommonIdentifier(names[1]) + commonIdentifier += "." + toCommonIdentifier(names[1]) } - typeToIndexMap[strings.ToLower(goName)] = i + commonIdentToColumnIndex[commonIdentifier] = i } return &scanContext{ row: createScanValue(columnTypes), uniqueDestObjectsMap: make(map[string]int), - groupKeyInfoCache: make(map[string]groupKeyInfo), - typeToColumnIndexMap: typeToIndexMap, + groupKeyInfoCache: make(map[string]groupKeyInfo), + commonIdentToColumnIndex: commonIdentToColumnIndex, typeInfoMap: make(map[string]typeInfo), }, nil @@ -65,7 +60,7 @@ type typeInfo struct { } type fieldMapping struct { - complexType bool + complexType bool // slice or struct columnIndex int implementsScanner bool } @@ -137,7 +132,7 @@ func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect. func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { - return "|ROW: " + strconv.Itoa(s.rowNum) + "|" + return fmt.Sprintf("|ROW:%d|", s.rowNum) } groupKeys := []string{} @@ -204,7 +199,7 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { key = strings.ToLower(fieldName) } - index, ok := s.typeToColumnIndexMap[key] + index, ok := s.commonIdentToColumnIndex[key] if !ok { return -1 diff --git a/qrm/utill.go b/qrm/utill.go index e23dfac..7b6edaa 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -33,14 +33,9 @@ func getScanner(value reflect.Value) sql.Scanner { func getSliceElemType(slicePtrValue reflect.Value) reflect.Type { sliceTypePtr := slicePtrValue.Type() + elemType := indirectType(sliceTypePtr).Elem() - elemType := sliceTypePtr.Elem().Elem() - - if elemType.Kind() == reflect.Ptr { - return elemType.Elem() - } - - return elemType + return indirectType(elemType) } func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value { diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 9af0394..f2e8fbe 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -134,7 +134,7 @@ LIMIT $5; func TestExpressionCast(t *testing.T) { query := AllTypes.SELECT( - postgres.CAST(Int(150)).AS_CHAR(12), + postgres.CAST(Int(150)).AS_CHAR(12).AS("char12"), postgres.CAST(String("TRUE")).AS_BOOL(), postgres.CAST(String("111")).AS_SMALLINT(), postgres.CAST(String("111")).AS_INTEGER(), diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index d3ec061..f42dbd5 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -1,6 +1,7 @@ package postgres import ( + "database/sql" "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -694,6 +695,35 @@ func TestScanToSlice(t *testing.T) { }) } +func TestStructScanErrNoRows(t *testing.T) { + query := SELECT(Customer.AllColumns). + FROM(Customer). + WHERE(Customer.CustomerID.EQ(Int(-1))) + + customer := model.Customer{} + + err := query.Query(db, &customer) + + assert.Error(t, err, sql.ErrNoRows.Error()) +} + +func TestStructScanAllNull(t *testing.T) { + query := SELECT(NULL.AS("null1"), NULL.AS("null2")) + + dest := struct { + Null1 *int + Null2 *int + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest, struct { + Null1 *int + Null2 *int + }{}) +} + var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index afb7832..3cd2fba 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1,7 +1,6 @@ package postgres import ( - "database/sql" "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -1798,13 +1797,3 @@ func TestJoinViewWithTable(t *testing.T) { assert.Equal(t, len(dest[0].Rentals), 32) assert.Equal(t, len(dest[1].Rentals), 27) } - -func TestErrNoRows(t *testing.T) { - query := SELECT(Customer.AllColumns).FROM(Customer).WHERE(Customer.CustomerID.EQ(Int(-1))) - - customer := model.Customer{} - - err := query.Query(db, &customer) - - assert.Error(t, err, sql.ErrNoRows.Error()) -} From 9a3f12ea5fbd97271de30a425d5abf767c03f401 Mon Sep 17 00:00:00 2001 From: Christian King Date: Thu, 17 Oct 2019 13:27:18 -0400 Subject: [PATCH 4/7] Fix issue with using UUID primary keys in complex return types When using a UUID as a primary key with PostgreSQL the grouping was defaulting to the row which caused incorrect results to be returned. --- qrm/utill.go | 4 +- tests/postgres/sample_test.go | 148 ++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) diff --git a/qrm/utill.go b/qrm/utill.go index 7b6edaa..baee26e 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -170,7 +170,9 @@ func isSimpleModelType(objType reflect.Type) bool { case reflect.Slice: return objType.Elem().Kind() == reflect.Uint8 //[]byte case reflect.Struct: - return objType == timeType || objType == uuidType // time.Time || uuid.UUID + return objType == timeType + case reflect.Array: + return objType == uuidType // uuid.UUID returns reflect.Array kind } return false diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 32502a5..41c62f8 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -30,6 +30,154 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; assert.DeepEqual(t, result.UUIDPtr, UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) } +func TestUUIDComplex(t *testing.T) { + query := Person.INNER_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)). + SELECT(Person.AllColumns, PersonPhone.AllColumns). + ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC()) + + t.Run("slice of structs", func(t *testing.T) { + + var dest []struct { + model.Person + Phones []struct { + model.PersonPhone + } + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 2) + testutils.AssertJSON(t, dest, ` +[ + { + "PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6", + "FirstName": "Sad", + "LastName": "John", + "Mood": "sad", + "Phones": [ + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a51", + "PhoneNumber": "212-555-1211", + "PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6" + } + ] + }, + { + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8", + "FirstName": "Ok", + "LastName": "John", + "Mood": "ok", + "Phones": [ + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a52", + "PhoneNumber": "212-555-1212", + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8" + }, + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a53", + "PhoneNumber": "212-555-1213", + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8" + } + ] + } +] +`) + + }) + + t.Run("single struct", func(t *testing.T) { + singleQuery := query.WHERE(Person.PersonID.EQ(String("b68dbff6-a87d-11e9-a7f2-98ded00c39c8"))) + + var dest struct { + model.Person + Phones []struct { + model.PersonPhone + } + } + err := singleQuery.Query(db, &dest) + assert.NilError(t, err) + + testutils.AssertJSON(t, dest, ` +{ + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8", + "FirstName": "Ok", + "LastName": "John", + "Mood": "ok", + "Phones": [ + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a52", + "PhoneNumber": "212-555-1212", + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8" + }, + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a53", + "PhoneNumber": "212-555-1213", + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8" + } + ] +} +`) + }) + + t.Run("slice of structs left join", func(t *testing.T) { + leftQuery := Person.LEFT_JOIN(PersonPhone, PersonPhone.PersonID.EQ(Person.PersonID)). + SELECT(Person.AllColumns, PersonPhone.AllColumns). + ORDER_BY(Person.PersonID.ASC(), PersonPhone.PhoneID.ASC()) + var dest []struct { + model.Person + Phones []struct { + model.PersonPhone + } + } + err := leftQuery.Query(db, &dest) + + assert.NilError(t, err) + testutils.AssertJSON(t, dest, ` +[ + { + "PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6", + "FirstName": "Sad", + "LastName": "John", + "Mood": "sad", + "Phones": [ + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a51", + "PhoneNumber": "212-555-1211", + "PersonID": "b68dbff4-a87d-11e9-a7f2-98ded00c39c6" + } + ] + }, + { + "PersonID": "b68dbff5-a87d-11e9-a7f2-98ded00c39c7", + "FirstName": "Ok", + "LastName": "John", + "Mood": "ok", + "Phones": null + }, + { + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8", + "FirstName": "Ok", + "LastName": "John", + "Mood": "ok", + "Phones": [ + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a52", + "PhoneNumber": "212-555-1212", + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8" + }, + { + "PhoneID": "02b61cc4-d500-4847-bd36-111eccbc7a53", + "PhoneNumber": "212-555-1213", + "PersonID": "b68dbff6-a87d-11e9-a7f2-98ded00c39c8" + } + ] + } +] +`) + }) + +} func TestEnumType(t *testing.T) { query := Person. SELECT(Person.AllColumns) From 53a76f31b4f290f7a4d6f587723ae6817afda0de Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 18 Oct 2019 09:56:38 +0200 Subject: [PATCH 5/7] Additional qrm tests. --- .circleci/config.yml | 2 +- qrm/utill.go | 9 ++------- qrm/utill_test.go | 35 +++++++++++++++++++++++++++++++++++ tests/testdata | 2 +- 4 files changed, 39 insertions(+), 9 deletions(-) create mode 100644 qrm/utill_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 6c73d31..d3ea5d0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -94,7 +94,7 @@ jobs: - run: mkdir -p $TEST_RESULTS - - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/execution/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml + - run: go test -v ./... -coverpkg=github.com/go-jet/jet/postgres/...,github.com/go-jet/jet/mysql/...,github.com/go-jet/jet/qrm/...,github.com/go-jet/jet/generator/...,github.com/go-jet/jet/internal/... -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - run: name: Upload code coverage diff --git a/qrm/utill.go b/qrm/utill.go index baee26e..20b8b54 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -156,6 +156,7 @@ func valueToString(value reflect.Value) string { var timeType = reflect.TypeOf(time.Now()) var uuidType = reflect.TypeOf(uuid.New()) +var byteArrayType = reflect.TypeOf([]byte("")) func isSimpleModelType(objType reflect.Type) bool { objType = indirectType(objType) @@ -167,15 +168,9 @@ func isSimpleModelType(objType reflect.Type) bool { reflect.String, reflect.Bool: return true - case reflect.Slice: - return objType.Elem().Kind() == reflect.Uint8 //[]byte - case reflect.Struct: - return objType == timeType - case reflect.Array: - return objType == uuidType // uuid.UUID returns reflect.Array kind } - return false + return objType == timeType || objType == uuidType || objType == byteArrayType } func isIntegerType(value reflect.Type) bool { diff --git a/qrm/utill_test.go b/qrm/utill_test.go new file mode 100644 index 0000000..e4ab53d --- /dev/null +++ b/qrm/utill_test.go @@ -0,0 +1,35 @@ +package qrm + +import ( + "github.com/google/uuid" + "gotest.tools/assert" + "reflect" + "testing" + "time" +) + +func TestIsSimpleModelType(t *testing.T) { + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int8(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int16(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int32(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(int64(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint8(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint16(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint32(11)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uint64(11)))) + + assert.Assert(t, isSimpleModelType(reflect.TypeOf(float32(123.46)))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(float64(123.46)))) + + assert.Assert(t, isSimpleModelType(reflect.TypeOf([]byte("Text")))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(time.Now()))) + assert.Assert(t, isSimpleModelType(reflect.TypeOf(uuid.New()))) + + complexModelType := struct { + Field1 string + Field2 string + }{} + + assert.Equal(t, isSimpleModelType(reflect.TypeOf(complexModelType)), false) + assert.Equal(t, isSimpleModelType(reflect.TypeOf(&complexModelType)), false) +} diff --git a/tests/testdata b/tests/testdata index 1f6bd8b..02e0795 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 1f6bd8bb86458019fa43b1e2cd7ae9488a7ac9a4 +Subproject commit 02e0795d1e06b959d0c564dc1e349159d57b1bf6 From f8daa1d76e9aa80b34f1b2c6c75f979c83796354 Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 18 Oct 2019 10:09:56 +0200 Subject: [PATCH 6/7] Some linter errors. --- internal/utils/utils.go | 8 ++++++++ postgres/functions.go | 2 +- qrm/scan_context.go | 4 +--- qrm/utill.go | 5 ++--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 9be5310..22ea1c3 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -113,6 +113,13 @@ func IsNil(v interface{}) bool { return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) } +// MustBeTrue panics when condition is false +func MustBeTrue(condition bool, errorStr string) { + if !condition { + panic(errorStr) + } +} + // MustBe panics with errorStr error, if v interface is not of reflect kind func MustBe(v interface{}, kind reflect.Kind, errorStr string) { if reflect.TypeOf(v).Kind() != kind { @@ -165,6 +172,7 @@ func ErrorCatch(err *error) { } } +// StringSliceContains checks if slice of strings contains a string func StringSliceContains(strings []string, contains string) bool { for _, str := range strings { if str == contains { diff --git a/postgres/functions.go b/postgres/functions.go index 6993de4..ddd01db 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -69,7 +69,7 @@ var COUNT = jet.COUNT // EVERY is aggregate function. Returns true if all input values are true, otherwise false var EVERY = jet.EVERY -// MAXf is aggregate function. Returns maximum value of expression across all input values +// MAX is aggregate function. Returns maximum value of expression across all input values var MAX = jet.MAX // MAXf is aggregate function. Returns maximum value of float expression across all input values diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 8141ed7..e3f7f40 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -212,9 +212,7 @@ func (s *scanContext) rowElem(index int) interface{} { valuer, ok := s.row[index].(driver.Valuer) - if !ok { - panic("jet: internal error, scan value doesn't implement driver.Valuer") - } + utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer") value, err := valuer.Value() diff --git a/qrm/utill.go b/qrm/utill.go index 20b8b54..7791f9a 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -50,9 +50,8 @@ func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value { } func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error { - if slicePtrValue.IsNil() { - panic("jet: internal, slice is nil") - } + utils.MustBeTrue(!slicePtrValue.IsNil(), "jet: internal, slice is nil") + sliceValue := slicePtrValue.Elem() sliceElemType := sliceValue.Type().Elem() From 64a51dc093ae2b48625a6d5be333d5df9e29b2a9 Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 18 Oct 2019 10:15:08 +0200 Subject: [PATCH 7/7] QRM returns qrm.ErrNoRows when scanning into struct destination and query result set is empty. --- internal/jet/statement.go | 4 ++-- qrm/qrm.go | 9 ++++++--- tests/postgres/scan_test.go | 3 +-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/internal/jet/statement.go b/internal/jet/statement.go index ab0c655..e4ba41b 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -16,11 +16,11 @@ type Statement interface { // Query executes statement over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. - // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. + // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. Query(db qrm.DB, destination interface{}) error // QueryContext executes statement with a context over database connection db and stores row result in destination. // Destination can be either pointer to struct or pointer to a slice. - // If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. + // If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. QueryContext(context context.Context, db qrm.DB, destination interface{}) error //Exec executes statement over db connection without returning any rows. diff --git a/qrm/qrm.go b/qrm/qrm.go index 8f79c57..e7e6406 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -2,15 +2,18 @@ package qrm import ( "context" - "database/sql" + "errors" "github.com/go-jet/jet/internal/utils" "reflect" ) +// ErrNoRows is returned by Query when query result set is empty +var ErrNoRows = errors.New("qrm: no rows in result set") + // Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` // using context `ctx` into destination `destPtr`. // Destination can be either pointer to struct or pointer to slice of structs. -// If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows. +// If destination is pointer to struct and query result set is empty, method returns qrm.ErrNoRows. func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { utils.MustBeInitializedPtr(db, "jet: db is nil") @@ -33,7 +36,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr } if rowsProcessed == 0 { - return sql.ErrNoRows + return ErrNoRows } // edge case when row result set contains only NULLs. diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index f42dbd5..11dac96 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -1,7 +1,6 @@ package postgres import ( - "database/sql" "fmt" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" @@ -704,7 +703,7 @@ func TestStructScanErrNoRows(t *testing.T) { err := query.Query(db, &customer) - assert.Error(t, err, sql.ErrNoRows.Error()) + assert.Error(t, err, qrm.ErrNoRows.Error()) } func TestStructScanAllNull(t *testing.T) {