From 063b17ca05cfb4d332341cc3c6665be29bf53a60 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sun, 9 May 2021 16:37:16 +0200 Subject: [PATCH] Update lossless decimal tests to use new floats test table and DECIMAL literal constructor. --- go.mod | 8 +- qrm/utill.go | 14 ++-- qrm/utill_test.go | 45 +++++++++++ tests/mysql/alltypes_test.go | 97 +++++++++++++++++++++++ tests/postgres/generator_test.go | 4 +- tests/postgres/sample_test.go | 130 ++++++++++++++++++++----------- tests/testdata | 2 +- 7 files changed, 238 insertions(+), 62 deletions(-) diff --git a/go.mod b/go.mod index 923117c..b6a3951 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,10 @@ go 1.11 require ( github.com/go-sql-driver/mysql v1.5.0 - github.com/google/go-cmp v0.5.0 + github.com/google/go-cmp v0.5.0 //tests github.com/google/uuid v1.1.1 github.com/lib/pq v1.7.0 - github.com/pkg/profile v1.5.0 - github.com/shopspring/decimal v1.2.0 - github.com/stretchr/testify v1.6.1 + github.com/pkg/profile v1.5.0 //tests + github.com/shopspring/decimal v1.2.0 // tests + github.com/stretchr/testify v1.6.1 // tests ) diff --git a/qrm/utill.go b/qrm/utill.go index 2b208eb..1574599 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -7,9 +7,9 @@ import ( "github.com/go-jet/jet/v2/qrm/internal" "github.com/google/uuid" "reflect" + "strconv" "strings" "time" - "strconv" ) var scannerInterfaceType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() @@ -184,11 +184,11 @@ func isIntegerType(value reflect.Type) bool { } func tryAssign(source, destination reflect.Value) bool { - if source.Type().ConvertibleTo(destination.Type()) { - source = source.Convert(destination.Type()) - } - if isIntegerType(source.Type()) && destination.Type() == boolType { + switch { + case source.Type().ConvertibleTo(destination.Type()): + source = source.Convert(destination.Type()) + case isIntegerType(source.Type()) && destination.Type() == boolType: intValue := source.Int() if intValue == 1 { @@ -196,9 +196,7 @@ func tryAssign(source, destination reflect.Value) bool { } else if intValue == 0 { source = reflect.ValueOf(false) } - } - - if source.Type() == stringType && destination.Type() == float64Type { + case source.Type() == stringType && destination.Type() == float64Type: strValue := source.String() f, err := strconv.ParseFloat(strValue, 64) if err != nil { diff --git a/qrm/utill_test.go b/qrm/utill_test.go index 897bb2c..e23fa15 100644 --- a/qrm/utill_test.go +++ b/qrm/utill_test.go @@ -35,3 +35,48 @@ func TestIsSimpleModelType(t *testing.T) { require.Equal(t, isSimpleModelType(reflect.TypeOf([]string{"str"})), false) require.Equal(t, isSimpleModelType(reflect.TypeOf([]int{1, 2})), false) } + +func TestTryAssign(t *testing.T) { + convertible := int16(16) + intBool1 := int32(1) + intBool0 := int32(0) + intBool2 := int32(2) + floatStr := "1.11" + floatErr := "1.abcd2" + str := "some string" + + destination := struct { + Convertible int64 + IntBool1 bool + IntBool0 bool + IntBool2 bool + FloatStr float64 + FloatErr float64 + Str string + }{} + + testValue := reflect.ValueOf(&destination).Elem() + + // convertible + require.True(t, tryAssign(reflect.ValueOf(convertible), testValue.FieldByName("Convertible"))) + require.Equal(t, int64(16), destination.Convertible) + + // 1/0 to bool + require.True(t, tryAssign(reflect.ValueOf(intBool1), testValue.FieldByName("IntBool1"))) + require.Equal(t, true, destination.IntBool1) + require.True(t, tryAssign(reflect.ValueOf(intBool0), testValue.FieldByName("IntBool0"))) + require.Equal(t, false, destination.IntBool0) + + require.False(t, tryAssign(reflect.ValueOf(intBool2), testValue.FieldByName("IntBool2"))) + require.Equal(t, false, destination.IntBool2) + + // string to float + require.True(t, tryAssign(reflect.ValueOf(floatStr), testValue.FieldByName("FloatStr"))) + require.Equal(t, 1.11, destination.FloatStr) + require.False(t, tryAssign(reflect.ValueOf(floatErr), testValue.FieldByName("FloatErr"))) + require.Equal(t, 0.00, destination.FloatErr) + + // string to string + require.True(t, tryAssign(reflect.ValueOf(str), testValue.FieldByName("Str"))) + require.Equal(t, str, destination.Str) +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 6fbb136..8c1539a 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,6 +1,7 @@ package mysql import ( + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "strings" "testing" @@ -1278,3 +1279,99 @@ FROM test_sample.user; ] `) } + +func TestExactDecimals(t *testing.T) { + + type floats struct { + model.Floats + Numeric decimal.Decimal + NumericPtr decimal.Decimal + Decimal decimal.Decimal + DecimalPtr decimal.Decimal + } + + t.Run("should query decimal", func(t *testing.T) { + query := SELECT( + Floats.AllColumns, + ).FROM( + Floats, + ).WHERE(Floats.Decimal.EQ(Decimal("1.11111111111111111111"))) + + var result floats + + err := query.Query(db, &result) + require.NoError(t, err) + + require.Equal(t, "1.11111111111111111111", result.Decimal.String()) + require.Equal(t, "0", result.DecimalPtr.String()) // NULL + require.Equal(t, "2.22222222222222222222", result.Numeric.String()) + require.Equal(t, "0", result.NumericPtr.String()) // NULL + + require.Equal(t, 1.1111111111111112, result.Floats.Decimal) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.DecimalPtr) + require.Equal(t, 2.2222222222222223, result.Floats.Numeric) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.NumericPtr) + + // floating point + require.Equal(t, 3.3333333, result.Floats.Float) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.FloatPtr) + require.Equal(t, 4.444444444444445, result.Floats.Double) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.DoublePtr) + require.Equal(t, 5.555555555555555, result.Floats.Real) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.RealPtr) + }) + + t.Run("should insert decimal", func(t *testing.T) { + + insertQuery := Floats.INSERT( + Floats.AllColumns, + ).MODEL( + floats{ + Floats: model.Floats{ + // overwritten by wrapped(floats) scope + Numeric: 0.1, + NumericPtr: testutils.Float64Ptr(0.1), + Decimal: 0.1, + DecimalPtr: testutils.Float64Ptr(0.1), + + // not overwritten + Float: 0.2, + FloatPtr: testutils.Float64Ptr(0.22), + Double: 0.3, + DoublePtr: testutils.Float64Ptr(0.33), + Real: 0.4, + RealPtr: testutils.Float64Ptr(0.44), + }, + Numeric: decimal.RequireFromString("12.35"), + NumericPtr: decimal.RequireFromString("56.79"), + Decimal: decimal.RequireFromString("91.23"), + DecimalPtr: decimal.RequireFromString("45.67"), + }, + ) + + testutils.AssertDebugStatementSql(t, insertQuery, strings.Replace(` +INSERT INTO test_sample.floats (''decimal'', decimal_ptr, ''numeric'', numeric_ptr, ''float'', float_ptr, ''double'', double_ptr, ''real'', real_ptr) +VALUES ('91.23', '45.67', '12.35', '56.79', 0.2, 0.22, 0.3, 0.33, 0.4, 0.44); +`, "''", "`", -1)) + _, err := insertQuery.Exec(db) + require.NoError(t, err) + + var result floats + + err = SELECT(Floats.AllColumns). + FROM(Floats). + WHERE(Floats.Numeric.EQ(Float(12.35))). + Query(db, &result) + require.NoError(t, err) + + require.Equal(t, "12.35", result.Numeric.String()) + require.Equal(t, "56.79", result.NumericPtr.String()) + require.Equal(t, "91.23", result.Decimal.String()) + require.Equal(t, "45.67", result.DecimalPtr.String()) + + require.Equal(t, 12.35, result.Floats.Numeric) + require.Equal(t, 56.79, *result.Floats.NumericPtr) + require.Equal(t, 91.23, result.Floats.Decimal) + require.Equal(t, 45.67, *result.Floats.DecimalPtr) + }) +} diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 833e2a4..0571157 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -346,7 +346,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { require.NoError(t, err) testutils.AssertFileNamesEqual(t, modelFiles, "all_types.go", "all_types_view.go", "employee.go", "link.go", - "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go") + "mood.go", "person.go", "person_phone.go", "weird_names_table.go", "level.go", "user.go", "floats.go") testutils.AssertFileContent(t, modelDir+"all_types.go", allTypesModelContent) @@ -354,7 +354,7 @@ func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { require.NoError(t, err) testutils.AssertFileNamesEqual(t, tableFiles, "all_types.go", "employee.go", "link.go", - "person.go", "person_phone.go", "weird_names_table.go", "user.go") + "person.go", "person_phone.go", "weird_names_table.go", "user.go", "floats.go") testutils.AssertFileContent(t, tableDir+"all_types.go", allTypesTableContent) } diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index d9208b4..bc82c7e 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -36,70 +36,106 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; } func TestExactDecimals(t *testing.T) { + + type floats struct { + model.Floats + Numeric decimal.Decimal + NumericPtr decimal.Decimal + Decimal decimal.Decimal + DecimalPtr decimal.Decimal + } + t.Run("should query decimal", func(t *testing.T) { query := SELECT( - AllTypes.Numeric, - AllTypes.NumericPtr, - AllTypes.Decimal, - AllTypes.DecimalPtr, + Floats.AllColumns, ).FROM( - AllTypes, - ).WHERE(AllTypes.UUID.EQ(String("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))) + Floats, + ).WHERE(Floats.Decimal.EQ(Decimal("1.11111111111111111111"))) - type AllTypes struct { - model.AllTypes - Numeric decimal.Decimal - NumericPtr decimal.Decimal - Decimal decimal.Decimal - DecimalPtr decimal.Decimal - } - - var result AllTypes + var result floats err := query.Query(db, &result) require.NoError(t, err) - require.Equal(t, "1.11", result.Decimal.String()) - require.Equal(t, "1.11", result.DecimalPtr.String()) - require.Equal(t, "2.22", result.Numeric.String()) - require.Equal(t, "2.22", result.NumericPtr.String()) + require.Equal(t, "1.11111111111111111111", result.Decimal.String()) + require.Equal(t, "0", result.DecimalPtr.String()) // NULL + require.Equal(t, "2.22222222222222222222", result.Numeric.String()) + require.Equal(t, "0", result.NumericPtr.String()) // NULL + + require.Equal(t, 1.1111111111111112, result.Floats.Decimal) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.DecimalPtr) + require.Equal(t, 2.2222222222222223, result.Floats.Numeric) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.NumericPtr) + + // floating point + require.Equal(t, float32(3.3333333), result.Floats.Real) // precision loss + require.Equal(t, (*float32)(nil), result.Floats.RealPtr) + require.Equal(t, 4.444444444444445, result.Floats.Double) // precision loss + require.Equal(t, (*float64)(nil), result.Floats.DoublePtr) }) t.Run("should insert decimal", func(t *testing.T) { - type allTypes struct { - model.AllTypes - Numeric decimal.Decimal - NumericPtr decimal.Decimal - Decimal decimal.Decimal - DecimalPtr decimal.Decimal - } - m := allTypes{ - AllTypes: allTypesRow0, - Numeric: decimal.RequireFromString("12.345"), - NumericPtr: decimal.RequireFromString("56.789"), - Decimal: decimal.RequireFromString("91.23"), - DecimalPtr: decimal.RequireFromString("45.67"), - } + insertQuery := Floats.INSERT( + Floats.AllColumns, + ).MODEL( + floats{ + Floats: model.Floats{ + // overwritten by wrapped(floats) scope + Numeric: 0.1, + NumericPtr: testutils.Float64Ptr(0.1), + Decimal: 0.1, + DecimalPtr: testutils.Float64Ptr(0.1), - insertQuery := AllTypes.INSERT( - AllTypes.MutableColumns, - ).MODEL(m). - RETURNING( - AllTypes.Numeric, - AllTypes.NumericPtr, - AllTypes.Decimal, - AllTypes.DecimalPtr, - ) + // not overwritten + Real: 0.4, + RealPtr: testutils.Float32Ptr(0.44), + Double: 0.3, + DoublePtr: testutils.Float64Ptr(0.33), + }, + Numeric: decimal.RequireFromString("0.1234567890123456789"), + NumericPtr: decimal.RequireFromString("1.1111111111111111111"), + Decimal: decimal.RequireFromString("2.2222222222222222222"), + DecimalPtr: decimal.RequireFromString("3.3333333333333333333"), + }, + ).RETURNING( + Floats.AllColumns, + ) - var result allTypes + testutils.AssertDebugStatementSql(t, insertQuery, ` +INSERT INTO test_sample.floats (decimal_ptr, decimal, numeric_ptr, numeric, real_ptr, real, double_ptr, double) +VALUES ('3.3333333333333333333', '2.2222222222222222222', '1.1111111111111111111', '0.1234567890123456789', 0.4399999976158142, 0.4000000059604645, 0.33, 0.3) +RETURNING floats.decimal_ptr AS "floats.decimal_ptr", + floats.decimal AS "floats.decimal", + floats.numeric_ptr AS "floats.numeric_ptr", + floats.numeric AS "floats.numeric", + floats.real_ptr AS "floats.real_ptr", + floats.real AS "floats.real", + floats.double_ptr AS "floats.double_ptr", + floats.double AS "floats.double"; +`) + + var result floats err := insertQuery.Query(db, &result) require.NoError(t, err) - require.Equal(t, "12.345", result.Numeric.String()) - require.Equal(t, "56.789", result.NumericPtr.String()) - require.Equal(t, "91.23", result.Decimal.String()) - require.Equal(t, "45.67", result.DecimalPtr.String()) + // exact decimal + require.Equal(t, "0.1234567890123456789", result.Numeric.String()) + require.Equal(t, "1.1111111111111111111", result.NumericPtr.String()) + require.Equal(t, "2.2222222222222222222", result.Decimal.String()) + require.Equal(t, "3.3333333333333333333", result.DecimalPtr.String()) + + // precision loss + require.Equal(t, 0.12345678901234568, result.Floats.Numeric) + require.Equal(t, 1.1111111111111112, *result.Floats.NumericPtr) + require.Equal(t, 2.2222222222222223, result.Floats.Decimal) + require.Equal(t, 3.3333333333333335, *result.Floats.DecimalPtr) + + // floating points numbers + require.Equal(t, float32(0.4), result.Floats.Real) + require.Equal(t, float32(0.44), *result.Floats.RealPtr) + require.Equal(t, 0.3, result.Floats.Double) + require.Equal(t, 0.33, *result.Floats.DoublePtr) }) } diff --git a/tests/testdata b/tests/testdata index 391d936..1c97764 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 391d936515d2f826df073707697de44907a7f67d +Subproject commit 1c977643ceb0df149fc953ad617e2a86c6ecdd65