From 5de001d7e019d278172128a7b4a54202dca937fc Mon Sep 17 00:00:00 2001 From: zer0sub Date: Sun, 5 May 2019 12:37:23 +0200 Subject: [PATCH] Allow set statements to be used as tables and expressions. --- sqlbuilder/example_test.go | 2 +- sqlbuilder/execution/execution.go | 5 ++++ sqlbuilder/select_statement.go | 23 ++++++++++++---- sqlbuilder/select_statement_table.go | 40 +++++++++++++++++----------- sqlbuilder/set_statement.go | 40 +++++++++++++++++++++++++--- sqlbuilder/table.go | 18 ++++++------- sqlbuilder/table_test.go | 6 ++--- tests/select_test.go | 38 +++++++++++++++++++++++++- 8 files changed, 134 insertions(+), 38 deletions(-) diff --git a/sqlbuilder/example_test.go b/sqlbuilder/example_test.go index 11475a9..fc61909 100644 --- a/sqlbuilder/example_test.go +++ b/sqlbuilder/example_test.go @@ -32,7 +32,7 @@ func Example() { filename := t2.C("filename") in := []int32{1, 2, 3} - join := t2.LeftJoinOn(t1, Eq(ns_id1, ns_id2)) + join := t2.LEFT_JOIN(t1, Eq(ns_id1, ns_id2)) q := join.Select(ns_id2, sjid, prefix, filename).Where( And(EqL(ns_id2, 456), In(sjid, in))) text, _ := q.String() diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index c03fb31..e9286a8 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -72,6 +72,11 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa return err } + err = rows.Close() + if err != nil { + return err + } + fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROW(S) PROCESSED") return nil diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 958db90..3cf3c6f 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -11,6 +11,7 @@ type SelectStatement interface { Expression DISTINCT() SelectStatement + FROM(table ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement GROUP_BY(expressions ...Clause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement @@ -21,7 +22,11 @@ type SelectStatement interface { FOR_UPDATE() SelectStatement - AsTable(alias string) *SelectStatementTable + AsTable(alias string) ExpressionTable +} + +var SELECT = func(projection ...Projection) SelectStatement { + return newSelectStatement(nil, projection) } // NOTE: SelectStatement purposely does not implement the Table interface since @@ -59,8 +64,7 @@ func defaultProjectionAliasing(projections []Projection) []Projection { } func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { - - return &selectStatementImpl{ + newSelect := &selectStatementImpl{ table: table, projections: defaultProjectionAliasing(projections), limit: -1, @@ -68,6 +72,15 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta forUpdate: false, distinct: false, } + + newSelect.expressionInterfaceImpl.parent = newSelect + + return newSelect +} + +func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { + s.table = table + return s } func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOption) error { @@ -176,8 +189,8 @@ func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error return queryData.buff.String(), queryData.args, nil } -func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { - return &SelectStatementTable{ +func (s *selectStatementImpl) AsTable(alias string) ExpressionTable { + return &expressionTableImpl{ statement: s, alias: alias, } diff --git a/sqlbuilder/select_statement_table.go b/sqlbuilder/select_statement_table.go index 5efef4c..11a33cc 100644 --- a/sqlbuilder/select_statement_table.go +++ b/sqlbuilder/select_statement_table.go @@ -1,45 +1,53 @@ package sqlbuilder -type SelectStatementTable struct { - statement SelectStatement +type ExpressionTable interface { + ReadableTable + + RefIntColumnName(name string) *IntegerColumn + RefIntColumn(column Column) *IntegerColumn + RefStringColumn(column Column) *StringColumn +} + +type expressionTableImpl struct { + statement Expression columns []Column alias string } // Returns the tableName's name in the database -func (t *SelectStatementTable) SchemaName() string { +func (t *expressionTableImpl) SchemaName() string { return "" } -func (s *SelectStatementTable) TableName() string { +func (s *expressionTableImpl) TableName() string { return s.alias } -func (s *SelectStatementTable) Columns() []Column { +func (s *expressionTableImpl) Columns() []Column { return s.columns } -func (s *SelectStatementTable) RefIntColumnName(name string) Column { +func (s *expressionTableImpl) RefIntColumnName(name string) *IntegerColumn { intColumn := NewIntegerColumn(name, NotNullable) intColumn.setTableName(s.alias) return intColumn } -func (s *SelectStatementTable) RefIntColumn(column Column) *IntegerColumn { +func (s *expressionTableImpl) RefIntColumn(column Column) *IntegerColumn { intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), NotNullable) intColumn.setTableName(s.alias) return intColumn } -func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn { +func (s *expressionTableImpl) RefStringColumn(column Column) *StringColumn { strColumn := NewStringColumn(column.Name(), NotNullable) strColumn.setTableName(column.TableName()) return strColumn } -func (s *SelectStatementTable) SerializeSql(out *queryData) error { +func (s *expressionTableImpl) SerializeSql(out *queryData) error { out.WriteString("( ") err := s.statement.Serialize(out) @@ -54,33 +62,33 @@ func (s *SelectStatementTable) SerializeSql(out *queryData) error { } // Generates a select query on the current tableName. -func (s *SelectStatementTable) SELECT(projections ...Projection) SelectStatement { +func (s *expressionTableImpl) SELECT(projections ...Projection) SelectStatement { return newSelectStatement(s, projections) } // Creates a inner join tableName expression using onCondition. -func (s *SelectStatementTable) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (s *expressionTableImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return InnerJoinOn(s, table, onCondition) } -//func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable { +//func (s *expressionTableImpl) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable { // return INNER_JOIN(s, table, col1.Eq(col2)) //} // Creates a left join tableName expression using onCondition. -func (s *SelectStatementTable) LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (s *expressionTableImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return LeftJoinOn(s, table, onCondition) } // Creates a right join tableName expression using onCondition. -func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (s *expressionTableImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return RightJoinOn(s, table, onCondition) } -func (s *SelectStatementTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { +func (s *expressionTableImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable { return FullJoin(s, table, onCondition) } -func (s *SelectStatementTable) CrossJoin(table ReadableTable) ReadableTable { +func (s *expressionTableImpl) CROSS_JOIN(table ReadableTable) ReadableTable { return CrossJoin(s, table) } diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index 0501b81..3034593 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -14,10 +14,13 @@ const ( type SetStatement interface { Statement + Expression ORDER_BY(clauses ...OrderByClause) SetStatement LIMIT(limit int64) SetStatement OFFSET(offset int64) SetStatement + + AsTable(alias string) ExpressionTable } func UNION(selects ...SelectStatement) SetStatement { @@ -46,6 +49,8 @@ func EXCEPT_ALL(selects ...SelectStatement) SetStatement { // Similar to selectStatementImpl, but less complete type setStatementImpl struct { + expressionInterfaceImpl + operator string selects []SelectStatement orderBy []OrderByClause @@ -54,14 +59,18 @@ type setStatementImpl struct { all bool } -func newSetStatementImpl(operator string, all bool, selects ...SelectStatement) *setStatementImpl { - return &setStatementImpl{ +func newSetStatementImpl(operator string, all bool, selects ...SelectStatement) SetStatement { + setStatement := &setStatementImpl{ operator: operator, selects: selects, limit: -1, offset: -1, all: all, } + + setStatement.expressionInterfaceImpl.parent = setStatement + + return setStatement } func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement { @@ -80,7 +89,32 @@ func (us *setStatementImpl) OFFSET(offset int64) SetStatement { return us } +func (us *setStatementImpl) AsTable(alias string) ExpressionTable { + return &expressionTableImpl{ + statement: us, + alias: alias, + } +} + func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error { + if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { + out.WriteString("(") + } + + err := s.serializeImpl(out) + + if err != nil { + return err + } + + if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { + out.WriteString(")") + } + + return nil +} + +func (s *setStatementImpl) serializeImpl(out *queryData, options ...serializeOption) error { if len(s.selects) < 2 { return errors.Newf("UNION statement must have at least two SELECT statements.") @@ -131,7 +165,7 @@ func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption) func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) { queryData := &queryData{} - err = us.Serialize(queryData) + err = us.serializeImpl(queryData) if err != nil { return diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 33fa219..567fe4e 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -29,14 +29,14 @@ type ReadableTable interface { INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable // Creates a left join tableName expression using onCondition. - LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable // Creates a right join tableName expression using onCondition. - RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable + RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable - CrossJoin(table ReadableTable) ReadableTable + CROSS_JOIN(table ReadableTable) ReadableTable } // The sql tableName write interface. @@ -189,7 +189,7 @@ func (t *Table) INNER_JOIN( //} // Creates a left join tableName expression using onCondition. -func (t *Table) LeftJoinOn( +func (t *Table) LEFT_JOIN( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -197,7 +197,7 @@ func (t *Table) LeftJoinOn( } // Creates a right join tableName expression using onCondition. -func (t *Table) RightJoinOn( +func (t *Table) RIGHT_JOIN( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -208,7 +208,7 @@ func (t *Table) FULL_JOIN(table ReadableTable, onCondition BoolExpression) Reada return FullJoin(t, table, onCondition) } -func (t *Table) CrossJoin(table ReadableTable) ReadableTable { +func (t *Table) CROSS_JOIN(table ReadableTable) ReadableTable { return CrossJoin(t, table) } @@ -369,7 +369,7 @@ func (t *joinTable) INNER_JOIN( return InnerJoinOn(t, table, onCondition) } -func (t *joinTable) LeftJoinOn( +func (t *joinTable) LEFT_JOIN( table ReadableTable, onCondition BoolExpression) ReadableTable { @@ -380,11 +380,11 @@ func (t *joinTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) R return FullJoin(t, table, onCondition) } -func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable { +func (t *joinTable) CROSS_JOIN(table ReadableTable) ReadableTable { return CrossJoin(t, table) } -func (t *joinTable) RightJoinOn( +func (t *joinTable) RIGHT_JOIN( table ReadableTable, onCondition BoolExpression) ReadableTable { diff --git a/sqlbuilder/table_test.go b/sqlbuilder/table_test.go index e280794..5c46ea0 100644 --- a/sqlbuilder/table_test.go +++ b/sqlbuilder/table_test.go @@ -112,7 +112,7 @@ func (s *TableSuite) TestInnerJoin(c *gc.C) { } func (s *TableSuite) TestLeftJoin(c *gc.C) { - join := table1.LeftJoinOn(table2, Eq(table1Col3, table2Col3)) + join := table1.LEFT_JOIN(table2, Eq(table1Col3, table2Col3)) buf := &bytes.Buffer{} @@ -128,7 +128,7 @@ func (s *TableSuite) TestLeftJoin(c *gc.C) { } func (s *TableSuite) TestRightJoin(c *gc.C) { - join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3)) + join := table1.RIGHT_JOIN(table2, Eq(table1Col3, table2Col3)) buf := &bytes.Buffer{} @@ -144,7 +144,7 @@ func (s *TableSuite) TestRightJoin(c *gc.C) { } func (s *TableSuite) TestJoinColumns(c *gc.C) { - join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3)) + join := table1.RIGHT_JOIN(table2, Eq(table1Col3, table2Col3)) cols := join.Columns() c.Assert(len(cols), gc.Equals, 6) diff --git a/tests/select_test.go b/tests/select_test.go index 4e44957..f2455cd 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -36,6 +36,25 @@ func TestSelect_ScanToStruct(t *testing.T) { assert.DeepEqual(t, actor, expectedActor) } +func TestClassicSelect(t *testing.T) { + query := sqlbuilder.SELECT(Payment.AllColumns, Customer.AllColumns). + FROM(Payment.INNER_JOIN(Customer, Payment.CustomerID.Eq(Customer.CustomerID))). + ORDER_BY(Payment.PaymentID.Asc()). + LIMIT(30) + + queryStr, args, err := query.Sql() + + assert.NilError(t, err) + fmt.Println(queryStr) + fmt.Println(args) + + dest := []model.Payment{} + + err = query.Query(db, &dest) + + assert.NilError(t, err) +} + func TestSelect_ScanToSlice(t *testing.T) { customers := []model.Customer{} @@ -58,6 +77,23 @@ func TestSelect_ScanToSlice(t *testing.T) { assert.DeepEqual(t, lastCustomer, customers[598]) } +func TestSelectAndUnionInProjection(t *testing.T) { + + query := Payment. + SELECT( + Payment.PaymentID, + Customer.SELECT(Customer.CustomerID).LIMIT(1), + sqlbuilder.UNION(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1), + ). + LIMIT(12) + + queryStr, args, err := query.Sql() + + assert.NilError(t, err) + fmt.Println(queryStr) + fmt.Println(args) +} + //func TestJoinQueryStruct(t *testing.T) { // // query := FilmActor. @@ -253,7 +289,7 @@ func TestSelectFullJoin(t *testing.T) { func TestSelectFullCrossJoin(t *testing.T) { query := Customer. - CrossJoin(Address). + CROSS_JOIN(Address). SELECT(Customer.AllColumns, Address.AllColumns). ORDER_BY(Customer.CustomerID.Asc()). LIMIT(1000)