Merge pull request #542 from k4n4ry/feat/strict-field-mapping

feat: add StrictFieldMapping config
This commit is contained in:
go-jet 2026-01-31 14:04:09 +01:00 committed by GitHub
commit eaaa328580
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 344 additions and 9 deletions

View file

@ -6,8 +6,9 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/must"
"reflect" "reflect"
"github.com/go-jet/jet/v2/internal/utils/must"
) )
// Config holds the configuration settings for QRM scanning behavior. // Config holds the configuration settings for QRM scanning behavior.
@ -18,6 +19,16 @@ type Config struct {
// Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR // Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR
StrictScan bool StrictScan bool
// StrictFieldMapping, when true, causes the scanning function to panic if it encounters any
// destination struct fields that do not have matching columns in the SQL query result.
//
// Optional fields:
// If a destination field (including struct/slice fields) is not always selected by a query,
// it can be marked as optional using `qrm:"optional"`. When StrictFieldMapping is enabled,
// unmapped fields under an optional field will not trigger a panic.
// Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR
StrictFieldMapping bool
// JsonUnmarshalFunc is called by the Query method to unmarshal JSON query results created by // JsonUnmarshalFunc is called by the Query method to unmarshal JSON query results created by
// SELECT_JSON_OBJ and SELECT_JSON_ARR statements. // SELECT_JSON_OBJ and SELECT_JSON_ARR statements.
// It can be replaced with any implementation that matches the standard "encoding/json" `Unmarshal` function signature. // It can be replaced with any implementation that matches the standard "encoding/json" `Unmarshal` function signature.
@ -29,6 +40,7 @@ type Config struct {
// This variable is not thread safe, and it should be modified only once, for instance, during application initialization. // This variable is not thread safe, and it should be modified only once, for instance, during application initialization.
var GlobalConfig = Config{ var GlobalConfig = Config{
StrictScan: false, StrictScan: false,
StrictFieldMapping: false,
JsonUnmarshalFunc: json.Unmarshal, JsonUnmarshalFunc: json.Unmarshal,
} }
@ -233,6 +245,10 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
scanContext.EnsureEveryColumnRead() // can panic scanContext.EnsureEveryColumnRead() // can panic
} }
if GlobalConfig.StrictFieldMapping {
scanContext.EnsureEveryFieldMapped() // can panic
}
return nil return nil
} }
@ -278,6 +294,9 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf
if scanContext.rowNum == 1 && GlobalConfig.StrictScan { if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
scanContext.EnsureEveryColumnRead() scanContext.EnsureEveryColumnRead()
} }
if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping {
scanContext.EnsureEveryFieldMapped()
}
} }
err = rows.Close() err = rows.Close()

View file

@ -20,6 +20,8 @@ type ScanContext struct {
typesVisited typeStack // to prevent circular dependency scan typesVisited typeStack // to prevent circular dependency scan
columnAlias []string columnAlias []string
columnIndexRead []bool columnIndexRead []bool
unmappedFields []string
} }
// NewScanContext creates new ScanContext from rows // NewScanContext creates new ScanContext from rows
@ -79,6 +81,65 @@ func (s *ScanContext) EnsureEveryColumnRead() {
} }
} }
func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *reflect.StructField, field reflect.StructField) {
// skip private/unsettable fields (those are ignored by mapRowToStruct anyway)
if !field.IsExported() {
return
}
// NOTE: For unnamed/anonymous structs, Name() is empty, so String() is used for readability/uniqueness.
typeName := structType.Name()
if typeName == "" {
typeName = structType.String()
}
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
if parentField != nil {
fieldIdent = fmt.Sprintf("%s %s.%s", parentField.Name, typeName, field.Name)
}
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
}
func (s *ScanContext) EnsureEveryFieldMapped() {
if len(s.unmappedFields) == 0 {
return
}
panic("jet: fields never mapped: " + strings.Join(s.unmappedFields, ", "))
}
func isOptionalQrmField(field *reflect.StructField) bool {
if field == nil {
return false
}
tag := field.Tag.Get("qrm")
if tag == "" {
return false
}
for _, part := range strings.Split(tag, ",") {
if strings.TrimSpace(part) == "optional" {
return true
}
}
return false
}
func shouldRecordUnmappedField(parentField *reflect.StructField, field reflect.StructField, fieldMap fieldMapping) bool {
if !GlobalConfig.StrictFieldMapping {
return false
}
if fieldMap.Type == complexType {
return false
}
if fieldMap.rowIndex != -1 {
return false
}
if isOptionalQrmField(parentField) || isOptionalQrmField(&field) {
return false
}
return true
}
func createScanSlice(columnCount int) []interface{} { func createScanSlice(columnCount int) []interface{} {
scanPtrSlice := make([]interface{}, columnCount) scanPtrSlice := make([]interface{}, columnCount)
@ -144,6 +205,10 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
fieldMap.Type = simpleType fieldMap.Type = simpleType
} }
if shouldRecordUnmappedField(parentField, field, fieldMap) {
s.recordUnmappedField(structType, parentField, field)
}
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap) newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
} }

