Additional scan performance improvements

Move typeStack to ScanContext, so it is shared between rows.Scan calls.
Use string.Builder for string concatenations.
Simplify value assign logic.
Move convert value to the last assign step (needs for type conversions are rare).
This commit is contained in:
go-jet 2022-02-09 12:34:10 +01:00
parent c10244aeab
commit c86903fd1d
8 changed files with 428 additions and 174 deletions

View file

@ -74,15 +74,15 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
err := rows.Scan(scanContext.row...) err := rows.Scan(scanContext.row...)
if err != nil { if err != nil {
return fmt.Errorf("rows scan error, %w", err) return fmt.Errorf("jet: rows scan error, %w", err)
} }
destValue := reflect.ValueOf(destPtr) destValuePtr := reflect.ValueOf(destPtr)
_, err = mapRowToStruct(scanContext, "", newTypeStack(), destValue, nil) _, err = mapRowToStruct(scanContext, "", destValuePtr, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to map a row, %w", err) return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
} }
return nil return nil
@ -121,7 +121,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
scanContext.rowNum++ scanContext.rowNum++
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil) _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
if err != nil { if err != nil {
return scanContext.rowNum, err return scanContext.rowNum, err
@ -139,7 +139,6 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
func mapRowToSlice( func mapRowToSlice(
scanContext *ScanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack,
slicePtrValue reflect.Value, slicePtrValue reflect.Value,
field *reflect.StructField) (updated bool, err error) { field *reflect.StructField) (updated bool, err error) {
@ -154,19 +153,19 @@ func mapRowToSlice(
structGroupKey := scanContext.getGroupKey(sliceElemType, field) structGroupKey := scanContext.getGroupKey(sliceElemType, field)
groupKey = groupKey + "," + structGroupKey groupKey = concat(groupKey, ",", structGroupKey)
index, ok := scanContext.uniqueDestObjectsMap[groupKey] index, ok := scanContext.uniqueDestObjectsMap[groupKey]
if ok { if ok {
structPtrValue := getSliceElemPtrAt(slicePtrValue, index) structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true) return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
} }
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field) updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
if err != nil { if err != nil {
return return
@ -192,7 +191,7 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value
return return
} }
} }
rowElemPtr := scanContext.rowElemValuePtr(index) rowElemPtr := scanContext.rowElemValueClonePtr(index)
if rowElemPtr.IsValid() && !rowElemPtr.IsNil() { if rowElemPtr.IsValid() && !rowElemPtr.IsNil() {
updated = true updated = true
@ -208,7 +207,6 @@ func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value
func mapRowToStruct( func mapRowToStruct(
scanContext *ScanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack, // to prevent circular dependency scan
structPtrValue reflect.Value, structPtrValue reflect.Value,
parentField *reflect.StructField, parentField *reflect.StructField,
onlySlices ...bool, // small optimization, not to assign to already assigned struct fields onlySlices ...bool, // small optimization, not to assign to already assigned struct fields
@ -217,12 +215,12 @@ func mapRowToStruct(
mapOnlySlices := len(onlySlices) > 0 mapOnlySlices := len(onlySlices) > 0
structType := structPtrValue.Type().Elem() structType := structPtrValue.Type().Elem()
if typesVisited.contains(&structType) { if scanContext.typesVisited.contains(&structType) {
return false, nil return false, nil
} }
typesVisited.push(&structType) scanContext.typesVisited.push(&structType)
defer typesVisited.pop() defer scanContext.typesVisited.pop()
typeInf := scanContext.getTypeInfo(structType, parentField) typeInf := scanContext.getTypeInfo(structType, parentField)
@ -240,7 +238,7 @@ func mapRowToStruct(
if fieldMap.complexType { if fieldMap.complexType {
var changed bool var changed bool
changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field) changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
if err != nil { if err != nil {
return return
@ -251,34 +249,36 @@ func mapRowToStruct(
} }
} else { } else {
if mapOnlySlices || fieldMap.columnIndex == -1 { if mapOnlySlices || fieldMap.rowIndex == -1 {
continue continue
} }
cellValue := scanContext.rowElem(fieldMap.columnIndex) scannedValue := scanContext.rowElemValue(fieldMap.rowIndex)
if cellValue == nil { if !scannedValue.IsValid() {
setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value
continue continue
} }
initializeValueIfNilPtr(fieldValue)
updated = true updated = true
if fieldMap.implementsScanner { if fieldMap.implementsScanner {
scanner := getScanner(fieldValue) initializeValueIfNilPtr(fieldValue)
fieldScanner := getScanner(fieldValue)
err = scanner.Scan(cellValue) value := scannedValue.Interface()
err := fieldScanner.Scan(value)
if err != nil { if err != nil {
err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err)
return
} }
} else { } else {
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) err := assign(scannedValue, fieldValue)
if err != nil { if err != nil {
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(),
return field.Name, field.Type.String(), err)
} }
} }
} }
@ -290,7 +290,6 @@ func mapRowToStruct(
func mapRowToDestinationValue( func mapRowToDestinationValue(
scanContext *ScanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack,
dest reflect.Value, dest reflect.Value,
structField *reflect.StructField) (updated bool, err error) { structField *reflect.StructField) (updated bool, err error) {
@ -306,7 +305,7 @@ func mapRowToDestinationValue(
} }
} }
updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField) updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
if err != nil { if err != nil {
return return
@ -322,7 +321,6 @@ func mapRowToDestinationValue(
func mapRowToDestinationPtr( func mapRowToDestinationPtr(
scanContext *ScanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack,
destPtrValue reflect.Value, destPtrValue reflect.Value,
structField *reflect.StructField) (updated bool, err error) { structField *reflect.StructField) (updated bool, err error) {
@ -331,9 +329,9 @@ func mapRowToDestinationPtr(
destValueKind := destPtrValue.Elem().Kind() destValueKind := destPtrValue.Elem().Kind()
if destValueKind == reflect.Struct { if destValueKind == reflect.Struct {
return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField) return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
} else if destValueKind == reflect.Slice { } else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField) return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else { } else {
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
} }

View file

@ -16,6 +16,8 @@ type ScanContext struct {
commonIdentToColumnIndex map[string]int commonIdentToColumnIndex map[string]int
groupKeyInfoCache map[string]groupKeyInfo groupKeyInfoCache map[string]groupKeyInfo
typeInfoMap map[string]typeInfo typeInfoMap map[string]typeInfo
typesVisited typeStack // to prevent circular dependency scan
} }
// NewScanContext creates new ScanContext from rows // NewScanContext creates new ScanContext from rows
@ -39,7 +41,7 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
commonIdentifier := toCommonIdentifier(names[0]) commonIdentifier := toCommonIdentifier(names[0])
if len(names) > 1 { if len(names) > 1 {
commonIdentifier += "." + toCommonIdentifier(names[1]) commonIdentifier = concat(commonIdentifier, ".", toCommonIdentifier(names[1]))
} }
commonIdentToColumnIndex[commonIdentifier] = i commonIdentToColumnIndex[commonIdentifier] = i
@ -53,15 +55,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
commonIdentToColumnIndex: commonIdentToColumnIndex, commonIdentToColumnIndex: commonIdentToColumnIndex,
typeInfoMap: make(map[string]typeInfo), typeInfoMap: make(map[string]typeInfo),
typesVisited: newTypeStack(),
}, nil }, nil
} }
func createScanSlice(columnCount int) []interface{} { func createScanSlice(columnCount int) []interface{} {
scanSlice := make([]interface{}, columnCount)
scanPtrSlice := make([]interface{}, columnCount) scanPtrSlice := make([]interface{}, columnCount)
for i := range scanPtrSlice { for i := range scanPtrSlice {
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value var a interface{}
scanPtrSlice[i] = &a // if destination is pointer to interface sql.Scan will just forward driver value
} }
return scanPtrSlice return scanPtrSlice
@ -72,8 +76,8 @@ type typeInfo struct {
} }
type fieldMapping struct { type fieldMapping struct {
complexType bool // slice or struct complexType bool // slice and struct are complex types
columnIndex int rowIndex int // index in ScanContext.row
implementsScanner bool implementsScanner bool
} }
@ -82,7 +86,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
typeMapKey := structType.String() typeMapKey := structType.String()
if parentField != nil { if parentField != nil {
typeMapKey += string(parentField.Tag) typeMapKey = concat(typeMapKey, string(parentField.Tag))
} }
if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok { if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok {
@ -100,7 +104,7 @@ func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
columnIndex := s.typeToColumnIndex(newTypeName, fieldName) columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{ fieldMap := fieldMapping{
columnIndex: columnIndex, rowIndex: columnIndex,
} }
if implementsScannerType(field.Type) { if implementsScannerType(field.Type) {
@ -128,14 +132,15 @@ func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect.
mapKey := structType.Name() mapKey := structType.Name()
if structField != nil { if structField != nil {
mapKey += structField.Type.String() mapKey = concat(mapKey, structField.Type.String())
} }
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok { if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo) return s.constructGroupKey(groupKeyInfo)
} }
groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack()) tempTypeStack := newTypeStack()
groupKeyInfo := s.getGroupKeyInfo(structType, structField, &tempTypeStack)
s.groupKeyInfoCache[mapKey] = groupKeyInfo s.groupKeyInfoCache[mapKey] = groupKeyInfo
@ -150,10 +155,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
var groupKeys []string var groupKeys []string
for _, index := range groupKeyInfo.indexes { for _, index := range groupKeyInfo.indexes {
cellValue := s.rowElem(index) groupKeys = append(groupKeys, s.rowElemToString(index))
subKey := valueToString(reflect.ValueOf(cellValue))
groupKeys = append(groupKeys, subKey)
} }
var subTypesGroupKeys []string var subTypesGroupKeys []string
@ -161,7 +163,7 @@ func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
} }
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" return concat(groupKeyInfo.typeName, "(", strings.Join(groupKeys, ","), strings.Join(subTypesGroupKeys, ","), ")")
} }
func (s *ScanContext) getGroupKeyInfo( func (s *ScanContext) getGroupKeyInfo(
@ -231,32 +233,36 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
return index return index
} }
func (s *ScanContext) rowElem(index int) interface{} { // rowElemValue always returns non-ptr value,
cellValue := reflect.ValueOf(s.row[index]) // invalid value is nil
func (s *ScanContext) rowElemValue(index int) reflect.Value {
if cellValue.IsValid() && !cellValue.IsNil() { scannedValue := reflect.ValueOf(s.row[index])
return cellValue.Elem().Interface() return scannedValue.Elem().Elem() // no need to check validity of Elem, because s.row[index] always contains interface in interface
}
return nil
} }
func (s *ScanContext) rowElemValuePtr(index int) reflect.Value { func (s *ScanContext) rowElemToString(index int) string {
rowElem := s.rowElem(index) value := s.rowElemValue(index)
rowElemValue := reflect.ValueOf(rowElem)
if !value.IsValid() {
return "nil"
}
valueInterface := value.Interface()
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
func (s *ScanContext) rowElemValueClonePtr(index int) reflect.Value {
rowElemValue := s.rowElemValue(index)
if !rowElemValue.IsValid() { if !rowElemValue.IsValid() {
return reflect.Value{} return reflect.Value{}
} }
if rowElemValue.Kind() == reflect.Ptr {
return rowElemValue
}
if rowElemValue.CanAddr() {
return rowElemValue.Addr()
}
newElem := reflect.New(rowElemValue.Type()) newElem := reflect.New(rowElemValue.Type())
newElem.Elem().Set(rowElemValue) newElem.Elem().Set(rowElemValue)
return newElem return newElem

View file

@ -4,9 +4,9 @@ import "reflect"
type typeStack []*reflect.Type type typeStack []*reflect.Type
func newTypeStack() *typeStack { func newTypeStack() typeStack {
stack := make(typeStack, 0, 20) stack := make(typeStack, 0, 20)
return &stack return stack
} }
func (s *typeStack) isEmpty() bool { func (s *typeStack) isEmpty() bool {

View file

@ -18,9 +18,9 @@ func implementsScannerType(fieldType reflect.Type) bool {
return true return true
} }
typePtr := reflect.New(fieldType).Type() fieldTypePtr := reflect.New(fieldType).Type()
return typePtr.Implements(scannerInterfaceType) return fieldTypePtr.Implements(scannerInterfaceType)
} }
func getScanner(value reflect.Value) sql.Scanner { func getScanner(value reflect.Value) sql.Scanner {
@ -68,9 +68,9 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
if newSliceElemValue.Kind() == reflect.Ptr { if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem())) newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem()) err = assign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else { } else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue) err = assign(objPtrValue.Elem(), newSliceElemValue)
} }
if err != nil { if err != nil {
@ -138,29 +138,6 @@ func initializeValueIfNilPtr(value reflect.Value) {
} }
} }
func valueToString(value reflect.Value) string {
if !value.IsValid() {
return "nil"
}
var valueInterface interface{}
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return "nil"
}
valueInterface = value.Elem().Interface()
} else {
valueInterface = value.Interface()
}
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
var timeType = reflect.TypeOf(time.Now()) var timeType = reflect.TypeOf(time.Now())
var uuidType = reflect.TypeOf(uuid.New()) var uuidType = reflect.TypeOf(uuid.New())
var byteArrayType = reflect.TypeOf([]byte("")) var byteArrayType = reflect.TypeOf([]byte(""))
@ -180,30 +157,35 @@ func isSimpleModelType(objType reflect.Type) bool {
return objType == timeType || objType == uuidType || objType == byteArrayType return objType == timeType || objType == uuidType || objType == byteArrayType
} }
func isIntegerType(objType reflect.Type) bool { // source can't be pointer
objType = indirectType(objType) // destination can be pointer
func assign(source, destination reflect.Value) error {
switch objType.Kind() { if destination.Kind() == reflect.Ptr {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, if destination.IsNil() {
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: initializeValueIfNilPtr(destination)
return true
} }
return false destination = destination.Elem()
}
func isFloatType(value reflect.Type) bool {
switch value.Kind() {
case reflect.Float32, reflect.Float64:
return true
} }
return false err := tryAssign(source, destination)
if err != nil {
// needs for the type conversions are rare, so we leave conversion as a last assign step if everything else fails
if tryConvert(source, destination) {
return nil
}
return err
}
return nil
} }
func assignIfAssignable(source, destination reflect.Value) bool { func assignIfAssignable(source, destination reflect.Value) bool {
if source.Type().AssignableTo(destination.Type()) { sourceType := source.Type()
switch source.Type() { if sourceType.AssignableTo(destination.Type()) {
switch sourceType {
case byteArrayType: case byteArrayType:
destination.SetBytes(cloneBytes(source.Interface().([]byte))) destination.SetBytes(cloneBytes(source.Interface().([]byte)))
default: default:
@ -215,31 +197,17 @@ func assignIfAssignable(source, destination reflect.Value) bool {
return false return false
} }
// source and destination are non-ptr values
func tryAssign(source, destination reflect.Value) error { func tryAssign(source, destination reflect.Value) error {
if assignIfAssignable(source, destination) { if assignIfAssignable(source, destination) {
return nil return nil
} }
sourceType := source.Type()
destinationType := destination.Type()
if sourceType != destinationType &&
!isFloatType(destinationType) && // to preserve precision during conversion
!(isIntegerType(sourceType) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
sourceType.ConvertibleTo(destinationType) {
source = source.Convert(destinationType)
}
if assignIfAssignable(source, destination) {
return nil
}
sourceInterface := source.Interface() sourceInterface := source.Interface()
switch destination.Interface().(type) { switch destination.Type().Kind() {
case bool: case reflect.Bool:
var nullBool internal.NullBool var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface) err := nullBool.Scan(sourceInterface)
@ -250,7 +218,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetBool(nullBool.Bool) destination.SetBool(nullBool.Bool)
case float32, float64: case reflect.Float32, reflect.Float64:
var nullFloat sql.NullFloat64 var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface) err := nullFloat.Scan(sourceInterface)
@ -261,7 +229,7 @@ func tryAssign(source, destination reflect.Value) error {
if nullFloat.Valid { if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64) destination.SetFloat(nullFloat.Float64)
} }
case int, int8, int16, int32, int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var integer sql.NullInt64 var integer sql.NullInt64
err := integer.Scan(sourceInterface) err := integer.Scan(sourceInterface)
@ -273,7 +241,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetInt(integer.Int64) destination.SetInt(integer.Int64)
} }
case uint, uint8, uint16, uint32, uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var uInt internal.NullUInt64 var uInt internal.NullUInt64
err := uInt.Scan(sourceInterface) err := uInt.Scan(sourceInterface)
@ -286,7 +254,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetUint(uInt.UInt64) destination.SetUint(uInt.UInt64)
} }
case string: case reflect.String:
var str sql.NullString var str sql.NullString
err := str.Scan(sourceInterface) err := str.Scan(sourceInterface)
@ -298,6 +266,8 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetString(str.String) destination.SetString(str.String)
} }
default:
switch destination.Interface().(type) {
case time.Time: case time.Time:
var nullTime internal.NullTime var nullTime internal.NullTime
@ -309,42 +279,31 @@ func tryAssign(source, destination reflect.Value) error {
if nullTime.Valid { if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time)) destination.Set(reflect.ValueOf(nullTime.Time))
} }
default: default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface()) return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
} }
}
return nil return nil
} }
func tryConvert(source, destination reflect.Value) bool {
destinationType := destination.Type()
if source.Type().ConvertibleTo(destinationType) {
source = source.Convert(destinationType)
return assignIfAssignable(source, destination)
}
return false
}
func setZeroValue(value reflect.Value) { func setZeroValue(value reflect.Value) {
if !value.IsZero() { if !value.IsZero() {
value.Set(reflect.Zero(value.Type())) value.Set(reflect.Zero(value.Type()))
} }
} }
func setReflectValue(source, destination reflect.Value) error {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
// source is nil, destination should be its zero value
setZeroValue(destination)
return nil
}
source = source.Elem()
}
if destination.Kind() == reflect.Ptr {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
destination = destination.Elem()
}
return tryAssign(source, destination)
}
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
if len(primaryKeyOverwrites) > 0 { if len(primaryKeyOverwrites) > 0 {
return utils.StringSliceContains(primaryKeyOverwrites, field.Name) return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
@ -398,3 +357,11 @@ func cloneBytes(b []byte) []byte {
copy(c, b) copy(c, b)
return c return c
} }
func concat(stringList ...string) string {
var b strings.Builder
for _, str := range stringList {
b.WriteString(str)
}
return b.String()
}

View file

@ -206,7 +206,7 @@ GROUP BY payment.customer_id;
"RentalID": null, "RentalID": null,
"Amount": 0, "Amount": 0,
"PaymentDate": "0001-01-01T00:00:00Z", "PaymentDate": "0001-01-01T00:00:00Z",
"LastUpdate": "0001-01-01T00:00:00Z", "LastUpdate": null,
"Count": 8, "Count": 8,
"Sum": 38.92, "Sum": 38.92,
"Avg": 4.865, "Avg": 4.865,
@ -964,7 +964,6 @@ func TestRowsScan(t *testing.T) {
rows, err := stmt.Rows(context.Background(), db) rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err) require.NoError(t, err)
for rows.Next() {
var inventory struct { var inventory struct {
model.Inventory model.Inventory
@ -972,6 +971,7 @@ func TestRowsScan(t *testing.T) {
Store model.Store Store model.Store
} }
for rows.Next() {
err = rows.Scan(&inventory) err = rows.Scan(&inventory)
require.NoError(t, err) require.NoError(t, err)
@ -1056,3 +1056,50 @@ func TestScanNumericToNumber(t *testing.T) {
require.Equal(t, number.Float32, float32(1.234568e+09)) require.Equal(t, number.Float32, float32(1.234568e+09))
require.Equal(t, number.Float64, float64(1.23456789e+09)) require.Equal(t, number.Float64, float64(1.23456789e+09))
} }
// scan into custom base types should be equivalent to the scan into base go types
func TestScanIntoCustomBaseTypes(t *testing.T) {
type MyUint8 uint8
type MyUint16 uint16
type MyUint32 uint32
type MyInt16 int16
type MyFloat32 float32
type MyFloat64 float64
type MyString string
type MyTime = time.Time
type film struct {
FilmID MyUint16 `sql:"primary_key"`
Title MyString
Description *MyString
ReleaseYear *MyInt16
LanguageID MyUint8
OriginalLanguageID *MyUint8
RentalDuration MyUint8
RentalRate MyFloat32
Length *MyUint32
ReplacementCost MyFloat64
Rating *model.FilmRating
SpecialFeatures *MyString
LastUpdate MyTime
}
stmt := SELECT(
Film.AllColumns,
).FROM(
Film,
).ORDER_BY(
Film.FilmID.ASC(),
).LIMIT(3)
var films []model.Film
err := stmt.Query(db, &films)
require.NoError(t, err)
var myFilms []film
err = stmt.Query(db, &myFilms)
require.NoError(t, err)
require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms))
}

View file

@ -786,6 +786,123 @@ func TestRowsScan(t *testing.T) {
requireQueryLogged(t, stmt, 0) requireQueryLogged(t, stmt, 0)
} }
func TestScanNullColumn(t *testing.T) {
stmt := SELECT(
Address.AllColumns,
).FROM(
Address,
).WHERE(
Address.Address2.IS_NULL(),
)
var dest []model.Address
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"AddressID": 1,
"Address": "47 MySakila Drive",
"Address2": null,
"District": "Alberta",
"CityID": 300,
"PostalCode": "",
"Phone": "",
"LastUpdate": "2006-02-15T09:45:30Z"
},
{
"AddressID": 2,
"Address": "28 MySQL Boulevard",
"Address2": null,
"District": "QLD",
"CityID": 576,
"PostalCode": "",
"Phone": "",
"LastUpdate": "2006-02-15T09:45:30Z"
},
{
"AddressID": 3,
"Address": "23 Workhaven Lane",
"Address2": null,
"District": "Alberta",
"CityID": 300,
"PostalCode": "",
"Phone": "14033335568",
"LastUpdate": "2006-02-15T09:45:30Z"
},
{
"AddressID": 4,
"Address": "1411 Lillydale Drive",
"Address2": null,
"District": "QLD",
"CityID": 576,
"PostalCode": "",
"Phone": "6172235589",
"LastUpdate": "2006-02-15T09:45:30Z"
}
]
`)
}
func TestRowsScanSetZeroValue(t *testing.T) {
stmt := SELECT(
Rental.AllColumns,
).FROM(
Rental,
).WHERE(
Rental.RentalID.IN(Int(16049), Int(15966)),
).ORDER_BY(
Rental.RentalID.DESC(),
)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
defer rows.Close()
// destination object is used as destination for all rows scan.
// this tests checks that ReturnedDate is set to nil with the second call
// check qrm.setZeroValue
var dest model.Rental
for rows.Next() {
err := rows.Scan(&dest)
require.NoError(t, err)
if dest.RentalID == 16049 {
testutils.AssertJSON(t, dest, `
{
"RentalID": 16049,
"RentalDate": "2005-08-23T22:50:12Z",
"InventoryID": 2666,
"CustomerID": 393,
"ReturnDate": "2005-08-30T01:01:12Z",
"StaffID": 2,
"LastUpdate": "2006-02-16T02:30:53Z"
}
`)
} else {
testutils.AssertJSON(t, dest, `
{
"RentalID": 15966,
"RentalDate": "2006-02-14T15:16:03Z",
"InventoryID": 4472,
"CustomerID": 374,
"ReturnDate": null,
"StaffID": 1,
"LastUpdate": "2006-02-16T02:30:53Z"
}
`)
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
}
func TestScanNumericToFloat(t *testing.T) { func TestScanNumericToFloat(t *testing.T) {
type Number struct { type Number struct {
Float32 float32 Float32 float32
@ -826,6 +943,54 @@ func TestScanNumericToIntegerError(t *testing.T) {
} }
func TestScanIntoCustomBaseTypes(t *testing.T) {
type MyUint8 uint8
type MyUint16 uint16
type MyUint32 uint32
type MyInt16 int16
type MyFloat32 float32
type MyFloat64 float64
type MyString string
type MyTime = time.Time
type film struct {
FilmID MyUint16 `sql:"primary_key"`
Title MyString
Description *MyString
ReleaseYear *MyInt16
LanguageID MyUint8
RentalDuration MyUint8
RentalRate MyFloat32
Length *MyUint32
ReplacementCost MyFloat64
Rating *model.MpaaRating
LastUpdate MyTime
SpecialFeatures *MyString
Fulltext MyString
}
stmt := SELECT(
Film.AllColumns,
).FROM(
Film,
).ORDER_BY(
Film.FilmID.ASC(),
).LIMIT(3)
var films []model.Film
err := stmt.Query(db, &films)
require.NoError(t, err)
var myFilms []film
err = stmt.Query(db, &myFilms)
require.NoError(t, err)
require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms))
}
// QueryContext panic when the scanned value is nil and the destination is a slice of primitive // QueryContext panic when the scanned value is nil and the destination is a slice of primitive
// https://github.com/go-jet/jet/issues/91 // https://github.com/go-jet/jet/issues/91
func TestScanToPrimitiveElementsSlice(t *testing.T) { func TestScanToPrimitiveElementsSlice(t *testing.T) {

View file

@ -2521,6 +2521,79 @@ func TestRecursionScanNx1(t *testing.T) {
}) })
} }
type StoreInfo struct {
model.Store
Staffs ManagerInfo
}
type ManagerInfo struct {
model.Staff
Store *StoreInfo
}
func TestRecursionScan1x1(t *testing.T) {
stmt := SELECT(
Store.AllColumns,
Staff.AllColumns,
).FROM(
Store.
INNER_JOIN(Staff, Staff.StaffID.EQ(Store.ManagerStaffID)),
).ORDER_BY(
Store.StoreID,
)
var dest []StoreInfo
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"StoreID": 1,
"ManagerStaffID": 1,
"AddressID": 1,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": {
"StaffID": 1,
"FirstName": "Mike",
"LastName": "Hillyer",
"AddressID": 3,
"Email": "Mike.Hillyer@sakilastaff.com",
"StoreID": 1,
"Active": true,
"Username": "Mike",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": "iVBORw0KWgo=",
"Store": null
}
},
{
"StoreID": 2,
"ManagerStaffID": 2,
"AddressID": 2,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": {
"StaffID": 2,
"FirstName": "Jon",
"LastName": "Stephens",
"AddressID": 4,
"Email": "Jon.Stephens@sakilastaff.com",
"StoreID": 2,
"Active": true,
"Username": "Jon",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": null,
"Store": null
}
}
]
`)
}
// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions, // In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions,
// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error. // postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error.
// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible. // Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible.

View file

@ -2,7 +2,6 @@ package postgres
import ( import (
"context" "context"
"fmt"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres" . "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model"
@ -864,5 +863,4 @@ WHERE orders1."orders.order_id" < $1;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, dest, 72) require.Len(t, dest, 72)
fmt.Println(len(dest))
} }