diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go index daaf209..090141f 100644 --- a/examples/quick-start/quick-start.go +++ b/examples/quick-start/quick-start.go @@ -4,11 +4,12 @@ import ( "database/sql" "encoding/json" "fmt" + _ "github.com/lib/pq" - . "github.com/go-jet/jet" // dot import so go code would resemble as much as native SQL - . "github.com/go-jet/jet/examples/quick-start/gen/jetdb/dvds/table" // dot import is not mandatory + . "github.com/go-jet/jet" // dot import so go code would resemble as much as native SQL + . "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/table" // dot import is not mandatory - "github.com/go-jet/jet/examples/quick-start/gen/jetdb/dvds/model" + "github.com/go-jet/jet/examples/quick-start/.gen/jetdb/dvds/model" "github.com/go-jet/jet/tests/dbconfig" ) diff --git a/execution/execution.go b/execution/execution.go index 05b4325..9d79444 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "github.com/go-jet/jet/execution/internal" - "github.com/go-jet/jet/internal/utils" "reflect" "strconv" "strings" @@ -132,48 +131,33 @@ func queryToSlice(db DB, ctx context.Context, query string, args []interface{}, return nil } -func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { +func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { sliceElemType := getSliceElemType(slicePtrValue) if isGoBaseType(sliceElemType) { - index := 0 - if structField != nil { - if index = scanContext.aliasColumnIndex(structField.Tag.Get("alias")); index < 0 { - return - } - } - rowElemPtr := scanContext.rowElemValuePtr(index) - - if !rowElemPtr.IsNil() { - updated = true - err = appendElemToSlice(slicePtrValue, rowElemPtr) - if err != nil { - return - } - } - + updated, err = mapRowToBaseTypeSlice(scanContext, slicePtrValue, field) return } if sliceElemType.Kind() != reflect.Struct { - return false, errors.New("jet: Unsupported dest type: " + structField.Name + " " + structField.Type.String()) + return false, errors.New("jet: Unsupported dest type: " + field.Name + " " + field.Type.String()) } - structGroupKey := scanContext.getGroupKey(sliceElemType, structField) + structGroupKey := scanContext.getGroupKey(sliceElemType, field) - groupKey = groupKey + ":" + structGroupKey + groupKey = groupKey + "," + structGroupKey index, ok := scanContext.uniqueDestObjectsMap[groupKey] if ok { structPtrValue := getSliceElemPtrAt(slicePtrValue, index) - return mapRowToStruct(scanContext, groupKey, structPtrValue, structField, true) + return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true) } else { destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) - updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, structField) + updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field) if err != nil { return @@ -192,6 +176,213 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl return } +func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { + index := 0 + if field != nil { + typeName, columnName := getTypeAndFieldName("", *field) + if index = scanContext.typeToColumnIndex(typeName, columnName); index < 0 { + return + } + } + rowElemPtr := scanContext.rowElemValuePtr(index) + + if !rowElemPtr.IsNil() { + updated = true + err = appendElemToSlice(slicePtrValue, rowElemPtr) + if err != nil { + return + } + } + + return +} + +type typeInfo struct { + fieldMappings []fieldMapping +} + +type fieldMapping struct { + complexType bool + columnIndex int + implementsScanner bool +} + +func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { + + typeMapKey := structType.String() + + if parentField != nil { + typeMapKey += string(parentField.Tag) + } + + if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok { + return typeInfo + } + + typeName := getTypeName(structType, parentField) + + newTypeInfo := typeInfo{} + + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + + newTypeName, fieldName := getTypeAndFieldName(typeName, field) + columnIndex := s.typeToColumnIndex(newTypeName, fieldName) + + fieldMap := fieldMapping{ + columnIndex: columnIndex, + } + + if implementsScannerType(field.Type) { + fieldMap.implementsScanner = true + } else if !isGoBaseType(field.Type) { + fieldMap.complexType = true + } + + newTypeInfo.fieldMappings = append(newTypeInfo.fieldMappings, fieldMap) + } + + s.typeInfoMap[typeMapKey] = newTypeInfo + + return newTypeInfo +} + +func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { + structType := structPtrValue.Type().Elem() + + typeInf := scanContext.getTypeInfo(structType, parentField) + + structValue := structPtrValue.Elem() + + for i := 0; i < structValue.NumField(); i++ { + field := structType.Field(i) + fieldValue := structValue.Field(i) + + fieldMap := typeInf.fieldMappings[i] + + if fieldMap.complexType { + var changed bool + changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) + + if err != nil { + return + } + + if changed { + updated = true + } + + } else if len(onlySlices) == 0 { + + if fieldMap.columnIndex == -1 { + continue + } + + if fieldMap.implementsScanner { + + cellValue := scanContext.rowElem(fieldMap.columnIndex) + + if cellValue == nil { + continue + } + + initializeValueIfNilPtr(fieldValue) + + scanner := getScanner(fieldValue) + + err = scanner.Scan(cellValue) + + if err != nil { + err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) + return + } + updated = true + } else { + cellValue := scanContext.rowElem(fieldMap.columnIndex) + + if cellValue != nil { + updated = true + initializeValueIfNilPtr(fieldValue) + err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) + + if err != nil { + err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) + return + } + } + } + } + } + + return +} + +func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { + + if destPtrValue.Kind() != reflect.Ptr { + return false, errors.New("jet: Internal error. ") + } + + destValueKind := destPtrValue.Elem().Kind() + + if destValueKind == reflect.Struct { + return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) + } else if destValueKind == reflect.Slice { + return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) + } else { + return false, errors.New("jet: Unsupported dest type: " + structField.Name + " " + structField.Type.String()) + } +} + +func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { + + var destPtrValue reflect.Value + + if dest.Kind() != reflect.Ptr { + destPtrValue = dest.Addr() + } else if dest.Kind() == reflect.Ptr { + if dest.IsNil() { + destPtrValue = reflect.New(dest.Type().Elem()) + } else { + destPtrValue = dest + } + } else { + return false, errors.New("jet: Internal error. ") + } + + updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) + + if err != nil { + return + } + + if dest.Kind() == reflect.Ptr && dest.IsNil() && updated { + dest.Set(destPtrValue) + } + + return +} + +var scannerInterfaceType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +func implementsScannerType(fieldType reflect.Type) bool { + if fieldType.Implements(scannerInterfaceType) { + return true + } + + typePtr := reflect.New(fieldType).Type() + + return typePtr.Implements(scannerInterfaceType) +} + +func getScanner(value reflect.Value) sql.Scanner { + if scanner, ok := value.Interface().(sql.Scanner); ok { + return scanner + } + + return value.Addr().Interface().(sql.Scanner) +} + func getSliceElemType(slicePtrValue reflect.Value) reflect.Type { sliceTypePtr := slicePtrValue.Type() @@ -244,120 +435,6 @@ func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value { return reflect.New(elemType) } -func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { - - if destPtrValue.Kind() != reflect.Ptr { - return false, errors.New("jet: Internal error. ") - } - - destValueKind := destPtrValue.Elem().Kind() - - if destValueKind == reflect.Struct { - return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) - } else if destValueKind == reflect.Slice { - return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) - } else { - return false, errors.New("jet: Unsupported dest type: " + structField.Name + " " + structField.Type.String()) - } -} - -func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { - - var destPtrValue reflect.Value - - if dest.Kind() != reflect.Ptr { - destPtrValue = dest.Addr() - } else if dest.Kind() == reflect.Ptr { - if dest.IsNil() { - destPtrValue = reflect.New(dest.Type().Elem()) - } else { - destPtrValue = dest - } - } else { - return false, errors.New("jet: Internal error. ") - } - - updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) - - if err != nil { - return - } - - if dest.Kind() == reflect.Ptr && dest.IsNil() && updated { - dest.Set(destPtrValue) - } - - return -} - -func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) { - structType := structPtrValue.Type().Elem() - structValue := structPtrValue.Elem() - - typeName := getTypeName(structType, parentField) - - for i := 0; i < structType.NumField(); i++ { - field := structType.Field(i) - - fieldValue := structValue.Field(i) - fieldName := field.Name - - if scannerValue, ok := implementsScanner(fieldValue); ok { - if len(onlySlices) > 0 { - continue - } - - cellValue := scanContext.getCellValue(typeName, fieldName) - - if cellValue == nil { - continue - } - - initializeValueIfNilPtr(fieldValue) - - scanner := scannerValue.Interface().(sql.Scanner) - - err = scanner.Scan(cellValue) - - if err != nil { - err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) - return - } - updated = true - } else if isGoBaseType(field.Type) { - if len(onlySlices) > 0 { - continue - } - - cellValue := scanContext.getCellValue(typeName, fieldName) - - if cellValue != nil { - updated = true - initializeValueIfNilPtr(fieldValue) - err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) - - if err != nil { - err = fmt.Errorf("%s, at struct field: %s %s of type %s. ", err.Error(), field.Name, field.Type.String(), structType.String()) - return - } - } - } else { - var changed bool - changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) - - if err != nil { - return - } - - if changed { - updated = true - } - } - } - - return -} - func getTypeName(structType reflect.Type, parentField *reflect.StructField) string { if parentField == nil { return structType.Name() @@ -371,7 +448,29 @@ func getTypeName(structType reflect.Type, parentField *reflect.StructField) stri aliasParts := strings.Split(aliasTag, ".") - return aliasParts[0] + return toCommonIdentifier(aliasParts[0]) +} + +func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) { + aliasTag := field.Tag.Get("alias") + + if aliasTag == "" { + return structType, field.Name + } + + aliasParts := strings.Split(aliasTag, ".") + + if len(aliasParts) == 1 { + return structType, toCommonIdentifier(aliasParts[0]) + } + + return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1]) +} + +var replacer = strings.NewReplacer(" ", "", "-", "", "_", "") + +func toCommonIdentifier(name string) string { + return strings.ToLower(replacer.Replace(name)) } func initializeValueIfNilPtr(value reflect.Value) { @@ -384,18 +483,6 @@ func initializeValueIfNilPtr(value reflect.Value) { } } -func implementsScanner(value reflect.Value) (reflect.Value, bool) { - if _, ok := value.Interface().(sql.Scanner); ok { - return value, true - } else if value.CanAddr() { - if _, ok := value.Addr().Interface().(sql.Scanner); ok { - return value.Addr(), true - } - } - - return value, false -} - func valueToString(value reflect.Value) string { if !value.IsValid() { @@ -520,10 +607,11 @@ type scanContext struct { row []interface{} uniqueDestObjectsMap map[string]int - aliasIndexMap map[string]int - goNameMap map[string]int + typeToColumnIndexMap map[string]int groupKeyInfoCache map[string]groupKeyInfo + + typeInfoMap map[string]typeInfo } func newScanContext(rows *sql.Rows) (*scanContext, error) { @@ -539,36 +627,37 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) { return nil, err } - aliasIndexMap := map[string]int{} - - for i, columnName := range aliases { - aliasIndexMap[strings.ToLower(columnName)] = i - } - - goNamesMap := map[string]int{} + typeToIndexMap := map[string]int{} for i, alias := range aliases { names := strings.SplitN(alias, ".", 2) - goName := utils.ToGoIdentifier(names[0]) + goName := toCommonIdentifier(names[0]) if len(names) > 1 { - goName += "." + utils.ToGoIdentifier(names[1]) + goName += "." + toCommonIdentifier(names[1]) } - goNamesMap[strings.ToLower(goName)] = i + typeToIndexMap[strings.ToLower(goName)] = i } return &scanContext{ row: createScanValue(columnTypes), uniqueDestObjectsMap: make(map[string]int), - groupKeyInfoCache: make(map[string]groupKeyInfo), - aliasIndexMap: aliasIndexMap, - goNameMap: goNamesMap, + groupKeyInfoCache: make(map[string]groupKeyInfo), + typeToColumnIndexMap: typeToIndexMap, + + typeInfoMap: make(map[string]typeInfo), }, nil } +type groupKeyInfo struct { + typeName string + indexes []int + subTypes []groupKeyInfo +} + func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { mapKey := structType.Name() @@ -607,7 +696,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) } - return "{" + groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")}" + return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" } func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { @@ -617,6 +706,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) + newTypeName, fieldName := getTypeAndFieldName(typeName, field) if !isGoBaseType(field.Type) { var structType reflect.Type @@ -634,7 +724,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl ret.subTypes = append(ret.subTypes, subType) } } else if isPrimaryKey(field) { - index := s.typeColumnIndex(typeName, field.Name) + index := s.typeToColumnIndex(newTypeName, fieldName) if index < 0 { continue @@ -647,23 +737,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *refl return ret } -type groupKeyInfo struct { - typeName string - indexes []int - subTypes []groupKeyInfo -} - -func (s *scanContext) aliasColumnIndex(alias string) int { - index, ok := s.aliasIndexMap[alias] - - if !ok { - return -1 - } - - return index -} - -func (s *scanContext) typeColumnIndex(typeName, fieldName string) int { +func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { var key string if typeName != "" { @@ -672,7 +746,7 @@ func (s *scanContext) typeColumnIndex(typeName, fieldName string) int { key = strings.ToLower(fieldName) } - index, ok := s.goNameMap[key] + index, ok := s.typeToColumnIndexMap[key] if !ok { return -1 @@ -682,7 +756,7 @@ func (s *scanContext) typeColumnIndex(typeName, fieldName string) int { } func (s *scanContext) getCellValue(typeName, fieldName string) interface{} { - index := s.typeColumnIndex(typeName, fieldName) + index := s.typeToColumnIndex(typeName, fieldName) if index < 0 { return nil diff --git a/tests/all_types_test.go b/tests/all_types_test.go index 7a545d3..0174c4b 100644 --- a/tests/all_types_test.go +++ b/tests/all_types_test.go @@ -18,8 +18,6 @@ func TestAllTypesSelect(t *testing.T) { fmt.Println(err) assert.NilError(t, err) - assert.Equal(t, len(dest), 2) - assert.DeepEqual(t, dest[0], allTypesRow0) assert.DeepEqual(t, dest[1], allTypesRow1) } @@ -98,7 +96,7 @@ func TestExpressionOperators(t *testing.T) { RAW("current_database()"), ) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -159,11 +157,11 @@ func TestStringOperators(t *testing.T) { TO_HEX(AllTypes.IntegerPtr), ) - _, args, _ := query.Sql() + //_, args, _ := query.Sql() - fmt.Println(query.Sql()) - fmt.Println(args[15]) - fmt.Println(query.DebugSql()) + //fmt.Println(query.Sql()) + //fmt.Println(args[15]) + //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -190,7 +188,7 @@ func TestBoolOperators(t *testing.T) { AllTypes.Boolean.OR(AllTypes.Boolean).EQ(AllTypes.Boolean.AND(AllTypes.Boolean)), ) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -240,7 +238,7 @@ func TestFloatOperators(t *testing.T) { TRUNC(AllTypes.Decimal, Int(1)), ) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -287,7 +285,7 @@ func TestIntegerOperators(t *testing.T) { CBRT(AllTypes.Integer), ) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -348,7 +346,7 @@ func TestTimeOperators(t *testing.T) { NOW(), ) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -522,7 +520,7 @@ FROM` ). FROM(subQuery) - fmt.Println(stmt2.DebugSql()) + //fmt.Println(stmt2.DebugSql()) assertStatementSql(t, stmt2, expectedSql+expected.sql+";\n", expected.args...) diff --git a/tests/chinook_db_test.go b/tests/chinook_db_test.go index 1576dfc..ad7d738 100644 --- a/tests/chinook_db_test.go +++ b/tests/chinook_db_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/davecgh/go-spew/spew" . "github.com/go-jet/jet" "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/chinook/table" @@ -105,12 +104,106 @@ func TestJoinEverything(t *testing.T) { err := stmt.Query(db, &dest) assert.NilError(t, err) - //jsonSave(dest) - fmt.Println("Artist count :", len(dest)) assert.Equal(t, len(dest), 275) - assertJson(t, "./testdata/joined_everything.json", dest) + assertJsonFile(t, "./testdata/joined_everything.json", dest) +} + +func TestSelfJoin(t *testing.T) { + var dest []struct { + model.Employee + + Manager *model.Employee `alias:"Manager.*"` + } + + manager := Employee.AS("Manager") + + stmt := Employee. + LEFT_JOIN(manager, Employee.ReportsTo.EQ(manager.EmployeeId)). + SELECT( + Employee.EmployeeId, + Employee.FirstName, + Employee.LastName, + manager.EmployeeId, + manager.FirstName, + manager.LastName, + ). + ORDER_BY(Employee.EmployeeId) + + assertStatementSql(t, stmt, ` +SELECT "Employee"."EmployeeId" AS "Employee.EmployeeId", + "Employee"."FirstName" AS "Employee.FirstName", + "Employee"."LastName" AS "Employee.LastName", + "Manager"."EmployeeId" AS "Manager.EmployeeId", + "Manager"."FirstName" AS "Manager.FirstName", + "Manager"."LastName" AS "Manager.LastName" +FROM chinook."Employee" + LEFT JOIN chinook."Employee" AS "Manager" ON ("Employee"."ReportsTo" = "Manager"."EmployeeId") +ORDER BY "Employee"."EmployeeId"; +`) + + err := stmt.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 8) + assertJson(t, dest[0:2], ` +[ + { + "EmployeeId": 1, + "LastName": "Adams", + "FirstName": "Andrew", + "Title": null, + "ReportsTo": null, + "BirthDate": null, + "HireDate": null, + "Address": null, + "City": null, + "State": null, + "Country": null, + "PostalCode": null, + "Phone": null, + "Fax": null, + "Email": null, + "Manager": null + }, + { + "EmployeeId": 2, + "LastName": "Edwards", + "FirstName": "Nancy", + "Title": null, + "ReportsTo": null, + "BirthDate": null, + "HireDate": null, + "Address": null, + "City": null, + "State": null, + "Country": null, + "PostalCode": null, + "Phone": null, + "Fax": null, + "Email": null, + "Manager": { + "EmployeeId": 1, + "LastName": "Adams", + "FirstName": "Andrew", + "Title": null, + "ReportsTo": null, + "BirthDate": null, + "HireDate": null, + "Address": null, + "City": null, + "State": null, + "Country": null, + "PostalCode": null, + "Phone": null, + "Fax": null, + "Email": null + } + } +] +`) + } func TestUnionForQuotedNames(t *testing.T) { @@ -121,7 +214,7 @@ func TestUnionForQuotedNames(t *testing.T) { ). ORDER_BY(Album.AlbumId) - fmt.Println(stmt.DebugSql()) + //fmt.Println(stmt.DebugSql()) assertStatementSql(t, stmt, ` ( ( @@ -242,10 +335,17 @@ ORDER BY "first10Artist"."Artist.ArtistId"; assert.NilError(t, err) - spew.Dump(dest) + //spew.Dump(dest) } -func assertJson(t *testing.T, jsonFilePath string, data interface{}) { +func assertJson(t *testing.T, data interface{}, expectedJson string) { + jsonData, err := json.MarshalIndent(data, "", "\t") + assert.NilError(t, err) + + assert.Equal(t, "\n"+string(jsonData)+"\n", expectedJson) +} + +func assertJsonFile(t *testing.T, jsonFilePath string, data interface{}) { fileJsonData, err := ioutil.ReadFile(jsonFilePath) assert.NilError(t, err) @@ -253,6 +353,7 @@ func assertJson(t *testing.T, jsonFilePath string, data interface{}) { assert.NilError(t, err) assert.Assert(t, string(fileJsonData) == string(jsonData)) + //assert.Equal(t, string(fileJsonData), string(jsonData)) } func jsonPrint(v interface{}) { diff --git a/tests/insert_test.go b/tests/insert_test.go index 8a4d539..c636c85 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -89,6 +89,7 @@ INSERT INTO test_sample.link VALUES } func TestInsertModelObject(t *testing.T) { + cleanUpLinkTable(t) var expectedSql = ` INSERT INTO test_sample.link (url, name) VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); diff --git a/tests/sample_test.go b/tests/sample_test.go index b7663b2..f9f251d 100644 --- a/tests/sample_test.go +++ b/tests/sample_test.go @@ -6,25 +6,29 @@ import ( . "github.com/go-jet/jet" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" + "github.com/google/uuid" "gotest.tools/assert" "testing" ) func TestUUIDType(t *testing.T) { query := AllTypes. - SELECT(AllTypes.AllColumns). + SELECT(AllTypes.UUID, AllTypes.UUIDPtr). WHERE(AllTypes.UUID.EQ(String("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))) - queryStr, args, err := query.Sql() + assertStatementSql(t, query, ` +SELECT all_types.uuid AS "all_types.uuid", + all_types.uuid_ptr AS "all_types.uuid_ptr" +FROM test_sample.all_types +WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; +`, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") - assert.NilError(t, err) - assert.Equal(t, len(args), 1) - fmt.Println(queryStr) - //assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid", all_types.json AS "all_types.json", all_types.jsonb AS "all_types.jsonb" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11`) result := model.AllTypes{} - err = query.Query(db, &result) - spew.Dump(result) + err := query.Query(db, &result) + assert.NilError(t, err) + assert.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + assert.DeepEqual(t, result.UUIDPtr, uuidPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) } func TestEnumType(t *testing.T) { @@ -120,7 +124,25 @@ ORDER BY employee.employee_id; func TestWierdNamesTable(t *testing.T) { stmt := WeirdNamesTable.SELECT(WeirdNamesTable.AllColumns) - fmt.Println(stmt.DebugSql()) + assertStatementSql(t, stmt, ` +SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column_name1", + "WEIRD NAMES TABLE"."Weird_Column_Name2" AS "WEIRD NAMES TABLE.Weird_Column_Name2", + "WEIRD NAMES TABLE"."wEiRd_cOluMn_nAmE3" AS "WEIRD NAMES TABLE.wEiRd_cOluMn_nAmE3", + "WEIRD NAMES TABLE"."WeIrd_CoLuMN_Name4" AS "WEIRD NAMES TABLE.WeIrd_CoLuMN_Name4", + "WEIRD NAMES TABLE"."WEIRD_COLUMN_NAME5" AS "WEIRD NAMES TABLE.WEIRD_COLUMN_NAME5", + "WEIRD NAMES TABLE"."WeirdColumnName6" AS "WEIRD NAMES TABLE.WeirdColumnName6", + "WEIRD NAMES TABLE"."weirdColumnName7" AS "WEIRD NAMES TABLE.weirdColumnName7", + "WEIRD NAMES TABLE".weirdcolumnname8 AS "WEIRD NAMES TABLE.weirdcolumnname8", + "WEIRD NAMES TABLE"."weird col name9" AS "WEIRD NAMES TABLE.weird col name9", + "WEIRD NAMES TABLE"."wEiRd cOlu nAmE10" AS "WEIRD NAMES TABLE.wEiRd cOlu nAmE10", + "WEIRD NAMES TABLE"."WEIRD COLU NAME11" AS "WEIRD NAMES TABLE.WEIRD COLU NAME11", + "WEIRD NAMES TABLE"."Weird Colu Name12" AS "WEIRD NAMES TABLE.Weird Colu Name12", + "WEIRD NAMES TABLE"."weird-col-name13" AS "WEIRD NAMES TABLE.weird-col-name13", + "WEIRD NAMES TABLE"."wEiRd-cOlu-nAmE14" AS "WEIRD NAMES TABLE.wEiRd-cOlu-nAmE14", + "WEIRD NAMES TABLE"."WEIRD-COLU-NAME15" AS "WEIRD NAMES TABLE.WEIRD-COLU-NAME15", + "WEIRD NAMES TABLE"."Weird-Colu-Name16" AS "WEIRD NAMES TABLE.Weird-Colu-Name16" +FROM test_sample."WEIRD NAMES TABLE"; +`) dest := []model.WeirdNamesTable{} diff --git a/tests/select_test.go b/tests/select_test.go index 3cafc03..b450405 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -1,8 +1,6 @@ package tests import ( - "fmt" - "github.com/davecgh/go-spew/spew" . "github.com/go-jet/jet" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/enum" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" @@ -158,7 +156,6 @@ LIMIT 12; ). LIMIT(12) - fmt.Println(query.Sql()) assertStatementSql(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) } @@ -253,7 +250,6 @@ LIMIT 1000; assert.Equal(t, len(languageActorFilm[0].Films), 10) assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10) } - } func TestJoinQuerySlice(t *testing.T) { @@ -487,6 +483,97 @@ ORDER BY city.city_id, address.address_id, customer.customer_id; assert.Equal(t, *dest[0].Customers[1].LastName, "Vines") } +func TestExecution4(t *testing.T) { + + var dest []struct { + CityID int32 `sql:"primary_key" alias:"city.city_id"` + CityName string `alias:"city.city"` + + Customers []struct { + CustomerID int32 `sql:"primary_key" alias:"customer_id"` + LastName *string `alias:"last_name"` + + Address struct { + AddressID int32 `sql:"primary_key" alias:"AddressId"` + AddressLine string `alias:"address.address"` + } `alias:"address.*"` + } `alias:"customer"` + } + + stmt := City. + INNER_JOIN(Address, Address.CityID.EQ(City.CityID)). + INNER_JOIN(Customer, Customer.AddressID.EQ(Address.AddressID)). + SELECT( + City.CityID, + City.City, + Customer.CustomerID, + Customer.LastName, + Address.AddressID, + Address.Address, + ). + WHERE(City.City.EQ(String("London")).OR(City.City.EQ(String("York")))). + ORDER_BY(City.CityID, Address.AddressID, Customer.CustomerID) + + assertStatementSql(t, stmt, ` +SELECT city.city_id AS "city.city_id", + city.city AS "city.city", + customer.customer_id AS "customer.customer_id", + customer.last_name AS "customer.last_name", + address.address_id AS "address.address_id", + address.address AS "address.address" +FROM dvds.city + INNER JOIN dvds.address ON (address.city_id = city.city_id) + INNER JOIN dvds.customer ON (customer.address_id = address.address_id) +WHERE (city.city = 'London') OR (city.city = 'York') +ORDER BY city.city_id, address.address_id, customer.customer_id; +`, "London", "York") + + err := stmt.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 2) + assertJson(t, dest, ` +[ + { + "CityID": 312, + "CityName": "London", + "Customers": [ + { + "CustomerID": 252, + "LastName": "Hoffman", + "Address": { + "AddressID": 256, + "AddressLine": "1497 Yuzhou Drive" + } + }, + { + "CustomerID": 512, + "LastName": "Vines", + "Address": { + "AddressID": 517, + "AddressLine": "548 Uruapan Street" + } + } + ] + }, + { + "CityID": 589, + "CityName": "York", + "Customers": [ + { + "CustomerID": 497, + "LastName": "Sledge", + "Address": { + "AddressID": 502, + "AddressLine": "1515 Korla Way" + } + } + ] + } +] +`) +} + func TestJoinQuerySliceWithPtrs(t *testing.T) { type FilmsPerLanguage struct { Language model.Language @@ -810,8 +897,6 @@ FROM dvds.actor rRatingFilms.AllColumns(), ) - fmt.Println(query.DebugSql()) - assertStatementSql(t, query, expectedQuery) dest := []model.Actor{} @@ -875,7 +960,6 @@ ORDER BY film.film_id ASC; WHERE(Film.RentalRate.EQ(maxFilmRentalRate)). ORDER_BY(Film.FilmID.ASC()) - fmt.Println(query.Sql()) assertStatementSql(t, query, expectedSql) maxRentalRateFilms := []model.Film{} @@ -1039,7 +1123,36 @@ func TestSelectStaff(t *testing.T) { assert.NilError(t, err) - spew.Dump(staffs) + assertJson(t, staffs, ` +[ + { + "StaffID": 1, + "FirstName": "Mike", + "LastName": "Hillyer", + "AddressID": 3, + "Email": "Mike.Hillyer@sakilastaff.com", + "StoreID": 1, + "Active": true, + "Username": "Mike", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": "iVBORw0KWgo=" + }, + { + "StaffID": 2, + "FirstName": "Jon", + "LastName": "Stephens", + "AddressID": 4, + "Email": "Jon.Stephens@sakilastaff.com", + "StoreID": 2, + "Active": true, + "Username": "Jon", + "Password": "8cb2237d0679ca88db6464eac60da96345513964", + "LastUpdate": "2006-05-16T16:13:11.79328Z", + "Picture": null + } +] +`) } func TestSelectTimeColumns(t *testing.T) { @@ -1114,9 +1227,6 @@ OFFSET 20; LIMIT(10). OFFSET(20) - queryStr, _, _ := query.Sql() - - fmt.Println("-" + queryStr + "-") assertStatementSql(t, query, expectedQuery, float64(100), float64(200), int64(10), int64(20)) dest := []model.Payment{} @@ -1396,7 +1506,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; assert.NilError(t, err) //jsonSave("./testdata/quick-start-dest.json", dest) - assertJson(t, "./testdata/quick-start-dest.json", dest) + assertJsonFile(t, "./testdata/quick-start-dest.json", dest) var dest2 []struct { model.Category @@ -1409,7 +1519,7 @@ ORDER BY actor.actor_id ASC, film.film_id ASC; assert.NilError(t, err) //jsonSave("./testdata/quick-start-dest2.json", dest2) - assertJson(t, "./testdata/quick-start-dest2.json", dest2) + assertJsonFile(t, "./testdata/quick-start-dest2.json", dest2) } func TestQuickStartWithSubQueries(t *testing.T) { @@ -1461,7 +1571,7 @@ func TestQuickStartWithSubQueries(t *testing.T) { assert.NilError(t, err) //jsonSave("./testdata/quick-start-dest.json", dest) - assertJson(t, "./testdata/quick-start-dest.json", dest) + assertJsonFile(t, "./testdata/quick-start-dest.json", dest) var dest2 []struct { model.Category @@ -1474,5 +1584,5 @@ func TestQuickStartWithSubQueries(t *testing.T) { assert.NilError(t, err) //jsonSave("./testdata/quick-start-dest2.json", dest2) - assertJson(t, "./testdata/quick-start-dest2.json", dest2) + assertJsonFile(t, "./testdata/quick-start-dest2.json", dest2) } diff --git a/tests/update_test.go b/tests/update_test.go index 0ff0ba6..93487e9 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -24,8 +24,6 @@ UPDATE test_sample.link SET (name, url) = ('Bong', 'http://bong.com') WHERE link.name = 'Bing'; ` - fmt.Println(query.Sql()) - assertStatementSql(t, query, expectedSql, "Bong", "http://bong.com", "Bing") assertExec(t, query, 1) diff --git a/test_utils.go b/utils_test.go similarity index 100% rename from test_utils.go rename to utils_test.go diff --git a/wiki/Scan-to-arbitrary-destination.md b/wiki/Scan-to-arbitrary-destination.md index f4e0d7d..4f70d0e 100644 --- a/wiki/Scan-to-arbitrary-destination.md +++ b/wiki/Scan-to-arbitrary-destination.md @@ -46,7 +46,7 @@ WHERE (city.city = 'London') OR (city.city = 'York') ORDER BY city.city_id, address.address_id, customer.customer_id; ``` -Every column is aliased by default. Format is "`table_name`.`column_name`" +Note that every column is aliased by default. Format is "`table_name`.`column_name`" Above statement will produce following result set: @@ -57,7 +57,7 @@ Above statement will produce following result set: | _3_| 589 | "York" | 502 | "1515 Korla Way" | 497 | "Sledge" | Lets execute statement and scan result set to destination `dest`: - ```sql + ``` var dest []struct { model.City @@ -70,33 +70,29 @@ var dest []struct { err := stmt.Query(db, &dest) ``` - -`Query` uses reflection to introspect destination type structure, and result set column names(aliases), to be able to map result set data to destination object. + Note that camel case of result set column names(aliases) is the same as `model type name`.`field name`. For instance `city.city_id` -> `City.CityID`. This is being used to find appropriate column for each destination model field. -It is not an error if there is not a column for each destination model field. +It is not an error if there is not a column for each destination model field. Table and column names does not have +to be in snake case. + +`Query` uses reflection to introspect destination type structure, and result set column names(aliases), to be able to map result set data to destination object. +Every new destination struct object is cached by his and all the parents primary key. So grouping to work correctly at least table primary keys has to appear in query result set. If there is no primary key in a result set +row number is used as grouping condition(which is always unique). +For instance, after row 1 is processed, two objects are stored to cache: +``` +Key: Object: +(City(312)) -> (*struct { model.City; Customers []struct { model.Customer; Address model.Address } }) +(City(312)),(Customer(252),Address(256)) -> (*struct { model.Customer; Address model.Address }) +``` +After row 2 processing only one object is stored to cache, because city with city_id 312 is already in cache. +``` +Key: Object: +(City(312)) -> pulled from cache +(City(312)),(Customer(512),Address(517)) -> (*struct { model.Customer; Address model.Address }) +``` -Lets see in general how `Query` works row by row: - -- ROW 1: - - dest is slice of structs, so new struct object is initialized and scan proceeds to next step. - - `city.city_id` and `city.city` columns (values `312` and `"London"`) are used to initialize `CityID` and `City` fields of `model.City` object. - - `Customers` is a slice of structs, so new struct object is initialized and scan proceeds to next step. - - `customer.customer_id` and `customer.last_name` is used to initialize fields in `model.Customer` object. - - `address.address_id` and `address.address` is used to initialize fields in `Address model.Address` - - because at least one field of struct is being initialized struct is added to `Customers []struct` and cached by parent and - struct primary key fields([more about primary key fields](TODO)). Primary keys used for caching are `CityID`, `CustomerID` and `AddressID` of `model.City`, `model.Customer` - and `model.Address` - - because at least one field of struct is being initialized struct is added to `var dest []struct` and cached by - struct primary key fields. Primary keys used for caching is only `CityID` from `model.City` -- ROW 2: - - Does not initialize new struct object for `dest []struct` but pulls one from the cache, because `city` with `city_id` of `312` has - already being processed. Following steps are the same as above, new objects are created, stored in slice and cached. -- ROW 3: - - steps would be similar as for the first step. Nothing is pulled from he cache, stored in slice and cached. - - - Lets print `dest` as a json, to visualize `Query` result: +Lets print `dest` as a json, to visualize `Query` result: ``` [ @@ -190,8 +186,10 @@ City of `London` has two customers, which is the product of object reuse in `ROW ### Custom model files -**Destinations are not limited to just model files, any destination will work, as long as camel case of result set column -is equal to `model type name`.`field name`.** +Destinations are not limited to just model files, any destination will work, as long as camel case of result set column +is equal to `model type name`.`field name`. +Custom model type can have field of any type listed in [Mappings of database types to Go types](), +plus any type that implements `sql.Scanner` interface. #### Named types @@ -237,7 +235,7 @@ err := stmt2.Query(db, &dest2) ``` Destination type names and field names are now changed. Every type has 'My' prefix, every primary key column is named `ID`, - `FirstName` is now string pointer etc. + `LastName` is now string pointer etc. Because we are using custom types with changed identifier, every column has to be aliased. For instance: `City.CityID.AS("my_city.id")`, ` City.City.AS("myCity.Name")` etc. **Table names, column names and aliases doesn't have to be in a snake case. CamelCase, PascalCase or some other mixed space is also supported, @@ -286,10 +284,10 @@ Json of new destination is also changed: ] ``` -#### Antonymous types +#### Anonymous custom types There is no need to create new named type for every custom model. -Destination type can be declared inline without naming any type. +Destination type can be declared inline without naming any new type. ``` var dest []struct { @@ -326,6 +324,80 @@ err := stmt.Query(db, &dest) Aliasing is now simplified. Alias contains only (column/field) name. On the other hand, we can not have 3 fields named `ID`, because aliases have to be unique. +### Tagging model files + +Desired mapping can be set the other way around as well, by tagging destination fields and types. + +``` +var dest []struct { + CityID int32 `sql:"primary_key" alias:"city.city_id"` + CityName string `alias:"city.city"` + + Customers []struct { + // because the whole struct is refering to 'customer.*' (see below tag), + // we can just use 'alias:"customer_id"`' instead of 'alias:"customer.customer_id"`' + CustomerID int32 `sql:"primary_key" alias:"customer_id"` + LastName *string `alias:"last_name"` + + Address struct { + AddressID int32 `sql:"primary_key" alias:"AddressId"` // cammel case for alias will work as well + AddressLine string `alias:"address.address"` // full alias will work as well + } `alias:"address.*"` // struct is now refering to all address.* columns + + } `alias:"customer.*"` // struct is now refering to all customer.* columns +} + +stmt := City. + INNER_JOIN(Address, Address.CityID.EQ(City.CityID)). + INNER_JOIN(Customer, Customer.AddressID.EQ(Address.AddressID)). + SELECT( + City.CityID, + City.City, + Customer.CustomerID, + Customer.LastName, + Address.AddressID, + Address.Address, + ). + WHERE(City.City.EQ(String("London")).OR(City.City.EQ(String("York")))). + ORDER_BY(City.CityID, Address.AddressID, Customer.CustomerID) + +err := stmt.Query(db, &dest) +``` + +This kind of mapping is more complicated than in previous examples, and it should avoided and used +only when there is no alternative. Usually this is the case in two scenarios: + +1) Self join + +``` +var dest []struct{ + model.Employee + + Manager *model.Employee `alias:"Manager.*"` //or just `alias:"Manager" +} + +manager := Employee.AS("Manager") + +stmt := Employee. + LEFT_JOIN(manager, Employee.ReportsTo.EQ(manager.EmployeeId)). + SELECT( + Employee.EmployeeId, + Employee.FirstName, + manager.EmployeeId, + manager.FirstName, + ) +``` +_This example could also be written without tag alias, by just introducing of a new type `type Manager model.Employee`._ + +2) Slices of go base types (int32, float64, string, ...) + +``` +var dest struct { + model.Film + + InventoryIDs []int32 `alias:"inventory.inventory_id"` +} +``` ### Combining autogenerated and custom model files @@ -346,4 +418,6 @@ type MyCity struct { Customers []MyCustomer } -``` \ No newline at end of file +``` + +