Qrm refactor

- Allow custom types Scan method to read values returned by the driver rather then the value from intermediate Null types. Scan to intermidiate Null types removed.
- Better error handling
This commit is contained in:
go-jet 2021-10-15 17:43:10 +02:00
parent 555ec293fb
commit 0d418890ab
11 changed files with 459 additions and 574 deletions

View file

@ -7,7 +7,6 @@ import (
"github.com/go-jet/jet/v2/qrm/internal"
"github.com/google/uuid"
"reflect"
"strconv"
"strings"
"time"
)
@ -56,21 +55,22 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem()
newElemValue := objPtrValue
newSliceElemValue := reflect.New(sliceElemType).Elem()
if sliceElemType.Kind() != reflect.Ptr {
newElemValue = objPtrValue.Elem()
var err error
if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
}
if newElemValue.Type().ConvertibleTo(sliceElemType) {
newElemValue = newElemValue.Convert(sliceElemType)
if err != nil {
return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err)
}
if !newElemValue.Type().AssignableTo(sliceElemType) {
panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice")
}
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue))
return nil
}
@ -121,7 +121,6 @@ func toCommonIdentifier(name string) string {
}
func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() {
return
}
@ -173,172 +172,147 @@ func isSimpleModelType(objType reflect.Type) bool {
return objType == timeType || objType == uuidType || objType == byteArrayType
}
func isIntegerType(value reflect.Type) bool {
switch value {
case int8Type, unit8Type, int16Type, uint16Type,
int32Type, uint32Type, int64Type, uint64Type:
func isFloatType(value reflect.Type) bool {
switch value.Kind() {
case reflect.Float32, reflect.Float64:
return true
}
return false
}
func isNumber(valueType reflect.Type) bool {
return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type
}
func tryAssign(source, destination reflect.Value) error {
func tryAssign(source, destination reflect.Value) bool {
if source.Type() != destination.Type() &&
!isFloatType(destination.Type()) && // to preserve precision during conversion
source.Type().ConvertibleTo(destination.Type()) {
switch {
case source.Type().ConvertibleTo(destination.Type()):
source = source.Convert(destination.Type())
case isIntegerType(source.Type()) && destination.Type() == boolType:
intValue := source.Int()
if intValue == 1 {
source = reflect.ValueOf(true)
} else if intValue == 0 {
source = reflect.ValueOf(false)
}
case source.Type() == stringType && isNumber(destination.Type()):
// if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number
// and then parsed number is converted into destination type
f, err := strconv.ParseFloat(source.String(), 64)
if err != nil {
return false
}
source = reflect.ValueOf(f)
if source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
}
if source.Type().AssignableTo(destination.Type()) {
destination.Set(source)
return true
switch b := source.Interface().(type) {
case []byte:
destination.SetBytes(cloneBytes(b))
default:
destination.Set(source)
}
return nil
}
return false
sourceInterface := source.Interface()
switch destination.Interface().(type) {
case bool:
var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface)
if err != nil {
return err
}
destination.SetBool(nullBool.Bool)
case float32, float64:
var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface)
if err != nil {
return err
}
if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64)
}
case int, int8, int16, int32, int64:
var integer sql.NullInt64
err := integer.Scan(sourceInterface)
if err != nil {
return err
}
if integer.Valid {
destination.SetInt(integer.Int64)
}
case uint, uint8, uint16, uint32, uint64:
var uInt internal.NullUInt64
err := uInt.Scan(sourceInterface)
if err != nil {
return err
}
if uInt.Valid {
destination.SetUint(uInt.UInt64)
}
case string:
var str sql.NullString
err := str.Scan(sourceInterface)
if err != nil {
return err
}
if str.Valid {
destination.SetString(str.String)
}
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
}
return nil
}
func setReflectValue(source, destination reflect.Value) {
if tryAssign(source, destination) {
return
}
func setReflectValue(source, destination reflect.Value) error {
if destination.Kind() == reflect.Ptr {
if source.Kind() == reflect.Ptr {
if !source.IsNil() {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
} else {
return
}
} else {
if source.CanAddr() {
source = source.Addr()
} else {
sourceCopy := reflect.New(source.Type())
sourceCopy.Elem().Set(source)
source = sourceCopy
}
if tryAssign(source, destination) {
return
}
if tryAssign(source.Elem(), destination.Elem()) {
return
}
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if tryAssign(source, destination) {
return
if err := tryAssign(source, destination.Elem()); err != nil {
return err
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination); err != nil {
return err
}
}
panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String())
}
func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
values := make([]interface{}, len(columnTypes))
for i, sqlColumnType := range columnTypes {
columnType := newScanType(sqlColumnType)
columnValue := reflect.New(columnType)
values[i] = columnValue.Interface()
}
return values
}
var boolType = reflect.TypeOf(true)
var int8Type = reflect.TypeOf(int8(1))
var unit8Type = reflect.TypeOf(uint8(1))
var int16Type = reflect.TypeOf(int16(1))
var uint16Type = reflect.TypeOf(uint16(1))
var int32Type = reflect.TypeOf(int32(1))
var uint32Type = reflect.TypeOf(uint32(1))
var int64Type = reflect.TypeOf(int64(1))
var uint64Type = reflect.TypeOf(uint64(1))
var float32Type = reflect.TypeOf(float32(1))
var float64Type = reflect.TypeOf(float64(1))
var stringType = reflect.TypeOf("")
var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullInt8Type = reflect.TypeOf(internal.NullInt8{})
var nullInt16Type = reflect.TypeOf(internal.NullInt16{})
var nullInt32Type = reflect.TypeOf(internal.NullInt32{})
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{})
var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{})
var nullStringType = reflect.TypeOf(sql.NullString{})
var nullTimeType = reflect.TypeOf(internal.NullTime{})
var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
switch columnType.DatabaseTypeName() {
case "TINYINT":
return nullInt8Type
case "INT2", "SMALLINT", "YEAR":
return nullInt16Type
case "INT4", "MEDIUMINT", "INT":
return nullInt32Type
case "INT8", "BIGINT":
return nullInt64Type
case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML":
return nullStringType
case "FLOAT4":
return nullFloat32Type
case "FLOAT8", "FLOAT", "DOUBLE":
return nullFloat64Type
case "BOOL":
return nullBoolType
case "BYTEA", "BINARY", "VARBINARY", "BLOB":
return nullByteArrayType
case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ":
return nullTimeType
default:
return nullStringType
}
return nil
}
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
@ -385,3 +359,12 @@ func fieldToString(field *reflect.StructField) string {
return " at '" + field.Name + " " + field.Type.String() + "'"
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
c := make([]byte, len(b))
copy(c, b)
return c
}