Qrm refactor

- Allow custom types Scan method to read values returned by the driver rather then the value from intermediate Null types. Scan to intermidiate Null types removed.
- Better error handling
This commit is contained in:
go-jet 2021-10-15 17:43:10 +02:00
parent 555ec293fb
commit 0d418890ab
11 changed files with 459 additions and 574 deletions

View file

@ -20,10 +20,17 @@ const (
) )
func (e *MpaaRating) Scan(value interface{}) error { func (e *MpaaRating) Scan(value interface{}) error {
if v, ok := value.(string); !ok { var enumValue string
return errors.New("jet: Invalid data for MpaaRating enum") switch val := value.(type) {
} else { case string:
switch string(v) { enumValue = val
case []byte:
enumValue = string(val)
default:
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
}
switch enumValue {
case "G": case "G":
*e = MpaaRating_G *e = MpaaRating_G
case "PG": case "PG":
@ -35,12 +42,11 @@ func (e *MpaaRating) Scan(value interface{}) error {
case "NC-17": case "NC-17":
*e = MpaaRating_Nc17 *e = MpaaRating_Nc17
default: default:
return errors.New("jet: Inavlid data " + string(v) + "for MpaaRating enum") return errors.New("jet: Invalid scan value '" + enumValue + "' for MpaaRating enum")
} }
return nil return nil
} }
}
func (e MpaaRating) String() string { func (e MpaaRating) String() string {
return string(e) return string(e)

View file

@ -200,21 +200,27 @@ const (
) )
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error { func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok { var enumValue string
return errors.New("jet: Invalid scan value for {{$enumTemplate.TypeName}} enum. Enum value has to be of type string") switch val := value.(type) {
} else { case string:
switch string(v) { enumValue = val
case []byte:
enumValue = string(val)
default:
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
}
switch enumValue {
{{- range $_, $value := .Values}} {{- range $_, $value := .Values}}
case "{{$value}}": case "{{$value}}":
*e = {{valueName $value}} *e = {{valueName $value}}
{{- end}} {{- end}}
default: default:
return errors.New("jet: Invalid scan value '" + string(v) + "' for {{$enumTemplate.TypeName}} enum") return errors.New("jet: Invalid scan value '" + enumValue + "' for {{$enumTemplate.TypeName}} enum")
} }
return nil return nil
} }
}
func (e {{$enumTemplate.TypeName}}) String() string { func (e {{$enumTemplate.TypeName}}) String() string {
return string(e) return string(e)

View file

@ -0,0 +1,9 @@
package min
// Int returns minimum of two int values
func Int(a, b int) int {
if a < b {
return a
}
return b
}

View file

@ -1,263 +1,170 @@
package internal package internal
import ( import (
"database/sql"
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils/min"
"reflect"
"strconv" "strconv"
"time" "time"
) )
//===============================================================// // NullBool struct
type NullBool struct {
// NullByteArray struct sql.NullBool
type NullByteArray struct {
ByteArray []byte
Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nb *NullByteArray) Scan(value interface{}) error { func (nb *NullBool) Scan(value interface{}) error {
switch v := value.(type) { switch v := value.(type) {
case nil: case bool:
nb.Valid = false nb.Bool, nb.Valid = v, true
return nil case int8, int16, int32, int64, int:
case []byte: intVal := reflect.ValueOf(v).Int()
nb.ByteArray = append(v[:0:0], v...)
if intVal != 0 && intVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = intVal == 1
nb.Valid = true
case uint8, uint16, uint32, uint64, uint:
uintVal := reflect.ValueOf(v).Uint()
if uintVal != 0 && uintVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = uintVal == 1
nb.Valid = true nb.Valid = true
return nil
default: default:
return fmt.Errorf("can't scan []byte from %v", value) return nb.NullBool.Scan(value)
}
} }
// Value implements the driver Valuer interface. return nil
func (nb NullByteArray) Value() (driver.Value, error) {
if !nb.Valid {
return nil, nil
} }
return nb.ByteArray, nil
}
//===============================================================//
// NullTime struct // NullTime struct
type NullTime struct { type NullTime struct {
Time time.Time sql.NullTime
Valid bool // Valid is true if Time is not NULL
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) (err error) { func (nt *NullTime) Scan(value interface{}) error {
switch v := value.(type) { err := nt.NullTime.Scan(value)
case nil:
nt.Valid = false if err == nil {
return return nil
case time.Time: }
nt.Time, nt.Valid = v, true
return // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
case []byte: // At this point we try to parse time using some of the predefined formats
nt.Time, nt.Valid = parseTime(string(v)) nt.Time, nt.Valid = tryParseAsTime(value)
return
case string: if !nt.Valid {
nt.Time, nt.Valid = parseTime(v)
return
default:
return fmt.Errorf("can't scan time.Time from %v", value) return fmt.Errorf("can't scan time.Time from %v", value)
} }
return nil
} }
// Value implements the driver Valuer interface. var formats = []string{
func (nt NullTime) Value() (driver.Value, error) { "2006-01-02 15:04:05.999999", // go-sql-driver/mysql
if !nt.Valid { "15:04:05-07", // pgx
return nil, nil "15:04:05.999999", // pgx
}
return nt.Time, nil
} }
const formatTime = "2006-01-02 15:04:05.999999" func tryParseAsTime(value interface{}) (time.Time, bool) {
func parseTime(timeStr string) (t time.Time, valid bool) { var timeStr string
var format string
switch len(timeStr) {
case 8:
format = formatTime[11:19]
case 10, 19, 21, 22, 23, 24, 25, 26:
format = formatTime[:len(timeStr)]
default:
return t, false
}
t, err := time.Parse(format, timeStr)
return t, err == nil
}
//===============================================================//
// NullInt8 struct
type NullInt8 struct {
Int8 int8
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt8) Scan(value interface{}) (err error) {
switch v := value.(type) { switch v := value.(type) {
case nil: case string:
n.Valid = false timeStr = v
return
case int64:
n.Int8, n.Valid = int8(v), true
return
case int8:
n.Int8, n.Valid = v, true
return
case []byte: case []byte:
intV, err := strconv.ParseInt(string(v), 10, 8) timeStr = string(v)
if err == nil {
n.Int8, n.Valid = int8(intV), true
}
return err
default:
return fmt.Errorf("can't scan int8 from %v", value)
}
} }
// Value implements the driver Valuer interface. for _, format := range formats {
func (n NullInt8) Value() (driver.Value, error) { formatLen := min.Int(len(format), len(timeStr))
if !n.Valid { t, err := time.Parse(format[:formatLen], timeStr)
return nil, nil
} if err != nil {
return n.Int8, nil continue
} }
//===============================================================// return t, true
}
// NullInt16 struct return time.Time{}, false
type NullInt16 struct { }
Int16 int16
// NullUInt64 struct
type NullUInt64 struct {
UInt64 uint64
Valid bool Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error { func (n *NullUInt64) Scan(value interface{}) error {
var stringValue string
switch v := value.(type) { switch v := value.(type) {
case nil: case nil:
n.Valid = false n.Valid = false
return nil return nil
case int64: case int64:
n.Int16, n.Valid = int16(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int16: case uint64:
n.Int16, n.Valid = v, true n.UInt64, n.Valid = v, true
return nil
case int8:
n.Int16, n.Valid = int16(v), true
return nil
case uint8:
n.Int16, n.Valid = int16(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 16)
if err == nil {
n.Int16, n.Valid = int16(intV), true
}
return nil
default:
return fmt.Errorf("can't scan int16 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullInt16) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int16, nil
}
//===============================================================//
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64:
n.Int32, n.Valid = int32(v), true
return nil return nil
case int32: case int32:
n.Int32, n.Valid = v, true n.UInt64, n.Valid = uint64(v), true
return nil
case uint32:
n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int16: case int16:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case uint16: case uint16:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case int8: case int8:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil return nil
case uint8: case uint8:
n.Int32, n.Valid = int32(v), true n.UInt64, n.Valid = uint64(v), true
return nil
case int:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint:
n.UInt64, n.Valid = uint64(v), true
return nil return nil
case []byte: case []byte:
intV, err := strconv.ParseInt(string(v), 10, 32) stringValue = string(v)
if err == nil { case string:
n.Int32, n.Valid = int32(intV), true stringValue = v
}
return nil
default: default:
return fmt.Errorf("can't scan int32 from %v", value) return fmt.Errorf("can't scan uint64 from %v", value)
} }
uintV, err := strconv.ParseUint(stringValue, 10, 64)
if err != nil {
return err
}
n.UInt64 = uintV
n.Valid = true
return nil
} }
// Value implements the driver Valuer interface. // Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) { func (n NullUInt64) Value() (driver.Value, error) {
if !n.Valid { if !n.Valid {
return nil, nil return nil, nil
} }
return n.Int32, nil return n.UInt64, nil
}
//===============================================================//
// NullFloat32 struct
type NullFloat32 struct {
Float32 float32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullFloat32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case float64:
n.Float32, n.Valid = float32(v), true
return nil
case float32:
n.Float32, n.Valid = v, true
return nil
default:
return fmt.Errorf("can't scan float32 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullFloat32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Float32, nil
} }

View file

@ -7,141 +7,85 @@ import (
"time" "time"
) )
func TestNullByteArray(t *testing.T) { func TestNullBool(t *testing.T) {
var array NullByteArray var nullBool NullBool
require.NoError(t, array.Scan(nil)) require.NoError(t, nullBool.Scan(nil))
require.Equal(t, array.Valid, false) require.Equal(t, nullBool.Valid, false)
require.NoError(t, array.Scan([]byte("bytea"))) require.NoError(t, nullBool.Scan(int64(1)))
require.Equal(t, array.Valid, true) require.Equal(t, nullBool.Valid, true)
require.Equal(t, string(array.ByteArray), string([]byte("bytea"))) value, _ := nullBool.Value()
require.Equal(t, value, true)
require.Error(t, array.Scan(12), "can't scan []byte from 12") require.NoError(t, nullBool.Scan(uint32(0)))
require.Equal(t, nullBool.Valid, true)
value, _ = nullBool.Value()
require.Equal(t, value, false)
require.EqualError(t, nullBool.Scan(uint16(22)), "can't assign uint16(22) to bool")
} }
func TestNullTime(t *testing.T) { func TestNullTime(t *testing.T) {
var array NullTime var nullTime NullTime
require.NoError(t, array.Scan(nil)) require.NoError(t, nullTime.Scan(nil))
require.Equal(t, array.Valid, false) require.Equal(t, nullTime.Valid, false)
time := time.Now() time := time.Now()
require.NoError(t, array.Scan(time)) require.NoError(t, nullTime.Scan(time))
require.Equal(t, array.Valid, true) require.Equal(t, nullTime.Valid, true)
value, _ := array.Value() value, _ := nullTime.Value()
require.Equal(t, value, time) require.Equal(t, value, time)
require.NoError(t, array.Scan([]byte("13:10:11"))) require.NoError(t, nullTime.Scan([]byte("13:10:11")))
require.Equal(t, array.Valid, true) require.Equal(t, nullTime.Valid, true)
value, _ = array.Value() value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.NoError(t, array.Scan("13:10:11")) require.NoError(t, nullTime.Scan("13:10:11"))
require.Equal(t, array.Valid, true) require.Equal(t, nullTime.Valid, true)
value, _ = array.Value() value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.Error(t, array.Scan(12), "can't scan time.Time from 12") require.Error(t, nullTime.Scan(12), "can't scan time.Time from 12")
} }
func TestNullInt8(t *testing.T) { func TestNullUInt64(t *testing.T) {
var array NullInt8 var nullUInt64 NullUInt64
require.NoError(t, array.Scan(nil)) require.NoError(t, nullUInt64.Scan(nil))
require.Equal(t, array.Valid, false) require.Equal(t, nullUInt64.Valid, false)
require.NoError(t, array.Scan(int64(11))) require.NoError(t, nullUInt64.Scan(int64(11)))
require.Equal(t, array.Valid, true) require.Equal(t, nullUInt64.Valid, true)
value, _ := array.Value() value, _ := nullUInt64.Value()
require.Equal(t, value, int8(11)) require.Equal(t, value, uint64(11))
require.Error(t, array.Scan("text"), "can't scan int8 from text") require.NoError(t, nullUInt64.Scan(int32(32)))
} require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
func TestNullInt16(t *testing.T) { require.Equal(t, value, uint64(32))
var array NullInt16
require.NoError(t, nullUInt64.Scan(int16(20)))
require.NoError(t, array.Scan(nil)) require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, array.Valid, false) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(20))
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true) require.NoError(t, nullUInt64.Scan(uint16(16)))
value, _ := array.Value() require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, value, int16(11)) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(16))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true) require.NoError(t, nullUInt64.Scan(int8(30)))
value, _ = array.Value() require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, value, int16(20)) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true) require.NoError(t, nullUInt64.Scan(uint8(30)))
value, _ = array.Value() require.Equal(t, nullUInt64.Valid, true)
require.Equal(t, value, int16(30)) value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true) require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text")
value, _ = array.Value()
require.Equal(t, value, int16(30))
require.Error(t, array.Scan("text"), "can't scan int16 from text")
}
func TestNullInt32(t *testing.T) {
var array NullInt32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int32(11))
require.NoError(t, array.Scan(int32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(32))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(20))
require.NoError(t, array.Scan(uint16(16)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(16))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.Error(t, array.Scan("text"), "can't scan int32 from text")
}
func TestNullFloat32(t *testing.T) {
var array NullFloat32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(float64(64)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, float32(64))
require.NoError(t, array.Scan(float32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, float32(32))
require.Error(t, array.Scan(12), "can't scan float32 from 12")
} }

