Improve database to golang name mapping.
This commit is contained in:
parent
3e7277015d
commit
950663dadb
19 changed files with 538 additions and 122 deletions
|
|
@ -7,7 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-jet/jet/execution/internal"
|
||||
"github.com/serenize/snaker"
|
||||
"github.com/go-jet/jet/internal/util"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
|
@ -139,10 +139,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
|
|||
if isGoBaseType(sliceElemType) {
|
||||
index := 0
|
||||
if structField != nil {
|
||||
tableName, columnName := getRefAlias(structField)
|
||||
index = scanContext.columnIndex(tableName, columnName)
|
||||
|
||||
if index < 0 {
|
||||
if index = scanContext.aliasColumnIndex(structField.Tag.Get("alias")); index < 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -293,28 +290,24 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
|
|||
return
|
||||
}
|
||||
|
||||
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
|
||||
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
|
||||
structType := structPtrValue.Type().Elem()
|
||||
structValue := structPtrValue.Elem()
|
||||
|
||||
tableName, _ := getRefAlias(structField)
|
||||
|
||||
if tableName == "" {
|
||||
tableName = structType.Name()
|
||||
}
|
||||
typeName := getTypeName(structType, parentField)
|
||||
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
field := structType.Field(i)
|
||||
|
||||
fieldValue := structValue.Field(i)
|
||||
columnName := field.Name
|
||||
fieldName := field.Name
|
||||
|
||||
if scannerValue, ok := implementsScanner(fieldValue); ok {
|
||||
if len(onlySlices) > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
cellValue := scanContext.getCellValue(tableName, columnName)
|
||||
cellValue := scanContext.getCellValue(typeName, fieldName)
|
||||
|
||||
if cellValue == nil {
|
||||
continue
|
||||
|
|
@ -336,7 +329,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
|
|||
continue
|
||||
}
|
||||
|
||||
cellValue := scanContext.getCellValue(tableName, columnName)
|
||||
cellValue := scanContext.getCellValue(typeName, fieldName)
|
||||
|
||||
if cellValue != nil {
|
||||
updated = true
|
||||
|
|
@ -365,26 +358,20 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
|
|||
return
|
||||
}
|
||||
|
||||
func getRefAlias(structField *reflect.StructField) (table, column string) {
|
||||
if structField == nil {
|
||||
return
|
||||
func getTypeName(structType reflect.Type, parentField *reflect.StructField) string {
|
||||
if parentField == nil {
|
||||
return structType.Name()
|
||||
}
|
||||
|
||||
aliasTag := structField.Tag.Get("alias")
|
||||
aliasTag := parentField.Tag.Get("alias")
|
||||
|
||||
if aliasTag == "" {
|
||||
return
|
||||
return structType.Name()
|
||||
}
|
||||
|
||||
aliasParts := strings.Split(aliasTag, ".")
|
||||
|
||||
table = aliasParts[0]
|
||||
|
||||
if len(aliasParts) > 1 {
|
||||
column = aliasParts[1]
|
||||
}
|
||||
|
||||
return
|
||||
return aliasParts[0]
|
||||
}
|
||||
|
||||
func initializeValueIfNilPtr(value reflect.Value) {
|
||||
|
|
@ -533,12 +520,14 @@ type scanContext struct {
|
|||
row []interface{}
|
||||
uniqueDestObjectsMap map[string]int
|
||||
|
||||
columnNameIndexMap map[string]int
|
||||
groupKeyInfoCache map[string]groupKeyInfo
|
||||
aliasIndexMap map[string]int
|
||||
goNameMap map[string]int
|
||||
|
||||
groupKeyInfoCache map[string]groupKeyInfo
|
||||
}
|
||||
|
||||
func newScanContext(rows *sql.Rows) (*scanContext, error) {
|
||||
columnNames, err := rows.Columns()
|
||||
aliases, err := rows.Columns()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -550,10 +539,24 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
columnNameIndexMap := map[string]int{}
|
||||
aliasIndexMap := map[string]int{}
|
||||
|
||||
for i, columnName := range columnNames {
|
||||
columnNameIndexMap[strings.ToLower(columnName)] = i
|
||||
for i, columnName := range aliases {
|
||||
aliasIndexMap[strings.ToLower(columnName)] = i
|
||||
}
|
||||
|
||||
goNamesMap := map[string]int{}
|
||||
|
||||
for i, alias := range aliases {
|
||||
names := strings.SplitN(alias, ".", 2)
|
||||
|
||||
goName := util.ToGoIdentifier(names[0])
|
||||
|
||||
if len(names) > 1 {
|
||||
goName += "." + util.ToGoIdentifier(names[1])
|
||||
}
|
||||
|
||||
goNamesMap[strings.ToLower(goName)] = i
|
||||
}
|
||||
|
||||
return &scanContext{
|
||||
|
|
@ -561,8 +564,8 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
|
|||
uniqueDestObjectsMap: make(map[string]int),
|
||||
|
||||
groupKeyInfoCache: make(map[string]groupKeyInfo),
|
||||
|
||||
columnNameIndexMap: columnNameIndexMap,
|
||||
aliasIndexMap: aliasIndexMap,
|
||||
goNameMap: goNamesMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -607,12 +610,8 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
|
|||
return "{" + groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")}"
|
||||
}
|
||||
|
||||
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, structField *reflect.StructField) groupKeyInfo {
|
||||
tableName, _ := getRefAlias(structField)
|
||||
|
||||
if tableName == "" {
|
||||
tableName = structType.Name()
|
||||
}
|
||||
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
|
||||
typeName := getTypeName(structType, parentField)
|
||||
|
||||
ret := groupKeyInfo{typeName: structType.Name()}
|
||||
|
||||
|
|
@ -635,7 +634,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, structField *refl
|
|||
ret.subTypes = append(ret.subTypes, subType)
|
||||
}
|
||||
} else if isPrimaryKey(field) {
|
||||
index := s.columnIndex(tableName, field.Name)
|
||||
index := s.typeColumnIndex(typeName, field.Name)
|
||||
|
||||
if index < 0 {
|
||||
continue
|
||||
|
|
@ -654,47 +653,36 @@ type groupKeyInfo struct {
|
|||
subTypes []groupKeyInfo
|
||||
}
|
||||
|
||||
func (s *scanContext) columnIndex(tableName, columnName string) int {
|
||||
if tableName == "" {
|
||||
name := strings.ToLower(columnName)
|
||||
if i, ok := s.columnNameIndexMap[name]; ok {
|
||||
return i
|
||||
}
|
||||
func (s *scanContext) aliasColumnIndex(alias string) int {
|
||||
index, ok := s.aliasIndexMap[alias]
|
||||
|
||||
name = strings.ToLower(snaker.CamelToSnake(columnName))
|
||||
if i, ok := s.columnNameIndexMap[name]; ok {
|
||||
return i
|
||||
}
|
||||
} else {
|
||||
name := strings.ToLower(tableName + "." + columnName)
|
||||
if i, ok := s.columnNameIndexMap[name]; ok {
|
||||
return i
|
||||
}
|
||||
|
||||
snakedTableName := snaker.CamelToSnake(tableName)
|
||||
snakedColumnName := snaker.CamelToSnake(columnName)
|
||||
|
||||
name = strings.ToLower(snakedTableName + "." + snakedColumnName)
|
||||
if i, ok := s.columnNameIndexMap[name]; ok {
|
||||
return i
|
||||
}
|
||||
|
||||
name = strings.ToLower(tableName + "." + snakedColumnName)
|
||||
if i, ok := s.columnNameIndexMap[name]; ok {
|
||||
return i
|
||||
}
|
||||
|
||||
name = strings.ToLower(snakedTableName + "." + columnName)
|
||||
if i, ok := s.columnNameIndexMap[name]; ok {
|
||||
return i
|
||||
}
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
|
||||
return -1
|
||||
return index
|
||||
}
|
||||
|
||||
func (s *scanContext) getCellValue(tableName, fieldName string) interface{} {
|
||||
index := s.columnIndex(tableName, fieldName)
|
||||
func (s *scanContext) typeColumnIndex(typeName, fieldName string) int {
|
||||
var key string
|
||||
|
||||
if typeName != "" {
|
||||
key = strings.ToLower(typeName + "." + fieldName)
|
||||
} else {
|
||||
key = strings.ToLower(fieldName)
|
||||
}
|
||||
|
||||
index, ok := s.goNameMap[key]
|
||||
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
|
||||
return index
|
||||
}
|
||||
|
||||
func (s *scanContext) getCellValue(typeName, fieldName string) interface{} {
|
||||
index := s.typeColumnIndex(typeName, fieldName)
|
||||
|
||||
if index < 0 {
|
||||
return nil
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue