diff --git a/generator/postgres-metadata/column_info.go b/generator/postgres-metadata/column_info.go index d23a745..7d571b1 100644 --- a/generator/postgres-metadata/column_info.go +++ b/generator/postgres-metadata/column_info.go @@ -81,6 +81,20 @@ func (c ColumnInfo) GoModelType() string { return typeStr } +func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string { + tags := []string{} + + if isPrimaryKey { + tags = append(tags, "primary_key") + } + + if len(tags) > 0 { + return "`sql:\"" + strings.Join(tags, ",") + "\"`" + } + + return "" +} + func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) { query := ` diff --git a/generator/templates.go b/generator/templates.go index 8b181be..38eaaa9 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -81,7 +81,7 @@ import ( type {{camelize .Name}} struct { {{- range .Columns}} - {{camelize .Name}} {{.GoModelType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}} + {{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsUnique .Name)}}" + ` {{- end}} } ` diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index a858a3c..3733ee6 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -3,6 +3,7 @@ package sqlbuilder import ( "bytes" "github.com/google/uuid" + "github.com/lib/pq" "strconv" "strings" "time" @@ -223,7 +224,7 @@ func ArgToString(value interface{}) string { case uuid.UUID: return stringQuote(bindVal.String()) case time.Time: - return stringQuote(bindVal.String()) + return stringQuote(string(pq.FormatTimestamp(bindVal))) default: return "[Unknown type]" } diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index ef09e7d..a3f0de7 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -131,8 +131,8 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl if isGoBaseType(sliceElemType) { index := 0 if structField != nil { - columnName := getRefTableNameFrom(structField) - index = getIndex(scanContext.columnNames, columnName) + tableName, columnName := getRefTableNameFrom(structField) + index = getIndex(scanContext.columnNames, tableName+"."+columnName) if index < 0 { return @@ -192,7 +192,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl } func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string { - tableName := getRefTableNameFrom(structField) + tableName, _ := getRefTableNameFrom(structField) if tableName == "" { tableName = snaker.CamelToSnake(structType.Name()) @@ -218,7 +218,7 @@ func getGroupKey(scanContext *scanContext, structType reflect.Type, structField if structGroupKey != "" { groupKeys = append(groupKeys, structGroupKey) } - } else if field.Tag == `sql:"unique"` { + } else if tagInfo(field.Tag.Get("sql")).isPrimaryKey { fieldName := field.Name columnName := tableName + "." + snaker.CamelToSnake(fieldName) @@ -293,11 +293,7 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value { destinationSliceType := slicePtrValue.Type().Elem() - elemType := destinationSliceType.Elem() - - if elemType.Kind() == reflect.Ptr { - return reflect.New(elemType.Elem()) - } + elemType := indirectType(destinationSliceType.Elem()) return reflect.New(elemType) } @@ -348,51 +344,37 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re return } -func getRefTableNameFrom(structField *reflect.StructField) string { +func getRefTableNameFrom(structField *reflect.StructField) (table, column string) { if structField == nil { - return "" + return } - tagOverwriteName := structField.Tag.Get("sqlbuilder") + sqlTag := structField.Tag.Get("sql") - if tagOverwriteName != "" { - return tagOverwriteName + if sqlTag != "" { + tagInfo := tagInfo(sqlTag) + + return tagInfo.table, tagInfo.column } if !structField.Anonymous { - return snaker.CamelToSnake(structField.Name) + return snaker.CamelToSnake(structField.Name), "" } - var elemType string + fieldType := indirectType(structField.Type) - if structField.Type.Kind() == reflect.Ptr { - elem := structField.Type.Elem() - if elem.Kind() == reflect.Struct { - elemType = elem.Name() - } else if elem.Kind() == reflect.Slice { - elemType = elem.Elem().Name() - } - } else { - if structField.Type.Kind() == reflect.Struct { - elemType = structField.Type.Name() - } else { - sliceElem := structField.Type.Elem() - if sliceElem.Kind() == reflect.Ptr { - elemType = sliceElem.Elem().Name() - } else { - elemType = sliceElem.Name() - } - } + if fieldType.Kind() == reflect.Slice { + fieldType = fieldType.Elem() } - return snaker.CamelToSnake(elemType) + return snaker.CamelToSnake(fieldType.Name()), "" } func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { structType := structPtrValue.Type().Elem() structValue := structPtrValue.Elem() - tableName := getRefTableNameFrom(structField) + tableName, _ := getRefTableNameFrom(structField) if tableName == "" { tableName = snaker.CamelToSnake(structType.Name()) @@ -689,3 +671,41 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value { newElem.Elem().Set(rowElemValue) return newElem } + +type sqlTagInfo struct { + isPrimaryKey bool + table string + column string +} + +func tagInfo(tag string) sqlTagInfo { + sqlTags := strings.Split(tag, ",") + tagMap := map[string]string{} + + for _, tag := range sqlTags { + tagParts := strings.Split(tag, ":") + + tagKey := tagParts[0] + tagValue := "" + if len(tagParts) > 1 { + tagValue = tagParts[1] + } + + tagMap[tagKey] = tagValue + } + + _, isPrimaryKey := tagMap["primary_key"] + + return sqlTagInfo{ + isPrimaryKey: isPrimaryKey, + table: tagMap["table"], + column: tagMap["column"], + } +} + +func indirectType(reflectType reflect.Type) reflect.Type { + if reflectType.Kind() != reflect.Ptr { + return reflectType + } + return reflectType.Elem() +} diff --git a/tests/types_test.go b/tests/all_types_test.go similarity index 100% rename from tests/types_test.go rename to tests/all_types_test.go diff --git a/tests/dvdrental.tar b/tests/dvdrental.tar deleted file mode 100644 index ebcef4f..0000000 Binary files a/tests/dvdrental.tar and /dev/null differ diff --git a/tests/init/data/test_sample.sql b/tests/init/data/test_sample.sql index 54fd241..221122d 100644 --- a/tests/init/data/test_sample.sql +++ b/tests/init/data/test_sample.sql @@ -142,11 +142,11 @@ VALUES (1, 1, 300, 300, 50000, 5000, 11.44, 11.44, 55.77, 55.77, 99.1, 99.1, 111 DROP TABLE IF EXISTS test_sample.link; CREATE TABLE IF NOT EXISTS test_sample.link ( - ID serial PRIMARY KEY, - url VARCHAR (255) NOT NULL, - name VARCHAR (255) NOT NULL, - description VARCHAR (255), - rel VARCHAR (50) + ID serial PRIMARY KEY, + url VARCHAR (255) NOT NULL, + name VARCHAR (255) NOT NULL, + description VARCHAR (255), + rel VARCHAR (50) ); @@ -158,6 +158,7 @@ CREATE TABLE test_sample.employee ( employee_id INT PRIMARY KEY, first_name VARCHAR (255) NOT NULL, last_name VARCHAR (255) NOT NULL, + employment_date timestamp with time zone, manager_id INT, FOREIGN KEY (manager_id) REFERENCES test_sample.employee (employee_id) @@ -167,17 +168,18 @@ INSERT INTO test_sample.employee ( employee_id, first_name, last_name, + employment_date, manager_id ) VALUES -(1, 'Windy', 'Hays', NULL), -(2, 'Ava', 'Christensen', 1), -(3, 'Hassan', 'Conner', 1), -(4, 'Anna', 'Reeves', 2), -(5, 'Sau', 'Norman', 2), -(6, 'Kelsie', 'Hays', 3), -(7, 'Tory', 'Goff', 3), -(8, 'Salley', 'Lester', 3); +(1, 'Windy', 'Hays', '1999-01-08 04:05:06.100 -8:00', NULL), +(2, 'Ava', 'Christensen', '1999-01-08 04:05:06', 1), +(3, 'Hassan', 'Conner', '1999-01-08 04:05:06', 1), +(4, 'Anna', 'Reeves', '1999-01-08 04:05:06', 2), +(5, 'Sau', 'Norman', '1999-01-08 04:05:06', 2), +(6, 'Kelsie', 'Hays', '1999-01-08 04:05:06', 3), +(7, 'Tory', 'Goff', '1999-01-08 04:05:06', 3), +(8, 'Salley', 'Lester', '1999-01-08 04:05:06', 3); -- Person table ------------------ diff --git a/tests/main_test.go b/tests/main_test.go index 25335e3..97d1ffa 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -36,7 +36,7 @@ func TestGenerateModel(t *testing.T) { assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") assert.Assert(t, ok) - assert.Equal(t, actorIDField.Tag.Get("sql"), "unique") + assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key") assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time") @@ -46,12 +46,12 @@ func TestGenerateModel(t *testing.T) { assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16") filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") assert.Assert(t, ok) - assert.Equal(t, filmIDField.Tag.Get("sql"), "unique") + assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16") actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") assert.Assert(t, ok) - assert.Equal(t, filmIDField.Tag.Get("sql"), "unique") + assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") staff := model.Staff{} diff --git a/tests/scan_test.go b/tests/scan_test.go index aa15b59..a57018f 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -364,6 +364,8 @@ func TestScanToNestedStruct(t *testing.T) { SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns). WHERE(Inventory.InventoryID.EQ(Int(1))) + type Language3 model.Language + dest := struct { model.Inventory Film struct { @@ -371,7 +373,7 @@ func TestScanToNestedStruct(t *testing.T) { Language model.Language Language2 *model.Language - Language3 *model.Language `sqlbuilder:"language"` + Language3 *Language3 `sql:"table:language"` Lang struct { model.Language } @@ -392,7 +394,7 @@ func TestScanToNestedStruct(t *testing.T) { assert.DeepEqual(t, dest.Film.Lang.Language, language1) assert.DeepEqual(t, dest.Film.Lang2.Language, language1) assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil)) - assert.DeepEqual(t, dest.Film.Language3, &language1) + assert.DeepEqual(t, model.Language(*dest.Film.Language3), language1) }) } @@ -443,7 +445,7 @@ func TestScanToSlice(t *testing.T) { t.Run("struct with slice of ints", func(t *testing.T) { var dest struct { model.Film - IDs []int32 `sqlbuilder:"inventory.inventory_id"` + IDs []int32 `sql:"table:inventory,column:inventory_id"` } err := query.Query(db, &dest) @@ -456,7 +458,7 @@ func TestScanToSlice(t *testing.T) { t.Run("slice of structs with slice of ints", func(t *testing.T) { var dest []struct { model.Film - IDs []int32 `sqlbuilder:"inventory.inventory_id"` + IDs []int32 `sql:"table:inventory,column:inventory_id"` } err := query.Query(db, &dest) @@ -472,7 +474,7 @@ func TestScanToSlice(t *testing.T) { t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) { var dest []struct { model.Film - IDs []*int32 `sqlbuilder:"inventory.inventory_id"` + IDs []*int32 `sql:"table:inventory,column:inventory_id"` } err := query.Query(db, &dest) diff --git a/tests/select_test.go b/tests/select_test.go index f5d4a9d..3d9f95b 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -495,10 +495,12 @@ func TestSelecSelfJoin1(t *testing.T) { SELECT employee.employee_id AS "employee.employee_id", employee.first_name AS "employee.first_name", employee.last_name AS "employee.last_name", + employee.employment_date AS "employee.employment_date", employee.manager_id AS "employee.manager_id", manager.employee_id AS "manager.employee_id", manager.first_name AS "manager.first_name", manager.last_name AS "manager.last_name", + manager.employment_date AS "manager.employment_date", manager.manager_id AS "manager.manager_id" FROM test_sample.employee LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id) @@ -524,19 +526,21 @@ ORDER BY employee.employee_id; assert.NilError(t, err) assert.Equal(t, len(dest), 8) assert.DeepEqual(t, dest[0].Employee, model2.Employee{ - EmployeeID: 1, - FirstName: "Windy", - LastName: "Hays", - ManagerID: nil, + EmployeeID: 1, + FirstName: "Windy", + LastName: "Hays", + EmploymentDate: timestampWithTimeZone("1999-01-08 13:05:06.1 +0100 CET", 1), + ManagerID: nil, }) assert.Assert(t, dest[0].Manager == nil) assert.DeepEqual(t, dest[7].Employee, model2.Employee{ - EmployeeID: 8, - FirstName: "Salley", - LastName: "Lester", - ManagerID: int32Ptr(3), + EmployeeID: 8, + FirstName: "Salley", + LastName: "Lester", + EmploymentDate: timestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1), + ManagerID: int32Ptr(3), }) }