View file

@ -27,7 +27,10 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
_, err := queryToSlice(ctx, db, query, args, destPtr) _, err := queryToSlice(ctx, db, query, args, destPtr)
return err if err != nil {
return fmt.Errorf("jet: %w", err)
}
return nil
} else if destinationPtrType.Elem().Kind() == reflect.Struct { } else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem() tempSliceValue := tempSlicePtrValue.Elem()
@ -35,7 +38,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
if err != nil { if err != nil {
return err return fmt.Errorf("jet: %w", err)
} }
if rowsProcessed == 0 { if rowsProcessed == 0 {
@ -275,10 +278,16 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
err = scanner.Scan(cellValue) err = scanner.Scan(cellValue)
if err != nil { if err != nil {
panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String()) err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
} }
} else { } else {
setReflectValue(reflect.ValueOf(cellValue), fieldValue) err = setReflectValue(reflect.ValueOf(cellValue), fieldValue)
if err != nil {
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
return
}
} }
} }
} }

View file

@ -2,10 +2,7 @@ package qrm
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/internal/utils/throw"
"reflect" "reflect"
"strings" "strings"
) )
@ -46,7 +43,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
} }
return &scanContext{ return &scanContext{
row: createScanValue(columnTypes), row: createScanSlice(len(columnTypes)),
uniqueDestObjectsMap: make(map[string]int), uniqueDestObjectsMap: make(map[string]int),
groupKeyInfoCache: make(map[string]groupKeyInfo), groupKeyInfoCache: make(map[string]groupKeyInfo),
@ -56,6 +53,17 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
}, nil }, nil
} }
func createScanSlice(columnCount int) []interface{} {
scanSlice := make([]interface{}, columnCount)
scanPtrSlice := make([]interface{}, columnCount)
for i := range scanPtrSlice {
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value
}
return scanPtrSlice
}
type typeInfo struct { type typeInfo struct {
fieldMappings []fieldMapping fieldMappings []fieldMapping
} }
@ -210,16 +218,13 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
} }
func (s *scanContext) rowElem(index int) interface{} { func (s *scanContext) rowElem(index int) interface{} {
cellValue := reflect.ValueOf(s.row[index])
valuer, ok := s.row[index].(driver.Valuer) if cellValue.IsValid() && !cellValue.IsNil() {
return cellValue.Elem().Interface()
}
utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer") return nil
value, err := valuer.Value()
throw.OnError(err)
return value
} }
func (s *scanContext) rowElemValuePtr(index int) reflect.Value { func (s *scanContext) rowElemValuePtr(index int) reflect.Value {

View file

@ -7,7 +7,6 @@ import (
"github.com/go-jet/jet/v2/qrm/internal" "github.com/go-jet/jet/v2/qrm/internal"
"github.com/google/uuid" "github.com/google/uuid"
"reflect" "reflect"
"strconv"
"strings" "strings"
"time" "time"
) )
@ -56,21 +55,22 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
sliceValue := slicePtrValue.Elem() sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem() sliceElemType := sliceValue.Type().Elem()
newElemValue := objPtrValue newSliceElemValue := reflect.New(sliceElemType).Elem()
if sliceElemType.Kind() != reflect.Ptr { var err error
newElemValue = objPtrValue.Elem()
if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
} }
if newElemValue.Type().ConvertibleTo(sliceElemType) { if err != nil {
newElemValue = newElemValue.Convert(sliceElemType) return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err)
} }
if !newElemValue.Type().AssignableTo(sliceElemType) { sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue))
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
}
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
return nil return nil
} }
@ -121,7 +121,6 @@ func toCommonIdentifier(name string) string {
} }
func initializeValueIfNilPtr(value reflect.Value) { func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() { if !value.IsValid() || !value.CanSet() {
return return
} }
@ -173,172 +172,147 @@ func isSimpleModelType(objType reflect.Type) bool {
return objType == timeType || objType == uuidType || objType == byteArrayType return objType == timeType || objType == uuidType || objType == byteArrayType
} }
func isIntegerType(value reflect.Type) bool { func isFloatType(value reflect.Type) bool {
switch value { switch value.Kind() {
case int8Type, unit8Type, int16Type, uint16Type, case reflect.Float32, reflect.Float64:
int32Type, uint32Type, int64Type, uint64Type:
return true return true
} }
return false return false
} }
func isNumber(valueType reflect.Type) bool { func tryAssign(source, destination reflect.Value) error {
return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type
}
func tryAssign(source, destination reflect.Value) bool { if source.Type() != destination.Type() &&
!isFloatType(destination.Type()) && // to preserve precision during conversion
source.Type().ConvertibleTo(destination.Type()) {
switch {
case source.Type().ConvertibleTo(destination.Type()):
source = source.Convert(destination.Type()) source = source.Convert(destination.Type())
case isIntegerType(source.Type()) && destination.Type() == boolType:
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
case source.Type() == stringType && isNumber(destination.Type()):
// if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number
// and then parsed number is converted into destination type
f, err := strconv.ParseFloat(source.String(), 64)
if err != nil {
return false
}
source = reflect.ValueOf(f)
if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
} }
if source.Type().AssignableTo(destination.Type()) { if source.Type().AssignableTo(destination.Type()) {
switch b := source.Interface().(type) {
case []byte:
destination.SetBytes(cloneBytes(b))
default:
destination.Set(source) destination.Set(source)
return true }
return nil
} }
return false sourceInterface := source.Interface()
switch destination.Interface().(type) {
case bool:
var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface)
if err != nil {
return err
} }
func setReflectValue(source, destination reflect.Value) { destination.SetBool(nullBool.Bool)
if tryAssign(source, destination) { case float32, float64:
return var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface)
if err != nil {
return err
} }
if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64)
}
case int, int8, int16, int32, int64:
var integer sql.NullInt64
err := integer.Scan(sourceInterface)
if err != nil {
return err
}
if integer.Valid {
destination.SetInt(integer.Int64)
}
case uint, uint8, uint16, uint32, uint64:
var uInt internal.NullUInt64
err := uInt.Scan(sourceInterface)
if err != nil {
return err
}
if uInt.Valid {
destination.SetUint(uInt.UInt64)
}
case string:
var str sql.NullString
err := str.Scan(sourceInterface)
if err != nil {
return err
}
if str.Valid {
destination.SetString(str.String)
}
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
}
return nil
}
func setReflectValue(source, destination reflect.Value) error {
if destination.Kind() == reflect.Ptr { if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr {
if !source.IsNil() {
if destination.IsNil() { if destination.IsNil() {
initializeValueIfNilPtr(destination) initializeValueIfNilPtr(destination)
} }
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else {
if source.CanAddr() {
source = source.Addr()
} else {
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
}
} else {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
if source.IsNil() { if source.IsNil() {
return return nil // source is nil, destination should keep its zero value
} }
source = source.Elem() source = source.Elem()
} }
if tryAssign(source, destination) { if err := tryAssign(source, destination.Elem()); err != nil {
return return err
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination); err != nil {
return err
} }
} }
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) return nil
}
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
values := make([]interface{}, len(columnTypes))
for i, sqlColumnType := range columnTypes {
columnType := newScanType(sqlColumnType)
columnValue := reflect.New(columnType)
values[i] = columnValue.Interface()
}
return values
}
var boolType = reflect.TypeOf(true)
var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var float32Type = reflect.TypeOf(float32(1))
var float64Type = reflect.TypeOf(float64(1))
var stringType = reflect.TypeOf("")
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type
case "INT4", "MEDIUMINT", "INT":
return nullInt32Type
case "INT8", "BIGINT":
return nullInt64Type
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloat32Type
case "FLOAT8", "FLOAT", "DOUBLE":
return nullFloat64Type
case "BOOL":
return nullBoolType
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
return nullStringType
}
} }
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
@ -385,3 +359,12 @@ func fieldToString(field *reflect.StructField) string {
return " at '" + field.Name + " " + field.Type.String() + "'" return " at '" + field.Name + " " + field.Type.String() + "'"
} }
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
c := make([]byte, len(b))
copy(c, b)
return c
}

