jet/qrm/utill.go

392 lines
8.3 KiB
Go
Raw Normal View History

2019-10-11 10:15:36 +02:00
package qrm
import (
"database/sql"
"fmt"
2020-11-16 15:51:32 -05:00
"github.com/go-jet/jet/v2/internal/utils"
"github.com/go-jet/jet/v2/qrm/internal"
"github.com/google/uuid"
2019-10-11 10:15:36 +02:00
"reflect"
"strings"
"time"
)
var scannerInterfaceType = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
func implementsScannerType(fieldType reflect.Type) bool {
if fieldType.Implements(scannerInterfaceType) {
return true
}
typePtr := reflect.New(fieldType).Type()
return typePtr.Implements(scannerInterfaceType)
}
func getScanner(value reflect.Value) sql.Scanner {
if scanner, ok := value.Interface().(sql.Scanner); ok {
return scanner
}
return value.Addr().Interface().(sql.Scanner)
}
func getSliceElemType(slicePtrValue reflect.Value) reflect.Type {
sliceTypePtr := slicePtrValue.Type()
elemType := indirectType(sliceTypePtr).Elem()
2019-10-11 10:15:36 +02:00
return indirectType(elemType)
2019-10-11 10:15:36 +02:00
}
func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value {
sliceValue := slicePtrValue.Elem()
elem := sliceValue.Index(index)
if elem.Kind() == reflect.Ptr {
return elem
}
return elem.Addr()
}
func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) error {
2019-10-18 10:09:56 +02:00
utils.MustBeTrue(!slicePtrValue.IsNil(), "jet: internal, slice is nil")
2019-10-11 10:15:36 +02:00
sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem()
var newSliceElemValue reflect.Value
2019-10-11 10:15:36 +02:00
if objPtrValue.Type().AssignableTo(sliceElemType) {
newSliceElemValue = objPtrValue
} else if objPtrValue.Elem().Type().AssignableTo(sliceElemType) {
newSliceElemValue = objPtrValue.Elem()
} else {
newSliceElemValue = reflect.New(sliceElemType).Elem()
var err error
if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue)
}
2019-10-11 10:15:36 +02:00
if err != nil {
return fmt.Errorf("can't append %T to %T slice: %w", objPtrValue.Elem().Interface(), sliceValue.Interface(), err)
}
2019-10-11 10:15:36 +02:00
}
sliceValue.Set(reflect.Append(sliceValue, newSliceElemValue))
2019-10-11 10:15:36 +02:00
return nil
}
func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
destinationSliceType := slicePtrValue.Type().Elem()
elemType := indirectType(destinationSliceType.Elem())
return reflect.New(elemType)
}
func getTypeName(structType reflect.Type, parentField *reflect.StructField) string {
if parentField == nil {
return structType.Name()
}
aliasTag := parentField.Tag.Get("alias")
if aliasTag == "" {
return structType.Name()
}
aliasParts := strings.Split(aliasTag, ".")
return toCommonIdentifier(aliasParts[0])
}
func getTypeAndFieldName(structType string, field reflect.StructField) (string, string) {
aliasTag := field.Tag.Get("alias")
if aliasTag == "" {
return structType, field.Name
}
aliasParts := strings.Split(aliasTag, ".")
if len(aliasParts) == 1 {
return structType, toCommonIdentifier(aliasParts[0])
}
return toCommonIdentifier(aliasParts[0]), toCommonIdentifier(aliasParts[1])
}
var replacer = strings.NewReplacer(" ", "", "-", "", "_", "")
func toCommonIdentifier(name string) string {
return strings.ToLower(replacer.Replace(name))
}
func initializeValueIfNilPtr(value reflect.Value) {
if !value.IsValid() || !value.CanSet() {
return
}
if value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
}
}
func valueToString(value reflect.Value) string {
if !value.IsValid() {
return "nil"
}
var valueInterface interface{}
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return "nil"
}
valueInterface = value.Elem().Interface()
} else {
valueInterface = value.Interface()
}
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
var timeType = reflect.TypeOf(time.Now())
var uuidType = reflect.TypeOf(uuid.New())
2019-10-18 09:56:38 +02:00
var byteArrayType = reflect.TypeOf([]byte(""))
2019-10-11 10:15:36 +02:00
func isSimpleModelType(objType reflect.Type) bool {
objType = indirectType(objType)
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64,
reflect.String,
reflect.Bool:
return true
}
2019-10-18 09:56:38 +02:00
return objType == timeType || objType == uuidType || objType == byteArrayType
2019-10-11 10:15:36 +02:00
}
func isIntegerType(objType reflect.Type) bool {
objType = indirectType(objType)
switch objType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
}
return false
}
func isFloatType(value reflect.Type) bool {
switch value.Kind() {
case reflect.Float32, reflect.Float64:
2019-10-11 10:15:36 +02:00
return true
}
return false
}
func tryAssign(source, destination reflect.Value) error {
if source.Type() != destination.Type() &&
!isFloatType(destination.Type()) && // to preserve precision during conversion
!(isIntegerType(source.Type()) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
source.Type().ConvertibleTo(destination.Type()) {
2019-10-11 10:15:36 +02:00
source = source.Convert(destination.Type())
}
2019-10-11 10:15:36 +02:00
if source.Type().AssignableTo(destination.Type()) {
switch b := source.Interface().(type) {
case []byte:
destination.SetBytes(cloneBytes(b))
default:
destination.Set(source)
2019-10-11 10:15:36 +02:00
}
return nil
}
sourceInterface := source.Interface()
switch destination.Interface().(type) {
case bool:
var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface)
if err != nil {
return err
}
destination.SetBool(nullBool.Bool)
case float32, float64:
var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface)
if err != nil {
return err
}
if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64)
}
case int, int8, int16, int32, int64:
var integer sql.NullInt64
2019-10-11 10:15:36 +02:00
err := integer.Scan(sourceInterface)
if err != nil {
return err
}
2019-10-11 10:15:36 +02:00
if integer.Valid {
destination.SetInt(integer.Int64)
}
2019-10-11 10:15:36 +02:00
case uint, uint8, uint16, uint32, uint64:
var uInt internal.NullUInt64
2019-10-11 10:15:36 +02:00
err := uInt.Scan(sourceInterface)
2019-10-11 10:15:36 +02:00
if err != nil {
return err
}
2019-10-11 10:15:36 +02:00
if uInt.Valid {
destination.SetUint(uInt.UInt64)
2019-10-11 10:15:36 +02:00
}
case string:
var str sql.NullString
err := str.Scan(sourceInterface)
if err != nil {
return err
2019-10-11 10:15:36 +02:00
}
if str.Valid {
destination.SetString(str.String)
2019-10-11 10:15:36 +02:00
}
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
2019-10-11 10:15:36 +02:00
}
return nil
2019-10-11 10:15:36 +02:00
}
func setReflectValue(source, destination reflect.Value) error {
2019-10-11 10:15:36 +02:00
if destination.Kind() == reflect.Ptr {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
2019-10-11 10:15:36 +02:00
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
2019-10-11 10:15:36 +02:00
if err := tryAssign(source, destination.Elem()); err != nil {
return err
}
2019-10-11 10:15:36 +02:00
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
2019-10-11 10:15:36 +02:00
if err := tryAssign(source, destination); err != nil {
return err
}
2019-10-11 10:15:36 +02:00
}
return nil
2019-10-11 10:15:36 +02:00
}
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
if len(primaryKeyOverwrites) > 0 {
return utils.StringSliceContains(primaryKeyOverwrites, field.Name)
}
sqlTag := field.Tag.Get("sql")
return sqlTag == "primary_key"
}
func parentFieldPrimaryKeyOverwrite(parentField *reflect.StructField) []string {
if parentField == nil {
return nil
}
sqlTag := parentField.Tag.Get("sql")
if !strings.HasPrefix(sqlTag, "primary_key") {
return nil
}
parts := strings.Split(sqlTag, "=")
if len(parts) < 2 {
return nil
}
return strings.Split(parts[1], ",")
}
func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr {
return reflectType
}
return reflectType.Elem()
}
func fieldToString(field *reflect.StructField) string {
if field == nil {
return ""
}
return " at '" + field.Name + " " + field.Type.String() + "'"
}
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
c := make([]byte, len(b))
copy(c, b)
return c
}