Query group scan refactoring.

This commit is contained in:
zer0sub 2019-05-20 17:37:55 +02:00
parent 5ed7cf2b1c
commit e656fb610c
9 changed files with 1273 additions and 398 deletions

View file

@ -54,7 +54,7 @@ func (c ColumnInfo) ToSqlBuilderColumnType() string {
func (c ColumnInfo) ToGoType() string { func (c ColumnInfo) ToGoType() string {
typeStr := c.GoBaseType() typeStr := c.GoBaseType()
if c.IsNullable || c.TableInfo.IsForeignKey(c.Name) { if c.IsNullable {
return "*" + typeStr return "*" + typeStr
} }
@ -62,47 +62,40 @@ func (c ColumnInfo) ToGoType() string {
} }
func (c ColumnInfo) GoBaseType() string { func (c ColumnInfo) GoBaseType() string {
if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok { switch c.DataType {
return snaker.SnakeToCamel(forignKeyTable) case "USER-DEFINED":
} else { return snaker.SnakeToCamel(c.EnumName)
switch c.DataType { case "boolean":
case "USER-DEFINED": return "bool"
return snaker.SnakeToCamel(c.EnumName) case "smallint":
case "boolean": return "int16"
return "bool" case "integer":
case "smallint": return "int32"
return "int16" case "bigint":
case "integer": return "int64"
return "int32" case "date", "timestamp without time zone", "timestamp with time zone":
case "bigint": return "time.Time"
return "int64" case "bytea":
case "date", "timestamp without time zone", "timestamp with time zone": return "[]byte"
return "time.Time" case "text", "character", "character varying":
case "bytea": return "string"
return "[]byte" case "real":
case "text", "character", "character varying": return "float32"
return "string" case "numeric", "double precision":
case "real": return "float64"
return "float32" case "uuid":
case "numeric", "double precision": return "uuid.UUID"
return "float64" case "json", "jsonb":
case "uuid": return "types.JSONText"
return "uuid.UUID" default:
case "json", "jsonb": fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.")
return "types.JSONText" return "string"
default:
fmt.Println("Unknown go map type: " + c.DataType + ", " + c.EnumName + ", using string instead.")
return "string"
}
} }
} }
func (c ColumnInfo) ToGoDMFieldName() string { func (c ColumnInfo) ToGoDMFieldName() string {
if forignKeyTable, ok := c.TableInfo.ForeignTableMap[c.Name]; ok { return snaker.SnakeToCamel(c.Name)
return snaker.SnakeToCamel(forignKeyTable)
} else {
return snaker.SnakeToCamel(c.Name)
}
} }
func (c ColumnInfo) ToGoFieldName() string { func (c ColumnInfo) ToGoFieldName() string {

9
go-sqlbuilder.iml Normal file
View file

@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View file

@ -8,24 +8,64 @@ import (
"github.com/serenize/snaker" "github.com/serenize/snaker"
"github.com/sub0zero/go-sqlbuilder/types" "github.com/sub0zero/go-sqlbuilder/types"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
func Query(db types.Db, query string, args []interface{}, destinationPtr interface{}) error { func Query(db types.Db, query string, args []interface{}, destinationPtr interface{}) error {
if destinationPtr == nil {
return errors.New("Destination is nil. ")
}
destinationPtrType := reflect.TypeOf(destinationPtr)
if destinationPtrType.Kind() != reflect.Ptr {
return errors.New("Destination has to be a pointer to slice or pointer to struct. ")
}
if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(db, query, args, destinationPtr)
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
err := queryToSlice(db, query, args, tempSlicePtrValue.Interface())
if err != nil {
return err
}
fmt.Println("TEMP SLICE SIZE: ", tempSliceValue.Len())
if tempSliceValue.Len() == 0 {
return nil
}
structValue := reflect.ValueOf(destinationPtr).Elem()
firstTempStruct := tempSliceValue.Index(0).Elem()
if structValue.Type().AssignableTo(firstTempStruct.Type()) {
structValue.Set(tempSliceValue.Index(0).Elem())
}
return nil
} else {
return errors.New("Unsupported destination type. ")
}
}
func queryToSlice(db types.Db, query string, args []interface{}, slicePtr interface{}) error {
if db == nil { if db == nil {
return errors.New("db is nil") return errors.New("db is nil")
} }
if destinationPtr == nil { if slicePtr == nil {
return errors.New("Destination is nil ") return errors.New("Destination is nil. ")
} }
destinationType := reflect.TypeOf(destinationPtr) destinationType := reflect.TypeOf(slicePtr)
if destinationType.Kind() != reflect.Ptr { if destinationType.Kind() != reflect.Ptr && destinationType.Elem().Kind() != reflect.Slice {
return errors.New("Destination has to be a pointer to slice or pointer to struct ") return errors.New("Destination has to be a pointer to slice. ")
} }
rows, err := db.Query(query, args...) rows, err := db.Query(query, args...)
@ -35,16 +75,17 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa
} }
defer rows.Close() defer rows.Close()
columnNames, _ := rows.Columns() scanContext, err := newScanContext(rows)
columnTypes, _ := rows.ColumnTypes()
scanContext := &scanContext{ if err != nil {
row: createScanValue(columnTypes), return err
columnNames: columnNames,
uniqueObjectsMap: make(map[string]interface{}),
} }
//spew.Dump(columnTypes) if len(scanContext.row) == 0 {
return nil
}
groupTime := time.Duration(0)
for rows.Next() { for rows.Next() {
err := rows.Scan(scanContext.row...) err := rows.Scan(scanContext.row...)
@ -55,17 +96,19 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa
scanContext.rowNum++ scanContext.rowNum++
if destinationType.Elem().Kind() == reflect.Slice { begin := time.Now()
err := mapRowToSlice(scanContext, "", map[string]bool{}, destinationPtr, nil)
if err != nil { _, err = mapRowToSlice(scanContext, "", reflect.ValueOf(slicePtr), nil)
return err
} if err != nil {
} else if destinationType.Elem().Kind() == reflect.Struct { return err
return mapRowToStruct(scanContext, "", map[string]bool{}, destinationPtr, nil)
} }
groupTime += time.Now().Sub(begin)
} }
fmt.Println(groupTime.String())
err = rows.Err() err = rows.Err()
if err != nil { if err != nil {
@ -82,68 +125,78 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa
return nil return nil
} }
type scanContext struct { func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
rowNum int
columnNames []string
row []interface{} sliceElemType := getSliceElemType(slicePtrValue)
uniqueObjectsMap map[string]interface{}
}
func getColumnTypeName(columnName string) (string, error) { if isGoBaseType(sliceElemType) {
split := strings.Split(columnName, ".") index := 0
if len(split) != 2 { if structField != nil {
return "", errors.New("Invalid column name") columnName := getRefTableNameFrom(structField)
index = getIndex(scanContext.columnNames, columnName)
if index < 0 {
return
}
}
rowElemPtr := scanContext.rowElemPtr(index)
if !rowElemPtr.IsNil() {
appendElemToSlice(slicePtrValue, rowElemPtr)
}
return
} }
return split[0], nil if sliceElemType.Kind() != reflect.Struct {
} return false, errors.New("Unsupported dest type: " + structField.Name + " " + structField.Type.String())
}
func allProcessed(arr []bool) bool { structGroupKey := getGroupKey(scanContext, sliceElemType, structField)
for _, b := range arr {
if !b { if structGroupKey == "" {
return false structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|"
}
groupKey = groupKey + ":" + structGroupKey
index, ok := scanContext.uniqueObjectsMap[groupKey]
if ok {
structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, structPtrValue, structField)
} else {
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, structField)
if err != nil {
return
}
if updated {
scanContext.uniqueObjectsMap[groupKey] = slicePtrValue.Elem().Len()
appendElemToSlice(slicePtrValue, destinationStructPtr)
} }
} }
return true return
} }
func getType(reflectType reflect.Type) string { func getGroupKey(scanContext *scanContext, structType reflect.Type, structField *reflect.StructField) string {
var structType reflect.Type tableName := getRefTableNameFrom(structField)
if reflectType.Kind() == reflect.Struct {
structType = reflectType
} else if reflectType.Kind() == reflect.Ptr && reflectType.Elem().Kind() == reflect.Struct {
structType = reflectType.Elem()
}
return structType.Name()
}
func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, structType reflect.Type, structField *reflect.StructField) string {
tableName := getTableAlias(structField)
//fmt.Println("Group: " + tableName)
if tableName == "" { if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name()) tableName = snaker.CamelToSnake(structType.Name())
} }
//fmt.Println(tableName)
if typesProcessed[tableName] {
return ""
}
typesProcessed[tableName] = true
groupKeys := []string{} groupKeys := []string{}
for i := 0; i < structType.NumField(); i++ { for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i) field := structType.Field(i)
////fmt.Println(field.Tag) if !isGoBaseType(field.Type) {
if !isDbBaseType(field.Type) {
var structType reflect.Type var structType reflect.Type
if field.Type.Kind() == reflect.Struct { if field.Type.Kind() == reflect.Struct {
structType = field.Type structType = field.Type
@ -153,11 +206,7 @@ func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, struc
continue continue
} }
//spew.Dump(structType) structGroupKey := getGroupKey(scanContext, structType, &field)
structGroupKey := getGroupKey(scanContext, typesProcessed, structType, &field)
//groupKey = strings.Join([]string{structGroupKey, groupKey}, ":")
if structGroupKey != "" { if structGroupKey != "" {
groupKeys = append(groupKeys, structGroupKey) groupKeys = append(groupKeys, structGroupKey)
@ -166,15 +215,14 @@ func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, struc
fieldName := field.Name fieldName := field.Name
columnName := tableName + "." + snaker.CamelToSnake(fieldName) columnName := tableName + "." + snaker.CamelToSnake(fieldName)
//fmt.Println(fieldName)
index := getIndex(scanContext.columnNames, columnName) index := getIndex(scanContext.columnNames, columnName)
if index < 0 { if index < 0 {
continue continue
} }
cellValue := cellValue(scanContext.row, index) cellValue := scanContext.rowElem(index)
subKey := reflectValueToString(cellValue) subKey := valueToString(cellValue)
if subKey != "" { if subKey != "" {
groupKeys = append(groupKeys, subKey) groupKeys = append(groupKeys, subKey)
@ -186,35 +234,13 @@ func getGroupKey(scanContext *scanContext, typesProcessed map[string]bool, struc
return "" return ""
} }
return "|" + structType.Name() + "(" + strings.Join(groupKeys, ", ") + ")|" groupKey := "{" + structType.Name() + "(" + strings.Join(groupKeys, ",") + ")}"
return groupKey
} }
func cellValue(row []interface{}, index int) interface{} { func getSliceElemType(slicePtrValue reflect.Value) reflect.Type {
//spew.Dump(row[index]) sliceTypePtr := slicePtrValue.Type()
valuer, ok := row[index].(driver.Valuer)
if !ok {
//fmt.Println("____________________")
//spew.Dump(row[index])
panic("Scan value doesn't implement driver.Valuer")
}
//spew.Dump(valuer)
value, err := valuer.Value()
if err != nil {
panic(err)
}
//spew.Dump(value)
return value
}
func getSliceStructType(slicePtr interface{}) reflect.Type {
sliceTypePtr := reflect.TypeOf(slicePtr)
elemType := sliceTypePtr.Elem().Elem() elemType := sliceTypePtr.Elem().Elem()
@ -225,148 +251,101 @@ func getSliceStructType(slicePtr interface{}) reflect.Type {
return elemType return elemType
} }
func cloneProcessedMap(processedMap map[string]bool) map[string]bool { func getSliceElemPtrAt(slicePtrValue reflect.Value, index int) reflect.Value {
newMap := make(map[string]bool, len(processedMap)) sliceValue := slicePtrValue.Elem()
elem := sliceValue.Index(index)
for k, v := range newMap { if elem.Kind() == reflect.Ptr {
newMap[k] = v return elem
} }
return newMap return elem.Addr()
} }
func mapRowToSlice(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error { func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) {
var err error if slicePtrValue.IsNil() {
panic("Slice is nil")
}
sliceValue := slicePtrValue.Elem()
sliceElemType := sliceValue.Type().Elem()
structType := getSliceStructType(destinationPtr) newElemValue := objPtrValue
structGroupKey := getGroupKey(scanContext, cloneProcessedMap(typesProcessed), structType, structField) if sliceElemType.Kind() != reflect.Ptr {
newElemValue = objPtrValue.Elem()
if structGroupKey == "" {
structGroupKey = "|ROW: " + strconv.Itoa(scanContext.rowNum) + "|"
} }
groupKey = groupKey + ":" + structGroupKey if newElemValue.Type().AssignableTo(sliceElemType) {
sliceValue.Set(reflect.Append(sliceValue, newElemValue))
//fmt.Println(groupKey)
objPtr, ok := scanContext.uniqueObjectsMap[groupKey]
if ok {
err = mapRowToStruct(scanContext, groupKey, typesProcessed, objPtr, structField)
if err != nil {
return err
}
} else {
destinationStructPtr := newElemForSlice(destinationPtr)
err = mapRowToStruct(scanContext, groupKey, typesProcessed, destinationStructPtr, structField)
if err != nil {
return err
}
elemPtr := appendElemToSlice(destinationPtr, destinationStructPtr)
scanContext.uniqueObjectsMap[groupKey] = elemPtr
} }
return err
} }
func appendElemToSlice(slice interface{}, objPtr interface{}) interface{} { func newElemPtrValueForSlice(slicePtrValue reflect.Value) reflect.Value {
sliceValue := reflect.ValueOf(slice).Elem() destinationSliceType := slicePtrValue.Type().Elem()
elemType := sliceValue.Type().Elem()
if elemType.Kind() == reflect.Ptr {
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(objPtr)))
return sliceValue.Index(sliceValue.Len() - 1).Interface()
}
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(objPtr).Elem()))
return sliceValue.Index(sliceValue.Len() - 1).Addr().Interface()
}
func newElemForSlice(destinationSlicePtr interface{}) interface{} {
destinationSliceType := reflect.TypeOf(destinationSlicePtr).Elem()
elemType := destinationSliceType.Elem() elemType := destinationSliceType.Elem()
if elemType.Kind() == reflect.Ptr { if elemType.Kind() == reflect.Ptr {
return reflect.New(elemType.Elem()).Interface() return reflect.New(elemType.Elem())
} }
return reflect.New(elemType).Interface() return reflect.New(elemType)
} }
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, dest reflect.Value, structField *reflect.StructField) error { func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
if dest.Kind() == reflect.Struct {
err := mapRowToStruct(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField) if destPtrValue.Kind() != reflect.Ptr {
if err != nil { return false, errors.New("Internal error. ")
return err }
}
} else if dest.Kind() == reflect.Slice { destValueKind := destPtrValue.Elem().Kind()
err := mapRowToSlice(scanContext, groupKey, typesProcessed, dest.Addr().Interface(), structField)
if err != nil { if destValueKind == reflect.Struct {
return err return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
} } else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else {
return false, errors.New("Unsupported dest type: " + structField.Name + " " + structField.Type.String())
}
}
func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest reflect.Value, structField *reflect.StructField) (updated bool, err error) {
var destPtrValue reflect.Value
if dest.Kind() != reflect.Ptr {
destPtrValue = dest.Addr()
} else if dest.Kind() == reflect.Ptr { } else if dest.Kind() == reflect.Ptr {
elemType := dest.Type().Elem() if dest.IsNil() {
destPtrValue = reflect.New(dest.Type().Elem())
if elemType.Kind() == reflect.Struct {
var structValuePtr reflect.Value
if dest.IsNil() {
structValuePtr = reflect.New(elemType)
} else {
return nil
}
err := mapRowToStruct(scanContext, groupKey, typesProcessed, structValuePtr.Interface(), structField)
if err != nil {
return err
}
if structValuePtr.Elem().Interface() != reflect.New(elemType).Elem().Interface() {
dest.Set(structValuePtr)
}
} else if elemType.Kind() == reflect.Slice {
var sliceValuePtr reflect.Value
if dest.IsNil() {
sliceValuePtr = reflect.New(elemType)
} else {
sliceValuePtr = dest
}
err := mapRowToSlice(scanContext, groupKey, typesProcessed, sliceValuePtr.Interface(), structField)
if err != nil {
return err
}
if sliceValuePtr.Elem().Len() > 0 {
dest.Set(sliceValuePtr)
}
} else { } else {
return errors.New("Unsuported field type: " + dest.Type().Name()) destPtrValue = dest
} }
} else { } else {
return errors.New("Unsuported field type: " + dest.Type().Name()) return false, errors.New("Internal error. ")
} }
return nil updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
if err != nil {
return
}
if dest.Kind() == reflect.Ptr && dest.IsNil() && updated {
dest.Set(destPtrValue)
}
return
} }
func getTableAlias(structField *reflect.StructField) string { func getRefTableNameFrom(structField *reflect.StructField) string {
if structField == nil { if structField == nil {
return "" return ""
} }
re := regexp.MustCompile(`sqlbuilder:"(.*?)"`) tagOverwriteName := structField.Tag.Get("sqlbuilder")
tagMatch := re.FindStringSubmatch(string(structField.Tag))
if tagMatch != nil && len(tagMatch) == 2 && tagMatch[1] != "" { if tagOverwriteName != "" {
return tagMatch[1] return tagOverwriteName
} }
if !structField.Anonymous { if !structField.Anonymous {
@ -398,33 +377,20 @@ func getTableAlias(structField *reflect.StructField) string {
return snaker.CamelToSnake(elemType) return snaker.CamelToSnake(elemType)
} }
func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed map[string]bool, destinationPtr interface{}, structField *reflect.StructField) error { func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) {
structType := reflect.TypeOf(destinationPtr).Elem() structType := structPtrValue.Type().Elem()
structValue := reflect.ValueOf(destinationPtr).Elem() structValue := structPtrValue.Elem()
tableName := getTableAlias(structField) tableName := getRefTableNameFrom(structField)
if tableName == "" { if tableName == "" {
tableName = snaker.CamelToSnake(structType.Name()) tableName = snaker.CamelToSnake(structType.Name())
} }
//fmt.Println("map -", tableName)
if typesProcessed[tableName] {
//fmt.Println("Already processed")
return nil
}
typesProcessed[tableName] = true
for i := 0; i < structType.NumField(); i++ { for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i) field := structType.Field(i)
fieldValue := structValue.Field(i) fieldValue := structValue.Field(i)
//fieldTypeName := field.Name
//fmt.Println("---------------", fieldTypeName,)
//spew.Dump(field.Type)
fieldName := field.Name fieldName := field.Name
if scannerValue, ok := implementsScanner(fieldValue); ok { if scannerValue, ok := implementsScanner(fieldValue); ok {
@ -434,38 +400,53 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, typesProcessed ma
continue continue
} }
//spew.Dump(scannerValue.Interface()) initializeValueIfNil(fieldValue)
if scannerValue.IsNil() {
initializePtrValue(scannerValue)
}
scanner := scannerValue.Interface().(sql.Scanner) scanner := scannerValue.Interface().(sql.Scanner)
err := scanner.Scan(cellValue) err = scanner.Scan(cellValue)
if err != nil { if err != nil {
return err return
} }
} else if !isDbBaseType(field.Type) { updated = true
//var fieldValueInterface interface{} } else if !isGoBaseType(field.Type) {
err := mapRowToDestinationValue(scanContext, groupKey, typesProcessed, fieldValue, &field) var changed bool
changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
if err != nil { if err != nil {
return err return
}
if changed {
updated = true
} }
} else { } else {
cellValue := getCellValue(scanContext, tableName, fieldName) cellValue := getCellValue(scanContext, tableName, fieldName)
//spew.Dump(cellValue) //spew.Dump(rowElem)
//spew.Dump(rowColumnValue, fieldValue) //spew.Dump(rowColumnValue, fieldValue)
if cellValue != nil { if cellValue != nil {
updated = true
initializeValueIfNil(fieldValue)
setReflectValue(reflect.ValueOf(cellValue), fieldValue) setReflectValue(reflect.ValueOf(cellValue), fieldValue)
} }
} }
} }
return nil return
}
func initializeValueIfNil(value reflect.Value) {
if !value.IsValid() || !value.CanSet() {
return
}
if value.Type().Kind() == reflect.Slice && value.IsNil() {
value.Set(reflect.New(value.Type()).Elem())
} else if value.Kind() == reflect.Ptr && value.IsNil() {
value.Set(reflect.New(value.Type().Elem()))
}
} }
func implementsScanner(value reflect.Value) (reflect.Value, bool) { func implementsScanner(value reflect.Value) (reflect.Value, bool) {
@ -480,12 +461,6 @@ func implementsScanner(value reflect.Value) (reflect.Value, bool) {
return value, false return value, false
} }
func initializePtrValue(value reflect.Value) {
if value.Kind() == reflect.Ptr {
value.Set(reflect.New(value.Type().Elem()))
}
}
func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} { func getCellValue(scanContext *scanContext, tableName, fieldName string) interface{} {
columnName := "" columnName := ""
@ -495,28 +470,22 @@ func getCellValue(scanContext *scanContext, tableName, fieldName string) interfa
columnName = tableName + "." + snaker.CamelToSnake(fieldName) columnName = tableName + "." + snaker.CamelToSnake(fieldName)
} }
//columnName := snaker.CamelToSnake(fieldName)
////fmt.Println(columnName)
index := getIndex(scanContext.columnNames, columnName) index := getIndex(scanContext.columnNames, columnName)
if index < 0 { if index < 0 {
return nil return nil
} }
return cellValue(scanContext.row, index) return scanContext.rowElem(index)
} }
func reflectValueToString(val interface{}) string { func valueToString(val interface{}) string {
//spew.Dump(val)
if val == nil { if val == nil {
return "" return ""
} }
value := reflect.ValueOf(val) value := reflect.ValueOf(val)
//if !value.IsValid()
var valueInterface interface{} var valueInterface interface{}
if value.Kind() == reflect.Ptr { if value.Kind() == reflect.Ptr {
valueInterface = value.Elem().Interface() valueInterface = value.Elem().Interface()
@ -536,10 +505,7 @@ var floatType = reflect.TypeOf(1.0)
var stringType = reflect.TypeOf("str") var stringType = reflect.TypeOf("str")
var intType = reflect.TypeOf(1) var intType = reflect.TypeOf(1)
func isDbBaseType(objType reflect.Type) bool { func isGoBaseType(objType reflect.Type) bool {
//isBaseType := objType == timeType || floatType == objType || stringType == objType || intType == objType
//isPtrToBaseType := objType.Kind() == reflect.Ptr && (objType.Elem() == timeType || floatType == objType.Elem() ||
// stringType == objType.Elem() || intType == objType.Elem())
typeStr := objType.String() typeStr := objType.String()
switch typeStr { switch typeStr {
@ -548,7 +514,6 @@ func isDbBaseType(objType reflect.Type) bool {
return true return true
} }
//return isBaseType || isPtrToBaseType
return false return false
} }
@ -604,8 +569,6 @@ var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullTimeType = reflect.TypeOf(NullTime{}) var nullTimeType = reflect.TypeOf(NullTime{})
func newScanType(columnType *sql.ColumnType) reflect.Type { func newScanType(columnType *sql.ColumnType) reflect.Type {
//spew.Dump(columnType)
//fmt.Println(columnType.DatabaseTypeName())
switch columnType.DatabaseTypeName() { switch columnType.DatabaseTypeName() {
case "INT2": case "INT2":
return nullInt16Type return nullInt16Type
@ -627,3 +590,67 @@ func newScanType(columnType *sql.ColumnType) reflect.Type {
panic("Unknown column database type " + columnType.DatabaseTypeName()) panic("Unknown column database type " + columnType.DatabaseTypeName())
} }
} }
type scanContext struct {
rowNum int
columnNames []string
row []interface{}
uniqueObjectsMap map[string]int
groupKeyMap map[string]string
}
func newScanContext(rows *sql.Rows) (*scanContext, error) {
columnNames, err := rows.Columns()
if err != nil {
return nil, err
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
return &scanContext{
row: createScanValue(columnTypes),
columnNames: columnNames,
uniqueObjectsMap: make(map[string]int),
groupKeyMap: make(map[string]string),
}, nil
}
func (s *scanContext) rowElem(index int) interface{} {
valuer, ok := s.row[index].(driver.Valuer)
if !ok {
panic("Scan value doesn't implement driver.Valuer")
}
value, err := valuer.Value()
if err != nil {
panic(err)
}
return value
}
func (s *scanContext) rowElemPtr(index int) reflect.Value {
rowElem := s.rowElem(index)
rowElemValue := reflect.ValueOf(rowElem)
if rowElemValue.Kind() == reflect.Ptr {
return rowElemValue
}
if rowElemValue.CanAddr() {
return rowElemValue.Addr()
}
newElem := reflect.New(rowElemValue.Type())
newElem.Elem().Set(rowElemValue)
return newElem
}

View file

@ -58,6 +58,10 @@ func NewNumericFunc(name string, expressions ...expression) numericExpression {
return numericFunc return numericFunc
} }
func COUNT(expression numericExpression) numericExpression {
return NewNumericFunc("COUNT", expression)
}
func MAX(expression numericExpression) numericExpression { func MAX(expression numericExpression) numericExpression {
return NewNumericFunc("MAX", expression) return NewNumericFunc("MAX", expression)
} }

View file

@ -14,6 +14,7 @@ type numericExpression interface {
GtEq(rhs numericExpression) boolExpression GtEq(rhs numericExpression) boolExpression
GtEqL(literal interface{}) boolExpression GtEqL(literal interface{}) boolExpression
Lt(rhs numericExpression) boolExpression
LtEq(rhs numericExpression) boolExpression LtEq(rhs numericExpression) boolExpression
LtEqL(literal interface{}) boolExpression LtEqL(literal interface{}) boolExpression
@ -55,6 +56,10 @@ func (n *numericInterfaceImpl) GtEqL(literal interface{}) boolExpression {
return GtEq(n.parent, Literal(literal)) return GtEq(n.parent, Literal(literal))
} }
func (n *numericInterfaceImpl) Lt(expression numericExpression) boolExpression {
return Lt(n.parent, expression)
}
func (n *numericInterfaceImpl) LtEq(expression numericExpression) boolExpression { func (n *numericInterfaceImpl) LtEq(expression numericExpression) boolExpression {
return LtEq(n.parent, expression) return LtEq(n.parent, expression)
} }

View file

@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/pkg/profile"
"github.com/sub0zero/go-sqlbuilder/generator" "github.com/sub0zero/go-sqlbuilder/generator"
"gotest.tools/assert" "gotest.tools/assert"
"os" "os"
@ -31,6 +32,8 @@ var db *sql.DB
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
fmt.Println("Begin") fmt.Println("Begin")
defer profile.Start().Stop()
var err error var err error
db, err = sql.Open("postgres", connectString) db, err = sql.Open("postgres", connectString)
if err != nil { if err != nil {
@ -66,7 +69,36 @@ CREATE TABLE IF NOT EXISTS test_sample.link (
name VARCHAR (255) NOT NULL, name VARCHAR (255) NOT NULL,
description VARCHAR (255), description VARCHAR (255),
rel VARCHAR (50) rel VARCHAR (50)
);` );
DROP TABLE IF EXISTS test_sample.employee;
CREATE TABLE test_sample.employee (
employee_id INT PRIMARY KEY,
first_name VARCHAR (255) NOT NULL,
last_name VARCHAR (255) NOT NULL,
manager_id INT,
FOREIGN KEY (manager_id)
REFERENCES test_sample.employee (employee_id)
ON DELETE CASCADE
);
INSERT INTO test_sample.employee (
employee_id,
first_name,
last_name,
manager_id
)
VALUES
(1, 'Windy', 'Hays', NULL),
(2, 'Ava', 'Christensen', 1),
(3, 'Hassan', 'Conner', 1),
(4, 'Anna', 'Reeves', 2),
(5, 'Sau', 'Norman', 2),
(6, 'Kelsie', 'Hays', 3),
(7, 'Tory', 'Goff', 3),
(8, 'Salley', 'Lester', 3);
`
result, err := db.Exec(linkTableCreate) result, err := db.Exec(linkTableCreate)
@ -78,6 +110,24 @@ CREATE TABLE IF NOT EXISTS test_sample.link (
} }
func queryAll(t *testing.T, query string, args []interface{}) {
rows, err := db.Query(query, args...)
assert.NilError(t, err)
defer rows.Close()
for rows.Next() {
//err := rows.Scan(scanContext.row...)
//
//assert.NilError(t, err)
}
err = rows.Err()
assert.NilError(t, err)
}
func TestGenerateModel(t *testing.T) { func TestGenerateModel(t *testing.T) {
err := generator.Generate(folderPath, connectString, dbname, schemaName) err := generator.Generate(folderPath, connectString, dbname, schemaName)

660
tests/scan_test.go Normal file
View file

@ -0,0 +1,660 @@
package tests
import (
. "github.com/sub0zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
"gotest.tools/assert"
"testing"
)
var query = Inventory.
SELECT(Inventory.AllColumns).
LIMIT(1).
ORDER_BY(Inventory.InventoryID)
func TestScanToInvalidDestination(t *testing.T) {
t.Run("nil dest", func(t *testing.T) {
err := query.Query(db, nil)
assert.Error(t, err, "Destination is nil. ")
})
t.Run("struct dest", func(t *testing.T) {
err := query.Query(db, struct{}{})
assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ")
})
t.Run("slice dest", func(t *testing.T) {
err := query.Query(db, []struct{}{})
assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ")
})
t.Run("slice of pointers to pointer dest", func(t *testing.T) {
err := query.Query(db, []**struct{}{})
assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ")
})
t.Run("map dest", func(t *testing.T) {
err := query.Query(db, []map[string]string{})
assert.Error(t, err, "Destination has to be a pointer to slice or pointer to struct. ")
})
}
func TestScanToValidDestination(t *testing.T) {
t.Run("pointer to struct", func(t *testing.T) {
err := query.Query(db, &struct{}{})
assert.NilError(t, err)
})
t.Run("pointer to slice", func(t *testing.T) {
err := query.Query(db, &[]struct{}{})
assert.NilError(t, err)
})
t.Run("pointer to slice of pointer to structs", func(t *testing.T) {
err := query.Query(db, &[]*struct{}{})
assert.NilError(t, err)
})
t.Run("pointer to slice of strings", func(t *testing.T) {
err := query.Query(db, &[]string{})
assert.NilError(t, err)
})
}
func TestScanToStruct(t *testing.T) {
query := Inventory.
SELECT(Inventory.AllColumns).
ORDER_BY(Inventory.InventoryID)
t.Run("one struct", func(t *testing.T) {
dest := model.Inventory{}
err := query.LIMIT(1).Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, inventory1, dest)
})
t.Run("multiple structs, just first one used", func(t *testing.T) {
dest := model.Inventory{}
err := query.LIMIT(10).Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, inventory1, dest)
})
t.Run("one struct", func(t *testing.T) {
dest := struct {
model.Inventory
}{}
err := query.LIMIT(1).Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, inventory1, dest.Inventory)
})
t.Run("one struct", func(t *testing.T) {
dest := struct {
*model.Inventory
}{}
err := query.LIMIT(1).Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, inventory1, *dest.Inventory)
})
t.Run("invalid dest", func(t *testing.T) {
dest := struct {
Inventory **model.Inventory
}{}
err := query.Query(db, &dest)
assert.Error(t, err, "Unsupported dest type: Inventory **model.Inventory")
})
t.Run("invalid dest 2", func(t *testing.T) {
dest := struct {
Inventory ***model.Inventory
}{}
err := query.Query(db, &dest)
assert.Error(t, err, "Unsupported dest type: Inventory ***model.Inventory")
})
}
func TestScanToNestedStruct(t *testing.T) {
query := Inventory.
INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)).
INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)).
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns).
WHERE(Inventory.InventoryID.EqL(1))
t.Run("embedded structs", func(t *testing.T) {
dest := struct {
model.Inventory
model.Film
model.Store
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Film, film1)
assert.DeepEqual(t, dest.Store, store1)
})
t.Run("embedded pointer structs", func(t *testing.T) {
dest := struct {
*model.Inventory
*model.Film
*model.Store
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, *dest.Inventory, inventory1)
assert.DeepEqual(t, *dest.Film, film1)
assert.DeepEqual(t, *dest.Store, store1)
})
t.Run("embedded unused structs", func(t *testing.T) {
dest := struct {
model.Inventory
model.Actor //unused
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, model.Actor{})
})
t.Run("embedded unused pointer structs", func(t *testing.T) {
dest := struct {
model.Inventory
*model.Actor //unused
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil))
})
t.Run("embedded unused pointer structs", func(t *testing.T) {
dest := struct {
model.Inventory
Actor *model.Actor //unused
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*model.Actor)(nil))
})
t.Run("embedded pointer to selected column", func(t *testing.T) {
query := Inventory.
INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)).
INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)).
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Literal("").AS("actor.first_name")).
WHERE(Inventory.InventoryID.EqL(1))
dest := struct {
model.Inventory
Actor *model.Actor //unused
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.Assert(t, dest.Actor != nil)
})
t.Run("struct embedded unused pointer", func(t *testing.T) {
dest := struct {
model.Inventory
Actor *struct {
model.Actor
} //unused
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*struct{ model.Actor })(nil))
})
t.Run("multiple embedded unused pointer", func(t *testing.T) {
dest := struct {
model.Inventory
Actor *struct {
model.Actor //unused
model.Language //unesed
}
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Actor, (*struct {
model.Actor
model.Language
})(nil))
})
t.Run("field not nil, embedded selected model", func(t *testing.T) {
dest := struct {
model.Inventory
Actor *struct {
model.Actor //unselected
model.Film //selected
}
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.Assert(t, dest.Actor != nil)
assert.DeepEqual(t, dest.Actor.Actor, model.Actor{})
assert.DeepEqual(t, dest.Actor.Film, film1)
})
t.Run("field not nil, deeply nested selected model", func(t *testing.T) {
dest := struct {
model.Inventory
Actor *struct {
model.Actor //unselected
Film *struct {
*model.Film //selected
}
}
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.Assert(t, dest.Actor != nil)
assert.Assert(t, dest.Actor.Film != nil)
assert.DeepEqual(t, dest.Actor.Film.Film, &film1)
})
t.Run("embedded structs", func(t *testing.T) {
query := Inventory.
INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)).
INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)).
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns, Language.AllColumns).
WHERE(Inventory.InventoryID.EqL(1))
dest := struct {
model.Inventory
Film struct {
model.Film
Language model.Language
Language2 *model.Language
Language3 *model.Language `sqlbuilder:"language"`
Lang struct {
model.Language
}
Lang2 *struct {
model.Language
}
}
Store model.Store
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.DeepEqual(t, dest.Inventory, inventory1)
assert.DeepEqual(t, dest.Film.Film, film1)
assert.DeepEqual(t, dest.Store, store1)
assert.DeepEqual(t, dest.Film.Language, language1)
assert.DeepEqual(t, dest.Film.Lang.Language, language1)
assert.DeepEqual(t, dest.Film.Lang2.Language, language1)
assert.DeepEqual(t, dest.Film.Language2, (*model.Language)(nil))
assert.DeepEqual(t, dest.Film.Language3, &language1)
})
}
func TestScanToSlice(t *testing.T) {
t.Run("slice of structs", func(t *testing.T) {
query := Inventory.
SELECT(Inventory.AllColumns).
ORDER_BY(Inventory.InventoryID).
LIMIT(10)
dest := []model.Inventory{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0], inventory1)
assert.DeepEqual(t, dest[1], inventory2)
})
t.Run("slice of complex structs", func(t *testing.T) {
query := Inventory.
INNER_JOIN(Film, Inventory.FilmID.Eq(Film.FilmID)).
INNER_JOIN(Store, Inventory.StoreID.Eq(Store.StoreID)).
SELECT(Inventory.AllColumns, Film.AllColumns, Store.AllColumns).
ORDER_BY(Inventory.InventoryID).
LIMIT(10)
t.Run("complex struct 1", func(t *testing.T) {
dest := []struct {
model.Inventory
model.Film
model.Store
}{}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0].Inventory, inventory1)
assert.DeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, dest[0].Store, store1)
assert.DeepEqual(t, dest[1].Inventory, inventory2)
})
t.Run("complex struct 2", func(t *testing.T) {
var dest []struct {
*model.Inventory
model.Film
*model.Store
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0].Inventory, &inventory1)
assert.DeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, dest[0].Store, &store1)
assert.DeepEqual(t, dest[1].Inventory, &inventory2)
})
t.Run("complex struct 3", func(t *testing.T) {
var dest []struct {
Inventory model.Inventory
Film *model.Film
Store struct {
*model.Store
}
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 10)
assert.DeepEqual(t, dest[0].Inventory, inventory1)
assert.DeepEqual(t, dest[0].Film, &film1)
assert.DeepEqual(t, dest[0].Store.Store, &store1)
assert.DeepEqual(t, dest[1].Inventory, inventory2)
})
t.Run("complex struct 4", func(t *testing.T) {
var dest []struct {
model.Film
Inventories []struct {
model.Inventory
model.Store
}
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0].Film, film1)
assert.DeepEqual(t, len(dest[0].Inventories), 8)
assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1)
assert.DeepEqual(t, dest[0].Inventories[0].Store, store1)
})
t.Run("complex struct 5", func(t *testing.T) {
var dest []struct {
model.Film
Inventories []struct {
model.Inventory
Rentals *[]model.Rental
Rentals2 []model.Rental
}
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 2)
assert.DeepEqual(t, dest[0].Film, film1)
assert.Equal(t, len(dest[0].Inventories), 8)
assert.DeepEqual(t, dest[0].Inventories[0].Inventory, inventory1)
assert.Assert(t, dest[0].Inventories[0].Rentals == nil)
assert.Assert(t, dest[0].Inventories[0].Rentals2 == nil)
})
})
t.Run("slice of complex structs 2", func(t *testing.T) {
query := Country.
INNER_JOIN(City, City.CountryID.Eq(Country.CountryID)).
INNER_JOIN(Address, Address.CityID.Eq(City.CityID)).
INNER_JOIN(Customer, Customer.AddressID.Eq(Address.AddressID)).
SELECT(Country.AllColumns, City.AllColumns, Address.AllColumns, Customer.AllColumns).
ORDER_BY(Country.CountryID.ASC(), City.CityID.ASC(), Address.AddressID.ASC(), Customer.CustomerID.ASC()).
LIMIT(1000)
t.Run("dest1", func(t *testing.T) {
var dest []struct {
model.Country
Cities []struct {
model.City
Adresses []struct {
model.Address
Customer model.Customer
}
}
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 108)
assert.DeepEqual(t, dest[100].Country, countryUk)
assert.Equal(t, len(dest[100].Cities), 8)
assert.DeepEqual(t, dest[100].Cities[2].City, cityLondon)
assert.Equal(t, len(dest[100].Cities[2].Adresses), 2)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Address, address256)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[0].Customer, customer256)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Address, addres517)
assert.DeepEqual(t, dest[100].Cities[2].Adresses[1].Customer, customer512)
})
t.Run("dest1", func(t *testing.T) {
var dest []*struct {
*model.Country
Cities []*struct {
*model.City
Adresses *[]*struct {
*model.Address
Customer *model.Customer
}
}
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 108)
assert.DeepEqual(t, dest[100].Country, &countryUk)
assert.Equal(t, len(dest[100].Cities), 8)
assert.DeepEqual(t, dest[100].Cities[2].City, &cityLondon)
assert.Equal(t, len(*dest[100].Cities[2].Adresses), 2)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Address, &address256)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[0].Customer, &customer256)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Address, &addres517)
assert.DeepEqual(t, (*dest[100].Cities[2].Adresses)[1].Customer, &customer512)
})
})
t.Run("dest1", func(t *testing.T) {
var dest []*struct {
*model.Country
Cities []**struct {
*model.City
}
}
err := query.Query(db, &dest)
assert.Error(t, err, "Unsupported dest type: Cities []**struct { *model.City }")
})
}
var address256 = model.Address{
AddressID: 256,
Address: "1497 Yuzhou Drive",
Address2: stringPtr(""),
District: "England",
CityID: 312,
PostalCode: stringPtr("3433"),
Phone: "246810237916",
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:30", 0),
}
var addres517 = model.Address{
AddressID: 517,
Address: "548 Uruapan Street",
Address2: stringPtr(""),
District: "Ontario",
CityID: 312,
PostalCode: stringPtr("35653"),
Phone: "879347453467",
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:30", 0),
}
var customer256 = model.Customer{
CustomerID: 252,
StoreID: 2,
FirstName: "Mattie",
LastName: "Hoffman",
Email: stringPtr("mattie.hoffman@sakilacustomer.org"),
AddressID: 256,
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 0),
Active: int32Ptr(1),
}
var customer512 = model.Customer{
CustomerID: 512,
StoreID: 1,
FirstName: "Cecil",
LastName: "Vines",
Email: stringPtr("cecil.vines@sakilacustomer.org"),
AddressID: 517,
Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 0),
Active: int32Ptr(1),
}
var countryUk = model.Country{
CountryID: 102,
Country: "United Kingdom",
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:44:00", 0),
}
var cityLondon = model.City{
CityID: 312,
City: "London",
CountryID: 102,
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:45:25", 0),
}
var inventory1 = model.Inventory{
InventoryID: 1,
FilmID: 1,
StoreID: 1,
LastUpdate: *timeWithoutTimeZone("2006-02-15 10:09:17", 0),
}
var inventory2 = model.Inventory{
InventoryID: 2,
FilmID: 1,
StoreID: 1,
LastUpdate: *timeWithoutTimeZone("2006-02-15 10:09:17", 0),
}
var film1 = model.Film{
FilmID: 1,
Title: "Academy Dinosaur",
Description: stringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"),
ReleaseYear: int32Ptr(2006),
LanguageID: 1,
RentalDuration: 6,
RentalRate: 0.99,
Length: int16Ptr(86),
ReplacementCost: 20.99,
Rating: &pgRating,
LastUpdate: *timeWithoutTimeZone("2013-05-26 14:50:58.951", 3),
SpecialFeatures: stringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"),
Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17",
}
var store1 = model.Store{
StoreID: 1,
ManagerStaffID: 1,
AddressID: 1,
LastUpdate: *timeWithoutTimeZone("2006-02-15 09:57:12", 0),
}
var pgRating = model.MpaaRating_PG
var language1 = model.Language{
LanguageID: 1,
Name: "English ",
LastUpdate: *timeWithoutTimeZone("2006-02-15 10:02:19", 0),
}

