Merge pull request #557 from go-jet/first_row_strict
Check strict scan only on first row.
This commit is contained in:
commit
95224a793f
2 changed files with 18 additions and 15 deletions
15
qrm/qrm.go
15
qrm/qrm.go
|
|
@ -241,12 +241,8 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
|
||||||
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
|
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
|
if scanContext.rowNum == 1 {
|
||||||
scanContext.EnsureEveryColumnRead() // can panic
|
scanContext.ensureStrictness()
|
||||||
}
|
|
||||||
|
|
||||||
if GlobalConfig.StrictFieldMapping {
|
|
||||||
scanContext.EnsureEveryFieldMapped() // can panic
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -291,11 +287,8 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf
|
||||||
return scanContext.rowNum, err
|
return scanContext.rowNum, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
|
if scanContext.rowNum == 1 {
|
||||||
scanContext.EnsureEveryColumnRead()
|
scanContext.ensureStrictness()
|
||||||
}
|
|
||||||
if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping {
|
|
||||||
scanContext.EnsureEveryFieldMapped()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ScanContext) EnsureEveryColumnRead() {
|
func (s *ScanContext) ensureStrictness() { // can panic
|
||||||
|
if GlobalConfig.StrictScan {
|
||||||
|
s.ensureEveryColumnRead() // can panic
|
||||||
|
}
|
||||||
|
|
||||||
|
if GlobalConfig.StrictFieldMapping {
|
||||||
|
s.ensureEveryFieldMapped() // can panic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScanContext) ensureEveryColumnRead() {
|
||||||
var neverUsedColumns []string
|
var neverUsedColumns []string
|
||||||
|
|
||||||
for index, read := range s.columnIndexRead {
|
for index, read := range s.columnIndexRead {
|
||||||
|
|
@ -95,13 +105,13 @@ func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *
|
||||||
|
|
||||||
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
|
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
|
||||||
if parentField != nil {
|
if parentField != nil {
|
||||||
fieldIdent = fmt.Sprintf("%s %s.%s", parentField.Name, typeName, field.Name)
|
fieldIdent = fmt.Sprintf("%s %s", parentField.Name, fieldIdent)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
|
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ScanContext) EnsureEveryFieldMapped() {
|
func (s *ScanContext) ensureEveryFieldMapped() {
|
||||||
if len(s.unmappedFields) == 0 {
|
if len(s.unmappedFields) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -328,7 +338,7 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
|
||||||
// rowElemValue always returns non-ptr value,
|
// rowElemValue always returns non-ptr value,
|
||||||
// invalid value is nil
|
// invalid value is nil
|
||||||
func (s *ScanContext) rowElemValue(index int) reflect.Value {
|
func (s *ScanContext) rowElemValue(index int) reflect.Value {
|
||||||
if s.rowNum == 1 {
|
if s.rowNum == 1 && GlobalConfig.StrictScan {
|
||||||
s.columnIndexRead[index] = true
|
s.columnIndexRead[index] = true
|
||||||
}
|
}
|
||||||
scannedValue := reflect.ValueOf(s.row[index])
|
scannedValue := reflect.ValueOf(s.row[index])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue