Add support for database uuid types.
This commit is contained in:
parent
2c7a9f5058
commit
37b0a6445b
6 changed files with 126 additions and 44 deletions
|
|
@ -40,14 +40,14 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string {
|
|||
return "IntegerColumn"
|
||||
case "date", "timestamp without time zone", "timestamp with time zone":
|
||||
return "TimeColumn"
|
||||
case "text", "character", "character varying", "bytea":
|
||||
case "text", "character", "character varying", "bytea", "uuid":
|
||||
return "StringColumn"
|
||||
case "real":
|
||||
return "NumericColumn"
|
||||
case "numeric", "double precision":
|
||||
return "NumericColumn"
|
||||
default:
|
||||
fmt.Println("Unknownl type: " + c.DataType + ", using string instead.")
|
||||
fmt.Println("Unknownl type: " + c.DataType + ", using string column instead.")
|
||||
return "StringColumn"
|
||||
}
|
||||
}
|
||||
|
|
@ -67,7 +67,7 @@ func (c ColumnInfo) GoBaseType() string {
|
|||
} else {
|
||||
switch c.DataType {
|
||||
case "USER-DEFINED":
|
||||
return c.EnumName
|
||||
return snaker.SnakeToCamel(c.EnumName)
|
||||
case "boolean":
|
||||
return "bool"
|
||||
case "smallint":
|
||||
|
|
@ -86,8 +86,10 @@ func (c ColumnInfo) GoBaseType() string {
|
|||
return "float32"
|
||||
case "numeric", "double precision":
|
||||
return "float64"
|
||||
case "uuid":
|
||||
return "uuid.UUID"
|
||||
default:
|
||||
fmt.Println("Unknown go map type: " + c.DataType + ", using string instead.")
|
||||
fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.")
|
||||
return "string"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,8 +21,11 @@ func (t TableInfo) GetImports() []string {
|
|||
for _, column := range t.Columns {
|
||||
columnType := column.GoBaseType()
|
||||
|
||||
if columnType == "time.Time" {
|
||||
switch columnType {
|
||||
case "time.Time":
|
||||
imports["time.Time"] = "time"
|
||||
case "uuid.UUID":
|
||||
imports["uuid.UUID"] = "github.com/google/uuid"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -66,17 +66,17 @@ var EnumModelTemplate = `package model
|
|||
|
||||
import "errors"
|
||||
|
||||
type {{.Name}} string
|
||||
type {{camelize $.Name}} string
|
||||
|
||||
const (
|
||||
{{- range $index, $element := .Values}}
|
||||
{{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}"
|
||||
{{camelize $.Name}}_{{camelize $element}} {{camelize $.Name}} = "{{$element}}"
|
||||
{{- end}}
|
||||
)
|
||||
|
||||
func (e *{{$.Name}}) Scan(value interface{}) error {
|
||||
func (e *{{camelize $.Name}}) Scan(value interface{}) error {
|
||||
if v, ok := value.(string); !ok {
|
||||
return errors.New("Invalid data for {{$.Name}} enum")
|
||||
return errors.New("Invalid data for {{camelize $.Name}} enum")
|
||||
} else {
|
||||
switch string(v) {
|
||||
{{- range $index, $element := .Values}}
|
||||
|
|
@ -84,7 +84,7 @@ func (e *{{$.Name}}) Scan(value interface{}) error {
|
|||
*e = {{camelize $.Name}}_{{camelize $element}}
|
||||
{{- end}}
|
||||
default:
|
||||
return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum")
|
||||
return errors.New("Inavlid data " + string(v) + "for {{camelize $.Name}} enum")
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -36,9 +36,9 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
|
|||
|
||||
columnNames, _ := rows.Columns()
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
rowData := createScanValue(columnTypes)
|
||||
|
||||
scanContext := &scanContext{
|
||||
row: createScanValue(columnTypes),
|
||||
columnNames: columnNames,
|
||||
uniqueObjectsMap: make(map[string]interface{}),
|
||||
}
|
||||
|
|
@ -46,7 +46,7 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
|
|||
//spew.Dump(columnTypes)
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(rowData...)
|
||||
err := rows.Scan(scanContext.row...)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -55,13 +55,13 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
|
|||
scanContext.rowNum++
|
||||
|
||||
if destinationType.Elem().Kind() == reflect.Slice {
|
||||
err := mapRowToSlice(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil)
|
||||
err := mapRowToSlice(scanContext, "", map[string]bool{}, destinationPtr, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if destinationType.Elem().Kind() == reflect.Struct {
|
||||
return mapRowToStruct(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil)
|
||||
return mapRowToStruct(scanContext, "", map[string]bool{}, destinationPtr, nil)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -77,8 +77,10 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
|
|||
}
|
||||
|
||||
type scanContext struct {
|
||||
rowNum int
|
||||
columnNames []string
|
||||
rowNum int
|
||||
columnNames []string
|
||||
|
||||
row []interface{}
|
||||
uniqueObjectsMap map[string]interface{}
|
||||
}
|
||||
|
||||
|
|
@ -112,7 +114,7 @@ func getType(reflectType reflect.Type) string {
|
|||
return structType.Name()
|
||||
}
|
||||
|
||||
func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string {
|
||||
func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string {
|
||||
tableName := getTableAlias(structField)
|
||||
|
||||
//fmt.Println("Group: " + tableName)
|
||||
|
|
@ -147,7 +149,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map
|
|||
|
||||
//spew.Dump(structType)
|
||||
|
||||
structGroupKey := getGroupKey(scanContext, row, typesProcessed, structType, &field)
|
||||
structGroupKey := getGroupKey(scanContext, typesProcessed, structType, &field)
|
||||
|
||||
//groupKey = strings.Join([]string{structGroupKey, groupKey}, ":")
|
||||
|
||||
|
|
@ -165,7 +167,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map
|
|||
continue
|
||||
}
|
||||
|
||||
cellValue := cellValue(row, index)
|
||||
cellValue := cellValue(scanContext.row, index)
|
||||
subKey := reflectValueToString(cellValue)
|
||||
|
||||
if subKey != "" {
|
||||
|
|
@ -227,12 +229,12 @@ func cloneProcessedMap(processedMap map[string]bool) map[string]bool {
|
|||
return newMap
|
||||
}
|
||||
|
||||
func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error {
|
||||
func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error {
|
||||
var err error
|
||||
|
||||
structType := getSliceStructType(destinationPtr)
|
||||
|
||||
structGroupKey := getGroupKey(scanContext, row, cloneProcessedMap(typesProcessed), structType, structField)
|
||||
structGroupKey := getGroupKey(scanContext, cloneProcessedMap(typesProcessed), structType, structField)
|
||||
|
||||
if structGroupKey == "" {
|
||||
structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|"
|
||||
|
|
@ -245,14 +247,14 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map
|
|||
objPtr, ok := scanContext.uniqueObjectsMap[groupKey]
|
||||
|
||||
if ok {
|
||||
err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, objPtr, structField)
|
||||
err = mapRowToStruct(scanContext, groupKey, typesProcessed, objPtr, structField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
destinationStructPtr := newElemForSlice(destinationPtr)
|
||||
|
||||
err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, destinationStructPtr, structField)
|
||||
err = mapRowToStruct(scanContext, groupKey, typesProcessed, destinationStructPtr, structField)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -290,14 +292,14 @@ func newElemForSlice(destinationSlicePtr interface{}) interface{} {
|
|||
return reflect.New(elemType).Interface()
|
||||
}
|
||||
|
||||
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, dest reflect.Value, structField *reflect.StructField) error {
|
||||
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, dest reflect.Value, structField *reflect.StructField) error {
|
||||
if dest.Kind() == reflect.Struct {
|
||||
err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField)
|
||||
err := mapRowToStruct(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if dest.Kind() == reflect.Slice {
|
||||
err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField)
|
||||
err := mapRowToSlice(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -313,7 +315,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr
|
|||
return nil
|
||||
}
|
||||
|
||||
err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, structValuePtr.Interface(), structField)
|
||||
err := mapRowToStruct(scanContext, groupKey, typesProcessed, structValuePtr.Interface(), structField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -331,7 +333,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr
|
|||
sliceValuePtr = dest
|
||||
}
|
||||
|
||||
err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, sliceValuePtr.Interface(), structField)
|
||||
err := mapRowToSlice(scanContext, groupKey, typesProcessed, sliceValuePtr.Interface(), structField)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -390,7 +392,7 @@ func getTableAlias(structField *reflect.StructField) string {
|
|||
return snaker.CamelToSnake(elemType)
|
||||
}
|
||||
|
||||
func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error {
|
||||
func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error {
|
||||
structType := reflect.TypeOf(destinationPtr).Elem()
|
||||
structValue := reflect.ValueOf(destinationPtr).Elem()
|
||||
|
||||
|
|
@ -419,16 +421,20 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
|
|||
|
||||
fieldName := field.Name
|
||||
|
||||
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||
cellValue := getCellValue(scanContext, tableName, fieldName, row)
|
||||
if scannerValue, ok := implementsScanner(fieldValue); ok {
|
||||
cellValue := getCellValue(scanContext, tableName, fieldName)
|
||||
|
||||
if cellValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
initializeValue(fieldValue)
|
||||
//spew.Dump(scannerValue.Interface())
|
||||
|
||||
scanner := fieldValue.Interface().(sql.Scanner)
|
||||
if scannerValue.IsNil() {
|
||||
initializePtrValue(scannerValue)
|
||||
}
|
||||
|
||||
scanner := scannerValue.Interface().(sql.Scanner)
|
||||
|
||||
err := scanner.Scan(cellValue)
|
||||
|
||||
|
|
@ -437,13 +443,13 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
|
|||
}
|
||||
} else if !isDbBaseType(field.Type) {
|
||||
//var fieldValueInterface interface{}
|
||||
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field)
|
||||
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, fieldValue, &field)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
cellValue := getCellValue(scanContext, tableName, fieldName, row)
|
||||
cellValue := getCellValue(scanContext, tableName, fieldName)
|
||||
//spew.Dump(cellValue)
|
||||
|
||||
//spew.Dump(rowColumnValue, fieldValue)
|
||||
|
|
@ -456,16 +462,25 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
|
|||
return nil
|
||||
}
|
||||
|
||||
func initializeValue(value reflect.Value) {
|
||||
newValuePtr := reflect.New(value.Type().Elem())
|
||||
func implementsScanner(value reflect.Value) (reflect.Value, bool) {
|
||||
if _, ok := value.Interface().(sql.Scanner); ok {
|
||||
return value, true
|
||||
} else if value.CanAddr() {
|
||||
if _, ok := value.Addr().Interface().(sql.Scanner); ok {
|
||||
return value.Addr(), true
|
||||
}
|
||||
}
|
||||
|
||||
return value, false
|
||||
}
|
||||
|
||||
func initializePtrValue(value reflect.Value) {
|
||||
if value.Kind() == reflect.Ptr {
|
||||
value.Set(newValuePtr)
|
||||
} else {
|
||||
value.Set(newValuePtr.Elem())
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
}
|
||||
}
|
||||
|
||||
func getCellValue(scanContext *scanContext, tableName, fieldName string, row []interface{}) interface{} {
|
||||
func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} {
|
||||
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
|
||||
//columnName := snaker.CamelToSnake(fieldName)
|
||||
|
||||
|
|
@ -476,7 +491,7 @@ func getCellValue(scanContext *scanContext, tableName, fieldName string, row []i
|
|||
return nil
|
||||
}
|
||||
|
||||
return cellValue(row, index)
|
||||
return cellValue(scanContext.row, index)
|
||||
}
|
||||
|
||||
func reflectValueToString(val interface{}) string {
|
||||
|
|
@ -585,7 +600,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
|
|||
return nullInt32Type
|
||||
case "INT8":
|
||||
return nullInt64Type
|
||||
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR":
|
||||
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "BYTEA", "UUID":
|
||||
return nullStringType
|
||||
case "FLOAT4":
|
||||
return nullFloatType
|
||||
|
|
@ -593,7 +608,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
|
|||
return nullFloat64Type
|
||||
case "BOOL":
|
||||
return nullBoolType
|
||||
case "DATE", "TIMESTAMP":
|
||||
case "DATE", "TIMESTAMP", "TIMESTAMPTZ":
|
||||
return nullTimeType
|
||||
default:
|
||||
panic("Unknown column database type " + columnType.DatabaseTypeName())
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ type StringExpression interface {
|
|||
Eq(expression StringExpression) BoolExpression
|
||||
EqL(value string) BoolExpression
|
||||
NotEq(expression StringExpression) BoolExpression
|
||||
NotEqL(value string) BoolExpression
|
||||
}
|
||||
|
||||
type stringInterfaceImpl struct {
|
||||
|
|
@ -23,3 +24,7 @@ func (b *stringInterfaceImpl) EqL(value string) BoolExpression {
|
|||
func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression {
|
||||
return newBinaryBoolExpression(b.parent, expression, []byte(" != "))
|
||||
}
|
||||
|
||||
func (b *stringInterfaceImpl) NotEqL(value string) BoolExpression {
|
||||
return newBinaryBoolExpression(b.parent, Literal(value), []byte(" != "))
|
||||
}
|
||||
|
|
|
|||
57
tests/sample_test.go
Normal file
57
tests/sample_test.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
|
||||
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
|
||||
"gotest.tools/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUUIDType(t *testing.T) {
|
||||
query := table.AllTypes.
|
||||
SELECT(table.AllTypes.AllColumns).
|
||||
Where(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
|
||||
|
||||
queryStr, err := query.String()
|
||||
|
||||
assert.NilError(t, err)
|
||||
fmt.Println(queryStr)
|
||||
assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'`)
|
||||
|
||||
result := model.AllTypes{}
|
||||
|
||||
err = query.Execute(db, &result)
|
||||
spew.Dump(result)
|
||||
}
|
||||
|
||||
func TestEnumType(t *testing.T) {
|
||||
query := table.Person.
|
||||
SELECT(table.Person.AllColumns)
|
||||
|
||||
queryStr, err := query.String()
|
||||
|
||||
assert.NilError(t, err)
|
||||
fmt.Println(queryStr)
|
||||
|
||||
result := []model.Person{}
|
||||
|
||||
err = query.Execute(db, &result)
|
||||
|
||||
assert.NilError(t, err)
|
||||
//spew.Dump(result)
|
||||
|
||||
type Person struct {
|
||||
Name string
|
||||
CurrentMood model.Mood
|
||||
}
|
||||
|
||||
result2 := []Person{}
|
||||
|
||||
err = query.Execute(db, &result2)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
//spew.Dump(result2)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue