diff --git a/row_type.go b/row_type.go deleted file mode 100644 index 2777373..0000000 --- a/row_type.go +++ /dev/null @@ -1,6 +0,0 @@ -package jet - -type rowsType interface { - clause - projections() []projection -} diff --git a/select_statement.go b/select_statement.go index 2a6670b..2174a7b 100644 --- a/select_statement.go +++ b/select_statement.go @@ -28,6 +28,13 @@ type SelectStatement interface { OFFSET(offset int64) SelectStatement FOR(lock SelectLock) SelectStatement + UNION(rhs SelectStatement) SelectStatement + UNION_ALL(rhs SelectStatement) SelectStatement + INTERSECT(rhs SelectStatement) SelectStatement + INTERSECT_ALL(rhs SelectStatement) SelectStatement + EXCEPT(rhs SelectStatement) SelectStatement + EXCEPT_ALL(rhs SelectStatement) SelectStatement + AsTable(alias string) ExpressionTable projections() []projection @@ -39,6 +46,7 @@ func SELECT(projection1 projection, projections ...projection) SelectStatement { type selectStatementImpl struct { expressionInterfaceImpl + parent SelectStatement table ReadableTable distinct bool @@ -46,8 +54,8 @@ type selectStatementImpl struct { where BoolExpression groupBy []groupByClause having BoolExpression - orderBy []OrderByClause + orderBy []OrderByClause limit, offset int64 lockFor SelectLock @@ -63,13 +71,86 @@ func newSelectStatement(table ReadableTable, projections []projection) SelectSta } newSelect.expressionInterfaceImpl.parent = newSelect + newSelect.parent = newSelect return newSelect } func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { s.table = table - return s + return s.parent +} + +func (s *selectStatementImpl) AsTable(alias string) ExpressionTable { + return newExpressionTable(s.parent, alias, s.parent.projections()) +} + +func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { + s.where = expression + return s.parent +} + +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) SelectStatement { + s.groupBy = groupByClauses + return s.parent +} + +func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement { + s.having = expression + return s.parent +} + +func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { + s.orderBy = clauses + return s.parent +} + +func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { + s.offset = offset + return s.parent +} + +func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { + s.limit = limit + return s.parent +} + +func (s *selectStatementImpl) DISTINCT() SelectStatement { + s.distinct = true + return s.parent +} + +func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { + s.lockFor = lock + return s.parent +} + +func (s *selectStatementImpl) UNION(rhs SelectStatement) SelectStatement { + return UNION(s.parent, rhs) +} + +func (s *selectStatementImpl) UNION_ALL(rhs SelectStatement) SelectStatement { + return UNION_ALL(s.parent, rhs) +} + +func (s *selectStatementImpl) INTERSECT(rhs SelectStatement) SelectStatement { + return INTERSECT(s.parent, rhs) +} + +func (s *selectStatementImpl) INTERSECT_ALL(rhs SelectStatement) SelectStatement { + return INTERSECT_ALL(s.parent, rhs) +} + +func (s *selectStatementImpl) EXCEPT(rhs SelectStatement) SelectStatement { + return EXCEPT(s.parent, rhs) +} + +func (s *selectStatementImpl) EXCEPT_ALL(rhs SelectStatement) SelectStatement { + return EXCEPT_ALL(s.parent, rhs) +} + +func (s *selectStatementImpl) projections() []projection { + return s.projectionList } func (s *selectStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { @@ -192,56 +273,26 @@ func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error } func (s *selectStatementImpl) DebugSql() (query string, err error) { - return debugSql(s) + return debugSql(s.parent) } -func (s *selectStatementImpl) projections() []projection { - return s.projectionList +func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error { + return query(s.parent, db, destination) } -func (s *selectStatementImpl) AsTable(alias string) ExpressionTable { - return newExpressionTable(s.parent, alias, s.projectionList) +func (s *selectStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { + return queryContext(s.parent, db, context, destination) } -func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { - s.where = expression - return s +func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { + return exec(s.parent, db) } -func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) SelectStatement { - s.groupBy = groupByClauses - return s +func (s *selectStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { + return execContext(s.parent, db, context) } -func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement { - s.having = expression - return s -} - -func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { - s.orderBy = clauses - return s -} - -func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { - s.offset = offset - return s -} - -func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { - s.limit = limit - return s -} - -func (s *selectStatementImpl) DISTINCT() SelectStatement { - s.distinct = true - return s -} - -func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { - s.lockFor = lock - return s -} +// SelectLock type SelectLock interface { clause @@ -288,19 +339,3 @@ func (s *selectLockImpl) serialize(statement statementType, out *queryData, opti return nil } - -func (s *selectStatementImpl) Query(db execution.DB, destination interface{}) error { - return query(s, db, destination) -} - -func (s *selectStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { - return queryContext(s, db, context, destination) -} - -func (s *selectStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { - return exec(s, db) -} - -func (s *selectStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { - return execContext(s, db, context) -} diff --git a/set_statement.go b/set_statement.go index 8efca36..e28f0ec 100644 --- a/set_statement.go +++ b/set_statement.go @@ -1,23 +1,35 @@ package jet import ( - "context" - "database/sql" "errors" - "github.com/go-jet/jet/execution" ) -type SetStatement interface { - Statement - Expression +func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { + return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) +} - ORDER_BY(clauses ...OrderByClause) SetStatement - LIMIT(limit int64) SetStatement - OFFSET(offset int64) SetStatement +func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { + return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) +} - AsTable(alias string) ExpressionTable +func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { + return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...)) +} - projections() []projection +func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { + return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...)) +} + +func EXCEPT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { + return newSetStatementImpl(except, false, toSelectList(lhs, rhs, selects...)) +} + +func EXCEPT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { + return newSetStatementImpl(except, true, toSelectList(lhs, rhs, selects...)) +} + +func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement { + return append([]SelectStatement{lhs, rhs}, selects...) } const ( @@ -26,71 +38,30 @@ const ( except = "EXCEPT" ) -func UNION(selects ...rowsType) SetStatement { - return newSetStatementImpl(union, false, selects...) -} - -func UNION_ALL(selects ...rowsType) SetStatement { - return newSetStatementImpl(union, true, selects...) -} - -func INTERSECT(selects ...rowsType) SetStatement { - return newSetStatementImpl(intersect, false, selects...) -} - -func INTERSECT_ALL(selects ...rowsType) SetStatement { - return newSetStatementImpl(intersect, true, selects...) -} - -func EXCEPT(selects ...rowsType) SetStatement { - return newSetStatementImpl(except, false, selects...) -} - -func EXCEPT_ALL(selects ...rowsType) SetStatement { - return newSetStatementImpl(except, true, selects...) -} - // Similar to selectStatementImpl, but less complete type setStatementImpl struct { - expressionInterfaceImpl + selectStatementImpl - operator string - selects []rowsType - orderBy []OrderByClause - limit, offset int64 - - all bool + operator string + all bool + selects []SelectStatement } -func newSetStatementImpl(operator string, all bool, selects ...rowsType) SetStatement { +func newSetStatementImpl(operator string, all bool, selects []SelectStatement) SelectStatement { setStatement := &setStatementImpl{ operator: operator, - selects: selects, - limit: -1, - offset: -1, all: all, + selects: selects, } - setStatement.expressionInterfaceImpl.parent = setStatement + setStatement.selectStatementImpl.expressionInterfaceImpl.parent = setStatement + setStatement.selectStatementImpl.parent = setStatement + setStatement.limit = -1 + setStatement.offset = -1 return setStatement } -func (s *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement { - s.orderBy = orderBy - return s -} - -func (s *setStatementImpl) LIMIT(limit int64) SetStatement { - s.limit = limit - return s -} - -func (s *setStatementImpl) OFFSET(offset int64) SetStatement { - s.offset = offset - return s -} - func (s *setStatementImpl) projections() []projection { if len(s.selects) > 0 { return s.selects[0].projections() @@ -98,10 +69,6 @@ func (s *setStatementImpl) projections() []projection { return []projection{} } -func (s *setStatementImpl) AsTable(alias string) ExpressionTable { - return newExpressionTable(s.parent, alias, s.projections()) -} - func (s *setStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { if s == nil { return errors.New("Set expression is nil. ") @@ -153,6 +120,10 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { out.newLine() } + if selectStmt == nil { + return errors.New("select statement is nil") + } + err := selectStmt.serialize(set_statement, out) if err != nil { @@ -198,23 +169,3 @@ func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) { query, args = queryData.finalize() return } - -func (s *setStatementImpl) DebugSql() (query string, err error) { - return debugSql(s) -} - -func (s *setStatementImpl) Query(db execution.DB, destination interface{}) error { - return query(s, db, destination) -} - -func (s *setStatementImpl) QueryContext(db execution.DB, context context.Context, destination interface{}) error { - return queryContext(s, db, context, destination) -} - -func (s *setStatementImpl) Exec(db execution.DB) (res sql.Result, err error) { - return exec(s, db) -} - -func (s *setStatementImpl) ExecContext(db execution.DB, context context.Context) (res sql.Result, err error) { - return execContext(s, db, context) -} diff --git a/set_statement_test.go b/set_statement_test.go index 87b57bb..9741c14 100644 --- a/set_statement_test.go +++ b/set_statement_test.go @@ -1,34 +1,13 @@ package jet import ( + "fmt" "gotest.tools/assert" "testing" ) -func TestUnionNoSelect(t *testing.T) { - _, _, err := UNION().Sql() - - assert.Assert(t, err != nil) - //fmt.Println(err.Error()) - //fmt.Print(query, args) -} - -func TestUnionOneSelect(t *testing.T) { - _, _, err := UNION( - table1.SELECT(table1Col1), - ).Sql() - - assert.Assert(t, err != nil) -} - func TestUnionTwoSelect(t *testing.T) { - query, args, err := UNION( - table1.SELECT(table1Col1), - table2.SELECT(table2Col3), - ).Sql() - - assert.NilError(t, err) - assert.Equal(t, query, ` + var expectedSql = ` ( ( SELECT table1.col1 AS "table1.col1" @@ -40,19 +19,71 @@ func TestUnionTwoSelect(t *testing.T) { FROM db.table2 ) ); -`) - assert.Equal(t, len(args), 0) +` + unionStmt1 := table1. + SELECT(table1Col1). + UNION( + table2.SELECT(table2Col3), + ) + + unionStmt2 := UNION(table1.SELECT(table1Col1), table2.SELECT(table2Col3)) + + assertStatement(t, unionStmt1, expectedSql) + assertStatement(t, unionStmt2, expectedSql) } -func TestUnionThreeSelect(t *testing.T) { - query, args, err := UNION( +func TestUnionNilSelect(t *testing.T) { + unionStmt := table1. + SELECT(table1Col1). + UNION(nil) + + assertStatementErr(t, unionStmt, "select statement is nil") +} + +func TestUnionThreeSelect1(t *testing.T) { + + unionStmt1 := table1.SELECT(table1Col1). + UNION( + table2.SELECT(table2Col3), + ). + UNION( + table3.SELECT(table3Col1), + ) + + var expectedSql = ` +( + + ( + ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 + ) + UNION + ( + SELECT table2.col3 AS "table2.col3" + FROM db.table2 + ) + ) + UNION + ( + SELECT table3.col1 AS "table3.col1" + FROM db.table3 + ) +); +` + + assertStatement(t, unionStmt1, expectedSql) +} + +func TestUnionThreeSelect2(t *testing.T) { + + unionStmt2 := UNION( table1.SELECT(table1Col1), table2.SELECT(table2Col3), table3.SELECT(table3Col1), - ).Sql() + ) - assert.NilError(t, err) - assert.Equal(t, query, ` + var expectedSql = ` ( ( SELECT table1.col1 AS "table1.col1" @@ -69,18 +100,19 @@ func TestUnionThreeSelect(t *testing.T) { FROM db.table3 ) ); -`) - assert.Equal(t, len(args), 0) +` + + assertStatement(t, unionStmt2, expectedSql) } func TestUnionWithOrderBy(t *testing.T) { - query, args, err := UNION( + unionStmt := UNION( table1.SELECT(table1Col1), table2.SELECT(table2Col3), - ).ORDER_BY(table1Col1.ASC()).Sql() + ). + ORDER_BY(table1Col1.ASC()) - assert.NilError(t, err) - assert.Equal(t, query, ` + assertStatement(t, unionStmt, ` ( ( SELECT table1.col1 AS "table1.col1" @@ -94,14 +126,15 @@ func TestUnionWithOrderBy(t *testing.T) { ) ORDER BY "table1.col1" ASC; `) - assert.Equal(t, len(args), 0) } -func TestUnionWithLimit(t *testing.T) { +func TestUnionWithLimitAndOffset(t *testing.T) { query, args, err := UNION( table1.SELECT(table1Col1), table2.SELECT(table2Col3), - ).LIMIT(10).OFFSET(11).Sql() + ). + LIMIT(10). + OFFSET(11).Sql() assert.NilError(t, err) assert.Equal(t, query, ` @@ -150,11 +183,8 @@ func TestUnionInUnion(t *testing.T) { UNION_ALL(table1.SELECT(table1Col1), table2.SELECT(table2Col3)), ) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - assert.Equal(t, len(args), 0) - assert.Equal(t, queryStr, expectedSql) + fmt.Println(query.Sql()) + assertStatement(t, query, expectedSql) } func TestUnionALL(t *testing.T) { diff --git a/tests/select_test.go b/tests/select_test.go index 4dfc337..cab7198 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -963,6 +963,41 @@ OFFSET 20; }) } +func TestAllSetOperators(t *testing.T) { + + select1 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17600)).AND(Payment.PaymentID.LT(Int(17610)))) + select2 := Payment.SELECT(Payment.AllColumns).WHERE(Payment.PaymentID.GT_EQ(Int(17620)).AND(Payment.PaymentID.LT(Int(17630)))) + + type setOperator func(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement + operators := []setOperator{ + UNION, + UNION_ALL, + INTERSECT, + INTERSECT_ALL, + EXCEPT, + EXCEPT_ALL, + } + + expectedDestLen := []int{ + 20, + 20, + 0, + 0, + 10, + 10, + } + + for i, operator := range operators { + query := operator(select1, select2) + + dest := []model.Payment{} + err := query.Query(db, &dest) + + assert.NilError(t, err) + assert.Equal(t, len(dest), expectedDestLen[i]) + } +} + func TestSelectWithCase(t *testing.T) { expectedQuery := ` SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE' WHEN 2 THEN 'TWO' WHEN 3 THEN 'THREE' ELSE 'OTHER' END) AS "staff_id_num"