Improve database to golang name mapping.

This commit is contained in:
go-jet 2019-07-03 16:27:14 +02:00
parent 3e7277015d
commit 950663dadb
19 changed files with 538 additions and 122 deletions

View file

@ -7,7 +7,7 @@ import (
"errors"
"fmt"
"github.com/go-jet/jet/execution/internal"
"github.com/serenize/snaker"
"github.com/go-jet/jet/internal/util"
"reflect"
"strconv"
"strings"
@ -139,10 +139,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
if isGoBaseType(sliceElemType) {
index := 0
if structField != nil {
tableName, columnName := getRefAlias(structField)
index = scanContext.columnIndex(tableName, columnName)
if index < 0 {
if index = scanContext.aliasColumnIndex(structField.Tag.Get("alias")); index < 0 {
return
}
}
@ -293,28 +290,24 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
return
}
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, parentField *reflect.StructField, onlySlices ...bool) (updated bool, err error) {
structType := structPtrValue.Type().Elem()
structValue := structPtrValue.Elem()
tableName, _ := getRefAlias(structField)
if tableName == "" {
tableName = structType.Name()
}
typeName := getTypeName(structType, parentField)
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
fieldValue := structValue.Field(i)
columnName := field.Name
fieldName := field.Name
if scannerValue, ok := implementsScanner(fieldValue); ok {
if len(onlySlices) > 0 {
continue
}
cellValue := scanContext.getCellValue(tableName, columnName)
cellValue := scanContext.getCellValue(typeName, fieldName)
if cellValue == nil {
continue
@ -336,7 +329,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
continue
}
cellValue := scanContext.getCellValue(tableName, columnName)
cellValue := scanContext.getCellValue(typeName, fieldName)
if cellValue != nil {
updated = true
@ -365,26 +358,20 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re
return
}
func getRefAlias(structField *reflect.StructField) (table, column string) {
if structField == nil {
return
func getTypeName(structType reflect.Type, parentField *reflect.StructField) string {
if parentField == nil {
return structType.Name()
}
aliasTag := structField.Tag.Get("alias")
aliasTag := parentField.Tag.Get("alias")
if aliasTag == "" {
return
return structType.Name()
}
aliasParts := strings.Split(aliasTag, ".")
table = aliasParts[0]
if len(aliasParts) > 1 {
column = aliasParts[1]
}
return
return aliasParts[0]
}
func initializeValueIfNilPtr(value reflect.Value) {
@ -533,12 +520,14 @@ type scanContext struct {
row []interface{}
uniqueDestObjectsMap map[string]int
columnNameIndexMap map[string]int
groupKeyInfoCache map[string]groupKeyInfo
aliasIndexMap map[string]int
goNameMap map[string]int
groupKeyInfoCache map[string]groupKeyInfo
}
func newScanContext(rows *sql.Rows) (*scanContext, error) {
columnNames, err := rows.Columns()
aliases, err := rows.Columns()
if err != nil {
return nil, err
@ -550,10 +539,24 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
return nil, err
}
columnNameIndexMap := map[string]int{}
aliasIndexMap := map[string]int{}
for i, columnName := range columnNames {
columnNameIndexMap[strings.ToLower(columnName)] = i
for i, columnName := range aliases {
aliasIndexMap[strings.ToLower(columnName)] = i
}
goNamesMap := map[string]int{}
for i, alias := range aliases {
names := strings.SplitN(alias, ".", 2)
goName := util.ToGoIdentifier(names[0])
if len(names) > 1 {
goName += "." + util.ToGoIdentifier(names[1])
}
goNamesMap[strings.ToLower(goName)] = i
}
return &scanContext{
@ -561,8 +564,8 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
uniqueDestObjectsMap: make(map[string]int),
groupKeyInfoCache: make(map[string]groupKeyInfo),
columnNameIndexMap: columnNameIndexMap,
aliasIndexMap: aliasIndexMap,
goNameMap: goNamesMap,
}, nil
}
@ -607,12 +610,8 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
return "{" + groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")}"
}
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, structField *reflect.StructField) groupKeyInfo {
tableName, _ := getRefAlias(structField)
if tableName == "" {
tableName = structType.Name()
}
func (s *scanContext) getGroupKeyInfo(structType reflect.Type, parentField *reflect.StructField) groupKeyInfo {
typeName := getTypeName(structType, parentField)
ret := groupKeyInfo{typeName: structType.Name()}
@ -635,7 +634,7 @@ func (s *scanContext) getGroupKeyInfo(structType reflect.Type, structField *refl
ret.subTypes = append(ret.subTypes, subType)
}
} else if isPrimaryKey(field) {
index := s.columnIndex(tableName, field.Name)
index := s.typeColumnIndex(typeName, field.Name)
if index < 0 {
continue
@ -654,47 +653,36 @@ type groupKeyInfo struct {
subTypes []groupKeyInfo
}
func (s *scanContext) columnIndex(tableName, columnName string) int {
if tableName == "" {
name := strings.ToLower(columnName)
if i, ok := s.columnNameIndexMap[name]; ok {
return i
}
func (s *scanContext) aliasColumnIndex(alias string) int {
index, ok := s.aliasIndexMap[alias]
name = strings.ToLower(snaker.CamelToSnake(columnName))
if i, ok := s.columnNameIndexMap[name]; ok {
return i
}
} else {
name := strings.ToLower(tableName + "." + columnName)
if i, ok := s.columnNameIndexMap[name]; ok {
return i
}
snakedTableName := snaker.CamelToSnake(tableName)
snakedColumnName := snaker.CamelToSnake(columnName)
name = strings.ToLower(snakedTableName + "." + snakedColumnName)
if i, ok := s.columnNameIndexMap[name]; ok {
return i
}
name = strings.ToLower(tableName + "." + snakedColumnName)
if i, ok := s.columnNameIndexMap[name]; ok {
return i
}
name = strings.ToLower(snakedTableName + "." + columnName)
if i, ok := s.columnNameIndexMap[name]; ok {
return i
}
if !ok {
return -1
}
return -1
return index
}
func (s *scanContext) getCellValue(tableName, fieldName string) interface{} {
index := s.columnIndex(tableName, fieldName)
func (s *scanContext) typeColumnIndex(typeName, fieldName string) int {
var key string
if typeName != "" {
key = strings.ToLower(typeName + "." + fieldName)
} else {
key = strings.ToLower(fieldName)
}
index, ok := s.goNameMap[key]
if !ok {
return -1
}
return index
}
func (s *scanContext) getCellValue(typeName, fieldName string) interface{} {
index := s.typeColumnIndex(typeName, fieldName)
if index < 0 {
return nil