Qrm refactor
- Allow custom types Scan method to read values returned by the driver rather then the value from intermediate Null types. Scan to intermidiate Null types removed. - Better error handling
This commit is contained in:
parent
555ec293fb
commit
0d418890ab
11 changed files with 459 additions and 574 deletions
|
|
@ -20,10 +20,17 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (e *MpaaRating) Scan(value interface{}) error {
|
func (e *MpaaRating) Scan(value interface{}) error {
|
||||||
if v, ok := value.(string); !ok {
|
var enumValue string
|
||||||
return errors.New("jet: Invalid data for MpaaRating enum")
|
switch val := value.(type) {
|
||||||
} else {
|
case string:
|
||||||
switch string(v) {
|
enumValue = val
|
||||||
|
case []byte:
|
||||||
|
enumValue = string(val)
|
||||||
|
default:
|
||||||
|
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch enumValue {
|
||||||
case "G":
|
case "G":
|
||||||
*e = MpaaRating_G
|
*e = MpaaRating_G
|
||||||
case "PG":
|
case "PG":
|
||||||
|
|
@ -35,11 +42,10 @@ func (e *MpaaRating) Scan(value interface{}) error {
|
||||||
case "NC-17":
|
case "NC-17":
|
||||||
*e = MpaaRating_Nc17
|
*e = MpaaRating_Nc17
|
||||||
default:
|
default:
|
||||||
return errors.New("jet: Inavlid data " + string(v) + "for MpaaRating enum")
|
return errors.New("jet: Invalid scan value '" + enumValue + "' for MpaaRating enum")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e MpaaRating) String() string {
|
func (e MpaaRating) String() string {
|
||||||
|
|
|
||||||
|
|
@ -200,20 +200,26 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
|
func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error {
|
||||||
if v, ok := value.(string); !ok {
|
var enumValue string
|
||||||
return errors.New("jet: Invalid scan value for {{$enumTemplate.TypeName}} enum. Enum value has to be of type string")
|
switch val := value.(type) {
|
||||||
} else {
|
case string:
|
||||||
switch string(v) {
|
enumValue = val
|
||||||
|
case []byte:
|
||||||
|
enumValue = string(val)
|
||||||
|
default:
|
||||||
|
return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch enumValue {
|
||||||
{{- range $_, $value := .Values}}
|
{{- range $_, $value := .Values}}
|
||||||
case "{{$value}}":
|
case "{{$value}}":
|
||||||
*e = {{valueName $value}}
|
*e = {{valueName $value}}
|
||||||
{{- end}}
|
{{- end}}
|
||||||
default:
|
default:
|
||||||
return errors.New("jet: Invalid scan value '" + string(v) + "' for {{$enumTemplate.TypeName}} enum")
|
return errors.New("jet: Invalid scan value '" + enumValue + "' for {{$enumTemplate.TypeName}} enum")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e {{$enumTemplate.TypeName}}) String() string {
|
func (e {{$enumTemplate.TypeName}}) String() string {
|
||||||
|
|
|
||||||
9
internal/utils/min/min.go
Normal file
9
internal/utils/min/min.go
Normal file
|
|
@ -0,0 +1,9 @@
|
||||||
|
package min
|
||||||
|
|
||||||
|
// Int returns minimum of two int values
|
||||||
|
func Int(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
@ -1,263 +1,170 @@
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/go-jet/jet/v2/internal/utils/min"
|
||||||
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
//===============================================================//
|
// NullBool struct
|
||||||
|
type NullBool struct {
|
||||||
// NullByteArray struct
|
sql.NullBool
|
||||||
type NullByteArray struct {
|
|
||||||
ByteArray []byte
|
|
||||||
Valid bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
// Scan implements the Scanner interface.
|
||||||
func (nb *NullByteArray) Scan(value interface{}) error {
|
func (nb *NullBool) Scan(value interface{}) error {
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case nil:
|
case bool:
|
||||||
nb.Valid = false
|
nb.Bool, nb.Valid = v, true
|
||||||
return nil
|
case int8, int16, int32, int64, int:
|
||||||
case []byte:
|
intVal := reflect.ValueOf(v).Int()
|
||||||
nb.ByteArray = append(v[:0:0], v...)
|
|
||||||
|
if intVal != 0 && intVal != 1 {
|
||||||
|
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
nb.Bool = intVal == 1
|
||||||
|
nb.Valid = true
|
||||||
|
case uint8, uint16, uint32, uint64, uint:
|
||||||
|
uintVal := reflect.ValueOf(v).Uint()
|
||||||
|
|
||||||
|
if uintVal != 0 && uintVal != 1 {
|
||||||
|
return fmt.Errorf("can't assign %T(%d) to bool", value, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
nb.Bool = uintVal == 1
|
||||||
nb.Valid = true
|
nb.Valid = true
|
||||||
return nil
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("can't scan []byte from %v", value)
|
return nb.NullBool.Scan(value)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
return nil
|
||||||
func (nb NullByteArray) Value() (driver.Value, error) {
|
|
||||||
if !nb.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nb.ByteArray, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===============================================================//
|
|
||||||
|
|
||||||
// NullTime struct
|
// NullTime struct
|
||||||
type NullTime struct {
|
type NullTime struct {
|
||||||
Time time.Time
|
sql.NullTime
|
||||||
Valid bool // Valid is true if Time is not NULL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
// Scan implements the Scanner interface.
|
||||||
func (nt *NullTime) Scan(value interface{}) (err error) {
|
func (nt *NullTime) Scan(value interface{}) error {
|
||||||
switch v := value.(type) {
|
err := nt.NullTime.Scan(value)
|
||||||
case nil:
|
|
||||||
nt.Valid = false
|
if err == nil {
|
||||||
return
|
return nil
|
||||||
case time.Time:
|
}
|
||||||
nt.Time, nt.Valid = v, true
|
|
||||||
return
|
// Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
|
||||||
case []byte:
|
// At this point we try to parse time using some of the predefined formats
|
||||||
nt.Time, nt.Valid = parseTime(string(v))
|
nt.Time, nt.Valid = tryParseAsTime(value)
|
||||||
return
|
|
||||||
case string:
|
if !nt.Valid {
|
||||||
nt.Time, nt.Valid = parseTime(v)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("can't scan time.Time from %v", value)
|
return fmt.Errorf("can't scan time.Time from %v", value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
var formats = []string{
|
||||||
func (nt NullTime) Value() (driver.Value, error) {
|
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql
|
||||||
if !nt.Valid {
|
"15:04:05-07", // pgx
|
||||||
return nil, nil
|
"15:04:05.999999", // pgx
|
||||||
}
|
|
||||||
return nt.Time, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const formatTime = "2006-01-02 15:04:05.999999"
|
func tryParseAsTime(value interface{}) (time.Time, bool) {
|
||||||
|
|
||||||
func parseTime(timeStr string) (t time.Time, valid bool) {
|
var timeStr string
|
||||||
|
|
||||||
var format string
|
|
||||||
|
|
||||||
switch len(timeStr) {
|
|
||||||
case 8:
|
|
||||||
format = formatTime[11:19]
|
|
||||||
case 10, 19, 21, 22, 23, 24, 25, 26:
|
|
||||||
format = formatTime[:len(timeStr)]
|
|
||||||
default:
|
|
||||||
return t, false
|
|
||||||
}
|
|
||||||
|
|
||||||
t, err := time.Parse(format, timeStr)
|
|
||||||
return t, err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//===============================================================//
|
|
||||||
|
|
||||||
// NullInt8 struct
|
|
||||||
type NullInt8 struct {
|
|
||||||
Int8 int8
|
|
||||||
Valid bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
|
||||||
func (n *NullInt8) Scan(value interface{}) (err error) {
|
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case nil:
|
case string:
|
||||||
n.Valid = false
|
timeStr = v
|
||||||
return
|
|
||||||
case int64:
|
|
||||||
n.Int8, n.Valid = int8(v), true
|
|
||||||
return
|
|
||||||
case int8:
|
|
||||||
n.Int8, n.Valid = v, true
|
|
||||||
return
|
|
||||||
case []byte:
|
case []byte:
|
||||||
intV, err := strconv.ParseInt(string(v), 10, 8)
|
timeStr = string(v)
|
||||||
if err == nil {
|
|
||||||
n.Int8, n.Valid = int8(intV), true
|
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
default:
|
for _, format := range formats {
|
||||||
return fmt.Errorf("can't scan int8 from %v", value)
|
formatLen := min.Int(len(format), len(timeStr))
|
||||||
|
t, err := time.Parse(format[:formatLen], timeStr)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return t, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Time{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
// NullUInt64 struct
|
||||||
func (n NullInt8) Value() (driver.Value, error) {
|
type NullUInt64 struct {
|
||||||
if !n.Valid {
|
UInt64 uint64
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return n.Int8, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//===============================================================//
|
|
||||||
|
|
||||||
// NullInt16 struct
|
|
||||||
type NullInt16 struct {
|
|
||||||
Int16 int16
|
|
||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
// Scan implements the Scanner interface.
|
||||||
func (n *NullInt16) Scan(value interface{}) error {
|
func (n *NullUInt64) Scan(value interface{}) error {
|
||||||
|
var stringValue string
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
n.Valid = false
|
n.Valid = false
|
||||||
return nil
|
return nil
|
||||||
case int64:
|
case int64:
|
||||||
n.Int16, n.Valid = int16(v), true
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
return nil
|
return nil
|
||||||
case int16:
|
case uint64:
|
||||||
n.Int16, n.Valid = v, true
|
n.UInt64, n.Valid = v, true
|
||||||
return nil
|
|
||||||
case int8:
|
|
||||||
n.Int16, n.Valid = int16(v), true
|
|
||||||
return nil
|
|
||||||
case uint8:
|
|
||||||
n.Int16, n.Valid = int16(v), true
|
|
||||||
return nil
|
|
||||||
case []byte:
|
|
||||||
intV, err := strconv.ParseInt(string(v), 10, 16)
|
|
||||||
if err == nil {
|
|
||||||
n.Int16, n.Valid = int16(intV), true
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("can't scan int16 from %v", value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
|
||||||
func (n NullInt16) Value() (driver.Value, error) {
|
|
||||||
if !n.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return n.Int16, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//===============================================================//
|
|
||||||
|
|
||||||
// NullInt32 struct
|
|
||||||
type NullInt32 struct {
|
|
||||||
Int32 int32
|
|
||||||
Valid bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
|
||||||
func (n *NullInt32) Scan(value interface{}) error {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case nil:
|
|
||||||
n.Valid = false
|
|
||||||
return nil
|
|
||||||
case int64:
|
|
||||||
n.Int32, n.Valid = int32(v), true
|
|
||||||
return nil
|
return nil
|
||||||
case int32:
|
case int32:
|
||||||
n.Int32, n.Valid = v, true
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
|
return nil
|
||||||
|
case uint32:
|
||||||
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
return nil
|
return nil
|
||||||
case int16:
|
case int16:
|
||||||
n.Int32, n.Valid = int32(v), true
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
return nil
|
return nil
|
||||||
case uint16:
|
case uint16:
|
||||||
n.Int32, n.Valid = int32(v), true
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
return nil
|
return nil
|
||||||
case int8:
|
case int8:
|
||||||
n.Int32, n.Valid = int32(v), true
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
return nil
|
return nil
|
||||||
case uint8:
|
case uint8:
|
||||||
n.Int32, n.Valid = int32(v), true
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
|
return nil
|
||||||
|
case int:
|
||||||
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
|
return nil
|
||||||
|
case uint:
|
||||||
|
n.UInt64, n.Valid = uint64(v), true
|
||||||
return nil
|
return nil
|
||||||
case []byte:
|
case []byte:
|
||||||
intV, err := strconv.ParseInt(string(v), 10, 32)
|
stringValue = string(v)
|
||||||
if err == nil {
|
case string:
|
||||||
n.Int32, n.Valid = int32(intV), true
|
stringValue = v
|
||||||
}
|
|
||||||
return nil
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("can't scan int32 from %v", value)
|
return fmt.Errorf("can't scan uint64 from %v", value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uintV, err := strconv.ParseUint(stringValue, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n.UInt64 = uintV
|
||||||
|
n.Valid = true
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
// Value implements the driver Valuer interface.
|
||||||
func (n NullInt32) Value() (driver.Value, error) {
|
func (n NullUInt64) Value() (driver.Value, error) {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return n.Int32, nil
|
return n.UInt64, nil
|
||||||
}
|
|
||||||
|
|
||||||
//===============================================================//
|
|
||||||
|
|
||||||
// NullFloat32 struct
|
|
||||||
type NullFloat32 struct {
|
|
||||||
Float32 float32
|
|
||||||
Valid bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
|
||||||
func (n *NullFloat32) Scan(value interface{}) error {
|
|
||||||
switch v := value.(type) {
|
|
||||||
case nil:
|
|
||||||
n.Valid = false
|
|
||||||
return nil
|
|
||||||
case float64:
|
|
||||||
n.Float32, n.Valid = float32(v), true
|
|
||||||
return nil
|
|
||||||
case float32:
|
|
||||||
n.Float32, n.Valid = v, true
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("can't scan float32 from %v", value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
|
||||||
func (n NullFloat32) Value() (driver.Value, error) {
|
|
||||||
if !n.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return n.Float32, nil
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,141 +7,85 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNullByteArray(t *testing.T) {
|
func TestNullBool(t *testing.T) {
|
||||||
var array NullByteArray
|
var nullBool NullBool
|
||||||
|
|
||||||
require.NoError(t, array.Scan(nil))
|
require.NoError(t, nullBool.Scan(nil))
|
||||||
require.Equal(t, array.Valid, false)
|
require.Equal(t, nullBool.Valid, false)
|
||||||
|
|
||||||
require.NoError(t, array.Scan([]byte("bytea")))
|
require.NoError(t, nullBool.Scan(int64(1)))
|
||||||
require.Equal(t, array.Valid, true)
|
require.Equal(t, nullBool.Valid, true)
|
||||||
require.Equal(t, string(array.ByteArray), string([]byte("bytea")))
|
value, _ := nullBool.Value()
|
||||||
|
require.Equal(t, value, true)
|
||||||
|
|
||||||
require.Error(t, array.Scan(12), "can't scan []byte from 12")
|
require.NoError(t, nullBool.Scan(uint32(0)))
|
||||||
|
require.Equal(t, nullBool.Valid, true)
|
||||||
|
value, _ = nullBool.Value()
|
||||||
|
require.Equal(t, value, false)
|
||||||
|
|
||||||
|
require.EqualError(t, nullBool.Scan(uint16(22)), "can't assign uint16(22) to bool")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNullTime(t *testing.T) {
|
func TestNullTime(t *testing.T) {
|
||||||
var array NullTime
|
var nullTime NullTime
|
||||||
|
|
||||||
require.NoError(t, array.Scan(nil))
|
require.NoError(t, nullTime.Scan(nil))
|
||||||
require.Equal(t, array.Valid, false)
|
require.Equal(t, nullTime.Valid, false)
|
||||||
|
|
||||||
time := time.Now()
|
time := time.Now()
|
||||||
require.NoError(t, array.Scan(time))
|
require.NoError(t, nullTime.Scan(time))
|
||||||
require.Equal(t, array.Valid, true)
|
require.Equal(t, nullTime.Valid, true)
|
||||||
value, _ := array.Value()
|
value, _ := nullTime.Value()
|
||||||
require.Equal(t, value, time)
|
require.Equal(t, value, time)
|
||||||
|
|
||||||
require.NoError(t, array.Scan([]byte("13:10:11")))
|
require.NoError(t, nullTime.Scan([]byte("13:10:11")))
|
||||||
require.Equal(t, array.Valid, true)
|
require.Equal(t, nullTime.Valid, true)
|
||||||
value, _ = array.Value()
|
value, _ = nullTime.Value()
|
||||||
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
|
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
|
||||||
|
|
||||||
require.NoError(t, array.Scan("13:10:11"))
|
require.NoError(t, nullTime.Scan("13:10:11"))
|
||||||
require.Equal(t, array.Valid, true)
|
require.Equal(t, nullTime.Valid, true)
|
||||||
value, _ = array.Value()
|
value, _ = nullTime.Value()
|
||||||
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
|
require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC")
|
||||||
|
|
||||||
require.Error(t, array.Scan(12), "can't scan time.Time from 12")
|
require.Error(t, nullTime.Scan(12), "can't scan time.Time from 12")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNullInt8(t *testing.T) {
|
func TestNullUInt64(t *testing.T) {
|
||||||
var array NullInt8
|
var nullUInt64 NullUInt64
|
||||||
|
|
||||||
require.NoError(t, array.Scan(nil))
|
require.NoError(t, nullUInt64.Scan(nil))
|
||||||
require.Equal(t, array.Valid, false)
|
require.Equal(t, nullUInt64.Valid, false)
|
||||||
|
|
||||||
require.NoError(t, array.Scan(int64(11)))
|
require.NoError(t, nullUInt64.Scan(int64(11)))
|
||||||
require.Equal(t, array.Valid, true)
|
require.Equal(t, nullUInt64.Valid, true)
|
||||||
value, _ := array.Value()
|
value, _ := nullUInt64.Value()
|
||||||
require.Equal(t, value, int8(11))
|
require.Equal(t, value, uint64(11))
|
||||||
|
|
||||||
require.Error(t, array.Scan("text"), "can't scan int8 from text")
|
require.NoError(t, nullUInt64.Scan(int32(32)))
|
||||||
}
|
require.Equal(t, nullUInt64.Valid, true)
|
||||||
|
value, _ = nullUInt64.Value()
|
||||||
func TestNullInt16(t *testing.T) {
|
require.Equal(t, value, uint64(32))
|
||||||
var array NullInt16
|
|
||||||
|
require.NoError(t, nullUInt64.Scan(int16(20)))
|
||||||
require.NoError(t, array.Scan(nil))
|
require.Equal(t, nullUInt64.Valid, true)
|
||||||
require.Equal(t, array.Valid, false)
|
value, _ = nullUInt64.Value()
|
||||||
|
require.Equal(t, value, uint64(20))
|
||||||
require.NoError(t, array.Scan(int64(11)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
require.NoError(t, nullUInt64.Scan(uint16(16)))
|
||||||
value, _ := array.Value()
|
require.Equal(t, nullUInt64.Valid, true)
|
||||||
require.Equal(t, value, int16(11))
|
value, _ = nullUInt64.Value()
|
||||||
|
require.Equal(t, value, uint64(16))
|
||||||
require.NoError(t, array.Scan(int16(20)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
require.NoError(t, nullUInt64.Scan(int8(30)))
|
||||||
value, _ = array.Value()
|
require.Equal(t, nullUInt64.Valid, true)
|
||||||
require.Equal(t, value, int16(20))
|
value, _ = nullUInt64.Value()
|
||||||
|
require.Equal(t, value, uint64(30))
|
||||||
require.NoError(t, array.Scan(int8(30)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
require.NoError(t, nullUInt64.Scan(uint8(30)))
|
||||||
value, _ = array.Value()
|
require.Equal(t, nullUInt64.Valid, true)
|
||||||
require.Equal(t, value, int16(30))
|
value, _ = nullUInt64.Value()
|
||||||
|
require.Equal(t, value, uint64(30))
|
||||||
require.NoError(t, array.Scan(uint8(30)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text")
|
||||||
value, _ = array.Value()
|
|
||||||
require.Equal(t, value, int16(30))
|
|
||||||
|
|
||||||
require.Error(t, array.Scan("text"), "can't scan int16 from text")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNullInt32(t *testing.T) {
|
|
||||||
var array NullInt32
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(nil))
|
|
||||||
require.Equal(t, array.Valid, false)
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(int64(11)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ := array.Value()
|
|
||||||
require.Equal(t, value, int32(11))
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(int32(32)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ = array.Value()
|
|
||||||
require.Equal(t, value, int32(32))
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(int16(20)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ = array.Value()
|
|
||||||
require.Equal(t, value, int32(20))
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(uint16(16)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ = array.Value()
|
|
||||||
require.Equal(t, value, int32(16))
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(int8(30)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ = array.Value()
|
|
||||||
require.Equal(t, value, int32(30))
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(uint8(30)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ = array.Value()
|
|
||||||
require.Equal(t, value, int32(30))
|
|
||||||
|
|
||||||
require.Error(t, array.Scan("text"), "can't scan int32 from text")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNullFloat32(t *testing.T) {
|
|
||||||
var array NullFloat32
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(nil))
|
|
||||||
require.Equal(t, array.Valid, false)
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(float64(64)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ := array.Value()
|
|
||||||
require.Equal(t, value, float32(64))
|
|
||||||
|
|
||||||
require.NoError(t, array.Scan(float32(32)))
|
|
||||||
require.Equal(t, array.Valid, true)
|
|
||||||
value, _ = array.Value()
|
|
||||||
require.Equal(t, value, float32(32))
|
|
||||||
|
|
||||||
require.Error(t, array.Scan(12), "can't scan float32 from 12")
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
17
qrm/qrm.go
17
qrm/qrm.go
|
|
@ -27,7 +27,10 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
|
||||||
|
|
||||||
if destinationPtrType.Elem().Kind() == reflect.Slice {
|
if destinationPtrType.Elem().Kind() == reflect.Slice {
|
||||||
_, err := queryToSlice(ctx, db, query, args, destPtr)
|
_, err := queryToSlice(ctx, db, query, args, destPtr)
|
||||||
return err
|
if err != nil {
|
||||||
|
return fmt.Errorf("jet: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
|
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
|
||||||
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
|
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
|
||||||
tempSliceValue := tempSlicePtrValue.Elem()
|
tempSliceValue := tempSlicePtrValue.Elem()
|
||||||
|
|
@ -35,7 +38,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
|
||||||
rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
|
rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("jet: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsProcessed == 0 {
|
if rowsProcessed == 0 {
|
||||||
|
|
@ -275,10 +278,16 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
|
||||||
err = scanner.Scan(cellValue)
|
err = scanner.Scan(cellValue)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String())
|
err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
setReflectValue(reflect.ValueOf(cellValue), fieldValue)
|
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,7 @@ package qrm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-jet/jet/v2/internal/utils"
|
|
||||||
"github.com/go-jet/jet/v2/internal/utils/throw"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
@ -46,7 +43,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return &scanContext{
|
return &scanContext{
|
||||||
row: createScanValue(columnTypes),
|
row: createScanSlice(len(columnTypes)),
|
||||||
uniqueDestObjectsMap: make(map[string]int),
|
uniqueDestObjectsMap: make(map[string]int),
|
||||||
|
|
||||||
groupKeyInfoCache: make(map[string]groupKeyInfo),
|
groupKeyInfoCache: make(map[string]groupKeyInfo),
|
||||||
|
|
@ -56,6 +53,17 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createScanSlice(columnCount int) []interface{} {
|
||||||
|
scanSlice := make([]interface{}, columnCount)
|
||||||
|
scanPtrSlice := make([]interface{}, columnCount)
|
||||||
|
|
||||||
|
for i := range scanPtrSlice {
|
||||||
|
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value
|
||||||
|
}
|
||||||
|
|
||||||
|
return scanPtrSlice
|
||||||
|
}
|
||||||
|
|
||||||
type typeInfo struct {
|
type typeInfo struct {
|
||||||
fieldMappings []fieldMapping
|
fieldMappings []fieldMapping
|
||||||
}
|
}
|
||||||
|
|
@ -210,16 +218,13 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scanContext) rowElem(index int) interface{} {
|
func (s *scanContext) rowElem(index int) interface{} {
|
||||||
|
cellValue := reflect.ValueOf(s.row[index])
|
||||||
|
|
||||||
valuer, ok := s.row[index].(driver.Valuer)
|
if cellValue.IsValid() && !cellValue.IsNil() {
|
||||||
|
return cellValue.Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer")
|
return nil
|
||||||
|
|
||||||
value, err := valuer.Value()
|
|
||||||
|
|
||||||
throw.OnError(err)
|
|
||||||
|
|
||||||
return value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
|
func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
|
||||||
|
|
|
||||||
275
qrm/utill.go
275
qrm/utill.go
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"github.com/go-jet/jet/v2/qrm/internal"
|
"github.com/go-jet/jet/v2/qrm/internal"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
@ -56,21 +55,22 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
|
||||||
sliceValue := slicePtrValue.Elem()
|
sliceValue := slicePtrValue.Elem()
|
||||||
sliceElemType := sliceValue.Type().Elem()
|
sliceElemType := sliceValue.Type().Elem()
|
||||||
|
|
||||||
newElemValue := objPtrValue
|
newSliceElemValue := reflect.New(sliceElemType).Elem()
|
||||||
|
|
||||||
if sliceElemType.Kind() != reflect.Ptr {
|
var err error
|
||||||
newElemValue = objPtrValue.Elem()
|
|
||||||
|
if newSliceElemValue.Kind() == reflect.Ptr {
|
||||||
|
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
|
||||||
|
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
|
||||||
|
} else {
|
||||||
|
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
if newElemValue.Type().ConvertibleTo(sliceElemType) {
|
if err != nil {
|
||||||
newElemValue = newElemValue.Convert(sliceElemType)
|
return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !newElemValue.Type().AssignableTo(sliceElemType) {
|
sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue))
|
||||||
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
|
|
||||||
}
|
|
||||||
|
|
||||||
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -121,7 +121,6 @@ func toCommonIdentifier(name string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func initializeValueIfNilPtr(value reflect.Value) {
|
func initializeValueIfNilPtr(value reflect.Value) {
|
||||||
|
|
||||||
if !value.IsValid() || !value.CanSet() {
|
if !value.IsValid() || !value.CanSet() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -173,172 +172,147 @@ func isSimpleModelType(objType reflect.Type) bool {
|
||||||
return objType == timeType || objType == uuidType || objType == byteArrayType
|
return objType == timeType || objType == uuidType || objType == byteArrayType
|
||||||
}
|
}
|
||||||
|
|
||||||
func isIntegerType(value reflect.Type) bool {
|
func isFloatType(value reflect.Type) bool {
|
||||||
switch value {
|
switch value.Kind() {
|
||||||
case int8Type, unit8Type, int16Type, uint16Type,
|
case reflect.Float32, reflect.Float64:
|
||||||
int32Type, uint32Type, int64Type, uint64Type:
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func isNumber(valueType reflect.Type) bool {
|
func tryAssign(source, destination reflect.Value) error {
|
||||||
return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func tryAssign(source, destination reflect.Value) bool {
|
if source.Type() != destination.Type() &&
|
||||||
|
!isFloatType(destination.Type()) && // to preserve precision during conversion
|
||||||
|
source.Type().ConvertibleTo(destination.Type()) {
|
||||||
|
|
||||||
switch {
|
|
||||||
case source.Type().ConvertibleTo(destination.Type()):
|
|
||||||
source = source.Convert(destination.Type())
|
source = source.Convert(destination.Type())
|
||||||
case isIntegerType(source.Type()) && destination.Type() == boolType:
|
|
||||||
intValue := source.Int()
|
|
||||||
|
|
||||||
if intValue == 1 {
|
|
||||||
source = reflect.ValueOf(true)
|
|
||||||
} else if intValue == 0 {
|
|
||||||
source = reflect.ValueOf(false)
|
|
||||||
}
|
|
||||||
case source.Type() == stringType && isNumber(destination.Type()):
|
|
||||||
// if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number
|
|
||||||
// and then parsed number is converted into destination type
|
|
||||||
f, err := strconv.ParseFloat(source.String(), 64)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
source = reflect.ValueOf(f)
|
|
||||||
|
|
||||||
if source.Type().ConvertibleTo(destination.Type()) {
|
|
||||||
source = source.Convert(destination.Type())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if source.Type().AssignableTo(destination.Type()) {
|
if source.Type().AssignableTo(destination.Type()) {
|
||||||
|
switch b := source.Interface().(type) {
|
||||||
|
case []byte:
|
||||||
|
destination.SetBytes(cloneBytes(b))
|
||||||
|
default:
|
||||||
destination.Set(source)
|
destination.Set(source)
|
||||||
return true
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
sourceInterface := source.Interface()
|
||||||
|
|
||||||
|
switch destination.Interface().(type) {
|
||||||
|
case bool:
|
||||||
|
var nullBool internal.NullBool
|
||||||
|
|
||||||
|
err := nullBool.Scan(sourceInterface)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
destination.SetBool(nullBool.Bool)
|
||||||
|
|
||||||
|
case float32, float64:
|
||||||
|
var nullFloat sql.NullFloat64
|
||||||
|
|
||||||
|
err := nullFloat.Scan(sourceInterface)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullFloat.Valid {
|
||||||
|
destination.SetFloat(nullFloat.Float64)
|
||||||
|
}
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
var integer sql.NullInt64
|
||||||
|
|
||||||
|
err := integer.Scan(sourceInterface)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if integer.Valid {
|
||||||
|
destination.SetInt(integer.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
var uInt internal.NullUInt64
|
||||||
|
|
||||||
|
err := uInt.Scan(sourceInterface)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if uInt.Valid {
|
||||||
|
destination.SetUint(uInt.UInt64)
|
||||||
|
}
|
||||||
|
|
||||||
|
case string:
|
||||||
|
var str sql.NullString
|
||||||
|
|
||||||
|
err := str.Scan(sourceInterface)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if str.Valid {
|
||||||
|
destination.SetString(str.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
case time.Time:
|
||||||
|
var nullTime internal.NullTime
|
||||||
|
|
||||||
|
err := nullTime.Scan(sourceInterface)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if nullTime.Valid {
|
||||||
|
destination.Set(reflect.ValueOf(nullTime.Time))
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setReflectValue(source, destination reflect.Value) {
|
func setReflectValue(source, destination reflect.Value) error {
|
||||||
|
|
||||||
if tryAssign(source, destination) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if destination.Kind() == reflect.Ptr {
|
if destination.Kind() == reflect.Ptr {
|
||||||
if source.Kind() == reflect.Ptr {
|
|
||||||
if !source.IsNil() {
|
|
||||||
if destination.IsNil() {
|
if destination.IsNil() {
|
||||||
initializeValueIfNilPtr(destination)
|
initializeValueIfNilPtr(destination)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tryAssign(source.Elem(), destination.Elem()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if source.CanAddr() {
|
|
||||||
source = source.Addr()
|
|
||||||
} else {
|
|
||||||
sourceCopy := reflect.New(source.Type())
|
|
||||||
sourceCopy.Elem().Set(source)
|
|
||||||
|
|
||||||
source = sourceCopy
|
|
||||||
}
|
|
||||||
|
|
||||||
if tryAssign(source, destination) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if tryAssign(source.Elem(), destination.Elem()) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if source.Kind() == reflect.Ptr {
|
if source.Kind() == reflect.Ptr {
|
||||||
if source.IsNil() {
|
if source.IsNil() {
|
||||||
return
|
return nil // source is nil, destination should keep its zero value
|
||||||
}
|
}
|
||||||
source = source.Elem()
|
source = source.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if tryAssign(source, destination) {
|
if err := tryAssign(source, destination.Elem()); err != nil {
|
||||||
return
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if source.Kind() == reflect.Ptr {
|
||||||
|
if source.IsNil() {
|
||||||
|
return nil // source is nil, destination should keep its zero value
|
||||||
|
}
|
||||||
|
source = source.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tryAssign(source, destination); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
|
|
||||||
values := make([]interface{}, len(columnTypes))
|
|
||||||
|
|
||||||
for i, sqlColumnType := range columnTypes {
|
|
||||||
columnType := newScanType(sqlColumnType)
|
|
||||||
|
|
||||||
columnValue := reflect.New(columnType)
|
|
||||||
|
|
||||||
values[i] = columnValue.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
||||||
var boolType = reflect.TypeOf(true)
|
|
||||||
var int8Type = reflect.TypeOf(int8(1))
|
|
||||||
var unit8Type = reflect.TypeOf(uint8(1))
|
|
||||||
var int16Type = reflect.TypeOf(int16(1))
|
|
||||||
var uint16Type = reflect.TypeOf(uint16(1))
|
|
||||||
var int32Type = reflect.TypeOf(int32(1))
|
|
||||||
var uint32Type = reflect.TypeOf(uint32(1))
|
|
||||||
var int64Type = reflect.TypeOf(int64(1))
|
|
||||||
var uint64Type = reflect.TypeOf(uint64(1))
|
|
||||||
var float32Type = reflect.TypeOf(float32(1))
|
|
||||||
var float64Type = reflect.TypeOf(float64(1))
|
|
||||||
var stringType = reflect.TypeOf("")
|
|
||||||
|
|
||||||
var nullBoolType = reflect.TypeOf(sql.NullBool{})
|
|
||||||
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
|
|
||||||
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
|
|
||||||
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
|
|
||||||
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
|
|
||||||
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
|
|
||||||
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
|
|
||||||
var nullStringType = reflect.TypeOf(sql.NullString{})
|
|
||||||
var nullTimeType = reflect.TypeOf(internal.NullTime{})
|
|
||||||
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
|
|
||||||
|
|
||||||
func newScanType(columnType *sql.ColumnType) reflect.Type {
|
|
||||||
|
|
||||||
switch columnType.DatabaseTypeName() {
|
|
||||||
case "TINYINT":
|
|
||||||
return nullInt8Type
|
|
||||||
case "INT2", "SMALLINT", "YEAR":
|
|
||||||
return nullInt16Type
|
|
||||||
case "INT4", "MEDIUMINT", "INT":
|
|
||||||
return nullInt32Type
|
|
||||||
case "INT8", "BIGINT":
|
|
||||||
return nullInt64Type
|
|
||||||
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
|
|
||||||
return nullStringType
|
|
||||||
case "FLOAT4":
|
|
||||||
return nullFloat32Type
|
|
||||||
case "FLOAT8", "FLOAT", "DOUBLE":
|
|
||||||
return nullFloat64Type
|
|
||||||
case "BOOL":
|
|
||||||
return nullBoolType
|
|
||||||
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
|
|
||||||
return nullByteArrayType
|
|
||||||
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
|
|
||||||
return nullTimeType
|
|
||||||
default:
|
|
||||||
return nullStringType
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
|
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
|
||||||
|
|
@ -385,3 +359,12 @@ func fieldToString(field *reflect.StructField) string {
|
||||||
|
|
||||||
return " at '" + field.Name + " " + field.Type.String() + "'"
|
return " at '" + field.Name + " " + field.Type.String() + "'"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneBytes(b []byte) []byte {
|
||||||
|
if b == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c := make([]byte, len(b))
|
||||||
|
copy(c, b)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAllTypesSelect(t *testing.T) {
|
func TestAllTypesSelect(t *testing.T) {
|
||||||
skipForPgxDriver(t) // pgx driver returns time with time zone as string
|
|
||||||
|
|
||||||
dest := []model.AllTypes{}
|
dest := []model.AllTypes{}
|
||||||
|
|
||||||
err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest)
|
err := AllTypes.SELECT(
|
||||||
|
AllTypes.AllColumns,
|
||||||
|
).LIMIT(2).
|
||||||
|
Query(db, &dest)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
|
testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
|
||||||
|
|
@ -29,8 +30,6 @@ func TestAllTypesSelect(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllTypesViewSelect(t *testing.T) {
|
func TestAllTypesViewSelect(t *testing.T) {
|
||||||
skipForPgxDriver(t) // pgx driver returns time with time zone as string
|
|
||||||
|
|
||||||
type AllTypesView model.AllTypes
|
type AllTypesView model.AllTypes
|
||||||
|
|
||||||
dest := []AllTypesView{}
|
dest := []AllTypesView{}
|
||||||
|
|
@ -43,7 +42,7 @@ func TestAllTypesViewSelect(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllTypesInsertModel(t *testing.T) {
|
func TestAllTypesInsertModel(t *testing.T) {
|
||||||
skipForPgxDriver(t) // pgx driver does not handle well time with time zone
|
skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008)
|
||||||
|
|
||||||
query := AllTypes.INSERT(AllTypes.AllColumns).
|
query := AllTypes.INSERT(AllTypes.AllColumns).
|
||||||
MODEL(allTypesRow0).
|
MODEL(allTypesRow0).
|
||||||
|
|
@ -60,8 +59,6 @@ func TestAllTypesInsertModel(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllTypesInsertQuery(t *testing.T) {
|
func TestAllTypesInsertQuery(t *testing.T) {
|
||||||
skipForPgxDriver(t) // pgx driver does not handle well time with time zone
|
|
||||||
|
|
||||||
query := AllTypes.INSERT(AllTypes.AllColumns).
|
query := AllTypes.INSERT(AllTypes.AllColumns).
|
||||||
QUERY(
|
QUERY(
|
||||||
AllTypes.
|
AllTypes.
|
||||||
|
|
@ -80,8 +77,6 @@ func TestAllTypesInsertQuery(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAllTypesFromSubQuery(t *testing.T) {
|
func TestAllTypesFromSubQuery(t *testing.T) {
|
||||||
skipForPgxDriver(t)
|
|
||||||
|
|
||||||
subQuery := SELECT(AllTypes.AllColumns).
|
subQuery := SELECT(AllTypes.AllColumns).
|
||||||
FROM(AllTypes).
|
FROM(AllTypes).
|
||||||
AsTable("allTypesSubQuery")
|
AsTable("allTypesSubQuery")
|
||||||
|
|
@ -302,10 +297,10 @@ LIMIT $11;
|
||||||
|
|
||||||
func TestExpressionCast(t *testing.T) {
|
func TestExpressionCast(t *testing.T) {
|
||||||
|
|
||||||
skipForPgxDriver(t) // for some reason, pgx driver, 150:char(12) returns as int value
|
skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text'
|
||||||
|
|
||||||
query := AllTypes.SELECT(
|
query := AllTypes.SELECT(
|
||||||
CAST(Int(150)).AS_CHAR(12).AS("char12"),
|
CAST(Int(151)).AS_CHAR(12).AS("char12"),
|
||||||
CAST(String("TRUE")).AS_BOOL(),
|
CAST(String("TRUE")).AS_BOOL(),
|
||||||
CAST(String("111")).AS_SMALLINT(),
|
CAST(String("111")).AS_SMALLINT(),
|
||||||
CAST(String("111")).AS_INTEGER(),
|
CAST(String("111")).AS_INTEGER(),
|
||||||
|
|
@ -349,7 +344,7 @@ func TestExpressionCast(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStringOperators(t *testing.T) {
|
func TestStringOperators(t *testing.T) {
|
||||||
skipForPgxDriver(t) // pgx driver returns text column as int value
|
skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text'
|
||||||
|
|
||||||
query := AllTypes.SELECT(
|
query := AllTypes.SELECT(
|
||||||
AllTypes.Text.EQ(AllTypes.Char),
|
AllTypes.Text.EQ(AllTypes.Char),
|
||||||
|
|
@ -866,8 +861,6 @@ func TestInterval(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSubQueryColumnReference(t *testing.T) {
|
func TestSubQueryColumnReference(t *testing.T) {
|
||||||
skipForPgxDriver(t) // pgx driver returns time with time zone as string value
|
|
||||||
|
|
||||||
type expected struct {
|
type expected struct {
|
||||||
sql string
|
sql string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
|
|
@ -1044,8 +1037,6 @@ FROM`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeLiterals(t *testing.T) {
|
func TestTimeLiterals(t *testing.T) {
|
||||||
skipForPgxDriver(t) // pgx driver returns time with time zone as string
|
|
||||||
|
|
||||||
loc, err := time.LoadLocation("Europe/Berlin")
|
loc, err := time.LoadLocation("Europe/Berlin")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
@ -1060,8 +1051,6 @@ func TestTimeLiterals(t *testing.T) {
|
||||||
).FROM(AllTypes).
|
).FROM(AllTypes).
|
||||||
LIMIT(1)
|
LIMIT(1)
|
||||||
|
|
||||||
//fmt.Println(query.Sql())
|
|
||||||
|
|
||||||
testutils.AssertStatementSql(t, query, `
|
testutils.AssertStatementSql(t, query, `
|
||||||
SELECT $1::date AS "date",
|
SELECT $1::date AS "date",
|
||||||
$2::time without time zone AS "time",
|
$2::time without time zone AS "time",
|
||||||
|
|
@ -1077,21 +1066,25 @@ LIMIT $6;
|
||||||
Time time.Time
|
Time time.Time
|
||||||
Timez time.Time
|
Timez time.Time
|
||||||
Timestamp time.Time
|
Timestamp time.Time
|
||||||
//Timestampz time.Time
|
Timestampz time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
err = query.Query(db, &dest)
|
err = query.Query(db, &dest)
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
//testutils.PrintJson(dest)
|
// pq driver will return time with time zone in local timezone,
|
||||||
|
// while pgx driver will return time in UTC time zone
|
||||||
|
dest.Timez = dest.Timez.UTC()
|
||||||
|
dest.Timestampz = dest.Timestampz.UTC()
|
||||||
|
|
||||||
testutils.AssertJSON(t, dest, `
|
testutils.AssertJSON(t, dest, `
|
||||||
{
|
{
|
||||||
"Date": "2009-11-17T00:00:00Z",
|
"Date": "2009-11-17T00:00:00Z",
|
||||||
"Time": "0000-01-01T20:34:58.651387Z",
|
"Time": "0000-01-01T20:34:58.651387Z",
|
||||||
"Timez": "0000-01-01T20:34:58.651387+01:00",
|
"Timez": "0000-01-01T19:34:58.651387Z",
|
||||||
"Timestamp": "2009-11-17T20:34:58.651387Z"
|
"Timestamp": "2009-11-17T20:34:58.651387Z",
|
||||||
|
"Timestampz": "2009-11-17T19:34:58.651387Z"
|
||||||
}
|
}
|
||||||
`)
|
`)
|
||||||
requireLogged(t, query)
|
requireLogged(t, query)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ func TestMain(m *testing.M) {
|
||||||
|
|
||||||
setTestRoot()
|
setTestRoot()
|
||||||
|
|
||||||
for _, driverName := range []string{"postgres", "pgx"} {
|
for _, driverName := range []string{"pgx", "postgres"} {
|
||||||
fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName)
|
fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName)
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
|
|
@ -81,8 +81,16 @@ func requireLogged(t *testing.T, statement postgres.Statement) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func skipForPgxDriver(t *testing.T) {
|
func skipForPgxDriver(t *testing.T) {
|
||||||
switch db.Driver().(type) {
|
if isPgxDriver() {
|
||||||
case *stdlib.Driver:
|
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isPgxDriver() bool {
|
||||||
|
switch db.Driver().(type) {
|
||||||
|
case *stdlib.Driver:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -78,16 +78,31 @@ func TestScanToValidDestination(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("pointer to slice of strings", func(t *testing.T) {
|
t.Run("pointer to slice of integers", func(t *testing.T) {
|
||||||
err := oneInventoryQuery.Query(db, &[]int32{})
|
var dest []int32
|
||||||
|
|
||||||
|
err := oneInventoryQuery.Query(db, &dest)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, dest[0], int32(1))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("pointer to slice of strings", func(t *testing.T) {
|
t.Run("pointer to slice integer pointers", func(t *testing.T) {
|
||||||
err := oneInventoryQuery.Query(db, &[]*int32{})
|
var dest []*int32
|
||||||
|
|
||||||
|
err := oneInventoryQuery.Query(db, &dest)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, dest[0], testutils.Int32Ptr(1))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NULL to integer", func(t *testing.T) {
|
||||||
|
var dest struct {
|
||||||
|
Int64 int64
|
||||||
|
UInt64 uint64
|
||||||
|
}
|
||||||
|
err := SELECT(NULL.AS("int64"), NULL.AS("uint64")).Query(db, &dest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, dest.Int64, int64(0))
|
||||||
|
require.Equal(t, dest.UInt64, uint64(0))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -189,7 +204,9 @@ func TestScanToStruct(t *testing.T) {
|
||||||
|
|
||||||
dest := Inventory{}
|
dest := Inventory{}
|
||||||
|
|
||||||
testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: Scan: unable to scan type int32 into UUID, at 'InventoryID uuid.UUID' of type postgres.Inventory`)
|
err := query.Query(db, &dest)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("type mismatch base type", func(t *testing.T) {
|
t.Run("type mismatch base type", func(t *testing.T) {
|
||||||
|
|
@ -200,7 +217,9 @@ func TestScanToStruct(t *testing.T) {
|
||||||
|
|
||||||
dest := []Inventory{}
|
dest := []Inventory{}
|
||||||
|
|
||||||
testutils.AssertQueryPanicErr(t, query.OFFSET(10), db, &dest, `jet: can't set int16 to bool`)
|
err := query.OFFSET(10).Query(db, &dest)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.EqualError(t, err, "jet: can't assign int64('\\x02') to 'FilmID bool': can't assign int64(2) to bool")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -451,8 +470,9 @@ func TestScanToSlice(t *testing.T) {
|
||||||
t.Run("slice type mismatch", func(t *testing.T) {
|
t.Run("slice type mismatch", func(t *testing.T) {
|
||||||
var dest []bool
|
var dest []bool
|
||||||
|
|
||||||
testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: can't append int32 to []bool slice`)
|
err := query.Query(db, &dest)
|
||||||
//require.Error(t, err, `jet: can't append int32 to []bool slice `)
|
require.Error(t, err)
|
||||||
|
require.EqualError(t, err, `jet: can't append int64 to []bool slice: can't assign int64(2) to bool`)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -764,16 +784,8 @@ func TestRowsScan(t *testing.T) {
|
||||||
requireLogged(t, stmt)
|
requireLogged(t, stmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestScanNumericToNumber(t *testing.T) {
|
func TestScanNumericToFloat(t *testing.T) {
|
||||||
type Number struct {
|
type Number struct {
|
||||||
Int8 int8
|
|
||||||
UInt8 uint8
|
|
||||||
Int16 int16
|
|
||||||
UInt16 uint16
|
|
||||||
Int32 int32
|
|
||||||
UInt32 uint32
|
|
||||||
Int64 int64
|
|
||||||
UInt64 uint64
|
|
||||||
Float32 float32
|
Float32 float32
|
||||||
Float64 float64
|
Float64 float64
|
||||||
}
|
}
|
||||||
|
|
@ -781,14 +793,6 @@ func TestScanNumericToNumber(t *testing.T) {
|
||||||
numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC()
|
numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC()
|
||||||
|
|
||||||
stmt := SELECT(
|
stmt := SELECT(
|
||||||
numeric.AS("number.int8"),
|
|
||||||
numeric.AS("number.uint8"),
|
|
||||||
numeric.AS("number.int16"),
|
|
||||||
numeric.AS("number.uint16"),
|
|
||||||
numeric.AS("number.int32"),
|
|
||||||
numeric.AS("number.uint32"),
|
|
||||||
numeric.AS("number.int64"),
|
|
||||||
numeric.AS("number.uint64"),
|
|
||||||
numeric.AS("number.float32"),
|
numeric.AS("number.float32"),
|
||||||
numeric.AS("number.float64"),
|
numeric.AS("number.float64"),
|
||||||
)
|
)
|
||||||
|
|
@ -796,19 +800,30 @@ func TestScanNumericToNumber(t *testing.T) {
|
||||||
var number Number
|
var number Number
|
||||||
err := stmt.Query(db, &number)
|
err := stmt.Query(db, &number)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Equal(t, number.Int8, int8(-46)) // overflow
|
|
||||||
require.Equal(t, number.UInt8, uint8(210)) // overflow
|
|
||||||
require.Equal(t, number.Int16, int16(722)) // overflow
|
|
||||||
require.Equal(t, number.UInt16, uint16(722)) // overflow
|
|
||||||
require.Equal(t, number.Int32, int32(1234567890))
|
|
||||||
require.Equal(t, number.UInt32, uint32(1234567890))
|
|
||||||
require.Equal(t, number.Int64, int64(1234567890))
|
|
||||||
require.Equal(t, number.UInt64, uint64(1234567890))
|
|
||||||
require.Equal(t, number.Float32, float32(1.234568e+09))
|
require.Equal(t, number.Float32, float32(1.234568e+09))
|
||||||
require.Equal(t, number.Float64, float64(1.234567890111e+09))
|
require.Equal(t, number.Float64, float64(1.234567890111e+09))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScanNumericToIntegerError(t *testing.T) {
|
||||||
|
|
||||||
|
var dest struct {
|
||||||
|
Integer int32
|
||||||
|
}
|
||||||
|
|
||||||
|
err := SELECT(
|
||||||
|
CAST(Decimal("1234567890.111")).AS_NUMERIC().AS("integer"),
|
||||||
|
).Query(db, &dest)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
if isPgxDriver() {
|
||||||
|
require.Contains(t, err.Error(), `jet: can't assign string("1234567890.111") to 'Integer int32': converting driver.Value type string ("1234567890.111") to a int64: invalid syntax`)
|
||||||
|
} else {
|
||||||
|
require.Contains(t, err.Error(), `jet: can't assign []uint8("1234567890.111") to 'Integer int32': converting driver.Value type []uint8 ("1234567890.111") to a int64: invalid syntax`)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// QueryContext panic when the scanned value is nil and the destination is a slice of primitive
|
// QueryContext panic when the scanned value is nil and the destination is a slice of primitive
|
||||||
// https://github.com/go-jet/jet/issues/91
|
// https://github.com/go-jet/jet/issues/91
|
||||||
func TestScanToPrimitiveElementsSlice(t *testing.T) {
|
func TestScanToPrimitiveElementsSlice(t *testing.T) {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue