diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 52c2307..2f2b03f 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -2,6 +2,7 @@ package execution import ( "database/sql" + "database/sql/driver" "errors" "fmt" "github.com/serenize/snaker" @@ -37,7 +38,6 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { rowData := createScanValue(columnTypes) scanContext := &scanContext{ - columnNames: columnNames, uniqueObjectsMap: make(map[string]interface{}), } @@ -70,6 +70,8 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error { return err } + fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROWS PROCESSED") + return nil } @@ -118,9 +120,10 @@ func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect continue } - rowValue := reflect.ValueOf(row[index]) + cellValue := cellValue(row, index) + + groupKey = groupKey + reflectValueToString(cellValue) - groupKey = groupKey + reflectValueToString(rowValue) } else if !isDbBaseType(fieldType.Type) { var structType reflect.Type if fieldType.Type.Kind() == reflect.Struct { @@ -145,6 +148,30 @@ func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect return groupKey } +func cellValue(row []interface{}, index int) interface{} { + //spew.Dump(row[index]) + + valuer, ok := row[index].(driver.Valuer) + + if !ok { + //fmt.Println("____________________") + //spew.Dump(row[index]) + panic("Scan value doesn't implement driver.Valuer") + } + + //spew.Dump(valuer) + + value, err := valuer.Value() + + if err != nil { + panic(err) + } + + //spew.Dump(value) + + return value +} + func getSliceStructType(slicePtr interface{}) reflect.Type { sliceTypePtr := reflect.TypeOf(slicePtr) @@ -312,26 +339,38 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, columnProcessed [ //columnName := snaker.CamelToSnake(fieldName) ////fmt.Println(columnName) - rowIndex := getIndex(scanContext.columnNames, columnName) + index := getIndex(scanContext.columnNames, columnName) - if rowIndex < 0 || columnProcessed[rowIndex] { + if index < 0 || columnProcessed[index] { continue } - ////spew.Dump(row[rowIndex]) + ////spew.Dump(row[index]) - rowColumnValue := reflect.ValueOf(row[rowIndex]) + cellValue := cellValue(row, index) + //spew.Dump(cellValue) //spew.Dump(rowColumnValue, fieldValue) - setReflectValue(rowColumnValue, fieldValue) + if cellValue != nil { + setReflectValue(reflect.ValueOf(cellValue), fieldValue) + } - columnProcessed[rowIndex] = true + columnProcessed[index] = true } } return nil } -func reflectValueToString(value reflect.Value) string { +func reflectValueToString(val interface{}) string { + //spew.Dump(val) + + if val == nil { + return "" + } + + value := reflect.ValueOf(val) + + //if !value.IsValid() var valueInterface interface{} if value.Kind() == reflect.Ptr { valueInterface = value.Elem().Interface() @@ -372,7 +411,9 @@ func setReflectValue(source, destination reflect.Value) { if source.Kind() == reflect.Ptr { destination.Set(source) } else { - destination.Set(source.Addr()) + newDestination := reflect.New(destination.Type().Elem()) + newDestination.Elem().Set(source) + destination.Set(newDestination) } } else { if source.Kind() == reflect.Ptr { @@ -397,7 +438,7 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} { values := make([]interface{}, len(columnTypes)) for i, sqlColumnType := range columnTypes { - columnType := getScanType(sqlColumnType) + columnType := newScanType(sqlColumnType) columnValue := reflect.New(columnType) @@ -407,17 +448,32 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} { return values } -func getScanType(columnType *sql.ColumnType) reflect.Type { - scanType := columnType.ScanType() - //////fmt.Println(scanType.String()) - if scanType.String() != "interface {}" { - return scanType - } +var nullFloatType = reflect.TypeOf(sql.NullFloat64{}) +var nullInt16Type = reflect.TypeOf(NullInt16{}) +var nullInt32Type = reflect.TypeOf(NullInt32{}) +var nullInt64Type = reflect.TypeOf(sql.NullInt64{}) +var nullStringType = reflect.TypeOf(sql.NullString{}) +var nullBoolType = reflect.TypeOf(sql.NullBool{}) +var nullTimeType = reflect.TypeOf(NullTime{}) +func newScanType(columnType *sql.ColumnType) reflect.Type { + //spew.Dump(columnType) switch columnType.DatabaseTypeName() { + case "INT2": + return nullInt16Type + case "INT4": + return nullInt32Type + case "INT8": + return nullInt64Type + case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR": + return nullStringType case "FLOAT4": - return floatType + return nullFloatType + case "BOOL": + return nullBoolType + case "DATE", "TIMESTAMP": + return nullTimeType default: - return stringType + panic("Unknown column database type " + columnType.DatabaseTypeName()) } } diff --git a/sqlbuilder/execution/null_types.go b/sqlbuilder/execution/null_types.go new file mode 100644 index 0000000..b74ee8a --- /dev/null +++ b/sqlbuilder/execution/null_types.go @@ -0,0 +1,89 @@ +package execution + +import ( + "database/sql/driver" + "time" +) + +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (nt *NullTime) Scan(value interface{}) error { + nt.Time, nt.Valid = value.(time.Time) + return nil +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} + +type NullInt32 struct { + Int32 int32 + Valid bool // Valid is true if Int64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt32) Scan(value interface{}) error { + switch v := value.(type) { + case int64: + n.Int32, n.Valid = int32(v), true + return nil + case int32: + n.Int32, n.Valid = v, true + return nil + case uint8: + n.Int32, n.Valid = int32(v), true + return nil + } + + n.Valid = false + + return nil +} + +// Value implements the driver Valuer interface. +func (n NullInt32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Int32, nil +} + +type NullInt16 struct { + Int16 int16 + Valid bool // Valid is true if Int64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt16) Scan(value interface{}) error { + switch v := value.(type) { + case int64: + n.Int16, n.Valid = int16(v), true + return nil + case int16: + n.Int16, n.Valid = v, true + return nil + case uint8: + n.Int16, n.Valid = int16(v), true + return nil + } + + n.Valid = false + + return nil +} + +// Value implements the driver Valuer interface. +func (n NullInt16) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Int16, nil +} diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index d8ccb37..54b8290 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -33,6 +33,10 @@ type ReadableTable interface { // Creates a right join table expression using onCondition. RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + + FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable + + CrossJoin(table ReadableTable) ReadableTable } // The sql table write interface. @@ -196,6 +200,14 @@ func (t *Table) RightJoinOn( return RightJoinOn(t, table, onCondition) } +func (t *Table) FullJoin(table ReadableTable, col1, col2 Column) ReadableTable { + return FullJoin(t, table, col1.Eq(col2)) +} + +func (t *Table) CrossJoin(table ReadableTable) ReadableTable { + return CrossJoin(t, table) +} + func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement { return newInsertStatement(t, columns...) } @@ -214,6 +226,8 @@ const ( INNER_JOIN joinType = iota LEFT_JOIN RIGHT_JOIN + FULL_JOIN + CROSS_JOIN ) // Join expressions are pseudo readable tables. @@ -262,6 +276,21 @@ func RightJoinOn( return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition) } +func FullJoin( + lhs ReadableTable, + rhs ReadableTable, + onCondition BoolExpression) ReadableTable { + + return newJoinTable(lhs, rhs, FULL_JOIN, onCondition) +} + +func CrossJoin( + lhs ReadableTable, + rhs ReadableTable) ReadableTable { + + return newJoinTable(lhs, rhs, CROSS_JOIN, nil) +} + func (t *joinTable) Columns() []NonAliasColumn { columns := make([]NonAliasColumn, 0) columns = append(columns, t.lhs.Columns()...) @@ -278,7 +307,7 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { if t.rhs == nil { return errors.Newf("nil rhs. Generated sql: %s", out.String()) } - if t.onCondition == nil { + if t.onCondition == nil && t.join_type != CROSS_JOIN { return errors.Newf("nil onCondition. Generated sql: %s", out.String()) } @@ -293,15 +322,21 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { _, _ = out.WriteString(" LEFT JOIN ") case RIGHT_JOIN: _, _ = out.WriteString(" RIGHT JOIN ") + case FULL_JOIN: + out.WriteString(" FULL JOIN ") + case CROSS_JOIN: + out.WriteString(" CROSS JOIN ") } if err = t.rhs.SerializeSql(out); err != nil { return } - _, _ = out.WriteString(" ON ") - if err = t.onCondition.SerializeSql(out); err != nil { - return + if t.onCondition != nil { + _, _ = out.WriteString(" ON ") + if err = t.onCondition.SerializeSql(out); err != nil { + return + } } return nil @@ -333,6 +368,14 @@ func (t *joinTable) LeftJoinOn( return LeftJoinOn(t, table, onCondition) } +func (t *joinTable) FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable { + return FullJoin(t, table, col1.Eq(col2)) +} + +func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable { + return CrossJoin(t, table) +} + func (t *joinTable) RightJoinOn( table ReadableTable, onCondition BoolExpression) ReadableTable {