Allow set statements to be used as tables and expressions.

This commit is contained in:
zer0sub 2019-05-05 12:37:23 +02:00
parent 5ad213885f
commit 5de001d7e0
8 changed files with 134 additions and 38 deletions

View file

@ -32,7 +32,7 @@ func Example() {
filename := t2.C("filename") filename := t2.C("filename")
in := []int32{1, 2, 3} in := []int32{1, 2, 3}
join := t2.LeftJoinOn(t1, Eq(ns_id1, ns_id2)) join := t2.LEFT_JOIN(t1, Eq(ns_id1, ns_id2))
q := join.Select(ns_id2, sjid, prefix, filename).Where( q := join.Select(ns_id2, sjid, prefix, filename).Where(
And(EqL(ns_id2, 456), In(sjid, in))) And(EqL(ns_id2, 456), In(sjid, in)))
text, _ := q.String() text, _ := q.String()

View file

@ -72,6 +72,11 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa
return err return err
} }
err = rows.Close()
if err != nil {
return err
}
fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROW(S) PROCESSED") fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROW(S) PROCESSED")
return nil return nil

View file

@ -11,6 +11,7 @@ type SelectStatement interface {
Expression Expression
DISTINCT() SelectStatement DISTINCT() SelectStatement
FROM(table ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement WHERE(expression BoolExpression) SelectStatement
GROUP_BY(expressions ...Clause) SelectStatement GROUP_BY(expressions ...Clause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement
@ -21,7 +22,11 @@ type SelectStatement interface {
FOR_UPDATE() SelectStatement FOR_UPDATE() SelectStatement
AsTable(alias string) *SelectStatementTable AsTable(alias string) ExpressionTable
}
var SELECT = func(projection ...Projection) SelectStatement {
return newSelectStatement(nil, projection)
} }
// NOTE: SelectStatement purposely does not implement the Table interface since // NOTE: SelectStatement purposely does not implement the Table interface since
@ -59,8 +64,7 @@ func defaultProjectionAliasing(projections []Projection) []Projection {
} }
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{
return &selectStatementImpl{
table: table, table: table,
projections: defaultProjectionAliasing(projections), projections: defaultProjectionAliasing(projections),
limit: -1, limit: -1,
@ -68,6 +72,15 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
forUpdate: false, forUpdate: false,
distinct: false, distinct: false,
} }
newSelect.expressionInterfaceImpl.parent = newSelect
return newSelect
}
func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement {
s.table = table
return s
} }
func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOption) error { func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
@ -176,8 +189,8 @@ func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error
return queryData.buff.String(), queryData.args, nil return queryData.buff.String(), queryData.args, nil
} }
func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { func (s *selectStatementImpl) AsTable(alias string) ExpressionTable {
return &SelectStatementTable{ return &expressionTableImpl{
statement: s, statement: s,
alias: alias, alias: alias,
} }

View file

@ -1,45 +1,53 @@
package sqlbuilder package sqlbuilder
type SelectStatementTable struct { type ExpressionTable interface {
statement SelectStatement ReadableTable
RefIntColumnName(name string) *IntegerColumn
RefIntColumn(column Column) *IntegerColumn
RefStringColumn(column Column) *StringColumn
}
type expressionTableImpl struct {
statement Expression
columns []Column columns []Column
alias string alias string
} }
// Returns the tableName's name in the database // Returns the tableName's name in the database
func (t *SelectStatementTable) SchemaName() string { func (t *expressionTableImpl) SchemaName() string {
return "" return ""
} }
func (s *SelectStatementTable) TableName() string { func (s *expressionTableImpl) TableName() string {
return s.alias return s.alias
} }
func (s *SelectStatementTable) Columns() []Column { func (s *expressionTableImpl) Columns() []Column {
return s.columns return s.columns
} }
func (s *SelectStatementTable) RefIntColumnName(name string) Column { func (s *expressionTableImpl) RefIntColumnName(name string) *IntegerColumn {
intColumn := NewIntegerColumn(name, NotNullable) intColumn := NewIntegerColumn(name, NotNullable)
intColumn.setTableName(s.alias) intColumn.setTableName(s.alias)
return intColumn return intColumn
} }
func (s *SelectStatementTable) RefIntColumn(column Column) *IntegerColumn { func (s *expressionTableImpl) RefIntColumn(column Column) *IntegerColumn {
intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), NotNullable) intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), NotNullable)
intColumn.setTableName(s.alias) intColumn.setTableName(s.alias)
return intColumn return intColumn
} }
func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn { func (s *expressionTableImpl) RefStringColumn(column Column) *StringColumn {
strColumn := NewStringColumn(column.Name(), NotNullable) strColumn := NewStringColumn(column.Name(), NotNullable)
strColumn.setTableName(column.TableName()) strColumn.setTableName(column.TableName())
return strColumn return strColumn
} }
func (s *SelectStatementTable) SerializeSql(out *queryData) error { func (s *expressionTableImpl) SerializeSql(out *queryData) error {
out.WriteString("( ") out.WriteString("( ")
err := s.statement.Serialize(out) err := s.statement.Serialize(out)
@ -54,33 +62,33 @@ func (s *SelectStatementTable) SerializeSql(out *queryData) error {
} }
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (s *SelectStatementTable) SELECT(projections ...Projection) SelectStatement { func (s *expressionTableImpl) SELECT(projections ...Projection) SelectStatement {
return newSelectStatement(s, projections) return newSelectStatement(s, projections)
} }
// Creates a inner join tableName expression using onCondition. // Creates a inner join tableName expression using onCondition.
func (s *SelectStatementTable) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { func (s *expressionTableImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return InnerJoinOn(s, table, onCondition) return InnerJoinOn(s, table, onCondition)
} }
//func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable { //func (s *expressionTableImpl) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable {
// return INNER_JOIN(s, table, col1.Eq(col2)) // return INNER_JOIN(s, table, col1.Eq(col2))
//} //}
// Creates a left join tableName expression using onCondition. // Creates a left join tableName expression using onCondition.
func (s *SelectStatementTable) LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { func (s *expressionTableImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return LeftJoinOn(s, table, onCondition) return LeftJoinOn(s, table, onCondition)
} }
// Creates a right join tableName expression using onCondition. // Creates a right join tableName expression using onCondition.
func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { func (s *expressionTableImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return RightJoinOn(s, table, onCondition) return RightJoinOn(s, table, onCondition)
} }
func (s *SelectStatementTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { func (s *expressionTableImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(s, table, onCondition) return FullJoin(s, table, onCondition)
} }
func (s *SelectStatementTable) CrossJoin(table ReadableTable) ReadableTable { func (s *expressionTableImpl) CROSS_JOIN(table ReadableTable) ReadableTable {
return CrossJoin(s, table) return CrossJoin(s, table)
} }

View file

@ -14,10 +14,13 @@ const (
type SetStatement interface { type SetStatement interface {
Statement Statement
Expression
ORDER_BY(clauses ...OrderByClause) SetStatement ORDER_BY(clauses ...OrderByClause) SetStatement
LIMIT(limit int64) SetStatement LIMIT(limit int64) SetStatement
OFFSET(offset int64) SetStatement OFFSET(offset int64) SetStatement
AsTable(alias string) ExpressionTable
} }
func UNION(selects ...SelectStatement) SetStatement { func UNION(selects ...SelectStatement) SetStatement {
@ -46,6 +49,8 @@ func EXCEPT_ALL(selects ...SelectStatement) SetStatement {
// Similar to selectStatementImpl, but less complete // Similar to selectStatementImpl, but less complete
type setStatementImpl struct { type setStatementImpl struct {
expressionInterfaceImpl
operator string operator string
selects []SelectStatement selects []SelectStatement
orderBy []OrderByClause orderBy []OrderByClause
@ -54,14 +59,18 @@ type setStatementImpl struct {
all bool all bool
} }
func newSetStatementImpl(operator string, all bool, selects ...SelectStatement) *setStatementImpl { func newSetStatementImpl(operator string, all bool, selects ...SelectStatement) SetStatement {
return &setStatementImpl{ setStatement := &setStatementImpl{
operator: operator, operator: operator,
selects: selects, selects: selects,
limit: -1, limit: -1,
offset: -1, offset: -1,
all: all, all: all,
} }
setStatement.expressionInterfaceImpl.parent = setStatement
return setStatement
} }
func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement { func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement {
@ -80,7 +89,32 @@ func (us *setStatementImpl) OFFSET(offset int64) SetStatement {
return us return us
} }
func (us *setStatementImpl) AsTable(alias string) ExpressionTable {
return &expressionTableImpl{
statement: us,
alias: alias,
}
}
func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error { func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString("(")
}
err := s.serializeImpl(out)
if err != nil {
return err
}
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString(")")
}
return nil
}
func (s *setStatementImpl) serializeImpl(out *queryData, options ...serializeOption) error {
if len(s.selects) < 2 { if len(s.selects) < 2 {
return errors.Newf("UNION statement must have at least two SELECT statements.") return errors.Newf("UNION statement must have at least two SELECT statements.")
@ -131,7 +165,7 @@ func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption)
func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) { func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &queryData{} queryData := &queryData{}
err = us.Serialize(queryData) err = us.serializeImpl(queryData)
if err != nil { if err != nil {
return return

View file

@ -29,14 +29,14 @@ type ReadableTable interface {
INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a left join tableName expression using onCondition. // Creates a left join tableName expression using onCondition.
LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a right join tableName expression using onCondition. // Creates a right join tableName expression using onCondition.
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
CrossJoin(table ReadableTable) ReadableTable CROSS_JOIN(table ReadableTable) ReadableTable
} }
// The sql tableName write interface. // The sql tableName write interface.
@ -189,7 +189,7 @@ func (t *Table) INNER_JOIN(
//} //}
// Creates a left join tableName expression using onCondition. // Creates a left join tableName expression using onCondition.
func (t *Table) LeftJoinOn( func (t *Table) LEFT_JOIN(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
@ -197,7 +197,7 @@ func (t *Table) LeftJoinOn(
} }
// Creates a right join tableName expression using onCondition. // Creates a right join tableName expression using onCondition.
func (t *Table) RightJoinOn( func (t *Table) RIGHT_JOIN(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
@ -208,7 +208,7 @@ func (t *Table) FULL_JOIN(table ReadableTable, onCondition BoolExpression) Reada
return FullJoin(t, table, onCondition) return FullJoin(t, table, onCondition)
} }
func (t *Table) CrossJoin(table ReadableTable) ReadableTable { func (t *Table) CROSS_JOIN(table ReadableTable) ReadableTable {
return CrossJoin(t, table) return CrossJoin(t, table)
} }
@ -369,7 +369,7 @@ func (t *joinTable) INNER_JOIN(
return InnerJoinOn(t, table, onCondition) return InnerJoinOn(t, table, onCondition)
} }
func (t *joinTable) LeftJoinOn( func (t *joinTable) LEFT_JOIN(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
@ -380,11 +380,11 @@ func (t *joinTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) R
return FullJoin(t, table, onCondition) return FullJoin(t, table, onCondition)
} }
func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable { func (t *joinTable) CROSS_JOIN(table ReadableTable) ReadableTable {
return CrossJoin(t, table) return CrossJoin(t, table)
} }
func (t *joinTable) RightJoinOn( func (t *joinTable) RIGHT_JOIN(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {

View file

@ -112,7 +112,7 @@ func (s *TableSuite) TestInnerJoin(c *gc.C) {
} }
func (s *TableSuite) TestLeftJoin(c *gc.C) { func (s *TableSuite) TestLeftJoin(c *gc.C) {
join := table1.LeftJoinOn(table2, Eq(table1Col3, table2Col3)) join := table1.LEFT_JOIN(table2, Eq(table1Col3, table2Col3))
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
@ -128,7 +128,7 @@ func (s *TableSuite) TestLeftJoin(c *gc.C) {
} }
func (s *TableSuite) TestRightJoin(c *gc.C) { func (s *TableSuite) TestRightJoin(c *gc.C) {
join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3)) join := table1.RIGHT_JOIN(table2, Eq(table1Col3, table2Col3))
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
@ -144,7 +144,7 @@ func (s *TableSuite) TestRightJoin(c *gc.C) {
} }
func (s *TableSuite) TestJoinColumns(c *gc.C) { func (s *TableSuite) TestJoinColumns(c *gc.C) {
join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3)) join := table1.RIGHT_JOIN(table2, Eq(table1Col3, table2Col3))
cols := join.Columns() cols := join.Columns()
c.Assert(len(cols), gc.Equals, 6) c.Assert(len(cols), gc.Equals, 6)

View file

@ -36,6 +36,25 @@ func TestSelect_ScanToStruct(t *testing.T) {
assert.DeepEqual(t, actor, expectedActor) assert.DeepEqual(t, actor, expectedActor)
} }
func TestClassicSelect(t *testing.T) {
query := sqlbuilder.SELECT(Payment.AllColumns, Customer.AllColumns).
FROM(Payment.INNER_JOIN(Customer, Payment.CustomerID.Eq(Customer.CustomerID))).
ORDER_BY(Payment.PaymentID.Asc()).
LIMIT(30)
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
fmt.Println(args)
dest := []model.Payment{}
err = query.Query(db, &dest)
assert.NilError(t, err)
}
func TestSelect_ScanToSlice(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) {
customers := []model.Customer{} customers := []model.Customer{}
@ -58,6 +77,23 @@ func TestSelect_ScanToSlice(t *testing.T) {
assert.DeepEqual(t, lastCustomer, customers[598]) assert.DeepEqual(t, lastCustomer, customers[598])
} }
func TestSelectAndUnionInProjection(t *testing.T) {
query := Payment.
SELECT(
Payment.PaymentID,
Customer.SELECT(Customer.CustomerID).LIMIT(1),
sqlbuilder.UNION(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1),
).
LIMIT(12)
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
fmt.Println(args)
}
//func TestJoinQueryStruct(t *testing.T) { //func TestJoinQueryStruct(t *testing.T) {
// //
// query := FilmActor. // query := FilmActor.
@ -253,7 +289,7 @@ func TestSelectFullJoin(t *testing.T) {
func TestSelectFullCrossJoin(t *testing.T) { func TestSelectFullCrossJoin(t *testing.T) {
query := Customer. query := Customer.
CrossJoin(Address). CROSS_JOIN(Address).
SELECT(Customer.AllColumns, Address.AllColumns). SELECT(Customer.AllColumns, Address.AllColumns).
ORDER_BY(Customer.CustomerID.Asc()). ORDER_BY(Customer.CustomerID.Asc()).
LIMIT(1000) LIMIT(1000)