Model sql tags.

This commit is contained in:
go-jet 2019-06-12 12:47:30 +02:00
parent 367602757f
commit c598978ba6
10 changed files with 110 additions and 67 deletions

View file

@ -81,6 +81,20 @@ func (c ColumnInfo) GoModelType() string {
return typeStr return typeStr
} }
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) { func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := ` query := `

View file

@ -81,7 +81,7 @@ import (
type {{camelize .Name}} struct { type {{camelize .Name}} struct {
{{- range .Columns}} {{- range .Columns}}
{{camelize .Name}} {{.GoModelType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}} {{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsUnique .Name)}}" + `
{{- end}} {{- end}}
} }
` `

View file

@ -3,6 +3,7 @@ package sqlbuilder
import ( import (
"bytes" "bytes"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lib/pq"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -223,7 +224,7 @@ func ArgToString(value interface{}) string {
case uuid.UUID: case uuid.UUID:
return stringQuote(bindVal.String()) return stringQuote(bindVal.String())
case time.Time: case time.Time:
return stringQuote(bindVal.String()) return stringQuote(string(pq.FormatTimestamp(bindVal)))
default: default:
return "[Unknown type]" return "[Unknown type]"
} }

View file

@ -131,8 +131,8 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
if isGoBaseType(sliceElemType) { if isGoBaseType(sliceElemType) {
index := 0 index := 0
if structField != nil { if structField != nil {
columnName := getRefTableNameFrom(structField) tableName, columnName := getRefTableNameFrom(structField)
index = getIndex(scanContext.columnNames, columnName) index = getIndex(scanContext.columnNames, tableName+"."+columnName)
if index < 0 { if index < 0 {
return return
@ -192,7 +192,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
} }
func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string { func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string {
tableName := getRefTableNameFrom(structField) tableName, _ := getRefTableNameFrom(structField)
if tableName == "" { if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name()) tableName = snaker.CamelToSnake(structType.Name())
@ -218,7 +218,7 @@ func getGroupKey(scanContext *scanContext, structType reflect.Type, structField
if structGroupKey != "" { if structGroupKey != "" {
groupKeys = append(groupKeys, structGroupKey) groupKeys = append(groupKeys, structGroupKey)
} }
} else if field.Tag == `sql:"unique"` { } else if tagInfo(field.Tag.Get("sql")).isPrimaryKey {
fieldName := field.Name fieldName := field.Name
columnName := tableName + "." + snaker.CamelToSnake(fieldName) columnName := tableName + "." + snaker.CamelToSnake(fieldName)
@ -293,11 +293,7 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value { func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
destinationSliceType := slicePtrValue.Type().Elem() destinationSliceType := slicePtrValue.Type().Elem()
elemType := destinationSliceType.Elem() elemType := indirectType(destinationSliceType.Elem())
if elemType.Kind() == reflect.Ptr {
return reflect.New(elemType.Elem())
}
return reflect.New(elemType) return reflect.New(elemType)
} }
@ -348,51 +344,37 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
return return
} }
func getRefTableNameFrom(structField *reflect.StructField) string { func getRefTableNameFrom(structField *reflect.StructField) (table, column string) {
if structField == nil { if structField == nil {
return "" return
} }
tagOverwriteName := structField.Tag.Get("sqlbuilder") sqlTag := structField.Tag.Get("sql")
if tagOverwriteName != "" { if sqlTag != "" {
return tagOverwriteName tagInfo := tagInfo(sqlTag)
return tagInfo.table, tagInfo.column
} }
if !structField.Anonymous { if !structField.Anonymous {
return snaker.CamelToSnake(structField.Name) return snaker.CamelToSnake(structField.Name), ""
} }
var elemType string fieldType := indirectType(structField.Type)
if structField.Type.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Slice {
elem := structField.Type.Elem() fieldType = fieldType.Elem()
if elem.Kind() == reflect.Struct {
elemType = elem.Name()
} else if elem.Kind() == reflect.Slice {
elemType = elem.Elem().Name()
}
} else {
if structField.Type.Kind() == reflect.Struct {
elemType = structField.Type.Name()
} else {
sliceElem := structField.Type.Elem()
if sliceElem.Kind() == reflect.Ptr {
elemType = sliceElem.Elem().Name()
} else {
elemType = sliceElem.Name()
}
}
} }
return snaker.CamelToSnake(elemType) return snaker.CamelToSnake(fieldType.Name()), ""
} }
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
structType := structPtrValue.Type().Elem() structType := structPtrValue.Type().Elem()
structValue := structPtrValue.Elem() structValue := structPtrValue.Elem()
tableName := getRefTableNameFrom(structField) tableName, _ := getRefTableNameFrom(structField)
if tableName == "" { if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name()) tableName = snaker.CamelToSnake(structType.Name())
@ -689,3 +671,41 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
newElem.Elem().Set(rowElemValue) newElem.Elem().Set(rowElemValue)
return newElem return newElem
} }
type sqlTagInfo struct {
isPrimaryKey bool
table string
column string
}
func tagInfo(tag string) sqlTagInfo {
sqlTags := strings.Split(tag, ",")
tagMap := map[string]string{}
for _, tag := range sqlTags {
tagParts := strings.Split(tag, ":")
tagKey := tagParts[0]
tagValue := ""
if len(tagParts) > 1 {
tagValue = tagParts[1]
}
tagMap[tagKey] = tagValue
}
_, isPrimaryKey := tagMap["primary_key"]
return sqlTagInfo{
isPrimaryKey: isPrimaryKey,
table: tagMap["table"],
column: tagMap["column"],
}
}
func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr {
return reflectType
}
return reflectType.Elem()
}

Binary file not shown.

View file

@ -158,6 +158,7 @@ CREATE TABLE test_sample.employee (
employee_id INT PRIMARY KEY, employee_id INT PRIMARY KEY,
first_name VARCHAR (255) NOT NULL, first_name VARCHAR (255) NOT NULL,
last_name VARCHAR (255) NOT NULL, last_name VARCHAR (255) NOT NULL,
employment_date timestamp with time zone,
manager_id INT, manager_id INT,
FOREIGN KEY (manager_id) FOREIGN KEY (manager_id)
REFERENCES test_sample.employee (employee_id) REFERENCES test_sample.employee (employee_id)
@ -167,17 +168,18 @@ INSERT INTO test_sample.employee (
employee_id, employee_id,
first_name, first_name,
last_name, last_name,
employment_date,
manager_id manager_id
) )
VALUES VALUES
(1, 'Windy', 'Hays', NULL), (1, 'Windy', 'Hays', '1999-01-08 04:05:06.100 -8:00', NULL),
(2, 'Ava', 'Christensen', 1), (2, 'Ava', 'Christensen', '1999-01-08 04:05:06', 1),
(3, 'Hassan', 'Conner', 1), (3, 'Hassan', 'Conner', '1999-01-08 04:05:06', 1),
(4, 'Anna', 'Reeves', 2), (4, 'Anna', 'Reeves', '1999-01-08 04:05:06', 2),
(5, 'Sau', 'Norman', 2), (5, 'Sau', 'Norman', '1999-01-08 04:05:06', 2),
(6, 'Kelsie', 'Hays', 3), (6, 'Kelsie', 'Hays', '1999-01-08 04:05:06', 3),
(7, 'Tory', 'Goff', 3), (7, 'Tory', 'Goff', '1999-01-08 04:05:06', 3),
(8, 'Salley', 'Lester', 3); (8, 'Salley', 'Lester', '1999-01-08 04:05:06', 3);
-- Person table ------------------ -- Person table ------------------

View file

@ -36,7 +36,7 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
assert.Assert(t, ok) assert.Assert(t, ok)
assert.Equal(t, actorIDField.Tag.Get("sql"), "unique") assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time") assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time")
@ -46,12 +46,12 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16") assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16")
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
assert.Assert(t, ok) assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique") assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16") assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16")
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
assert.Assert(t, ok) assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique") assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
staff := model.Staff{} staff := model.Staff{}

