GroupKeyInfo improvements for parent specified primary keys.

This commit is contained in:
go-jet 2019-09-27 11:34:12 +02:00
parent d0297ca16f
commit 92de03d4b3
2 changed files with 36 additions and 35 deletions

View file

@ -14,7 +14,7 @@ import (
"time"
)
// Query executes query with arguments over database connection with context and stores result into destination.
// Query executes query with list of arguments over database connection using context and stores result into destination.
// Destination can be either pointer to struct or pointer to slice of structs.
func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error {
@ -748,37 +748,37 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
}
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
typeName := getTypeName(structType, parentField)
ret := groupKeyInfo{typeName: structType.Name()}
typeName := getTypeName(structType, parentField)
primaryKeyOverwrites := parentFieldPrimaryKeyOverwrite(parentField)
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
fieldType := indirectType(field.Type)
if !isSimpleModelType(field.Type) {
var structType reflect.Type
if field.Type.Kind() == reflect.Struct {
structType = field.Type
} else if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
structType = field.Type.Elem()
} else {
if !isSimpleModelType(fieldType) {
if fieldType.Kind() != reflect.Struct {
continue
}
subType := s.getGroupKeyInfo(structType, &field)
subType := s.getGroupKeyInfo(fieldType, &field)
if len(subType.indexes) != 0 || len(subType.subTypes) != 0 {
ret.subTypes = append(ret.subTypes, subType)
}
} else if isPrimaryKey(field, parentField) {
index := s.typeToColumnIndex(newTypeName, fieldName)
} else {
if isPrimaryKey(field, primaryKeyOverwrites) {
newTypeName, fieldName := getTypeAndFieldName(typeName, field)
if index < 0 {
continue
index := s.typeToColumnIndex(newTypeName, fieldName)
if index < 0 {
continue
}
ret.indexes = append(ret.indexes, index)
}
ret.indexes = append(ret.indexes, index)
}
}
@ -835,10 +835,9 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
return newElem
}
func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) bool {
if hasOverwrite, isPrimaryKey := primaryKeyOvewrite(field.Name, parentField); hasOverwrite {
return isPrimaryKey
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
if len(primaryKeyOverwrites) > 0 {
return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
}
sqlTag := field.Tag.Get("sql")
@ -846,32 +845,24 @@ func isPrimaryKey(field reflect.StructField, parentField *reflect.StructField) b
return sqlTag == "primary_key"
}
func primaryKeyOvewrite(columnName string, parentField *reflect.StructField) (hasOverwrite, primaryKey bool) {
func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string {
if parentField == nil {
return
return nil
}
sqlTag := parentField.Tag.Get("sql")
if !strings.HasPrefix(sqlTag, "primary_key") {
return
return nil
}
parts := strings.Split(sqlTag, "=")
if len(parts) < 2 {
return
return nil
}
primaryKeyColumns := strings.Split(parts[1], ",")
for _, primaryKeyCol := range primaryKeyColumns {
if toCommonIdentifier(columnName) == toCommonIdentifier(primaryKeyCol) {
return true, true
}
}
return true, false
return strings.Split(parts[1], ",")
}
func indirectType(reflectType reflect.Type) reflect.Type {