diff --git a/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go b/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go index c802aa9..bdb613a 100644 --- a/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go +++ b/examples/quick-start/.gen/jetdb/dvds/model/mpaa_rating.go @@ -20,26 +20,32 @@ const ( ) func (e *MpaaRating) Scan(value interface{}) error { - if v, ok := value.(string); !ok { - return errors.New("jet: Invalid data for MpaaRating enum") - } else { - switch string(v) { - case "G": - *e = MpaaRating_G - case "PG": - *e = MpaaRating_Pg - case "PG-13": - *e = MpaaRating_Pg13 - case "R": - *e = MpaaRating_R - case "NC-17": - *e = MpaaRating_Nc17 - default: - return errors.New("jet: Inavlid data " + string(v) + "for MpaaRating enum") - } - - return nil + var enumValue string + switch val := value.(type) { + case string: + enumValue = val + case []byte: + enumValue = string(val) + default: + return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte") } + + switch enumValue { + case "G": + *e = MpaaRating_G + case "PG": + *e = MpaaRating_Pg + case "PG-13": + *e = MpaaRating_Pg13 + case "R": + *e = MpaaRating_R + case "NC-17": + *e = MpaaRating_Nc17 + default: + return errors.New("jet: Invalid scan value '" + enumValue + "' for MpaaRating enum") + } + + return nil } func (e MpaaRating) String() string { diff --git a/generator/template/file_templates.go b/generator/template/file_templates.go index a738ea6..8b5c5b0 100644 --- a/generator/template/file_templates.go +++ b/generator/template/file_templates.go @@ -200,20 +200,26 @@ const ( ) func (e *{{$enumTemplate.TypeName}}) Scan(value interface{}) error { - if v, ok := value.(string); !ok { - return errors.New("jet: Invalid scan value for {{$enumTemplate.TypeName}} enum. Enum value has to be of type string") - } else { - switch string(v) { -{{- range $_, $value := .Values}} - case "{{$value}}": - *e = {{valueName $value}} -{{- end}} - default: - return errors.New("jet: Invalid scan value '" + string(v) + "' for {{$enumTemplate.TypeName}} enum") - } - - return nil + var enumValue string + switch val := value.(type) { + case string: + enumValue = val + case []byte: + enumValue = string(val) + default: + return errors.New("jet: Invalid scan value for AllTypesEnum enum. Enum value has to be of type string or []byte") } + + switch enumValue { +{{- range $_, $value := .Values}} + case "{{$value}}": + *e = {{valueName $value}} +{{- end}} + default: + return errors.New("jet: Invalid scan value '" + enumValue + "' for {{$enumTemplate.TypeName}} enum") + } + + return nil } func (e {{$enumTemplate.TypeName}}) String() string { diff --git a/internal/utils/min/min.go b/internal/utils/min/min.go new file mode 100644 index 0000000..0e92146 --- /dev/null +++ b/internal/utils/min/min.go @@ -0,0 +1,9 @@ +package min + +// Int returns minimum of two int values +func Int(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/qrm/internal/null_types.go b/qrm/internal/null_types.go index 5a39094..3ec5bdb 100644 --- a/qrm/internal/null_types.go +++ b/qrm/internal/null_types.go @@ -1,263 +1,170 @@ package internal import ( + "database/sql" "database/sql/driver" "fmt" + "github.com/go-jet/jet/v2/internal/utils/min" + "reflect" "strconv" "time" ) -//===============================================================// - -// NullByteArray struct -type NullByteArray struct { - ByteArray []byte - Valid bool +// NullBool struct +type NullBool struct { + sql.NullBool } // Scan implements the Scanner interface. -func (nb *NullByteArray) Scan(value interface{}) error { +func (nb *NullBool) Scan(value interface{}) error { switch v := value.(type) { - case nil: - nb.Valid = false - return nil - case []byte: - nb.ByteArray = append(v[:0:0], v...) + case bool: + nb.Bool, nb.Valid = v, true + case int8, int16, int32, int64, int: + intVal := reflect.ValueOf(v).Int() + + if intVal != 0 && intVal != 1 { + return fmt.Errorf("can't assign %T(%d) to bool", value, value) + } + + nb.Bool = intVal == 1 + nb.Valid = true + case uint8, uint16, uint32, uint64, uint: + uintVal := reflect.ValueOf(v).Uint() + + if uintVal != 0 && uintVal != 1 { + return fmt.Errorf("can't assign %T(%d) to bool", value, value) + } + + nb.Bool = uintVal == 1 nb.Valid = true - return nil default: - return fmt.Errorf("can't scan []byte from %v", value) + return nb.NullBool.Scan(value) } -} -// Value implements the driver Valuer interface. -func (nb NullByteArray) Value() (driver.Value, error) { - if !nb.Valid { - return nil, nil - } - return nb.ByteArray, nil + return nil } -//===============================================================// - // NullTime struct type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL + sql.NullTime } // Scan implements the Scanner interface. -func (nt *NullTime) Scan(value interface{}) (err error) { - switch v := value.(type) { - case nil: - nt.Valid = false - return - case time.Time: - nt.Time, nt.Valid = v, true - return - case []byte: - nt.Time, nt.Valid = parseTime(string(v)) - return - case string: - nt.Time, nt.Valid = parseTime(v) - return - default: +func (nt *NullTime) Scan(value interface{}) error { + err := nt.NullTime.Scan(value) + + if err == nil { + return nil + } + + // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value. + // At this point we try to parse time using some of the predefined formats + nt.Time, nt.Valid = tryParseAsTime(value) + + if !nt.Valid { return fmt.Errorf("can't scan time.Time from %v", value) } + + return nil } -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil +var formats = []string{ + "2006-01-02 15:04:05.999999", // go-sql-driver/mysql + "15:04:05-07", // pgx + "15:04:05.999999", // pgx } -const formatTime = "2006-01-02 15:04:05.999999" +func tryParseAsTime(value interface{}) (time.Time, bool) { -func parseTime(timeStr string) (t time.Time, valid bool) { + var timeStr string - var format string - - switch len(timeStr) { - case 8: - format = formatTime[11:19] - case 10, 19, 21, 22, 23, 24, 25, 26: - format = formatTime[:len(timeStr)] - default: - return t, false - } - - t, err := time.Parse(format, timeStr) - return t, err == nil -} - -//===============================================================// - -// NullInt8 struct -type NullInt8 struct { - Int8 int8 - Valid bool -} - -// Scan implements the Scanner interface. -func (n *NullInt8) Scan(value interface{}) (err error) { switch v := value.(type) { - case nil: - n.Valid = false - return - case int64: - n.Int8, n.Valid = int8(v), true - return - case int8: - n.Int8, n.Valid = v, true - return + case string: + timeStr = v case []byte: - intV, err := strconv.ParseInt(string(v), 10, 8) - if err == nil { - n.Int8, n.Valid = int8(intV), true + timeStr = string(v) + } + + for _, format := range formats { + formatLen := min.Int(len(format), len(timeStr)) + t, err := time.Parse(format[:formatLen], timeStr) + + if err != nil { + continue } - return err - default: - return fmt.Errorf("can't scan int8 from %v", value) + + return t, true } + + return time.Time{}, false } -// Value implements the driver Valuer interface. -func (n NullInt8) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil - } - return n.Int8, nil -} - -//===============================================================// - -// NullInt16 struct -type NullInt16 struct { - Int16 int16 - Valid bool +// NullUInt64 struct +type NullUInt64 struct { + UInt64 uint64 + Valid bool } // Scan implements the Scanner interface. -func (n *NullInt16) Scan(value interface{}) error { - +func (n *NullUInt64) Scan(value interface{}) error { + var stringValue string switch v := value.(type) { case nil: n.Valid = false return nil case int64: - n.Int16, n.Valid = int16(v), true + n.UInt64, n.Valid = uint64(v), true return nil - case int16: - n.Int16, n.Valid = v, true - return nil - case int8: - n.Int16, n.Valid = int16(v), true - return nil - case uint8: - n.Int16, n.Valid = int16(v), true - return nil - case []byte: - intV, err := strconv.ParseInt(string(v), 10, 16) - if err == nil { - n.Int16, n.Valid = int16(intV), true - } - return nil - default: - return fmt.Errorf("can't scan int16 from %v", value) - } -} - -// Value implements the driver Valuer interface. -func (n NullInt16) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil - } - return n.Int16, nil -} - -//===============================================================// - -// NullInt32 struct -type NullInt32 struct { - Int32 int32 - Valid bool -} - -// Scan implements the Scanner interface. -func (n *NullInt32) Scan(value interface{}) error { - switch v := value.(type) { - case nil: - n.Valid = false - return nil - case int64: - n.Int32, n.Valid = int32(v), true + case uint64: + n.UInt64, n.Valid = v, true return nil case int32: - n.Int32, n.Valid = v, true + n.UInt64, n.Valid = uint64(v), true + return nil + case uint32: + n.UInt64, n.Valid = uint64(v), true return nil case int16: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true return nil case uint16: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true return nil case int8: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true return nil case uint8: - n.Int32, n.Valid = int32(v), true + n.UInt64, n.Valid = uint64(v), true + return nil + case int: + n.UInt64, n.Valid = uint64(v), true + return nil + case uint: + n.UInt64, n.Valid = uint64(v), true return nil case []byte: - intV, err := strconv.ParseInt(string(v), 10, 32) - if err == nil { - n.Int32, n.Valid = int32(intV), true - } - return nil + stringValue = string(v) + case string: + stringValue = v default: - return fmt.Errorf("can't scan int32 from %v", value) + return fmt.Errorf("can't scan uint64 from %v", value) } + + uintV, err := strconv.ParseUint(stringValue, 10, 64) + if err != nil { + return err + } + n.UInt64 = uintV + n.Valid = true + + return nil } // Value implements the driver Valuer interface. -func (n NullInt32) Value() (driver.Value, error) { +func (n NullUInt64) Value() (driver.Value, error) { if !n.Valid { return nil, nil } - return n.Int32, nil -} - -//===============================================================// - -// NullFloat32 struct -type NullFloat32 struct { - Float32 float32 - Valid bool -} - -// Scan implements the Scanner interface. -func (n *NullFloat32) Scan(value interface{}) error { - switch v := value.(type) { - case nil: - n.Valid = false - return nil - case float64: - n.Float32, n.Valid = float32(v), true - return nil - case float32: - n.Float32, n.Valid = v, true - return nil - default: - return fmt.Errorf("can't scan float32 from %v", value) - } -} - -// Value implements the driver Valuer interface. -func (n NullFloat32) Value() (driver.Value, error) { - if !n.Valid { - return nil, nil - } - return n.Float32, nil + return n.UInt64, nil } diff --git a/qrm/internal/null_types_test.go b/qrm/internal/null_types_test.go index 8f4adde..a15b104 100644 --- a/qrm/internal/null_types_test.go +++ b/qrm/internal/null_types_test.go @@ -7,141 +7,85 @@ import ( "time" ) -func TestNullByteArray(t *testing.T) { - var array NullByteArray +func TestNullBool(t *testing.T) { + var nullBool NullBool - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) + require.NoError(t, nullBool.Scan(nil)) + require.Equal(t, nullBool.Valid, false) - require.NoError(t, array.Scan([]byte("bytea"))) - require.Equal(t, array.Valid, true) - require.Equal(t, string(array.ByteArray), string([]byte("bytea"))) + require.NoError(t, nullBool.Scan(int64(1))) + require.Equal(t, nullBool.Valid, true) + value, _ := nullBool.Value() + require.Equal(t, value, true) - require.Error(t, array.Scan(12), "can't scan []byte from 12") + require.NoError(t, nullBool.Scan(uint32(0))) + require.Equal(t, nullBool.Valid, true) + value, _ = nullBool.Value() + require.Equal(t, value, false) + + require.EqualError(t, nullBool.Scan(uint16(22)), "can't assign uint16(22) to bool") } func TestNullTime(t *testing.T) { - var array NullTime + var nullTime NullTime - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) + require.NoError(t, nullTime.Scan(nil)) + require.Equal(t, nullTime.Valid, false) time := time.Now() - require.NoError(t, array.Scan(time)) - require.Equal(t, array.Valid, true) - value, _ := array.Value() + require.NoError(t, nullTime.Scan(time)) + require.Equal(t, nullTime.Valid, true) + value, _ := nullTime.Value() require.Equal(t, value, time) - require.NoError(t, array.Scan([]byte("13:10:11"))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() + require.NoError(t, nullTime.Scan([]byte("13:10:11"))) + require.Equal(t, nullTime.Valid, true) + value, _ = nullTime.Value() require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") - require.NoError(t, array.Scan("13:10:11")) - require.Equal(t, array.Valid, true) - value, _ = array.Value() + require.NoError(t, nullTime.Scan("13:10:11")) + require.Equal(t, nullTime.Valid, true) + value, _ = nullTime.Value() require.Equal(t, fmt.Sprintf("%v", value), "0000-01-01 13:10:11 +0000 UTC") - require.Error(t, array.Scan(12), "can't scan time.Time from 12") + require.Error(t, nullTime.Scan(12), "can't scan time.Time from 12") } -func TestNullInt8(t *testing.T) { - var array NullInt8 +func TestNullUInt64(t *testing.T) { + var nullUInt64 NullUInt64 - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) + require.NoError(t, nullUInt64.Scan(nil)) + require.Equal(t, nullUInt64.Valid, false) - require.NoError(t, array.Scan(int64(11))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, int8(11)) + require.NoError(t, nullUInt64.Scan(int64(11))) + require.Equal(t, nullUInt64.Valid, true) + value, _ := nullUInt64.Value() + require.Equal(t, value, uint64(11)) - require.Error(t, array.Scan("text"), "can't scan int8 from text") -} - -func TestNullInt16(t *testing.T) { - var array NullInt16 - - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) - - require.NoError(t, array.Scan(int64(11))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, int16(11)) - - require.NoError(t, array.Scan(int16(20))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int16(20)) - - require.NoError(t, array.Scan(int8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int16(30)) - - require.NoError(t, array.Scan(uint8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int16(30)) - - require.Error(t, array.Scan("text"), "can't scan int16 from text") -} - -func TestNullInt32(t *testing.T) { - var array NullInt32 - - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) - - require.NoError(t, array.Scan(int64(11))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, int32(11)) - - require.NoError(t, array.Scan(int32(32))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(32)) - - require.NoError(t, array.Scan(int16(20))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(20)) - - require.NoError(t, array.Scan(uint16(16))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(16)) - - require.NoError(t, array.Scan(int8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(30)) - - require.NoError(t, array.Scan(uint8(30))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, int32(30)) - - require.Error(t, array.Scan("text"), "can't scan int32 from text") -} - -func TestNullFloat32(t *testing.T) { - var array NullFloat32 - - require.NoError(t, array.Scan(nil)) - require.Equal(t, array.Valid, false) - - require.NoError(t, array.Scan(float64(64))) - require.Equal(t, array.Valid, true) - value, _ := array.Value() - require.Equal(t, value, float32(64)) - - require.NoError(t, array.Scan(float32(32))) - require.Equal(t, array.Valid, true) - value, _ = array.Value() - require.Equal(t, value, float32(32)) - - require.Error(t, array.Scan(12), "can't scan float32 from 12") + require.NoError(t, nullUInt64.Scan(int32(32))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(32)) + + require.NoError(t, nullUInt64.Scan(int16(20))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(20)) + + require.NoError(t, nullUInt64.Scan(uint16(16))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(16)) + + require.NoError(t, nullUInt64.Scan(int8(30))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(30)) + + require.NoError(t, nullUInt64.Scan(uint8(30))) + require.Equal(t, nullUInt64.Valid, true) + value, _ = nullUInt64.Value() + require.Equal(t, value, uint64(30)) + + require.Error(t, nullUInt64.Scan("text"), "can't scan int32 from text") } diff --git a/qrm/qrm.go b/qrm/qrm.go index c21ce1f..4502402 100644 --- a/qrm/qrm.go +++ b/qrm/qrm.go @@ -27,7 +27,10 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr if destinationPtrType.Elem().Kind() == reflect.Slice { _, err := queryToSlice(ctx, db, query, args, destPtr) - return err + if err != nil { + return fmt.Errorf("jet: %w", err) + } + return nil } else if destinationPtrType.Elem().Kind() == reflect.Struct { tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSliceValue := tempSlicePtrValue.Elem() @@ -35,7 +38,7 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr rowsProcessed, err := queryToSlice(ctx, db, query, args, tempSlicePtrValue.Interface()) if err != nil { - return err + return fmt.Errorf("jet: %w", err) } if rowsProcessed == 0 { @@ -275,10 +278,16 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re err = scanner.Scan(cellValue) if err != nil { - panic("jet: " + err.Error() + ", " + fieldToString(&field) + " of type " + structType.String()) + err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) + return } } else { - setReflectValue(reflect.ValueOf(cellValue), fieldValue) + err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) + + if err != nil { + err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) + return + } } } } diff --git a/qrm/scan_context.go b/qrm/scan_context.go index 5f26e8d..dbc4b87 100644 --- a/qrm/scan_context.go +++ b/qrm/scan_context.go @@ -2,10 +2,7 @@ package qrm import ( "database/sql" - "database/sql/driver" "fmt" - "github.com/go-jet/jet/v2/internal/utils" - "github.com/go-jet/jet/v2/internal/utils/throw" "reflect" "strings" ) @@ -46,7 +43,7 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { } return &scanContext{ - row: createScanValue(columnTypes), + row: createScanSlice(len(columnTypes)), uniqueDestObjectsMap: make(map[string]int), groupKeyInfoCache: make(map[string]groupKeyInfo), @@ -56,6 +53,17 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { }, nil } +func createScanSlice(columnCount int) []interface{} { + scanSlice := make([]interface{}, columnCount) + scanPtrSlice := make([]interface{}, columnCount) + + for i := range scanPtrSlice { + scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value + } + + return scanPtrSlice +} + type typeInfo struct { fieldMappings []fieldMapping } @@ -210,16 +218,13 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { } func (s *scanContext) rowElem(index int) interface{} { + cellValue := reflect.ValueOf(s.row[index]) - valuer, ok := s.row[index].(driver.Valuer) + if cellValue.IsValid() && !cellValue.IsNil() { + return cellValue.Elem().Interface() + } - utils.MustBeTrue(ok, "jet: internal error, scan value doesn't implement driver.Valuer") - - value, err := valuer.Value() - - throw.OnError(err) - - return value + return nil } func (s *scanContext) rowElemValuePtr(index int) reflect.Value { diff --git a/qrm/utill.go b/qrm/utill.go index f485797..fa1435a 100644 --- a/qrm/utill.go +++ b/qrm/utill.go @@ -7,7 +7,6 @@ import ( "github.com/go-jet/jet/v2/qrm/internal" "github.com/google/uuid" "reflect" - "strconv" "strings" "time" ) @@ -56,21 +55,22 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e sliceValue := slicePtrValue.Elem() sliceElemType := sliceValue.Type().Elem() - newElemValue := objPtrValue + newSliceElemValue := reflect.New(sliceElemType).Elem() - if sliceElemType.Kind() != reflect.Ptr { - newElemValue = objPtrValue.Elem() + var err error + + if newSliceElemValue.Kind() == reflect.Ptr { + newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem())) + err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem()) + } else { + err = tryAssign(objPtrValue.Elem(), newSliceElemValue) } - if newElemValue.Type().ConvertibleTo(sliceElemType) { - newElemValue = newElemValue.Convert(sliceElemType) + if err != nil { + return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err) } - if !newElemValue.Type().AssignableTo(sliceElemType) { - panic("jet: can't append " + newElemValue.Type().String() + " to " + sliceValue.Type().String() + " slice") - } - - sliceValue.Set(reflect.Append(sliceValue, newElemValue)) + sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue)) return nil } @@ -121,7 +121,6 @@ func toCommonIdentifier(name string) string { } func initializeValueIfNilPtr(value reflect.Value) { - if !value.IsValid() || !value.CanSet() { return } @@ -173,172 +172,147 @@ func isSimpleModelType(objType reflect.Type) bool { return objType == timeType || objType == uuidType || objType == byteArrayType } -func isIntegerType(value reflect.Type) bool { - switch value { - case int8Type, unit8Type, int16Type, uint16Type, - int32Type, uint32Type, int64Type, uint64Type: +func isFloatType(value reflect.Type) bool { + switch value.Kind() { + case reflect.Float32, reflect.Float64: return true } return false } -func isNumber(valueType reflect.Type) bool { - return isIntegerType(valueType) || valueType == float64Type || valueType == float32Type -} +func tryAssign(source, destination reflect.Value) error { -func tryAssign(source, destination reflect.Value) bool { + if source.Type() != destination.Type() && + !isFloatType(destination.Type()) && // to preserve precision during conversion + source.Type().ConvertibleTo(destination.Type()) { - switch { - case source.Type().ConvertibleTo(destination.Type()): source = source.Convert(destination.Type()) - case isIntegerType(source.Type()) && destination.Type() == boolType: - intValue := source.Int() - - if intValue == 1 { - source = reflect.ValueOf(true) - } else if intValue == 0 { - source = reflect.ValueOf(false) - } - case source.Type() == stringType && isNumber(destination.Type()): - // if source is string and destination is a number(int8, int32, float32, ...), we first parse string to float64 number - // and then parsed number is converted into destination type - f, err := strconv.ParseFloat(source.String(), 64) - if err != nil { - return false - } - source = reflect.ValueOf(f) - - if source.Type().ConvertibleTo(destination.Type()) { - source = source.Convert(destination.Type()) - } } if source.Type().AssignableTo(destination.Type()) { - destination.Set(source) - return true + switch b := source.Interface().(type) { + case []byte: + destination.SetBytes(cloneBytes(b)) + default: + destination.Set(source) + } + return nil } - return false + sourceInterface := source.Interface() + + switch destination.Interface().(type) { + case bool: + var nullBool internal.NullBool + + err := nullBool.Scan(sourceInterface) + + if err != nil { + return err + } + + destination.SetBool(nullBool.Bool) + + case float32, float64: + var nullFloat sql.NullFloat64 + + err := nullFloat.Scan(sourceInterface) + if err != nil { + return err + } + + if nullFloat.Valid { + destination.SetFloat(nullFloat.Float64) + } + case int, int8, int16, int32, int64: + var integer sql.NullInt64 + + err := integer.Scan(sourceInterface) + if err != nil { + return err + } + + if integer.Valid { + destination.SetInt(integer.Int64) + } + + case uint, uint8, uint16, uint32, uint64: + var uInt internal.NullUInt64 + + err := uInt.Scan(sourceInterface) + + if err != nil { + return err + } + + if uInt.Valid { + destination.SetUint(uInt.UInt64) + } + + case string: + var str sql.NullString + + err := str.Scan(sourceInterface) + if err != nil { + return err + } + + if str.Valid { + destination.SetString(str.String) + } + + case time.Time: + var nullTime internal.NullTime + + err := nullTime.Scan(sourceInterface) + if err != nil { + return err + } + + if nullTime.Valid { + destination.Set(reflect.ValueOf(nullTime.Time)) + } + + default: + return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface()) + } + + return nil } -func setReflectValue(source, destination reflect.Value) { - - if tryAssign(source, destination) { - return - } +func setReflectValue(source, destination reflect.Value) error { if destination.Kind() == reflect.Ptr { - if source.Kind() == reflect.Ptr { - if !source.IsNil() { - if destination.IsNil() { - initializeValueIfNilPtr(destination) - } - - if tryAssign(source.Elem(), destination.Elem()) { - return - } - } else { - return - } - } else { - if source.CanAddr() { - source = source.Addr() - } else { - sourceCopy := reflect.New(source.Type()) - sourceCopy.Elem().Set(source) - - source = sourceCopy - } - - if tryAssign(source, destination) { - return - } - - if tryAssign(source.Elem(), destination.Elem()) { - return - } + if destination.IsNil() { + initializeValueIfNilPtr(destination) } - } else { + if source.Kind() == reflect.Ptr { if source.IsNil() { - return + return nil // source is nil, destination should keep its zero value } source = source.Elem() } - if tryAssign(source, destination) { - return + if err := tryAssign(source, destination.Elem()); err != nil { + return err + } + + } else { + if source.Kind() == reflect.Ptr { + if source.IsNil() { + return nil // source is nil, destination should keep its zero value + } + source = source.Elem() + } + + if err := tryAssign(source, destination); err != nil { + return err } } - panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) -} - -func createScanValue(columnTypes []*sql.ColumnType) []interface{} { - values := make([]interface{}, len(columnTypes)) - - for i, sqlColumnType := range columnTypes { - columnType := newScanType(sqlColumnType) - - columnValue := reflect.New(columnType) - - values[i] = columnValue.Interface() - } - - return values -} - -var boolType = reflect.TypeOf(true) -var int8Type = reflect.TypeOf(int8(1)) -var unit8Type = reflect.TypeOf(uint8(1)) -var int16Type = reflect.TypeOf(int16(1)) -var uint16Type = reflect.TypeOf(uint16(1)) -var int32Type = reflect.TypeOf(int32(1)) -var uint32Type = reflect.TypeOf(uint32(1)) -var int64Type = reflect.TypeOf(int64(1)) -var uint64Type = reflect.TypeOf(uint64(1)) -var float32Type = reflect.TypeOf(float32(1)) -var float64Type = reflect.TypeOf(float64(1)) -var stringType = reflect.TypeOf("") - -var nullBoolType = reflect.TypeOf(sql.NullBool{}) -var nullInt8Type = reflect.TypeOf(internal.NullInt8{}) -var nullInt16Type = reflect.TypeOf(internal.NullInt16{}) -var nullInt32Type = reflect.TypeOf(internal.NullInt32{}) -var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) -var nullFloat32Type = reflect.TypeOf(internal.NullFloat32{}) -var nullFloat64Type = reflect.TypeOf(sql.NullFloat64{}) -var nullStringType = reflect.TypeOf(sql.NullString{}) -var nullTimeType = reflect.TypeOf(internal.NullTime{}) -var nullByteArrayType = reflect.TypeOf(internal.NullByteArray{}) - -func newScanType(columnType *sql.ColumnType) reflect.Type { - - switch columnType.DatabaseTypeName() { - case "TINYINT": - return nullInt8Type - case "INT2", "SMALLINT", "YEAR": - return nullInt16Type - case "INT4", "MEDIUMINT", "INT": - return nullInt32Type - case "INT8", "BIGINT": - return nullInt64Type - case "CHAR", "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "UUID", "JSON", "JSONB", "INTERVAL", "POINT", "BIT", "VARBIT", "XML": - return nullStringType - case "FLOAT4": - return nullFloat32Type - case "FLOAT8", "FLOAT", "DOUBLE": - return nullFloat64Type - case "BOOL": - return nullBoolType - case "BYTEA", "BINARY", "VARBINARY", "BLOB": - return nullByteArrayType - case "DATE", "DATETIME", "TIMESTAMP", "TIMESTAMPTZ", "TIME", "TIMETZ": - return nullTimeType - default: - return nullStringType - } + return nil } func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { @@ -385,3 +359,12 @@ func fieldToString(field *reflect.StructField) string { return " at '" + field.Name + " " + field.Type.String() + "'" } + +func cloneBytes(b []byte) []byte { + if b == nil { + return nil + } + c := make([]byte, len(b)) + copy(c, b) + return c +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 29986da..ea66519 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -17,11 +17,12 @@ import ( ) func TestAllTypesSelect(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string - dest := []model.AllTypes{} - err := AllTypes.SELECT(AllTypes.AllColumns).Query(db, &dest) + err := AllTypes.SELECT( + AllTypes.AllColumns, + ).LIMIT(2). + Query(db, &dest) require.NoError(t, err) testutils.AssertDeepEqual(t, dest[0], allTypesRow0) @@ -29,8 +30,6 @@ func TestAllTypesSelect(t *testing.T) { } func TestAllTypesViewSelect(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string - type AllTypesView model.AllTypes dest := []AllTypesView{} @@ -43,7 +42,7 @@ func TestAllTypesViewSelect(t *testing.T) { } func TestAllTypesInsertModel(t *testing.T) { - skipForPgxDriver(t) // pgx driver does not handle well time with time zone + skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008) query := AllTypes.INSERT(AllTypes.AllColumns). MODEL(allTypesRow0). @@ -60,8 +59,6 @@ func TestAllTypesInsertModel(t *testing.T) { } func TestAllTypesInsertQuery(t *testing.T) { - skipForPgxDriver(t) // pgx driver does not handle well time with time zone - query := AllTypes.INSERT(AllTypes.AllColumns). QUERY( AllTypes. @@ -80,8 +77,6 @@ func TestAllTypesInsertQuery(t *testing.T) { } func TestAllTypesFromSubQuery(t *testing.T) { - skipForPgxDriver(t) - subQuery := SELECT(AllTypes.AllColumns). FROM(AllTypes). AsTable("allTypesSubQuery") @@ -302,10 +297,10 @@ LIMIT $11; func TestExpressionCast(t *testing.T) { - skipForPgxDriver(t) // for some reason, pgx driver, 150:char(12) returns as int value + skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text' query := AllTypes.SELECT( - CAST(Int(150)).AS_CHAR(12).AS("char12"), + CAST(Int(151)).AS_CHAR(12).AS("char12"), CAST(String("TRUE")).AS_BOOL(), CAST(String("111")).AS_SMALLINT(), CAST(String("111")).AS_INTEGER(), @@ -349,7 +344,7 @@ func TestExpressionCast(t *testing.T) { } func TestStringOperators(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns text column as int value + skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text' query := AllTypes.SELECT( AllTypes.Text.EQ(AllTypes.Char), @@ -866,8 +861,6 @@ func TestInterval(t *testing.T) { } func TestSubQueryColumnReference(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string value - type expected struct { sql string args []interface{} @@ -1044,8 +1037,6 @@ FROM` } func TestTimeLiterals(t *testing.T) { - skipForPgxDriver(t) // pgx driver returns time with time zone as string - loc, err := time.LoadLocation("Europe/Berlin") require.NoError(t, err) @@ -1060,8 +1051,6 @@ func TestTimeLiterals(t *testing.T) { ).FROM(AllTypes). LIMIT(1) - //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` SELECT $1::date AS "date", $2::time without time zone AS "time", @@ -1073,25 +1062,29 @@ LIMIT $6; `) var dest struct { - Date time.Time - Time time.Time - Timez time.Time - Timestamp time.Time - //Timestampz time.Time + Date time.Time + Time time.Time + Timez time.Time + Timestamp time.Time + Timestampz time.Time } err = query.Query(db, &dest) require.NoError(t, err) - //testutils.PrintJson(dest) + // pq driver will return time with time zone in local timezone, + // while pgx driver will return time in UTC time zone + dest.Timez = dest.Timez.UTC() + dest.Timestampz = dest.Timestampz.UTC() testutils.AssertJSON(t, dest, ` { "Date": "2009-11-17T00:00:00Z", "Time": "0000-01-01T20:34:58.651387Z", - "Timez": "0000-01-01T20:34:58.651387+01:00", - "Timestamp": "2009-11-17T20:34:58.651387Z" + "Timez": "0000-01-01T19:34:58.651387Z", + "Timestamp": "2009-11-17T20:34:58.651387Z", + "Timestampz": "2009-11-17T19:34:58.651387Z" } `) requireLogged(t, query) diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index e06c985..541747c 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -31,7 +31,7 @@ func TestMain(m *testing.M) { setTestRoot() - for _, driverName := range []string{"postgres", "pgx"} { + for _, driverName := range []string{"pgx", "postgres"} { fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) func() { @@ -81,8 +81,16 @@ func requireLogged(t *testing.T, statement postgres.Statement) { } func skipForPgxDriver(t *testing.T) { - switch db.Driver().(type) { - case *stdlib.Driver: + if isPgxDriver() { t.SkipNow() } } + +func isPgxDriver() bool { + switch db.Driver().(type) { + case *stdlib.Driver: + return true + } + + return false +} diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index b34e17d..ce3cc46 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -78,16 +78,31 @@ func TestScanToValidDestination(t *testing.T) { require.NoError(t, err) }) - t.Run("pointer to slice of strings", func(t *testing.T) { - err := oneInventoryQuery.Query(db, &[]int32{}) + t.Run("pointer to slice of integers", func(t *testing.T) { + var dest []int32 + err := oneInventoryQuery.Query(db, &dest) require.NoError(t, err) + require.Equal(t, dest[0], int32(1)) }) - t.Run("pointer to slice of strings", func(t *testing.T) { - err := oneInventoryQuery.Query(db, &[]*int32{}) + t.Run("pointer to slice integer pointers", func(t *testing.T) { + var dest []*int32 + err := oneInventoryQuery.Query(db, &dest) require.NoError(t, err) + require.Equal(t, dest[0], testutils.Int32Ptr(1)) + }) + + t.Run("NULL to integer", func(t *testing.T) { + var dest struct { + Int64 int64 + UInt64 uint64 + } + err := SELECT(NULL.AS("int64"), NULL.AS("uint64")).Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest.Int64, int64(0)) + require.Equal(t, dest.UInt64, uint64(0)) }) } @@ -189,7 +204,9 @@ func TestScanToStruct(t *testing.T) { dest := Inventory{} - testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: Scan: unable to scan type int32 into UUID, at 'InventoryID uuid.UUID' of type postgres.Inventory`) + err := query.Query(db, &dest) + require.Error(t, err) + require.EqualError(t, err, "jet: can't scan int64('\\x01') to 'InventoryID uuid.UUID': Scan: unable to scan type int64 into UUID") }) t.Run("type mismatch base type", func(t *testing.T) { @@ -200,7 +217,9 @@ func TestScanToStruct(t *testing.T) { dest := []Inventory{} - testutils.AssertQueryPanicErr(t, query.OFFSET(10), db, &dest, `jet: can't set int16 to bool`) + err := query.OFFSET(10).Query(db, &dest) + require.Error(t, err) + require.EqualError(t, err, "jet: can't assign int64('\\x02') to 'FilmID bool': can't assign int64(2) to bool") }) } @@ -451,8 +470,9 @@ func TestScanToSlice(t *testing.T) { t.Run("slice type mismatch", func(t *testing.T) { var dest []bool - testutils.AssertQueryPanicErr(t, query, db, &dest, `jet: can't append int32 to []bool slice`) - //require.Error(t, err, `jet: can't append int32 to []bool slice `) + err := query.Query(db, &dest) + require.Error(t, err) + require.EqualError(t, err, `jet: can't append int64 to []bool slice: can't assign int64(2) to bool`) }) }) @@ -764,16 +784,8 @@ func TestRowsScan(t *testing.T) { requireLogged(t, stmt) } -func TestScanNumericToNumber(t *testing.T) { +func TestScanNumericToFloat(t *testing.T) { type Number struct { - Int8 int8 - UInt8 uint8 - Int16 int16 - UInt16 uint16 - Int32 int32 - UInt32 uint32 - Int64 int64 - UInt64 uint64 Float32 float32 Float64 float64 } @@ -781,14 +793,6 @@ func TestScanNumericToNumber(t *testing.T) { numeric := CAST(Decimal("1234567890.111")).AS_NUMERIC() stmt := SELECT( - numeric.AS("number.int8"), - numeric.AS("number.uint8"), - numeric.AS("number.int16"), - numeric.AS("number.uint16"), - numeric.AS("number.int32"), - numeric.AS("number.uint32"), - numeric.AS("number.int64"), - numeric.AS("number.uint64"), numeric.AS("number.float32"), numeric.AS("number.float64"), ) @@ -796,19 +800,30 @@ func TestScanNumericToNumber(t *testing.T) { var number Number err := stmt.Query(db, &number) require.NoError(t, err) - - require.Equal(t, number.Int8, int8(-46)) // overflow - require.Equal(t, number.UInt8, uint8(210)) // overflow - require.Equal(t, number.Int16, int16(722)) // overflow - require.Equal(t, number.UInt16, uint16(722)) // overflow - require.Equal(t, number.Int32, int32(1234567890)) - require.Equal(t, number.UInt32, uint32(1234567890)) - require.Equal(t, number.Int64, int64(1234567890)) - require.Equal(t, number.UInt64, uint64(1234567890)) require.Equal(t, number.Float32, float32(1.234568e+09)) require.Equal(t, number.Float64, float64(1.234567890111e+09)) } +func TestScanNumericToIntegerError(t *testing.T) { + + var dest struct { + Integer int32 + } + + err := SELECT( + CAST(Decimal("1234567890.111")).AS_NUMERIC().AS("integer"), + ).Query(db, &dest) + + require.Error(t, err) + + if isPgxDriver() { + require.Contains(t, err.Error(), `jet: can't assign string("1234567890.111") to 'Integer int32': converting driver.Value type string ("1234567890.111") to a int64: invalid syntax`) + } else { + require.Contains(t, err.Error(), `jet: can't assign []uint8("1234567890.111") to 'Integer int32': converting driver.Value type []uint8 ("1234567890.111") to a int64: invalid syntax`) + } + +} + // QueryContext panic when the scanned value is nil and the destination is a slice of primitive // https://github.com/go-jet/jet/issues/91 func TestScanToPrimitiveElementsSlice(t *testing.T) {