Additional scan performance improvements
Move typeStack to ScanContext, so it is shared between rows.Scan calls. Use string.Builder for string concatenations. Simplify value assign logic. Move convert value to the last assign step (needs for type conversions are rare).
This commit is contained in:
parent
c10244aeab
commit
c86903fd1d
8 changed files with 428 additions and 174 deletions
60
qrm/qrm.go
60
qrm/qrm.go
|
|
@ -74,15 +74,15 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
|
|||
err := rows.Scan(scanContext.row...)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows scan error, %w", err)
|
||||
return fmt.Errorf("jet: rows scan error, %w", err)
|
||||
}
|
||||
|
||||
destValue := reflect.ValueOf(destPtr)
|
||||
destValuePtr := reflect.ValueOf(destPtr)
|
||||
|
||||
_, err = mapRowToStruct(scanContext, "", newTypeStack(), destValue, nil)
|
||||
_, err = mapRowToStruct(scanContext, "", destValuePtr, nil)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to map a row, %w", err)
|
||||
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -121,7 +121,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
|
|||
|
||||
scanContext.rowNum++
|
||||
|
||||
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil)
|
||||
_, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
|
||||
|
||||
if err != nil {
|
||||
return scanContext.rowNum, err
|
||||
|
|
@ -139,7 +139,6 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
|
|||
func mapRowToSlice(
|
||||
scanContext *ScanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack,
|
||||
slicePtrValue reflect.Value,
|
||||
field *reflect.StructField) (updated bool, err error) {
|
||||
|
||||
|
|
@ -154,19 +153,19 @@ func mapRowToSlice(
|
|||
|
||||
structGroupKey := scanContext.getGroupKey(sliceElemType, field)
|
||||
|
||||
groupKey = groupKey + "," + structGroupKey
|
||||
groupKey = concat(groupKey, ",", structGroupKey)
|
||||
|
||||
index, ok := scanContext.uniqueDestObjectsMap[groupKey]
|
||||
|
||||
if ok {
|
||||
structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
|
||||
|
||||
return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true)
|
||||
return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
|
||||
}
|
||||
|
||||
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
|
||||
|
||||
updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field)
|
||||
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -192,7 +191,7 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value
|
|||
return
|
||||
}
|
||||
}
|
||||
rowElemPtr := scanContext.rowElemValuePtr(index)
|
||||
rowElemPtr := scanContext.rowElemValueClonePtr(index)
|
||||
|
||||
if rowElemPtr.IsValid() && !rowElemPtr.IsNil() {
|
||||
updated = true
|
||||
|
|
@ -208,7 +207,6 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value
|
|||
func mapRowToStruct(
|
||||
scanContext *ScanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack, // to prevent circular dependency scan
|
||||
structPtrValue reflect.Value,
|
||||
parentField *reflect.StructField,
|
||||
onlySlices ...bool, // small optimization, not to assign to already assigned struct fields
|
||||
|
|
@ -217,12 +215,12 @@ func mapRowToStruct(
|
|||
mapOnlySlices := len(onlySlices) > 0
|
||||
structType := structPtrValue.Type().Elem()
|
||||
|
||||
if typesVisited.contains(&structType) {
|
||||
if scanContext.typesVisited.contains(&structType) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
typesVisited.push(&structType)
|
||||
defer typesVisited.pop()
|
||||
scanContext.typesVisited.push(&structType)
|
||||
defer scanContext.typesVisited.pop()
|
||||
|
||||
typeInf := scanContext.getTypeInfo(structType, parentField)
|
||||
|
||||
|
|
@ -240,7 +238,7 @@ func mapRowToStruct(
|
|||
|
||||
if fieldMap.complexType {
|
||||
var changed bool
|
||||
changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field)
|
||||
changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -251,34 +249,36 @@ func mapRowToStruct(
|
|||
}
|
||||
|
||||
} else {
|
||||
if mapOnlySlices || fieldMap.columnIndex == -1 {
|
||||
if mapOnlySlices || fieldMap.rowIndex == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
cellValue := scanContext.rowElem(fieldMap.columnIndex)
|
||||
scannedValue := scanContext.rowElemValue(fieldMap.rowIndex)
|
||||
|
||||
if cellValue == nil {
|
||||
if !scannedValue.IsValid() {
|
||||
setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value
|
||||
continue
|
||||
}
|
||||
|
||||
initializeValueIfNilPtr(fieldValue)
|
||||
updated = true
|
||||
|
||||
if fieldMap.implementsScanner {
|
||||
scanner := getScanner(fieldValue)
|
||||
initializeValueIfNilPtr(fieldValue)
|
||||
fieldScanner := getScanner(fieldValue)
|
||||
|
||||
err = scanner.Scan(cellValue)
|
||||
value := scannedValue.Interface()
|
||||
|
||||
err := fieldScanner.Scan(value)
|
||||
|
||||
if err != nil {
|
||||
err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
|
||||
return
|
||||
return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err)
|
||||
}
|
||||
} else {
|
||||
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue)
|
||||
err := assign(scannedValue, fieldValue)
|
||||
|
||||
if err != nil {
|
||||
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
|
||||
return
|
||||
return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(),
|
||||
field.Name, field.Type.String(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -290,7 +290,6 @@ func mapRowToStruct(
|
|||
func mapRowToDestinationValue(
|
||||
scanContext *ScanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack,
|
||||
dest reflect.Value,
|
||||
structField *reflect.StructField) (updated bool, err error) {
|
||||
|
||||
|
|
@ -306,7 +305,7 @@ func mapRowToDestinationValue(
|
|||
}
|
||||
}
|
||||
|
||||
updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField)
|
||||
updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
|
|
@ -322,7 +321,6 @@ func mapRowToDestinationValue(
|
|||
func mapRowToDestinationPtr(
|
||||
scanContext *ScanContext,
|
||||
groupKey string,
|
||||
typesVisited *typeStack,
|
||||
destPtrValue reflect.Value,
|
||||
structField *reflect.StructField) (updated bool, err error) {
|
||||
|
||||
|
|
@ -331,9 +329,9 @@ func mapRowToDestinationPtr(
|
|||
destValueKind := destPtrValue.Elem().Kind()
|
||||
|
||||
if destValueKind == reflect.Struct {
|
||||
return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField)
|
||||
return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
|
||||
} else if destValueKind == reflect.Slice {
|
||||
return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField)
|
||||
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
|
||||
} else {
|
||||
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ type ScanContext struct {
|
|||
commonIdentToColumnIndex map[string]int
|
||||
groupKeyInfoCache map[string]groupKeyInfo
|
||||
typeInfoMap map[string]typeInfo
|
||||
|
||||
typesVisited typeStack // to prevent circular dependency scan
|
||||
}
|
||||
|
||||
// NewScanContext creates new ScanContext from rows
|
||||
|
|
@ -39,7 +41,7 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
|
|||
commonIdentifier := toCommonIdentifier(names[0])
|
||||
|
||||
if len(names) > 1 {
|
||||
commonIdentifier += "." + toCommonIdentifier(names[1])
|
||||
commonIdentifier = concat(commonIdentifier, ".", toCommonIdentifier(names[1]))
|
||||
}
|
||||
|
||||
commonIdentToColumnIndex[commonIdentifier] = i
|
||||
|
|
@ -53,15 +55,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
|
|||
commonIdentToColumnIndex: commonIdentToColumnIndex,
|
||||
|
||||
typeInfoMap: make(map[string]typeInfo),
|
||||
|
||||
typesVisited: newTypeStack(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func createScanSlice(columnCount int) []interface{} {
|
||||
scanSlice := make([]interface{}, columnCount)
|
||||
scanPtrSlice := make([]interface{}, columnCount)
|
||||
|
||||
for i := range scanPtrSlice {
|
||||
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value
|
||||
var a interface{}
|
||||
scanPtrSlice[i] = &a // if destination is pointer to interface sql.Scan will just forward driver value
|
||||
}
|
||||
|
||||
return scanPtrSlice
|
||||
|
|
@ -72,8 +76,8 @@ type typeInfo struct {
|
|||
}
|
||||
|
||||
type fieldMapping struct {
|
||||
complexType bool // slice or struct
|
||||
columnIndex int
|
||||
complexType bool // slice and struct are complex types
|
||||
rowIndex int // index in ScanContext.row
|
||||
implementsScanner bool
|
||||
}
|
||||
|
||||
|
|
@ -82,7 +86,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
|
|||
typeMapKey := structType.String()
|
||||
|
||||
if parentField != nil {
|
||||
typeMapKey += string(parentField.Tag)
|
||||
typeMapKey = concat(typeMapKey, string(parentField.Tag))
|
||||
}
|
||||
|
||||
if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok {
|
||||
|
|
@ -100,7 +104,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
|
|||
columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
|
||||
|
||||
fieldMap := fieldMapping{
|
||||
columnIndex: columnIndex,
|
||||
rowIndex: columnIndex,
|
||||
}
|
||||
|
||||
if implementsScannerType(field.Type) {
|
||||
|
|
@ -128,14 +132,15 @@ func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect.
|
|||
mapKey := structType.Name()
|
||||
|
||||
if structField != nil {
|
||||
mapKey += structField.Type.String()
|
||||
mapKey = concat(mapKey, structField.Type.String())
|
||||
}
|
||||
|
||||
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
|
||||
return s.constructGroupKey(groupKeyInfo)
|
||||
}
|
||||
|
||||
groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack())
|
||||
tempTypeStack := newTypeStack()
|
||||
groupKeyInfo := s.getGroupKeyInfo(structType, structField, &tempTypeStack)
|
||||
|
||||
s.groupKeyInfoCache[mapKey] = groupKeyInfo
|
||||
|
||||
|
|
@ -150,10 +155,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
|
|||
var groupKeys []string
|
||||
|
||||
for _, index := range groupKeyInfo.indexes {
|
||||
cellValue := s.rowElem(index)
|
||||
subKey := valueToString(reflect.ValueOf(cellValue))
|
||||
|
||||
groupKeys = append(groupKeys, subKey)
|
||||
groupKeys = append(groupKeys, s.rowElemToString(index))
|
||||
}
|
||||
|
||||
var subTypesGroupKeys []string
|
||||
|
|
@ -161,7 +163,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
|
|||
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
|
||||
}
|
||||
|
||||
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")"
|
||||
return concat(groupKeyInfo.typeName, "(", strings.Join(groupKeys, ","), strings.Join(subTypesGroupKeys, ","), ")")
|
||||
}
|
||||
|
||||
func (s *ScanContext) getGroupKeyInfo(
|
||||
|
|
@ -231,32 +233,36 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
|
|||
return index
|
||||
}
|
||||
|
||||
func (s *ScanContext) rowElem(index int) interface{} {
|
||||
cellValue := reflect.ValueOf(s.row[index])
|
||||
|
||||
if cellValue.IsValid() && !cellValue.IsNil() {
|
||||
return cellValue.Elem().Interface()
|
||||
// rowElemValue always returns non-ptr value,
|
||||
// invalid value is nil
|
||||
func (s *ScanContext) rowElemValue(index int) reflect.Value {
|
||||
scannedValue := reflect.ValueOf(s.row[index])
|
||||
return scannedValue.Elem().Elem() // no need to check validity of Elem, because s.row[index] always contains interface in interface
|
||||
}
|
||||
|
||||
return nil
|
||||
func (s *ScanContext) rowElemToString(index int) string {
|
||||
value := s.rowElemValue(index)
|
||||
|
||||
if !value.IsValid() {
|
||||
return "nil"
|
||||
}
|
||||
|
||||
func (s *ScanContext) rowElemValuePtr(index int) reflect.Value {
|
||||
rowElem := s.rowElem(index)
|
||||
rowElemValue := reflect.ValueOf(rowElem)
|
||||
valueInterface := value.Interface()
|
||||
|
||||
if t, ok := valueInterface.(fmt.Stringer); ok {
|
||||
return t.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%#v", valueInterface)
|
||||
}
|
||||
|
||||
func (s *ScanContext) rowElemValueClonePtr(index int) reflect.Value {
|
||||
rowElemValue := s.rowElemValue(index)
|
||||
|
||||
if !rowElemValue.IsValid() {
|
||||
return reflect.Value{}
|
||||
}
|
||||
|
||||
if rowElemValue.Kind() == reflect.Ptr {
|
||||
return rowElemValue
|
||||
}
|
||||
|
||||
if rowElemValue.CanAddr() {
|
||||
return rowElemValue.Addr()
|
||||
}
|
||||
|
||||
newElem := reflect.New(rowElemValue.Type())
|
||||
newElem.Elem().Set(rowElemValue)
|
||||
return newElem
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ import "reflect"
|
|||
|
||||
type typeStack []*reflect.Type
|
||||
|
||||
func newTypeStack() *typeStack {
|
||||
func newTypeStack() typeStack {
|
||||
stack := make(typeStack, 0, 20)
|
||||
return &stack
|
||||
return stack
|
||||
}
|
||||
|
||||
func (s *typeStack) isEmpty() bool {
|
||||
|
|
|
|||
139
qrm/utill.go
139
qrm/utill.go
|
|
@ -18,9 +18,9 @@ func implementsScannerType(fieldType reflect.Type) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
typePtr := reflect.New(fieldType).Type()
|
||||
fieldTypePtr := reflect.New(fieldType).Type()
|
||||
|
||||
return typePtr.Implements(scannerInterfaceType)
|
||||
return fieldTypePtr.Implements(scannerInterfaceType)
|
||||
}
|
||||
|
||||
func getScanner(value reflect.Value) sql.Scanner {
|
||||
|
|
@ -68,9 +68,9 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
|
|||
|
||||
if newSliceElemValue.Kind() == reflect.Ptr {
|
||||
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
|
||||
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
|
||||
err = assign(objPtrValue.Elem(), newSliceElemValue.Elem())
|
||||
} else {
|
||||
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
|
||||
err = assign(objPtrValue.Elem(), newSliceElemValue)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -138,29 +138,6 @@ func initializeValueIfNilPtr(value reflect.Value) {
|
|||
}
|
||||
}
|
||||
|
||||
func valueToString(value reflect.Value) string {
|
||||
|
||||
if !value.IsValid() {
|
||||
return "nil"
|
||||
}
|
||||
|
||||
var valueInterface interface{}
|
||||
if value.Kind() == reflect.Ptr {
|
||||
if value.IsNil() {
|
||||
return "nil"
|
||||
}
|
||||
valueInterface = value.Elem().Interface()
|
||||
} else {
|
||||
valueInterface = value.Interface()
|
||||
}
|
||||
|
||||
if t, ok := valueInterface.(fmt.Stringer); ok {
|
||||
return t.String()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%#v", valueInterface)
|
||||
}
|
||||
|
||||
var timeType = reflect.TypeOf(time.Now())
|
||||
var uuidType = reflect.TypeOf(uuid.New())
|
||||
var byteArrayType = reflect.TypeOf([]byte(""))
|
||||
|
|
@ -180,30 +157,35 @@ func isSimpleModelType(objType reflect.Type) bool {
|
|||
return objType == timeType || objType == uuidType || objType == byteArrayType
|
||||
}
|
||||
|
||||
func isIntegerType(objType reflect.Type) bool {
|
||||
objType = indirectType(objType)
|
||||
|
||||
switch objType.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return true
|
||||
// source can't be pointer
|
||||
// destination can be pointer
|
||||
func assign(source, destination reflect.Value) error {
|
||||
if destination.Kind() == reflect.Ptr {
|
||||
if destination.IsNil() {
|
||||
initializeValueIfNilPtr(destination)
|
||||
}
|
||||
|
||||
return false
|
||||
destination = destination.Elem()
|
||||
}
|
||||
|
||||
func isFloatType(value reflect.Type) bool {
|
||||
switch value.Kind() {
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
err := tryAssign(source, destination)
|
||||
|
||||
if err != nil {
|
||||
// needs for the type conversions are rare, so we leave conversion as a last assign step if everything else fails
|
||||
if tryConvert(source, destination) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return false
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func assignIfAssignable(source, destination reflect.Value) bool {
|
||||
if source.Type().AssignableTo(destination.Type()) {
|
||||
switch source.Type() {
|
||||
sourceType := source.Type()
|
||||
if sourceType.AssignableTo(destination.Type()) {
|
||||
switch sourceType {
|
||||
case byteArrayType:
|
||||
destination.SetBytes(cloneBytes(source.Interface().([]byte)))
|
||||
default:
|
||||
|
|
@ -215,31 +197,17 @@ func assignIfAssignable(source, destination reflect.Value) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// source and destination are non-ptr values
|
||||
func tryAssign(source, destination reflect.Value) error {
|
||||
|
||||
if assignIfAssignable(source, destination) {
|
||||
return nil
|
||||
}
|
||||
|
||||
sourceType := source.Type()
|
||||
destinationType := destination.Type()
|
||||
|
||||
if sourceType != destinationType &&
|
||||
!isFloatType(destinationType) && // to preserve precision during conversion
|
||||
!(isIntegerType(sourceType) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
|
||||
sourceType.ConvertibleTo(destinationType) {
|
||||
|
||||
source = source.Convert(destinationType)
|
||||
}
|
||||
|
||||
if assignIfAssignable(source, destination) {
|
||||
return nil
|
||||
}
|
||||
|
||||
sourceInterface := source.Interface()
|
||||
|
||||
switch destination.Interface().(type) {
|
||||
case bool:
|
||||
switch destination.Type().Kind() {
|
||||
case reflect.Bool:
|
||||
var nullBool internal.NullBool
|
||||
|
||||
err := nullBool.Scan(sourceInterface)
|
||||
|
|
@ -250,7 +218,7 @@ func tryAssign(source, destination reflect.Value) error {
|
|||
|
||||
destination.SetBool(nullBool.Bool)
|
||||
|
||||
case float32, float64:
|
||||
case reflect.Float32, reflect.Float64:
|
||||
var nullFloat sql.NullFloat64
|
||||
|
||||
err := nullFloat.Scan(sourceInterface)
|
||||
|
|
@ -261,7 +229,7 @@ func tryAssign(source, destination reflect.Value) error {
|
|||
if nullFloat.Valid {
|
||||
destination.SetFloat(nullFloat.Float64)
|
||||
}
|
||||
case int, int8, int16, int32, int64:
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
var integer sql.NullInt64
|
||||
|
||||
err := integer.Scan(sourceInterface)
|
||||
|
|
@ -273,7 +241,7 @@ func tryAssign(source, destination reflect.Value) error {
|
|||
destination.SetInt(integer.Int64)
|
||||
}
|
||||
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
var uInt internal.NullUInt64
|
||||
|
||||
err := uInt.Scan(sourceInterface)
|
||||
|
|
@ -286,7 +254,7 @@ func tryAssign(source, destination reflect.Value) error {
|
|||
destination.SetUint(uInt.UInt64)
|
||||
}
|
||||
|
||||
case string:
|
||||
case reflect.String:
|
||||
var str sql.NullString
|
||||
|
||||
err := str.Scan(sourceInterface)
|
||||
|
|
@ -298,6 +266,8 @@ func tryAssign(source, destination reflect.Value) error {
|
|||
destination.SetString(str.String)
|
||||
}
|
||||
|
||||
default:
|
||||
switch destination.Interface().(type) {
|
||||
case time.Time:
|
||||
var nullTime internal.NullTime
|
||||
|
||||
|
|
@ -309,42 +279,31 @@ func tryAssign(source, destination reflect.Value) error {
|
|||
if nullTime.Valid {
|
||||
destination.Set(reflect.ValueOf(nullTime.Time))
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func tryConvert(source, destination reflect.Value) bool {
|
||||
destinationType := destination.Type()
|
||||
|
||||
if source.Type().ConvertibleTo(destinationType) {
|
||||
source = source.Convert(destinationType)
|
||||
return assignIfAssignable(source, destination)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func setZeroValue(value reflect.Value) {
|
||||
if !value.IsZero() {
|
||||
value.Set(reflect.Zero(value.Type()))
|
||||
}
|
||||
}
|
||||
|
||||
func setReflectValue(source, destination reflect.Value) error {
|
||||
|
||||
if source.Kind() == reflect.Ptr {
|
||||
if source.IsNil() {
|
||||
// source is nil, destination should be its zero value
|
||||
setZeroValue(destination)
|
||||
return nil
|
||||
}
|
||||
source = source.Elem()
|
||||
}
|
||||
|
||||
if destination.Kind() == reflect.Ptr {
|
||||
if destination.IsNil() {
|
||||
initializeValueIfNilPtr(destination)
|
||||
}
|
||||
|
||||
destination = destination.Elem()
|
||||
}
|
||||
|
||||
return tryAssign(source, destination)
|
||||
}
|
||||
|
||||
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
|
||||
if len(primaryKeyOverwrites) > 0 {
|
||||
return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
|
||||
|
|
@ -398,3 +357,11 @@ func cloneBytes(b []byte) []byte {
|
|||
copy(c, b)
|
||||
return c
|
||||
}
|
||||
|
||||
func concat(stringList ...string) string {
|
||||
var b strings.Builder
|
||||
for _, str := range stringList {
|
||||
b.WriteString(str)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ GROUP BY payment.customer_id;
|
|||
"RentalID": null,
|
||||
"Amount": 0,
|
||||
"PaymentDate": "0001-01-01T00:00:00Z",
|
||||
"LastUpdate": "0001-01-01T00:00:00Z",
|
||||
"LastUpdate": null,
|
||||
"Count": 8,
|
||||
"Sum": 38.92,
|
||||
"Avg": 4.865,
|
||||
|
|
@ -964,7 +964,6 @@ func TestRowsScan(t *testing.T) {
|
|||
rows, err := stmt.Rows(context.Background(), db)
|
||||
require.NoError(t, err)
|
||||
|
||||
for rows.Next() {
|
||||
var inventory struct {
|
||||
model.Inventory
|
||||
|
||||
|
|
@ -972,6 +971,7 @@ func TestRowsScan(t *testing.T) {
|
|||
Store model.Store
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&inventory)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -1056,3 +1056,50 @@ func TestScanNumericToNumber(t *testing.T) {
|
|||
require.Equal(t, number.Float32, float32(1.234568e+09))
|
||||
require.Equal(t, number.Float64, float64(1.23456789e+09))
|
||||
}
|
||||
|
||||
// scan into custom base types should be equivalent to the scan into base go types
|
||||
func TestScanIntoCustomBaseTypes(t *testing.T) {
|
||||
|
||||
type MyUint8 uint8
|
||||
type MyUint16 uint16
|
||||
type MyUint32 uint32
|
||||
type MyInt16 int16
|
||||
type MyFloat32 float32
|
||||
type MyFloat64 float64
|
||||
type MyString string
|
||||
type MyTime = time.Time
|
||||
|
||||
type film struct {
|
||||
FilmID MyUint16 `sql:"primary_key"`
|
||||
Title MyString
|
||||
Description *MyString
|
||||
ReleaseYear *MyInt16
|
||||
LanguageID MyUint8
|
||||
OriginalLanguageID *MyUint8
|
||||
RentalDuration MyUint8
|
||||
RentalRate MyFloat32
|
||||
Length *MyUint32
|
||||
ReplacementCost MyFloat64
|
||||
Rating *model.FilmRating
|
||||
SpecialFeatures *MyString
|
||||
LastUpdate MyTime
|
||||
}
|
||||
|
||||
stmt := SELECT(
|
||||
Film.AllColumns,
|
||||
).FROM(
|
||||
Film,
|
||||
).ORDER_BY(
|
||||
Film.FilmID.ASC(),
|
||||
).LIMIT(3)
|
||||
|
||||
var films []model.Film
|
||||
err := stmt.Query(db, &films)
|
||||
require.NoError(t, err)
|
||||
|
||||
var myFilms []film
|
||||
err = stmt.Query(db, &myFilms)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -786,6 +786,123 @@ func TestRowsScan(t *testing.T) {
|
|||
requireQueryLogged(t, stmt, 0)
|
||||
}
|
||||
|
||||
func TestScanNullColumn(t *testing.T) {
|
||||
stmt := SELECT(
|
||||
Address.AllColumns,
|
||||
).FROM(
|
||||
Address,
|
||||
).WHERE(
|
||||
Address.Address2.IS_NULL(),
|
||||
)
|
||||
|
||||
var dest []model.Address
|
||||
|
||||
err := stmt.Query(db, &dest)
|
||||
require.NoError(t, err)
|
||||
testutils.AssertJSON(t, dest, `
|
||||
[
|
||||
{
|
||||
"AddressID": 1,
|
||||
"Address": "47 MySakila Drive",
|
||||
"Address2": null,
|
||||
"District": "Alberta",
|
||||
"CityID": 300,
|
||||
"PostalCode": "",
|
||||
"Phone": "",
|
||||
"LastUpdate": "2006-02-15T09:45:30Z"
|
||||
},
|
||||
{
|
||||
"AddressID": 2,
|
||||
"Address": "28 MySQL Boulevard",
|
||||
"Address2": null,
|
||||
"District": "QLD",
|
||||
"CityID": 576,
|
||||
"PostalCode": "",
|
||||
"Phone": "",
|
||||
"LastUpdate": "2006-02-15T09:45:30Z"
|
||||
},
|
||||
{
|
||||
"AddressID": 3,
|
||||
"Address": "23 Workhaven Lane",
|
||||
"Address2": null,
|
||||
"District": "Alberta",
|
||||
"CityID": 300,
|
||||
"PostalCode": "",
|
||||
"Phone": "14033335568",
|
||||
"LastUpdate": "2006-02-15T09:45:30Z"
|
||||
},
|
||||
{
|
||||
"AddressID": 4,
|
||||
"Address": "1411 Lillydale Drive",
|
||||
"Address2": null,
|
||||
"District": "QLD",
|
||||
"CityID": 576,
|
||||
"PostalCode": "",
|
||||
"Phone": "6172235589",
|
||||
"LastUpdate": "2006-02-15T09:45:30Z"
|
||||
}
|
||||
]
|
||||
`)
|
||||
}
|
||||
|
||||
func TestRowsScanSetZeroValue(t *testing.T) {
|
||||
stmt := SELECT(
|
||||
Rental.AllColumns,
|
||||
).FROM(
|
||||
Rental,
|
||||
).WHERE(
|
||||
Rental.RentalID.IN(Int(16049), Int(15966)),
|
||||
).ORDER_BY(
|
||||
Rental.RentalID.DESC(),
|
||||
)
|
||||
|
||||
rows, err := stmt.Rows(context.Background(), db)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
// destination object is used as destination for all rows scan.
|
||||
// this tests checks that ReturnedDate is set to nil with the second call
|
||||
// check qrm.setZeroValue
|
||||
var dest model.Rental
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&dest)
|
||||
require.NoError(t, err)
|
||||
|
||||
if dest.RentalID == 16049 {
|
||||
testutils.AssertJSON(t, dest, `
|
||||
{
|
||||
"RentalID": 16049,
|
||||
"RentalDate": "2005-08-23T22:50:12Z",
|
||||
"InventoryID": 2666,
|
||||
"CustomerID": 393,
|
||||
"ReturnDate": "2005-08-30T01:01:12Z",
|
||||
"StaffID": 2,
|
||||
"LastUpdate": "2006-02-16T02:30:53Z"
|
||||
}
|
||||
`)
|
||||
} else {
|
||||
testutils.AssertJSON(t, dest, `
|
||||
{
|
||||
"RentalID": 15966,
|
||||
"RentalDate": "2006-02-14T15:16:03Z",
|
||||
"InventoryID": 4472,
|
||||
"CustomerID": 374,
|
||||
"ReturnDate": null,
|
||||
"StaffID": 1,
|
||||
"LastUpdate": "2006-02-16T02:30:53Z"
|
||||
}
|
||||
`)
|
||||
}
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
require.NoError(t, err)
|
||||
err = rows.Err()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestScanNumericToFloat(t *testing.T) {
|
||||
type Number struct {
|
||||
Float32 float32
|
||||
|
|
@ -826,6 +943,54 @@ func TestScanNumericToIntegerError(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestScanIntoCustomBaseTypes(t *testing.T) {
|
||||
|
||||
type MyUint8 uint8
|
||||
type MyUint16 uint16
|
||||
type MyUint32 uint32
|
||||
type MyInt16 int16
|
||||
type MyFloat32 float32
|
||||
type MyFloat64 float64
|
||||
type MyString string
|
||||
type MyTime = time.Time
|
||||
|
||||
type film struct {
|
||||
FilmID MyUint16 `sql:"primary_key"`
|
||||
Title MyString
|
||||
Description *MyString
|
||||
ReleaseYear *MyInt16
|
||||
LanguageID MyUint8
|
||||
RentalDuration MyUint8
|
||||
RentalRate MyFloat32
|
||||
Length *MyUint32
|
||||
ReplacementCost MyFloat64
|
||||
Rating *model.MpaaRating
|
||||
LastUpdate MyTime
|
||||
SpecialFeatures *MyString
|
||||
Fulltext MyString
|
||||
}
|
||||
|
||||
stmt := SELECT(
|
||||
Film.AllColumns,
|
||||
).FROM(
|
||||
Film,
|
||||
).ORDER_BY(
|
||||
Film.FilmID.ASC(),
|
||||
).LIMIT(3)
|
||||
|
||||
var films []model.Film
|
||||
|
||||
err := stmt.Query(db, &films)
|
||||
require.NoError(t, err)
|
||||
|
||||
var myFilms []film
|
||||
|
||||
err = stmt.Query(db, &myFilms)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms))
|
||||
}
|
||||
|
||||
// QueryContext panic when the scanned value is nil and the destination is a slice of primitive
|
||||
// https://github.com/go-jet/jet/issues/91
|
||||
func TestScanToPrimitiveElementsSlice(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -2521,6 +2521,79 @@ func TestRecursionScanNx1(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
type StoreInfo struct {
|
||||
model.Store
|
||||
|
||||
Staffs ManagerInfo
|
||||
}
|
||||
|
||||
type ManagerInfo struct {
|
||||
model.Staff
|
||||
Store *StoreInfo
|
||||
}
|
||||
|
||||
func TestRecursionScan1x1(t *testing.T) {
|
||||
|
||||
stmt := SELECT(
|
||||
Store.AllColumns,
|
||||
Staff.AllColumns,
|
||||
).FROM(
|
||||
Store.
|
||||
INNER_JOIN(Staff, Staff.StaffID.EQ(Store.ManagerStaffID)),
|
||||
).ORDER_BY(
|
||||
Store.StoreID,
|
||||
)
|
||||
|
||||
var dest []StoreInfo
|
||||
|
||||
err := stmt.Query(db, &dest)
|
||||
require.NoError(t, err)
|
||||
testutils.AssertJSON(t, dest, `
|
||||
[
|
||||
{
|
||||
"StoreID": 1,
|
||||
"ManagerStaffID": 1,
|
||||
"AddressID": 1,
|
||||
"LastUpdate": "2006-02-15T09:57:12Z",
|
||||
"Staffs": {
|
||||
"StaffID": 1,
|
||||
"FirstName": "Mike",
|
||||
"LastName": "Hillyer",
|
||||
"AddressID": 3,
|
||||
"Email": "Mike.Hillyer@sakilastaff.com",
|
||||
"StoreID": 1,
|
||||
"Active": true,
|
||||
"Username": "Mike",
|
||||
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
|
||||
"LastUpdate": "2006-05-16T16:13:11.79328Z",
|
||||
"Picture": "iVBORw0KWgo=",
|
||||
"Store": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"StoreID": 2,
|
||||
"ManagerStaffID": 2,
|
||||
"AddressID": 2,
|
||||
"LastUpdate": "2006-02-15T09:57:12Z",
|
||||
"Staffs": {
|
||||
"StaffID": 2,
|
||||
"FirstName": "Jon",
|
||||
"LastName": "Stephens",
|
||||
"AddressID": 4,
|
||||
"Email": "Jon.Stephens@sakilastaff.com",
|
||||
"StoreID": 2,
|
||||
"Active": true,
|
||||
"Username": "Jon",
|
||||
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
|
||||
"LastUpdate": "2006-05-16T16:13:11.79328Z",
|
||||
"Picture": null,
|
||||
"Store": null
|
||||
}
|
||||
}
|
||||
]
|
||||
`)
|
||||
}
|
||||
|
||||
// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions,
|
||||
// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error.
|
||||
// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package postgres
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet/v2/internal/testutils"
|
||||
. "github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model"
|
||||
|
|
@ -864,5 +863,4 @@ WHERE orders1."orders.order_id" < $1;
|
|||
err := stmt.Query(db, &dest)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dest, 72)
|
||||
fmt.Println(len(dest))
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue