Model sql tags.

This commit is contained in:
go-jet 2019-06-12 12:47:30 +02:00
parent 367602757f
commit c598978ba6
10 changed files with 110 additions and 67 deletions

View file

@ -81,6 +81,20 @@ func (c ColumnInfo) GoModelType() string {
return typeStr return typeStr
} }
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) { func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := ` query := `

View file

@ -81,7 +81,7 @@ import (
type {{camelize .Name}} struct { type {{camelize .Name}} struct {
{{- range .Columns}} {{- range .Columns}}
{{camelize .Name}} {{.GoModelType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}} {{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsUnique .Name)}}" + `
{{- end}} {{- end}}
} }
` `

View file

@ -3,6 +3,7 @@ package sqlbuilder
import ( import (
"bytes" "bytes"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lib/pq"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -223,7 +224,7 @@ func ArgToString(value interface{}) string {
case uuid.UUID: case uuid.UUID:
return stringQuote(bindVal.String()) return stringQuote(bindVal.String())
case time.Time: case time.Time:
return stringQuote(bindVal.String()) return stringQuote(string(pq.FormatTimestamp(bindVal)))
default: default:
return "[Unknown type]" return "[Unknown type]"
} }

View file

@ -131,8 +131,8 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
if isGoBaseType(sliceElemType) { if isGoBaseType(sliceElemType) {
index := 0 index := 0
if structField != nil { if structField != nil {
columnName := getRefTableNameFrom(structField) tableName, columnName := getRefTableNameFrom(structField)
index = getIndex(scanContext.columnNames, columnName) index = getIndex(scanContext.columnNames, tableName+"."+columnName)
if index < 0 { if index < 0 {
return return
@ -192,7 +192,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
} }
func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string { func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string {
tableName := getRefTableNameFrom(structField) tableName, _ := getRefTableNameFrom(structField)
if tableName == "" { if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name()) tableName = snaker.CamelToSnake(structType.Name())
@ -218,7 +218,7 @@ func getGroupKey(scanContext *scanContext, structType reflect.Type, structField
if structGroupKey != "" { if structGroupKey != "" {
groupKeys = append(groupKeys, structGroupKey) groupKeys = append(groupKeys, structGroupKey)
} }
} else if field.Tag == `sql:"unique"` { } else if tagInfo(field.Tag.Get("sql")).isPrimaryKey {
fieldName := field.Name fieldName := field.Name
columnName := tableName + "." + snaker.CamelToSnake(fieldName) columnName := tableName + "." + snaker.CamelToSnake(fieldName)
@ -293,11 +293,7 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value { func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
destinationSliceType := slicePtrValue.Type().Elem() destinationSliceType := slicePtrValue.Type().Elem()
elemType := destinationSliceType.Elem() elemType := indirectType(destinationSliceType.Elem())
if elemType.Kind() == reflect.Ptr {
return reflect.New(elemType.Elem())
}
return reflect.New(elemType) return reflect.New(elemType)
} }
@ -348,51 +344,37 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
return return
} }
func getRefTableNameFrom(structField *reflect.StructField) string { func getRefTableNameFrom(structField *reflect.StructField) (table, column string) {
if structField == nil { if structField == nil {
return "" return
} }
tagOverwriteName := structField.Tag.Get("sqlbuilder") sqlTag := structField.Tag.Get("sql")
if tagOverwriteName != "" { if sqlTag != "" {
return tagOverwriteName tagInfo := tagInfo(sqlTag)
return tagInfo.table, tagInfo.column
} }
if !structField.Anonymous { if !structField.Anonymous {
return snaker.CamelToSnake(structField.Name) return snaker.CamelToSnake(structField.Name), ""
} }
var elemType string fieldType := indirectType(structField.Type)
if structField.Type.Kind() == reflect.Ptr { if fieldType.Kind() == reflect.Slice {
elem := structField.Type.Elem() fieldType = fieldType.Elem()
if elem.Kind() == reflect.Struct {
elemType = elem.Name()
} else if elem.Kind() == reflect.Slice {
elemType = elem.Elem().Name()
}
} else {
if structField.Type.Kind() == reflect.Struct {
elemType = structField.Type.Name()
} else {
sliceElem := structField.Type.Elem()
if sliceElem.Kind() == reflect.Ptr {
elemType = sliceElem.Elem().Name()
} else {
elemType = sliceElem.Name()
}
}
} }
return snaker.CamelToSnake(elemType) return snaker.CamelToSnake(fieldType.Name()), ""
} }
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
structType := structPtrValue.Type().Elem() structType := structPtrValue.Type().Elem()
structValue := structPtrValue.Elem() structValue := structPtrValue.Elem()
tableName := getRefTableNameFrom(structField) tableName, _ := getRefTableNameFrom(structField)
if tableName == "" { if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name()) tableName = snaker.CamelToSnake(structType.Name())
@ -689,3 +671,41 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
newElem.Elem().Set(rowElemValue) newElem.Elem().Set(rowElemValue)
return newElem return newElem
} }
type sqlTagInfo struct {
isPrimaryKey bool
table string
column string
}
func tagInfo(tag string) sqlTagInfo {
sqlTags := strings.Split(tag, ",")
tagMap := map[string]string{}
for _, tag := range sqlTags {
tagParts := strings.Split(tag, ":")
tagKey := tagParts[0]
tagValue := ""
if len(tagParts) > 1 {
tagValue = tagParts[1]
}
tagMap[tagKey] = tagValue
}
_, isPrimaryKey := tagMap["primary_key"]
return sqlTagInfo{
isPrimaryKey: isPrimaryKey,
table: tagMap["table"],
column: tagMap["column"],
}
}
func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr {
return reflectType
}
return reflectType.Elem()
}

Binary file not shown.

View file

@ -142,11 +142,11 @@ VALUES (1, 1, 300, 300, 50000, 5000, 11.44, 11.44, 55.77, 55.77, 99.1, 99.1, 111
DROP TABLE IF EXISTS test_sample.link; DROP TABLE IF EXISTS test_sample.link;
CREATE TABLE IF NOT EXISTS test_sample.link ( CREATE TABLE IF NOT EXISTS test_sample.link (
ID serial PRIMARY KEY, ID serial PRIMARY KEY,
url VARCHAR (255) NOT NULL, url VARCHAR (255) NOT NULL,
name VARCHAR (255) NOT NULL, name VARCHAR (255) NOT NULL,
description VARCHAR (255), description VARCHAR (255),
rel VARCHAR (50) rel VARCHAR (50)
); );
@ -158,6 +158,7 @@ CREATE TABLE test_sample.employee (
employee_id INT PRIMARY KEY, employee_id INT PRIMARY KEY,
first_name VARCHAR (255) NOT NULL, first_name VARCHAR (255) NOT NULL,
last_name VARCHAR (255) NOT NULL, last_name VARCHAR (255) NOT NULL,
employment_date timestamp with time zone,
manager_id INT, manager_id INT,
FOREIGN KEY (manager_id) FOREIGN KEY (manager_id)
REFERENCES test_sample.employee (employee_id) REFERENCES test_sample.employee (employee_id)
@ -167,17 +168,18 @@ INSERT INTO test_sample.employee (
employee_id, employee_id,
first_name, first_name,
last_name, last_name,
employment_date,
manager_id manager_id
) )
VALUES VALUES
(1, 'Windy', 'Hays', NULL), (1, 'Windy', 'Hays', '1999-01-08 04:05:06.100 -8:00', NULL),
(2, 'Ava', 'Christensen', 1), (2, 'Ava', 'Christensen', '1999-01-08 04:05:06', 1),
(3, 'Hassan', 'Conner', 1), (3, 'Hassan', 'Conner', '1999-01-08 04:05:06', 1),
(4, 'Anna', 'Reeves', 2), (4, 'Anna', 'Reeves', '1999-01-08 04:05:06', 2),
(5, 'Sau', 'Norman', 2), (5, 'Sau', 'Norman', '1999-01-08 04:05:06', 2),
(6, 'Kelsie', 'Hays', 3), (6, 'Kelsie', 'Hays', '1999-01-08 04:05:06', 3),
(7, 'Tory', 'Goff', 3), (7, 'Tory', 'Goff', '1999-01-08 04:05:06', 3),
(8, 'Salley', 'Lester', 3); (8, 'Salley', 'Lester', '1999-01-08 04:05:06', 3);
-- Person table ------------------ -- Person table ------------------

View file

@ -36,7 +36,7 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
assert.Assert(t, ok) assert.Assert(t, ok)
assert.Equal(t, actorIDField.Tag.Get("sql"), "unique") assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time") assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time")
@ -46,12 +46,12 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16") assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16")
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
assert.Assert(t, ok) assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique") assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16") assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16")
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
assert.Assert(t, ok) assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique") assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
staff := model.Staff{} staff := model.Staff{}

