jet/qrm/scan_context.go

333 lines
8.1 KiB
Go
Raw Normal View History

2019-10-11 10:15:36 +02:00
package qrm
import (
"database/sql"
"fmt"
2019-10-11 10:15:36 +02:00
"reflect"
"strings"
)
// ScanContext contains information about current row processed, mapping from the row to the
// destination types and type grouping information.
type ScanContext struct {
rowNum int64
row []interface{}
uniqueDestObjectsMap map[string]int
commonIdentToColumnIndex map[string]int
groupKeyInfoCache map[string]groupKeyInfo
typeInfoMap map[string]typeInfo
typesVisited typeStack // to prevent circular dependency scan
columnAlias []string
columnIndexRead []bool
2025-12-29 23:27:27 +09:00
unmappedFields []string
2019-10-11 10:15:36 +02:00
}
// NewScanContext creates new ScanContext from rows
func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
2019-10-11 10:15:36 +02:00
aliases, err := rows.Columns()
if err != nil {
return nil, err
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
commonIdentToColumnIndex := map[string]int{}
2019-10-11 10:15:36 +02:00
for i, alias := range aliases {
names := strings.SplitN(alias, ".", 2)
commonIdentifier := toCommonIdentifier(names[0])
2019-10-11 10:15:36 +02:00
if len(names) > 1 {
commonIdentifier = concat(commonIdentifier, ".", toCommonIdentifier(names[1]))
2019-10-11 10:15:36 +02:00
}
commonIdentToColumnIndex[commonIdentifier] = i
2019-10-11 10:15:36 +02:00
}
return &ScanContext{
row: createScanSlice(len(columnTypes)),
2019-10-11 10:15:36 +02:00
uniqueDestObjectsMap: make(map[string]int),
groupKeyInfoCache: make(map[string]groupKeyInfo),
commonIdentToColumnIndex: commonIdentToColumnIndex,
2019-10-11 10:15:36 +02:00
typeInfoMap: make(map[string]typeInfo),
typesVisited: newTypeStack(),
columnAlias: aliases,
columnIndexRead: make([]bool, len(aliases)),
2019-10-11 10:15:36 +02:00
}, nil
}
func (s *ScanContext) EnsureEveryColumnRead() {
var neverUsedColumns []string
for index, read := range s.columnIndexRead {
if !read {
neverUsedColumns = append(neverUsedColumns, `'`+s.columnAlias[index]+`'`)
}
}
if len(neverUsedColumns) > 0 {
panic("jet: columns never used: " + strings.Join(neverUsedColumns, ", "))
}
}
2025-12-29 23:27:27 +09:00
func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *reflect.StructField, field reflect.StructField) {
// skip private/unsettable fields (those are ignored by mapRowToStruct anyway)
if field.PkgPath != "" {
return
}
// NOTE: For unnamed/anonymous structs, Name() is empty, so String() is used for readability/uniqueness.
typeName := structType.String()
if structType.Name() != "" {
typeName = structType.Name()
}
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
if parentField != nil {
fieldIdent = fmt.Sprintf("%s.%s.%s", typeName, parentField.Name, field.Name)
}
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
}
func (s *ScanContext) EnsureEveryFieldMapped() {
if len(s.unmappedFields) == 0 {
return
}
panic("jet: fields never mapped: " + strings.Join(s.unmappedFields, ", "))
}
func createScanSlice(columnCount int) []interface{} {
scanPtrSlice := make([]interface{}, columnCount)
for i := range scanPtrSlice {
var a interface{}
scanPtrSlice[i] = &a // if destination is pointer to interface sql.Scan will just forward driver value
}
return scanPtrSlice
}
2019-10-11 10:15:36 +02:00
type typeInfo struct {
fieldMappings []fieldMapping
}
type fieldMappingType int
const (
simpleType fieldMappingType = iota
complexType // slice and struct are complex types supported
implementsScanner
jsonUnmarshal
)
2019-10-11 10:15:36 +02:00
type fieldMapping struct {
rowIndex int // index in ScanContext.row
Type fieldMappingType
2019-10-11 10:15:36 +02:00
}
func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo {
2019-10-11 10:15:36 +02:00
typeMapKey := structType.String()
if parentField != nil {
typeMapKey = concat(typeMapKey, string(parentField.Tag))
2019-10-11 10:15:36 +02:00
}
if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok {
return typeInfo
}
typeName := getTypeName(structType, parentField)
newTypeInfo := typeInfo{}
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
newTypeName, fieldName, jsonUnmarshaler := getTypeAndFieldName(typeName, field)
2019-10-11 10:15:36 +02:00
columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{
rowIndex: columnIndex,
2019-10-11 10:15:36 +02:00
}
if jsonUnmarshaler {
fieldMap.Type = jsonUnmarshal
} else if implementsScannerType(field.Type) {
fieldMap.Type = implementsScanner
2019-10-11 10:15:36 +02:00
} else if !isSimpleModelType(field.Type) {
fieldMap.Type = complexType
} else {
fieldMap.Type = simpleType
2019-10-11 10:15:36 +02:00
}
2025-12-29 23:27:27 +09:00
if GlobalConfig.StrictFieldMapping && fieldMap.rowIndex == -1 && fieldMap.Type != complexType {
s.recordUnmappedField(structType, parentField, field)
}
2019-10-11 10:15:36 +02:00
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
}
s.typeInfoMap[typeMapKey] = newTypeInfo
return newTypeInfo
}
type groupKeyInfo struct {
2023-04-14 12:20:36 +02:00
typeName string
pkIndexes []int
subTypes []groupKeyInfo
2019-10-11 10:15:36 +02:00
}
func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string {
2019-10-11 10:15:36 +02:00
mapKey := structType.Name()
if structField != nil {
mapKey = concat(mapKey, structField.Type.String(), string(structField.Tag))
2019-10-11 10:15:36 +02:00
}
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo)
}
tempTypeStack := newTypeStack()
groupKeyInfo := s.getGroupKeyInfo(structType, structField, &tempTypeStack)
2019-10-11 10:15:36 +02:00
s.groupKeyInfoCache[mapKey] = groupKeyInfo
return s.constructGroupKey(groupKeyInfo)
}
func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
2023-04-14 12:20:36 +02:00
if len(groupKeyInfo.pkIndexes) == 0 && len(groupKeyInfo.subTypes) == 0 {
return fmt.Sprintf("|ROW:%d|", s.rowNum)
2019-10-11 10:15:36 +02:00
}
var groupKeys []string
2019-10-11 10:15:36 +02:00
2023-04-14 12:20:36 +02:00
for _, index := range groupKeyInfo.pkIndexes {
groupKeys = append(groupKeys, s.rowElemToString(index))
2019-10-11 10:15:36 +02:00
}
var subTypesGroupKeys []string
2019-10-11 10:15:36 +02:00
for _, subType := range groupKeyInfo.subTypes {
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
}
return concat(groupKeyInfo.typeName, "(", strings.Join(groupKeys, ","), strings.Join(subTypesGroupKeys, ","), ")")
2019-10-11 10:15:36 +02:00
}
func (s *ScanContext) getGroupKeyInfo(
structType reflect.Type,
parentField *reflect.StructField,
typeVisited *typeStack) groupKeyInfo {
2019-10-11 10:15:36 +02:00
ret := groupKeyInfo{typeName: structType.Name()}
if typeVisited.contains(&structType) {
return ret
}
typeVisited.push(&structType)
defer typeVisited.pop()
2019-10-11 10:15:36 +02:00
typeName := getTypeName(structType, parentField)
primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField)
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
fieldType := indirectType(field.Type)
2022-12-16 23:17:26 +01:00
if isPrimaryKey(field, primaryKeyOverwrites) {
newTypeName, fieldName, _ := getTypeAndFieldName(typeName, field)
2019-10-11 10:15:36 +02:00
2023-04-14 12:20:36 +02:00
pkIndex := s.typeToColumnIndex(newTypeName, fieldName)
2019-10-11 10:15:36 +02:00
2023-04-14 12:20:36 +02:00
if pkIndex < 0 {
continue
2019-10-11 10:15:36 +02:00
}
2023-04-14 12:20:36 +02:00
ret.pkIndexes = append(ret.pkIndexes, pkIndex)
2023-04-14 12:20:36 +02:00
} else if fieldType.Kind() == reflect.Struct && fieldType != timeType {
2022-12-16 23:17:26 +01:00
subType := s.getGroupKeyInfo(fieldType, &field, typeVisited)
2023-04-14 12:20:36 +02:00
if len(subType.pkIndexes) != 0 || len(subType.subTypes) != 0 {
2022-12-16 23:17:26 +01:00
ret.subTypes = append(ret.subTypes, subType)
}
2019-10-11 10:15:36 +02:00
}
}
return ret
}
func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
2019-10-11 10:15:36 +02:00
var key string
if typeName != "" {
key = strings.ToLower(typeName + "." + fieldName)
} else {
key = strings.ToLower(fieldName)
}
index, ok := s.commonIdentToColumnIndex[key]
2019-10-11 10:15:36 +02:00
if !ok {
return -1
}
return index
}
// rowElemValue always returns non-ptr value,
// invalid value is nil
func (s *ScanContext) rowElemValue(index int) reflect.Value {
if s.rowNum == 1 {
s.columnIndexRead[index] = true
}
scannedValue := reflect.ValueOf(s.row[index])
return scannedValue.Elem().Elem() // no need to check validity of Elem, because s.row[index] always contains interface in interface
}
func (s *ScanContext) rowElemToString(index int) string {
value := s.rowElemValue(index)
if !value.IsValid() {
return "nil"
}
valueInterface := value.Interface()
2019-10-11 10:15:36 +02:00
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
2019-10-11 10:15:36 +02:00
return fmt.Sprintf("%#v", valueInterface)
2019-10-11 10:15:36 +02:00
}
func (s *ScanContext) rowElemValueClonePtr(index int) reflect.Value {
rowElemValue := s.rowElemValue(index)
2019-10-11 10:15:36 +02:00
if !rowElemValue.IsValid() {
return reflect.Value{}
}
newElem := reflect.New(rowElemValue.Type())
newElem.Elem().Set(rowElemValue)
return newElem
2019-10-11 10:15:36 +02:00
}