From bb18e4d0f7aa0011a2dfe4b0101261ea1e0d212e Mon Sep 17 00:00:00 2001 From: sub0Zero Date: Thu, 14 Mar 2019 09:18:23 +0100 Subject: [PATCH] Add support for slice model scan. --- sqlbuilder/execution/execution.go | 287 +++++++++++++++++++++++++----- tests/generator_test.go | 172 ++++++++++++++---- 2 files changed, 385 insertions(+), 74 deletions(-) diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index f3f7784..1f67978 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -3,8 +3,10 @@ package execution import ( "database/sql" "errors" + "fmt" "github.com/serenize/snaker" "reflect" + "strings" "time" ) @@ -31,31 +33,30 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { columnNames, _ := rows.Columns() columnTypes, _ := rows.ColumnTypes() - values := createScanValue(columnTypes) - // - //spew.Dump(columnTypes) - //spew.Dump(values) + rowData := createScanValue(columnTypes) + + scanContext := &scanContext{ + columnNames: columnNames, + uniqueObjectsMap: make(map[string]interface{}), + } for rows.Next() { - err := rows.Scan(values...) + err := rows.Scan(rowData...) if err != nil { return err } + columnProcessed := make([]bool, len(columnTypes)) + if destinationType.Elem().Kind() == reflect.Slice { - - destinationStructPtr := newElemForSlice(destinationPtr) - - err = mapValuesToStruct(columnNames, values, destinationStructPtr) + err := mapRowToSlice(scanContext, "", columnProcessed, rowData, destinationPtr) if err != nil { return err } - - appendElemToSlice(destinationPtr, destinationStructPtr) } else if destinationType.Elem().Kind() == reflect.Struct { - return mapValuesToStruct(columnNames, values, destinationPtr) + return mapRowToStruct(scanContext, "", columnProcessed, rowData, destinationPtr) } } @@ -68,20 +69,213 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { return nil } -func appendElemToSlice(slice interface{}, obj interface{}) { - //spew.Dump(slice) - sliceValue := reflect.ValueOf(slice).Elem() +type scanContext struct { + columnNames []string + uniqueObjectsMap map[string]interface{} +} - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(obj).Elem())) +func getColumnTypeName(columnName string) (string, error) { + split := strings.Split(columnName, ".") + if len(split) != 2 { + return "", errors.New("Invalid column name") + } + + return split[0], nil +} + +func allProcessed(arr []bool) bool { + for _, b := range arr { + if !b { + return false + } + } + + return true +} + +func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect.Type) string { + structName := structType.Name() + groupKey := "" + + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + + ////fmt.Println(fieldType.Tag) + + if fieldType.Tag == `sql:"unique"` { + fieldName := fieldType.Name + columnName := snaker.CamelToSnake(structName) + "." + snaker.CamelToSnake(fieldName) + + //fmt.Println(fieldName) + rowIndex := getIndex(scanContext.columnNames, columnName) + + if rowIndex < 0 { + continue + } + + rowValue := reflect.ValueOf(row[rowIndex]) + + groupKey = groupKey + reflectValueToString(rowValue) + } else if !isDbBaseType(fieldType.Type) { + var structType reflect.Type + if fieldType.Type.Kind() == reflect.Struct { + structType = fieldType.Type + } else if fieldType.Type.Kind() == reflect.Ptr && fieldType.Type.Elem().Kind() == reflect.Struct { + structType = fieldType.Type.Elem() + } else { + continue + } + + //spew.Dump(structType) + + structGroupKey := getGroupKey(scanContext, row, structType) + + //groupKey = strings.Join([]string{structGroupKey, groupKey}, ":") + + groupKey = groupKey + structGroupKey + } + } + + //fmt.Println(groupKey) + return groupKey +} + +func getSliceStructType(slicePtr interface{}) reflect.Type { + sliceTypePtr := reflect.TypeOf(slicePtr) + + elemType := sliceTypePtr.Elem().Elem() + + if elemType.Kind() == reflect.Ptr { + return elemType.Elem() + } + + return elemType +} + +func mapRowToSlice(scanContext *scanContext, groupKey string, columnProcessed []bool, row []interface{}, destinationPtr interface{}) error { + if allProcessed(columnProcessed) { + return nil + } + + var err error + + structType := getSliceStructType(destinationPtr) + + groupKey = groupKey + ":" + getGroupKey(scanContext, row, structType) + + objPtr, ok := scanContext.uniqueObjectsMap[groupKey] + + if ok { + err = mapRowToStruct(scanContext, groupKey, columnProcessed, row, objPtr) + if err != nil { + return err + } + } else { + destinationStructPtr := newElemForSlice(destinationPtr) + + err = mapRowToStruct(scanContext, groupKey, columnProcessed, row, destinationStructPtr) + + if err != nil { + return err + } + + elemPtr := appendElemToSlice(destinationPtr, destinationStructPtr) + scanContext.uniqueObjectsMap[groupKey] = elemPtr + } + + return err +} + +func appendElemToSlice(slice interface{}, objPtr interface{}) interface{} { + sliceValue := reflect.ValueOf(slice).Elem() + elemType := sliceValue.Type().Elem() + + if elemType.Kind() == reflect.Ptr { + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(objPtr))) + return sliceValue.Index(sliceValue.Len() - 1).Interface() + } + + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(objPtr).Elem())) + + return sliceValue.Index(sliceValue.Len() - 1).Addr().Interface() } func newElemForSlice(destinationSlicePtr interface{}) interface{} { destinationSliceType := reflect.TypeOf(destinationSlicePtr).Elem() + elemType := destinationSliceType.Elem() - return reflect.New(destinationSliceType.Elem()).Interface() + if elemType.Kind() == reflect.Ptr { + return reflect.New(elemType.Elem()).Interface() + } + + return reflect.New(elemType).Interface() } -func mapValuesToStruct(columnNames []string, row []interface{}, destination interface{}) error { +func mapRowToDestinationValue(scanContext *scanContext, groupKey string, columnProcessed []bool, row []interface{}, dest reflect.Value) error { + if dest.Kind() == reflect.Struct { + err := mapRowToStruct(scanContext, groupKey, columnProcessed, row, dest.Addr().Interface()) + if err != nil { + return err + } + } else if dest.Kind() == reflect.Slice { + err := mapRowToSlice(scanContext, groupKey, columnProcessed, row, dest.Addr().Interface()) + if err != nil { + return err + } + } else if dest.Kind() == reflect.Ptr { + elemType := dest.Type().Elem() + + if elemType.Kind() == reflect.Struct { + var structValuePtr reflect.Value + + if dest.IsNil() { + structValuePtr = reflect.New(elemType) + } else { + return nil + } + + err := mapRowToStruct(scanContext, groupKey, columnProcessed, row, structValuePtr.Interface()) + if err != nil { + return err + } + + if structValuePtr.Elem().Interface() != reflect.New(elemType).Elem().Interface() { + dest.Set(structValuePtr) + } + + } else if elemType.Kind() == reflect.Slice { + var sliceValuePtr reflect.Value + + if dest.IsNil() { + sliceValuePtr = reflect.New(elemType) + } else { + sliceValuePtr = dest + } + + err := mapRowToSlice(scanContext, groupKey, columnProcessed, row, sliceValuePtr.Interface()) + if err != nil { + return err + } + + if sliceValuePtr.Elem().Len() > 0 { + dest.Set(sliceValuePtr) + } + + } else { + return errors.New("Unsuported field type: " + dest.Type().Name()) + } + } else { + return errors.New("Unsuported field type: " + dest.Type().Name()) + } + + return nil +} + +func mapRowToStruct(scanContext *scanContext, groupKey string, columnProcessed []bool, row []interface{}, destination interface{}) error { + if allProcessed(columnProcessed) { + return nil + } + structType := reflect.TypeOf(destination).Elem() structValue := reflect.ValueOf(destination).Elem() structName := structType.Name() @@ -91,50 +285,56 @@ func mapValuesToStruct(columnNames []string, row []interface{}, destination inte //fieldTypeName := fieldType.Name fieldValue := structValue.Field(i) //fmt.Println("---------------", fieldTypeName) - //spew.Dump(fieldType.Type) + ////spew.Dump(fieldType.Type) + + fieldName := fieldType.Name if !isDbBaseType(fieldType.Type) { - if fieldType.Type.Kind() == reflect.Struct { - err := mapValuesToStruct(columnNames, row, fieldValue.Addr().Interface()) - if err != nil { - return err - } - } else if fieldType.Type.Kind() == reflect.Ptr { - newStructValue := reflect.New(fieldType.Type.Elem()) - err := mapValuesToStruct(columnNames, row, newStructValue.Interface()) - if err != nil { - return err - } + //var fieldValueInterface interface{} + err := mapRowToDestinationValue(scanContext, groupKey, columnProcessed, row, fieldValue) - if newStructValue.Elem().Interface() != reflect.New(fieldType.Type.Elem()).Elem().Interface() { - fieldValue.Set(newStructValue) - } + if err != nil { + return err } } else { - fieldName := fieldType.Name - columnName := snaker.CamelToSnake(structName) + "." + snaker.CamelToSnake(fieldName) //columnName := snaker.CamelToSnake(fieldName) - //fmt.Println(columnName) - rowIndex := getIndex(columnNames, columnName) + ////fmt.Println(columnName) + rowIndex := getIndex(scanContext.columnNames, columnName) - if rowIndex < 0 { + if rowIndex < 0 || columnProcessed[rowIndex] { continue } - - //spew.Dump(row[rowIndex]) + ////spew.Dump(row[rowIndex]) rowColumnValue := reflect.ValueOf(row[rowIndex]) //spew.Dump(rowColumnValue, fieldValue) setReflectValue(rowColumnValue, fieldValue) + + columnProcessed[rowIndex] = true } } return nil } +func reflectValueToString(value reflect.Value) string { + var valueInterface interface{} + if value.Kind() == reflect.Ptr { + valueInterface = value.Elem().Interface() + } else { + valueInterface = value.Interface() + } + + if t, ok := valueInterface.(time.Time); ok { + return t.String() + } + + return fmt.Sprintf("%#v", valueInterface) +} + var timeType = reflect.TypeOf(time.Now()) var floatType = reflect.TypeOf(1.0) var stringType = reflect.TypeOf("str") @@ -147,9 +347,8 @@ func isDbBaseType(objType reflect.Type) bool { typeStr := objType.String() switch typeStr { - case "string", "int32", "int16", "float64", "time.Time": - return true - case "*string", "*int32", "*int16", "*float64", "*time.Time": + case "string", "int32", "int16", "float64", "time.Time", "bool", + "*string", "*int32", "*int16", "*float64", "*time.Time", "*bool": return true } @@ -199,7 +398,7 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} { func getScanType(columnType *sql.ColumnType) reflect.Type { scanType := columnType.ScanType() - //fmt.Println(scanType.String()) + //////fmt.Println(scanType.String()) if scanType.String() != "interface {}" { return scanType } diff --git a/tests/generator_test.go b/tests/generator_test.go index 177ba22..8ec4646 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -10,6 +10,7 @@ import ( "gotest.tools/assert" "os" "testing" + "time" ) const ( @@ -55,18 +56,23 @@ func TestGenerateModel(t *testing.T) { //assert.NilError(t, err) } -func TestSelectQuery(t *testing.T) { - //query := Actor.InnerJoinOn(Store, Eq(Actor.ActorID, Store.StoreID)). - // Select(Store.StoreID, Store.AddressID, Actor.ActorID) - // - //queryStr, err := query.String(schemaName) - // - //assert.NilError(t, err) - // - //assert.Equal(t, queryStr, "SELECT store.store_id,store.address_id,actor.actor_id FROM dvds.actor JOIN dvds.store ON actor.actor_id=store.store_id") - // - //err = query.Execute(db, nil) +func TestSelect_ScanToStruct(t *testing.T) { + actor := model.Actor{} + err := Actor.Select(Actor.All...).Execute(db, &actor) + assert.NilError(t, err) + + expectedActor := model.Actor{ + ActorID: 1, + FirstName: "Penelope", + LastName: "Guiness", + LastUpdate: *timeWithoutTimeZone("2013-05-26 14:47:57.62 +0000"), + } + + assert.DeepEqual(t, actor, expectedActor) +} + +func TestSelect_ScanToSlice(t *testing.T) { customers := []model.Customer{} query := Customer.Select(Customer.All...) @@ -75,51 +81,157 @@ func TestSelectQuery(t *testing.T) { assert.NilError(t, err) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id",customer.store_id AS "customer.store_id",customer.first_name AS "customer.first_name",customer.last_name AS "customer.last_name",customer.email AS "customer.email",customer.address_id AS "customer.address_id",customer.activebool AS "customer.activebool",customer.create_date AS "customer.create_date",customer.last_update AS "customer.last_update",customer.active AS "customer.active" FROM dvds.customer`) - //fmt.Println(queryStr) err = query.Execute(db, &customers) - //fmt.Println(customers) - // - //spew.Sdump(customers) - assert.NilError(t, err) + customer0 := model.Customer{ + CustomerID: 524, + StoreID: 1, + FirstName: "Jared", + LastName: "Ely", + Email: stringPtr("austin.cintron@sakilacustomer.org"), + Address: nil, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + Active: int32Ptr(1), + } + + customer1 := model.Customer{ + CustomerID: 1, + StoreID: 1, + FirstName: "Mary", + LastName: "Smith", + Email: stringPtr("austin.cintron@sakilacustomer.org"), + Address: nil, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + Active: int32Ptr(1), + } + + lastCustomer := model.Customer{ + CustomerID: 599, + StoreID: 2, + FirstName: "Austin", + LastName: "Cintron", + Email: stringPtr("austin.cintron@sakilacustomer.org"), + Address: nil, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + Active: int32Ptr(1), + } + assert.Equal(t, len(customers), 599) - actor := model.Actor{} - err = Actor.Select(Actor.All...).Execute(db, &actor) - - assert.NilError(t, err) - - //spew.Dump(actor) - //time, _ := time.Parse("2006-01-02 15:04:05.00MST", "2013-05-26 14:47:57.62MST") - assert.Equal(t, actor.ActorID, int32(1)) - assert.Equal(t, actor.FirstName, "Penelope") - assert.Equal(t, actor.LastName, "Guiness") + assert.DeepEqual(t, customer0, customers[0]) + assert.DeepEqual(t, customer1, customers[1]) + assert.DeepEqual(t, lastCustomer, customers[598]) } -func TestJoinQuery(t *testing.T) { +func TestJoinQueryStruct(t *testing.T) { //filmActor := model.FilmActor{} - allFilmActorColumns := append(append(Actor.All, Film.All...), Language.All...) + allFilmActorColumns := append(append(append(FilmActor.All, Film.All...), Language.All...), Actor.All...) query := FilmActor. InnerJoinOn(Actor, sqlbuilder.Eq(FilmActor.ActorID, Actor.ActorID)). InnerJoinOn(Film, sqlbuilder.Eq(FilmActor.FilmID, Film.FilmID)). InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). Select(allFilmActorColumns...). - Where(sqlbuilder.Eq(FilmActor.ActorID, sqlbuilder.Literal(1))) + Where(sqlbuilder.And(sqlbuilder.Gte(FilmActor.ActorID, sqlbuilder.Literal(1)), sqlbuilder.Lte(FilmActor.ActorID, sqlbuilder.Literal(2)))) queryStr, err := query.String() assert.NilError(t, err) fmt.Println(queryStr) - filmActor := model.FilmActor{} + filmActor := []model.FilmActor{} err = query.Execute(db, &filmActor) assert.NilError(t, err) + //fmt.Println("ACTORS: --------------------") //spew.Dump(filmActor) } + +func TestJoinQuerySlice(t *testing.T) { + type FilmsPerLanguage struct { + Language *model.Language + Films *[]model.Film + } + + filmsPerLanguage := []FilmsPerLanguage{} + limit := 15 + + query := Film.InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). + Select(append(Language.All, Film.All...)...). + Limit(15) + + queryStr, _ := query.String() + + fmt.Println(queryStr) + + err := query.Execute(db, &filmsPerLanguage) + + assert.NilError(t, err) + + //fmt.Println("--------------- result --------------- ") + //spew.Dump(filmsPerLanguage) + + assert.Equal(t, len(filmsPerLanguage), 1) + assert.Equal(t, len(*filmsPerLanguage[0].Films), limit) + + //spew.Dump(filmsPerLanguage) + + filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} + err = query.Execute(db, &filmsPerLanguageWithPtrs) + + assert.NilError(t, err) + assert.Equal(t, len(filmsPerLanguage), 1) + assert.Equal(t, len(*filmsPerLanguage[0].Films), limit) + +} + +func TestJoinQuerySliceWithPtrs(t *testing.T) { + type FilmsPerLanguage struct { + Language model.Language + Films *[]*model.Film + } + + limit := int64(3) + + query := Film.InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). + Select(append(Language.All, Film.All...)...). + Limit(limit) + + filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} + err := query.Execute(db, &filmsPerLanguageWithPtrs) + + //spew.Dump(filmsPerLanguageWithPtrs) + + assert.NilError(t, err) + assert.Equal(t, len(filmsPerLanguageWithPtrs), 1) + assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Films), int(limit)) +} + +func int32Ptr(i int32) *int32 { + return &i +} + +func stringPtr(s string) *string { + return &s +} + +func timeWithoutTimeZone(t string) *time.Time { + time, err := time.Parse("2006-01-02 15:04:05 -0700", t) + + if err != nil { + panic(err) + } + + return &time +}