diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go index 4197ae8..452ad69 100644 --- a/generator/metadata/column_info.go +++ b/generator/metadata/column_info.go @@ -40,14 +40,14 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string { return "IntegerColumn" case "date", "timestamp without time zone", "timestamp with time zone": return "TimeColumn" - case "text", "character", "character varying", "bytea": + case "text", "character", "character varying", "bytea", "uuid": return "StringColumn" case "real": return "NumericColumn" case "numeric", "double precision": return "NumericColumn" default: - fmt.Println("Unknownl type: " + c.DataType + ", using string instead.") + fmt.Println("Unknownl type: " + c.DataType + ", using string column instead.") return "StringColumn" } } @@ -67,7 +67,7 @@ func (c ColumnInfo) GoBaseType() string { } else { switch c.DataType { case "USER-DEFINED": - return c.EnumName + return snaker.SnakeToCamel(c.EnumName) case "boolean": return "bool" case "smallint": @@ -86,8 +86,10 @@ func (c ColumnInfo) GoBaseType() string { return "float32" case "numeric", "double precision": return "float64" + case "uuid": + return "uuid.UUID" default: - fmt.Println("Unknown go map type: " + c.DataType + ", using string instead.") + fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.") return "string" } } diff --git a/generator/metadata/table_info.go b/generator/metadata/table_info.go index e3c6d77..554e835 100644 --- a/generator/metadata/table_info.go +++ b/generator/metadata/table_info.go @@ -21,8 +21,11 @@ func (t TableInfo) GetImports() []string { for _, column := range t.Columns { columnType := column.GoBaseType() - if columnType == "time.Time" { + switch columnType { + case "time.Time": imports["time.Time"] = "time" + case "uuid.UUID": + imports["uuid.UUID"] = "github.com/google/uuid" } } diff --git a/generator/templates.go b/generator/templates.go index d8672dc..eb9e836 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -66,17 +66,17 @@ var EnumModelTemplate = `package model import "errors" -type {{.Name}} string +type {{camelize $.Name}} string const ( {{- range $index, $element := .Values}} - {{camelize $.Name}}_{{camelize $element}} {{$.Name}} = "{{$element}}" + {{camelize $.Name}}_{{camelize $element}} {{camelize $.Name}} = "{{$element}}" {{- end}} ) -func (e *{{$.Name}}) Scan(value interface{}) error { +func (e *{{camelize $.Name}}) Scan(value interface{}) error { if v, ok := value.(string); !ok { - return errors.New("Invalid data for {{$.Name}} enum") + return errors.New("Invalid data for {{camelize $.Name}} enum") } else { switch string(v) { {{- range $index, $element := .Values}} @@ -84,7 +84,7 @@ func (e *{{$.Name}}) Scan(value interface{}) error { *e = {{camelize $.Name}}_{{camelize $element}} {{- end}} default: - return errors.New("Inavlid data " + string(v) + "for {{$.Name}} enum") + return errors.New("Inavlid data " + string(v) + "for {{camelize $.Name}} enum") } return nil diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 4bcbbcc..6ebce36 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -36,9 +36,9 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { columnNames, _ := rows.Columns() columnTypes, _ := rows.ColumnTypes() - rowData := createScanValue(columnTypes) scanContext := &scanContext{ + row: createScanValue(columnTypes), columnNames: columnNames, uniqueObjectsMap: make(map[string]interface{}), } @@ -46,7 +46,7 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { //spew.Dump(columnTypes) for rows.Next() { - err := rows.Scan(rowData...) + err := rows.Scan(scanContext.row...) if err != nil { return err @@ -55,13 +55,13 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { scanContext.rowNum++ if destinationType.Elem().Kind() == reflect.Slice { - err := mapRowToSlice(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil) + err := mapRowToSlice(scanContext, "", map[string]bool{}, destinationPtr, nil) if err != nil { return err } } else if destinationType.Elem().Kind() == reflect.Struct { - return mapRowToStruct(scanContext, "", map[string]bool{}, rowData, destinationPtr, nil) + return mapRowToStruct(scanContext, "", map[string]bool{}, destinationPtr, nil) } } @@ -77,8 +77,10 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { } type scanContext struct { - rowNum int - columnNames []string + rowNum int + columnNames []string + + row []interface{} uniqueObjectsMap map[string]interface{} } @@ -112,7 +114,7 @@ func getType(reflectType reflect.Type) string { return structType.Name() } -func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string { +func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string { tableName := getTableAlias(structField) //fmt.Println("Group: " + tableName) @@ -147,7 +149,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map //spew.Dump(structType) - structGroupKey := getGroupKey(scanContext, row, typesProcessed, structType, &field) + structGroupKey := getGroupKey(scanContext, typesProcessed, structType, &field) //groupKey = strings.Join([]string{structGroupKey, groupKey}, ":") @@ -165,7 +167,7 @@ func getGroupKey(scanContext *scanContext, row []interface{}, typesProcessed map continue } - cellValue := cellValue(row, index) + cellValue := cellValue(scanContext.row, index) subKey := reflectValueToString(cellValue) if subKey != "" { @@ -227,12 +229,12 @@ func cloneProcessedMap(processedMap map[string]bool) map[string]bool { return newMap } -func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error { +func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error { var err error structType := getSliceStructType(destinationPtr) - structGroupKey := getGroupKey(scanContext, row, cloneProcessedMap(typesProcessed), structType, structField) + structGroupKey := getGroupKey(scanContext, cloneProcessedMap(typesProcessed), structType, structField) if structGroupKey == "" { structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|" @@ -245,14 +247,14 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map objPtr, ok := scanContext.uniqueObjectsMap[groupKey] if ok { - err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, objPtr, structField) + err = mapRowToStruct(scanContext, groupKey, typesProcessed, objPtr, structField) if err != nil { return err } } else { destinationStructPtr := newElemForSlice(destinationPtr) - err = mapRowToStruct(scanContext, groupKey, typesProcessed, row, destinationStructPtr, structField) + err = mapRowToStruct(scanContext, groupKey, typesProcessed, destinationStructPtr, structField) if err != nil { return err @@ -290,14 +292,14 @@ func newElemForSlice(destinationSlicePtr interface{}) interface{} { return reflect.New(elemType).Interface() } -func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, dest reflect.Value, structField *reflect.StructField) error { +func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, dest reflect.Value, structField *reflect.StructField) error { if dest.Kind() == reflect.Struct { - err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField) + err := mapRowToStruct(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField) if err != nil { return err } } else if dest.Kind() == reflect.Slice { - err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, dest.Addr().Interface(), structField) + err := mapRowToSlice(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField) if err != nil { return err } @@ -313,7 +315,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr return nil } - err := mapRowToStruct(scanContext, groupKey, typesProcessed, row, structValuePtr.Interface(), structField) + err := mapRowToStruct(scanContext, groupKey, typesProcessed, structValuePtr.Interface(), structField) if err != nil { return err } @@ -331,7 +333,7 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesPr sliceValuePtr = dest } - err := mapRowToSlice(scanContext, groupKey, typesProcessed, row, sliceValuePtr.Interface(), structField) + err := mapRowToSlice(scanContext, groupKey, typesProcessed, sliceValuePtr.Interface(), structField) if err != nil { return err } @@ -390,7 +392,7 @@ func getTableAlias(structField *reflect.StructField) string { return snaker.CamelToSnake(elemType) } -func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, row []interface{}, destinationPtr interface{}, structField *reflect.StructField) error { +func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error { structType := reflect.TypeOf(destinationPtr).Elem() structValue := reflect.ValueOf(destinationPtr).Elem() @@ -419,16 +421,20 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma fieldName := field.Name - if _, ok := fieldValue.Interface().(sql.Scanner); ok { - cellValue := getCellValue(scanContext, tableName, fieldName, row) + if scannerValue, ok := implementsScanner(fieldValue); ok { + cellValue := getCellValue(scanContext, tableName, fieldName) if cellValue == nil { continue } - initializeValue(fieldValue) + //spew.Dump(scannerValue.Interface()) - scanner := fieldValue.Interface().(sql.Scanner) + if scannerValue.IsNil() { + initializePtrValue(scannerValue) + } + + scanner := scannerValue.Interface().(sql.Scanner) err := scanner.Scan(cellValue) @@ -437,13 +443,13 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma } } else if !isDbBaseType(field.Type) { //var fieldValueInterface interface{} - err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, row, fieldValue, &field) + err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, fieldValue, &field) if err != nil { return err } } else { - cellValue := getCellValue(scanContext, tableName, fieldName, row) + cellValue := getCellValue(scanContext, tableName, fieldName) //spew.Dump(cellValue) //spew.Dump(rowColumnValue, fieldValue) @@ -456,16 +462,25 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma return nil } -func initializeValue(value reflect.Value) { - newValuePtr := reflect.New(value.Type().Elem()) +func implementsScanner(value reflect.Value) (reflect.Value, bool) { + if _, ok := value.Interface().(sql.Scanner); ok { + return value, true + } else if value.CanAddr() { + if _, ok := value.Addr().Interface().(sql.Scanner); ok { + return value.Addr(), true + } + } + + return value, false +} + +func initializePtrValue(value reflect.Value) { if value.Kind() == reflect.Ptr { - value.Set(newValuePtr) - } else { - value.Set(newValuePtr.Elem()) + value.Set(reflect.New(value.Type().Elem())) } } -func getCellValue(scanContext *scanContext, tableName, fieldName string, row []interface{}) interface{} { +func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} { columnName := tableName + "." + snaker.CamelToSnake(fieldName) //columnName := snaker.CamelToSnake(fieldName) @@ -476,7 +491,7 @@ func getCellValue(scanContext *scanContext, tableName, fieldName string, row []i return nil } - return cellValue(row, index) + return cellValue(scanContext.row, index) } func reflectValueToString(val interface{}) string { @@ -585,7 +600,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type { return nullInt32Type case "INT8": return nullInt64Type - case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR": + case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR", "BYTEA", "UUID": return nullStringType case "FLOAT4": return nullFloatType @@ -593,7 +608,7 @@ func newScanType(columnType *sql.ColumnType) reflect.Type { return nullFloat64Type case "BOOL": return nullBoolType - case "DATE", "TIMESTAMP": + case "DATE", "TIMESTAMP", "TIMESTAMPTZ": return nullTimeType default: panic("Unknown column database type " + columnType.DatabaseTypeName()) diff --git a/sqlbuilder/string_expression.go b/sqlbuilder/string_expression.go index 9cbc4ab..34bbf75 100644 --- a/sqlbuilder/string_expression.go +++ b/sqlbuilder/string_expression.go @@ -6,6 +6,7 @@ type StringExpression interface { Eq(expression StringExpression) BoolExpression EqL(value string) BoolExpression NotEq(expression StringExpression) BoolExpression + NotEqL(value string) BoolExpression } type stringInterfaceImpl struct { @@ -23,3 +24,7 @@ func (b *stringInterfaceImpl) EqL(value string) BoolExpression { func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression { return newBinaryBoolExpression(b.parent, expression, []byte(" != ")) } + +func (b *stringInterfaceImpl) NotEqL(value string) BoolExpression { + return newBinaryBoolExpression(b.parent, Literal(value), []byte(" != ")) +} diff --git a/tests/sample_test.go b/tests/sample_test.go new file mode 100644 index 0000000..31a8fcc --- /dev/null +++ b/tests/sample_test.go @@ -0,0 +1,57 @@ +package tests + +import ( + "fmt" + "github.com/davecgh/go-spew/spew" + "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model" + "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table" + "gotest.tools/assert" + "testing" +) + +func TestUUIDType(t *testing.T) { + query := table.AllTypes. + SELECT(table.AllTypes.AllColumns). + Where(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + + queryStr, err := query.String() + + assert.NilError(t, err) + fmt.Println(queryStr) + assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'`) + + result := model.AllTypes{} + + err = query.Execute(db, &result) + spew.Dump(result) +} + +func TestEnumType(t *testing.T) { + query := table.Person. + SELECT(table.Person.AllColumns) + + queryStr, err := query.String() + + assert.NilError(t, err) + fmt.Println(queryStr) + + result := []model.Person{} + + err = query.Execute(db, &result) + + assert.NilError(t, err) + //spew.Dump(result) + + type Person struct { + Name string + CurrentMood model.Mood + } + + result2 := []Person{} + + err = query.Execute(db, &result2) + + assert.NilError(t, err) + + //spew.Dump(result2) +}