Add FULL and CROSS JOIN support.
This commit is contained in:
parent
a49c682672
commit
6101e44bdf
3 changed files with 212 additions and 24 deletions
|
|
@ -2,6 +2,7 @@ package execution
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/serenize/snaker"
|
"github.com/serenize/snaker"
|
||||||
|
|
@ -37,7 +38,6 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
|
||||||
rowData := createScanValue(columnTypes)
|
rowData := createScanValue(columnTypes)
|
||||||
|
|
||||||
scanContext := &scanContext{
|
scanContext := &scanContext{
|
||||||
|
|
||||||
columnNames: columnNames,
|
columnNames: columnNames,
|
||||||
uniqueObjectsMap: make(map[string]interface{}),
|
uniqueObjectsMap: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
|
|
@ -70,6 +70,8 @@ func Execute(db *sql.DB, query string, destinationPtr interface{}) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROWS PROCESSED")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -118,9 +120,10 @@ func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rowValue := reflect.ValueOf(row[index])
|
cellValue := cellValue(row, index)
|
||||||
|
|
||||||
|
groupKey = groupKey + reflectValueToString(cellValue)
|
||||||
|
|
||||||
groupKey = groupKey + reflectValueToString(rowValue)
|
|
||||||
} else if !isDbBaseType(fieldType.Type) {
|
} else if !isDbBaseType(fieldType.Type) {
|
||||||
var structType reflect.Type
|
var structType reflect.Type
|
||||||
if fieldType.Type.Kind() == reflect.Struct {
|
if fieldType.Type.Kind() == reflect.Struct {
|
||||||
|
|
@ -145,6 +148,30 @@ func getGroupKey(scanContext *scanContext, row []interface{}, structType reflect
|
||||||
return groupKey
|
return groupKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cellValue(row []interface{}, index int) interface{} {
|
||||||
|
//spew.Dump(row[index])
|
||||||
|
|
||||||
|
valuer, ok := row[index].(driver.Valuer)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
//fmt.Println("____________________")
|
||||||
|
//spew.Dump(row[index])
|
||||||
|
panic("Scan value doesn't implement driver.Valuer")
|
||||||
|
}
|
||||||
|
|
||||||
|
//spew.Dump(valuer)
|
||||||
|
|
||||||
|
value, err := valuer.Value()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//spew.Dump(value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
func getSliceStructType(slicePtr interface{}) reflect.Type {
|
func getSliceStructType(slicePtr interface{}) reflect.Type {
|
||||||
sliceTypePtr := reflect.TypeOf(slicePtr)
|
sliceTypePtr := reflect.TypeOf(slicePtr)
|
||||||
|
|
||||||
|
|
@ -312,26 +339,38 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, columnProcessed [
|
||||||
//columnName := snaker.CamelToSnake(fieldName)
|
//columnName := snaker.CamelToSnake(fieldName)
|
||||||
|
|
||||||
////fmt.Println(columnName)
|
////fmt.Println(columnName)
|
||||||
rowIndex := getIndex(scanContext.columnNames, columnName)
|
index := getIndex(scanContext.columnNames, columnName)
|
||||||
|
|
||||||
if rowIndex < 0 || columnProcessed[rowIndex] {
|
if index < 0 || columnProcessed[index] {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
////spew.Dump(row[rowIndex])
|
////spew.Dump(row[index])
|
||||||
|
|
||||||
rowColumnValue := reflect.ValueOf(row[rowIndex])
|
cellValue := cellValue(row, index)
|
||||||
|
//spew.Dump(cellValue)
|
||||||
|
|
||||||
//spew.Dump(rowColumnValue, fieldValue)
|
//spew.Dump(rowColumnValue, fieldValue)
|
||||||
setReflectValue(rowColumnValue, fieldValue)
|
if cellValue != nil {
|
||||||
|
setReflectValue(reflect.ValueOf(cellValue), fieldValue)
|
||||||
|
}
|
||||||
|
|
||||||
columnProcessed[rowIndex] = true
|
columnProcessed[index] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func reflectValueToString(value reflect.Value) string {
|
func reflectValueToString(val interface{}) string {
|
||||||
|
//spew.Dump(val)
|
||||||
|
|
||||||
|
if val == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
value := reflect.ValueOf(val)
|
||||||
|
|
||||||
|
//if !value.IsValid()
|
||||||
var valueInterface interface{}
|
var valueInterface interface{}
|
||||||
if value.Kind() == reflect.Ptr {
|
if value.Kind() == reflect.Ptr {
|
||||||
valueInterface = value.Elem().Interface()
|
valueInterface = value.Elem().Interface()
|
||||||
|
|
@ -372,7 +411,9 @@ func setReflectValue(source, destination reflect.Value) {
|
||||||
if source.Kind() == reflect.Ptr {
|
if source.Kind() == reflect.Ptr {
|
||||||
destination.Set(source)
|
destination.Set(source)
|
||||||
} else {
|
} else {
|
||||||
destination.Set(source.Addr())
|
newDestination := reflect.New(destination.Type().Elem())
|
||||||
|
newDestination.Elem().Set(source)
|
||||||
|
destination.Set(newDestination)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if source.Kind() == reflect.Ptr {
|
if source.Kind() == reflect.Ptr {
|
||||||
|
|
@ -397,7 +438,7 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
|
||||||
values := make([]interface{}, len(columnTypes))
|
values := make([]interface{}, len(columnTypes))
|
||||||
|
|
||||||
for i, sqlColumnType := range columnTypes {
|
for i, sqlColumnType := range columnTypes {
|
||||||
columnType := getScanType(sqlColumnType)
|
columnType := newScanType(sqlColumnType)
|
||||||
|
|
||||||
columnValue := reflect.New(columnType)
|
columnValue := reflect.New(columnType)
|
||||||
|
|
||||||
|
|
@ -407,17 +448,32 @@ func createScanValue(columnTypes []*sql.ColumnType) []interface{} {
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func getScanType(columnType *sql.ColumnType) reflect.Type {
|
var nullFloatType = reflect.TypeOf(sql.NullFloat64{})
|
||||||
scanType := columnType.ScanType()
|
var nullInt16Type = reflect.TypeOf(NullInt16{})
|
||||||
//////fmt.Println(scanType.String())
|
var nullInt32Type = reflect.TypeOf(NullInt32{})
|
||||||
if scanType.String() != "interface {}" {
|
var nullInt64Type = reflect.TypeOf(sql.NullInt64{})
|
||||||
return scanType
|
var nullStringType = reflect.TypeOf(sql.NullString{})
|
||||||
}
|
var nullBoolType = reflect.TypeOf(sql.NullBool{})
|
||||||
|
var nullTimeType = reflect.TypeOf(NullTime{})
|
||||||
|
|
||||||
|
func newScanType(columnType *sql.ColumnType) reflect.Type {
|
||||||
|
//spew.Dump(columnType)
|
||||||
switch columnType.DatabaseTypeName() {
|
switch columnType.DatabaseTypeName() {
|
||||||
|
case "INT2":
|
||||||
|
return nullInt16Type
|
||||||
|
case "INT4":
|
||||||
|
return nullInt32Type
|
||||||
|
case "INT8":
|
||||||
|
return nullInt64Type
|
||||||
|
case "VARCHAR", "TEXT", "", "_TEXT", "TSVECTOR", "BPCHAR":
|
||||||
|
return nullStringType
|
||||||
case "FLOAT4":
|
case "FLOAT4":
|
||||||
return floatType
|
return nullFloatType
|
||||||
|
case "BOOL":
|
||||||
|
return nullBoolType
|
||||||
|
case "DATE", "TIMESTAMP":
|
||||||
|
return nullTimeType
|
||||||
default:
|
default:
|
||||||
return stringType
|
panic("Unknown column database type " + columnType.DatabaseTypeName())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
89
sqlbuilder/execution/null_types.go
Normal file
89
sqlbuilder/execution/null_types.go
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
package execution
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NullTime struct {
|
||||||
|
Time time.Time
|
||||||
|
Valid bool // Valid is true if Time is not NULL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the Scanner interface.
|
||||||
|
func (nt *NullTime) Scan(value interface{}) error {
|
||||||
|
nt.Time, nt.Valid = value.(time.Time)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (nt NullTime) Value() (driver.Value, error) {
|
||||||
|
if !nt.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nt.Time, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type NullInt32 struct {
|
||||||
|
Int32 int32
|
||||||
|
Valid bool // Valid is true if Int64 is not NULL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the Scanner interface.
|
||||||
|
func (n *NullInt32) Scan(value interface{}) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
n.Int32, n.Valid = int32(v), true
|
||||||
|
return nil
|
||||||
|
case int32:
|
||||||
|
n.Int32, n.Valid = v, true
|
||||||
|
return nil
|
||||||
|
case uint8:
|
||||||
|
n.Int32, n.Valid = int32(v), true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n.Valid = false
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (n NullInt32) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return n.Int32, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type NullInt16 struct {
|
||||||
|
Int16 int16
|
||||||
|
Valid bool // Valid is true if Int64 is not NULL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the Scanner interface.
|
||||||
|
func (n *NullInt16) Scan(value interface{}) error {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
n.Int16, n.Valid = int16(v), true
|
||||||
|
return nil
|
||||||
|
case int16:
|
||||||
|
n.Int16, n.Valid = v, true
|
||||||
|
return nil
|
||||||
|
case uint8:
|
||||||
|
n.Int16, n.Valid = int16(v), true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n.Valid = false
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the driver Valuer interface.
|
||||||
|
func (n NullInt16) Value() (driver.Value, error) {
|
||||||
|
if !n.Valid {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return n.Int16, nil
|
||||||
|
}
|
||||||
|
|
@ -33,6 +33,10 @@ type ReadableTable interface {
|
||||||
|
|
||||||
// Creates a right join table expression using onCondition.
|
// Creates a right join table expression using onCondition.
|
||||||
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
|
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
|
||||||
|
|
||||||
|
FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable
|
||||||
|
|
||||||
|
CrossJoin(table ReadableTable) ReadableTable
|
||||||
}
|
}
|
||||||
|
|
||||||
// The sql table write interface.
|
// The sql table write interface.
|
||||||
|
|
@ -196,6 +200,14 @@ func (t *Table) RightJoinOn(
|
||||||
return RightJoinOn(t, table, onCondition)
|
return RightJoinOn(t, table, onCondition)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Table) FullJoin(table ReadableTable, col1, col2 Column) ReadableTable {
|
||||||
|
return FullJoin(t, table, col1.Eq(col2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Table) CrossJoin(table ReadableTable) ReadableTable {
|
||||||
|
return CrossJoin(t, table)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement {
|
func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement {
|
||||||
return newInsertStatement(t, columns...)
|
return newInsertStatement(t, columns...)
|
||||||
}
|
}
|
||||||
|
|
@ -214,6 +226,8 @@ const (
|
||||||
INNER_JOIN joinType = iota
|
INNER_JOIN joinType = iota
|
||||||
LEFT_JOIN
|
LEFT_JOIN
|
||||||
RIGHT_JOIN
|
RIGHT_JOIN
|
||||||
|
FULL_JOIN
|
||||||
|
CROSS_JOIN
|
||||||
)
|
)
|
||||||
|
|
||||||
// Join expressions are pseudo readable tables.
|
// Join expressions are pseudo readable tables.
|
||||||
|
|
@ -262,6 +276,21 @@ func RightJoinOn(
|
||||||
return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition)
|
return newJoinTable(lhs, rhs, RIGHT_JOIN, onCondition)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FullJoin(
|
||||||
|
lhs ReadableTable,
|
||||||
|
rhs ReadableTable,
|
||||||
|
onCondition BoolExpression) ReadableTable {
|
||||||
|
|
||||||
|
return newJoinTable(lhs, rhs, FULL_JOIN, onCondition)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CrossJoin(
|
||||||
|
lhs ReadableTable,
|
||||||
|
rhs ReadableTable) ReadableTable {
|
||||||
|
|
||||||
|
return newJoinTable(lhs, rhs, CROSS_JOIN, nil)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *joinTable) Columns() []NonAliasColumn {
|
func (t *joinTable) Columns() []NonAliasColumn {
|
||||||
columns := make([]NonAliasColumn, 0)
|
columns := make([]NonAliasColumn, 0)
|
||||||
columns = append(columns, t.lhs.Columns()...)
|
columns = append(columns, t.lhs.Columns()...)
|
||||||
|
|
@ -278,7 +307,7 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
|
||||||
if t.rhs == nil {
|
if t.rhs == nil {
|
||||||
return errors.Newf("nil rhs. Generated sql: %s", out.String())
|
return errors.Newf("nil rhs. Generated sql: %s", out.String())
|
||||||
}
|
}
|
||||||
if t.onCondition == nil {
|
if t.onCondition == nil && t.join_type != CROSS_JOIN {
|
||||||
return errors.Newf("nil onCondition. Generated sql: %s", out.String())
|
return errors.Newf("nil onCondition. Generated sql: %s", out.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -293,15 +322,21 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
|
||||||
_, _ = out.WriteString(" LEFT JOIN ")
|
_, _ = out.WriteString(" LEFT JOIN ")
|
||||||
case RIGHT_JOIN:
|
case RIGHT_JOIN:
|
||||||
_, _ = out.WriteString(" RIGHT JOIN ")
|
_, _ = out.WriteString(" RIGHT JOIN ")
|
||||||
|
case FULL_JOIN:
|
||||||
|
out.WriteString(" FULL JOIN ")
|
||||||
|
case CROSS_JOIN:
|
||||||
|
out.WriteString(" CROSS JOIN ")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = t.rhs.SerializeSql(out); err != nil {
|
if err = t.rhs.SerializeSql(out); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _ = out.WriteString(" ON ")
|
if t.onCondition != nil {
|
||||||
if err = t.onCondition.SerializeSql(out); err != nil {
|
_, _ = out.WriteString(" ON ")
|
||||||
return
|
if err = t.onCondition.SerializeSql(out); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -333,6 +368,14 @@ func (t *joinTable) LeftJoinOn(
|
||||||
return LeftJoinOn(t, table, onCondition)
|
return LeftJoinOn(t, table, onCondition)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *joinTable) FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable {
|
||||||
|
return FullJoin(t, table, col1.Eq(col2))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable {
|
||||||
|
return CrossJoin(t, table)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *joinTable) RightJoinOn(
|
func (t *joinTable) RightJoinOn(
|
||||||
table ReadableTable,
|
table ReadableTable,
|
||||||
onCondition BoolExpression) ReadableTable {
|
onCondition BoolExpression) ReadableTable {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue