Merge pull request #22 from go-jet/err-no-rows

QRM returns sql.ErrNoRows when scanning into struct destination and query result set is empty.
This commit is contained in:
go-jet 2019-10-12 18:49:08 +02:00 committed by GitHub
commit 4fb1f52c85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 739 additions and 672 deletions

View file

@ -11,7 +11,7 @@ Jet is a framework for writing type-safe SQL queries in Go, with ability to easi
convert database query result into desired arbitrary object structure. convert database query result into desired arbitrary object structure.
Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases. Jet currently supports `PostgreSQL`, `MySQL` and `MariaDB`. Future releases will add support for additional databases.
![jet](https://github.com/go-jet/jet/wiki/image/jet.png) ![jet](https://github.com/go-jet/jet/wiki/image/jet.png)
Jet is the easiest and the fastest way to write complex SQL queries and map database query result Jet is the easiest and the fastest way to write complex SQL queries and map database query result
into complex object composition. __It is not an ORM.__ into complex object composition. __It is not an ORM.__

View file

@ -15,10 +15,12 @@ type Statement interface {
DebugSql() (query string) DebugSql() (query string)
// Query executes statement over database connection db and stores row result in destination. // Query executes statement over database connection db and stores row result in destination.
// Destination can be arbitrary structure // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows.
Query(db qrm.DB, destination interface{}) error Query(db qrm.DB, destination interface{}) error
// QueryContext executes statement with a context over database connection db and stores row result in destination. // QueryContext executes statement with a context over database connection db and stores row result in destination.
// Destination can be of arbitrary structure // Destination can be either pointer to struct or pointer to a slice.
// If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows.
QueryContext(context context.Context, db qrm.DB, destination interface{}) error QueryContext(context context.Context, db qrm.DB, destination interface{}) error
//Exec executes statement over db connection without returning any rows. //Exec executes statement over db connection without returning any rows.

View file

@ -3,20 +3,14 @@ package qrm
import ( import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver"
"fmt"
"github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/qrm/internal"
"github.com/google/uuid"
"reflect" "reflect"
"strconv"
"strings"
"time"
) )
// Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db` // Query executes Query Result Mapping (QRM) of `query` with list of parametrized arguments `arg` over database connection `db`
// using context `ctx` into destination `destPtr`. // using context `ctx` into destination `destPtr`.
// Destination can be either pointer to struct or pointer to slice of structs. // Destination can be either pointer to struct or pointer to slice of structs.
// If destination is pointer to struct and query result set is empty, method returns sql.ErrNoRows.
func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error { func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr interface{}) error {
utils.MustBeInitializedPtr(db, "jet: db is nil") utils.MustBeInitializedPtr(db, "jet: db is nil")
@ -26,17 +20,23 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
destinationPtrType := reflect.TypeOf(destPtr) destinationPtrType := reflect.TypeOf(destPtr)
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(ctx, db, query, args, destPtr) _, err := queryToSlice(ctx, db, query, args, destPtr)
return err
} else if destinationPtrType.Elem().Kind() == reflect.Struct { } else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem() tempSliceValue := tempSlicePtrValue.Elem()
err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
if err != nil { if err != nil {
return err return err
} }
if rowsProcessed == 0 {
return sql.ErrNoRows
}
// edge case when row result set contains only NULLs.
if tempSliceValue.Len() == 0 { if tempSliceValue.Len() == 0 {
return nil return nil
} }
@ -53,7 +53,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
} }
} }
func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error { func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) (rowsProcessed int64, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
@ -61,27 +61,27 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
rows, err := db.QueryContext(ctx, query, args...) rows, err := db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return err return
} }
defer rows.Close() defer rows.Close()
scanContext, err := newScanContext(rows) scanContext, err := newScanContext(rows)
if err != nil { if err != nil {
return err return
} }
if len(scanContext.row) == 0 { if len(scanContext.row) == 0 {
return nil return
} }
slicePtrValue := reflect.ValueOf(slicePtr) slicePtrValue := reflect.ValueOf(slicePtr)
for rows.Next() { for rows.Next() {
err := rows.Scan(scanContext.row...) err = rows.Scan(scanContext.row...)
if err != nil { if err != nil {
return err return
} }
scanContext.rowNum++ scanContext.rowNum++
@ -89,22 +89,24 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
_, err = mapRowToSlice(scanContext, "", slicePtrValue, nil) _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
if err != nil { if err != nil {
return err return
} }
} }
err = rows.Close() err = rows.Close()
if err != nil { if err != nil {
return err return
} }
err = rows.Err() err = rows.Err()
if err != nil { if err != nil {
return err return
} }
return nil rowsProcessed = scanContext.rowNum
return
} }
func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
@ -171,57 +173,8 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
return return
} }
type typeInfo struct {
fieldMappings []fieldMapping
}
type fieldMapping struct {
complexType bool
columnIndex int
implementsScanner bool
}
func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo {
typeMapKey := structType.String()
if parentField != nil {
typeMapKey += string(parentField.Tag)
}
if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok {
return typeInfo
}
typeName := getTypeName(structType, parentField)
newTypeInfo := typeInfo{}
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{
columnIndex: columnIndex,
}
if implementsScannerType(field.Type) {
fieldMap.implementsScanner = true
} else if !isSimpleModelType(field.Type) {
fieldMap.complexType = true
}
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
}
s.typeInfoMap[typeMapKey] = newTypeInfo
return newTypeInfo
}
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
mapOnlySlices := len(onlySlices) > 0
structType := structPtrValue.Type().Elem() structType := structPtrValue.Type().Elem()
typeInf := scanContext.getTypeInfo(structType, parentField) typeInf := scanContext.getTypeInfo(structType, parentField)
@ -250,22 +203,21 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
updated = true updated = true
} }
} else if len(onlySlices) == 0 { } else {
if mapOnlySlices || fieldMap.columnIndex == -1 {
if fieldMap.columnIndex == -1 {
continue continue
} }
cellValue := scanContext.rowElem(fieldMap.columnIndex)
if cellValue == nil {
continue
}
initializeValueIfNilPtr(fieldValue)
updated = true
if fieldMap.implementsScanner { if fieldMap.implementsScanner {
cellValue := scanContext.rowElem(fieldMap.columnIndex)
if cellValue == nil {
continue
}
initializeValueIfNilPtr(fieldValue)
scanner := getScanner(fieldValue) scanner := getScanner(fieldValue)
err = scanner.Scan(cellValue) err = scanner.Scan(cellValue)
@ -273,15 +225,8 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
if err != nil { if err != nil {
panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String()) panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String())
} }
updated = true
} else { } else {
cellValue := scanContext.rowElem(fieldMap.columnIndex) setReflectValue(reflect.ValueOf(cellValue), fieldValue)
if cellValue != nil {
updated = true
initializeValueIfNilPtr(fieldValue)
setReflectValue(reflect.ValueOf(cellValue), fieldValue)
}
} }
} }
} }
@ -289,21 +234,6 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
return return
} }
func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
destValueKind := destPtrValue.Elem().Kind()
if destValueKind == reflect.Struct {
return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
} else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else {
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
}
}
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) {
var destPtrValue reflect.Value var destPtrValue reflect.Value
@ -331,552 +261,17 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
return return
} }
var scannerInterfaceType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
func implementsScannerType(fieldType reflect.Type) bool { utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.")
if fieldType.Implements(scannerInterfaceType) {
return true
}
typePtr := reflect.New(fieldType).Type() destValueKind := destPtrValue.Elem().Kind()
return typePtr.Implements(scannerInterfaceType) if destValueKind == reflect.Struct {
} return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
} else if destValueKind == reflect.Slice {
func getScanner(value reflect.Value) sql.Scanner { return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
if scanner, ok := value.Interface().(sql.Scanner); ok {
return scanner
}
return value.Addr().Interface().(sql.Scanner)
}
func getSliceElemType(slicePtrValue reflect.Value) reflect.Type {
sliceTypePtr := slicePtrValue.Type()
elemType := sliceTypePtr.Elem().Elem()
if elemType.Kind() == reflect.Ptr {
return elemType.Elem()
}
return elemType
}
func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value {
sliceValue := slicePtrValue.Elem()
elem := sliceValue.Index(index)
if elem.Kind() == reflect.Ptr {
return elem
}
return elem.Addr()
}
func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error {
if slicePtrValue.IsNil() {
panic("jet: internal, slice is nil")
}
sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem()
newElemValue := objPtrValue
if sliceElemType.Kind() != reflect.Ptr {
newElemValue = objPtrValue.Elem()
}
if newElemValue.Type().ConvertibleTo(sliceElemType) {
newElemValue = newElemValue.Convert(sliceElemType)
}
if !newElemValue.Type().AssignableTo(sliceElemType) {
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
}
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
return nil
}
func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
destinationSliceType := slicePtrValue.Type().Elem()
elemType := indirectType(destinationSliceType.Elem())
return reflect.New(elemType)
}
func getTypeName(structType reflect.Type, parentField *reflect.StructField) string {
if parentField == nil {
return structType.Name()
}
aliasTag := parentField.Tag.Get("alias")
if aliasTag == "" {
return structType.Name()
}
aliasParts := strings.Split(aliasTag, ".")
return toCommonIdentifier(aliasParts[0])
}
func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) {
aliasTag := field.Tag.Get("alias")
if aliasTag == "" {
return structType, field.Name
}
aliasParts := strings.Split(aliasTag, ".")
if len(aliasParts) == 1 {
return structType, toCommonIdentifier(aliasParts[0])
}
return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1])
}
var replacer = strings.NewReplacer(" ", "", "-", "", "_", "")
func toCommonIdentifier(name string) string {
return strings.ToLower(replacer.Replace(name))
}
func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() {
return
}
if value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
}
}
func valueToString(value reflect.Value) string {
if !value.IsValid() {
return "nil"
}
var valueInterface interface{}
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return "nil"
}
valueInterface = value.Elem().Interface()
} else { } else {
valueInterface = value.Interface() panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
}
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
var timeType = reflect.TypeOf(time.Now())
var uuidType = reflect.TypeOf(uuid.New())
func isSimpleModelType(objType reflect.Type) bool {
objType = indirectType(objType)
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String,
reflect.Bool:
return true
case reflect.Slice:
return objType.Elem().Kind() == reflect.Uint8 //[]byte
case reflect.Struct:
return objType == timeType || objType == uuidType // time.Time || uuid.UUID
}
return false
}
func isIntegerType(value reflect.Type) bool {
switch value {
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type:
return true
}
return false
}
func tryAssign(source, destination reflect.Value) bool {
if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
if isIntegerType(source.Type()) && destination.Type() == boolType {
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
}
if source.Type().AssignableTo(destination.Type()) {
destination.Set(source)
return true
}
return false
}
func setReflectValue(source, destination reflect.Value) {
if tryAssign(source, destination) {
return
}
if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr {
if !source.IsNil() {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else {
if source.CanAddr() {
source = source.Addr()
} else {
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return
}
source = source.Elem()
}
if tryAssign(source, destination) {
return
}
}
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
}
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
values := make([]interface{}, len(columnTypes))
for i, sqlColumnType := range columnTypes {
columnType := newScanType(sqlColumnType)
columnValue := reflect.New(columnType)
values[i] = columnValue.Interface()
}
return values
}
var boolType = reflect.TypeOf(true)
var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type
case "INT4", "MEDIUMINT", "INT":
return nullInt32Type
case "INT8", "BIGINT":
return nullInt64Type
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloat32Type
case "FLOAT8", "NUMERIC", "DECIMAL", "FLOAT", "DOUBLE":
return nullFloat64Type
case "BOOL":
return nullBoolType
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
return nullStringType
} }
} }
type scanContext struct {
rowNum int
row []interface{}
uniqueDestObjectsMap map[string]int
typeToColumnIndexMap map[string]int
groupKeyInfoCache map[string]groupKeyInfo
typeInfoMap map[string]typeInfo
}
func newScanContext(rows *sql.Rows) (*scanContext, error) {
aliases, err := rows.Columns()
if err != nil {
return nil, err
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
typeToIndexMap := map[string]int{}
for i, alias := range aliases {
names := strings.SplitN(alias, ".", 2)
goName := toCommonIdentifier(names[0])
if len(names) > 1 {
goName += "." + toCommonIdentifier(names[1])
}
typeToIndexMap[strings.ToLower(goName)] = i
}
return &scanContext{
row: createScanValue(columnTypes),
uniqueDestObjectsMap: make(map[string]int),
groupKeyInfoCache: make(map[string]groupKeyInfo),
typeToColumnIndexMap: typeToIndexMap,
typeInfoMap: make(map[string]typeInfo),
}, nil
}
type groupKeyInfo struct {
typeName string
indexes []int
subTypes []groupKeyInfo
}
func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string {
mapKey := structType.Name()
if structField != nil {
mapKey += structField.Type.String()
}
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo)
}
groupKeyInfo := s.getGroupKeyInfo(structType, structField)
s.groupKeyInfoCache[mapKey] = groupKeyInfo
return s.constructGroupKey(groupKeyInfo)
}
func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 {
return "|ROW: " + strconv.Itoa(s.rowNum) + "|"
}
groupKeys := []string{}
for _, index := range groupKeyInfo.indexes {
cellValue := s.rowElem(index)
subKey := valueToString(reflect.ValueOf(cellValue))
groupKeys = append(groupKeys, subKey)
}
subTypesGroupKeys := []string{}
for _, subType := range groupKeyInfo.subTypes {
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
}
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")"
}
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
ret := groupKeyInfo{typeName: structType.Name()}
typeName := getTypeName(structType, parentField)
primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField)
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
fieldType := indirectType(field.Type)
if !isSimpleModelType(fieldType) {
if fieldType.Kind() != reflect.Struct {
continue
}
subType := s.getGroupKeyInfo(fieldType, &field)
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
ret.subTypes = append(ret.subTypes, subType)
}
} else {
if isPrimaryKey(field, primaryKeyOverwrites) {
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
index := s.typeToColumnIndex(newTypeName, fieldName)
if index < 0 {
continue
}
ret.indexes = append(ret.indexes, index)
}
}
}
return ret
}
func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
var key string
if typeName != "" {
key = strings.ToLower(typeName + "." + fieldName)
} else {
key = strings.ToLower(fieldName)
}
index, ok := s.typeToColumnIndexMap[key]
if !ok {
return -1
}
return index
}
func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer)
if !ok {
panic("jet: internal error, scan value doesn't implement driver.Valuer")
}
value, err := valuer.Value()
utils.PanicOnError(err)
return value
}
func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
rowElem := s.rowElem(index)
rowElemValue := reflect.ValueOf(rowElem)
if rowElemValue.Kind() == reflect.Ptr {
return rowElemValue
}
if rowElemValue.CanAddr() {
return rowElemValue.Addr()
}
newElem := reflect.New(rowElemValue.Type())
newElem.Elem().Set(rowElemValue)
return newElem
}
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
if len(primaryKeyOverwrites) > 0 {
return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
}
sqlTag := field.Tag.Get("sql")
return sqlTag == "primary_key"
}
func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string {
if parentField == nil {
return nil
}
sqlTag := parentField.Tag.Get("sql")
if !strings.HasPrefix(sqlTag, "primary_key") {
return nil
}
parts := strings.Split(sqlTag, "=")
if len(parts) < 2 {
return nil
}
return strings.Split(parts[1], ",")
}
func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr {
return reflectType
}
return reflectType.Elem()
}
func fieldToString(field *reflect.StructField) string {
if field == nil {
return ""
}
return " at '" + field.Name + " " + field.Type.String() + "'"
}

241
qrm/scan_context.go Normal file
View file

@ -0,0 +1,241 @@
package qrm
import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/go-jet/jet/internal/utils"
"reflect"
"strings"
)
type scanContext struct {
rowNum int64
row []interface{}
uniqueDestObjectsMap map[string]int
commonIdentToColumnIndex map[string]int
groupKeyInfoCache map[string]groupKeyInfo
typeInfoMap map[string]typeInfo
}
func newScanContext(rows *sql.Rows) (*scanContext, error) {
aliases, err := rows.Columns()
if err != nil {
return nil, err
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
commonIdentToColumnIndex := map[string]int{}
for i, alias := range aliases {
names := strings.SplitN(alias, ".", 2)
commonIdentifier := toCommonIdentifier(names[0])
if len(names) > 1 {
commonIdentifier += "." + toCommonIdentifier(names[1])
}
commonIdentToColumnIndex[commonIdentifier] = i
}
return &scanContext{
row: createScanValue(columnTypes),
uniqueDestObjectsMap: make(map[string]int),
groupKeyInfoCache: make(map[string]groupKeyInfo),
commonIdentToColumnIndex: commonIdentToColumnIndex,
typeInfoMap: make(map[string]typeInfo),
}, nil
}
type typeInfo struct {
fieldMappings []fieldMapping
}
type fieldMapping struct {
complexType bool // slice or struct
columnIndex int
implementsScanner bool
}
func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo {
typeMapKey := structType.String()
if parentField != nil {
typeMapKey += string(parentField.Tag)
}
if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok {
return typeInfo
}
typeName := getTypeName(structType, parentField)
newTypeInfo := typeInfo{}
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{
columnIndex: columnIndex,
}
if implementsScannerType(field.Type) {
fieldMap.implementsScanner = true
} else if !isSimpleModelType(field.Type) {
fieldMap.complexType = true
}
newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap)
}
s.typeInfoMap[typeMapKey] = newTypeInfo
return newTypeInfo
}
type groupKeyInfo struct {
typeName string
indexes []int
subTypes []groupKeyInfo
}
func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string {
mapKey := structType.Name()
if structField != nil {
mapKey += structField.Type.String()
}
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo)
}
groupKeyInfo := s.getGroupKeyInfo(structType, structField)
s.groupKeyInfoCache[mapKey] = groupKeyInfo
return s.constructGroupKey(groupKeyInfo)
}
func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 {
return fmt.Sprintf("|ROW:%d|", s.rowNum)
}
groupKeys := []string{}
for _, index := range groupKeyInfo.indexes {
cellValue := s.rowElem(index)
subKey := valueToString(reflect.ValueOf(cellValue))
groupKeys = append(groupKeys, subKey)
}
subTypesGroupKeys := []string{}
for _, subType := range groupKeyInfo.subTypes {
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
}
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")"
}
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
ret := groupKeyInfo{typeName: structType.Name()}
typeName := getTypeName(structType, parentField)
primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField)
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
fieldType := indirectType(field.Type)
if !isSimpleModelType(fieldType) {
if fieldType.Kind() != reflect.Struct {
continue
}
subType := s.getGroupKeyInfo(fieldType, &field)
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
ret.subTypes = append(ret.subTypes, subType)
}
} else {
if isPrimaryKey(field, primaryKeyOverwrites) {
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
index := s.typeToColumnIndex(newTypeName, fieldName)
if index < 0 {
continue
}
ret.indexes = append(ret.indexes, index)
}
}
}
return ret
}
func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
var key string
if typeName != "" {
key = strings.ToLower(typeName + "." + fieldName)
} else {
key = strings.ToLower(fieldName)
}
index, ok := s.commonIdentToColumnIndex[key]
if !ok {
return -1
}
return index
}
func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer)
if !ok {
panic("jet: internal error, scan value doesn't implement driver.Valuer")
}
value, err := valuer.Value()
utils.PanicOnError(err)
return value
}
func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
rowElem := s.rowElem(index)
rowElemValue := reflect.ValueOf(rowElem)
if rowElemValue.Kind() == reflect.Ptr {
return rowElemValue
}
if rowElemValue.CanAddr() {
return rowElemValue.Addr()
}
newElem := reflect.New(rowElemValue.Type())
newElem.Elem().Set(rowElemValue)
return newElem
}

371
qrm/utill.go Normal file
View file

@ -0,0 +1,371 @@
package qrm
import (
"database/sql"
"fmt"
"github.com/go-jet/jet/internal/utils"
"github.com/go-jet/jet/qrm/internal"
"github.com/google/uuid"
"reflect"
"strings"
"time"
)
var scannerInterfaceType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func implementsScannerType(fieldType reflect.Type) bool {
if fieldType.Implements(scannerInterfaceType) {
return true
}
typePtr := reflect.New(fieldType).Type()
return typePtr.Implements(scannerInterfaceType)
}
func getScanner(value reflect.Value) sql.Scanner {
if scanner, ok := value.Interface().(sql.Scanner); ok {
return scanner
}
return value.Addr().Interface().(sql.Scanner)
}
func getSliceElemType(slicePtrValue reflect.Value) reflect.Type {
sliceTypePtr := slicePtrValue.Type()
elemType := indirectType(sliceTypePtr).Elem()
return indirectType(elemType)
}
func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value {
sliceValue := slicePtrValue.Elem()
elem := sliceValue.Index(index)
if elem.Kind() == reflect.Ptr {
return elem
}
return elem.Addr()
}
func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error {
if slicePtrValue.IsNil() {
panic("jet: internal, slice is nil")
}
sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem()
newElemValue := objPtrValue
if sliceElemType.Kind() != reflect.Ptr {
newElemValue = objPtrValue.Elem()
}
if newElemValue.Type().ConvertibleTo(sliceElemType) {
newElemValue = newElemValue.Convert(sliceElemType)
}
if !newElemValue.Type().AssignableTo(sliceElemType) {
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
}
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
return nil
}
func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
destinationSliceType := slicePtrValue.Type().Elem()
elemType := indirectType(destinationSliceType.Elem())
return reflect.New(elemType)
}
func getTypeName(structType reflect.Type, parentField *reflect.StructField) string {
if parentField == nil {
return structType.Name()
}
aliasTag := parentField.Tag.Get("alias")
if aliasTag == "" {
return structType.Name()
}
aliasParts := strings.Split(aliasTag, ".")
return toCommonIdentifier(aliasParts[0])
}
func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) {
aliasTag := field.Tag.Get("alias")
if aliasTag == "" {
return structType, field.Name
}
aliasParts := strings.Split(aliasTag, ".")
if len(aliasParts) == 1 {
return structType, toCommonIdentifier(aliasParts[0])
}
return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1])
}
var replacer = strings.NewReplacer(" ", "", "-", "", "_", "")
func toCommonIdentifier(name string) string {
return strings.ToLower(replacer.Replace(name))
}
func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() {
return
}
if value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
}
}
func valueToString(value reflect.Value) string {
if !value.IsValid() {
return "nil"
}
var valueInterface interface{}
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return "nil"
}
valueInterface = value.Elem().Interface()
} else {
valueInterface = value.Interface()
}
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
var timeType = reflect.TypeOf(time.Now())
var uuidType = reflect.TypeOf(uuid.New())
func isSimpleModelType(objType reflect.Type) bool {
objType = indirectType(objType)
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String,
reflect.Bool:
return true
case reflect.Slice:
return objType.Elem().Kind() == reflect.Uint8 //[]byte
case reflect.Struct:
return objType == timeType || objType == uuidType // time.Time || uuid.UUID
}
return false
}
func isIntegerType(value reflect.Type) bool {
switch value {
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type:
return true
}
return false
}
func tryAssign(source, destination reflect.Value) bool {
if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
if isIntegerType(source.Type()) && destination.Type() == boolType {
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
}
if source.Type().AssignableTo(destination.Type()) {
destination.Set(source)
return true
}
return false
}
func setReflectValue(source, destination reflect.Value) {
if tryAssign(source, destination) {
return
}
if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr {
if !source.IsNil() {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else {
if source.CanAddr() {
source = source.Addr()
} else {
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return
}
source = source.Elem()
}
if tryAssign(source, destination) {
return
}
}
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
}
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
values := make([]interface{}, len(columnTypes))
for i, sqlColumnType := range columnTypes {
columnType := newScanType(sqlColumnType)
columnValue := reflect.New(columnType)
values[i] = columnValue.Interface()
}
return values
}
var boolType = reflect.TypeOf(true)
var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type
case "INT4", "MEDIUMINT", "INT":
return nullInt32Type
case "INT8", "BIGINT":
return nullInt64Type
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloat32Type
case "FLOAT8", "NUMERIC", "DECIMAL", "FLOAT", "DOUBLE":
return nullFloat64Type
case "BOOL":
return nullBoolType
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
return nullStringType
}
}
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
if len(primaryKeyOverwrites) > 0 {
return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
}
sqlTag := field.Tag.Get("sql")
return sqlTag == "primary_key"
}
func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string {
if parentField == nil {
return nil
}
sqlTag := parentField.Tag.Get("sql")
if !strings.HasPrefix(sqlTag, "primary_key") {
return nil
}
parts := strings.Split(sqlTag, "=")
if len(parts) < 2 {
return nil
}
return strings.Split(parts[1], ",")
}
func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr {
return reflectType
}
return reflectType.Elem()
}
func fieldToString(field *reflect.StructField) string {
if field == nil {
return ""
}
return " at '" + field.Name + " " + field.Type.String() + "'"
}

View file

@ -516,7 +516,8 @@ func TestStringOperators(t *testing.T) {
fmt.Println(query.DebugSql()) fmt.Println(query.DebugSql())
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -586,7 +587,8 @@ FROM test_sample.all_types;
`, "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", `, "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06",
"19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36") "19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36")
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -646,7 +648,8 @@ SELECT CAST(? AS DATE),
FROM test_sample.all_types; FROM test_sample.all_types;
`) `)
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -708,7 +711,8 @@ SELECT all_types.date_time = all_types.date_time,
FROM test_sample.all_types; FROM test_sample.all_types;
`) `)
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -769,7 +773,8 @@ SELECT all_types.timestamp = all_types.timestamp,
CURRENT_TIMESTAMP(2) CURRENT_TIMESTAMP(2)
FROM test_sample.all_types; FROM test_sample.all_types;
`) `)
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }

View file

@ -228,7 +228,8 @@ LIMIT ?;
testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -263,7 +264,8 @@ LIMIT ?;
testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(1)) testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(1))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -305,7 +307,8 @@ OFFSET ?;
testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(4), int64(3)) testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(4), int64(3))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -510,7 +513,8 @@ SELECT true,
'date'; 'date';
`) `)
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -530,7 +534,8 @@ LOCK IN SHARE MODE;
testutils.AssertDebugStatementSql(t, query, expectedSQL) testutils.AssertDebugStatementSql(t, query, expectedSQL)
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -606,7 +611,8 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date;
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -641,7 +647,8 @@ ORDER BY payment.customer_id;
testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }

View file

@ -134,7 +134,7 @@ LIMIT $5;
func TestExpressionCast(t *testing.T) { func TestExpressionCast(t *testing.T) {
query := AllTypes.SELECT( query := AllTypes.SELECT(
postgres.CAST(Int(150)).AS_CHAR(12), postgres.CAST(Int(150)).AS_CHAR(12).AS("char12"),
postgres.CAST(String("TRUE")).AS_BOOL(), postgres.CAST(String("TRUE")).AS_BOOL(),
postgres.CAST(String("111")).AS_SMALLINT(), postgres.CAST(String("111")).AS_SMALLINT(),
postgres.CAST(String("111")).AS_INTEGER(), postgres.CAST(String("111")).AS_INTEGER(),
@ -170,7 +170,8 @@ func TestExpressionCast(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -249,7 +250,8 @@ func TestStringOperators(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -618,7 +620,8 @@ func TestTimeExpression(t *testing.T) {
//fmt.Println(query.DebugSql()) //fmt.Println(query.DebugSql())
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }

View file

@ -1,6 +1,7 @@
package postgres package postgres
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
@ -50,14 +51,16 @@ func TestScanToInvalidDestination(t *testing.T) {
func TestScanToValidDestination(t *testing.T) { func TestScanToValidDestination(t *testing.T) {
t.Run("pointer to struct", func(t *testing.T) { t.Run("pointer to struct", func(t *testing.T) {
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
}) })
t.Run("global query function scan", func(t *testing.T) { t.Run("global query function scan", func(t *testing.T) {
queryStr, args := query.Sql() queryStr, args := query.Sql()
err := qrm.Query(nil, db, queryStr, args, &struct{}{}) dest := []struct{}{}
err := qrm.Query(nil, db, queryStr, args, &dest)
assert.NilError(t, err) assert.NilError(t, err)
}) })
@ -692,6 +695,35 @@ func TestScanToSlice(t *testing.T) {
}) })
} }
func TestStructScanErrNoRows(t *testing.T) {
query := SELECT(Customer.AllColumns).
FROM(Customer).
WHERE(Customer.CustomerID.EQ(Int(-1)))
customer := model.Customer{}
err := query.Query(db, &customer)
assert.Error(t, err, sql.ErrNoRows.Error())
}
func TestStructScanAllNull(t *testing.T) {
query := SELECT(NULL.AS("null1"), NULL.AS("null2"))
dest := struct {
Null1 *int
Null2 *int
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest, struct {
Null1 *int
Null2 *int
}{})
}
var address256 = model.Address{ var address256 = model.Address{
AddressID: 256, AddressID: 256,
Address: "1497 Yuzhou Drive", Address: "1497 Yuzhou Drive",

View file

@ -162,7 +162,8 @@ LIMIT 12;
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -1617,7 +1618,8 @@ SELECT true,
'date'; 'date';
`) `)
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -1688,7 +1690,8 @@ GROUP BY payment.amount, payment.customer_id, payment.payment_date;
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
@ -1723,12 +1726,14 @@ ORDER BY payment.customer_id;
testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
err := query.Query(db, &struct{}{}) dest := []struct{}{}
err := query.Query(db, &dest)
assert.NilError(t, err) assert.NilError(t, err)
} }
func TestSimpleView(t *testing.T) { func TestSimpleView(t *testing.T) {
query := SELECT( query := SELECT(
view.ActorInfo.AllColumns, view.ActorInfo.AllColumns,
). ).
@ -1743,6 +1748,12 @@ func TestSimpleView(t *testing.T) {
FilmInfo string FilmInfo string
} }
//sql, args := query.Sql()
//
//row := db.QueryRow(sql, args...)
//
//row.Scan()
var dest []ActorInfo var dest []ActorInfo
err := query.Query(db, &dest) err := query.Query(db, &dest)