diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go index 826a33f..a8b01e8 100644 --- a/generator/metadata/column_info.go +++ b/generator/metadata/column_info.go @@ -28,7 +28,7 @@ func (c ColumnInfo) ToGoVarName() string { func (c ColumnInfo) ToGoType() string { typeStr := c.GoBaseType() - if c.IsNullable { + if c.IsNullable || c.TableInfo.IsForeignKey(c.Name) { return "*" + typeStr } @@ -54,6 +54,8 @@ func (c ColumnInfo) GoBaseType() string { return "[]byte" case "text": return "string" + case "numeric", "real": + return "float64" default: return "string" } diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 944f773..afe81e9 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -74,16 +74,24 @@ func (c *baseColumn) setTableName(table string) error { } func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { + + c.SerializeSql(out) + + if c.table != "" { + _, _ = out.WriteString(" AS \"" + c.table + "." + c.name + "\"") + } + + return nil +} + +func (c *baseColumn) SerializeSql(out *bytes.Buffer) error { if c.table != "" { _, _ = out.WriteString(c.table) _, _ = out.WriteString(".") } _, _ = out.WriteString(c.name) - return nil -} -func (c *baseColumn) SerializeSql(out *bytes.Buffer) error { - return c.SerializeSqlForColumnList(out) + return nil } type bytesColumn struct { diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 3a70e66..f3f7784 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/serenize/snaker" "reflect" + "time" ) func Execute(db *sql.DB, query string, destinationPtr interface{}) error { @@ -26,10 +27,14 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { if err != nil { return err } + defer rows.Close() columnNames, _ := rows.Columns() columnTypes, _ := rows.ColumnTypes() values := createScanValue(columnTypes) + // + //spew.Dump(columnTypes) + //spew.Dump(values) for rows.Next() { err := rows.Scan(values...) @@ -79,30 +84,79 @@ func newElemForSlice(destinationSlicePtr interface{}) interface{} { func mapValuesToStruct(columnNames []string, row []interface{}, destination interface{}) error { structType := reflect.TypeOf(destination).Elem() structValue := reflect.ValueOf(destination).Elem() + structName := structType.Name() for i := 0; i < structType.NumField(); i++ { fieldType := structType.Field(i) + //fieldTypeName := fieldType.Name fieldValue := structValue.Field(i) + //fmt.Println("---------------", fieldTypeName) + //spew.Dump(fieldType.Type) - fieldName := fieldType.Name + if !isDbBaseType(fieldType.Type) { + if fieldType.Type.Kind() == reflect.Struct { + err := mapValuesToStruct(columnNames, row, fieldValue.Addr().Interface()) + if err != nil { + return err + } + } else if fieldType.Type.Kind() == reflect.Ptr { + newStructValue := reflect.New(fieldType.Type.Elem()) + err := mapValuesToStruct(columnNames, row, newStructValue.Interface()) + if err != nil { + return err + } - //columnName := structName + "." + fieldName - columnName := snaker.CamelToSnake(fieldName) + if newStructValue.Elem().Interface() != reflect.New(fieldType.Type.Elem()).Elem().Interface() { + fieldValue.Set(newStructValue) + } + } + } else { + fieldName := fieldType.Name - rowIndex := getIndex(columnNames, columnName) + columnName := snaker.CamelToSnake(structName) + "." + snaker.CamelToSnake(fieldName) + //columnName := snaker.CamelToSnake(fieldName) - if rowIndex < 0 { - continue + //fmt.Println(columnName) + rowIndex := getIndex(columnNames, columnName) + + if rowIndex < 0 { + continue + } + + //spew.Dump(row[rowIndex]) + + rowColumnValue := reflect.ValueOf(row[rowIndex]) + + //spew.Dump(rowColumnValue, fieldValue) + setReflectValue(rowColumnValue, fieldValue) } - - rowColumnValue := reflect.ValueOf(row[rowIndex]) - - setReflectValue(rowColumnValue, fieldValue) } return nil } +var timeType = reflect.TypeOf(time.Now()) +var floatType = reflect.TypeOf(1.0) +var stringType = reflect.TypeOf("str") +var intType = reflect.TypeOf(1) + +func isDbBaseType(objType reflect.Type) bool { + //isBaseType := objType == timeType || floatType == objType || stringType == objType || intType == objType + //isPtrToBaseType := objType.Kind() == reflect.Ptr && (objType.Elem() == timeType || floatType == objType.Elem() || + // stringType == objType.Elem() || intType == objType.Elem()) + typeStr := objType.String() + + switch typeStr { + case "string", "int32", "int16", "float64", "time.Time": + return true + case "*string", "*int32", "*int16", "*float64", "*time.Time": + return true + } + + //return isBaseType || isPtrToBaseType + return false +} + func setReflectValue(source, destination reflect.Value) { if destination.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr { @@ -133,7 +187,7 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} { values := make([]interface{}, len(columnTypes)) for i, sqlColumnType := range columnTypes { - columnType := sqlColumnType.ScanType() + columnType := getScanType(sqlColumnType) columnValue := reflect.New(columnType) @@ -142,3 +196,18 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} { return values } + +func getScanType(columnType *sql.ColumnType) reflect.Type { + scanType := columnType.ScanType() + //fmt.Println(scanType.String()) + if scanType.String() != "interface {}" { + return scanType + } + + switch columnType.DatabaseTypeName() { + case "FLOAT4": + return floatType + default: + return stringType + } +} diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index f7ef1f4..b42084b 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -443,7 +443,7 @@ func Eq(lhs, rhs Expression) BoolExpression { if ok && sqltypes.Value(lit.value).IsNull() { return newBoolExpression(lhs, rhs, []byte(" IS ")) } - return newBoolExpression(lhs, rhs, []byte("=")) + return newBoolExpression(lhs, rhs, []byte(" = ")) } // Returns a representation of "a=b", where b is a literal diff --git a/tests/generator_test.go b/tests/generator_test.go index 6ffba6b..177ba22 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -3,8 +3,8 @@ package tests import ( "database/sql" "fmt" - "github.com/davecgh/go-spew/spew" "github.com/sub0Zero/go-sqlbuilder/generator" + "github.com/sub0Zero/go-sqlbuilder/sqlbuilder" "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0Zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" "gotest.tools/assert" @@ -74,8 +74,7 @@ func TestSelectQuery(t *testing.T) { queryStr, err := query.String() assert.NilError(t, err) - assert.Equal(t, queryStr, "SELECT customer.customer_id,customer.store_id,customer.first_name,customer.last_name,customer.email,customer.address_id,customer.activebool,customer.create_date,customer.last_update,customer.active FROM dvds.customer") - + assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id",customer.store_id AS "customer.store_id",customer.first_name AS "customer.first_name",customer.last_name AS "customer.last_name",customer.email AS "customer.email",customer.address_id AS "customer.address_id",customer.activebool AS "customer.activebool",customer.create_date AS "customer.create_date",customer.last_update AS "customer.last_update",customer.active AS "customer.active" FROM dvds.customer`) //fmt.Println(queryStr) err = query.Execute(db, &customers) @@ -93,9 +92,34 @@ func TestSelectQuery(t *testing.T) { assert.NilError(t, err) - spew.Dump(actor) + //spew.Dump(actor) //time, _ := time.Parse("2006-01-02 15:04:05.00MST", "2013-05-26 14:47:57.62MST") assert.Equal(t, actor.ActorID, int32(1)) assert.Equal(t, actor.FirstName, "Penelope") assert.Equal(t, actor.LastName, "Guiness") } + +func TestJoinQuery(t *testing.T) { + + //filmActor := model.FilmActor{} + allFilmActorColumns := append(append(Actor.All, Film.All...), Language.All...) + query := FilmActor. + InnerJoinOn(Actor, sqlbuilder.Eq(FilmActor.ActorID, Actor.ActorID)). + InnerJoinOn(Film, sqlbuilder.Eq(FilmActor.FilmID, Film.FilmID)). + InnerJoinOn(Language, sqlbuilder.Eq(Film.LanguageID, Language.LanguageID)). + Select(allFilmActorColumns...). + Where(sqlbuilder.Eq(FilmActor.ActorID, sqlbuilder.Literal(1))) + + queryStr, err := query.String() + assert.NilError(t, err) + + fmt.Println(queryStr) + + filmActor := model.FilmActor{} + + err = query.Execute(db, &filmActor) + + assert.NilError(t, err) + + //spew.Dump(filmActor) +}