diff --git a/generator/postgres-metadata/column_info.go b/generator/postgres-metadata/column_info.go index 7d571b1..2e825a8 100644 --- a/generator/postgres-metadata/column_info.go +++ b/generator/postgres-metadata/column_info.go @@ -74,7 +74,7 @@ func (c ColumnInfo) GoBaseType() string { func (c ColumnInfo) GoModelType() string { typeStr := c.GoBaseType() - if c.IsNullable && !strings.HasPrefix(typeStr, "[]") { + if c.IsNullable { return "*" + typeStr } diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index a3f0de7..47b8fc6 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -86,6 +86,8 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) groupTime := time.Duration(0) + slicePtrValue := reflect.ValueOf(slicePtr) + for rows.Next() { err := rows.Scan(scanContext.row...) @@ -97,7 +99,7 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) begin := time.Now() - _, err = mapRowToSlice(scanContext, "", reflect.ValueOf(slicePtr), nil) + _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil) if err != nil { return err @@ -406,9 +408,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re updated = true } else if isGoBaseType(field.Type) { cellValue := getCellValue(scanContext, tableName, fieldName) - //spew.Dump(rowElem) - //spew.Dump(rowColumnValue, fieldValue) if cellValue != nil { updated = true initializeValueIfNil(fieldValue) @@ -441,9 +441,7 @@ func initializeValueIfNil(value reflect.Value) { return } - if value.Type().Kind() == reflect.Slice && value.IsNil() { - value.Set(reflect.New(value.Type()).Elem()) - } else if value.Kind() == reflect.Ptr && value.IsNil() { + if value.Kind() == reflect.Ptr && value.IsNil() { value.Set(reflect.New(value.Type().Elem())) } } diff --git a/sqlbuilder/execution/null_types.go b/sqlbuilder/execution/null_types.go index b47af8b..0cb57b6 100644 --- a/sqlbuilder/execution/null_types.go +++ b/sqlbuilder/execution/null_types.go @@ -8,12 +8,18 @@ import ( // NullByteArray type NullByteArray struct { ByteArray []byte - Valid bool // Valid is true if Time is not NULL + Valid bool } // Scan implements the Scanner interface. func (nb *NullByteArray) Scan(value interface{}) error { - nb.ByteArray, nb.Valid = value.([]byte) + switch v := value.(type) { + case []byte: + nb.ByteArray = append(v[:0:0], v...) + nb.Valid = true + default: + nb.Valid = false + } return nil } diff --git a/tests/all_types_test.go b/tests/all_types_test.go index ff6a467..4c91385 100644 --- a/tests/all_types_test.go +++ b/tests/all_types_test.go @@ -2,6 +2,7 @@ package tests import ( "fmt" + "github.com/davecgh/go-spew/spew" . "github.com/go-jet/jet/sqlbuilder" "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table" @@ -28,13 +29,18 @@ func TestAllTypesSelect(t *testing.T) { func TestAllTypesInsert(t *testing.T) { query := AllTypes.INSERT(AllTypes.AllColumns...). MODEL(allTypesRow0). - MODEL(&allTypesRow1) + MODEL(&allTypesRow1). + RETURNING(AllTypes.AllColumns) - _, err := query.Execute(db) + dest := []model.AllTypes{} + err := query.Query(db, &dest) + + spew.Dump(dest[0]) assert.NilError(t, err) - - fmt.Println(query.DebugSql()) + assert.Equal(t, len(dest), 2) + assert.DeepEqual(t, dest[0], allTypesRow0) + assert.DeepEqual(t, dest[1], allTypesRow1) } func TestExpressionOperators(t *testing.T) { @@ -361,7 +367,7 @@ var allTypesRow0 = model.AllTypes{ Character: "JOHN ", TextPtr: stringPtr("Some text"), Text: "Some text", - ByteaPtr: []byte("bytea"), + ByteaPtr: byteArrayPtr([]byte("bytea")), Bytea: []byte("bytea"), TimestampzPtr: timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *timestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), diff --git a/tests/main_test.go b/tests/main_test.go index 97d1ffa..916550b 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -56,5 +56,5 @@ func TestGenerateModel(t *testing.T) { staff := model.Staff{} assert.Equal(t, reflect.TypeOf(staff.Email).String(), "*string") - assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "[]uint8") + assert.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8") } diff --git a/tests/test_util.go b/tests/test_util.go index 7409b7b..c94d74e 100644 --- a/tests/test_util.go +++ b/tests/test_util.go @@ -41,6 +41,10 @@ func stringPtr(s string) *string { return &s } +func byteArrayPtr(arr []byte) *[]byte { + return &arr +} + func float32Ptr(f float32) *float32 { return &f }