Model sql tags.
This commit is contained in:
parent
367602757f
commit
c598978ba6
10 changed files with 110 additions and 67 deletions
|
|
@ -81,6 +81,20 @@ func (c ColumnInfo) GoModelType() string {
|
|||
return typeStr
|
||||
}
|
||||
|
||||
func (c ColumnInfo) GoModelTag(isPrimaryKey bool) string {
|
||||
tags := []string{}
|
||||
|
||||
if isPrimaryKey {
|
||||
tags = append(tags, "primary_key")
|
||||
}
|
||||
|
||||
if len(tags) > 0 {
|
||||
return "`sql:\"" + strings.Join(tags, ",") + "\"`"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func getColumnInfos(db *sql.DB, dbName, schemaName, tableName string) ([]ColumnInfo, error) {
|
||||
|
||||
query := `
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ import (
|
|||
|
||||
type {{camelize .Name}} struct {
|
||||
{{- range .Columns}}
|
||||
{{camelize .Name}} {{.GoModelType}} {{if $.IsUnique .Name}}` + "`sql:\"unique\"`" + ` {{end}}
|
||||
{{camelize .Name}} {{.GoModelType}} ` + "{{.GoModelTag ($.IsUnique .Name)}}" + `
|
||||
{{- end}}
|
||||
}
|
||||
`
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package sqlbuilder
|
|||
import (
|
||||
"bytes"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lib/pq"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -223,7 +224,7 @@ func ArgToString(value interface{}) string {
|
|||
case uuid.UUID:
|
||||
return stringQuote(bindVal.String())
|
||||
case time.Time:
|
||||
return stringQuote(bindVal.String())
|
||||
return stringQuote(string(pq.FormatTimestamp(bindVal)))
|
||||
default:
|
||||
return "[Unknown type]"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,8 +131,8 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
|
|||
if isGoBaseType(sliceElemType) {
|
||||
index := 0
|
||||
if structField != nil {
|
||||
columnName := getRefTableNameFrom(structField)
|
||||
index = getIndex(scanContext.columnNames, columnName)
|
||||
tableName, columnName := getRefTableNameFrom(structField)
|
||||
index = getIndex(scanContext.columnNames, tableName+"."+columnName)
|
||||
|
||||
if index < 0 {
|
||||
return
|
||||
|
|
@ -192,7 +192,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl
|
|||
}
|
||||
|
||||
func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string {
|
||||
tableName := getRefTableNameFrom(structField)
|
||||
tableName, _ := getRefTableNameFrom(structField)
|
||||
|
||||
if tableName == "" {
|
||||
tableName = snaker.CamelToSnake(structType.Name())
|
||||
|
|
@ -218,7 +218,7 @@ func getGroupKey(scanContext *scanContext, structType reflect.Type, structField
|
|||
if structGroupKey != "" {
|
||||
groupKeys = append(groupKeys, structGroupKey)
|
||||
}
|
||||
} else if field.Tag == `sql:"unique"` {
|
||||
} else if tagInfo(field.Tag.Get("sql")).isPrimaryKey {
|
||||
fieldName := field.Name
|
||||
columnName := tableName + "." + snaker.CamelToSnake(fieldName)
|
||||
|
||||
|
|
@ -293,11 +293,7 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
|
|||
|
||||
func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
|
||||
destinationSliceType := slicePtrValue.Type().Elem()
|
||||
elemType := destinationSliceType.Elem()
|
||||
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
return reflect.New(elemType.Elem())
|
||||
}
|
||||
elemType := indirectType(destinationSliceType.Elem())
|
||||
|
||||
return reflect.New(elemType)
|
||||
}
|
||||
|
|
@ -348,51 +344,37 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re
|
|||
return
|
||||
}
|
||||
|
||||
func getRefTableNameFrom(structField *reflect.StructField) string {
|
||||
func getRefTableNameFrom(structField *reflect.StructField) (table, column string) {
|
||||
if structField == nil {
|
||||
return ""
|
||||
return
|
||||
}
|
||||
|
||||
tagOverwriteName := structField.Tag.Get("sqlbuilder")
|
||||
sqlTag := structField.Tag.Get("sql")
|
||||
|
||||
if tagOverwriteName != "" {
|
||||
return tagOverwriteName
|
||||
if sqlTag != "" {
|
||||
tagInfo := tagInfo(sqlTag)
|
||||
|
||||
return tagInfo.table, tagInfo.column
|
||||
}
|
||||
|
||||
if !structField.Anonymous {
|
||||
return snaker.CamelToSnake(structField.Name)
|
||||
return snaker.CamelToSnake(structField.Name), ""
|
||||
}
|
||||
|
||||
var elemType string
|
||||
fieldType := indirectType(structField.Type)
|
||||
|
||||
if structField.Type.Kind() == reflect.Ptr {
|
||||
elem := structField.Type.Elem()
|
||||
if elem.Kind() == reflect.Struct {
|
||||
elemType = elem.Name()
|
||||
} else if elem.Kind() == reflect.Slice {
|
||||
elemType = elem.Elem().Name()
|
||||
}
|
||||
} else {
|
||||
if structField.Type.Kind() == reflect.Struct {
|
||||
elemType = structField.Type.Name()
|
||||
} else {
|
||||
sliceElem := structField.Type.Elem()
|
||||
if sliceElem.Kind() == reflect.Ptr {
|
||||
elemType = sliceElem.Elem().Name()
|
||||
} else {
|
||||
elemType = sliceElem.Name()
|
||||
}
|
||||
}
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
return snaker.CamelToSnake(elemType)
|
||||
return snaker.CamelToSnake(fieldType.Name()), ""
|
||||
}
|
||||
|
||||
func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
|
||||
structType := structPtrValue.Type().Elem()
|
||||
structValue := structPtrValue.Elem()
|
||||
|
||||
tableName := getRefTableNameFrom(structField)
|
||||
tableName, _ := getRefTableNameFrom(structField)
|
||||
|
||||
if tableName == "" {
|
||||
tableName = snaker.CamelToSnake(structType.Name())
|
||||
|
|
@ -689,3 +671,41 @@ func (s *scanContext) rowElemValuePtr(index int) reflect.Value {
|
|||
newElem.Elem().Set(rowElemValue)
|
||||
return newElem
|
||||
}
|
||||
|
||||
type sqlTagInfo struct {
|
||||
isPrimaryKey bool
|
||||
table string
|
||||
column string
|
||||
}
|
||||
|
||||
func tagInfo(tag string) sqlTagInfo {
|
||||
sqlTags := strings.Split(tag, ",")
|
||||
tagMap := map[string]string{}
|
||||
|
||||
for _, tag := range sqlTags {
|
||||
tagParts := strings.Split(tag, ":")
|
||||
|
||||
tagKey := tagParts[0]
|
||||
tagValue := ""
|
||||
if len(tagParts) > 1 {
|
||||
tagValue = tagParts[1]
|
||||
}
|
||||
|
||||
tagMap[tagKey] = tagValue
|
||||
}
|
||||
|
||||
_, isPrimaryKey := tagMap["primary_key"]
|
||||
|
||||
return sqlTagInfo{
|
||||
isPrimaryKey: isPrimaryKey,
|
||||
table: tagMap["table"],
|
||||
column: tagMap["column"],
|
||||
}
|
||||
}
|
||||
|
||||
func indirectType(reflectType reflect.Type) reflect.Type {
|
||||
if reflectType.Kind() != reflect.Ptr {
|
||||
return reflectType
|
||||
}
|
||||
return reflectType.Elem()
|
||||
}
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -142,11 +142,11 @@ VALUES (1, 1, 300, 300, 50000, 5000, 11.44, 11.44, 55.77, 55.77, 99.1, 99.1, 111
|
|||
DROP TABLE IF EXISTS test_sample.link;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS test_sample.link (
|
||||
ID serial PRIMARY KEY,
|
||||
url VARCHAR (255) NOT NULL,
|
||||
name VARCHAR (255) NOT NULL,
|
||||
description VARCHAR (255),
|
||||
rel VARCHAR (50)
|
||||
ID serial PRIMARY KEY,
|
||||
url VARCHAR (255) NOT NULL,
|
||||
name VARCHAR (255) NOT NULL,
|
||||
description VARCHAR (255),
|
||||
rel VARCHAR (50)
|
||||
);
|
||||
|
||||
|
||||
|
|
@ -158,6 +158,7 @@ CREATE TABLE test_sample.employee (
|
|||
employee_id INT PRIMARY KEY,
|
||||
first_name VARCHAR (255) NOT NULL,
|
||||
last_name VARCHAR (255) NOT NULL,
|
||||
employment_date timestamp with time zone,
|
||||
manager_id INT,
|
||||
FOREIGN KEY (manager_id)
|
||||
REFERENCES test_sample.employee (employee_id)
|
||||
|
|
@ -167,17 +168,18 @@ INSERT INTO test_sample.employee (
|
|||
employee_id,
|
||||
first_name,
|
||||
last_name,
|
||||
employment_date,
|
||||
manager_id
|
||||
)
|
||||
VALUES
|
||||
(1, 'Windy', 'Hays', NULL),
|
||||
(2, 'Ava', 'Christensen', 1),
|
||||
(3, 'Hassan', 'Conner', 1),
|
||||
(4, 'Anna', 'Reeves', 2),
|
||||
(5, 'Sau', 'Norman', 2),
|
||||
(6, 'Kelsie', 'Hays', 3),
|
||||
(7, 'Tory', 'Goff', 3),
|
||||
(8, 'Salley', 'Lester', 3);
|
||||
(1, 'Windy', 'Hays', '1999-01-08 04:05:06.100 -8:00', NULL),
|
||||
(2, 'Ava', 'Christensen', '1999-01-08 04:05:06', 1),
|
||||
(3, 'Hassan', 'Conner', '1999-01-08 04:05:06', 1),
|
||||
(4, 'Anna', 'Reeves', '1999-01-08 04:05:06', 2),
|
||||
(5, 'Sau', 'Norman', '1999-01-08 04:05:06', 2),
|
||||
(6, 'Kelsie', 'Hays', '1999-01-08 04:05:06', 3),
|
||||
(7, 'Tory', 'Goff', '1999-01-08 04:05:06', 3),
|
||||
(8, 'Salley', 'Lester', '1999-01-08 04:05:06', 3);
|
||||
|
||||
|
||||
-- Person table ------------------
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ func TestGenerateModel(t *testing.T) {
|
|||
assert.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
|
||||
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, actorIDField.Tag.Get("sql"), "unique")
|
||||
assert.Equal(t, actorIDField.Tag.Get("sql"), "primary_key")
|
||||
assert.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
|
||||
assert.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
|
||||
assert.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time")
|
||||
|
|
@ -46,12 +46,12 @@ func TestGenerateModel(t *testing.T) {
|
|||
assert.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int16")
|
||||
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
|
||||
assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
|
||||
|
||||
assert.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int16")
|
||||
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
|
||||
assert.Assert(t, ok)
|
||||
assert.Equal(t, filmIDField.Tag.Get("sql"), "unique")
|
||||
assert.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
|
||||
|
||||
staff := model.Staff{}
|
||||
|
||||
|
|
|
|||
|
|
@ -364,6 +364,8 @@ func TestScanToNestedStruct(t *testing.T) {
|
|||
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns).
|
||||
WHERE(Inventory.InventoryID.EQ(Int(1)))
|
||||
|
||||
type Language3 model.Language
|
||||
|
||||
dest := struct {
|
||||
model.Inventory
|
||||
Film struct {
|
||||
|
|
@ -371,7 +373,7 @@ func TestScanToNestedStruct(t *testing.T) {
|
|||
|
||||
Language model.Language
|
||||
Language2 *model.Language
|
||||
Language3 *model.Language `sqlbuilder:"language"`
|
||||
Language3 *Language3 `sql:"table:language"`
|
||||
Lang struct {
|
||||
model.Language
|
||||
}
|
||||
|
|
@ -392,7 +394,7 @@ func TestScanToNestedStruct(t *testing.T) {
|
|||
assert.DeepEqual(t, dest.Film.Lang.Language, language1)
|
||||
assert.DeepEqual(t, dest.Film.Lang2.Language, language1)
|
||||
assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil))
|
||||
assert.DeepEqual(t, dest.Film.Language3, &language1)
|
||||
assert.DeepEqual(t, model.Language(*dest.Film.Language3), language1)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -443,7 +445,7 @@ func TestScanToSlice(t *testing.T) {
|
|||
t.Run("struct with slice of ints", func(t *testing.T) {
|
||||
var dest struct {
|
||||
model.Film
|
||||
IDs []int32 `sqlbuilder:"inventory.inventory_id"`
|
||||
IDs []int32 `sql:"table:inventory,column:inventory_id"`
|
||||
}
|
||||
|
||||
err := query.Query(db, &dest)
|
||||
|
|
@ -456,7 +458,7 @@ func TestScanToSlice(t *testing.T) {
|
|||
t.Run("slice of structs with slice of ints", func(t *testing.T) {
|
||||
var dest []struct {
|
||||
model.Film
|
||||
IDs []int32 `sqlbuilder:"inventory.inventory_id"`
|
||||
IDs []int32 `sql:"table:inventory,column:inventory_id"`
|
||||
}
|
||||
|
||||
err := query.Query(db, &dest)
|
||||
|
|
@ -472,7 +474,7 @@ func TestScanToSlice(t *testing.T) {
|
|||
t.Run("slice of structs with slice of pointer to ints", func(t *testing.T) {
|
||||
var dest []struct {
|
||||
model.Film
|
||||
IDs []*int32 `sqlbuilder:"inventory.inventory_id"`
|
||||
IDs []*int32 `sql:"table:inventory,column:inventory_id"`
|
||||
}
|
||||
|
||||
err := query.Query(db, &dest)
|
||||
|
|
|
|||
|
|
@ -495,10 +495,12 @@ func TestSelecSelfJoin1(t *testing.T) {
|
|||
SELECT employee.employee_id AS "employee.employee_id",
|
||||
employee.first_name AS "employee.first_name",
|
||||
employee.last_name AS "employee.last_name",
|
||||
employee.employment_date AS "employee.employment_date",
|
||||
employee.manager_id AS "employee.manager_id",
|
||||
manager.employee_id AS "manager.employee_id",
|
||||
manager.first_name AS "manager.first_name",
|
||||
manager.last_name AS "manager.last_name",
|
||||
manager.employment_date AS "manager.employment_date",
|
||||
manager.manager_id AS "manager.manager_id"
|
||||
FROM test_sample.employee
|
||||
LEFT JOIN test_sample.employee AS manager ON (manager.employee_id = employee.manager_id)
|
||||
|
|
@ -524,19 +526,21 @@ ORDER BY employee.employee_id;
|
|||
assert.NilError(t, err)
|
||||
assert.Equal(t, len(dest), 8)
|
||||
assert.DeepEqual(t, dest[0].Employee, model2.Employee{
|
||||
EmployeeID: 1,
|
||||
FirstName: "Windy",
|
||||
LastName: "Hays",
|
||||
ManagerID: nil,
|
||||
EmployeeID: 1,
|
||||
FirstName: "Windy",
|
||||
LastName: "Hays",
|
||||
EmploymentDate: timestampWithTimeZone("1999-01-08 13:05:06.1 +0100 CET", 1),
|
||||
ManagerID: nil,
|
||||
})
|
||||
|
||||
assert.Assert(t, dest[0].Manager == nil)
|
||||
|
||||
assert.DeepEqual(t, dest[7].Employee, model2.Employee{
|
||||
EmployeeID: 8,
|
||||
FirstName: "Salley",
|
||||
LastName: "Lester",
|
||||
ManagerID: int32Ptr(3),
|
||||
EmployeeID: 8,
|
||||
FirstName: "Salley",
|
||||
LastName: "Lester",
|
||||
EmploymentDate: timestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1),
|
||||
ManagerID: int32Ptr(3),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue