Add support for database uuid types.

This commit is contained in:
zer0sub 2019-04-04 13:07:21 +02:00
parent 2c7a9f5058
commit 37b0a6445b
6 changed files with 126 additions and 44 deletions

View file

@ -40,14 +40,14 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string {
return "IntegerColumn" return "IntegerColumn"
case "date", "timestamp without time zone", "timestamp with time zone": case "date", "timestamp without time zone", "timestamp with time zone":
return "TimeColumn" return "TimeColumn"
case "text", "character", "character varying", "bytea": case "text", "character", "character varying", "bytea", "uuid":
return "StringColumn" return "StringColumn"
case "real": case "real":
return "NumericColumn" return "NumericColumn"
case "numeric", "double precision": case "numeric", "double precision":
return "NumericColumn" return "NumericColumn"
default: default:
fmt.Println("Unknownl type: " + c.DataType + ", using string instead.") fmt.Println("Unknownl type: " + c.DataType + ", using string column instead.")
return "StringColumn" return "StringColumn"
} }
} }
@ -67,7 +67,7 @@ func (c ColumnInfo) GoBaseType() string {
} else { } else {
switch c.DataType { switch c.DataType {
case "USER-DEFINED": case "USER-DEFINED":
return c.EnumName return snaker.SnakeToCamel(c.EnumName)
case "boolean": case "boolean":
return "bool" return "bool"
case "smallint": case "smallint":
@ -86,8 +86,10 @@ func (c ColumnInfo) GoBaseType() string {
return "float32" return "float32"
case "numeric", "double precision": case "numeric", "double precision":
return "float64" return "float64"
case "uuid":
return "uuid.UUID"
default: default:
fmt.Println("Unknown go map type: " + c.DataType + ", using string instead.") fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.")
return "string" return "string"
} }
} }

View file

@ -21,8 +21,11 @@ func (t TableInfo) GetImports() []string {
for _, column := range t.Columns { for _, column := range t.Columns {
columnType := column.GoBaseType() columnType := column.GoBaseType()
if columnType == "time.Time" { switch columnType {
case "time.Time":
imports["time.Time"] = "time" imports["time.Time"] = "time"
case "uuid.UUID":
imports["uuid.UUID"] = "github.com/google/uuid"
} }
} }

View file

@ -66,17 +66,17 @@ var EnumModelTemplate = `package model
import "errors" import "errors"
type {{.Name}} string type {{camelize $.Name}} string
const ( const (
{{- range $index, $element := .Values}} {{- range $index, $element := .Values}}
{{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}" {{camelize $.Name}}_{{camelize $element}} {{camelize $.Name}} = "{{$element}}"
{{- end}} {{- end}}
) )
func (e *{{$.Name}}) Scan(value interface{}) error { func (e *{{camelize $.Name}}) Scan(value interface{}) error {
if v, ok := value.(string); !ok { if v, ok := value.(string); !ok {
return errors.New("Invalid data for {{$.Name}} enum") return errors.New("Invalid data for {{camelize $.Name}} enum")
} else { } else {
switch string(v) { switch string(v) {
{{- range $index, $element := .Values}} {{- range $index, $element := .Values}}
@ -84,7 +84,7 @@ func (e *{{$.Name}}) Scan(value interface{}) error {
*e = {{camelize $.Name}}_{{camelize $element}} *e = {{camelize $.Name}}_{{camelize $element}}
{{- end}} {{- end}}
default: default:
return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum") return errors.New("Inavlid data " + string(v) + "for {{camelize $.Name}} enum")
} }
return nil return nil

View file

@ -36,9 +36,9 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
columnNames, _ := rows.Columns() columnNames, _ := rows.Columns()
columnTypes, _ := rows.ColumnTypes() columnTypes, _ := rows.ColumnTypes()
rowData := createScanValue(columnTypes)
scanContext := &scanContext{ scanContext := &scanContext{
row: createScanValue(columnTypes),
columnNames: columnNames, columnNames: columnNames,
uniqueObjectsMap: make(map[string]interface{}), uniqueObjectsMap: make(map[string]interface{}),
} }
@ -46,7 +46,7 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
//spew.Dump(columnTypes) //spew.Dump(columnTypes)
for rows.Next() { for rows.Next() {
err := rows.Scan(rowData...) err := rows.Scan(scanContext.row...)
if err != nil { if err != nil {
return err return err
@ -55,13 +55,13 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
scanContext.rowNum++ scanContext.rowNum++
if destinationType.Elem().Kind() == reflect.Slice { if destinationType.Elem().Kind() == reflect.Slice {
err := mapRowToSlice(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil) err := mapRowToSlice(scanContext, "", map[string]bool{}, destinationPtr, nil)
if err != nil { if err != nil {
return err return err
} }
} else if destinationType.Elem().Kind() == reflect.Struct { } else if destinationType.Elem().Kind() == reflect.Struct {
return mapRowToStruct(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil) return mapRowToStruct(scanContext, "", map[string]bool{}, destinationPtr, nil)
} }
} }
@ -77,8 +77,10 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
} }
type scanContext struct { type scanContext struct {
rowNum int rowNum int
columnNames []string columnNames []string
row []interface{}
uniqueObjectsMap map[string]interface{} uniqueObjectsMap map[string]interface{}
} }
@ -112,7 +114,7 @@ func getType(reflectType reflect.Type) string {
return structType.Name() return structType.Name()
} }
func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string { func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string {
tableName := getTableAlias(structField) tableName := getTableAlias(structField)
//fmt.Println("Group: " + tableName) //fmt.Println("Group: " + tableName)
@ -147,7 +149,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map
//spew.Dump(structType) //spew.Dump(structType)
structGroupKey := getGroupKey(scanContext, row, typesProcessed, structType, &field) structGroupKey := getGroupKey(scanContext, typesProcessed, structType, &field)
//groupKey = strings.Join([]string{structGroupKey, groupKey}, ":") //groupKey = strings.Join([]string{structGroupKey, groupKey}, ":")
@ -165,7 +167,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map
continue continue
} }
cellValue := cellValue(row, index) cellValue := cellValue(scanContext.row, index)
subKey := reflectValueToString(cellValue) subKey := reflectValueToString(cellValue)
if subKey != "" { if subKey != "" {
@ -227,12 +229,12 @@ func cloneProcessedMap(processedMap map[string]bool) map[string]bool {
return newMap return newMap
} }
func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error { func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error {
var err error var err error
structType := getSliceStructType(destinationPtr) structType := getSliceStructType(destinationPtr)
structGroupKey := getGroupKey(scanContext, row, cloneProcessedMap(typesProcessed), structType, structField) structGroupKey := getGroupKey(scanContext, cloneProcessedMap(typesProcessed), structType, structField)
if structGroupKey == "" { if structGroupKey == "" {
structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|" structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|"
@ -245,14 +247,14 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map
objPtr, ok := scanContext.uniqueObjectsMap[groupKey] objPtr, ok := scanContext.uniqueObjectsMap[groupKey]
if ok { if ok {
err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, objPtr, structField) err = mapRowToStruct(scanContext, groupKey, typesProcessed, objPtr, structField)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
destinationStructPtr := newElemForSlice(destinationPtr) destinationStructPtr := newElemForSlice(destinationPtr)
err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, destinationStructPtr, structField) err = mapRowToStruct(scanContext, groupKey, typesProcessed, destinationStructPtr, structField)
if err != nil { if err != nil {
return err return err
@ -290,14 +292,14 @@ func newElemForSlice(destinationSlicePtr interface{}) interface{} {
return reflect.New(elemType).Interface() return reflect.New(elemType).Interface()
} }
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, dest reflect.Value, structField *reflect.StructField) error { func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, dest reflect.Value, structField *reflect.StructField) error {
if dest.Kind() == reflect.Struct { if dest.Kind() == reflect.Struct {
err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField) err := mapRowToStruct(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField)
if err != nil { if err != nil {
return err return err
} }
} else if dest.Kind() == reflect.Slice { } else if dest.Kind() == reflect.Slice {
err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField) err := mapRowToSlice(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField)
if err != nil { if err != nil {
return err return err
} }
@ -313,7 +315,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr
return nil return nil
} }
err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, structValuePtr.Interface(), structField) err := mapRowToStruct(scanContext, groupKey, typesProcessed, structValuePtr.Interface(), structField)
if err != nil { if err != nil {
return err return err
} }
@ -331,7 +333,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr
sliceValuePtr = dest sliceValuePtr = dest
} }
err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, sliceValuePtr.Interface(), structField) err := mapRowToSlice(scanContext, groupKey, typesProcessed, sliceValuePtr.Interface(), structField)
if err != nil { if err != nil {
return err return err
} }
@ -390,7 +392,7 @@ func getTableAlias(structField *reflect.StructField) string {
return snaker.CamelToSnake(elemType) return snaker.CamelToSnake(elemType)
} }
func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error { func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error {
structType := reflect.TypeOf(destinationPtr).Elem() structType := reflect.TypeOf(destinationPtr).Elem()
structValue := reflect.ValueOf(destinationPtr).Elem() structValue := reflect.ValueOf(destinationPtr).Elem()
@ -419,16 +421,20 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
fieldName := field.Name fieldName := field.Name
if _, ok := fieldValue.Interface().(sql.Scanner); ok { if scannerValue, ok := implementsScanner(fieldValue); ok {
cellValue := getCellValue(scanContext, tableName, fieldName, row) cellValue := getCellValue(scanContext, tableName, fieldName)
if cellValue == nil { if cellValue == nil {
continue continue
} }
initializeValue(fieldValue) //spew.Dump(scannerValue.Interface())
scanner := fieldValue.Interface().(sql.Scanner) if scannerValue.IsNil() {
initializePtrValue(scannerValue)
}
scanner := scannerValue.Interface().(sql.Scanner)
err := scanner.Scan(cellValue) err := scanner.Scan(cellValue)
@ -437,13 +443,13 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
} }
} else if !isDbBaseType(field.Type) { } else if !isDbBaseType(field.Type) {
//var fieldValueInterface interface{} //var fieldValueInterface interface{}
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field) err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, fieldValue, &field)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
cellValue := getCellValue(scanContext, tableName, fieldName, row) cellValue := getCellValue(scanContext, tableName, fieldName)
//spew.Dump(cellValue) //spew.Dump(cellValue)
//spew.Dump(rowColumnValue, fieldValue) //spew.Dump(rowColumnValue, fieldValue)
@ -456,16 +462,25 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
return nil return nil
} }
func initializeValue(value reflect.Value) { func implementsScanner(value reflect.Value) (reflect.Value, bool) {
newValuePtr := reflect.New(value.Type().Elem()) if _, ok := value.Interface().(sql.Scanner); ok {
return value, true
} else if value.CanAddr() {
if _, ok := value.Addr().Interface().(sql.Scanner); ok {
return value.Addr(), true
}
}
return value, false
}
func initializePtrValue(value reflect.Value) {
if value.Kind() == reflect.Ptr { if value.Kind() == reflect.Ptr {
value.Set(newValuePtr) value.Set(reflect.New(value.Type().Elem()))
} else {
value.Set(newValuePtr.Elem())
} }
} }
func getCellValue(scanContext *scanContext, tableName, fieldName string, row []interface{}) interface{} { func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} {
columnName := tableName + "." + snaker.CamelToSnake(fieldName) columnName := tableName + "." + snaker.CamelToSnake(fieldName)
//columnName := snaker.CamelToSnake(fieldName) //columnName := snaker.CamelToSnake(fieldName)
@ -476,7 +491,7 @@ func getCellValue(scanContext *scanContext, tableName, fieldName string, row []i
return nil return nil
} }
return cellValue(row, index) return cellValue(scanContext.row, index)
} }
func reflectValueToString(val interface{}) string { func reflectValueToString(val interface{}) string {
@ -585,7 +600,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
return nullInt32Type return nullInt32Type
case "INT8": case "INT8":
return nullInt64Type return nullInt64Type
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR": case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "BYTEA", "UUID":
return nullStringType return nullStringType
case "FLOAT4": case "FLOAT4":
return nullFloatType return nullFloatType
@ -593,7 +608,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
return nullFloat64Type return nullFloat64Type
case "BOOL": case "BOOL":
return nullBoolType return nullBoolType
case "DATE", "TIMESTAMP": case "DATE", "TIMESTAMP", "TIMESTAMPTZ":
return nullTimeType return nullTimeType
default: default:
panic("Unknown column database type " + columnType.DatabaseTypeName()) panic("Unknown column database type " + columnType.DatabaseTypeName())

View file

@ -6,6 +6,7 @@ type StringExpression interface {
Eq(expression StringExpression) BoolExpression Eq(expression StringExpression) BoolExpression
EqL(value string) BoolExpression EqL(value string) BoolExpression
NotEq(expression StringExpression) BoolExpression NotEq(expression StringExpression) BoolExpression
NotEqL(value string) BoolExpression
} }
type stringInterfaceImpl struct { type stringInterfaceImpl struct {
@ -23,3 +24,7 @@ func (b *stringInterfaceImpl) EqL(value string) BoolExpression {
func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression { func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, []byte(" != ")) return newBinaryBoolExpression(b.parent, expression, []byte(" != "))
} }
func (b *stringInterfaceImpl) NotEqL(value string) BoolExpression {
return newBinaryBoolExpression(b.parent, Literal(value), []byte(" != "))
}

57
tests/sample_test.go Normal file
View file

@ -0,0 +1,57 @@
package tests
import (
"fmt"
"github.com/davecgh/go-spew/spew"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
"github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
"gotest.tools/assert"
"testing"
)
func TestUUIDType(t *testing.T) {
query := table.AllTypes.
SELECT(table.AllTypes.AllColumns).
Where(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'`)
result := model.AllTypes{}
err = query.Execute(db, &result)
spew.Dump(result)
}
func TestEnumType(t *testing.T) {
query := table.Person.
SELECT(table.Person.AllColumns)
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
result := []model.Person{}
err = query.Execute(db, &result)
assert.NilError(t, err)
//spew.Dump(result)
type Person struct {
Name string
CurrentMood model.Mood
}
result2 := []Person{}
err = query.Execute(db, &result2)
assert.NilError(t, err)
//spew.Dump(result2)
}