View file

@ -364,6 +364,8 @@ func TestScanToNestedStruct(t *testing.T) {
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns). SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns).
WHERE(Inventory.InventoryID.EQ(Int(1))) WHERE(Inventory.InventoryID.EQ(Int(1)))
type Language3 model.Language
dest := struct { dest := struct {
model.Inventory model.Inventory
Film struct { Film struct {
@ -371,7 +373,7 @@ func TestScanToNestedStruct(t *testing.T) {
Language model.Language Language model.Language
Language2 *model.Language Language2 *model.Language
Language3 *model.Language `sqlbuilder:"language"` Language3 *Language3 `sql:"table:language"`
Lang struct { Lang struct {
model.Language model.Language
} }
@ -392,7 +394,7 @@ func TestScanToNestedStruct(t *testing.T) {
assert.DeepEqual(t, dest.Film.Lang.Language, language1) assert.DeepEqual(t, dest.Film.Lang.Language, language1)
assert.DeepEqual(t, dest.Film.Lang2.Language, language1) assert.DeepEqual(t, dest.Film.Lang2.Language, language1)
assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil)) assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil))
assert.DeepEqual(t, dest.Film.Language3, &language1) assert.DeepEqual(t, model.Language(*dest.Film.Language3), language1)
}) })
} }
@ -443,7 +445,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("struct with slice of ints", func(t *testing.T) { t.Run("struct with slice of ints", func(t *testing.T) {
var dest struct { var dest struct {
model.Film model.Film
IDs []int32 `sqlbuilder:"inventory.inventory_id"` IDs []int32 `sql:"table:inventory,column:inventory_id"`
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
@ -456,7 +458,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice of structs with slice of ints", func(t *testing.T) { t.Run("slice of structs with slice of ints", func(t *testing.T) {
var dest []struct { var dest []struct {
model.Film model.Film
IDs []int32 `sqlbuilder:"inventory.inventory_id"` IDs []int32 `sql:"table:inventory,column:inventory_id"`
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
@ -472,7 +474,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) { t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) {
var dest []struct { var dest []struct {
model.Film model.Film
IDs []*int32 `sqlbuilder:"inventory.inventory_id"` IDs []*int32 `sql:"table:inventory,column:inventory_id"`
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)

View file

@ -495,10 +495,12 @@ func TestSelecSelfJoin1(t *testing.T) {
SELECT employee.employee_id AS "employee.employee_id", SELECT employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name", employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name", employee.last_name AS "employee.last_name",
employee.employment_date AS "employee.employment_date",
employee.manager_id AS "employee.manager_id", employee.manager_id AS "employee.manager_id",
manager.employee_id AS "manager.employee_id", manager.employee_id AS "manager.employee_id",
manager.first_name AS "manager.first_name", manager.first_name AS "manager.first_name",
manager.last_name AS "manager.last_name", manager.last_name AS "manager.last_name",
manager.employment_date AS "manager.employment_date",
manager.manager_id AS "manager.manager_id" manager.manager_id AS "manager.manager_id"
FROM test_sample.employee FROM test_sample.employee
LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id) LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id)
@ -527,6 +529,7 @@ ORDER BY employee.employee_id;
EmployeeID: 1, EmployeeID: 1,
FirstName: "Windy", FirstName: "Windy",
LastName: "Hays", LastName: "Hays",
EmploymentDate: timestampWithTimeZone("1999-01-08 13:05:06.1 +0100 CET", 1),
ManagerID: nil, ManagerID: nil,
}) })
@ -536,6 +539,7 @@ ORDER BY employee.employee_id;
EmployeeID: 8, EmployeeID: 8,
FirstName: "Salley", FirstName: "Salley",
LastName: "Lester", LastName: "Lester",
EmploymentDate: timestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1),
ManagerID: int32Ptr(3), ManagerID: int32Ptr(3),
}) })
} }