View file

@ -17,11 +17,12 @@ import (
) )
func TestAllTypesSelect(t *testing.T) { func TestAllTypesSelect(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string
dest := []model.AllTypes{} dest := []model.AllTypes{}
err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest) err := AllTypes.SELECT(
AllTypes.AllColumns,
).LIMIT(2).
Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
@ -29,8 +30,6 @@ func TestAllTypesSelect(t *testing.T) {
} }
func TestAllTypesViewSelect(t *testing.T) { func TestAllTypesViewSelect(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string
type AllTypesView model.AllTypes type AllTypesView model.AllTypes
dest := []AllTypesView{} dest := []AllTypesView{}
@ -43,7 +42,7 @@ func TestAllTypesViewSelect(t *testing.T) {
} }
func TestAllTypesInsertModel(t *testing.T) { func TestAllTypesInsertModel(t *testing.T) {
skipForPgxDriver(t) // pgx driver does not handle well time with time zone skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008)
query := AllTypes.INSERT(AllTypes.AllColumns). query := AllTypes.INSERT(AllTypes.AllColumns).
MODEL(allTypesRow0). MODEL(allTypesRow0).
@ -60,8 +59,6 @@ func TestAllTypesInsertModel(t *testing.T) {
} }
func TestAllTypesInsertQuery(t *testing.T) { func TestAllTypesInsertQuery(t *testing.T) {
skipForPgxDriver(t) // pgx driver does not handle well time with time zone
query := AllTypes.INSERT(AllTypes.AllColumns). query := AllTypes.INSERT(AllTypes.AllColumns).
QUERY( QUERY(
AllTypes. AllTypes.
@ -80,8 +77,6 @@ func TestAllTypesInsertQuery(t *testing.T) {
} }
func TestAllTypesFromSubQuery(t *testing.T) { func TestAllTypesFromSubQuery(t *testing.T) {
skipForPgxDriver(t)
subQuery := SELECT(AllTypes.AllColumns). subQuery := SELECT(AllTypes.AllColumns).
FROM(AllTypes). FROM(AllTypes).
AsTable("allTypesSubQuery") AsTable("allTypesSubQuery")
@ -302,10 +297,10 @@ LIMIT $11;
func TestExpressionCast(t *testing.T) { func TestExpressionCast(t *testing.T) {
skipForPgxDriver(t) // for some reason, pgx driver, 150:char(12) returns as int value skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text'
query := AllTypes.SELECT( query := AllTypes.SELECT(
CAST(Int(150)).AS_CHAR(12).AS("char12"), CAST(Int(151)).AS_CHAR(12).AS("char12"),
CAST(String("TRUE")).AS_BOOL(), CAST(String("TRUE")).AS_BOOL(),
CAST(String("111")).AS_SMALLINT(), CAST(String("111")).AS_SMALLINT(),
CAST(String("111")).AS_INTEGER(), CAST(String("111")).AS_INTEGER(),
@ -349,7 +344,7 @@ func TestExpressionCast(t *testing.T) {
} }
func TestStringOperators(t *testing.T) { func TestStringOperators(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns text column as int value skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text'
query := AllTypes.SELECT( query := AllTypes.SELECT(
AllTypes.Text.EQ(AllTypes.Char), AllTypes.Text.EQ(AllTypes.Char),
@ -866,8 +861,6 @@ func TestInterval(t *testing.T) {
} }
func TestSubQueryColumnReference(t *testing.T) { func TestSubQueryColumnReference(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string value
type expected struct { type expected struct {
sql string sql string
args []interface{} args []interface{}
@ -1044,8 +1037,6 @@ FROM`
} }
func TestTimeLiterals(t *testing.T) { func TestTimeLiterals(t *testing.T) {
skipForPgxDriver(t) // pgx driver returns time with time zone as string
loc, err := time.LoadLocation("Europe/Berlin") loc, err := time.LoadLocation("Europe/Berlin")
require.NoError(t, err) require.NoError(t, err)
@ -1060,8 +1051,6 @@ func TestTimeLiterals(t *testing.T) {
).FROM(AllTypes). ).FROM(AllTypes).
LIMIT(1) LIMIT(1)
//fmt.Println(query.Sql())
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
SELECT $1::date AS "date", SELECT $1::date AS "date",
$2::time without time zone AS "time", $2::time without time zone AS "time",
@ -1077,21 +1066,25 @@ LIMIT $6;
Time time.Time Time time.Time
Timez time.Time Timez time.Time
Timestamp time.Time Timestamp time.Time
//Timestampz time.Time Timestampz time.Time
} }
err = query.Query(db, &dest) err = query.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
//testutils.PrintJson(dest) // pq driver will return time with time zone in local timezone,
// while pgx driver will return time in UTC time zone
dest.Timez = dest.Timez.UTC()
dest.Timestampz = dest.Timestampz.UTC()
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
{ {
"Date": "2009-11-17T00:00:00Z", "Date": "2009-11-17T00:00:00Z",
"Time": "0000-01-01T20:34:58.651387Z", "Time": "0000-01-01T20:34:58.651387Z",
"Timez": "0000-01-01T20:34:58.651387+01:00", "Timez": "0000-01-01T19:34:58.651387Z",
"Timestamp": "2009-11-17T20:34:58.651387Z" "Timestamp": "2009-11-17T20:34:58.651387Z",
"Timestampz": "2009-11-17T19:34:58.651387Z"
} }
`) `)
requireLogged(t, query) requireLogged(t, query)

View file

@ -31,7 +31,7 @@ func TestMain(m *testing.M) {
setTestRoot() setTestRoot()
for _, driverName := range []string{"postgres", "pgx"} { for _, driverName := range []string{"pgx", "postgres"} {
fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName)
func() { func() {
@ -81,8 +81,16 @@ func requireLogged(t *testing.T, statement postgres.Statement) {
} }
func skipForPgxDriver(t *testing.T) { func skipForPgxDriver(t *testing.T) {
switch db.Driver().(type) { if isPgxDriver() {
case *stdlib.Driver:
t.SkipNow() t.SkipNow()
} }
} }
func isPgxDriver() bool {
switch db.Driver().(type) {
case *stdlib.Driver:
return true
}
return false
}

View file

@ -78,16 +78,31 @@ func TestScanToValidDestination(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("pointer to slice of strings", func(t *testing.T) { t.Run("pointer to slice of integers", func(t *testing.T) {
err := oneInventoryQuery.Query(db, &[]int32{}) var dest []int32
err := oneInventoryQuery.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, dest[0], int32(1))
}) })
t.Run("pointer to slice of strings", func(t *testing.T) { t.Run("pointer to slice integer pointers", func(t *testing.T) {
err := oneInventoryQuery.Query(db, &[]*int32{}) var dest []*int32
err := oneInventoryQuery.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, dest[0], testutils.Int32Ptr(1))
})
t.Run("NULL to integer", func(t *testing.T) {
var dest struct {
Int64 int64
UInt64 uint64
}
err := SELECT(NULL.AS("int64"), NULL.AS("uint64")).Query(db, &dest)
require.NoError(t, err)
require.Equal(t, dest.Int64, int64(0))
require.Equal(t, dest.UInt64, uint64(0))
}) })
} }
@ -189,7 +204,9 @@ func TestScanToStruct(t *testing.T) {
dest := Inventory{} dest := Inventory{}
testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: Scan: unable to scan type int32 into UUID, at 'InventoryID uuid.UUID' of type postgres.Inventory`) err := query.Query(db, &dest)
require.Error(t, err)
require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID")
}) })
t.Run("type mismatch base type", func(t *testing.T) { t.Run("type mismatch base type", func(t *testing.T) {
@ -200,7 +217,9 @@ func TestScanToStruct(t *testing.T) {
dest := []Inventory{} dest := []Inventory{}
testutils.AssertQueryPanicErr(t, query.OFFSET(10), db, &dest, `jet: can't set int16 to bool`) err := query.OFFSET(10).Query(db, &dest)
require.Error(t, err)
require.EqualError(t, err, "jet: can't assign int64('\\x02') to 'FilmID bool': can't assign int64(2) to bool")
}) })
} }
@ -451,8 +470,9 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice type mismatch", func(t *testing.T) { t.Run("slice type mismatch", func(t *testing.T) {
var dest []bool var dest []bool
testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: can't append int32 to []bool slice`) err := query.Query(db, &dest)
//require.Error(t, err, `jet: can't append int32 to []bool slice `) require.Error(t, err)
require.EqualError(t, err, `jet: can't append int64 to []bool slice: can't assign int64(2) to bool`)
}) })
}) })
@ -764,16 +784,8 @@ func TestRowsScan(t *testing.T) {
requireLogged(t, stmt) requireLogged(t, stmt)
} }
func TestScanNumericToNumber(t *testing.T) { func TestScanNumericToFloat(t *testing.T) {
type Number struct { type Number struct {
Int8 int8
UInt8 uint8
Int16 int16
UInt16 uint16
Int32 int32
UInt32 uint32
Int64 int64
UInt64 uint64
Float32 float32 Float32 float32
Float64 float64 Float64 float64
} }
@ -781,14 +793,6 @@ func TestScanNumericToNumber(t *testing.T) {
numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC() numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC()
stmt := SELECT( stmt := SELECT(
numeric.AS("number.int8"),
numeric.AS("number.uint8"),
numeric.AS("number.int16"),
numeric.AS("number.uint16"),
numeric.AS("number.int32"),
numeric.AS("number.uint32"),
numeric.AS("number.int64"),
numeric.AS("number.uint64"),
numeric.AS("number.float32"), numeric.AS("number.float32"),
numeric.AS("number.float64"), numeric.AS("number.float64"),
) )
@ -796,19 +800,30 @@ func TestScanNumericToNumber(t *testing.T) {
var number Number var number Number
err := stmt.Query(db, &number) err := stmt.Query(db, &number)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, number.Int8, int8(-46)) // overflow
require.Equal(t, number.UInt8, uint8(210)) // overflow
require.Equal(t, number.Int16, int16(722)) // overflow
require.Equal(t, number.UInt16, uint16(722)) // overflow
require.Equal(t, number.Int32, int32(1234567890))
require.Equal(t, number.UInt32, uint32(1234567890))
require.Equal(t, number.Int64, int64(1234567890))
require.Equal(t, number.UInt64, uint64(1234567890))
require.Equal(t, number.Float32, float32(1.234568e+09)) require.Equal(t, number.Float32, float32(1.234568e+09))
require.Equal(t, number.Float64, float64(1.234567890111e+09)) require.Equal(t, number.Float64, float64(1.234567890111e+09))
} }
func TestScanNumericToIntegerError(t *testing.T) {
var dest struct {
Integer int32
}
err := SELECT(
CAST(Decimal("1234567890.111")).AS_NUMERIC().AS("integer"),
).Query(db, &dest)
require.Error(t, err)
if isPgxDriver() {
require.Contains(t, err.Error(), `jet: can't assign string("1234567890.111") to 'Integer int32': converting driver.Value type string ("1234567890.111") to a int64: invalid syntax`)
} else {
require.Contains(t, err.Error(), `jet: can't assign []uint8("1234567890.111") to 'Integer int32': converting driver.Value type []uint8 ("1234567890.111") to a int64: invalid syntax`)
}
}
// QueryContext panic when the scanned value is nil and the destination is a slice of primitive // QueryContext panic when the scanned value is nil and the destination is a slice of primitive
// https://github.com/go-jet/jet/issues/91 // https://github.com/go-jet/jet/issues/91
func TestScanToPrimitiveElementsSlice(t *testing.T) { func TestScanToPrimitiveElementsSlice(t *testing.T) {