View file

@ -5,6 +5,8 @@ import (
. "github.com/sub0zero/go-sqlbuilder/sqlbuilder" . "github.com/sub0zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model"
. "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table"
model2 "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
. "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
) )
@ -16,14 +18,12 @@ SELECT actor.actor_id AS "actor.actor_id",
actor.last_name AS "actor.last_name", actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update" actor.last_update AS "actor.last_update"
FROM dvds.actor FROM dvds.actor
WHERE actor.actor_id = 1 WHERE actor.actor_id = 1;
ORDER BY actor.actor_id ASC;
` `
query := Actor. query := Actor.
SELECT(Actor.AllColumns). SELECT(Actor.AllColumns).
WHERE(Actor.ActorID.EqL(1)). WHERE(Actor.ActorID.EqL(1))
ORDER_BY(Actor.ActorID.ASC())
assertQuery(t, query, expectedSql, 1) assertQuery(t, query, expectedSql, 1)
@ -79,8 +79,6 @@ LIMIT 30;
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(dest), 30) assert.Equal(t, len(dest), 30)
//spew.Dump(dest)
} }
func TestSelect_ScanToSlice(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) {
@ -159,30 +157,99 @@ LIMIT 12;
assertQuery(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) assertQuery(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12))
} }
//func TestJoinQueryStruct(t *testing.T) { func TestJoinQueryStruct(t *testing.T) {
//
// query := FilmActor. expectedSql := `
// INNER_JOIN(Actor, FilmActor.ActorID.Eq(Actor.ActorID)). SELECT film_actor.actor_id AS "film_actor.actor_id",
// INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)). film_actor.film_id AS "film_actor.film_id",
// INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). film_actor.last_update AS "film_actor.last_update",
// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). film.film_id AS "film.film_id",
// WHERE(FilmActor.ActorID.GtEq(1).AND(FilmActor.ActorID.LteLiteral(2))) film.title AS "film.title",
// film.description AS "film.description",
// queryStr, args, err := query.Sql() film.release_year AS "film.release_year",
// assert.NilError(t, err) film.language_id AS "film.language_id",
// assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`) film.rental_duration AS "film.rental_duration",
// film.rental_rate AS "film.rental_rate",
// //fmt.Println(queryStr) film.length AS "film.length",
// film.replacement_cost AS "film.replacement_cost",
// filmActor := []model.FilmActor{} film.rating AS "film.rating",
// film.last_update AS "film.last_update",
// err = query.Execute(db, &filmActor) film.special_features AS "film.special_features",
// film.fulltext AS "film.fulltext",
// assert.NilError(t, err) language.language_id AS "language.language_id",
// language.name AS "language.name",
// //fmt.Println("ACTORS: --------------------") language.last_update AS "language.last_update",
// //spew.Dump(filmActor) actor.actor_id AS "actor.actor_id",
//} actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update",
inventory.inventory_id AS "inventory.inventory_id",
inventory.film_id AS "inventory.film_id",
inventory.store_id AS "inventory.store_id",
inventory.last_update AS "inventory.last_update",
rental.rental_id AS "rental.rental_id",
rental.rental_date AS "rental.rental_date",
rental.inventory_id AS "rental.inventory_id",
rental.customer_id AS "rental.customer_id",
rental.return_date AS "rental.return_date",
rental.staff_id AS "rental.staff_id",
rental.last_update AS "rental.last_update"
FROM dvds.film_actor
JOIN dvds.actor ON film_actor.actor_id = actor.actor_id
JOIN dvds.film ON film_actor.film_id = film.film_id
JOIN dvds.language ON film.language_id = language.language_id
JOIN dvds.inventory ON inventory.film_id = film.film_id
JOIN dvds.rental ON rental.inventory_id = inventory.inventory_id
ORDER BY film.film_id ASC
LIMIT 50;
`
for i := 0; i < 1; i++ {
query := FilmActor.
INNER_JOIN(Actor, FilmActor.ActorID.Eq(Actor.ActorID)).
INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)).
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
INNER_JOIN(Inventory, Inventory.FilmID.Eq(Film.FilmID)).
INNER_JOIN(Rental, Rental.InventoryID.Eq(Inventory.InventoryID)).
SELECT(
FilmActor.AllColumns,
Film.AllColumns,
Language.AllColumns,
Actor.AllColumns,
Inventory.AllColumns,
Rental.AllColumns,
).
//WHERE(FilmActor.ActorID.GtEqL(1).AND(FilmActor.ActorID.LtEqL(2))).
ORDER_BY(Film.FilmID.ASC()).
LIMIT(50)
assertQuery(t, query, expectedSql, int64(50))
var languageActorFilm []struct {
model.Language
Films []struct {
model.Film
Actors []struct {
model.Actor
}
Inventory []struct {
model.Inventory
Rental []model.Rental
}
}
}
err := query.Query(db, &languageActorFilm)
assert.NilError(t, err)
assert.Equal(t, len(languageActorFilm), 1)
assert.Equal(t, len(languageActorFilm[0].Films), 1)
assert.Equal(t, len(languageActorFilm[0].Films[0].Actors), 10)
}
}
func TestJoinQuerySlice(t *testing.T) { func TestJoinQuerySlice(t *testing.T) {
expectedSql := ` expectedSql := `
@ -408,7 +475,10 @@ LIMIT 1000;
assertQuery(t, query, expectedSql, int64(1000)) assertQuery(t, query, expectedSql, int64(1000))
customerAddresCrosJoined := []model.Customer{} var customerAddresCrosJoined []struct {
model.Customer
model.Address
}
err := query.Query(db, &customerAddresCrosJoined) err := query.Query(db, &customerAddresCrosJoined)
@ -417,6 +487,57 @@ LIMIT 1000;
assert.NilError(t, err) assert.NilError(t, err)
} }
func TestSelecSelfJoin1(t *testing.T) {
var expectedSql = `
SELECT employee.employee_id AS "employee.employee_id",
employee.first_name AS "employee.first_name",
employee.last_name AS "employee.last_name",
employee.manager_id AS "employee.manager_id",
manager.employee_id AS "manager.employee_id",
manager.first_name AS "manager.first_name",
manager.last_name AS "manager.last_name",
manager.manager_id AS "manager.manager_id"
FROM test_sample.employee
LEFT JOIN test_sample.employee AS manager ON manager.employee_id = employee.manager_id
ORDER BY employee.employee_id;
`
manager := Employee.AS("manager")
query := Employee.
LEFT_JOIN(manager, manager.EmployeeID.Eq(Employee.ManagerID)).
SELECT(Employee.AllColumns, manager.AllColumns).
ORDER_BY(Employee.EmployeeID)
assertQuery(t, query, expectedSql)
var dest []struct {
model2.Employee
Manager *model2.Employee
}
err := query.Query(db, &dest)
assert.NilError(t, err)
assert.Equal(t, len(dest), 8)
assert.DeepEqual(t, dest[0].Employee, model2.Employee{
EmployeeID: 1,
FirstName: "Windy",
LastName: "Hays",
ManagerID: nil,
})
assert.Assert(t, dest[0].Manager == nil)
assert.DeepEqual(t, dest[7].Employee, model2.Employee{
EmployeeID: 8,
FirstName: "Salley",
LastName: "Lester",
ManagerID: int32Ptr(3),
})
}
func TestSelectSelfJoin(t *testing.T) { func TestSelectSelfJoin(t *testing.T) {
expectedSql := ` expectedSql := `
SELECT f1.film_id AS "f1.film_id", SELECT f1.film_id AS "f1.film_id",
@ -446,21 +567,19 @@ SELECT f1.film_id AS "f1.film_id",
f2.special_features AS "f2.special_features", f2.special_features AS "f2.special_features",
f2.fulltext AS "f2.fulltext" f2.fulltext AS "f2.fulltext"
FROM dvds.film AS f1 FROM dvds.film AS f1
JOIN dvds.film AS f2 ON (f1.film_id != f2.film_id AND f1.length = f2.length) JOIN dvds.film AS f2 ON (f1.film_id < f2.film_id AND f1.length = f2.length)
ORDER BY f1.film_id ASC ORDER BY f1.film_id ASC;
LIMIT 100;
` `
f1 := Film.AS("f1") f1 := Film.AS("f1")
f2 := Film.AS("f2") f2 := Film.AS("f2")
query := f1. query := f1.
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))). INNER_JOIN(f2, f1.FilmID.Lt(f2.FilmID).AND(f1.Length.Eq(f2.Length))).
SELECT(f1.AllColumns, f2.AllColumns). SELECT(f1.AllColumns, f2.AllColumns).
ORDER_BY(f1.FilmID.ASC()). ORDER_BY(f1.FilmID.ASC())
LIMIT(100)
assertQuery(t, query, expectedSql, int64(100)) assertQuery(t, query, expectedSql)
type F1 model.Film type F1 model.Film
type F2 model.Film type F2 model.Film
@ -474,7 +593,9 @@ LIMIT 100;
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(theSameLengthFilms), 100) //spew.Dump(theSameLengthFilms)
//assert.Equal(t, len(theSameLengthFilms), 100)
} }
func TestSelectAliasColumn(t *testing.T) { func TestSelectAliasColumn(t *testing.T) {
@ -517,61 +638,62 @@ LIMIT 1000;
assert.DeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46}) assert.DeepEqual(t, films[0], thesameLengthFilms{"Alien Center", "Iron Moon", 46})
} }
type Manager staff //
//type Manager staff
type staff struct { //
StaffID int32 `sql:"unique"` //type staff struct {
FirstName string // StaffID int32 `sql:"unique"`
LastName string // FirstName string
//Address *model.Address // LastName string
//Email *string // //Address *model.Address
//StoreID int16 // //Email *string
//Active bool // //StoreID int16
//Username string // //Active bool
//Password *string // //Username string
//LastUpdate time.Time // //Password *string
*Manager //`sqlbuilder:"manager"` // //LastUpdate time.Time
} // *Manager //`sqlbuilder:"manager"`
//}
func TestSelectSelfReferenceType(t *testing.T) { //
//func TestSelectSelfReferenceType(t *testing.T) {
expectedSql := ` //
SELECT DISTINCT staff.staff_id AS "staff.staff_id", // expectedSql := `
staff.first_name AS "staff.first_name", //SELECT DISTINCT staff.staff_id AS "staff.staff_id",
staff.last_name AS "staff.last_name", // staff.first_name AS "staff.first_name",
address.address_id AS "address.address_id", // staff.last_name AS "staff.last_name",
address.address AS "address.address", // address.address_id AS "address.address_id",
address.address2 AS "address.address2", // address.address AS "address.address",
address.district AS "address.district", // address.address2 AS "address.address2",
address.city_id AS "address.city_id", // address.district AS "address.district",
address.postal_code AS "address.postal_code", // address.city_id AS "address.city_id",
address.phone AS "address.phone", // address.postal_code AS "address.postal_code",
address.last_update AS "address.last_update", // address.phone AS "address.phone",
manager.staff_id AS "manager.staff_id", // address.last_update AS "address.last_update",
manager.first_name AS "manager.first_name" // manager.staff_id AS "manager.staff_id",
FROM dvds.staff // manager.first_name AS "manager.first_name"
JOIN dvds.address ON staff.address_id = address.address_id //FROM dvds.staff
JOIN dvds.staff AS manager ON staff.staff_id = manager.staff_id; // JOIN dvds.address ON staff.address_id = address.address_id
` // JOIN dvds.staff AS manager ON staff.staff_id = manager.staff_id;
manager := Staff.AS("manager") //`
// manager := Staff.AS("manager")
query := Staff. //
INNER_JOIN(Address, Staff.AddressID.Eq(Address.AddressID)). // query := Staff.
INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)). // INNER_JOIN(Address, Staff.AddressID.Eq(Address.AddressID)).
SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName). // INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)).
DISTINCT() // SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName).
// DISTINCT()
assertQuery(t, query, expectedSql) //
// assertQuery(t, query, expectedSql)
staffs := []staff{} //
// staffs := []staff{}
err := query.Query(db, &staffs) //
// err := query.Query(db, &staffs)
assert.NilError(t, err) //
// assert.NilError(t, err)
fmt.Println(query.DebugSql()) //
//spew.Dump(staffs) // fmt.Println(query.DebugSql())
} // //spew.Dump(staffs)
//}
func TestSubQuery(t *testing.T) { func TestSubQuery(t *testing.T) {
@ -684,7 +806,8 @@ ORDER BY film.film_id ASC;
maxFilmRentalRate := NumExp(Film.SELECT(MAX(Film.RentalRate))) maxFilmRentalRate := NumExp(Film.SELECT(MAX(Film.RentalRate)))
query := Film.SELECT(Film.AllColumns). query := Film.
SELECT(Film.AllColumns).
WHERE(Film.RentalRate.Eq(maxFilmRentalRate)). WHERE(Film.RentalRate.Eq(maxFilmRentalRate)).
ORDER_BY(Film.FilmID.ASC()) ORDER_BY(Film.FilmID.ASC())
@ -705,7 +828,7 @@ ORDER BY film.film_id ASC;
Title: "Ace Goldfinger", Title: "Ace Goldfinger",
Description: stringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), Description: stringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"),
ReleaseYear: int32Ptr(2006), ReleaseYear: int32Ptr(2006),
Language: nil, LanguageID: 1,
RentalRate: 4.99, RentalRate: 4.99,
Length: int16Ptr(48), Length: int16Ptr(48),
ReplacementCost: 12.99, ReplacementCost: 12.99,
@ -810,6 +933,7 @@ ORDER BY customer_payment_sum.amount_sum ASC;
StoreID: 1, StoreID: 1,
FirstName: "Brian", FirstName: "Brian",
LastName: "Wyman", LastName: "Wyman",
AddressID: 323,
Email: stringPtr("brian.wyman@sakilacustomer.org"), Email: stringPtr("brian.wyman@sakilacustomer.org"),
Activebool: true, Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
@ -851,6 +975,9 @@ ORDER BY payment.payment_date ASC;
assert.Equal(t, len(payments), 9) assert.Equal(t, len(payments), 9)
assert.DeepEqual(t, payments[0], model.Payment{ assert.DeepEqual(t, payments[0], model.Payment{
PaymentID: 17793, PaymentID: 17793,
CustomerID: 416,
StaffID: 2,
RentalID: 1158,
Amount: 2.99, Amount: 2.99,
PaymentDate: *timeWithoutTimeZone("2007-02-14 21:21:59.996577", 6), PaymentDate: *timeWithoutTimeZone("2007-02-14 21:21:59.996577", 6),
}) })

View file

@ -17,7 +17,7 @@ func assertQuery(t *testing.T, query sqlbuilder.Statement, expectedQuery string,
debuqSql, err := query.DebugSql() debuqSql, err := query.DebugSql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, debuqSql, expectedQuery, args) assert.Equal(t, debuqSql, expectedQuery)
} }
func int16Ptr(i int16) *int16 { func int16Ptr(i int16) *int16 {
@ -55,7 +55,7 @@ var customer0 = model.Customer{
FirstName: "Mary", FirstName: "Mary",
LastName: "Smith", LastName: "Smith",
Email: stringPtr("mary.smith@sakilacustomer.org"), Email: stringPtr("mary.smith@sakilacustomer.org"),
Address: nil, AddressID: 5,
Activebool: true, Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),
@ -68,7 +68,7 @@ var customer1 = model.Customer{
FirstName: "Patricia", FirstName: "Patricia",
LastName: "Johnson", LastName: "Johnson",
Email: stringPtr("patricia.johnson@sakilacustomer.org"), Email: stringPtr("patricia.johnson@sakilacustomer.org"),
Address: nil, AddressID: 6,
Activebool: true, Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),
@ -81,7 +81,7 @@ var lastCustomer = model.Customer{
FirstName: "Austin", FirstName: "Austin",
LastName: "Cintron", LastName: "Cintron",
Email: stringPtr("austin.cintron@sakilacustomer.org"), Email: stringPtr("austin.cintron@sakilacustomer.org"),
Address: nil, AddressID: 605,
Activebool: true, Activebool: true,
CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3),