Merge pull request #557 from go-jet/first_row_strict

Check strict scan only on first row.
This commit is contained in:
go-jet 2026-01-31 14:42:06 +01:00 committed by GitHub
commit 95224a793f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 18 additions and 15 deletions

View file

@ -241,12 +241,8 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
}
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
scanContext.EnsureEveryColumnRead() // can panic
}
if GlobalConfig.StrictFieldMapping {
scanContext.EnsureEveryFieldMapped() // can panic
if scanContext.rowNum == 1 {
scanContext.ensureStrictness()
}
return nil
@ -291,11 +287,8 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf
return scanContext.rowNum, err
}
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
scanContext.EnsureEveryColumnRead()
}
if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping {
scanContext.EnsureEveryFieldMapped()
if scanContext.rowNum == 1 {
scanContext.ensureStrictness()
}
}

View file

@ -67,7 +67,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
}, nil
}
func (s *ScanContext) EnsureEveryColumnRead() {
func (s *ScanContext) ensureStrictness() { // can panic
if GlobalConfig.StrictScan {
s.ensureEveryColumnRead() // can panic
}
if GlobalConfig.StrictFieldMapping {
s.ensureEveryFieldMapped() // can panic
}
}
func (s *ScanContext) ensureEveryColumnRead() {
var neverUsedColumns []string
for index, read := range s.columnIndexRead {
@ -95,13 +105,13 @@ func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
if parentField != nil {
fieldIdent = fmt.Sprintf("%s %s.%s", parentField.Name, typeName, field.Name)
fieldIdent = fmt.Sprintf("%s %s", parentField.Name, fieldIdent)
}
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
}
func (s *ScanContext) EnsureEveryFieldMapped() {
func (s *ScanContext) ensureEveryFieldMapped() {
if len(s.unmappedFields) == 0 {
return
}
@ -328,7 +338,7 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
// rowElemValue always returns non-ptr value,
// invalid value is nil
func (s *ScanContext) rowElemValue(index int) reflect.Value {
if s.rowNum == 1 {
if s.rowNum == 1 && GlobalConfig.StrictScan {
s.columnIndexRead[index] = true
}
scannedValue := reflect.ValueOf(s.row[index])