Qrm refactor

- Allow custom types Scan method to read values returned by the driver rather then the value from intermediate Null types. Scan to intermidiate Null types removed.
- Better error handling
This commit is contained in:
go-jet 2021-10-15 17:43:10 +02:00
parent 555ec293fb
commit 0d418890ab
11 changed files with 459 additions and 574 deletions

View file

@ -1,263 +1,170 @@
package internal
import (
"database/sql"
"database/sql/driver"
"fmt"
"github.com/go-jet/jet/v2/internal/utils/min"
"reflect"
"strconv"
"time"
)
//===============================================================//
// NullByteArray struct
type NullByteArray struct {
ByteArray []byte
Valid bool
// NullBool struct
type NullBool struct {
sql.NullBool
}
// Scan implements the Scanner interface.
func (nb *NullByteArray) Scan(value interface{}) error {
func (nb *NullBool) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
nb.Valid = false
return nil
case []byte:
nb.ByteArray = append(v[:0:0], v...)
case bool:
nb.Bool, nb.Valid = v, true
case int8, int16, int32, int64, int:
intVal := reflect.ValueOf(v).Int()
if intVal != 0 && intVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = intVal == 1
nb.Valid = true
case uint8, uint16, uint32, uint64, uint:
uintVal := reflect.ValueOf(v).Uint()
if uintVal != 0 && uintVal != 1 {
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
}
nb.Bool = uintVal == 1
nb.Valid = true
return nil
default:
return fmt.Errorf("can't scan []byte from %v", value)
return nb.NullBool.Scan(value)
}
}
// Value implements the driver Valuer interface.
func (nb NullByteArray) Value() (driver.Value, error) {
if !nb.Valid {
return nil, nil
}
return nb.ByteArray, nil
return nil
}
//===============================================================//
// NullTime struct
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
sql.NullTime
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) (err error) {
switch v := value.(type) {
case nil:
nt.Valid = false
return
case time.Time:
nt.Time, nt.Valid = v, true
return
case []byte:
nt.Time, nt.Valid = parseTime(string(v))
return
case string:
nt.Time, nt.Valid = parseTime(v)
return
default:
func (nt *NullTime) Scan(value interface{}) error {
err := nt.NullTime.Scan(value)
if err == nil {
return nil
}
// Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
// At this point we try to parse time using some of the predefined formats
nt.Time, nt.Valid = tryParseAsTime(value)
if !nt.Valid {
return fmt.Errorf("can't scan time.Time from %v", value)
}
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
var formats = []string{
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx
"15:04:05.999999", // pgx
}
const formatTime = "2006-01-02 15:04:05.999999"
func tryParseAsTime(value interface{}) (time.Time, bool) {
func parseTime(timeStr string) (t time.Time, valid bool) {
var timeStr string
var format string
switch len(timeStr) {
case 8:
format = formatTime[11:19]
case 10, 19, 21, 22, 23, 24, 25, 26:
format = formatTime[:len(timeStr)]
default:
return t, false
}
t, err := time.Parse(format, timeStr)
return t, err == nil
}
//===============================================================//
// NullInt8 struct
type NullInt8 struct {
Int8 int8
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt8) Scan(value interface{}) (err error) {
switch v := value.(type) {
case nil:
n.Valid = false
return
case int64:
n.Int8, n.Valid = int8(v), true
return
case int8:
n.Int8, n.Valid = v, true
return
case string:
timeStr = v
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 8)
if err == nil {
n.Int8, n.Valid = int8(intV), true
timeStr = string(v)
}
for _, format := range formats {
formatLen := min.Int(len(format), len(timeStr))
t, err := time.Parse(format[:formatLen], timeStr)
if err != nil {
continue
}
return err
default:
return fmt.Errorf("can't scan int8 from %v", value)
return t, true
}
return time.Time{}, false
}
// Value implements the driver Valuer interface.
func (n NullInt8) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int8, nil
}
//===============================================================//
// NullInt16 struct
type NullInt16 struct {
Int16 int16
Valid bool
// NullUInt64 struct
type NullUInt64 struct {
UInt64 uint64
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error {
func (n *NullUInt64) Scan(value interface{}) error {
var stringValue string
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64:
n.Int16, n.Valid = int16(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case int16:
n.Int16, n.Valid = v, true
return nil
case int8:
n.Int16, n.Valid = int16(v), true
return nil
case uint8:
n.Int16, n.Valid = int16(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 16)
if err == nil {
n.Int16, n.Valid = int16(intV), true
}
return nil
default:
return fmt.Errorf("can't scan int16 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullInt16) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int16, nil
}
//===============================================================//
// NullInt32 struct
type NullInt32 struct {
Int32 int32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case int64:
n.Int32, n.Valid = int32(v), true
case uint64:
n.UInt64, n.Valid = v, true
return nil
case int32:
n.Int32, n.Valid = v, true
n.UInt64, n.Valid = uint64(v), true
return nil
case uint32:
n.UInt64, n.Valid = uint64(v), true
return nil
case int16:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case uint16:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case int8:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
n.UInt64, n.Valid = uint64(v), true
return nil
case int:
n.UInt64, n.Valid = uint64(v), true
return nil
case uint:
n.UInt64, n.Valid = uint64(v), true
return nil
case []byte:
intV, err := strconv.ParseInt(string(v), 10, 32)
if err == nil {
n.Int32, n.Valid = int32(intV), true
}
return nil
stringValue = string(v)
case string:
stringValue = v
default:
return fmt.Errorf("can't scan int32 from %v", value)
return fmt.Errorf("can't scan uint64 from %v", value)
}
uintV, err := strconv.ParseUint(stringValue, 10, 64)
if err != nil {
return err
}
n.UInt64 = uintV
n.Valid = true
return nil
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
func (n NullUInt64) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
}
//===============================================================//
// NullFloat32 struct
type NullFloat32 struct {
Float32 float32
Valid bool
}
// Scan implements the Scanner interface.
func (n *NullFloat32) Scan(value interface{}) error {
switch v := value.(type) {
case nil:
n.Valid = false
return nil
case float64:
n.Float32, n.Valid = float32(v), true
return nil
case float32:
n.Float32, n.Valid = v, true
return nil
default:
return fmt.Errorf("can't scan float32 from %v", value)
}
}
// Value implements the driver Valuer interface.
func (n NullFloat32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Float32, nil
return n.UInt64, nil
}

View file

@ -7,141 +7,85 @@ import (
"time"
)
func TestNullByteArray(t *testing.T) {
var array NullByteArray
func TestNullBool(t *testing.T) {
var nullBool NullBool
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, nullBool.Scan(nil))
require.Equal(t, nullBool.Valid, false)
require.NoError(t, array.Scan([]byte("bytea")))
require.Equal(t, array.Valid, true)
require.Equal(t, string(array.ByteArray), string([]byte("bytea")))
require.NoError(t, nullBool.Scan(int64(1)))
require.Equal(t, nullBool.Valid, true)
value, _ := nullBool.Value()
require.Equal(t, value, true)
require.Error(t, array.Scan(12), "can't scan []byte from 12")
require.NoError(t, nullBool.Scan(uint32(0)))
require.Equal(t, nullBool.Valid, true)
value, _ = nullBool.Value()
require.Equal(t, value, false)
require.EqualError(t, nullBool.Scan(uint16(22)), "can't assign uint16(22) to bool")
}
func TestNullTime(t *testing.T) {
var array NullTime
var nullTime NullTime
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, nullTime.Scan(nil))
require.Equal(t, nullTime.Valid, false)
time := time.Now()
require.NoError(t, array.Scan(time))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.NoError(t, nullTime.Scan(time))
require.Equal(t, nullTime.Valid, true)
value, _ := nullTime.Value()
require.Equal(t, value, time)
require.NoError(t, array.Scan([]byte("13:10:11")))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.NoError(t, nullTime.Scan([]byte("13:10:11")))
require.Equal(t, nullTime.Valid, true)
value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.NoError(t, array.Scan("13:10:11"))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.NoError(t, nullTime.Scan("13:10:11"))
require.Equal(t, nullTime.Valid, true)
value, _ = nullTime.Value()
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
require.Error(t, array.Scan(12), "can't scan time.Time from 12")
require.Error(t, nullTime.Scan(12), "can't scan time.Time from 12")
}
func TestNullInt8(t *testing.T) {
var array NullInt8
func TestNullUInt64(t *testing.T) {
var nullUInt64 NullUInt64
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, nullUInt64.Scan(nil))
require.Equal(t, nullUInt64.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int8(11))
require.NoError(t, nullUInt64.Scan(int64(11)))
require.Equal(t, nullUInt64.Valid, true)
value, _ := nullUInt64.Value()
require.Equal(t, value, uint64(11))
require.Error(t, array.Scan("text"), "can't scan int8 from text")
}
func TestNullInt16(t *testing.T) {
var array NullInt16
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int16(11))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int16(20))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int16(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int16(30))
require.Error(t, array.Scan("text"), "can't scan int16 from text")
}
func TestNullInt32(t *testing.T) {
var array NullInt32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(int64(11)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, int32(11))
require.NoError(t, array.Scan(int32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(32))
require.NoError(t, array.Scan(int16(20)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(20))
require.NoError(t, array.Scan(uint16(16)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(16))
require.NoError(t, array.Scan(int8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.NoError(t, array.Scan(uint8(30)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, int32(30))
require.Error(t, array.Scan("text"), "can't scan int32 from text")
}
func TestNullFloat32(t *testing.T) {
var array NullFloat32
require.NoError(t, array.Scan(nil))
require.Equal(t, array.Valid, false)
require.NoError(t, array.Scan(float64(64)))
require.Equal(t, array.Valid, true)
value, _ := array.Value()
require.Equal(t, value, float32(64))
require.NoError(t, array.Scan(float32(32)))
require.Equal(t, array.Valid, true)
value, _ = array.Value()
require.Equal(t, value, float32(32))
require.Error(t, array.Scan(12), "can't scan float32 from 12")
require.NoError(t, nullUInt64.Scan(int32(32)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(32))
require.NoError(t, nullUInt64.Scan(int16(20)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(20))
require.NoError(t, nullUInt64.Scan(uint16(16)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(16))
require.NoError(t, nullUInt64.Scan(int8(30)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.NoError(t, nullUInt64.Scan(uint8(30)))
require.Equal(t, nullUInt64.Valid, true)
value, _ = nullUInt64.Value()
require.Equal(t, value, uint64(30))
require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text")
}