Add FULL and CROSS JOIN support.

This commit is contained in:
sub0Zero 2019-03-16 14:02:45 +01:00 committed by zer0sub
parent a49c682672
commit 6101e44bdf
3 changed files with 212 additions and 24 deletions

View file

@ -2,6 +2,7 @@ package execution
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"github.com/serenize/snaker" "github.com/serenize/snaker"
@ -37,7 +38,6 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
rowData := createScanValue(columnTypes) rowData := createScanValue(columnTypes)
scanContext := &scanContext{ scanContext := &scanContext{
columnNames: columnNames, columnNames: columnNames,
uniqueObjectsMap: make(map[string]interface{}), uniqueObjectsMap: make(map[string]interface{}),
} }
@ -70,6 +70,8 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
return err return err
} }
fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROWS PROCESSED")
return nil return nil
} }
@ -118,9 +120,10 @@ func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect
continue continue
} }
rowValue := reflect.ValueOf(row[index]) cellValue := cellValue(row, index)
groupKey = groupKey + reflectValueToString(cellValue)
groupKey = groupKey + reflectValueToString(rowValue)
} else if !isDbBaseType(fieldType.Type) { } else if !isDbBaseType(fieldType.Type) {
var structType reflect.Type var structType reflect.Type
if fieldType.Type.Kind() == reflect.Struct { if fieldType.Type.Kind() == reflect.Struct {
@ -145,6 +148,30 @@ func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect
return groupKey return groupKey
} }
func cellValue(row []interface{}, index int) interface{} {
//spew.Dump(row[index])
valuer, ok := row[index].(driver.Valuer)
if !ok {
//fmt.Println("____________________")
//spew.Dump(row[index])
panic("Scan value doesn't implement driver.Valuer")
}
//spew.Dump(valuer)
value, err := valuer.Value()
if err != nil {
panic(err)
}
//spew.Dump(value)
return value
}
func getSliceStructType(slicePtr interface{}) reflect.Type { func getSliceStructType(slicePtr interface{}) reflect.Type {
sliceTypePtr := reflect.TypeOf(slicePtr) sliceTypePtr := reflect.TypeOf(slicePtr)
@ -312,26 +339,38 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, columnProcessed [
//columnName := snaker.CamelToSnake(fieldName) //columnName := snaker.CamelToSnake(fieldName)
////fmt.Println(columnName) ////fmt.Println(columnName)
rowIndex := getIndex(scanContext.columnNames, columnName) index := getIndex(scanContext.columnNames, columnName)
if rowIndex < 0 || columnProcessed[rowIndex] { if index < 0 || columnProcessed[index] {
continue continue
} }
////spew.Dump(row[rowIndex]) ////spew.Dump(row[index])
rowColumnValue := reflect.ValueOf(row[rowIndex]) cellValue := cellValue(row, index)
//spew.Dump(cellValue)
//spew.Dump(rowColumnValue, fieldValue) //spew.Dump(rowColumnValue, fieldValue)
setReflectValue(rowColumnValue, fieldValue) if cellValue != nil {
setReflectValue(reflect.ValueOf(cellValue), fieldValue)
}
columnProcessed[rowIndex] = true columnProcessed[index] = true
} }
} }
return nil return nil
} }
func reflectValueToString(value reflect.Value) string { func reflectValueToString(val interface{}) string {
//spew.Dump(val)
if val == nil {
return ""
}
value := reflect.ValueOf(val)
//if !value.IsValid()
var valueInterface interface{} var valueInterface interface{}
if value.Kind() == reflect.Ptr { if value.Kind() == reflect.Ptr {
valueInterface = value.Elem().Interface() valueInterface = value.Elem().Interface()
@ -372,7 +411,9 @@ func setReflectValue(source, destination reflect.Value) {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
destination.Set(source) destination.Set(source)
} else { } else {
destination.Set(source.Addr()) newDestination := reflect.New(destination.Type().Elem())
newDestination.Elem().Set(source)
destination.Set(newDestination)
} }
} else { } else {
if source.Kind() == reflect.Ptr { if source.Kind() == reflect.Ptr {
@ -397,7 +438,7 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
values := make([]interface{}, len(columnTypes)) values := make([]interface{}, len(columnTypes))
for i, sqlColumnType := range columnTypes { for i, sqlColumnType := range columnTypes {
columnType := getScanType(sqlColumnType) columnType := newScanType(sqlColumnType)
columnValue := reflect.New(columnType) columnValue := reflect.New(columnType)
@ -407,17 +448,32 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
return values return values
} }
func getScanType(columnType *sql.ColumnType) reflect.Type { var nullFloatType = reflect.TypeOf(sql.NullFloat64{})
scanType := columnType.ScanType() var nullInt16Type = reflect.TypeOf(NullInt16{})
//////fmt.Println(scanType.String()) var nullInt32Type = reflect.TypeOf(NullInt32{})
if scanType.String() != "interface {}" { var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
return scanType var nullStringType = reflect.TypeOf(sql.NullString{})
} var nullBoolType = reflect.TypeOf(sql.NullBool{})
var nullTimeType = reflect.TypeOf(NullTime{})
func newScanType(columnType *sql.ColumnType) reflect.Type {
//spew.Dump(columnType)
switch columnType.DatabaseTypeName() { switch columnType.DatabaseTypeName() {
case "INT2":
return nullInt16Type
case "INT4":
return nullInt32Type
case "INT8":
return nullInt64Type
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR":
return nullStringType
case "FLOAT4": case "FLOAT4":
return floatType return nullFloatType
case "BOOL":
return nullBoolType
case "DATE", "TIMESTAMP":
return nullTimeType
default: default:
return stringType panic("Unknown column database type " + columnType.DatabaseTypeName())
} }
} }

View file

@ -0,0 +1,89 @@
package execution
import (
"database/sql/driver"
"time"
)
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error {
nt.Time, nt.Valid = value.(time.Time)
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}
type NullInt32 struct {
Int32 int32
Valid bool // Valid is true if Int64 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullInt32) Scan(value interface{}) error {
switch v := value.(type) {
case int64:
n.Int32, n.Valid = int32(v), true
return nil
case int32:
n.Int32, n.Valid = v, true
return nil
case uint8:
n.Int32, n.Valid = int32(v), true
return nil
}
n.Valid = false
return nil
}
// Value implements the driver Valuer interface.
func (n NullInt32) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int32, nil
}
type NullInt16 struct {
Int16 int16
Valid bool // Valid is true if Int64 is not NULL
}
// Scan implements the Scanner interface.
func (n *NullInt16) Scan(value interface{}) error {
switch v := value.(type) {
case int64:
n.Int16, n.Valid = int16(v), true
return nil
case int16:
n.Int16, n.Valid = v, true
return nil
case uint8:
n.Int16, n.Valid = int16(v), true
return nil
}
n.Valid = false
return nil
}
// Value implements the driver Valuer interface.
func (n NullInt16) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Int16, nil
}

View file

@ -33,6 +33,10 @@ type ReadableTable interface {
// Creates a right join table expression using onCondition. // Creates a right join table expression using onCondition.
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable
CrossJoin(table ReadableTable) ReadableTable
} }
// The sql table write interface. // The sql table write interface.
@ -196,6 +200,14 @@ func (t *Table) RightJoinOn(
return RightJoinOn(t, table, onCondition) return RightJoinOn(t, table, onCondition)
} }
func (t *Table) FullJoin(table ReadableTable, col1, col2 Column) ReadableTable {
return FullJoin(t, table, col1.Eq(col2))
}
func (t *Table) CrossJoin(table ReadableTable) ReadableTable {
return CrossJoin(t, table)
}
func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement { func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement {
return newInsertStatement(t, columns...) return newInsertStatement(t, columns...)
} }
@ -214,6 +226,8 @@ const (
INNER_JOIN joinType = iota INNER_JOIN joinType = iota
LEFT_JOIN LEFT_JOIN
RIGHT_JOIN RIGHT_JOIN
FULL_JOIN
CROSS_JOIN
) )
// Join expressions are pseudo readable tables. // Join expressions are pseudo readable tables.
@ -262,6 +276,21 @@ func RightJoinOn(
return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition) return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition)
} }
func FullJoin(
lhs ReadableTable,
rhs ReadableTable,
onCondition BoolExpression) ReadableTable {
return newJoinTable(lhs, rhs, FULL_JOIN, onCondition)
}
func CrossJoin(
lhs ReadableTable,
rhs ReadableTable) ReadableTable {
return newJoinTable(lhs, rhs, CROSS_JOIN, nil)
}
func (t *joinTable) Columns() []NonAliasColumn { func (t *joinTable) Columns() []NonAliasColumn {
columns := make([]NonAliasColumn, 0) columns := make([]NonAliasColumn, 0)
columns = append(columns, t.lhs.Columns()...) columns = append(columns, t.lhs.Columns()...)
@ -278,7 +307,7 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
if t.rhs == nil { if t.rhs == nil {
return errors.Newf("nil rhs. Generated sql: %s", out.String()) return errors.Newf("nil rhs. Generated sql: %s", out.String())
} }
if t.onCondition == nil { if t.onCondition == nil && t.join_type != CROSS_JOIN {
return errors.Newf("nil onCondition. Generated sql: %s", out.String()) return errors.Newf("nil onCondition. Generated sql: %s", out.String())
} }
@ -293,15 +322,21 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
_, _ = out.WriteString(" LEFT JOIN ") _, _ = out.WriteString(" LEFT JOIN ")
case RIGHT_JOIN: case RIGHT_JOIN:
_, _ = out.WriteString(" RIGHT JOIN ") _, _ = out.WriteString(" RIGHT JOIN ")
case FULL_JOIN:
out.WriteString(" FULL JOIN ")
case CROSS_JOIN:
out.WriteString(" CROSS JOIN ")
} }
if err = t.rhs.SerializeSql(out); err != nil { if err = t.rhs.SerializeSql(out); err != nil {
return return
} }
_, _ = out.WriteString(" ON ") if t.onCondition != nil {
if err = t.onCondition.SerializeSql(out); err != nil { _, _ = out.WriteString(" ON ")
return if err = t.onCondition.SerializeSql(out); err != nil {
return
}
} }
return nil return nil
@ -333,6 +368,14 @@ func (t *joinTable) LeftJoinOn(
return LeftJoinOn(t, table, onCondition) return LeftJoinOn(t, table, onCondition)
} }
func (t *joinTable) FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable {
return FullJoin(t, table, col1.Eq(col2))
}
func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable {
return CrossJoin(t, table)
}
func (t *joinTable) RightJoinOn( func (t *joinTable) RightJoinOn(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {