feat: add StrictFieldMapping config
This commit is contained in:
parent
adef2f9b1a
commit
f33c2ee357
4 changed files with 219 additions and 9 deletions
21
qrm/qrm.go
21
qrm/qrm.go
|
|
@ -6,8 +6,9 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-jet/jet/v2/internal/utils/must"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/go-jet/jet/v2/internal/utils/must"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config holds the configuration settings for QRM scanning behavior.
|
// Config holds the configuration settings for QRM scanning behavior.
|
||||||
|
|
@ -18,6 +19,13 @@ type Config struct {
|
||||||
// Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR
|
// Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR
|
||||||
StrictScan bool
|
StrictScan bool
|
||||||
|
|
||||||
|
// StrictFieldMapping, when true, causes the scanning function to panic if it encounters any
|
||||||
|
// destination struct fields that do not have matching columns in the SQL query result.
|
||||||
|
// This check applies only to fields that are mapped from a single column (simple/scanner/json_column).
|
||||||
|
// Complex fields (struct/slice) are excluded because they are populated recursively and can be optional.
|
||||||
|
// Does not apply to statements build with SELECT_JSON_OBJ or SELECT_JSON_ARR
|
||||||
|
StrictFieldMapping bool
|
||||||
|
|
||||||
// JsonUnmarshalFunc is called by the Query method to unmarshal JSON query results created by
|
// JsonUnmarshalFunc is called by the Query method to unmarshal JSON query results created by
|
||||||
// SELECT_JSON_OBJ and SELECT_JSON_ARR statements.
|
// SELECT_JSON_OBJ and SELECT_JSON_ARR statements.
|
||||||
// It can be replaced with any implementation that matches the standard "encoding/json" `Unmarshal` function signature.
|
// It can be replaced with any implementation that matches the standard "encoding/json" `Unmarshal` function signature.
|
||||||
|
|
@ -28,8 +36,9 @@ type Config struct {
|
||||||
// GlobalConfig is the package-wide configuration for SQL scanning.
|
// GlobalConfig is the package-wide configuration for SQL scanning.
|
||||||
// This variable is not thread safe, and it should be modified only once, for instance, during application initialization.
|
// This variable is not thread safe, and it should be modified only once, for instance, during application initialization.
|
||||||
var GlobalConfig = Config{
|
var GlobalConfig = Config{
|
||||||
StrictScan: false,
|
StrictScan: false,
|
||||||
JsonUnmarshalFunc: json.Unmarshal,
|
StrictFieldMapping: false,
|
||||||
|
JsonUnmarshalFunc: json.Unmarshal,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrNoRows is returned by Query when query result set is empty
|
// ErrNoRows is returned by Query when query result set is empty
|
||||||
|
|
@ -230,6 +239,9 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
|
||||||
}
|
}
|
||||||
|
|
||||||
scanContext.EnsureEveryColumnRead() // can panic
|
scanContext.EnsureEveryColumnRead() // can panic
|
||||||
|
if GlobalConfig.StrictFieldMapping {
|
||||||
|
scanContext.EnsureEveryFieldMapped() // can panic
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -276,6 +288,9 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf
|
||||||
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
|
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
|
||||||
scanContext.EnsureEveryColumnRead()
|
scanContext.EnsureEveryColumnRead()
|
||||||
}
|
}
|
||||||
|
if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping {
|
||||||
|
scanContext.EnsureEveryFieldMapped()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rows.Close()
|
err = rows.Close()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ type ScanContext struct {
|
||||||
typesVisited typeStack // to prevent circular dependency scan
|
typesVisited typeStack // to prevent circular dependency scan
|
||||||
columnAlias []string
|
columnAlias []string
|
||||||
columnIndexRead []bool
|
columnIndexRead []bool
|
||||||
|
|
||||||
|
unmappedFields []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewScanContext creates new ScanContext from rows
|
// NewScanContext creates new ScanContext from rows
|
||||||
|
|
@ -79,6 +81,33 @@ func (s *ScanContext) EnsureEveryColumnRead() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *reflect.StructField, field reflect.StructField) {
|
||||||
|
// skip private/unsettable fields (those are ignored by mapRowToStruct anyway)
|
||||||
|
if field.PkgPath != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: For unnamed/anonymous structs, Name() is empty, so String() is used for readability/uniqueness.
|
||||||
|
typeName := structType.String()
|
||||||
|
if structType.Name() != "" {
|
||||||
|
typeName = structType.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
|
||||||
|
if parentField != nil {
|
||||||
|
fieldIdent = fmt.Sprintf("%s.%s.%s", typeName, parentField.Name, field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ScanContext) EnsureEveryFieldMapped() {
|
||||||
|
if len(s.unmappedFields) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
panic("jet: fields never mapped: " + strings.Join(s.unmappedFields, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
func createScanSlice(columnCount int) []interface{} {
|
func createScanSlice(columnCount int) []interface{} {
|
||||||
scanPtrSlice := make([]interface{}, columnCount)
|
scanPtrSlice := make([]interface{}, columnCount)
|
||||||
|
|
||||||
|
|
@ -144,6 +173,10 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
|
||||||
fieldMap.Type = simpleType
|
fieldMap.Type = simpleType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if GlobalConfig.StrictFieldMapping && fieldMap.rowIndex == -1 && fieldMap.Type != complexType {
|
||||||
|
s.recordUnmappedField(structType, parentField, field)
|
||||||
|
}
|
||||||
|
|
||||||
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
|
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,17 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
"github.com/go-jet/jet/v2/internal/utils/throw"
|
||||||
|
"github.com/go-jet/jet/v2/qrm"
|
||||||
"github.com/go-jet/jet/v2/sqlite"
|
"github.com/go-jet/jet/v2/sqlite"
|
||||||
"github.com/go-jet/jet/v2/stmtcache"
|
"github.com/go-jet/jet/v2/stmtcache"
|
||||||
"github.com/go-jet/jet/v2/tests/dbconfig"
|
"github.com/go-jet/jet/v2/tests/dbconfig"
|
||||||
"github.com/pkg/profile"
|
"github.com/pkg/profile"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
@ -57,6 +59,20 @@ func TestMain(m *testing.M) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func allowUnmappedFields(f func()) {
|
||||||
|
previous := qrm.GlobalConfig.StrictFieldMapping
|
||||||
|
defer func() { qrm.GlobalConfig.StrictFieldMapping = previous }()
|
||||||
|
qrm.GlobalConfig.StrictFieldMapping = false
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func requireStrictFieldMapping(f func()) {
|
||||||
|
previous := qrm.GlobalConfig.StrictFieldMapping
|
||||||
|
defer func() { qrm.GlobalConfig.StrictFieldMapping = previous }()
|
||||||
|
qrm.GlobalConfig.StrictFieldMapping = true
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
func runCount(stmtCaching bool) int {
|
func runCount(stmtCaching bool) int {
|
||||||
if stmtCaching {
|
if stmtCaching {
|
||||||
return 4
|
return 4
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,14 @@ package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/go-jet/jet/v2/internal/utils/ptr"
|
|
||||||
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model"
|
|
||||||
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jet/jet/v2/internal/utils/ptr"
|
||||||
|
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model"
|
||||||
|
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table"
|
||||||
|
|
||||||
"github.com/go-jet/jet/v2/internal/testutils"
|
"github.com/go-jet/jet/v2/internal/testutils"
|
||||||
. "github.com/go-jet/jet/v2/sqlite"
|
. "github.com/go-jet/jet/v2/sqlite"
|
||||||
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
|
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
|
||||||
|
|
@ -43,6 +44,151 @@ WHERE actor.actor_id = ?;
|
||||||
requireQueryLogged(t, query, 1)
|
requireQueryLogged(t, query, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStrictFieldMapping(t *testing.T) {
|
||||||
|
queryAll := SELECT(
|
||||||
|
Actor.AllColumns,
|
||||||
|
).FROM(
|
||||||
|
Actor,
|
||||||
|
).WHERE(
|
||||||
|
Actor.ActorID.EQ(Int(2)),
|
||||||
|
).LIMIT(1)
|
||||||
|
|
||||||
|
testutils.AssertStatementSql(t, queryAll, `
|
||||||
|
SELECT actor.actor_id AS "actor.actor_id",
|
||||||
|
actor.first_name AS "actor.first_name",
|
||||||
|
actor.last_name AS "actor.last_name",
|
||||||
|
actor.last_update AS "actor.last_update"
|
||||||
|
FROM actor
|
||||||
|
WHERE actor.actor_id = ?
|
||||||
|
LIMIT ?;
|
||||||
|
`, int64(2), int64(1))
|
||||||
|
|
||||||
|
queryPartial := SELECT(
|
||||||
|
Actor.ActorID,
|
||||||
|
Actor.FirstName,
|
||||||
|
).FROM(
|
||||||
|
Actor,
|
||||||
|
).WHERE(
|
||||||
|
Actor.ActorID.EQ(Int(2)),
|
||||||
|
).LIMIT(1)
|
||||||
|
|
||||||
|
testutils.AssertStatementSql(t, queryPartial, `
|
||||||
|
SELECT actor.actor_id AS "actor.actor_id",
|
||||||
|
actor.first_name AS "actor.first_name"
|
||||||
|
FROM actor
|
||||||
|
WHERE actor.actor_id = ?
|
||||||
|
LIMIT ?;
|
||||||
|
`, int64(2), int64(1))
|
||||||
|
|
||||||
|
// Destination model mapped via explicit field aliases ("actor.*").
|
||||||
|
type AliasedActor struct {
|
||||||
|
ActorID int32 `alias:"actor.actor_id"`
|
||||||
|
FirstName string `alias:"actor.first_name"`
|
||||||
|
LastName string `alias:"actor.last_name"`
|
||||||
|
LastUpdate string `alias:"actor.last_update"`
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("all columns scan succeeds for generated model", func(t *testing.T) {
|
||||||
|
allowUnmappedFields(func() {
|
||||||
|
var dest model.Actor
|
||||||
|
require.NoError(t, queryAll.Query(db, &dest))
|
||||||
|
})
|
||||||
|
requireStrictFieldMapping(func() {
|
||||||
|
var dest model.Actor
|
||||||
|
require.NoError(t, queryAll.Query(db, &dest))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all columns scan succeeds for aliased destination", func(t *testing.T) {
|
||||||
|
allowUnmappedFields(func() {
|
||||||
|
var dest []AliasedActor
|
||||||
|
require.NoError(t, queryAll.Query(db, &dest))
|
||||||
|
require.Len(t, dest, 1)
|
||||||
|
})
|
||||||
|
requireStrictFieldMapping(func() {
|
||||||
|
var dest []AliasedActor
|
||||||
|
require.NoError(t, queryAll.Query(db, &dest))
|
||||||
|
require.Len(t, dest, 1)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("partial columns panics in strict mode for generated model", func(t *testing.T) {
|
||||||
|
allowUnmappedFields(func() {
|
||||||
|
var dest []model.Actor
|
||||||
|
require.NoError(t, queryPartial.Query(db, &dest))
|
||||||
|
require.Len(t, dest, 1)
|
||||||
|
})
|
||||||
|
requireStrictFieldMapping(func() {
|
||||||
|
require.PanicsWithValue(t, "jet: fields never mapped: 'Actor.LastName', 'Actor.LastUpdate'", func() {
|
||||||
|
var dest []model.Actor
|
||||||
|
_ = queryPartial.Query(db, &dest)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("partial columns panics in strict mode for aliased destination", func(t *testing.T) {
|
||||||
|
allowUnmappedFields(func() {
|
||||||
|
var dest []AliasedActor
|
||||||
|
require.NoError(t, queryPartial.Query(db, &dest))
|
||||||
|
require.Len(t, dest, 1)
|
||||||
|
})
|
||||||
|
requireStrictFieldMapping(func() {
|
||||||
|
require.PanicsWithValue(t, "jet: fields never mapped: 'AliasedActor.LastName', 'AliasedActor.LastUpdate'", func() {
|
||||||
|
var dest []AliasedActor
|
||||||
|
_ = queryPartial.Query(db, &dest)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unexported fields are ignored by strict field mapping", func(t *testing.T) {
|
||||||
|
type Dest struct {
|
||||||
|
actorID int32 `alias:"actor.missing_column"`
|
||||||
|
}
|
||||||
|
|
||||||
|
requireStrictFieldMapping(func() {
|
||||||
|
var dest []Dest
|
||||||
|
require.NoError(t, queryAll.Query(db, &dest))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested unmapped field uses parent field name in error", func(t *testing.T) {
|
||||||
|
type Inner struct {
|
||||||
|
Missing string `alias:"actor.missing_column"`
|
||||||
|
}
|
||||||
|
type Outer struct {
|
||||||
|
Child Inner
|
||||||
|
}
|
||||||
|
|
||||||
|
requireStrictFieldMapping(func() {
|
||||||
|
require.PanicsWithValue(t, "jet: fields never mapped: 'Inner.Child.Missing'", func() {
|
||||||
|
var dest []Outer
|
||||||
|
_ = queryAll.Query(db, &dest)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Rows.Scan triggers strict field mapping check", func(t *testing.T) {
|
||||||
|
type ActorLiteMissing struct {
|
||||||
|
ActorID int32 `alias:"actor.actor_id"`
|
||||||
|
FirstName string `alias:"actor.first_name"`
|
||||||
|
LastName string `alias:"actor.last_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
requireStrictFieldMapping(func() {
|
||||||
|
rows, err := queryPartial.Rows(context.Background(), db)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, rows.Next())
|
||||||
|
|
||||||
|
require.PanicsWithValue(t, "jet: fields never mapped: 'ActorLiteMissing.LastName'", func() {
|
||||||
|
var dest ActorLiteMissing
|
||||||
|
_ = rows.Scan(&dest)
|
||||||
|
})
|
||||||
|
|
||||||
|
_ = rows.Close()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
var actor2 = model.Actor{
|
var actor2 = model.Actor{
|
||||||
ActorID: 2,
|
ActorID: 2,
|
||||||
FirstName: "NICK",
|
FirstName: "NICK",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue