Qrm refactor

- Allow custom types Scan method to read values returned by the driver rather then the value from intermediate Null types. Scan to intermidiate Null types removed.
- Better error handling
This commit is contained in:
go-jet 2021-10-15 17:43:10 +02:00
parent 555ec293fb
commit 0d418890ab
11 changed files with 459 additions and 574 deletions

View file

@ -20,26 +20,32 @@ const (
)
func (e *MpaaRating) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("jet: Invalid data for MpaaRating enum")
} else {
switch string(v) {
case "G":
*e = MpaaRating_G
case "PG":
*e = MpaaRating_Pg
case "PG-13":
*e = MpaaRating_Pg13
case "R":
*e = MpaaRating_R
case "NC-17":
*e = MpaaRating_Nc17
default:
return errors.New("jet: Inavlid data " + string(v) + "for MpaaRating enum")
}
return nil
var enumValue string
switch val := value.(type) {
case string:
enumValue = val
case []byte:
enumValue = string(val)
default:
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
}
switch enumValue {
case "G":
*e = MpaaRating_G
case "PG":
*e = MpaaRating_Pg
case "PG-13":
*e = MpaaRating_Pg13
case "R":
*e = MpaaRating_R
case "NC-17":
*e = MpaaRating_Nc17
default:
return errors.New("jet: Invalid scan value '" + enumValue + "' for MpaaRating enum")
}
return nil
}
func (e MpaaRating) String() string {

View file

@ -200,20 +200,26 @@ const (
)
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("jet: Invalid scan value for {{$enumTemplate.TypeName}} enum. Enum value has to be of type string")
} else {
switch string(v) {
{{- range $_, $value := .Values}}
case "{{$value}}":
*e = {{valueName $value}}
{{- end}}
default:
return errors.New("jet: Invalid scan value '" + string(v) + "' for {{$enumTemplate.TypeName}} enum")
}
return nil
var enumValue string
switch val := value.(type) {
case string:
enumValue = val
case []byte:
enumValue = string(val)
default:
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
}
switch enumValue {
{{- range $_, $value := .Values}}
case "{{$value}}":
*e = {{valueName $value}}
{{- end}}
default:
return errors.New("jet: Invalid scan value '" + enumValue + "' for {{$enumTemplate.TypeName}} enum")
}
return nil
}
func (e {{$enumTemplate.TypeName}}) String() string {

View file

@ -0,0 +1,9 @@
package min
// Int returns minimum of two int values
func Int(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -1,263 +1,170 @@
package internal
import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/go-jet/jet/v2/internal/utils/min"
"reflect"
"strconv"
"time"
)
//===============================================================//
// NullByteArray struct
type NullByteArray struct {
ByteArray []byte
Valid bool
// NullBool struct
type NullBool struct {
sql.NullBool
}
// Scan implements the Scanner interface.
func (nb *NullByteArray) Scan(value interface{}) error {
func (nb *NullBool) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
nb.Valid = false
return nil
case []byte:
nb.ByteArray = append(v[:0:0], v...)
case bool:
nb.Bool, nb.Valid = v, true
case int8, int16, int32, int64, int:
intVal := reflect.ValueOf(v).Int()
if intVal != 0 && intVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = intVal == 1
nb.Valid = true
case uint8, uint16, uint32, uint64, uint:
uintVal := reflect.ValueOf(v).Uint()
if uintVal != 0 && uintVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = uintVal == 1
nb.Valid = true
return nil
default:
return fmt.Errorf("can't scan []byte from %v", value)
return nb.NullBool.Scan(value)
}
}
// Value implements the driver Valuer interface.
func (nb NullByteArray) Value() (driver.Value, error) {
if !nb.Valid {
return nil, nil
}
return nb.ByteArray, nil
return nil
}
//===============================================================//
// NullTime struct
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
sql.NullTime
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) (err error) {
switch v := value.(type) {
case nil:
nt.Valid = false
return
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, nt.Valid = parseTime(string(v))
return
case string:
nt.Time, nt.Valid = parseTime(v)
return
default:
func (nt *NullTime) Scan(value interface{}) error {
err := nt.NullTime.Scan(value)
if err == nil {
return nil
}
// Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
// At this point we try to parse time using some of the predefined formats
nt.Time, nt.Valid = tryParseAsTime(value)
if !nt.Valid {
return fmt.Errorf("can't scan time.Time from %v", value)
}
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
var formats = []string{
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx
"15:04:05.999999", // pgx
}
const formatTime = "2006-01-02 15:04:05.999999"
func tryParseAsTime(value interface{}) (time.Time, bool) {
func parseTime(timeStr string) (t time.Time, valid bool) {
var timeStr string
var format string
switch len(timeStr) {
case 8:
format = formatTime[11:19]
case 10, 19, 21, 22, 23, 24, 25, 26:
format = formatTime[:len(timeStr)]
default:
return t, false
}
t, err := time.Parse(format, timeStr)
return t, err == nil
}
//===============================================================//
// NullInt8 struct
type NullInt8 struct {
Int8 int8
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt8) Scan(value interface{}) (err error) {
switch v := value.(type) {
case nil:
n.Valid = false
return
case int64:
n.Int8, n.Valid = int8(v), true
return
case int8:
n.Int8, n.Valid = v, true
return
case string:
timeStr = v
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 8)
if err == nil {
n.Int8, n.Valid = int8(intV), true
timeStr = string(v)
}
for _, format := range formats {
formatLen := min.Int(len(format), len(timeStr))
t, err := time.Parse(format[:formatLen], timeStr)
if err != nil {
continue
}
return err
default:
return fmt.Errorf("can't scan int8 from %v", value)
return t, true
}
return time.Time{}, false
}
// Value implements the driver Valuer interface.
func (n NullInt8) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int8, nil
}
//===============================================================//
// NullInt16 struct
type NullInt16 struct {
Int16 int16
Valid bool
// NullUInt64 struct
type NullUInt64 struct {
UInt64 uint64
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error {
func (n *NullUInt64) Scan(value interface{}) error {
var stringValue string
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64:
n.Int16, n.Valid = int16(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case int16:
n.Int16, n.Valid = v, true
return nil
case int8:
n.Int16, n.Valid = int16(v), true
return nil
case uint8:
n.Int16, n.Valid = int16(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 16)
if err == nil {
n.Int16, n.Valid = int16(intV), true
}
return nil
default:
return fmt.Errorf("can't scan int16 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullInt16) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int16, nil
}
//===============================================================//
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64:
n.Int32, n.Valid = int32(v), true
case uint64:
n.UInt64, n.Valid = v, true
return nil
case int32:
n.Int32, n.Valid = v, true
n.UInt64, n.Valid = uint64(v), true
return nil
case uint32:
n.UInt64, n.Valid = uint64(v), true
return nil
case int16:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case uint16:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case int8:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case int:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint:
n.UInt64, n.Valid = uint64(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 32)
if err == nil {
n.Int32, n.Valid = int32(intV), true
}
return nil
stringValue = string(v)
case string:
stringValue = v
default:
return fmt.Errorf("can't scan int32 from %v", value)
return fmt.Errorf("can't scan uint64 from %v", value)
}
uintV, err := strconv.ParseUint(stringValue, 10, 64)
if err != nil {
return err
}
n.UInt64 = uintV
n.Valid = true
return nil
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
func (n NullUInt64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
}
//===============================================================//
// NullFloat32 struct
type NullFloat32 struct {
Float32 float32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullFloat32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case float64:
n.Float32, n.Valid = float32(v), true
return nil
case float32:
n.Float32, n.Valid = v, true
return nil
default:
return fmt.Errorf("can't scan float32 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullFloat32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Float32, nil
return n.UInt64, nil
}

View file

@ -7,141 +7,85 @@ import (
"time"
)
func TestNullByteArray(t *testing.T) {
var array NullByteArray
func TestNullBool(t *testing.T) {
var nullBool NullBool
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, nullBool.Scan(nil))
require.Equal(t, nullBool.Valid, false)
require.NoError(t, array.Scan([]byte("bytea")))
require.Equal(t, array.Valid, true)
require.Equal(t, string(array.ByteArray), string([]byte("bytea")))
require.NoError(t, nullBool.Scan(int64(1)))
require.Equal(t, nullBool.Valid, true)
value, _ := nullBool.Value()
require.Equal(t, value, true)
require.Error(t, array.Scan(12), "can't scan []byte from 12")
require.NoError(t, nullBool.Scan(uint32(0)))
require.Equal(t, nullBool.Valid, true)
value, _ = nullBool.Value()
require.Equal(t, value, false)
require.EqualError(t, nullBool.Scan(uint16(22)), "can't assign uint16(22) to bool")
}
func TestNullTime(t *testing.T) {
var array NullTime
var nullTime NullTime
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, nullTime.Scan(nil))
require.Equal(t, nullTime.Valid, false)
time := time.Now()
require.NoError(t, array.Scan(time))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.NoError(t, nullTime.Scan(time))
require.Equal(t, nullTime.Valid, true)
value, _ := nullTime.Value()
require.Equal(t, value, time)
require.NoError(t, array.Scan([]byte("13:10:11")))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.NoError(t, nullTime.Scan([]byte("13:10:11")))
require.Equal(t, nullTime.Valid, true)
value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.NoError(t, array.Scan("13:10:11"))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.NoError(t, nullTime.Scan("13:10:11"))
require.Equal(t, nullTime.Valid, true)
value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.Error(t, array.Scan(12), "can't scan time.Time from 12")
require.Error(t, nullTime.Scan(12), "can't scan time.Time from 12")
}
func TestNullInt8(t *testing.T) {
var array NullInt8
func TestNullUInt64(t *testing.T) {
var nullUInt64 NullUInt64
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, nullUInt64.Scan(nil))
require.Equal(t, nullUInt64.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int8(11))
require.NoError(t, nullUInt64.Scan(int64(11)))
require.Equal(t, nullUInt64.Valid, true)
value, _ := nullUInt64.Value()
require.Equal(t, value, uint64(11))
require.Error(t, array.Scan("text"), "can't scan int8 from text")
}
func TestNullInt16(t *testing.T) {
var array NullInt16
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int16(11))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int16(20))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int16(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int16(30))
require.Error(t, array.Scan("text"), "can't scan int16 from text")
}
func TestNullInt32(t *testing.T) {
var array NullInt32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int32(11))
require.NoError(t, array.Scan(int32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(32))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(20))
require.NoError(t, array.Scan(uint16(16)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(16))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.Error(t, array.Scan("text"), "can't scan int32 from text")
}
func TestNullFloat32(t *testing.T) {
var array NullFloat32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(float64(64)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, float32(64))
require.NoError(t, array.Scan(float32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, float32(32))
require.Error(t, array.Scan(12), "can't scan float32 from 12")
require.NoError(t, nullUInt64.Scan(int32(32)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(32))
require.NoError(t, nullUInt64.Scan(int16(20)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(20))
require.NoError(t, nullUInt64.Scan(uint16(16)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(16))
require.NoError(t, nullUInt64.Scan(int8(30)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.NoError(t, nullUInt64.Scan(uint8(30)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text")
}

View file

@ -27,7 +27,10 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
if destinationPtrType.Elem().Kind() == reflect.Slice {
_, err := queryToSlice(ctx, db, query, args, destPtr)
return err
if err != nil {
return fmt.Errorf("jet: %w", err)
}
return nil
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
@ -35,7 +38,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
if err != nil {
return err
return fmt.Errorf("jet: %w", err)
}
if rowsProcessed == 0 {
@ -275,10 +278,16 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
err = scanner.Scan(cellValue)
if err != nil {
panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String())
err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
}
} else {
setReflectValue(reflect.ValueOf(cellValue), fieldValue)
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue)
if err != nil {
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
}
}
}
}

View file

@ -2,10 +2,7 @@ package qrm
import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"reflect"
"strings"
)
@ -46,7 +43,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
}
return &scanContext{
row: createScanValue(columnTypes),
row: createScanSlice(len(columnTypes)),
uniqueDestObjectsMap: make(map[string]int),
groupKeyInfoCache: make(map[string]groupKeyInfo),
@ -56,6 +53,17 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
}, nil
}
func createScanSlice(columnCount int) []interface{} {
scanSlice := make([]interface{}, columnCount)
scanPtrSlice := make([]interface{}, columnCount)
for i := range scanPtrSlice {
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value
}
return scanPtrSlice
}
type typeInfo struct {
fieldMappings []fieldMapping
}
@ -210,16 +218,13 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
}
func (s *scanContext) rowElem(index int) interface{} {
cellValue := reflect.ValueOf(s.row[index])
valuer, ok := s.row[index].(driver.Valuer)
if cellValue.IsValid() && !cellValue.IsNil() {
return cellValue.Elem().Interface()
}
utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer")
value, err := valuer.Value()
throw.OnError(err)
return value
return nil
}
func (s *scanContext) rowElemValuePtr(index int) reflect.Value {

View file

@ -7,7 +7,6 @@ import (
"github.com/go-jet/jet/v2/qrm/internal"
"github.com/google/uuid"
"reflect"
"strconv"
"strings"
"time"
)
@ -56,21 +55,22 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem()
newElemValue := objPtrValue
newSliceElemValue := reflect.New(sliceElemType).Elem()
if sliceElemType.Kind() != reflect.Ptr {
newElemValue = objPtrValue.Elem()
var err error
if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
}
if newElemValue.Type().ConvertibleTo(sliceElemType) {
newElemValue = newElemValue.Convert(sliceElemType)
if err != nil {
return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err)
}
if !newElemValue.Type().AssignableTo(sliceElemType) {
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
}
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue))
return nil
}
@ -121,7 +121,6 @@ func toCommonIdentifier(name string) string {
}
func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() {
return
}
@ -173,172 +172,147 @@ func isSimpleModelType(objType reflect.Type) bool {
return objType == timeType || objType == uuidType || objType == byteArrayType
}
func isIntegerType(value reflect.Type) bool {
switch value {
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type:
func isFloatType(value reflect.Type) bool {
switch value.Kind() {
case reflect.Float32, reflect.Float64:
return true
}
return false
}
func isNumber(valueType reflect.Type) bool {
return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type
}
func tryAssign(source, destination reflect.Value) error {
func tryAssign(source, destination reflect.Value) bool {
if source.Type() != destination.Type() &&
!isFloatType(destination.Type()) && // to preserve precision during conversion
source.Type().ConvertibleTo(destination.Type()) {
switch {
case source.Type().ConvertibleTo(destination.Type()):
source = source.Convert(destination.Type())
case isIntegerType(source.Type()) && destination.Type() == boolType:
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
case source.Type() == stringType && isNumber(destination.Type()):
// if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number
// and then parsed number is converted into destination type
f, err := strconv.ParseFloat(source.String(), 64)
if err != nil {
return false
}
source = reflect.ValueOf(f)
if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
}
if source.Type().AssignableTo(destination.Type()) {
destination.Set(source)
return true
switch b := source.Interface().(type) {
case []byte:
destination.SetBytes(cloneBytes(b))
default:
destination.Set(source)
}
return nil
}
return false
sourceInterface := source.Interface()
switch destination.Interface().(type) {
case bool:
var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface)
if err != nil {
return err
}
destination.SetBool(nullBool.Bool)
case float32, float64:
var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface)
if err != nil {
return err
}
if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64)
}
case int, int8, int16, int32, int64:
var integer sql.NullInt64
err := integer.Scan(sourceInterface)
if err != nil {
return err
}
if integer.Valid {
destination.SetInt(integer.Int64)
}
case uint, uint8, uint16, uint32, uint64:
var uInt internal.NullUInt64
err := uInt.Scan(sourceInterface)
if err != nil {
return err
}
if uInt.Valid {
destination.SetUint(uInt.UInt64)
}
case string:
var str sql.NullString
err := str.Scan(sourceInterface)
if err != nil {
return err
}
if str.Valid {
destination.SetString(str.String)
}
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
}
return nil
}
func setReflectValue(source, destination reflect.Value) {
if tryAssign(source, destination) {
return
}
func setReflectValue(source, destination reflect.Value) error {
if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr {
if !source.IsNil() {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else {
if source.CanAddr() {
source = source.Addr()
} else {
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if tryAssign(source, destination) {
return
if err := tryAssign(source, destination.Elem()); err != nil {
return err
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination); err != nil {
return err
}
}
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
}
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
values := make([]interface{}, len(columnTypes))
for i, sqlColumnType := range columnTypes {
columnType := newScanType(sqlColumnType)
columnValue := reflect.New(columnType)
values[i] = columnValue.Interface()
}
return values
}
var boolType = reflect.TypeOf(true)
var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var float32Type = reflect.TypeOf(float32(1))
var float64Type = reflect.TypeOf(float64(1))
var stringType = reflect.TypeOf("")
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type
case "INT4", "MEDIUMINT", "INT":
return nullInt32Type
case "INT8", "BIGINT":
return nullInt64Type
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloat32Type
case "FLOAT8", "FLOAT", "DOUBLE":
return nullFloat64Type
case "BOOL":
return nullBoolType
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
return nullStringType
}
return nil
}
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
@ -385,3 +359,12 @@ func fieldToString(field *reflect.StructField) string {
return " at '" + field.Name + " " + field.Type.String() + "'"
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
c := make([]byte, len(b))
copy(c, b)
return c
}

View file

@ -17,11 +17,12 @@ import (
)
func TestAllTypesSelect(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string
dest := []model.AllTypes{}
err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest)
err := AllTypes.SELECT(
AllTypes.AllColumns,
).LIMIT(2).
Query(db, &dest)
require.NoError(t, err)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
@ -29,8 +30,6 @@ func TestAllTypesSelect(t *testing.T) {
}
func TestAllTypesViewSelect(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string
type AllTypesView model.AllTypes
dest := []AllTypesView{}
@ -43,7 +42,7 @@ func TestAllTypesViewSelect(t *testing.T) {
}
func TestAllTypesInsertModel(t *testing.T) {
skipForPgxDriver(t) // pgx driver does not handle well time with time zone
skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008)
query := AllTypes.INSERT(AllTypes.AllColumns).
MODEL(allTypesRow0).
@ -60,8 +59,6 @@ func TestAllTypesInsertModel(t *testing.T) {
}
func TestAllTypesInsertQuery(t *testing.T) {
skipForPgxDriver(t) // pgx driver does not handle well time with time zone
query := AllTypes.INSERT(AllTypes.AllColumns).
QUERY(
AllTypes.
@ -80,8 +77,6 @@ func TestAllTypesInsertQuery(t *testing.T) {
}
func TestAllTypesFromSubQuery(t *testing.T) {
skipForPgxDriver(t)
subQuery := SELECT(AllTypes.AllColumns).
FROM(AllTypes).
AsTable("allTypesSubQuery")
@ -302,10 +297,10 @@ LIMIT $11;
func TestExpressionCast(t *testing.T) {
skipForPgxDriver(t) // for some reason, pgx driver, 150:char(12) returns as int value
skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text'
query := AllTypes.SELECT(
CAST(Int(150)).AS_CHAR(12).AS("char12"),
CAST(Int(151)).AS_CHAR(12).AS("char12"),
CAST(String("TRUE")).AS_BOOL(),
CAST(String("111")).AS_SMALLINT(),
CAST(String("111")).AS_INTEGER(),
@ -349,7 +344,7 @@ func TestExpressionCast(t *testing.T) {
}
func TestStringOperators(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns text column as int value
skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text'
query := AllTypes.SELECT(
AllTypes.Text.EQ(AllTypes.Char),
@ -866,8 +861,6 @@ func TestInterval(t *testing.T) {
}
func TestSubQueryColumnReference(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string value
type expected struct {
sql string
args []interface{}
@ -1044,8 +1037,6 @@ FROM`
}
func TestTimeLiterals(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string
loc, err := time.LoadLocation("Europe/Berlin")
require.NoError(t, err)
@ -1060,8 +1051,6 @@ func TestTimeLiterals(t *testing.T) {
).FROM(AllTypes).
LIMIT(1)
//fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, `
SELECT $1::date AS "date",
$2::time without time zone AS "time",
@ -1073,25 +1062,29 @@ LIMIT $6;
`)
var dest struct {
Date time.Time
Time time.Time
Timez time.Time
Timestamp time.Time
//Timestampz time.Time
Date time.Time
Time time.Time
Timez time.Time
Timestamp time.Time
Timestampz time.Time
}
err = query.Query(db, &dest)
require.NoError(t, err)
//testutils.PrintJson(dest)
// pq driver will return time with time zone in local timezone,
// while pgx driver will return time in UTC time zone
dest.Timez = dest.Timez.UTC()
dest.Timestampz = dest.Timestampz.UTC()
testutils.AssertJSON(t, dest, `
{
"Date": "2009-11-17T00:00:00Z",
"Time": "0000-01-01T20:34:58.651387Z",
"Timez": "0000-01-01T20:34:58.651387+01:00",
"Timestamp": "2009-11-17T20:34:58.651387Z"
"Timez": "0000-01-01T19:34:58.651387Z",
"Timestamp": "2009-11-17T20:34:58.651387Z",
"Timestampz": "2009-11-17T19:34:58.651387Z"
}
`)
requireLogged(t, query)

View file

@ -31,7 +31,7 @@ func TestMain(m *testing.M) {
setTestRoot()
for _, driverName := range []string{"postgres", "pgx"} {
for _, driverName := range []string{"pgx", "postgres"} {
fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName)
func() {
@ -81,8 +81,16 @@ func requireLogged(t *testing.T, statement postgres.Statement) {
}
func skipForPgxDriver(t *testing.T) {
switch db.Driver().(type) {
case *stdlib.Driver:
if isPgxDriver() {
t.SkipNow()
}
}
func isPgxDriver() bool {
switch db.Driver().(type) {
case *stdlib.Driver:
return true
}
return false
}

View file

@ -78,16 +78,31 @@ func TestScanToValidDestination(t *testing.T) {
require.NoError(t, err)
})
t.Run("pointer to slice of strings", func(t *testing.T) {
err := oneInventoryQuery.Query(db, &[]int32{})
t.Run("pointer to slice of integers", func(t *testing.T) {
var dest []int32
err := oneInventoryQuery.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest[0], int32(1))
})
t.Run("pointer to slice of strings", func(t *testing.T) {
err := oneInventoryQuery.Query(db, &[]*int32{})
t.Run("pointer to slice integer pointers", func(t *testing.T) {
var dest []*int32
err := oneInventoryQuery.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest[0], testutils.Int32Ptr(1))
})
t.Run("NULL to integer", func(t *testing.T) {
var dest struct {
Int64 int64
UInt64 uint64
}
err := SELECT(NULL.AS("int64"), NULL.AS("uint64")).Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.Int64, int64(0))
require.Equal(t, dest.UInt64, uint64(0))
})
}
@ -189,7 +204,9 @@ func TestScanToStruct(t *testing.T) {
dest := Inventory{}
testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: Scan: unable to scan type int32 into UUID, at 'InventoryID uuid.UUID' of type postgres.Inventory`)
err := query.Query(db, &dest)
require.Error(t, err)
require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID")
})
t.Run("type mismatch base type", func(t *testing.T) {
@ -200,7 +217,9 @@ func TestScanToStruct(t *testing.T) {
dest := []Inventory{}
testutils.AssertQueryPanicErr(t, query.OFFSET(10), db, &dest, `jet: can't set int16 to bool`)
err := query.OFFSET(10).Query(db, &dest)
require.Error(t, err)
require.EqualError(t, err, "jet: can't assign int64('\\x02') to 'FilmID bool': can't assign int64(2) to bool")
})
}
@ -451,8 +470,9 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice type mismatch", func(t *testing.T) {
var dest []bool
testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: can't append int32 to []bool slice`)
//require.Error(t, err, `jet: can't append int32 to []bool slice `)
err := query.Query(db, &dest)
require.Error(t, err)
require.EqualError(t, err, `jet: can't append int64 to []bool slice: can't assign int64(2) to bool`)
})
})
@ -764,16 +784,8 @@ func TestRowsScan(t *testing.T) {
requireLogged(t, stmt)
}
func TestScanNumericToNumber(t *testing.T) {
func TestScanNumericToFloat(t *testing.T) {
type Number struct {
Int8 int8
UInt8 uint8
Int16 int16
UInt16 uint16
Int32 int32
UInt32 uint32
Int64 int64
UInt64 uint64
Float32 float32
Float64 float64
}
@ -781,14 +793,6 @@ func TestScanNumericToNumber(t *testing.T) {
numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC()
stmt := SELECT(
numeric.AS("number.int8"),
numeric.AS("number.uint8"),
numeric.AS("number.int16"),
numeric.AS("number.uint16"),
numeric.AS("number.int32"),
numeric.AS("number.uint32"),
numeric.AS("number.int64"),
numeric.AS("number.uint64"),
numeric.AS("number.float32"),
numeric.AS("number.float64"),
)
@ -796,19 +800,30 @@ func TestScanNumericToNumber(t *testing.T) {
var number Number
err := stmt.Query(db, &number)
require.NoError(t, err)
require.Equal(t, number.Int8, int8(-46)) // overflow
require.Equal(t, number.UInt8, uint8(210)) // overflow
require.Equal(t, number.Int16, int16(722)) // overflow
require.Equal(t, number.UInt16, uint16(722)) // overflow
require.Equal(t, number.Int32, int32(1234567890))
require.Equal(t, number.UInt32, uint32(1234567890))
require.Equal(t, number.Int64, int64(1234567890))
require.Equal(t, number.UInt64, uint64(1234567890))
require.Equal(t, number.Float32, float32(1.234568e+09))
require.Equal(t, number.Float64, float64(1.234567890111e+09))
}
func TestScanNumericToIntegerError(t *testing.T) {
var dest struct {
Integer int32
}
err := SELECT(
CAST(Decimal("1234567890.111")).AS_NUMERIC().AS("integer"),
).Query(db, &dest)
require.Error(t, err)
if isPgxDriver() {
require.Contains(t, err.Error(), `jet: can't assign string("1234567890.111") to 'Integer int32': converting driver.Value type string ("1234567890.111") to a int64: invalid syntax`)
} else {
require.Contains(t, err.Error(), `jet: can't assign []uint8("1234567890.111") to 'Integer int32': converting driver.Value type []uint8 ("1234567890.111") to a int64: invalid syntax`)
}
}
// QueryContext panic when the scanned value is nil and the destination is a slice of primitive
// https://github.com/go-jet/jet/issues/91
func TestScanToPrimitiveElementsSlice(t *testing.T) {