From 7930fb23bab50f6207d6fa859576e3f9e2a5192f Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 8 Aug 2019 11:44:19 +0200 Subject: [PATCH] Handle unsupported set operators for MySQL. --- internal/jet/set_statement.go | 22 ++++--- mysql/dialect.go | 14 +++++ mysql/statements.go | 4 -- tests/mysql/select_test.go | 106 ++++++++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 13 deletions(-) diff --git a/internal/jet/set_statement.go b/internal/jet/set_statement.go index a0add2e..5bdaaae 100644 --- a/internal/jet/set_statement.go +++ b/internal/jet/set_statement.go @@ -7,37 +7,37 @@ import ( // UNION effectively appends the result of sub-queries(select statements) into single query. // It eliminates duplicate rows from its result. func UNION(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) + return newSetStatementImpl(Union, false, toSelectList(lhs, rhs, selects...)) } // UNION_ALL effectively appends the result of sub-queries(select statements) into single query. // It does not eliminates duplicate rows from its result. func UNION_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) + return newSetStatementImpl(Union, true, toSelectList(lhs, rhs, selects...)) } // INTERSECT returns all rows that are in query results. // It eliminates duplicate rows from its result. func INTERSECT(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(intersect, false, toSelectList(lhs, rhs, selects...)) + return newSetStatementImpl(Intersect, false, toSelectList(lhs, rhs, selects...)) } // INTERSECT_ALL returns all rows that are in query results. // It does not eliminates duplicate rows from its result. func INTERSECT_ALL(lhs, rhs SelectStatement, selects ...SelectStatement) SelectStatement { - return newSetStatementImpl(intersect, true, toSelectList(lhs, rhs, selects...)) + return newSetStatementImpl(Intersect, true, toSelectList(lhs, rhs, selects...)) } // EXCEPT returns all rows that are in the result of query lhs but not in the result of query rhs. // It eliminates duplicate rows from its result. func EXCEPT(lhs, rhs SelectStatement) SelectStatement { - return newSetStatementImpl(except, false, toSelectList(lhs, rhs)) + return newSetStatementImpl(Except, false, toSelectList(lhs, rhs)) } // EXCEPT_ALL returns all rows that are in the result of query lhs but not in the result of query rhs. // It does not eliminates duplicate rows from its result. func EXCEPT_ALL(lhs, rhs SelectStatement) SelectStatement { - return newSetStatementImpl(except, true, toSelectList(lhs, rhs)) + return newSetStatementImpl(Except, true, toSelectList(lhs, rhs)) } func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []SelectStatement { @@ -45,9 +45,9 @@ func toSelectList(lhs, rhs SelectStatement, selects ...SelectStatement) []Select } const ( - union = "UNION" - intersect = "INTERSECT" - except = "EXCEPT" + Union = "UNION" + Intersect = "INTERSECT" + Except = "EXCEPT" ) // Similar to selectStatementImpl, but less complete @@ -125,6 +125,10 @@ func (s *setStatementImpl) serializeImpl(out *SqlBuilder) error { return errors.New("jet: UNION Statement must have at least two SELECT statements") } + if setOverride := out.Dialect.SerializeOverride(s.operator); setOverride != nil { + return setOverride()(SelectStatementType, out) + } + out.newLine() out.WriteString("(") out.increaseIdent() diff --git a/mysql/dialect.go b/mysql/dialect.go index a0b3d50..6b05ead 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -15,6 +15,8 @@ func NewDialect() jet.Dialect { serializeOverrides["/"] = mysql_DIVISION serializeOverrides["#"] = mysql_BIT_XOR serializeOverrides[jet.StringConcatOperator] = mysql_CONCAT_operator + serializeOverrides[jet.Except] = mysql_EXCEPT + serializeOverrides[jet.Intersect] = mysql_INTERSECT mySQLDialectParams := jet.DialectParams{ Name: "MySQL", @@ -31,6 +33,18 @@ func NewDialect() jet.Dialect { return jet.NewDialect(mySQLDialectParams) } +func mysql_EXCEPT(expressions ...jet.Expression) jet.SerializeFunc { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + panic("jet: MySQL does not support EXCEPT operator.") + } +} + +func mysql_INTERSECT(expressions ...jet.Expression) jet.SerializeFunc { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + panic("jet: MySQL does not support INTERSECT operator.") + } +} + func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc { return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { if len(expressions) != 2 { diff --git a/mysql/statements.go b/mysql/statements.go index 458d029..30cd202 100644 --- a/mysql/statements.go +++ b/mysql/statements.go @@ -15,10 +15,6 @@ var ( var UNION = jet.UNION var UNION_ALL = jet.UNION_ALL -var INTERSECT = jet.INTERSECT -var INTERSECT_ALL = jet.INTERSECT_ALL -var EXCEPT = jet.EXCEPT -var EXCEPT_ALL = jet.EXCEPT_ALL //-----------------literals----------------------// diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 2f8b26d..d9b7ca2 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -215,6 +215,112 @@ LIMIT ?; LIMIT(12) testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestSelectUNION(t *testing.T) { + expectedSQL := ` +( + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? + ) + UNION + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? + ) +) +LIMIT ?; +` + query := UNION( + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2), + ).LIMIT(1) + + //fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(1)) + + query2 := Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10). + UNION(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1) + + testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2), int64(1)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestSelectINTERSECT(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, r, "jet: MySQL does not support INTERSECT operator.") + }() + + query := Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10). + INTERSECT(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1) + + //fmt.Println(query.DebugSql()) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestSelectEXCEPT(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, r, "jet: MySQL does not support EXCEPT operator.") + }() + + query := Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10). + EXCEPT(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1) + + //fmt.Println(query.DebugSql()) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} + +func TestSelectUNION_ALL(t *testing.T) { + expectedSQL := ` +( + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? + ) + UNION ALL + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? + ) +); +` + query := UNION_ALL( + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2), + ) + + //fmt.Println(query.Sql()) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(1), int64(10), int64(1), int64(2)) + + query2 := Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10). + UNION_ALL(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)) + + testutils.AssertStatementSql(t, query2, expectedSQL, int64(1), int64(10), int64(1), int64(2)) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) } func TestJoinQueryStruct(t *testing.T) {