Add support for strict scan.

If there are unused columns in query result set Query method panics.
This commit is contained in:
go-jet 2025-03-11 10:50:06 +01:00
parent cfc264221b
commit d86f14e665
9 changed files with 402 additions and 207 deletions

View file

@ -10,6 +10,21 @@ import (
"reflect"
)
// Config holds the configuration settings for QRM scanning behavior.
type Config struct {
// StrictScan, when true, causes the scanning function to panic if it encounters any
// unused columns in the SQL query result. This ensures that every column is mapped
// to a field in the destination struct.
// Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR
StrictScan bool
}
// GlobalConfig is the package-wide configuration for SQL scanning.
// This variable should be modified only once, for instance, during application initialization.
var GlobalConfig = Config{
StrictScan: false,
}
// ErrNoRows is returned by Query when query result set is empty
var ErrNoRows = errors.New("qrm: no rows in result set")
@ -199,12 +214,16 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
destValuePtr := reflect.ValueOf(destPtr)
scanContext.rowNum++
_, err = mapRowToStruct(scanContext, "", destValuePtr, nil)
if err != nil {
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
}
scanContext.EnsureEveryColumnRead() // can panic
return nil
}
@ -246,6 +265,10 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf
if err != nil {
return scanContext.rowNum, err
}
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
scanContext.EnsureEveryColumnRead()
}
}
err = rows.Close()

View file

@ -17,7 +17,9 @@ type ScanContext struct {
groupKeyInfoCache map[string]groupKeyInfo
typeInfoMap map[string]typeInfo
typesVisited typeStack // to prevent circular dependency scan
typesVisited typeStack // to prevent circular dependency scan
columnAlias []string
columnIndexRead []bool
}
// NewScanContext creates new ScanContext from rows
@ -57,9 +59,26 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
typeInfoMap: make(map[string]typeInfo),
typesVisited: newTypeStack(),
columnAlias: aliases,
columnIndexRead: make([]bool, len(aliases)),
}, nil
}
func (s *ScanContext) EnsureEveryColumnRead() {
var neverUsedColumns []string
for index, read := range s.columnIndexRead {
if !read {
neverUsedColumns = append(neverUsedColumns, `'`+s.columnAlias[index]+`'`)
}
}
if len(neverUsedColumns) > 0 {
panic("jet: columns never used: " + strings.Join(neverUsedColumns, ", "))
}
}
func createScanSlice(columnCount int) []interface{} {
scanPtrSlice := make([]interface{}, columnCount)
@ -244,6 +263,9 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
// rowElemValue always returns non-ptr value,
// invalid value is nil
func (s *ScanContext) rowElemValue(index int) reflect.Value {
if s.rowNum == 1 {
s.columnIndexRead[index] = true
}
scannedValue := reflect.ValueOf(s.row[index])
return scannedValue.Elem().Elem() // no need to check validity of Elem, because s.row[index] always contains interface in interface
}