Model sql tags.

This commit is contained in:
go-jet 2019-06-12 12:47:30 +02:00
parent 367602757f
commit c598978ba6
10 changed files with 110 additions and 67 deletions

View file

@ -81,6 +81,20 @@ func (c ColumnInfo) GoModelType() string {
return typeStr
}
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
tags := []string{}
if isPrimaryKey {
tags = append(tags, "primary_key")
}
if len(tags) > 0 {
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
}
return ""
}
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
query := `

View file

@ -81,7 +81,7 @@ import (
type {{camelize .Name}} struct {
{{- range .Columns}}
{{camelize .Name}} {{.GoModelType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}}
{{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsUnique .Name)}}" + `
{{- end}}
}
`

View file

@ -3,6 +3,7 @@ package sqlbuilder
import (
"bytes"
"github.com/google/uuid"
"github.com/lib/pq"
"strconv"
"strings"
"time"
@ -223,7 +224,7 @@ func ArgToString(value interface{}) string {
case uuid.UUID:
return stringQuote(bindVal.String())
case time.Time:
return stringQuote(bindVal.String())
return stringQuote(string(pq.FormatTimestamp(bindVal)))
default:
return "[Unknown type]"
}

View file

@ -131,8 +131,8 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
if isGoBaseType(sliceElemType) {
index := 0
if structField != nil {
columnName := getRefTableNameFrom(structField)
index = getIndex(scanContext.columnNames, columnName)
tableName, columnName := getRefTableNameFrom(structField)
index = getIndex(scanContext.columnNames, tableName+"."+columnName)
if index < 0 {
return
@ -192,7 +192,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
}
func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string {
tableName := getRefTableNameFrom(structField)
tableName, _ := getRefTableNameFrom(structField)
if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name())
@ -218,7 +218,7 @@ func getGroupKey(scanContext *scanContext, structType reflect.Type, structField
if structGroupKey != "" {
groupKeys = append(groupKeys, structGroupKey)
}
} else if field.Tag == `sql:"unique"` {
} else if tagInfo(field.Tag.Get("sql")).isPrimaryKey {
fieldName := field.Name
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
@ -293,11 +293,7 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
destinationSliceType := slicePtrValue.Type().Elem()
elemType := destinationSliceType.Elem()
if elemType.Kind() == reflect.Ptr {
return reflect.New(elemType.Elem())
}
elemType := indirectType(destinationSliceType.Elem())
return reflect.New(elemType)
}
@ -348,51 +344,37 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
return
}
func getRefTableNameFrom(structField *reflect.StructField) string {
func getRefTableNameFrom(structField *reflect.StructField) (table, column string) {
if structField == nil {
return ""
return
}
tagOverwriteName := structField.Tag.Get("sqlbuilder")
sqlTag := structField.Tag.Get("sql")
if tagOverwriteName != "" {
return tagOverwriteName
if sqlTag != "" {
tagInfo := tagInfo(sqlTag)
return tagInfo.table, tagInfo.column
}
if !structField.Anonymous {
return snaker.CamelToSnake(structField.Name)
return snaker.CamelToSnake(structField.Name), ""
}
var elemType string
fieldType := indirectType(structField.Type)
if structField.Type.Kind() == reflect.Ptr {
elem := structField.Type.Elem()
if elem.Kind() == reflect.Struct {
elemType = elem.Name()
} else if elem.Kind() == reflect.Slice {
elemType = elem.Elem().Name()
}
} else {
if structField.Type.Kind() == reflect.Struct {
elemType = structField.Type.Name()
} else {
sliceElem := structField.Type.Elem()
if sliceElem.Kind() == reflect.Ptr {
elemType = sliceElem.Elem().Name()
} else {
elemType = sliceElem.Name()
}
}
if fieldType.Kind() == reflect.Slice {
fieldType = fieldType.Elem()
}
return snaker.CamelToSnake(elemType)
return snaker.CamelToSnake(fieldType.Name()), ""
}
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
structType := structPtrValue.Type().Elem()
structValue := structPtrValue.Elem()
tableName := getRefTableNameFrom(structField)
tableName, _ := getRefTableNameFrom(structField)
if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name())
@ -689,3 +671,41 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
newElem.Elem().Set(rowElemValue)
return newElem
}
type sqlTagInfo struct {
isPrimaryKey bool
table string
column string
}
func tagInfo(tag string) sqlTagInfo {
sqlTags := strings.Split(tag, ",")
tagMap := map[string]string{}
for _, tag := range sqlTags {
tagParts := strings.Split(tag, ":")
tagKey := tagParts[0]
tagValue := ""
if len(tagParts) > 1 {
tagValue = tagParts[1]
}
tagMap[tagKey] = tagValue
}
_, isPrimaryKey := tagMap["primary_key"]
return sqlTagInfo{
isPrimaryKey: isPrimaryKey,
table: tagMap["table"],
column: tagMap["column"],
}
}
func indirectType(reflectType reflect.Type) reflect.Type {
if reflectType.Kind() != reflect.Ptr {
return reflectType
}
return reflectType.Elem()
}

Binary file not shown.

View file

@ -142,11 +142,11 @@ VALUES (1, 1, 300, 300, 50000, 5000, 11.44, 11.44, 55.77, 55.77, 99.1, 99.1, 111
DROP TABLE IF EXISTS test_sample.link;
CREATE TABLE IF NOT EXISTS test_sample.link (
ID serial PRIMARY KEY,
url VARCHAR (255) NOT NULL,
name VARCHAR (255) NOT NULL,
description VARCHAR (255),
rel VARCHAR (50)
ID serial PRIMARY KEY,
url VARCHAR (255) NOT NULL,
name VARCHAR (255) NOT NULL,
description VARCHAR (255),
rel VARCHAR (50)
);
@ -158,6 +158,7 @@ CREATE TABLE test_sample.employee (
employee_id INT PRIMARY KEY,
first_name VARCHAR (255) NOT NULL,
last_name VARCHAR (255) NOT NULL,
employment_date timestamp with time zone,
manager_id INT,
FOREIGN KEY (manager_id)
REFERENCES test_sample.employee (employee_id)
@ -167,17 +168,18 @@ INSERT INTO test_sample.employee (
employee_id,
first_name,
last_name,
employment_date,
manager_id
)
VALUES
(1, 'Windy', 'Hays', NULL),
(2, 'Ava', 'Christensen', 1),
(3, 'Hassan', 'Conner', 1),
(4, 'Anna', 'Reeves', 2),
(5, 'Sau', 'Norman', 2),
(6, 'Kelsie', 'Hays', 3),
(7, 'Tory', 'Goff', 3),
(8, 'Salley', 'Lester', 3);
(1, 'Windy', 'Hays', '1999-01-08 04:05:06.100 -8:00', NULL),
(2, 'Ava', 'Christensen', '1999-01-08 04:05:06', 1),
(3, 'Hassan', 'Conner', '1999-01-08 04:05:06', 1),
(4, 'Anna', 'Reeves', '1999-01-08 04:05:06', 2),
(5, 'Sau', 'Norman', '1999-01-08 04:05:06', 2),
(6, 'Kelsie', 'Hays', '1999-01-08 04:05:06', 3),
(7, 'Tory', 'Goff', '1999-01-08 04:05:06', 3),
(8, 'Salley', 'Lester', '1999-01-08 04:05:06', 3);
-- Person table ------------------

View file

@ -36,7 +36,7 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
assert.Assert(t, ok)
assert.Equal(t, actorIDField.Tag.Get("sql"), "unique")
assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time")
@ -46,12 +46,12 @@ func TestGenerateModel(t *testing.T) {
assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16")
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16")
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
assert.Assert(t, ok)
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
staff := model.Staff{}

View file

@ -364,6 +364,8 @@ func TestScanToNestedStruct(t *testing.T) {
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns).
WHERE(Inventory.InventoryID.EQ(Int(1)))
type Language3 model.Language
dest := struct {
model.Inventory
Film struct {
@ -371,7 +373,7 @@ func TestScanToNestedStruct(t *testing.T) {
Language model.Language
Language2 *model.Language
Language3 *model.Language `sqlbuilder:"language"`
Language3 *Language3 `sql:"table:language"`
Lang struct {
model.Language
}
@ -392,7 +394,7 @@ func TestScanToNestedStruct(t *testing.T) {
assert.DeepEqual(t, dest.Film.Lang.Language, language1)
assert.DeepEqual(t, dest.Film.Lang2.Language, language1)
assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil))
assert.DeepEqual(t, dest.Film.Language3, &language1)
assert.DeepEqual(t, model.Language(*dest.Film.Language3), language1)
})
}
@ -443,7 +445,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("struct with slice of ints", func(t *testing.T) {
var dest struct {
model.Film
IDs []int32 `sqlbuilder:"inventory.inventory_id"`
IDs []int32 `sql:"table:inventory,column:inventory_id"`
}
err := query.Query(db, &dest)
@ -456,7 +458,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice of structs with slice of ints", func(t *testing.T) {
var dest []struct {
model.Film
IDs []int32 `sqlbuilder:"inventory.inventory_id"`
IDs []int32 `sql:"table:inventory,column:inventory_id"`
}
err := query.Query(db, &dest)
@ -472,7 +474,7 @@ func TestScanToSlice(t *testing.T) {
t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) {
var dest []struct {
model.Film
IDs []*int32 `sqlbuilder:"inventory.inventory_id"`
IDs []*int32 `sql:"table:inventory,column:inventory_id"`
}
err := query.Query(db, &dest)

View file

@ -495,10 +495,12 @@ func TestSelecSelfJoin1(t *testing.T) {
SELECT employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name",
employee.employment_date AS "employee.employment_date",
employee.manager_id AS "employee.manager_id",
manager.employee_id AS "manager.employee_id",
manager.first_name AS "manager.first_name",
manager.last_name AS "manager.last_name",
manager.employment_date AS "manager.employment_date",
manager.manager_id AS "manager.manager_id"
FROM test_sample.employee
LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id)
@ -524,19 +526,21 @@ ORDER BY employee.employee_id;
assert.NilError(t, err)
assert.Equal(t, len(dest), 8)
assert.DeepEqual(t, dest[0].Employee, model2.Employee{
EmployeeID: 1,
FirstName: "Windy",
LastName: "Hays",
ManagerID: nil,
EmployeeID: 1,
FirstName: "Windy",
LastName: "Hays",
EmploymentDate: timestampWithTimeZone("1999-01-08 13:05:06.1 +0100 CET", 1),
ManagerID: nil,
})
assert.Assert(t, dest[0].Manager == nil)
assert.DeepEqual(t, dest[7].Employee, model2.Employee{
EmployeeID: 8,
FirstName: "Salley",
LastName: "Lester",
ManagerID: int32Ptr(3),
EmployeeID: 8,
FirstName: "Salley",
LastName: "Lester",
EmploymentDate: timestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1),
ManagerID: int32Ptr(3),
})
}