From 75f8e0dfec5a66f186253eada8ceab28214f148e Mon Sep 17 00:00:00 2001 From: sub0Zero Date: Tue, 5 Mar 2019 18:55:47 +0100 Subject: [PATCH] Select statement execution and mapping to struct or slice added. --- generator/metadata/column_info.go | 5 +- generator/metadata/table_info.go | 20 +++++ generator/templates.go | 8 ++ sqlbuilder/execution/execution.go | 144 ++++++++++++++++++++++++++++++ sqlbuilder/statement.go | 48 ++++++++++ tests/generator_test.go | 79 +++++++++++++--- 6 files changed, 291 insertions(+), 13 deletions(-) create mode 100644 sqlbuilder/execution/execution.go diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go index 84dd754..826a33f 100644 --- a/generator/metadata/column_info.go +++ b/generator/metadata/column_info.go @@ -45,10 +45,11 @@ func (c ColumnInfo) GoBaseType() string { case "smallint": return "int16" case "integer": - return "int" + return "int32" case "bigint": return "int64" - //case "date" : return "time.Time" + case "date", "timestamp without time zone", "timestamp with time zone": + return "time.Time" case "bytea": return "[]byte" case "text": diff --git a/generator/metadata/table_info.go b/generator/metadata/table_info.go index b667948..e3c6d77 100644 --- a/generator/metadata/table_info.go +++ b/generator/metadata/table_info.go @@ -15,6 +15,26 @@ type TableInfo struct { DatabaseInfo *DatabaseInfo } +func (t TableInfo) GetImports() []string { + imports := map[string]string{} + + for _, column := range t.Columns { + columnType := column.GoBaseType() + + if columnType == "time.Time" { + imports["time.Time"] = "time" + } + } + + ret := []string{} + + for _, packageImport := range imports { + ret = append(ret, packageImport) + } + + return ret +} + func (t TableInfo) IsForeignKey(columnName string) bool { _, exist := t.ForeignTableMap[columnName] diff --git a/generator/templates.go b/generator/templates.go index 626f1d8..a6c6c30 100644 --- a/generator/templates.go +++ b/generator/templates.go @@ -11,6 +11,8 @@ type {{.ToGoStructName}} struct { {{- range .Columns}} {{.ToGoFieldName}} sqlbuilder.NonAliasColumn {{- end}} + + All []sqlbuilder.Projection } var {{.ToGoVarName}} = &{{.ToGoStructName}}{ @@ -20,6 +22,8 @@ var {{.ToGoVarName}} = &{{.ToGoStructName}}{ {{- range .Columns}} {{.ToGoFieldName}}: {{.ToGoVarName}}, {{- end}} + + All: []sqlbuilder.Projection{ {{.ToGoColumnFieldList ", "}} }, } var ( @@ -31,6 +35,10 @@ var ( var DataModelTemplate = `package model +{{range .GetImports}} + import "{{.}}" +{{end}} + type {{.ToGoModelStructName}} struct { {{- range .Columns}} {{.ToGoDMFieldName}} {{.ToGoType}} {{if .IsUnique}}` + "`sql:\"unique\"`" + ` {{end}} diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go new file mode 100644 index 0000000..3a70e66 --- /dev/null +++ b/sqlbuilder/execution/execution.go @@ -0,0 +1,144 @@ +package execution + +import ( + "database/sql" + "errors" + "github.com/serenize/snaker" + "reflect" +) + +func Execute(db *sql.DB, query string, destinationPtr interface{}) error { + if db == nil { + return errors.New("db is nil") + } + + if destinationPtr == nil { + return errors.New("Destination is nil ") + } + + destinationType := reflect.TypeOf(destinationPtr) + if destinationType.Kind() != reflect.Ptr { + return errors.New("Destination has to be a pointer to slice or pointer to struct ") + } + + rows, err := db.Query(query) + + if err != nil { + return err + } + + columnNames, _ := rows.Columns() + columnTypes, _ := rows.ColumnTypes() + values := createScanValue(columnTypes) + + for rows.Next() { + err := rows.Scan(values...) + + if err != nil { + return err + } + + if destinationType.Elem().Kind() == reflect.Slice { + + destinationStructPtr := newElemForSlice(destinationPtr) + + err = mapValuesToStruct(columnNames, values, destinationStructPtr) + + if err != nil { + return err + } + + appendElemToSlice(destinationPtr, destinationStructPtr) + } else if destinationType.Elem().Kind() == reflect.Struct { + return mapValuesToStruct(columnNames, values, destinationPtr) + } + } + + err = rows.Err() + + if err != nil { + return err + } + + return nil +} + +func appendElemToSlice(slice interface{}, obj interface{}) { + //spew.Dump(slice) + sliceValue := reflect.ValueOf(slice).Elem() + + sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(obj).Elem())) +} + +func newElemForSlice(destinationSlicePtr interface{}) interface{} { + destinationSliceType := reflect.TypeOf(destinationSlicePtr).Elem() + + return reflect.New(destinationSliceType.Elem()).Interface() +} + +func mapValuesToStruct(columnNames []string, row []interface{}, destination interface{}) error { + structType := reflect.TypeOf(destination).Elem() + structValue := reflect.ValueOf(destination).Elem() + + for i := 0; i < structType.NumField(); i++ { + fieldType := structType.Field(i) + fieldValue := structValue.Field(i) + + fieldName := fieldType.Name + + //columnName := structName + "." + fieldName + columnName := snaker.CamelToSnake(fieldName) + + rowIndex := getIndex(columnNames, columnName) + + if rowIndex < 0 { + continue + } + + rowColumnValue := reflect.ValueOf(row[rowIndex]) + + setReflectValue(rowColumnValue, fieldValue) + } + + return nil +} + +func setReflectValue(source, destination reflect.Value) { + if destination.Kind() == reflect.Ptr { + if source.Kind() == reflect.Ptr { + destination.Set(source) + } else { + destination.Set(source.Addr()) + } + } else { + if source.Kind() == reflect.Ptr { + destination.Set(source.Elem()) + } else { + destination.Set(source) + } + } +} + +func getIndex(list []string, text string) int { + for i, str := range list { + if str == text { + return i + } + } + + return -1 +} + +func createScanValue(columnTypes []*sql.ColumnType) []interface{} { + values := make([]interface{}, len(columnTypes)) + + for i, sqlColumnType := range columnTypes { + columnType := sqlColumnType.ScanType() + + columnValue := reflect.New(columnType) + + values[i] = columnValue.Interface() + } + + return values +} diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 3536c77..df6ca0d 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -2,7 +2,10 @@ package sqlbuilder import ( "bytes" + "database/sql" "fmt" + "github.com/sub0Zero/go-sqlbuilder/sqlbuilder/execution" + "reflect" "regexp" "github.com/dropbox/godropbox/errors" @@ -11,6 +14,7 @@ import ( type Statement interface { // String returns generated SQL as string. String(database string) (sql string, err error) + Execute(db *sql.DB, destination interface{}) error } type SelectStatement interface { @@ -133,6 +137,10 @@ type unionStatementImpl struct { unique bool } +func (us *unionStatementImpl) Execute(db *sql.DB, data interface{}) error { + return nil +} + func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { us.where = expression return us @@ -305,6 +313,22 @@ type selectStatementImpl struct { distinct bool } +func (s *selectStatementImpl) Execute(db *sql.DB, destination interface{}) error { + destinationType := reflect.TypeOf(destination) + + if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct { + s.Limit(1) + } + + query, err := s.String("dvds") + + if err != nil { + return err + } + + return execution.Execute(db, query, destination) +} + func (s *selectStatementImpl) Copy() SelectStatement { ret := *s return &ret @@ -495,6 +519,10 @@ type insertStatementImpl struct { ignore bool } +func (i *insertStatementImpl) Execute(db *sql.DB, data interface{}) error { + return nil +} + func (s *insertStatementImpl) Add( row ...Expression) InsertStatement { @@ -665,6 +693,10 @@ type updateStatementImpl struct { comment string } +func (u *updateStatementImpl) Execute(db *sql.DB, data interface{}) error { + return nil +} + func (u *updateStatementImpl) Set( column NonAliasColumn, expression Expression) UpdateStatement { @@ -808,6 +840,10 @@ type deleteStatementImpl struct { comment string } +func (d *deleteStatementImpl) Execute(db *sql.DB, data interface{}) error { + return nil +} + func (d *deleteStatementImpl) Where(expression BoolExpression) DeleteStatement { d.where = expression return d @@ -895,6 +931,10 @@ type tableLock struct { w bool } +func (l *lockStatementImpl) Execute(db *sql.DB, data interface{}) error { + return nil +} + // AddReadLock takes read lock on the table. func (s *lockStatementImpl) AddReadLock(t *Table) LockStatement { s.locks = append(s.locks, tableLock{t: t, w: false}) @@ -951,6 +991,10 @@ func NewUnlockStatement() UnlockStatement { type unlockStatementImpl struct { } +func (u *unlockStatementImpl) Execute(db *sql.DB, data interface{}) error { + return nil +} + func (s *unlockStatementImpl) String(database string) (sql string, err error) { return "UNLOCK TABLES", nil } @@ -968,6 +1012,10 @@ type gtidNextStatementImpl struct { gno uint64 } +func (g *gtidNextStatementImpl) Execute(db *sql.DB, data interface{}) error { + return nil +} + func (s *gtidNextStatementImpl) String(database string) (sql string, err error) { // This statement sets a session local variable defining what the next transaction ID is. It // does not interact with other MySQL sessions. It is neither a DDL nor DML statement, so we diff --git a/tests/generator_test.go b/tests/generator_test.go index 3bf31a7..1c40faa 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -1,15 +1,18 @@ package tests import ( + "database/sql" "fmt" - . "github.com/sub0Zero/.test_files/dvd_rental/dvds/table" + "github.com/davecgh/go-spew/spew" "github.com/sub0Zero/go-sqlbuilder/generator" - . "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" + "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" + . "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" "gotest.tools/assert" + "os" "testing" ) -var ( +const ( folderPath = ".test_files/" host = "localhost" port = 5432 @@ -19,26 +22,80 @@ var ( schemaName = "dvds" ) +var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname) +var db *sql.DB + //go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files +func TestMain(m *testing.M) { + fmt.Println("Begin") + var err error + db, err = sql.Open("postgres", connectString) + if err != nil { + panic("Failed to connect to test db") + } + defer db.Close() + + ret := m.Run() + + db.Close() + fmt.Println("END") + + os.Exit(ret) +} + func TestGenerateModel(t *testing.T) { - connectString := fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", - host, port, user, password, dbname) err := generator.Generate(folderPath, connectString, dbname, schemaName) assert.NilError(t, err) - err = generator.Generate(folderPath, connectString, dbname, "sport") - - assert.NilError(t, err) + //err = generator.Generate(folderPath, connectString, dbname, "sport") + // + //assert.NilError(t, err) } func TestSelectQuery(t *testing.T) { - query, err := Actor.InnerJoinOn(Store, Eq(Actor.ActorID, Store.StoreID)). - Select(Store.StoreID, Store.AddressID, Actor.ActorID).String(schemaName) + //query := Actor.InnerJoinOn(Store, Eq(Actor.ActorID, Store.StoreID)). + // Select(Store.StoreID, Store.AddressID, Actor.ActorID) + // + //queryStr, err := query.String(schemaName) + // + //assert.NilError(t, err) + // + //assert.Equal(t, queryStr, "SELECT store.store_id,store.address_id,actor.actor_id FROM dvds.actor JOIN dvds.store ON actor.actor_id=store.store_id") + // + //err = query.Execute(db, nil) + + customers := []model.Customer{} + + query := Customer.Select(Customer.All...) + + queryStr, err := query.String(schemaName) + + assert.NilError(t, err) + assert.Equal(t, queryStr, "SELECT customer.customer_id,customer.store_id,customer.first_name,customer.last_name,customer.email,customer.address_id,customer.activebool,customer.create_date,customer.last_update,customer.active FROM dvds.customer") + + //fmt.Println(queryStr) + + err = query.Execute(db, &customers) + + //fmt.Println(customers) + // + //spew.Sdump(customers) assert.NilError(t, err) - assert.Equal(t, query, "SELECT store.store_id,store.address_id,actor.actor_id FROM dvds.actor JOIN dvds.store ON actor.actor_id=store.store_id") + assert.Equal(t, len(customers), 599) + + actor := model.Actor{} + err = Actor.Select(Actor.All...).Execute(db, &actor) + + assert.NilError(t, err) + + spew.Dump(actor) + //time, _ := time.Parse("2006-01-02 15:04:05.00MST", "2013-05-26 14:47:57.62MST") + assert.Equal(t, actor.ActorID, int32(1)) + assert.Equal(t, actor.FirstName, "Penelope") + assert.Equal(t, actor.LastName, "Guiness") }