From 8049b2ec01efd51e3d1bc35db108dd998ce57789 Mon Sep 17 00:00:00 2001 From: sub0Zero Date: Fri, 15 Mar 2019 21:55:43 +0100 Subject: [PATCH] Order by column simplified. --- sqlbuilder/column.go | 149 +++++++++++++++++------------- sqlbuilder/execution/execution.go | 19 +++- sqlbuilder/table.go | 12 +-- tests/generator_test.go | 62 +++++++++++++ 4 files changed, 167 insertions(+), 75 deletions(-) diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 0c2039a..4f957b6 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -4,7 +4,6 @@ package sqlbuilder import ( "bytes" - "github.com/dropbox/godropbox/database/sqltypes" "regexp" "github.com/dropbox/godropbox/errors" @@ -34,6 +33,9 @@ type Column interface { Lte(rhs Expression) BoolExpression LteLiteral(rhs interface{}) BoolExpression + + Asc() OrderByClause + Desc() OrderByClause } type NullableColumn bool @@ -93,7 +95,7 @@ func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { return nil } -func (c *baseColumn) SerializeSql(out *bytes.Buffer) error { +func (c baseColumn) SerializeSql(out *bytes.Buffer) error { if c.table != "" { _, _ = out.WriteString(c.table) _, _ = out.WriteString(".") @@ -123,6 +125,14 @@ func (c *baseColumn) LteLiteral(literal interface{}) BoolExpression { return Lte(c, Literal(literal)) } +func (c *baseColumn) Asc() OrderByClause { + return Asc(c) +} + +func (c *baseColumn) Desc() OrderByClause { + return Desc(c) +} + type bytesColumn struct { baseColumn isExpression @@ -295,66 +305,75 @@ func validIdentifierName(name string) bool { return validIdentifierRegexp.MatchString(name) } -// Pseudo Column type returned by table.C(name) -type deferredLookupColumn struct { - isProjection - isExpression - table *Table - colName string - - cachedColumn NonAliasColumn -} - -func (c *deferredLookupColumn) Name() string { - return c.colName -} - -func (c *deferredLookupColumn) SerializeSqlForColumnList( - out *bytes.Buffer) error { - - return c.SerializeSql(out) -} - -func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error { - if c.cachedColumn != nil { - return c.cachedColumn.SerializeSql(out) - } - - col, err := c.table.getColumn(c.colName) - if err != nil { - return err - } - - c.cachedColumn = col - return col.SerializeSql(out) -} - -func (c *deferredLookupColumn) setTableName(table string) error { - return errors.Newf( - "Lookup column '%s' should never have setTableName called on it", - c.colName) -} - -func (c *deferredLookupColumn) Eq(rhs Expression) BoolExpression { - lit, ok := rhs.(*literalExpression) - if ok && sqltypes.Value(lit.value).IsNull() { - return newBoolExpression(c, rhs, []byte(" IS ")) - } - return newBoolExpression(c, rhs, []byte(" = ")) -} - -func (c *deferredLookupColumn) Gte(rhs Expression) BoolExpression { - return Gte(c, rhs) -} - -func (c *deferredLookupColumn) GteLiteral(rhs interface{}) BoolExpression { - return Gte(c, Literal(rhs)) -} - -func (c *deferredLookupColumn) Lte(rhs Expression) BoolExpression { - return Lte(c, rhs) -} - -func (c *deferredLookupColumn) LteLiteral(literal interface{}) BoolExpression { - return Lte(c, Literal(literal)) -} +// +//// Pseudo Column type returned by table.C(name) +//type deferredLookupColumn struct { +// isProjection +// isExpression +// table *Table +// colName string +// +// cachedColumn NonAliasColumn +//} +// +//func (c *deferredLookupColumn) Name() string { +// return c.colName +//} +// +//func (c *deferredLookupColumn) SerializeSqlForColumnList( +// out *bytes.Buffer) error { +// +// return c.SerializeSql(out) +//} +// +//func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error { +// if c.cachedColumn != nil { +// return c.cachedColumn.SerializeSql(out) +// } +// +// col, err := c.table.getColumn(c.colName) +// if err != nil { +// return err +// } +// +// c.cachedColumn = col +// return col.SerializeSql(out) +//} +// +//func (c *deferredLookupColumn) setTableName(table string) error { +// return errors.Newf( +// "Lookup column '%s' should never have setTableName called on it", +// c.colName) +//} +// +//func (c *deferredLookupColumn) Eq(rhs Expression) BoolExpression { +// lit, ok := rhs.(*literalExpression) +// if ok && sqltypes.Value(lit.value).IsNull() { +// return newBoolExpression(c, rhs, []byte(" IS ")) +// } +// return newBoolExpression(c, rhs, []byte(" = ")) +//} +// +//func (c *deferredLookupColumn) Gte(rhs Expression) BoolExpression { +// return Gte(c, rhs) +//} +// +//func (c *deferredLookupColumn) GteLiteral(rhs interface{}) BoolExpression { +// return Gte(c, Literal(rhs)) +//} +// +//func (c *deferredLookupColumn) Lte(rhs Expression) BoolExpression { +// return Lte(c, rhs) +//} +// +//func (c *deferredLookupColumn) LteLiteral(literal interface{}) BoolExpression { +// return Lte(c, Literal(literal)) +//} +// +//func (c *deferredLookupColumn) Asc() OrderByClause { +// return sqlbuilder.Asc(c) +//} +// +//func (c *deferredLookupColumn) Desc() OrderByClause { +// return sqlbuilder.Desc(c) +//} diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 1f67978..52c2307 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/serenize/snaker" "reflect" + "strconv" "strings" "time" ) @@ -36,6 +37,7 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { rowData := createScanValue(columnTypes) scanContext := &scanContext{ + columnNames: columnNames, uniqueObjectsMap: make(map[string]interface{}), } @@ -47,6 +49,8 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { return err } + scanContext.rowNum++ + columnProcessed := make([]bool, len(columnTypes)) if destinationType.Elem().Kind() == reflect.Slice { @@ -70,6 +74,7 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { } type scanContext struct { + rowNum int columnNames []string uniqueObjectsMap map[string]interface{} } @@ -107,13 +112,13 @@ func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect columnName := snaker.CamelToSnake(structName) + "." + snaker.CamelToSnake(fieldName) //fmt.Println(fieldName) - rowIndex := getIndex(scanContext.columnNames, columnName) + index := getIndex(scanContext.columnNames, columnName) - if rowIndex < 0 { + if index < 0 { continue } - rowValue := reflect.ValueOf(row[rowIndex]) + rowValue := reflect.ValueOf(row[index]) groupKey = groupKey + reflectValueToString(rowValue) } else if !isDbBaseType(fieldType.Type) { @@ -161,7 +166,13 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, columnProcessed [] structType := getSliceStructType(destinationPtr) - groupKey = groupKey + ":" + getGroupKey(scanContext, row, structType) + structGroupKey := getGroupKey(scanContext, row, structType) + + if structGroupKey == "" { + structGroupKey = strconv.Itoa(scanContext.rowNum) + } + + groupKey = groupKey + ":" + structGroupKey objPtr, ok := scanContext.uniqueObjectsMap[groupKey] diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 2a58bfd..d8ccb37 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -97,12 +97,12 @@ func (t *Table) getColumn(name string) (NonAliasColumn, error) { // Returns a pseudo column representation of the column name. Error checking // is deferred to SerializeSql. -func (t *Table) C(name string) NonAliasColumn { - return &deferredLookupColumn{ - table: t, - colName: name, - } -} +//func (t *Table) C(name string) NonAliasColumn { +// return &deferredLookupColumn{ +// table: t, +// colName: name, +// } +//} // Returns all columns for a table as a slice of projections func (t *Table) Projections() []Projection { diff --git a/tests/generator_test.go b/tests/generator_test.go index d877b40..a0491de 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -219,6 +219,68 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Films), int(limit)) } +func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { + query := Customer.Select(Customer.FirstName, Customer.LastName, Customer.Email) + + customers := []model.Customer{} + + err := query.Execute(db, &customers) + + assert.NilError(t, err) + + //spew.Dump(customers) + + assert.Equal(t, len(customers), 599) +} + +func TestSelectOrderByAscDesc(t *testing.T) { + customersAsc := []model.Customer{} + + err := Customer.Select(Customer.CustomerID, Customer.FirstName, Customer.LastName). + OrderBy(Customer.FirstName.Asc()). + Execute(db, &customersAsc) + + assert.NilError(t, err) + + firstCustomerAsc := customersAsc[0] + lastCustomerAsc := customersAsc[len(customersAsc)-1] + + customersDesc := []model.Customer{} + err = Customer.Select(Customer.CustomerID, Customer.FirstName, Customer.LastName). + OrderBy(Customer.FirstName.Desc()). + Execute(db, &customersDesc) + + assert.NilError(t, err) + + firstCustomerDesc := customersDesc[0] + lastCustomerDesc := customersDesc[len(customersAsc)-1] + + assert.DeepEqual(t, firstCustomerAsc, lastCustomerDesc) + assert.DeepEqual(t, lastCustomerAsc, firstCustomerDesc) + + customersAscDesc := []model.Customer{} + err = Customer.Select(Customer.CustomerID, Customer.FirstName, Customer.LastName). + OrderBy(Customer.FirstName.Asc(), Customer.LastName.Desc()). + Execute(db, &customersAscDesc) + + assert.NilError(t, err) + + customerAscDesc326 := model.Customer{ + CustomerID: 67, + FirstName: "Kelly", + LastName: "Torres", + } + + customerAscDesc327 := model.Customer{ + CustomerID: 546, + FirstName: "Kelly", + LastName: "Knott", + } + + assert.DeepEqual(t, customerAscDesc326, customersAscDesc[326]) + assert.DeepEqual(t, customerAscDesc327, customersAscDesc[327]) +} + func int32Ptr(i int32) *int32 { return &i }