Bug fix: Bytea nil values aren't stored as null in database.

This commit is contained in:
go-jet 2019-06-12 16:21:50 +02:00
parent c598978ba6
commit 038a4b9dd0
6 changed files with 29 additions and 15 deletions

View file

@ -74,7 +74,7 @@ func (c ColumnInfo) GoBaseType() string {
func (c ColumnInfo) GoModelType() string { func (c ColumnInfo) GoModelType() string {
typeStr := c.GoBaseType() typeStr := c.GoBaseType()
if c.IsNullable && !strings.HasPrefix(typeStr, "[]") { if c.IsNullable {
return "*" + typeStr return "*" + typeStr
} }

View file

@ -86,6 +86,8 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{})
groupTime := time.Duration(0) groupTime := time.Duration(0)
slicePtrValue := reflect.ValueOf(slicePtr)
for rows.Next() { for rows.Next() {
err := rows.Scan(scanContext.row...) err := rows.Scan(scanContext.row...)
@ -97,7 +99,7 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{})
begin := time.Now() begin := time.Now()
_, err = mapRowToSlice(scanContext, "", reflect.ValueOf(slicePtr), nil) _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
if err != nil { if err != nil {
return err return err
@ -406,9 +408,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
updated = true updated = true
} else if isGoBaseType(field.Type) { } else if isGoBaseType(field.Type) {
cellValue := getCellValue(scanContext, tableName, fieldName) cellValue := getCellValue(scanContext, tableName, fieldName)
//spew.Dump(rowElem)
//spew.Dump(rowColumnValue, fieldValue)
if cellValue != nil { if cellValue != nil {
updated = true updated = true
initializeValueIfNil(fieldValue) initializeValueIfNil(fieldValue)
@ -441,9 +441,7 @@ func initializeValueIfNil(value reflect.Value) {
return return
} }
if value.Type().Kind() == reflect.Slice && value.IsNil() { if value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type()).Elem())
} else if value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem())) value.Set(reflect.New(value.Type().Elem()))
} }
} }

View file

@ -8,12 +8,18 @@ import (
// NullByteArray // NullByteArray
type NullByteArray struct { type NullByteArray struct {
ByteArray []byte ByteArray []byte
Valid bool // Valid is true if Time is not NULL Valid bool
} }
// Scan implements the Scanner interface. // Scan implements the Scanner interface.
func (nb *NullByteArray) Scan(value interface{}) error { func (nb *NullByteArray) Scan(value interface{}) error {
nb.ByteArray, nb.Valid = value.([]byte) switch v := value.(type) {
case []byte:
nb.ByteArray = append(v[:0:0], v...)
nb.Valid = true
default:
nb.Valid = false
}
return nil return nil
} }

View file

@ -2,6 +2,7 @@ package tests
import ( import (
"fmt" "fmt"
"github.com/davecgh/go-spew/spew"
. "github.com/go-jet/jet/sqlbuilder" . "github.com/go-jet/jet/sqlbuilder"
"github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model" "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model"
. "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table"
@ -28,13 +29,18 @@ func TestAllTypesSelect(t *testing.T) {
func TestAllTypesInsert(t *testing.T) { func TestAllTypesInsert(t *testing.T) {
query := AllTypes.INSERT(AllTypes.AllColumns...). query := AllTypes.INSERT(AllTypes.AllColumns...).
MODEL(allTypesRow0). MODEL(allTypesRow0).
MODEL(&allTypesRow1) MODEL(&allTypesRow1).
RETURNING(AllTypes.AllColumns)
_, err := query.Execute(db) dest := []model.AllTypes{}
err := query.Query(db, &dest)
spew.Dump(dest[0])
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
fmt.Println(query.DebugSql()) assert.DeepEqual(t, dest[0], allTypesRow0)
assert.DeepEqual(t, dest[1], allTypesRow1)
} }
func TestExpressionOperators(t *testing.T) { func TestExpressionOperators(t *testing.T) {
@ -361,7 +367,7 @@ var allTypesRow0 = model.AllTypes{
Character: "JOHN ", Character: "JOHN ",
TextPtr: stringPtr("Some text"), TextPtr: stringPtr("Some text"),
Text: "Some text", Text: "Some text",
ByteaPtr: []byte("bytea"), ByteaPtr: byteArrayPtr([]byte("bytea")),
Bytea: []byte("bytea"), Bytea: []byte("bytea"),
TimestampzPtr: timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), TimestampzPtr: timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),
Timestampz: *timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0),

View file

@ -56,5 +56,5 @@ func TestGenerateModel(t *testing.T) {
staff := model.Staff{} staff := model.Staff{}
assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string") assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string")
assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "[]uint8") assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8")
} }

View file

@ -41,6 +41,10 @@ func stringPtr(s string) *string {
return &s return &s
} }
func byteArrayPtr(arr []byte) *[]byte {
return &arr
}
func float32Ptr(f float32) *float32 { func float32Ptr(f float32) *float32 {
return &f return &f
} }