diff --git a/qrm/utill.go b/qrm/utill.go index 1574599..f485797 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -183,6 +183,10 @@ func isIntegerType(value reflect.Type) bool { return false } +func isNumber(valueType reflect.Type) bool { + return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type +} + func tryAssign(source, destination reflect.Value) bool { switch { @@ -196,13 +200,18 @@ func tryAssign(source, destination reflect.Value) bool { } else if intValue == 0 { source = reflect.ValueOf(false) } - case source.Type() == stringType && destination.Type() == float64Type: - strValue := source.String() - f, err := strconv.ParseFloat(strValue, 64) + case source.Type() == stringType && isNumber(destination.Type()): + // if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number + // and then parsed number is converted into destination type + f, err := strconv.ParseFloat(source.String(), 64) if err != nil { return false } source = reflect.ValueOf(f) + + if source.Type().ConvertibleTo(destination.Type()) { + source = source.Convert(destination.Type()) + } } if source.Type().AssignableTo(destination.Type()) { @@ -289,6 +298,7 @@ var int32Type = reflect.TypeOf(int32(1)) var uint32Type = reflect.TypeOf(uint32(1)) var int64Type = reflect.TypeOf(int64(1)) var uint64Type = reflect.TypeOf(uint64(1)) +var float32Type = reflect.TypeOf(float32(1)) var float64Type = reflect.TypeOf(float64(1)) var stringType = reflect.TypeOf("") diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 6bbc211..5a88acf 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -927,3 +927,48 @@ func TestRowsScan(t *testing.T) { requireLogged(t, stmt) } + +func TestScanNumericToNumber(t *testing.T) { + type Number struct { + Int8 int8 + UInt8 uint8 + Int16 int16 + UInt16 uint16 + Int32 int32 + UInt32 uint32 + Int64 int64 + UInt64 uint64 + Float32 float32 + Float64 float64 + } + + numeric := CAST(Decimal("1234567890.111")).AS_DECIMAL() + + stmt := SELECT( + numeric.AS("number.int8"), + numeric.AS("number.uint8"), + numeric.AS("number.int16"), + numeric.AS("number.uint16"), + numeric.AS("number.int32"), + numeric.AS("number.uint32"), + numeric.AS("number.int64"), + numeric.AS("number.uint64"), + numeric.AS("number.float32"), + numeric.AS("number.float64"), + ) + + var number Number + err := stmt.Query(db, &number) + require.NoError(t, err) + + require.Equal(t, number.Int8, int8(-46)) // overflow + require.Equal(t, number.UInt8, uint8(210)) // overflow + require.Equal(t, number.Int16, int16(722)) // overflow + require.Equal(t, number.UInt16, uint16(722)) // overflow + require.Equal(t, number.Int32, int32(1234567890)) + require.Equal(t, number.UInt32, uint32(1234567890)) + require.Equal(t, number.Int64, int64(1234567890)) + require.Equal(t, number.UInt64, uint64(1234567890)) + require.Equal(t, number.Float32, float32(1.234568e+09)) + require.Equal(t, number.Float64, float64(1.23456789e+09)) +} diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index dacdf88..4a80ba0 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -764,6 +764,51 @@ func TestRowsScan(t *testing.T) { requireLogged(t, stmt) } +func TestScanNumericToNumber(t *testing.T) { + type Number struct { + Int8 int8 + UInt8 uint8 + Int16 int16 + UInt16 uint16 + Int32 int32 + UInt32 uint32 + Int64 int64 + UInt64 uint64 + Float32 float32 + Float64 float64 + } + + numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC() + + stmt := SELECT( + numeric.AS("number.int8"), + numeric.AS("number.uint8"), + numeric.AS("number.int16"), + numeric.AS("number.uint16"), + numeric.AS("number.int32"), + numeric.AS("number.uint32"), + numeric.AS("number.int64"), + numeric.AS("number.uint64"), + numeric.AS("number.float32"), + numeric.AS("number.float64"), + ) + + var number Number + err := stmt.Query(db, &number) + require.NoError(t, err) + + require.Equal(t, number.Int8, int8(-46)) // overflow + require.Equal(t, number.UInt8, uint8(210)) // overflow + require.Equal(t, number.Int16, int16(722)) // overflow + require.Equal(t, number.UInt16, uint16(722)) // overflow + require.Equal(t, number.Int32, int32(1234567890)) + require.Equal(t, number.UInt32, uint32(1234567890)) + require.Equal(t, number.Int64, int64(1234567890)) + require.Equal(t, number.UInt64, uint64(1234567890)) + require.Equal(t, number.Float32, float32(1.234568e+09)) + require.Equal(t, number.Float64, float64(1.234567890111e+09)) +} + var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive",