Merge pull request #557 from go-jet/first_row_strict
Check strict scan only on first row.
This commit is contained in:
commit
95224a793f
2 changed files with 18 additions and 15 deletions
15
qrm/qrm.go
15
qrm/qrm.go
|
|
@ -241,12 +241,8 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
|
|||
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
|
||||
}
|
||||
|
||||
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
|
||||
scanContext.EnsureEveryColumnRead() // can panic
|
||||
}
|
||||
|
||||
if GlobalConfig.StrictFieldMapping {
|
||||
scanContext.EnsureEveryFieldMapped() // can panic
|
||||
if scanContext.rowNum == 1 {
|
||||
scanContext.ensureStrictness()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -291,11 +287,8 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf
|
|||
return scanContext.rowNum, err
|
||||
}
|
||||
|
||||
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
|
||||
scanContext.EnsureEveryColumnRead()
|
||||
}
|
||||
if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping {
|
||||
scanContext.EnsureEveryFieldMapped()
|
||||
if scanContext.rowNum == 1 {
|
||||
scanContext.ensureStrictness()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *ScanContext) EnsureEveryColumnRead() {
|
||||
func (s *ScanContext) ensureStrictness() { // can panic
|
||||
if GlobalConfig.StrictScan {
|
||||
s.ensureEveryColumnRead() // can panic
|
||||
}
|
||||
|
||||
if GlobalConfig.StrictFieldMapping {
|
||||
s.ensureEveryFieldMapped() // can panic
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ScanContext) ensureEveryColumnRead() {
|
||||
var neverUsedColumns []string
|
||||
|
||||
for index, read := range s.columnIndexRead {
|
||||
|
|
@ -95,13 +105,13 @@ func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *
|
|||
|
||||
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
|
||||
if parentField != nil {
|
||||
fieldIdent = fmt.Sprintf("%s %s.%s", parentField.Name, typeName, field.Name)
|
||||
fieldIdent = fmt.Sprintf("%s %s", parentField.Name, fieldIdent)
|
||||
}
|
||||
|
||||
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
|
||||
}
|
||||
|
||||
func (s *ScanContext) EnsureEveryFieldMapped() {
|
||||
func (s *ScanContext) ensureEveryFieldMapped() {
|
||||
if len(s.unmappedFields) == 0 {
|
||||
return
|
||||
}
|
||||
|
|
@ -328,7 +338,7 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
|
|||
// rowElemValue always returns non-ptr value,
|
||||
// invalid value is nil
|
||||
func (s *ScanContext) rowElemValue(index int) reflect.Value {
|
||||
if s.rowNum == 1 {
|
||||
if s.rowNum == 1 && GlobalConfig.StrictScan {
|
||||
s.columnIndexRead[index] = true
|
||||
}
|
||||
scannedValue := reflect.ValueOf(s.row[index])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue