Add support for database uuid types.

This commit is contained in:
zer0sub 2019-04-04 13:07:21 +02:00
parent 2c7a9f5058
commit 37b0a6445b
6 changed files with 126 additions and 44 deletions

View file

@ -40,14 +40,14 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string {
return "IntegerColumn"
case "date", "timestamp without time zone", "timestamp with time zone":
return "TimeColumn"
case "text", "character", "character varying", "bytea":
case "text", "character", "character varying", "bytea", "uuid":
return "StringColumn"
case "real":
return "NumericColumn"
case "numeric", "double precision":
return "NumericColumn"
default:
fmt.Println("Unknownl type: " + c.DataType + ", using string instead.")
fmt.Println("Unknownl type: " + c.DataType + ", using string column instead.")
return "StringColumn"
}
}
@ -67,7 +67,7 @@ func (c ColumnInfo) GoBaseType() string {
} else {
switch c.DataType {
case "USER-DEFINED":
return c.EnumName
return snaker.SnakeToCamel(c.EnumName)
case "boolean":
return "bool"
case "smallint":
@ -86,8 +86,10 @@ func (c ColumnInfo) GoBaseType() string {
return "float32"
case "numeric", "double precision":
return "float64"
case "uuid":
return "uuid.UUID"
default:
fmt.Println("Unknown go map type: " + c.DataType + ", using string instead.")
fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.")
return "string"
}
}

View file

@ -21,8 +21,11 @@ func (t TableInfo) GetImports() []string {
for _, column := range t.Columns {
columnType := column.GoBaseType()
if columnType == "time.Time" {
switch columnType {
case "time.Time":
imports["time.Time"] = "time"
case "uuid.UUID":
imports["uuid.UUID"] = "github.com/google/uuid"
}
}

View file

@ -66,17 +66,17 @@ var EnumModelTemplate = `package model
import "errors"
type {{.Name}} string
type {{camelize $.Name}} string
const (
{{- range $index, $element := .Values}}
{{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}"
{{camelize $.Name}}_{{camelize $element}} {{camelize $.Name}} = "{{$element}}"
{{- end}}
)
func (e *{{$.Name}}) Scan(value interface{}) error {
func (e *{{camelize $.Name}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok {
return errors.New("Invalid data for {{$.Name}} enum")
return errors.New("Invalid data for {{camelize $.Name}} enum")
} else {
switch string(v) {
{{- range $index, $element := .Values}}
@ -84,7 +84,7 @@ func (e *{{$.Name}}) Scan(value interface{}) error {
*e = {{camelize $.Name}}_{{camelize $element}}
{{- end}}
default:
return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum")
return errors.New("Inavlid data " + string(v) + "for {{camelize $.Name}} enum")
}
return nil

View file

@ -36,9 +36,9 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
columnNames, _ := rows.Columns()
columnTypes, _ := rows.ColumnTypes()
rowData := createScanValue(columnTypes)
scanContext := &scanContext{
row: createScanValue(columnTypes),
columnNames: columnNames,
uniqueObjectsMap: make(map[string]interface{}),
}
@ -46,7 +46,7 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
//spew.Dump(columnTypes)
for rows.Next() {
err := rows.Scan(rowData...)
err := rows.Scan(scanContext.row...)
if err != nil {
return err
@ -55,13 +55,13 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
scanContext.rowNum++
if destinationType.Elem().Kind() == reflect.Slice {
err := mapRowToSlice(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil)
err := mapRowToSlice(scanContext, "", map[string]bool{}, destinationPtr, nil)
if err != nil {
return err
}
} else if destinationType.Elem().Kind() == reflect.Struct {
return mapRowToStruct(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil)
return mapRowToStruct(scanContext, "", map[string]bool{}, destinationPtr, nil)
}
}
@ -77,8 +77,10 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
}
type scanContext struct {
rowNum int
columnNames []string
rowNum int
columnNames []string
row []interface{}
uniqueObjectsMap map[string]interface{}
}
@ -112,7 +114,7 @@ func getType(reflectType reflect.Type) string {
return structType.Name()
}
func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string {
func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string {
tableName := getTableAlias(structField)
//fmt.Println("Group: " + tableName)
@ -147,7 +149,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map
//spew.Dump(structType)
structGroupKey := getGroupKey(scanContext, row, typesProcessed, structType, &field)
structGroupKey := getGroupKey(scanContext, typesProcessed, structType, &field)
//groupKey = strings.Join([]string{structGroupKey, groupKey}, ":")
@ -165,7 +167,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map
continue
}
cellValue := cellValue(row, index)
cellValue := cellValue(scanContext.row, index)
subKey := reflectValueToString(cellValue)
if subKey != "" {
@ -227,12 +229,12 @@ func cloneProcessedMap(processedMap map[string]bool) map[string]bool {
return newMap
}
func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error {
func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error {
var err error
structType := getSliceStructType(destinationPtr)
structGroupKey := getGroupKey(scanContext, row, cloneProcessedMap(typesProcessed), structType, structField)
structGroupKey := getGroupKey(scanContext, cloneProcessedMap(typesProcessed), structType, structField)
if structGroupKey == "" {
structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|"
@ -245,14 +247,14 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map
objPtr, ok := scanContext.uniqueObjectsMap[groupKey]
if ok {
err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, objPtr, structField)
err = mapRowToStruct(scanContext, groupKey, typesProcessed, objPtr, structField)
if err != nil {
return err
}
} else {
destinationStructPtr := newElemForSlice(destinationPtr)
err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, destinationStructPtr, structField)
err = mapRowToStruct(scanContext, groupKey, typesProcessed, destinationStructPtr, structField)
if err != nil {
return err
@ -290,14 +292,14 @@ func newElemForSlice(destinationSlicePtr interface{}) interface{} {
return reflect.New(elemType).Interface()
}
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, dest reflect.Value, structField *reflect.StructField) error {
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, dest reflect.Value, structField *reflect.StructField) error {
if dest.Kind() == reflect.Struct {
err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField)
err := mapRowToStruct(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField)
if err != nil {
return err
}
} else if dest.Kind() == reflect.Slice {
err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField)
err := mapRowToSlice(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField)
if err != nil {
return err
}
@ -313,7 +315,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr
return nil
}
err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, structValuePtr.Interface(), structField)
err := mapRowToStruct(scanContext, groupKey, typesProcessed, structValuePtr.Interface(), structField)
if err != nil {
return err
}
@ -331,7 +333,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr
sliceValuePtr = dest
}
err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, sliceValuePtr.Interface(), structField)
err := mapRowToSlice(scanContext, groupKey, typesProcessed, sliceValuePtr.Interface(), structField)
if err != nil {
return err
}
@ -390,7 +392,7 @@ func getTableAlias(structField *reflect.StructField) string {
return snaker.CamelToSnake(elemType)
}
func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error {
func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error {
structType := reflect.TypeOf(destinationPtr).Elem()
structValue := reflect.ValueOf(destinationPtr).Elem()
@ -419,16 +421,20 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
fieldName := field.Name
if _, ok := fieldValue.Interface().(sql.Scanner); ok {
cellValue := getCellValue(scanContext, tableName, fieldName, row)
if scannerValue, ok := implementsScanner(fieldValue); ok {
cellValue := getCellValue(scanContext, tableName, fieldName)
if cellValue == nil {
continue
}
initializeValue(fieldValue)
//spew.Dump(scannerValue.Interface())
scanner := fieldValue.Interface().(sql.Scanner)
if scannerValue.IsNil() {
initializePtrValue(scannerValue)
}
scanner := scannerValue.Interface().(sql.Scanner)
err := scanner.Scan(cellValue)
@ -437,13 +443,13 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
}
} else if !isDbBaseType(field.Type) {
//var fieldValueInterface interface{}
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field)
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, fieldValue, &field)
if err != nil {
return err
}
} else {
cellValue := getCellValue(scanContext, tableName, fieldName, row)
cellValue := getCellValue(scanContext, tableName, fieldName)
//spew.Dump(cellValue)
//spew.Dump(rowColumnValue, fieldValue)
@ -456,16 +462,25 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
return nil
}
func initializeValue(value reflect.Value) {
newValuePtr := reflect.New(value.Type().Elem())
func implementsScanner(value reflect.Value) (reflect.Value, bool) {
if _, ok := value.Interface().(sql.Scanner); ok {
return value, true
} else if value.CanAddr() {
if _, ok := value.Addr().Interface().(sql.Scanner); ok {
return value.Addr(), true
}
}
return value, false
}
func initializePtrValue(value reflect.Value) {
if value.Kind() == reflect.Ptr {
value.Set(newValuePtr)
} else {
value.Set(newValuePtr.Elem())
value.Set(reflect.New(value.Type().Elem()))
}
}
func getCellValue(scanContext *scanContext, tableName, fieldName string, row []interface{}) interface{} {
func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} {
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
//columnName := snaker.CamelToSnake(fieldName)
@ -476,7 +491,7 @@ func getCellValue(scanContext *scanContext, tableName, fieldName string, row []i
return nil
}
return cellValue(row, index)
return cellValue(scanContext.row, index)
}
func reflectValueToString(val interface{}) string {
@ -585,7 +600,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
return nullInt32Type
case "INT8":
return nullInt64Type
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR":
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "BYTEA", "UUID":
return nullStringType
case "FLOAT4":
return nullFloatType
@ -593,7 +608,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
return nullFloat64Type
case "BOOL":
return nullBoolType
case "DATE", "TIMESTAMP":
case "DATE", "TIMESTAMP", "TIMESTAMPTZ":
return nullTimeType
default:
panic("Unknown column database type " + columnType.DatabaseTypeName())

View file

@ -6,6 +6,7 @@ type StringExpression interface {
Eq(expression StringExpression) BoolExpression
EqL(value string) BoolExpression
NotEq(expression StringExpression) BoolExpression
NotEqL(value string) BoolExpression
}
type stringInterfaceImpl struct {
@ -23,3 +24,7 @@ func (b *stringInterfaceImpl) EqL(value string) BoolExpression {
func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, []byte(" != "))
}
func (b *stringInterfaceImpl) NotEqL(value string) BoolExpression {
return newBinaryBoolExpression(b.parent, Literal(value), []byte(" != "))
}

57
tests/sample_test.go Normal file
View file

@ -0,0 +1,57 @@
package tests
import (
"fmt"
"github.com/davecgh/go-spew/spew"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
"gotest.tools/assert"
"testing"
)
func TestUUIDType(t *testing.T) {
query := table.AllTypes.
SELECT(table.AllTypes.AllColumns).
Where(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'`)
result := model.AllTypes{}
err = query.Execute(db, &result)
spew.Dump(result)
}
func TestEnumType(t *testing.T) {
query := table.Person.
SELECT(table.Person.AllColumns)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
result := []model.Person{}
err = query.Execute(db, &result)
assert.NilError(t, err)
//spew.Dump(result)
type Person struct {
Name string
CurrentMood model.Mood
}
result2 := []Person{}
err = query.Execute(db, &result2)
assert.NilError(t, err)
//spew.Dump(result2)
}