diff --git a/qrm/qrm.go b/qrm/qrm.go index cc93bae..cf44044 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -241,12 +241,8 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac return fmt.Errorf("jet: failed to scan a row into destination, %w", err) } - if scanContext.rowNum == 1 && GlobalConfig.StrictScan { - scanContext.EnsureEveryColumnRead() // can panic - } - - if GlobalConfig.StrictFieldMapping { - scanContext.EnsureEveryFieldMapped() // can panic + if scanContext.rowNum == 1 { + scanContext.ensureStrictness() } return nil @@ -291,11 +287,8 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf return scanContext.rowNum, err } - if scanContext.rowNum == 1 && GlobalConfig.StrictScan { - scanContext.EnsureEveryColumnRead() - } - if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping { - scanContext.EnsureEveryFieldMapped() + if scanContext.rowNum == 1 { + scanContext.ensureStrictness() } } diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 9841270..f9eb148 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -67,7 +67,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) { }, nil } -func (s *ScanContext) EnsureEveryColumnRead() { +func (s *ScanContext) ensureStrictness() { // can panic + if GlobalConfig.StrictScan { + s.ensureEveryColumnRead() // can panic + } + + if GlobalConfig.StrictFieldMapping { + s.ensureEveryFieldMapped() // can panic + } +} + +func (s *ScanContext) ensureEveryColumnRead() { var neverUsedColumns []string for index, read := range s.columnIndexRead { @@ -95,13 +105,13 @@ func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField * fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name) if parentField != nil { - fieldIdent = fmt.Sprintf("%s %s.%s", parentField.Name, typeName, field.Name) + fieldIdent = fmt.Sprintf("%s %s", parentField.Name, fieldIdent) } s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent)) } -func (s *ScanContext) EnsureEveryFieldMapped() { +func (s *ScanContext) ensureEveryFieldMapped() { if len(s.unmappedFields) == 0 { return } @@ -328,7 +338,7 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int { // rowElemValue always returns non-ptr value, // invalid value is nil func (s *ScanContext) rowElemValue(index int) reflect.Value { - if s.rowNum == 1 { + if s.rowNum == 1 && GlobalConfig.StrictScan { s.columnIndexRead[index] = true } scannedValue := reflect.ValueOf(s.row[index])