diff --git a/generator/generator.go b/generator/generator.go index 53b7450..d69448f 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -16,6 +16,7 @@ type DbConnectInfo struct { } func Generate(folderPath string, connectString string, databaseName, schemaName string) error { + err := cleanUpGeneratedFiles(path.Join(folderPath, databaseName, schemaName)) if err != nil { diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go index ba83d9f..7d54b90 100644 --- a/generator/metadata/column_info.go +++ b/generator/metadata/column_info.go @@ -2,6 +2,7 @@ package metadata import ( "database/sql" + "fmt" "github.com/serenize/snaker" ) @@ -37,16 +38,15 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string { case "bigint": return "IntegerColumn" case "date", "timestamp without time zone", "timestamp with time zone": - return "StringColumn" - case "bytea": - return "StringColumn" - case "text": + return "TimeColumn" + case "text", "character", "character varying", "bytea": return "StringColumn" case "real": return "NumericColumn" case "numeric", "double precision": return "NumericColumn" default: + fmt.Println("Unknownl type: " + c.DataType + ", using string instead.") return "StringColumn" } } @@ -77,13 +77,14 @@ func (c ColumnInfo) GoBaseType() string { return "time.Time" case "bytea": return "[]byte" - case "text": + case "text", "character", "character varying": return "string" case "real": return "float32" case "numeric", "double precision": return "float64" default: + fmt.Println("Unknown go map type: " + c.DataType + ", using string instead.") return "string" } } diff --git a/sqlbuilder/column_types.go b/sqlbuilder/column_types.go index 724f6a2..5b2fa44 100644 --- a/sqlbuilder/column_types.go +++ b/sqlbuilder/column_types.go @@ -76,10 +76,28 @@ func NewStringColumn(name string, nullable NullableColumn) *StringColumn { stringColumn := &StringColumn{} - stringColumn.stringInterfaceImpl.parent = stringColumn stringColumn.stringInterfaceImpl.parent = stringColumn stringColumn.baseColumn = newBaseColumn(name, nullable, "", stringColumn) return stringColumn } + +//------------------------------------------------------// +type TimeColumn struct { + timeInterfaceImpl + + baseColumn +} + +// Representation of any integer column +// This function will panic if name is not valid +func NewTimeColumn(name string, nullable NullableColumn) *TimeColumn { + stringColumn := &TimeColumn{} + + stringColumn.timeInterfaceImpl.parent = stringColumn + + stringColumn.baseColumn = newBaseColumn(name, nullable, "", stringColumn) + + return stringColumn +} diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 6463501..5a6db84 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -411,10 +411,11 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) - //fieldTypeName := field.Name + fieldValue := structValue.Field(i) - //fmt.Println("---------------", fieldTypeName) - ////spew.Dump(field.Type) + //fieldTypeName := field.Name + //fmt.Println("---------------", fieldTypeName,) + //spew.Dump(field.Type) fieldName := field.Name @@ -486,8 +487,8 @@ func isDbBaseType(objType reflect.Type) bool { typeStr := objType.String() switch typeStr { - case "string", "int32", "int16", "float32", "float64", "time.Time", "bool", - "*string", "*int32", "*int16", "*float32", "*float64", "*time.Time", "*bool": + case "string", "int32", "int16", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8", + "*string", "*int32", "*int16", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8": return true } diff --git a/sqlbuilder/time_expression.go b/sqlbuilder/time_expression.go new file mode 100644 index 0000000..c9b22d7 --- /dev/null +++ b/sqlbuilder/time_expression.go @@ -0,0 +1,50 @@ +package sqlbuilder + +type TimeExpression interface { + Expression + + Eq(expression TimeExpression) BoolExpression + EqL(literal string) BoolExpression + NotEq(expression TimeExpression) BoolExpression + NotEqL(literal string) BoolExpression + GtEq(rhs TimeExpression) BoolExpression + GtEqL(literal string) BoolExpression + LtEq(rhs TimeExpression) BoolExpression + LtEqL(literal string) BoolExpression +} + +type timeInterfaceImpl struct { + parent TimeExpression +} + +func (t *timeInterfaceImpl) Eq(expression TimeExpression) BoolExpression { + return Eq(t.parent, expression) +} + +func (t *timeInterfaceImpl) EqL(literal string) BoolExpression { + return Eq(t.parent, Literal(literal)) +} + +func (t *timeInterfaceImpl) NotEq(expression TimeExpression) BoolExpression { + return NotEq(t.parent, expression) +} + +func (t *timeInterfaceImpl) NotEqL(literal string) BoolExpression { + return NotEq(t.parent, Literal(literal)) +} + +func (t *timeInterfaceImpl) GtEq(expression TimeExpression) BoolExpression { + return GtEq(t.parent, expression) +} + +func (t *timeInterfaceImpl) GtEqL(literal string) BoolExpression { + return GtEq(t.parent, Literal(literal)) +} + +func (t *timeInterfaceImpl) LtEq(expression TimeExpression) BoolExpression { + return LtEq(t.parent, expression) +} + +func (t *timeInterfaceImpl) LtEqL(literal string) BoolExpression { + return LtEq(t.parent, Literal(literal)) +} diff --git a/tests/generator_test.go b/tests/generator_test.go index c3e3cae..63ca2e3 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -9,6 +9,7 @@ import ( . "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" "gotest.tools/assert" "os" + "strings" "testing" "time" ) @@ -27,6 +28,7 @@ var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=% var db *sql.DB //go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files +//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files func TestMain(m *testing.M) { fmt.Println("Begin") @@ -74,7 +76,7 @@ func TestSelect_ScanToStruct(t *testing.T) { ActorID: 1, FirstName: "Penelope", LastName: "Guiness", - LastUpdate: *timeWithoutTimeZone("2013-05-26 14:47:57.62 +0000"), + LastUpdate: *timeWithoutTimeZone("2013-05-26 14:47:57.62", 2), } assert.DeepEqual(t, actor, expectedActor) @@ -501,7 +503,7 @@ func TestSelectQueryScalar(t *testing.T) { ReplacementCost: 12.99, Rating: stringPtr("G"), RentalDuration: 3, - LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951 +0000"), + LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3), SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"), Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) @@ -582,13 +584,39 @@ func TestSelectGroupBy2(t *testing.T) { LastName: "Wyman", Email: stringPtr("brian.wyman@sakilacustomer.org"), Activebool: true, - CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), - LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), Active: int32Ptr(1), }) assert.Equal(t, customersWithAmounts[0].AmountSum, 27.93) +} +func TestSelectTimeColumns(t *testing.T) { + query := Payment.SELECT(Payment.AllColumns). + Where(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")). + OrderBy(Payment.PaymentDate.Asc()) + + queryStr, err := query.String() + + assert.NilError(t, err) + + fmt.Println(queryStr) + + payments := []model.Payment{} + + err = query.Execute(db, &payments) + + assert.NilError(t, err) + + //spew.Dump(payments) + + assert.Equal(t, len(payments), 9) + assert.DeepEqual(t, payments[0], model.Payment{ + PaymentID: 17793, + Amount: 2.99, + PaymentDate: *timeWithoutTimeZone("2007-02-14 21:21:59.996577", 6), + }) } func int16Ptr(i int16) *int16 { @@ -603,8 +631,15 @@ func stringPtr(s string) *string { return &s } -func timeWithoutTimeZone(t string) *time.Time { - time, err := time.Parse("2006-01-02 15:04:05 -0700", t) +func timeWithoutTimeZone(t string, precision int) *time.Time { + + precisionStr := "" + + if precision > 0 { + precisionStr = "." + strings.Repeat("9", precision) + } + + time, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") if err != nil { panic(err) @@ -621,8 +656,8 @@ var customer0 = model.Customer{ Email: stringPtr("mary.smith@sakilacustomer.org"), Address: nil, Activebool: true, - CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), - LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), Active: int32Ptr(1), } @@ -634,8 +669,8 @@ var customer1 = model.Customer{ Email: stringPtr("patricia.johnson@sakilacustomer.org"), Address: nil, Activebool: true, - CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), - LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), Active: int32Ptr(1), } @@ -647,7 +682,7 @@ var lastCustomer = model.Customer{ Email: stringPtr("austin.cintron@sakilacustomer.org"), Address: nil, Activebool: true, - CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00 +0000"), - LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738 +0000"), + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), Active: int32Ptr(1), }