From e656fb610cd097f6455ef5f4ea3116bb788a12a7 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Mon, 20 May 2019 17:37:55 +0200 Subject: [PATCH] Query group scan refactoring. --- generator/metadata/column_info.go | 69 ++-- go-sqlbuilder.iml | 9 + sqlbuilder/execution/execution.go | 545 ++++++++++++------------ sqlbuilder/func_expression.go | 4 + sqlbuilder/numeric_expression.go | 5 + tests/main_test.go | 52 ++- tests/scan_test.go | 660 ++++++++++++++++++++++++++++++ tests/select_test.go | 319 ++++++++++----- tests/test_util.go | 8 +- 9 files changed, 1273 insertions(+), 398 deletions(-) create mode 100644 go-sqlbuilder.iml create mode 100644 tests/scan_test.go diff --git a/generator/metadata/column_info.go b/generator/metadata/column_info.go index 1a8463f..19522fd 100644 --- a/generator/metadata/column_info.go +++ b/generator/metadata/column_info.go @@ -54,7 +54,7 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string { func (c ColumnInfo) ToGoType() string { typeStr := c.GoBaseType() - if c.IsNullable || c.TableInfo.IsForeignKey(c.Name) { + if c.IsNullable { return "*" + typeStr } @@ -62,47 +62,40 @@ func (c ColumnInfo) ToGoType() string { } func (c ColumnInfo) GoBaseType() string { - if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok { - return snaker.SnakeToCamel(forignKeyTable) - } else { - switch c.DataType { - case "USER-DEFINED": - return snaker.SnakeToCamel(c.EnumName) - case "boolean": - return "bool" - case "smallint": - return "int16" - case "integer": - return "int32" - case "bigint": - return "int64" - case "date", "timestamp without time zone", "timestamp with time zone": - return "time.Time" - case "bytea": - return "[]byte" - case "text", "character", "character varying": - return "string" - case "real": - return "float32" - case "numeric", "double precision": - return "float64" - case "uuid": - return "uuid.UUID" - case "json", "jsonb": - return "types.JSONText" - default: - fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.") - return "string" - } + switch c.DataType { + case "USER-DEFINED": + return snaker.SnakeToCamel(c.EnumName) + case "boolean": + return "bool" + case "smallint": + return "int16" + case "integer": + return "int32" + case "bigint": + return "int64" + case "date", "timestamp without time zone", "timestamp with time zone": + return "time.Time" + case "bytea": + return "[]byte" + case "text", "character", "character varying": + return "string" + case "real": + return "float32" + case "numeric", "double precision": + return "float64" + case "uuid": + return "uuid.UUID" + case "json", "jsonb": + return "types.JSONText" + default: + fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.") + return "string" } } func (c ColumnInfo) ToGoDMFieldName() string { - if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok { - return snaker.SnakeToCamel(forignKeyTable) - } else { - return snaker.SnakeToCamel(c.Name) - } + return snaker.SnakeToCamel(c.Name) + } func (c ColumnInfo) ToGoFieldName() string { diff --git a/go-sqlbuilder.iml b/go-sqlbuilder.iml new file mode 100644 index 0000000..8021953 --- /dev/null +++ b/go-sqlbuilder.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 84880b9..d2c3a33 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -8,24 +8,64 @@ import ( "github.com/serenize/snaker" "github.com/sub0zero/go-sqlbuilder/types" "reflect" - "regexp" "strconv" "strings" "time" ) func Query(db types.Db, query string, args []interface{}, destinationPtr interface{}) error { + + if destinationPtr == nil { + return errors.New("Destination is nil. ") + } + + destinationPtrType := reflect.TypeOf(destinationPtr) + if destinationPtrType.Kind() != reflect.Ptr { + return errors.New("Destination has to be a pointer to slice or pointer to struct. ") + } + + if destinationPtrType.Elem().Kind() == reflect.Slice { + return queryToSlice(db, query, args, destinationPtr) + } else if destinationPtrType.Elem().Kind() == reflect.Struct { + tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) + tempSliceValue := tempSlicePtrValue.Elem() + + err := queryToSlice(db, query, args, tempSlicePtrValue.Interface()) + + if err != nil { + return err + } + + fmt.Println("TEMP SLICE SIZE: ", tempSliceValue.Len()) + + if tempSliceValue.Len() == 0 { + return nil + } + + structValue := reflect.ValueOf(destinationPtr).Elem() + firstTempStruct := tempSliceValue.Index(0).Elem() + + if structValue.Type().AssignableTo(firstTempStruct.Type()) { + structValue.Set(tempSliceValue.Index(0).Elem()) + } + return nil + } else { + return errors.New("Unsupported destination type. ") + } +} + +func queryToSlice(db types.Db, query string, args []interface{}, slicePtr interface{}) error { if db == nil { return errors.New("db is nil") } - if destinationPtr == nil { - return errors.New("Destination is nil ") + if slicePtr == nil { + return errors.New("Destination is nil. ") } - destinationType := reflect.TypeOf(destinationPtr) - if destinationType.Kind() != reflect.Ptr { - return errors.New("Destination has to be a pointer to slice or pointer to struct ") + destinationType := reflect.TypeOf(slicePtr) + if destinationType.Kind() != reflect.Ptr && destinationType.Elem().Kind() != reflect.Slice { + return errors.New("Destination has to be a pointer to slice. ") } rows, err := db.Query(query, args...) @@ -35,16 +75,17 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa } defer rows.Close() - columnNames, _ := rows.Columns() - columnTypes, _ := rows.ColumnTypes() + scanContext, err := newScanContext(rows) - scanContext := &scanContext{ - row: createScanValue(columnTypes), - columnNames: columnNames, - uniqueObjectsMap: make(map[string]interface{}), + if err != nil { + return err } - //spew.Dump(columnTypes) + if len(scanContext.row) == 0 { + return nil + } + + groupTime := time.Duration(0) for rows.Next() { err := rows.Scan(scanContext.row...) @@ -55,17 +96,19 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa scanContext.rowNum++ - if destinationType.Elem().Kind() == reflect.Slice { - err := mapRowToSlice(scanContext, "", map[string]bool{}, destinationPtr, nil) + begin := time.Now() - if err != nil { - return err - } - } else if destinationType.Elem().Kind() == reflect.Struct { - return mapRowToStruct(scanContext, "", map[string]bool{}, destinationPtr, nil) + _, err = mapRowToSlice(scanContext, "", reflect.ValueOf(slicePtr), nil) + + if err != nil { + return err } + + groupTime += time.Now().Sub(begin) } + fmt.Println(groupTime.String()) + err = rows.Err() if err != nil { @@ -82,68 +125,78 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa return nil } -type scanContext struct { - rowNum int - columnNames []string +func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { - row []interface{} - uniqueObjectsMap map[string]interface{} -} + sliceElemType := getSliceElemType(slicePtrValue) -func getColumnTypeName(columnName string) (string, error) { - split := strings.Split(columnName, ".") - if len(split) != 2 { - return "", errors.New("Invalid column name") + if isGoBaseType(sliceElemType) { + index := 0 + if structField != nil { + columnName := getRefTableNameFrom(structField) + index = getIndex(scanContext.columnNames, columnName) + + if index < 0 { + return + } + } + rowElemPtr := scanContext.rowElemPtr(index) + + if !rowElemPtr.IsNil() { + appendElemToSlice(slicePtrValue, rowElemPtr) + } + + return } - return split[0], nil -} + if sliceElemType.Kind() != reflect.Struct { + return false, errors.New("Unsupported dest type: " + structField.Name + " " + structField.Type.String()) + } -func allProcessed(arr []bool) bool { - for _, b := range arr { - if !b { - return false + structGroupKey := getGroupKey(scanContext, sliceElemType, structField) + + if structGroupKey == "" { + structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|" + } + + groupKey = groupKey + ":" + structGroupKey + + index, ok := scanContext.uniqueObjectsMap[groupKey] + + if ok { + structPtrValue := getSliceElemPtrAt(slicePtrValue, index) + + return mapRowToStruct(scanContext, groupKey, structPtrValue, structField) + } else { + destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) + + updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, structField) + + if err != nil { + return + } + + if updated { + scanContext.uniqueObjectsMap[groupKey] = slicePtrValue.Elem().Len() + appendElemToSlice(slicePtrValue, destinationStructPtr) } } - return true + return } -func getType(reflectType reflect.Type) string { - var structType reflect.Type - if reflectType.Kind() == reflect.Struct { - structType = reflectType - } else if reflectType.Kind() == reflect.Ptr && reflectType.Elem().Kind() == reflect.Struct { - structType = reflectType.Elem() - } - - return structType.Name() -} - -func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string { - tableName := getTableAlias(structField) - - //fmt.Println("Group: " + tableName) +func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string { + tableName := getRefTableNameFrom(structField) if tableName == "" { tableName = snaker.CamelToSnake(structType.Name()) } - //fmt.Println(tableName) - - if typesProcessed[tableName] { - return "" - } - - typesProcessed[tableName] = true - groupKeys := []string{} for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) - ////fmt.Println(field.Tag) - if !isDbBaseType(field.Type) { + if !isGoBaseType(field.Type) { var structType reflect.Type if field.Type.Kind() == reflect.Struct { structType = field.Type @@ -153,11 +206,7 @@ func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, struc continue } - //spew.Dump(structType) - - structGroupKey := getGroupKey(scanContext, typesProcessed, structType, &field) - - //groupKey = strings.Join([]string{structGroupKey, groupKey}, ":") + structGroupKey := getGroupKey(scanContext, structType, &field) if structGroupKey != "" { groupKeys = append(groupKeys, structGroupKey) @@ -166,15 +215,14 @@ func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, struc fieldName := field.Name columnName := tableName + "." + snaker.CamelToSnake(fieldName) - //fmt.Println(fieldName) index := getIndex(scanContext.columnNames, columnName) if index < 0 { continue } - cellValue := cellValue(scanContext.row, index) - subKey := reflectValueToString(cellValue) + cellValue := scanContext.rowElem(index) + subKey := valueToString(cellValue) if subKey != "" { groupKeys = append(groupKeys, subKey) @@ -186,35 +234,13 @@ func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, struc return "" } - return "|" + structType.Name() + "(" + strings.Join(groupKeys, ", ") + ")|" + groupKey := "{" + structType.Name() + "(" + strings.Join(groupKeys, ",") + ")}" + + return groupKey } -func cellValue(row []interface{}, index int) interface{} { - //spew.Dump(row[index]) - - valuer, ok := row[index].(driver.Valuer) - - if !ok { - //fmt.Println("____________________") - //spew.Dump(row[index]) - panic("Scan value doesn't implement driver.Valuer") - } - - //spew.Dump(valuer) - - value, err := valuer.Value() - - if err != nil { - panic(err) - } - - //spew.Dump(value) - - return value -} - -func getSliceStructType(slicePtr interface{}) reflect.Type { - sliceTypePtr := reflect.TypeOf(slicePtr) +func getSliceElemType(slicePtrValue reflect.Value) reflect.Type { + sliceTypePtr := slicePtrValue.Type() elemType := sliceTypePtr.Elem().Elem() @@ -225,148 +251,101 @@ func getSliceStructType(slicePtr interface{}) reflect.Type { return elemType } -func cloneProcessedMap(processedMap map[string]bool) map[string]bool { - newMap := make(map[string]bool, len(processedMap)) +func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value { + sliceValue := slicePtrValue.Elem() + elem := sliceValue.Index(index) - for k, v := range newMap { - newMap[k] = v + if elem.Kind() == reflect.Ptr { + return elem } - return newMap + return elem.Addr() } -func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error { - var err error +func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) { + if slicePtrValue.IsNil() { + panic("Slice is nil") + } + sliceValue := slicePtrValue.Elem() + sliceElemType := sliceValue.Type().Elem() - structType := getSliceStructType(destinationPtr) + newElemValue := objPtrValue - structGroupKey := getGroupKey(scanContext, cloneProcessedMap(typesProcessed), structType, structField) - - if structGroupKey == "" { - structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|" + if sliceElemType.Kind() != reflect.Ptr { + newElemValue = objPtrValue.Elem() } - groupKey = groupKey + ":" + structGroupKey - - //fmt.Println(groupKey) - - objPtr, ok := scanContext.uniqueObjectsMap[groupKey] - - if ok { - err = mapRowToStruct(scanContext, groupKey, typesProcessed, objPtr, structField) - if err != nil { - return err - } - } else { - destinationStructPtr := newElemForSlice(destinationPtr) - - err = mapRowToStruct(scanContext, groupKey, typesProcessed, destinationStructPtr, structField) - - if err != nil { - return err - } - - elemPtr := appendElemToSlice(destinationPtr, destinationStructPtr) - scanContext.uniqueObjectsMap[groupKey] = elemPtr + if newElemValue.Type().AssignableTo(sliceElemType) { + sliceValue.Set(reflect.Append(sliceValue, newElemValue)) } - - return err } -func appendElemToSlice(slice interface{}, objPtr interface{}) interface{} { - sliceValue := reflect.ValueOf(slice).Elem() - elemType := sliceValue.Type().Elem() - - if elemType.Kind() == reflect.Ptr { - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(objPtr))) - return sliceValue.Index(sliceValue.Len() - 1).Interface() - } - - sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(objPtr).Elem())) - - return sliceValue.Index(sliceValue.Len() - 1).Addr().Interface() -} - -func newElemForSlice(destinationSlicePtr interface{}) interface{} { - destinationSliceType := reflect.TypeOf(destinationSlicePtr).Elem() +func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value { + destinationSliceType := slicePtrValue.Type().Elem() elemType := destinationSliceType.Elem() if elemType.Kind() == reflect.Ptr { - return reflect.New(elemType.Elem()).Interface() + return reflect.New(elemType.Elem()) } - return reflect.New(elemType).Interface() + return reflect.New(elemType) } -func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, dest reflect.Value, structField *reflect.StructField) error { - if dest.Kind() == reflect.Struct { - err := mapRowToStruct(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField) - if err != nil { - return err - } - } else if dest.Kind() == reflect.Slice { - err := mapRowToSlice(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField) - if err != nil { - return err - } +func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { + + if destPtrValue.Kind() != reflect.Ptr { + return false, errors.New("Internal error. ") + } + + destValueKind := destPtrValue.Elem().Kind() + + if destValueKind == reflect.Struct { + return mapRowToStruct(scanContext, groupKey, destPtrValue, structField) + } else if destValueKind == reflect.Slice { + return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) + } else { + return false, errors.New("Unsupported dest type: " + structField.Name + " " + structField.Type.String()) + } +} + +func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) { + + var destPtrValue reflect.Value + + if dest.Kind() != reflect.Ptr { + destPtrValue = dest.Addr() } else if dest.Kind() == reflect.Ptr { - elemType := dest.Type().Elem() - - if elemType.Kind() == reflect.Struct { - var structValuePtr reflect.Value - - if dest.IsNil() { - structValuePtr = reflect.New(elemType) - } else { - return nil - } - - err := mapRowToStruct(scanContext, groupKey, typesProcessed, structValuePtr.Interface(), structField) - if err != nil { - return err - } - - if structValuePtr.Elem().Interface() != reflect.New(elemType).Elem().Interface() { - dest.Set(structValuePtr) - } - - } else if elemType.Kind() == reflect.Slice { - var sliceValuePtr reflect.Value - - if dest.IsNil() { - sliceValuePtr = reflect.New(elemType) - } else { - sliceValuePtr = dest - } - - err := mapRowToSlice(scanContext, groupKey, typesProcessed, sliceValuePtr.Interface(), structField) - if err != nil { - return err - } - - if sliceValuePtr.Elem().Len() > 0 { - dest.Set(sliceValuePtr) - } - + if dest.IsNil() { + destPtrValue = reflect.New(dest.Type().Elem()) } else { - return errors.New("Unsuported field type: " + dest.Type().Name()) + destPtrValue = dest } } else { - return errors.New("Unsuported field type: " + dest.Type().Name()) + return false, errors.New("Internal error. ") } - return nil + updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) + + if err != nil { + return + } + + if dest.Kind() == reflect.Ptr && dest.IsNil() && updated { + dest.Set(destPtrValue) + } + + return } -func getTableAlias(structField *reflect.StructField) string { +func getRefTableNameFrom(structField *reflect.StructField) string { if structField == nil { return "" } - re := regexp.MustCompile(`sqlbuilder:"(.*?)"`) - tagMatch := re.FindStringSubmatch(string(structField.Tag)) - if tagMatch != nil && len(tagMatch) == 2 && tagMatch[1] != "" { - return tagMatch[1] + tagOverwriteName := structField.Tag.Get("sqlbuilder") + + if tagOverwriteName != "" { + return tagOverwriteName } if !structField.Anonymous { @@ -398,33 +377,20 @@ func getTableAlias(structField *reflect.StructField) string { return snaker.CamelToSnake(elemType) } -func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error { - structType := reflect.TypeOf(destinationPtr).Elem() - structValue := reflect.ValueOf(destinationPtr).Elem() +func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { + structType := structPtrValue.Type().Elem() + structValue := structPtrValue.Elem() - tableName := getTableAlias(structField) + tableName := getRefTableNameFrom(structField) if tableName == "" { tableName = snaker.CamelToSnake(structType.Name()) } - //fmt.Println("map -", tableName) - - if typesProcessed[tableName] { - //fmt.Println("Already processed") - return nil - } - - typesProcessed[tableName] = true - for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) fieldValue := structValue.Field(i) - //fieldTypeName := field.Name - //fmt.Println("---------------", fieldTypeName,) - //spew.Dump(field.Type) - fieldName := field.Name if scannerValue, ok := implementsScanner(fieldValue); ok { @@ -434,38 +400,53 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma continue } - //spew.Dump(scannerValue.Interface()) - - if scannerValue.IsNil() { - initializePtrValue(scannerValue) - } + initializeValueIfNil(fieldValue) scanner := scannerValue.Interface().(sql.Scanner) - err := scanner.Scan(cellValue) + err = scanner.Scan(cellValue) if err != nil { - return err + return } - } else if !isDbBaseType(field.Type) { - //var fieldValueInterface interface{} - err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, fieldValue, &field) + updated = true + } else if !isGoBaseType(field.Type) { + var changed bool + changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field) if err != nil { - return err + return + } + + if changed { + updated = true } } else { cellValue := getCellValue(scanContext, tableName, fieldName) - //spew.Dump(cellValue) + //spew.Dump(rowElem) //spew.Dump(rowColumnValue, fieldValue) if cellValue != nil { + updated = true + initializeValueIfNil(fieldValue) setReflectValue(reflect.ValueOf(cellValue), fieldValue) } } } - return nil + return +} + +func initializeValueIfNil(value reflect.Value) { + if !value.IsValid() || !value.CanSet() { + return + } + + if value.Type().Kind() == reflect.Slice && value.IsNil() { + value.Set(reflect.New(value.Type()).Elem()) + } else if value.Kind() == reflect.Ptr && value.IsNil() { + value.Set(reflect.New(value.Type().Elem())) + } } func implementsScanner(value reflect.Value) (reflect.Value, bool) { @@ -480,12 +461,6 @@ func implementsScanner(value reflect.Value) (reflect.Value, bool) { return value, false } -func initializePtrValue(value reflect.Value) { - if value.Kind() == reflect.Ptr { - value.Set(reflect.New(value.Type().Elem())) - } -} - func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} { columnName := "" @@ -495,28 +470,22 @@ func getCellValue(scanContext *scanContext, tableName, fieldName string) interfa columnName = tableName + "." + snaker.CamelToSnake(fieldName) } - //columnName := snaker.CamelToSnake(fieldName) - - ////fmt.Println(columnName) index := getIndex(scanContext.columnNames, columnName) if index < 0 { return nil } - return cellValue(scanContext.row, index) + return scanContext.rowElem(index) } -func reflectValueToString(val interface{}) string { - //spew.Dump(val) - +func valueToString(val interface{}) string { if val == nil { return "" } value := reflect.ValueOf(val) - //if !value.IsValid() var valueInterface interface{} if value.Kind() == reflect.Ptr { valueInterface = value.Elem().Interface() @@ -536,10 +505,7 @@ var floatType = reflect.TypeOf(1.0) var stringType = reflect.TypeOf("str") var intType = reflect.TypeOf(1) -func isDbBaseType(objType reflect.Type) bool { - //isBaseType := objType == timeType || floatType == objType || stringType == objType || intType == objType - //isPtrToBaseType := objType.Kind() == reflect.Ptr && (objType.Elem() == timeType || floatType == objType.Elem() || - // stringType == objType.Elem() || intType == objType.Elem()) +func isGoBaseType(objType reflect.Type) bool { typeStr := objType.String() switch typeStr { @@ -548,7 +514,6 @@ func isDbBaseType(objType reflect.Type) bool { return true } - //return isBaseType || isPtrToBaseType return false } @@ -604,8 +569,6 @@ var nullBoolType = reflect.TypeOf(sql.NullBool{}) var nullTimeType = reflect.TypeOf(NullTime{}) func newScanType(columnType *sql.ColumnType) reflect.Type { - //spew.Dump(columnType) - //fmt.Println(columnType.DatabaseTypeName()) switch columnType.DatabaseTypeName() { case "INT2": return nullInt16Type @@ -627,3 +590,67 @@ func newScanType(columnType *sql.ColumnType) reflect.Type { panic("Unknown column database type " + columnType.DatabaseTypeName()) } } + +type scanContext struct { + rowNum int + columnNames []string + + row []interface{} + uniqueObjectsMap map[string]int + groupKeyMap map[string]string +} + +func newScanContext(rows *sql.Rows) (*scanContext, error) { + columnNames, err := rows.Columns() + + if err != nil { + return nil, err + } + + columnTypes, err := rows.ColumnTypes() + + if err != nil { + return nil, err + } + + return &scanContext{ + row: createScanValue(columnTypes), + columnNames: columnNames, + uniqueObjectsMap: make(map[string]int), + groupKeyMap: make(map[string]string), + }, nil +} + +func (s *scanContext) rowElem(index int) interface{} { + + valuer, ok := s.row[index].(driver.Valuer) + + if !ok { + panic("Scan value doesn't implement driver.Valuer") + } + + value, err := valuer.Value() + + if err != nil { + panic(err) + } + + return value +} + +func (s *scanContext) rowElemPtr(index int) reflect.Value { + rowElem := s.rowElem(index) + rowElemValue := reflect.ValueOf(rowElem) + + if rowElemValue.Kind() == reflect.Ptr { + return rowElemValue + } + + if rowElemValue.CanAddr() { + return rowElemValue.Addr() + } + + newElem := reflect.New(rowElemValue.Type()) + newElem.Elem().Set(rowElemValue) + return newElem +} diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index 5695294..a4cc0ea 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -58,6 +58,10 @@ func NewNumericFunc(name string, expressions ...expression) numericExpression { return numericFunc } +func COUNT(expression numericExpression) numericExpression { + return NewNumericFunc("COUNT", expression) +} + func MAX(expression numericExpression) numericExpression { return NewNumericFunc("MAX", expression) } diff --git a/sqlbuilder/numeric_expression.go b/sqlbuilder/numeric_expression.go index 05dc258..7829a89 100644 --- a/sqlbuilder/numeric_expression.go +++ b/sqlbuilder/numeric_expression.go @@ -14,6 +14,7 @@ type numericExpression interface { GtEq(rhs numericExpression) boolExpression GtEqL(literal interface{}) boolExpression + Lt(rhs numericExpression) boolExpression LtEq(rhs numericExpression) boolExpression LtEqL(literal interface{}) boolExpression @@ -55,6 +56,10 @@ func (n *numericInterfaceImpl) GtEqL(literal interface{}) boolExpression { return GtEq(n.parent, Literal(literal)) } +func (n *numericInterfaceImpl) Lt(expression numericExpression) boolExpression { + return Lt(n.parent, expression) +} + func (n *numericInterfaceImpl) LtEq(expression numericExpression) boolExpression { return LtEq(n.parent, expression) } diff --git a/tests/main_test.go b/tests/main_test.go index e4b6e62..a271353 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" _ "github.com/lib/pq" + "github.com/pkg/profile" "github.com/sub0zero/go-sqlbuilder/generator" "gotest.tools/assert" "os" @@ -31,6 +32,8 @@ var db *sql.DB func TestMain(m *testing.M) { fmt.Println("Begin") + defer profile.Start().Stop() + var err error db, err = sql.Open("postgres", connectString) if err != nil { @@ -66,7 +69,36 @@ CREATE TABLE IF NOT EXISTS test_sample.link ( name VARCHAR (255) NOT NULL, description VARCHAR (255), rel VARCHAR (50) -);` +); + +DROP TABLE IF EXISTS test_sample.employee; + +CREATE TABLE test_sample.employee ( + employee_id INT PRIMARY KEY, + first_name VARCHAR (255) NOT NULL, + last_name VARCHAR (255) NOT NULL, + manager_id INT, + FOREIGN KEY (manager_id) + REFERENCES test_sample.employee (employee_id) + ON DELETE CASCADE +); +INSERT INTO test_sample.employee ( + employee_id, + first_name, + last_name, + manager_id +) +VALUES + (1, 'Windy', 'Hays', NULL), + (2, 'Ava', 'Christensen', 1), + (3, 'Hassan', 'Conner', 1), + (4, 'Anna', 'Reeves', 2), + (5, 'Sau', 'Norman', 2), + (6, 'Kelsie', 'Hays', 3), + (7, 'Tory', 'Goff', 3), + (8, 'Salley', 'Lester', 3); + +` result, err := db.Exec(linkTableCreate) @@ -78,6 +110,24 @@ CREATE TABLE IF NOT EXISTS test_sample.link ( } +func queryAll(t *testing.T, query string, args []interface{}) { + rows, err := db.Query(query, args...) + + assert.NilError(t, err) + + defer rows.Close() + + for rows.Next() { + //err := rows.Scan(scanContext.row...) + // + //assert.NilError(t, err) + } + + err = rows.Err() + + assert.NilError(t, err) +} + func TestGenerateModel(t *testing.T) { err := generator.Generate(folderPath, connectString, dbname, schemaName) diff --git a/tests/scan_test.go b/tests/scan_test.go new file mode 100644 index 0000000..9e13949 --- /dev/null +++ b/tests/scan_test.go @@ -0,0 +1,660 @@ +package tests + +import ( + . "github.com/sub0zero/go-sqlbuilder/sqlbuilder" + "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" + . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" + "gotest.tools/assert" + "testing" +) + +var query = Inventory. + SELECT(Inventory.AllColumns). + LIMIT(1). + ORDER_BY(Inventory.InventoryID) + +func TestScanToInvalidDestination(t *testing.T) { + + t.Run("nil dest", func(t *testing.T) { + err := query.Query(db, nil) + + assert.Error(t, err, "Destination is nil. ") + }) + + t.Run("struct dest", func(t *testing.T) { + err := query.Query(db, struct{}{}) + + assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ") + }) + + t.Run("slice dest", func(t *testing.T) { + err := query.Query(db, []struct{}{}) + + assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ") + }) + + t.Run("slice of pointers to pointer dest", func(t *testing.T) { + err := query.Query(db, []**struct{}{}) + + assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ") + }) + + t.Run("map dest", func(t *testing.T) { + err := query.Query(db, []map[string]string{}) + + assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ") + }) +} + +func TestScanToValidDestination(t *testing.T) { + t.Run("pointer to struct", func(t *testing.T) { + err := query.Query(db, &struct{}{}) + + assert.NilError(t, err) + }) + + t.Run("pointer to slice", func(t *testing.T) { + err := query.Query(db, &[]struct{}{}) + + assert.NilError(t, err) + }) + + t.Run("pointer to slice of pointer to structs", func(t *testing.T) { + err := query.Query(db, &[]*struct{}{}) + + assert.NilError(t, err) + }) + + t.Run("pointer to slice of strings", func(t *testing.T) { + err := query.Query(db, &[]string{}) + + assert.NilError(t, err) + }) +} + +func TestScanToStruct(t *testing.T) { + query := Inventory. + SELECT(Inventory.AllColumns). + ORDER_BY(Inventory.InventoryID) + + t.Run("one struct", func(t *testing.T) { + dest := model.Inventory{} + err := query.LIMIT(1).Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, inventory1, dest) + }) + + t.Run("multiple structs, just first one used", func(t *testing.T) { + dest := model.Inventory{} + err := query.LIMIT(10).Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, inventory1, dest) + }) + + t.Run("one struct", func(t *testing.T) { + dest := struct { + model.Inventory + }{} + err := query.LIMIT(1).Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, inventory1, dest.Inventory) + }) + + t.Run("one struct", func(t *testing.T) { + dest := struct { + *model.Inventory + }{} + err := query.LIMIT(1).Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, inventory1, *dest.Inventory) + }) + + t.Run("invalid dest", func(t *testing.T) { + dest := struct { + Inventory **model.Inventory + }{} + + err := query.Query(db, &dest) + + assert.Error(t, err, "Unsupported dest type: Inventory **model.Inventory") + }) + + t.Run("invalid dest 2", func(t *testing.T) { + dest := struct { + Inventory ***model.Inventory + }{} + + err := query.Query(db, &dest) + + assert.Error(t, err, "Unsupported dest type: Inventory ***model.Inventory") + }) + +} + +func TestScanToNestedStruct(t *testing.T) { + query := Inventory. + INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)). + INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)). + SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns). + WHERE(Inventory.InventoryID.EqL(1)) + + t.Run("embedded structs", func(t *testing.T) { + dest := struct { + model.Inventory + model.Film + model.Store + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.DeepEqual(t, dest.Film, film1) + assert.DeepEqual(t, dest.Store, store1) + }) + + t.Run("embedded pointer structs", func(t *testing.T) { + dest := struct { + *model.Inventory + *model.Film + *model.Store + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, *dest.Inventory, inventory1) + assert.DeepEqual(t, *dest.Film, film1) + assert.DeepEqual(t, *dest.Store, store1) + }) + + t.Run("embedded unused structs", func(t *testing.T) { + dest := struct { + model.Inventory + model.Actor //unused + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.DeepEqual(t, dest.Actor, model.Actor{}) + }) + + t.Run("embedded unused pointer structs", func(t *testing.T) { + dest := struct { + model.Inventory + *model.Actor //unused + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil)) + }) + + t.Run("embedded unused pointer structs", func(t *testing.T) { + dest := struct { + model.Inventory + Actor *model.Actor //unused + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil)) + }) + + t.Run("embedded pointer to selected column", func(t *testing.T) { + query := Inventory. + INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)). + INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)). + SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Literal("").AS("actor.first_name")). + WHERE(Inventory.InventoryID.EqL(1)) + + dest := struct { + model.Inventory + Actor *model.Actor //unused + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.Assert(t, dest.Actor != nil) + }) + + t.Run("struct embedded unused pointer", func(t *testing.T) { + dest := struct { + model.Inventory + Actor *struct { + model.Actor + } //unused + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.DeepEqual(t, dest.Actor, (*struct{ model.Actor })(nil)) + }) + + t.Run("multiple embedded unused pointer", func(t *testing.T) { + dest := struct { + model.Inventory + Actor *struct { + model.Actor //unused + model.Language //unesed + } + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.DeepEqual(t, dest.Actor, (*struct { + model.Actor + model.Language + })(nil)) + }) + + t.Run("field not nil, embedded selected model", func(t *testing.T) { + dest := struct { + model.Inventory + Actor *struct { + model.Actor //unselected + model.Film //selected + } + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.Assert(t, dest.Actor != nil) + assert.DeepEqual(t, dest.Actor.Actor, model.Actor{}) + assert.DeepEqual(t, dest.Actor.Film, film1) + }) + + t.Run("field not nil, deeply nested selected model", func(t *testing.T) { + dest := struct { + model.Inventory + Actor *struct { + model.Actor //unselected + Film *struct { + *model.Film //selected + } + } + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.Assert(t, dest.Actor != nil) + assert.Assert(t, dest.Actor.Film != nil) + assert.DeepEqual(t, dest.Actor.Film.Film, &film1) + }) + + t.Run("embedded structs", func(t *testing.T) { + query := Inventory. + INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)). + INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)). + INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). + SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns). + WHERE(Inventory.InventoryID.EqL(1)) + + dest := struct { + model.Inventory + Film struct { + model.Film + + Language model.Language + Language2 *model.Language + Language3 *model.Language `sqlbuilder:"language"` + Lang struct { + model.Language + } + Lang2 *struct { + model.Language + } + } + Store model.Store + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.DeepEqual(t, dest.Inventory, inventory1) + assert.DeepEqual(t, dest.Film.Film, film1) + assert.DeepEqual(t, dest.Store, store1) + assert.DeepEqual(t, dest.Film.Language, language1) + assert.DeepEqual(t, dest.Film.Lang.Language, language1) + assert.DeepEqual(t, dest.Film.Lang2.Language, language1) + assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil)) + assert.DeepEqual(t, dest.Film.Language3, &language1) + }) +} + +func TestScanToSlice(t *testing.T) { + + t.Run("slice of structs", func(t *testing.T) { + query := Inventory. + SELECT(Inventory.AllColumns). + ORDER_BY(Inventory.InventoryID). + LIMIT(10) + + dest := []model.Inventory{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 10) + assert.DeepEqual(t, dest[0], inventory1) + assert.DeepEqual(t, dest[1], inventory2) + }) + + t.Run("slice of complex structs", func(t *testing.T) { + query := Inventory. + INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)). + INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)). + SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns). + ORDER_BY(Inventory.InventoryID). + LIMIT(10) + + t.Run("complex struct 1", func(t *testing.T) { + dest := []struct { + model.Inventory + model.Film + model.Store + }{} + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 10) + assert.DeepEqual(t, dest[0].Inventory, inventory1) + assert.DeepEqual(t, dest[0].Film, film1) + assert.DeepEqual(t, dest[0].Store, store1) + + assert.DeepEqual(t, dest[1].Inventory, inventory2) + }) + + t.Run("complex struct 2", func(t *testing.T) { + var dest []struct { + *model.Inventory + model.Film + *model.Store + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 10) + assert.DeepEqual(t, dest[0].Inventory, &inventory1) + assert.DeepEqual(t, dest[0].Film, film1) + assert.DeepEqual(t, dest[0].Store, &store1) + + assert.DeepEqual(t, dest[1].Inventory, &inventory2) + }) + + t.Run("complex struct 3", func(t *testing.T) { + var dest []struct { + Inventory model.Inventory + Film *model.Film + Store struct { + *model.Store + } + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 10) + assert.DeepEqual(t, dest[0].Inventory, inventory1) + assert.DeepEqual(t, dest[0].Film, &film1) + assert.DeepEqual(t, dest[0].Store.Store, &store1) + + assert.DeepEqual(t, dest[1].Inventory, inventory2) + }) + + t.Run("complex struct 4", func(t *testing.T) { + var dest []struct { + model.Film + + Inventories []struct { + model.Inventory + model.Store + } + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 2) + assert.DeepEqual(t, dest[0].Film, film1) + assert.DeepEqual(t, len(dest[0].Inventories), 8) + assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) + assert.DeepEqual(t, dest[0].Inventories[0].Store, store1) + }) + + t.Run("complex struct 5", func(t *testing.T) { + var dest []struct { + model.Film + + Inventories []struct { + model.Inventory + + Rentals *[]model.Rental + Rentals2 []model.Rental + } + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + + assert.Equal(t, len(dest), 2) + assert.DeepEqual(t, dest[0].Film, film1) + assert.Equal(t, len(dest[0].Inventories), 8) + assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1) + assert.Assert(t, dest[0].Inventories[0].Rentals == nil) + assert.Assert(t, dest[0].Inventories[0].Rentals2 == nil) + }) + }) + + t.Run("slice of complex structs 2", func(t *testing.T) { + query := Country. + INNER_JOIN(City, City.CountryID.Eq(Country.CountryID)). + INNER_JOIN(Address, Address.CityID.Eq(City.CityID)). + INNER_JOIN(Customer, Customer.AddressID.Eq(Address.AddressID)). + SELECT(Country.AllColumns, City.AllColumns, Address.AllColumns, Customer.AllColumns). + ORDER_BY(Country.CountryID.ASC(), City.CityID.ASC(), Address.AddressID.ASC(), Customer.CustomerID.ASC()). + LIMIT(1000) + + t.Run("dest1", func(t *testing.T) { + var dest []struct { + model.Country + + Cities []struct { + model.City + + Adresses []struct { + model.Address + + Customer model.Customer + } + } + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 108) + assert.DeepEqual(t, dest[100].Country, countryUk) + assert.Equal(t, len(dest[100].Cities), 8) + assert.DeepEqual(t, dest[100].Cities[2].City, cityLondon) + assert.Equal(t, len(dest[100].Cities[2].Adresses), 2) + assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Address, address256) + assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Customer, customer256) + assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Address, addres517) + assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Customer, customer512) + }) + + t.Run("dest1", func(t *testing.T) { + var dest []*struct { + *model.Country + + Cities []*struct { + *model.City + + Adresses *[]*struct { + *model.Address + + Customer *model.Customer + } + } + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 108) + assert.DeepEqual(t, dest[100].Country, &countryUk) + assert.Equal(t, len(dest[100].Cities), 8) + assert.DeepEqual(t, dest[100].Cities[2].City, &cityLondon) + assert.Equal(t, len(*dest[100].Cities[2].Adresses), 2) + assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Address, &address256) + assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Customer, &customer256) + assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Address, &addres517) + assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Customer, &customer512) + }) + + }) + + t.Run("dest1", func(t *testing.T) { + var dest []*struct { + *model.Country + + Cities []**struct { + *model.City + } + } + + err := query.Query(db, &dest) + + assert.Error(t, err, "Unsupported dest type: Cities []**struct { *model.City }") + }) +} + +var address256 = model.Address{ + AddressID: 256, + Address: "1497 Yuzhou Drive", + Address2: stringPtr(""), + District: "England", + CityID: 312, + PostalCode: stringPtr("3433"), + Phone: "246810237916", + LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:30", 0), +} + +var addres517 = model.Address{ + AddressID: 517, + Address: "548 Uruapan Street", + Address2: stringPtr(""), + District: "Ontario", + CityID: 312, + PostalCode: stringPtr("35653"), + Phone: "879347453467", + LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:30", 0), +} + +var customer256 = model.Customer{ + CustomerID: 252, + StoreID: 2, + FirstName: "Mattie", + LastName: "Hoffman", + Email: stringPtr("mattie.hoffman@sakilacustomer.org"), + AddressID: 256, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 0), + Active: int32Ptr(1), +} + +var customer512 = model.Customer{ + CustomerID: 512, + StoreID: 1, + FirstName: "Cecil", + LastName: "Vines", + Email: stringPtr("cecil.vines@sakilacustomer.org"), + AddressID: 517, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 0), + Active: int32Ptr(1), +} + +var countryUk = model.Country{ + CountryID: 102, + Country: "United Kingdom", + LastUpdate: *timeWithoutTimeZone("2006-02-15 09:44:00", 0), +} + +var cityLondon = model.City{ + CityID: 312, + City: "London", + CountryID: 102, + LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:25", 0), +} + +var inventory1 = model.Inventory{ + InventoryID: 1, + FilmID: 1, + StoreID: 1, + LastUpdate: *timeWithoutTimeZone("2006-02-15 10:09:17", 0), +} + +var inventory2 = model.Inventory{ + InventoryID: 2, + FilmID: 1, + StoreID: 1, + LastUpdate: *timeWithoutTimeZone("2006-02-15 10:09:17", 0), +} + +var film1 = model.Film{ + FilmID: 1, + Title: "Academy Dinosaur", + Description: stringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), + ReleaseYear: int32Ptr(2006), + LanguageID: 1, + RentalDuration: 6, + RentalRate: 0.99, + Length: int16Ptr(86), + ReplacementCost: 20.99, + Rating: &pgRating, + LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3), + SpecialFeatures: stringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", +} + +var store1 = model.Store{ + StoreID: 1, + ManagerStaffID: 1, + AddressID: 1, + LastUpdate: *timeWithoutTimeZone("2006-02-15 09:57:12", 0), +} + +var pgRating = model.MpaaRating_PG + +var language1 = model.Language{ + LanguageID: 1, + Name: "English ", + LastUpdate: *timeWithoutTimeZone("2006-02-15 10:02:19", 0), +} diff --git a/tests/select_test.go b/tests/select_test.go index db341c8..fd3ae37 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -5,6 +5,8 @@ import ( . "github.com/sub0zero/go-sqlbuilder/sqlbuilder" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" + model2 "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model" + . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table" "gotest.tools/assert" "testing" ) @@ -16,14 +18,12 @@ SELECT actor.actor_id AS "actor.actor_id", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor -WHERE actor.actor_id = 1 -ORDER BY actor.actor_id ASC; +WHERE actor.actor_id = 1; ` query := Actor. SELECT(Actor.AllColumns). - WHERE(Actor.ActorID.EqL(1)). - ORDER_BY(Actor.ActorID.ASC()) + WHERE(Actor.ActorID.EqL(1)) assertQuery(t, query, expectedSql, 1) @@ -79,8 +79,6 @@ LIMIT 30; assert.NilError(t, err) assert.Equal(t, len(dest), 30) - - //spew.Dump(dest) } func TestSelect_ScanToSlice(t *testing.T) { @@ -159,30 +157,99 @@ LIMIT 12; assertQuery(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) } -//func TestJoinQueryStruct(t *testing.T) { -// -// query := FilmActor. -// INNER_JOIN(Actor, FilmActor.ActorID.Eq(Actor.ActorID)). -// INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)). -// INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). -// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). -// WHERE(FilmActor.ActorID.GtEq(1).AND(FilmActor.ActorID.LteLiteral(2))) -// -// queryStr, args, err := query.Sql() -// assert.NilError(t, err) -// assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`) -// -// //fmt.Println(queryStr) -// -// filmActor := []model.FilmActor{} -// -// err = query.Execute(db, &filmActor) -// -// assert.NilError(t, err) -// -// //fmt.Println("ACTORS: --------------------") -// //spew.Dump(filmActor) -//} +func TestJoinQueryStruct(t *testing.T) { + + expectedSql := ` +SELECT film_actor.actor_id AS "film_actor.actor_id", + film_actor.film_id AS "film_actor.film_id", + film_actor.last_update AS "film_actor.last_update", + film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.last_update AS "film.last_update", + film.special_features AS "film.special_features", + film.fulltext AS "film.fulltext", + language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update", + inventory.inventory_id AS "inventory.inventory_id", + inventory.film_id AS "inventory.film_id", + inventory.store_id AS "inventory.store_id", + inventory.last_update AS "inventory.last_update", + rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM dvds.film_actor + JOIN dvds.actor ON film_actor.actor_id = actor.actor_id + JOIN dvds.film ON film_actor.film_id = film.film_id + JOIN dvds.language ON film.language_id = language.language_id + JOIN dvds.inventory ON inventory.film_id = film.film_id + JOIN dvds.rental ON rental.inventory_id = inventory.inventory_id +ORDER BY film.film_id ASC +LIMIT 50; +` + for i := 0; i < 1; i++ { + query := FilmActor. + INNER_JOIN(Actor, FilmActor.ActorID.Eq(Actor.ActorID)). + INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)). + INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). + INNER_JOIN(Inventory, Inventory.FilmID.Eq(Film.FilmID)). + INNER_JOIN(Rental, Rental.InventoryID.Eq(Inventory.InventoryID)). + SELECT( + FilmActor.AllColumns, + Film.AllColumns, + Language.AllColumns, + Actor.AllColumns, + Inventory.AllColumns, + Rental.AllColumns, + ). + //WHERE(FilmActor.ActorID.GtEqL(1).AND(FilmActor.ActorID.LtEqL(2))). + ORDER_BY(Film.FilmID.ASC()). + LIMIT(50) + + assertQuery(t, query, expectedSql, int64(50)) + + var languageActorFilm []struct { + model.Language + + Films []struct { + model.Film + Actors []struct { + model.Actor + } + + Inventory []struct { + model.Inventory + + Rental []model.Rental + } + } + } + + err := query.Query(db, &languageActorFilm) + + assert.NilError(t, err) + assert.Equal(t, len(languageActorFilm), 1) + assert.Equal(t, len(languageActorFilm[0].Films), 1) + assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10) + } + +} func TestJoinQuerySlice(t *testing.T) { expectedSql := ` @@ -408,7 +475,10 @@ LIMIT 1000; assertQuery(t, query, expectedSql, int64(1000)) - customerAddresCrosJoined := []model.Customer{} + var customerAddresCrosJoined []struct { + model.Customer + model.Address + } err := query.Query(db, &customerAddresCrosJoined) @@ -417,6 +487,57 @@ LIMIT 1000; assert.NilError(t, err) } +func TestSelecSelfJoin1(t *testing.T) { + + var expectedSql = ` +SELECT employee.employee_id AS "employee.employee_id", + employee.first_name AS "employee.first_name", + employee.last_name AS "employee.last_name", + employee.manager_id AS "employee.manager_id", + manager.employee_id AS "manager.employee_id", + manager.first_name AS "manager.first_name", + manager.last_name AS "manager.last_name", + manager.manager_id AS "manager.manager_id" +FROM test_sample.employee + LEFT JOIN test_sample.employee AS manager ON manager.employee_id = employee.manager_id +ORDER BY employee.employee_id; +` + + manager := Employee.AS("manager") + query := Employee. + LEFT_JOIN(manager, manager.EmployeeID.Eq(Employee.ManagerID)). + SELECT(Employee.AllColumns, manager.AllColumns). + ORDER_BY(Employee.EmployeeID) + + assertQuery(t, query, expectedSql) + + var dest []struct { + model2.Employee + + Manager *model2.Employee + } + + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), 8) + assert.DeepEqual(t, dest[0].Employee, model2.Employee{ + EmployeeID: 1, + FirstName: "Windy", + LastName: "Hays", + ManagerID: nil, + }) + + assert.Assert(t, dest[0].Manager == nil) + + assert.DeepEqual(t, dest[7].Employee, model2.Employee{ + EmployeeID: 8, + FirstName: "Salley", + LastName: "Lester", + ManagerID: int32Ptr(3), + }) +} + func TestSelectSelfJoin(t *testing.T) { expectedSql := ` SELECT f1.film_id AS "f1.film_id", @@ -446,21 +567,19 @@ SELECT f1.film_id AS "f1.film_id", f2.special_features AS "f2.special_features", f2.fulltext AS "f2.fulltext" FROM dvds.film AS f1 - JOIN dvds.film AS f2 ON (f1.film_id != f2.film_id AND f1.length = f2.length) -ORDER BY f1.film_id ASC -LIMIT 100; + JOIN dvds.film AS f2 ON (f1.film_id < f2.film_id AND f1.length = f2.length) +ORDER BY f1.film_id ASC; ` f1 := Film.AS("f1") f2 := Film.AS("f2") query := f1. - INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))). + INNER_JOIN(f2, f1.FilmID.Lt(f2.FilmID).AND(f1.Length.Eq(f2.Length))). SELECT(f1.AllColumns, f2.AllColumns). - ORDER_BY(f1.FilmID.ASC()). - LIMIT(100) + ORDER_BY(f1.FilmID.ASC()) - assertQuery(t, query, expectedSql, int64(100)) + assertQuery(t, query, expectedSql) type F1 model.Film type F2 model.Film @@ -474,7 +593,9 @@ LIMIT 100; assert.NilError(t, err) - assert.Equal(t, len(theSameLengthFilms), 100) + //spew.Dump(theSameLengthFilms) + + //assert.Equal(t, len(theSameLengthFilms), 100) } func TestSelectAliasColumn(t *testing.T) { @@ -517,61 +638,62 @@ LIMIT 1000; assert.DeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) } -type Manager staff - -type staff struct { - StaffID int32 `sql:"unique"` - FirstName string - LastName string - //Address *model.Address - //Email *string - //StoreID int16 - //Active bool - //Username string - //Password *string - //LastUpdate time.Time - *Manager //`sqlbuilder:"manager"` -} - -func TestSelectSelfReferenceType(t *testing.T) { - - expectedSql := ` -SELECT DISTINCT staff.staff_id AS "staff.staff_id", - staff.first_name AS "staff.first_name", - staff.last_name AS "staff.last_name", - address.address_id AS "address.address_id", - address.address AS "address.address", - address.address2 AS "address.address2", - address.district AS "address.district", - address.city_id AS "address.city_id", - address.postal_code AS "address.postal_code", - address.phone AS "address.phone", - address.last_update AS "address.last_update", - manager.staff_id AS "manager.staff_id", - manager.first_name AS "manager.first_name" -FROM dvds.staff - JOIN dvds.address ON staff.address_id = address.address_id - JOIN dvds.staff AS manager ON staff.staff_id = manager.staff_id; -` - manager := Staff.AS("manager") - - query := Staff. - INNER_JOIN(Address, Staff.AddressID.Eq(Address.AddressID)). - INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)). - SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName). - DISTINCT() - - assertQuery(t, query, expectedSql) - - staffs := []staff{} - - err := query.Query(db, &staffs) - - assert.NilError(t, err) - - fmt.Println(query.DebugSql()) - //spew.Dump(staffs) -} +// +//type Manager staff +// +//type staff struct { +// StaffID int32 `sql:"unique"` +// FirstName string +// LastName string +// //Address *model.Address +// //Email *string +// //StoreID int16 +// //Active bool +// //Username string +// //Password *string +// //LastUpdate time.Time +// *Manager //`sqlbuilder:"manager"` +//} +// +//func TestSelectSelfReferenceType(t *testing.T) { +// +// expectedSql := ` +//SELECT DISTINCT staff.staff_id AS "staff.staff_id", +// staff.first_name AS "staff.first_name", +// staff.last_name AS "staff.last_name", +// address.address_id AS "address.address_id", +// address.address AS "address.address", +// address.address2 AS "address.address2", +// address.district AS "address.district", +// address.city_id AS "address.city_id", +// address.postal_code AS "address.postal_code", +// address.phone AS "address.phone", +// address.last_update AS "address.last_update", +// manager.staff_id AS "manager.staff_id", +// manager.first_name AS "manager.first_name" +//FROM dvds.staff +// JOIN dvds.address ON staff.address_id = address.address_id +// JOIN dvds.staff AS manager ON staff.staff_id = manager.staff_id; +//` +// manager := Staff.AS("manager") +// +// query := Staff. +// INNER_JOIN(Address, Staff.AddressID.Eq(Address.AddressID)). +// INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)). +// SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName). +// DISTINCT() +// +// assertQuery(t, query, expectedSql) +// +// staffs := []staff{} +// +// err := query.Query(db, &staffs) +// +// assert.NilError(t, err) +// +// fmt.Println(query.DebugSql()) +// //spew.Dump(staffs) +//} func TestSubQuery(t *testing.T) { @@ -684,7 +806,8 @@ ORDER BY film.film_id ASC; maxFilmRentalRate := NumExp(Film.SELECT(MAX(Film.RentalRate))) - query := Film.SELECT(Film.AllColumns). + query := Film. + SELECT(Film.AllColumns). WHERE(Film.RentalRate.Eq(maxFilmRentalRate)). ORDER_BY(Film.FilmID.ASC()) @@ -705,7 +828,7 @@ ORDER BY film.film_id ASC; Title: "Ace Goldfinger", Description: stringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), ReleaseYear: int32Ptr(2006), - Language: nil, + LanguageID: 1, RentalRate: 4.99, Length: int16Ptr(48), ReplacementCost: 12.99, @@ -810,6 +933,7 @@ ORDER BY customer_payment_sum.amount_sum ASC; StoreID: 1, FirstName: "Brian", LastName: "Wyman", + AddressID: 323, Email: stringPtr("brian.wyman@sakilacustomer.org"), Activebool: true, CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), @@ -851,6 +975,9 @@ ORDER BY payment.payment_date ASC; assert.Equal(t, len(payments), 9) assert.DeepEqual(t, payments[0], model.Payment{ PaymentID: 17793, + CustomerID: 416, + StaffID: 2, + RentalID: 1158, Amount: 2.99, PaymentDate: *timeWithoutTimeZone("2007-02-14 21:21:59.996577", 6), }) diff --git a/tests/test_util.go b/tests/test_util.go index 2c445ec..f963de1 100644 --- a/tests/test_util.go +++ b/tests/test_util.go @@ -17,7 +17,7 @@ func assertQuery(t *testing.T, query sqlbuilder.Statement, expectedQuery string, debuqSql, err := query.DebugSql() assert.NilError(t, err) - assert.Equal(t, debuqSql, expectedQuery, args) + assert.Equal(t, debuqSql, expectedQuery) } func int16Ptr(i int16) *int16 { @@ -55,7 +55,7 @@ var customer0 = model.Customer{ FirstName: "Mary", LastName: "Smith", Email: stringPtr("mary.smith@sakilacustomer.org"), - Address: nil, + AddressID: 5, Activebool: true, CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), @@ -68,7 +68,7 @@ var customer1 = model.Customer{ FirstName: "Patricia", LastName: "Johnson", Email: stringPtr("patricia.johnson@sakilacustomer.org"), - Address: nil, + AddressID: 6, Activebool: true, CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), @@ -81,7 +81,7 @@ var lastCustomer = model.Customer{ FirstName: "Austin", LastName: "Cintron", Email: stringPtr("austin.cintron@sakilacustomer.org"), - Address: nil, + AddressID: 605, Activebool: true, CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),