View file

@ -4,15 +4,17 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"runtime"
"testing"
"github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/qrm"
"github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/stmtcache"
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/pkg/profile" "github.com/pkg/profile"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"os"
"runtime"
"testing"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -57,6 +59,20 @@ func TestMain(m *testing.M) {
} }
func allowUnmappedFields(f func()) {
previous := qrm.GlobalConfig.StrictFieldMapping
defer func() { qrm.GlobalConfig.StrictFieldMapping = previous }()
qrm.GlobalConfig.StrictFieldMapping = false
f()
}
func requireStrictFieldMapping(f func()) {
previous := qrm.GlobalConfig.StrictFieldMapping
defer func() { qrm.GlobalConfig.StrictFieldMapping = previous }()
qrm.GlobalConfig.StrictFieldMapping = true
f()
}
func runCount(stmtCaching bool) int { func runCount(stmtCaching bool) int {
if stmtCaching { if stmtCaching {
return 4 return 4

View file

@ -2,13 +2,14 @@ package sqlite
import ( import (
"context" "context"
"github.com/go-jet/jet/v2/internal/utils/ptr"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/go-jet/jet/v2/internal/utils/ptr"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite" . "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
@ -43,6 +44,240 @@ WHERE actor.actor_id = ?;
requireQueryLogged(t, query, 1) requireQueryLogged(t, query, 1)
} }
func TestStrictFieldMapping(t *testing.T) {
queryAll := SELECT(
Actor.AllColumns,
).FROM(
Actor,
).WHERE(
Actor.ActorID.EQ(Int(2)),
).LIMIT(1)
testutils.AssertStatementSql(t, queryAll, `
SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM actor
WHERE actor.actor_id = ?
LIMIT ?;
`, int64(2), int64(1))
queryPartial := SELECT(
Actor.ActorID,
Actor.FirstName,
).FROM(
Actor,
).WHERE(
Actor.ActorID.EQ(Int(2)),
).LIMIT(1)
testutils.AssertStatementSql(t, queryPartial, `
SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name"
FROM actor
WHERE actor.actor_id = ?
LIMIT ?;
`, int64(2), int64(1))
// Destination model mapped via explicit field aliases ("actor.*").
type AliasedActor struct {
ActorID int32 `alias:"actor.actor_id"`
FirstName string `alias:"actor.first_name"`
LastName string `alias:"actor.last_name"`
LastUpdate string `alias:"actor.last_update"`
}
t.Run("all columns scan succeeds for generated model", func(t *testing.T) {
allowUnmappedFields(func() {
var dest model.Actor
require.NoError(t, queryAll.Query(db, &dest))
})
requireStrictFieldMapping(func() {
var dest model.Actor
require.NoError(t, queryAll.Query(db, &dest))
})
})
t.Run("all columns scan succeeds for aliased destination", func(t *testing.T) {
allowUnmappedFields(func() {
var dest []AliasedActor
require.NoError(t, queryAll.Query(db, &dest))
require.Len(t, dest, 1)
})
requireStrictFieldMapping(func() {
var dest []AliasedActor
require.NoError(t, queryAll.Query(db, &dest))
require.Len(t, dest, 1)
})
})
t.Run("partial columns panics in strict mode for generated model", func(t *testing.T) {
allowUnmappedFields(func() {
var dest []model.Actor
require.NoError(t, queryPartial.Query(db, &dest))
require.Len(t, dest, 1)
})
requireStrictFieldMapping(func() {
require.PanicsWithValue(t, "jet: fields never mapped: 'Actor.LastName', 'Actor.LastUpdate'", func() {
var dest []model.Actor
_ = queryPartial.Query(db, &dest)
})
})
})
t.Run("partial columns panics in strict mode for aliased destination", func(t *testing.T) {
allowUnmappedFields(func() {
var dest []AliasedActor
require.NoError(t, queryPartial.Query(db, &dest))
require.Len(t, dest, 1)
})
requireStrictFieldMapping(func() {
require.PanicsWithValue(t, "jet: fields never mapped: 'AliasedActor.LastName', 'AliasedActor.LastUpdate'", func() {
var dest []AliasedActor
_ = queryPartial.Query(db, &dest)
})
})
})
t.Run("unexported fields are ignored by strict field mapping", func(t *testing.T) {
type Dest struct {
actorID int32 `alias:"actor.missing_column"`
}
requireStrictFieldMapping(func() {
var dest []Dest
require.NoError(t, queryAll.Query(db, &dest))
})
})
t.Run("nested unmapped field uses parent field name in error", func(t *testing.T) {
type Inner struct {
Missing string `alias:"actor.missing_column"`
}
type Outer struct {
Child Inner
}
requireStrictFieldMapping(func() {
require.PanicsWithValue(t, "jet: fields never mapped: 'Child Inner.Missing'", func() {
var dest []Outer
_ = queryAll.Query(db, &dest)
})
})
})
t.Run("Rows.Scan triggers strict field mapping check", func(t *testing.T) {
type ActorLiteMissing struct {
ActorID int32 `alias:"actor.actor_id"`
FirstName string `alias:"actor.first_name"`
LastName string `alias:"actor.last_name"`
}
requireStrictFieldMapping(func() {
rows, err := queryPartial.Rows(context.Background(), db)
require.NoError(t, err)
require.True(t, rows.Next())
require.PanicsWithValue(t, "jet: fields never mapped: 'ActorLiteMissing.LastName'", func() {
var dest ActorLiteMissing
_ = rows.Scan(&dest)
})
_ = rows.Close()
})
})
t.Run("missing joined table columns panics for nested struct field", func(t *testing.T) {
filmOnly := SELECT(Film.AllColumns).FROM(Film).LIMIT(1)
type Dest struct {
model.Film
Actor model.Actor
}
allowUnmappedFields(func() {
var dest []Dest
require.NoError(t, filmOnly.Query(db, &dest))
require.Len(t, dest, 1)
})
requireStrictFieldMapping(func() {
require.PanicsWithValue(t, "jet: fields never mapped: 'Actor Actor.ActorID', 'Actor Actor.FirstName', 'Actor Actor.LastName', 'Actor Actor.LastUpdate'", func() {
var dest []Dest
_ = filmOnly.Query(db, &dest)
})
})
})
t.Run("missing joined table columns do not panic when nested struct field is optional", func(t *testing.T) {
filmOnly := SELECT(Film.AllColumns).FROM(Film).LIMIT(1)
type Dest struct {
model.Film
Actor model.Actor `qrm:"optional"`
}
requireStrictFieldMapping(func() {
var dest []Dest
require.NoError(t, filmOnly.Query(db, &dest))
require.Len(t, dest, 1)
})
})
t.Run("missing joined table columns panics for nested slice field", func(t *testing.T) {
filmOnly := SELECT(Film.AllColumns).FROM(Film).LIMIT(1)
type Dest struct {
model.Film
Actor []model.Actor
}
allowUnmappedFields(func() {
var dest []Dest
require.NoError(t, filmOnly.Query(db, &dest))
require.Len(t, dest, 1)
})
requireStrictFieldMapping(func() {
require.PanicsWithValue(t, "jet: fields never mapped: 'Actor Actor.ActorID', 'Actor Actor.FirstName', 'Actor Actor.LastName', 'Actor Actor.LastUpdate'", func() {
var dest []Dest
_ = filmOnly.Query(db, &dest)
})
})
})
t.Run("missing joined table columns do not panic when nested slice field is optional", func(t *testing.T) {
filmOnly := SELECT(Film.AllColumns).FROM(Film).LIMIT(1)
type Dest struct {
model.Film
Actor []model.Actor `qrm:"optional"`
}
requireStrictFieldMapping(func() {
var dest []Dest
require.NoError(t, filmOnly.Query(db, &dest))
require.Len(t, dest, 1)
})
})
t.Run("optional tag skips strict field mapping for missing simple field", func(t *testing.T) {
query := SELECT(Actor.ActorID).FROM(Actor).WHERE(Actor.ActorID.EQ(Int(2))).LIMIT(1)
type DestOptional struct {
ActorID int32 `alias:"actor.actor_id"`
OptionalMissing string `alias:"actor.missing_column" qrm:"optional"`
}
requireStrictFieldMapping(func() {
var dest []DestOptional
require.NoError(t, query.Query(db, &dest))
require.Len(t, dest, 1)
})
})
}
var actor2 = model.Actor{ var actor2 = model.Actor{
ActorID: 2, ActorID: 2,
FirstName: "NICK", FirstName: "NICK",