Allow set statements to be used as tables and expressions.

This commit is contained in:
zer0sub 2019-05-05 12:37:23 +02:00
parent 5ad213885f
commit 5de001d7e0
8 changed files with 134 additions and 38 deletions

View file

@ -32,7 +32,7 @@ func Example() {
filename := t2.C("filename")
in := []int32{1, 2, 3}
join := t2.LeftJoinOn(t1, Eq(ns_id1, ns_id2))
join := t2.LEFT_JOIN(t1, Eq(ns_id1, ns_id2))
q := join.Select(ns_id2, sjid, prefix, filename).Where(
And(EqL(ns_id2, 456), In(sjid, in)))
text, _ := q.String()

View file

@ -72,6 +72,11 @@ func Query(db types.Db, query string, args []interface{}, destinationPtr interfa
return err
}
err = rows.Close()
if err != nil {
return err
}
fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROW(S) PROCESSED")
return nil

View file

@ -11,6 +11,7 @@ type SelectStatement interface {
Expression
DISTINCT() SelectStatement
FROM(table ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(expressions ...Clause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
@ -21,7 +22,11 @@ type SelectStatement interface {
FOR_UPDATE() SelectStatement
AsTable(alias string) *SelectStatementTable
AsTable(alias string) ExpressionTable
}
var SELECT = func(projection ...Projection) SelectStatement {
return newSelectStatement(nil, projection)
}
// NOTE: SelectStatement purposely does not implement the Table interface since
@ -59,8 +64,7 @@ func defaultProjectionAliasing(projections []Projection) []Projection {
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
return &selectStatementImpl{
newSelect := &selectStatementImpl{
table: table,
projections: defaultProjectionAliasing(projections),
limit: -1,
@ -68,6 +72,15 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta
forUpdate: false,
distinct: false,
}
newSelect.expressionInterfaceImpl.parent = newSelect
return newSelect
}
func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement {
s.table = table
return s
}
func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
@ -176,8 +189,8 @@ func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error
return queryData.buff.String(), queryData.args, nil
}
func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
return &SelectStatementTable{
func (s *selectStatementImpl) AsTable(alias string) ExpressionTable {
return &expressionTableImpl{
statement: s,
alias: alias,
}

View file

@ -1,45 +1,53 @@
package sqlbuilder
type SelectStatementTable struct {
statement SelectStatement
type ExpressionTable interface {
ReadableTable
RefIntColumnName(name string) *IntegerColumn
RefIntColumn(column Column) *IntegerColumn
RefStringColumn(column Column) *StringColumn
}
type expressionTableImpl struct {
statement Expression
columns []Column
alias string
}
// Returns the tableName's name in the database
func (t *SelectStatementTable) SchemaName() string {
func (t *expressionTableImpl) SchemaName() string {
return ""
}
func (s *SelectStatementTable) TableName() string {
func (s *expressionTableImpl) TableName() string {
return s.alias
}
func (s *SelectStatementTable) Columns() []Column {
func (s *expressionTableImpl) Columns() []Column {
return s.columns
}
func (s *SelectStatementTable) RefIntColumnName(name string) Column {
func (s *expressionTableImpl) RefIntColumnName(name string) *IntegerColumn {
intColumn := NewIntegerColumn(name, NotNullable)
intColumn.setTableName(s.alias)
return intColumn
}
func (s *SelectStatementTable) RefIntColumn(column Column) *IntegerColumn {
func (s *expressionTableImpl) RefIntColumn(column Column) *IntegerColumn {
intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), NotNullable)
intColumn.setTableName(s.alias)
return intColumn
}
func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn {
func (s *expressionTableImpl) RefStringColumn(column Column) *StringColumn {
strColumn := NewStringColumn(column.Name(), NotNullable)
strColumn.setTableName(column.TableName())
return strColumn
}
func (s *SelectStatementTable) SerializeSql(out *queryData) error {
func (s *expressionTableImpl) SerializeSql(out *queryData) error {
out.WriteString("( ")
err := s.statement.Serialize(out)
@ -54,33 +62,33 @@ func (s *SelectStatementTable) SerializeSql(out *queryData) error {
}
// Generates a select query on the current tableName.
func (s *SelectStatementTable) SELECT(projections ...Projection) SelectStatement {
func (s *expressionTableImpl) SELECT(projections ...Projection) SelectStatement {
return newSelectStatement(s, projections)
}
// Creates a inner join tableName expression using onCondition.
func (s *SelectStatementTable) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (s *expressionTableImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return InnerJoinOn(s, table, onCondition)
}
//func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable {
//func (s *expressionTableImpl) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable {
// return INNER_JOIN(s, table, col1.Eq(col2))
//}
// Creates a left join tableName expression using onCondition.
func (s *SelectStatementTable) LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (s *expressionTableImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return LeftJoinOn(s, table, onCondition)
}
// Creates a right join tableName expression using onCondition.
func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (s *expressionTableImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return RightJoinOn(s, table, onCondition)
}
func (s *SelectStatementTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (s *expressionTableImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(s, table, onCondition)
}
func (s *SelectStatementTable) CrossJoin(table ReadableTable) ReadableTable {
func (s *expressionTableImpl) CROSS_JOIN(table ReadableTable) ReadableTable {
return CrossJoin(s, table)
}

View file

@ -14,10 +14,13 @@ const (
type SetStatement interface {
Statement
Expression
ORDER_BY(clauses ...OrderByClause) SetStatement
LIMIT(limit int64) SetStatement
OFFSET(offset int64) SetStatement
AsTable(alias string) ExpressionTable
}
func UNION(selects ...SelectStatement) SetStatement {
@ -46,6 +49,8 @@ func EXCEPT_ALL(selects ...SelectStatement) SetStatement {
// Similar to selectStatementImpl, but less complete
type setStatementImpl struct {
expressionInterfaceImpl
operator string
selects []SelectStatement
orderBy []OrderByClause
@ -54,14 +59,18 @@ type setStatementImpl struct {
all bool
}
func newSetStatementImpl(operator string, all bool, selects ...SelectStatement) *setStatementImpl {
return &setStatementImpl{
func newSetStatementImpl(operator string, all bool, selects ...SelectStatement) SetStatement {
setStatement := &setStatementImpl{
operator: operator,
selects: selects,
limit: -1,
offset: -1,
all: all,
}
setStatement.expressionInterfaceImpl.parent = setStatement
return setStatement
}
func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement {
@ -80,7 +89,32 @@ func (us *setStatementImpl) OFFSET(offset int64) SetStatement {
return us
}
func (us *setStatementImpl) AsTable(alias string) ExpressionTable {
return &expressionTableImpl{
statement: us,
alias: alias,
}
}
func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString("(")
}
err := s.serializeImpl(out)
if err != nil {
return err
}
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString(")")
}
return nil
}
func (s *setStatementImpl) serializeImpl(out *queryData, options ...serializeOption) error {
if len(s.selects) < 2 {
return errors.Newf("UNION statement must have at least two SELECT statements.")
@ -131,7 +165,7 @@ func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption)
func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &queryData{}
err = us.Serialize(queryData)
err = us.serializeImpl(queryData)
if err != nil {
return

View file

@ -29,14 +29,14 @@ type ReadableTable interface {
INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a left join tableName expression using onCondition.
LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
LEFT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a right join tableName expression using onCondition.
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
CrossJoin(table ReadableTable) ReadableTable
CROSS_JOIN(table ReadableTable) ReadableTable
}
// The sql tableName write interface.
@ -189,7 +189,7 @@ func (t *Table) INNER_JOIN(
//}
// Creates a left join tableName expression using onCondition.
func (t *Table) LeftJoinOn(
func (t *Table) LEFT_JOIN(
table ReadableTable,
onCondition BoolExpression) ReadableTable {
@ -197,7 +197,7 @@ func (t *Table) LeftJoinOn(
}
// Creates a right join tableName expression using onCondition.
func (t *Table) RightJoinOn(
func (t *Table) RIGHT_JOIN(
table ReadableTable,
onCondition BoolExpression) ReadableTable {
@ -208,7 +208,7 @@ func (t *Table) FULL_JOIN(table ReadableTable, onCondition BoolExpression) Reada
return FullJoin(t, table, onCondition)
}
func (t *Table) CrossJoin(table ReadableTable) ReadableTable {
func (t *Table) CROSS_JOIN(table ReadableTable) ReadableTable {
return CrossJoin(t, table)
}
@ -369,7 +369,7 @@ func (t *joinTable) INNER_JOIN(
return InnerJoinOn(t, table, onCondition)
}
func (t *joinTable) LeftJoinOn(
func (t *joinTable) LEFT_JOIN(
table ReadableTable,
onCondition BoolExpression) ReadableTable {
@ -380,11 +380,11 @@ func (t *joinTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) R
return FullJoin(t, table, onCondition)
}
func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable {
func (t *joinTable) CROSS_JOIN(table ReadableTable) ReadableTable {
return CrossJoin(t, table)
}
func (t *joinTable) RightJoinOn(
func (t *joinTable) RIGHT_JOIN(
table ReadableTable,
onCondition BoolExpression) ReadableTable {

View file

@ -112,7 +112,7 @@ func (s *TableSuite) TestInnerJoin(c *gc.C) {
}
func (s *TableSuite) TestLeftJoin(c *gc.C) {
join := table1.LeftJoinOn(table2, Eq(table1Col3, table2Col3))
join := table1.LEFT_JOIN(table2, Eq(table1Col3, table2Col3))
buf := &bytes.Buffer{}
@ -128,7 +128,7 @@ func (s *TableSuite) TestLeftJoin(c *gc.C) {
}
func (s *TableSuite) TestRightJoin(c *gc.C) {
join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3))
join := table1.RIGHT_JOIN(table2, Eq(table1Col3, table2Col3))
buf := &bytes.Buffer{}
@ -144,7 +144,7 @@ func (s *TableSuite) TestRightJoin(c *gc.C) {
}
func (s *TableSuite) TestJoinColumns(c *gc.C) {
join := table1.RightJoinOn(table2, Eq(table1Col3, table2Col3))
join := table1.RIGHT_JOIN(table2, Eq(table1Col3, table2Col3))
cols := join.Columns()
c.Assert(len(cols), gc.Equals, 6)