View file

@ -364,6 +364,8 @@ func TestScanToNestedStruct(t *testing.T) {
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns). SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns).
WHERE(Inventory.InventoryID.EQ(Int(1))) WHERE(Inventory.InventoryID.EQ(Int(1)))
type Language3 model.Language
dest := struct { dest := struct {
model.Inventory model.Inventory
Film struct { Film struct {
@ -371,7 +373,7 @@ func TestScanToNestedStruct(t *testing.T) {
Language model.Language Language model.Language
Language2 *model.Language Language2 *model.Language
Language3 *model.Language `sqlbuilder:"language"` Language3 *Language3 `sql:"table:language"`
Lang struct { Lang struct {
model.Language model.Language
} }
@ -392,7 +394,7 @@ func TestScanToNestedStruct(t *testing.T) {
assert.DeepEqual(t, dest.Film.Lang.Language, language1) assert.DeepEqual(t, dest.Film.Lang.Language, language1)
assert.DeepEqual(t, dest.Film.Lang2.Language, language1) assert.DeepEqual(t, dest.Film.Lang2.Language, language1)
assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil)) assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil))
assert.DeepEqual(t, dest.Film.Language3, &language1) assert.DeepEqual(t, model.Language(*dest.Film.Language3), language1)
}) })
} }
@ -443,7 +445,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("struct with slice of ints", func(t *testing.T) { t.Run("struct with slice of ints", func(t *testing.T) {
var dest struct { var dest struct {
model.Film model.Film
IDs []int32 `sqlbuilder:"inventory.inventory_id"` IDs []int32 `sql:"table:inventory,column:inventory_id"`
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
@ -456,7 +458,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice of structs with slice of ints", func(t *testing.T) { t.Run("slice of structs with slice of ints", func(t *testing.T) {
var dest []struct { var dest []struct {
model.Film model.Film
IDs []int32 `sqlbuilder:"inventory.inventory_id"` IDs []int32 `sql:"table:inventory,column:inventory_id"`
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)
@ -472,7 +474,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) { t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) {
var dest []struct { var dest []struct {
model.Film model.Film
IDs []*int32 `sqlbuilder:"inventory.inventory_id"` IDs []*int32 `sql:"table:inventory,column:inventory_id"`
} }
err := query.Query(db, &dest) err := query.Query(db, &dest)

View file

@ -495,10 +495,12 @@ func TestSelecSelfJoin1(t *testing.T) {
SELECT employee.employee_id AS "employee.employee_id", SELECT employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name", employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name", employee.last_name AS "employee.last_name",
employee.employment_date AS "employee.employment_date",
employee.manager_id AS "employee.manager_id", employee.manager_id AS "employee.manager_id",
manager.employee_id AS "manager.employee_id", manager.employee_id AS "manager.employee_id",
manager.first_name AS "manager.first_name", manager.first_name AS "manager.first_name",
manager.last_name AS "manager.last_name", manager.last_name AS "manager.last_name",
manager.employment_date AS "manager.employment_date",
manager.manager_id AS "manager.manager_id" manager.manager_id AS "manager.manager_id"
FROM test_sample.employee FROM test_sample.employee
LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id) LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id)
@ -524,19 +526,21 @@ ORDER BY employee.employee_id;
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 8) assert.Equal(t, len(dest), 8)
assert.DeepEqual(t, dest[0].Employee, model2.Employee{ assert.DeepEqual(t, dest[0].Employee, model2.Employee{
EmployeeID: 1, EmployeeID: 1,
FirstName: "Windy", FirstName: "Windy",
LastName: "Hays", LastName: "Hays",
ManagerID: nil, EmploymentDate: timestampWithTimeZone("1999-01-08 13:05:06.1 +0100 CET", 1),
ManagerID: nil,
}) })
assert.Assert(t, dest[0].Manager == nil) assert.Assert(t, dest[0].Manager == nil)
assert.DeepEqual(t, dest[7].Employee, model2.Employee{ assert.DeepEqual(t, dest[7].Employee, model2.Employee{
EmployeeID: 8, EmployeeID: 8,
FirstName: "Salley", FirstName: "Salley",
LastName: "Lester", LastName: "Lester",
ManagerID: int32Ptr(3), EmploymentDate: timestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1),
ManagerID: int32Ptr(3),
}) })
} }