diff --git a/qrm/qrm.go b/qrm/qrm.go index 6fc081f..c2d1b10 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -6,8 +6,9 @@ import ( "encoding/json" "errors" "fmt" - "github.com/go-jet/jet/v2/internal/utils/must" "reflect" + + "github.com/go-jet/jet/v2/internal/utils/must" ) // Config holds the configuration settings for QRM scanning behavior. @@ -18,6 +19,13 @@ type Config struct { // Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR StrictScan bool + // StrictFieldMapping, when true, causes the scanning function to panic if it encounters any + // destination struct fields that do not have matching columns in the SQL query result. + // This check applies only to fields that are mapped from a single column (simple/scanner/json_column). + // Complex fields (struct/slice) are excluded because they are populated recursively and can be optional. + // Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR + StrictFieldMapping bool + // JsonUnmarshalFunc is called by the Query method to unmarshal JSON query results created by // SELECT_JSON_OBJ and SELECT_JSON_ARR statements. // It can be replaced with any implementation that matches the standard "encoding/json" `Unmarshal` function signature. @@ -28,8 +36,9 @@ type Config struct { // GlobalConfig is the package-wide configuration for SQL scanning. // This variable is not thread safe, and it should be modified only once, for instance, during application initialization. var GlobalConfig = Config{ - StrictScan: false, - JsonUnmarshalFunc: json.Unmarshal, + StrictScan: false, + StrictFieldMapping: false, + JsonUnmarshalFunc: json.Unmarshal, } // ErrNoRows is returned by Query when query result set is empty @@ -230,6 +239,9 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac } scanContext.EnsureEveryColumnRead() // can panic + if GlobalConfig.StrictFieldMapping { + scanContext.EnsureEveryFieldMapped() // can panic + } return nil } @@ -276,6 +288,9 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf if scanContext.rowNum == 1 && GlobalConfig.StrictScan { scanContext.EnsureEveryColumnRead() } + if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping { + scanContext.EnsureEveryFieldMapped() + } } err = rows.Close() diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 307dd60..ea529d0 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -20,6 +20,8 @@ type ScanContext struct { typesVisited typeStack // to prevent circular dependency scan columnAlias []string columnIndexRead []bool + + unmappedFields []string } // NewScanContext creates new ScanContext from rows @@ -79,6 +81,33 @@ func (s *ScanContext) EnsureEveryColumnRead() { } } +func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *reflect.StructField, field reflect.StructField) { + // skip private/unsettable fields (those are ignored by mapRowToStruct anyway) + if field.PkgPath != "" { + return + } + + // NOTE: For unnamed/anonymous structs, Name() is empty, so String() is used for readability/uniqueness. + typeName := structType.String() + if structType.Name() != "" { + typeName = structType.Name() + } + + fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name) + if parentField != nil { + fieldIdent = fmt.Sprintf("%s.%s.%s", typeName, parentField.Name, field.Name) + } + + s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent)) +} + +func (s *ScanContext) EnsureEveryFieldMapped() { + if len(s.unmappedFields) == 0 { + return + } + panic("jet: fields never mapped: " + strings.Join(s.unmappedFields, ", ")) +} + func createScanSlice(columnCount int) []interface{} { scanPtrSlice := make([]interface{}, columnCount) @@ -144,6 +173,10 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect. fieldMap.Type = simpleType } + if GlobalConfig.StrictFieldMapping && fieldMap.rowIndex == -1 && fieldMap.Type != complexType { + s.recordUnmappedField(structType, parentField, field) + } + newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap) } diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index a113c14..15ef110 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -4,15 +4,17 @@ import ( "context" "database/sql" "fmt" + "os" + "runtime" + "testing" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/tests/dbconfig" "github.com/pkg/profile" "github.com/stretchr/testify/require" - "os" - "runtime" - "testing" _ "github.com/mattn/go-sqlite3" ) @@ -57,6 +59,20 @@ func TestMain(m *testing.M) { } +func allowUnmappedFields(f func()) { + previous := qrm.GlobalConfig.StrictFieldMapping + defer func() { qrm.GlobalConfig.StrictFieldMapping = previous }() + qrm.GlobalConfig.StrictFieldMapping = false + f() +} + +func requireStrictFieldMapping(f func()) { + previous := qrm.GlobalConfig.StrictFieldMapping + defer func() { qrm.GlobalConfig.StrictFieldMapping = previous }() + qrm.GlobalConfig.StrictFieldMapping = true + f() +} + func runCount(stmtCaching bool) int { if stmtCaching { return 4 diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go index 79f1d6b..d1261da 100644 --- a/tests/sqlite/select_test.go +++ b/tests/sqlite/select_test.go @@ -2,13 +2,14 @@ package sqlite import ( "context" - "github.com/go-jet/jet/v2/internal/utils/ptr" - model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model" - "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table" "strings" "testing" "time" + "github.com/go-jet/jet/v2/internal/utils/ptr" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table" + "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/sqlite" "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" @@ -43,6 +44,151 @@ WHERE actor.actor_id = ?; requireQueryLogged(t, query, 1) } +func TestStrictFieldMapping(t *testing.T) { + queryAll := SELECT( + Actor.AllColumns, + ).FROM( + Actor, + ).WHERE( + Actor.ActorID.EQ(Int(2)), + ).LIMIT(1) + + testutils.AssertStatementSql(t, queryAll, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM actor +WHERE actor.actor_id = ? +LIMIT ?; +`, int64(2), int64(1)) + + queryPartial := SELECT( + Actor.ActorID, + Actor.FirstName, + ).FROM( + Actor, + ).WHERE( + Actor.ActorID.EQ(Int(2)), + ).LIMIT(1) + + testutils.AssertStatementSql(t, queryPartial, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name" +FROM actor +WHERE actor.actor_id = ? +LIMIT ?; +`, int64(2), int64(1)) + + // Destination model mapped via explicit field aliases ("actor.*"). + type AliasedActor struct { + ActorID int32 `alias:"actor.actor_id"` + FirstName string `alias:"actor.first_name"` + LastName string `alias:"actor.last_name"` + LastUpdate string `alias:"actor.last_update"` + } + + t.Run("all columns scan succeeds for generated model", func(t *testing.T) { + allowUnmappedFields(func() { + var dest model.Actor + require.NoError(t, queryAll.Query(db, &dest)) + }) + requireStrictFieldMapping(func() { + var dest model.Actor + require.NoError(t, queryAll.Query(db, &dest)) + }) + }) + + t.Run("all columns scan succeeds for aliased destination", func(t *testing.T) { + allowUnmappedFields(func() { + var dest []AliasedActor + require.NoError(t, queryAll.Query(db, &dest)) + require.Len(t, dest, 1) + }) + requireStrictFieldMapping(func() { + var dest []AliasedActor + require.NoError(t, queryAll.Query(db, &dest)) + require.Len(t, dest, 1) + }) + }) + + t.Run("partial columns panics in strict mode for generated model", func(t *testing.T) { + allowUnmappedFields(func() { + var dest []model.Actor + require.NoError(t, queryPartial.Query(db, &dest)) + require.Len(t, dest, 1) + }) + requireStrictFieldMapping(func() { + require.PanicsWithValue(t, "jet: fields never mapped: 'Actor.LastName', 'Actor.LastUpdate'", func() { + var dest []model.Actor + _ = queryPartial.Query(db, &dest) + }) + }) + }) + + t.Run("partial columns panics in strict mode for aliased destination", func(t *testing.T) { + allowUnmappedFields(func() { + var dest []AliasedActor + require.NoError(t, queryPartial.Query(db, &dest)) + require.Len(t, dest, 1) + }) + requireStrictFieldMapping(func() { + require.PanicsWithValue(t, "jet: fields never mapped: 'AliasedActor.LastName', 'AliasedActor.LastUpdate'", func() { + var dest []AliasedActor + _ = queryPartial.Query(db, &dest) + }) + }) + }) + + t.Run("unexported fields are ignored by strict field mapping", func(t *testing.T) { + type Dest struct { + actorID int32 `alias:"actor.missing_column"` + } + + requireStrictFieldMapping(func() { + var dest []Dest + require.NoError(t, queryAll.Query(db, &dest)) + }) + }) + + t.Run("nested unmapped field uses parent field name in error", func(t *testing.T) { + type Inner struct { + Missing string `alias:"actor.missing_column"` + } + type Outer struct { + Child Inner + } + + requireStrictFieldMapping(func() { + require.PanicsWithValue(t, "jet: fields never mapped: 'Inner.Child.Missing'", func() { + var dest []Outer + _ = queryAll.Query(db, &dest) + }) + }) + }) + + t.Run("Rows.Scan triggers strict field mapping check", func(t *testing.T) { + type ActorLiteMissing struct { + ActorID int32 `alias:"actor.actor_id"` + FirstName string `alias:"actor.first_name"` + LastName string `alias:"actor.last_name"` + } + + requireStrictFieldMapping(func() { + rows, err := queryPartial.Rows(context.Background(), db) + require.NoError(t, err) + require.True(t, rows.Next()) + + require.PanicsWithValue(t, "jet: fields never mapped: 'ActorLiteMissing.LastName'", func() { + var dest ActorLiteMissing + _ = rows.Scan(&dest) + }) + + _ = rows.Close() + }) + }) +} + var actor2 = model.Actor{ ActorID: 2, FirstName: "NICK",