GroupKeyInfo improvements for parent specified primary keys.

This commit is contained in:
go-jet 2019-09-27 11:34:12 +02:00
parent d0297ca16f
commit 92de03d4b3
2 changed files with 36 additions and 35 deletions

View file

@ -14,7 +14,7 @@ import (
"time" "time"
) )
// Query executes query with arguments over database connection with context and stores result into destination. // Query executes query with list of arguments over database connection using context and stores result into destination.
// Destination can be either pointer to struct or pointer to slice of structs. // Destination can be either pointer to struct or pointer to slice of structs.
func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error { func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error {
@ -748,37 +748,37 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
} }
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo { func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
typeName := getTypeName(structType, parentField)
ret := groupKeyInfo{typeName: structType.Name()} ret := groupKeyInfo{typeName: structType.Name()}
typeName := getTypeName(structType, parentField)
primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField)
for i := 0; i < structType.NumField(); i++ { for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i) field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field) fieldType := indirectType(field.Type)
if !isSimpleModelType(field.Type) { if !isSimpleModelType(fieldType) {
var structType reflect.Type if fieldType.Kind() != reflect.Struct {
if field.Type.Kind() == reflect.Struct {
structType = field.Type
} else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
structType = field.Type.Elem()
} else {
continue continue
} }
subType := s.getGroupKeyInfo(structType, &field) subType := s.getGroupKeyInfo(fieldType, &field)
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 { if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
ret.subTypes = append(ret.subTypes, subType) ret.subTypes = append(ret.subTypes, subType)
} }
} else if isPrimaryKey(field, parentField) { } else {
index := s.typeToColumnIndex(newTypeName, fieldName) if isPrimaryKey(field, primaryKeyOverwrites) {
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
if index < 0 { index := s.typeToColumnIndex(newTypeName, fieldName)
continue
if index < 0 {
continue
}
ret.indexes = append(ret.indexes, index)
} }
ret.indexes = append(ret.indexes, index)
} }
} }
@ -835,10 +835,9 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
return newElem return newElem
} }
func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) bool { func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
if len(primaryKeyOverwrites) > 0 {
if hasOverwrite, isPrimaryKey := primaryKeyOvewrite(field.Name, parentField); hasOverwrite { return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
return isPrimaryKey
} }
sqlTag := field.Tag.Get("sql") sqlTag := field.Tag.Get("sql")
@ -846,32 +845,24 @@ func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) b
return sqlTag == "primary_key" return sqlTag == "primary_key"
} }
func primaryKeyOvewrite(columnName string, parentField *reflect.StructField) (hasOverwrite, primaryKey bool) { func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string {
if parentField == nil { if parentField == nil {
return return nil
} }
sqlTag := parentField.Tag.Get("sql") sqlTag := parentField.Tag.Get("sql")
if !strings.HasPrefix(sqlTag, "primary_key") { if !strings.HasPrefix(sqlTag, "primary_key") {
return return nil
} }
parts := strings.Split(sqlTag, "=") parts := strings.Split(sqlTag, "=")
if len(parts) < 2 { if len(parts) < 2 {
return return nil
} }
primaryKeyColumns := strings.Split(parts[1], ",") return strings.Split(parts[1], ",")
for _, primaryKeyCol := range primaryKeyColumns {
if toCommonIdentifier(columnName) == toCommonIdentifier(primaryKeyCol) {
return true, true
}
}
return true, false
} }
func indirectType(reflectType reflect.Type) reflect.Type { func indirectType(reflectType reflect.Type) reflect.Type {

View file

@ -164,3 +164,13 @@ func ErrorCatch(err *error) {
*err = fmt.Errorf("%v", recovered) *err = fmt.Errorf("%v", recovered)
} }
} }
func StringSliceContains(strings []string, contains string) bool {
for _, str := range strings {
if str == contains {
return true
}
}
return false
}