diff --git a/execution/execution.go b/execution/execution.go index 3281b03..38bb01a 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -14,7 +14,7 @@ import ( "time" ) -// Query executes query with arguments over database connection with context and stores result into destination. +// Query executes query with list of arguments over database connection using context and stores result into destination. // Destination can be either pointer to struct or pointer to slice of structs. func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error { @@ -748,37 +748,37 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { } func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { - typeName := getTypeName(structType, parentField) - ret := groupKeyInfo{typeName: structType.Name()} + typeName := getTypeName(structType, parentField) + primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField) + for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) - newTypeName, fieldName := getTypeAndFieldName(typeName, field) + fieldType := indirectType(field.Type) - if !isSimpleModelType(field.Type) { - var structType reflect.Type - if field.Type.Kind() == reflect.Struct { - structType = field.Type - } else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct { - structType = field.Type.Elem() - } else { + if !isSimpleModelType(fieldType) { + if fieldType.Kind() != reflect.Struct { continue } - subType := s.getGroupKeyInfo(structType, &field) + subType := s.getGroupKeyInfo(fieldType, &field) if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { ret.subTypes = append(ret.subTypes, subType) } - } else if isPrimaryKey(field, parentField) { - index := s.typeToColumnIndex(newTypeName, fieldName) + } else { + if isPrimaryKey(field, primaryKeyOverwrites) { + newTypeName, fieldName := getTypeAndFieldName(typeName, field) - if index < 0 { - continue + index := s.typeToColumnIndex(newTypeName, fieldName) + + if index < 0 { + continue + } + + ret.indexes = append(ret.indexes, index) } - - ret.indexes = append(ret.indexes, index) } } @@ -835,10 +835,9 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value { return newElem } -func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) bool { - - if hasOverwrite, isPrimaryKey := primaryKeyOvewrite(field.Name, parentField); hasOverwrite { - return isPrimaryKey +func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { + if len(primaryKeyOverwrites) > 0 { + return utils.StringSliceContains(primaryKeyOverwrites, field.Name) } sqlTag := field.Tag.Get("sql") @@ -846,32 +845,24 @@ func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) b return sqlTag == "primary_key" } -func primaryKeyOvewrite(columnName string, parentField *reflect.StructField) (hasOverwrite, primaryKey bool) { +func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string { if parentField == nil { - return + return nil } sqlTag := parentField.Tag.Get("sql") if !strings.HasPrefix(sqlTag, "primary_key") { - return + return nil } parts := strings.Split(sqlTag, "=") if len(parts) < 2 { - return + return nil } - primaryKeyColumns := strings.Split(parts[1], ",") - - for _, primaryKeyCol := range primaryKeyColumns { - if toCommonIdentifier(columnName) == toCommonIdentifier(primaryKeyCol) { - return true, true - } - } - - return true, false + return strings.Split(parts[1], ",") } func indirectType(reflectType reflect.Type) reflect.Type { diff --git a/internal/utils/utils.go b/internal/utils/utils.go index a091973..9be5310 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -164,3 +164,13 @@ func ErrorCatch(err *error) { *err = fmt.Errorf("%v", recovered) } } + +func StringSliceContains(strings []string, contains string) bool { + for _, str := range strings { + if str == contains { + return true + } + } + + return false +}