From d19fdea86d24bd0d5b7a4c7a7caa56f4cb3471c7 Mon Sep 17 00:00:00 2001 From: go-jet Date: Mon, 1 Jun 2020 20:30:09 +0200 Subject: [PATCH] Additional MySQL WITH statement tests. --- internal/jet/with_statement.go | 10 +++- internal/testutils/test_utils.go | 2 +- mysql/with_statement.go | 2 +- postgres/with_statement.go | 2 +- tests/mysql/with_test.go | 95 +++++++++++++++++++++++++++++++- 5 files changed, 103 insertions(+), 8 deletions(-) diff --git a/internal/jet/with_statement.go b/internal/jet/with_statement.go index 6131b35..ab57067 100644 --- a/internal/jet/with_statement.go +++ b/internal/jet/with_statement.go @@ -1,7 +1,7 @@ package jet // WITH function creates new with statement from list of common table expressions for specified dialect -func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement SerializerStatement) Statement { +func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement { newWithImpl := &withImpl{ ctes: cte, serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{ @@ -11,8 +11,12 @@ func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statemen } newWithImpl.parent = newWithImpl - return func(primaryStatement SerializerStatement) Statement { - newWithImpl.primaryStatement = primaryStatement + return func(primaryStatement Statement) Statement { + serializerStatement, ok := primaryStatement.(SerializerStatement) + if !ok { + panic("jet: unsupported main WITH statement.") + } + newWithImpl.primaryStatement = serializerStatement return newWithImpl } } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 3cff7ab..e5d7a09 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -28,7 +28,7 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int require.NoError(t, err) if len(rowsAffected) > 0 { - require.Equal(t, rows, rowsAffected[0]) + require.Equal(t, rowsAffected[0], rows) } } diff --git a/mysql/with_statement.go b/mysql/with_statement.go index 5991287..35066f7 100644 --- a/mysql/with_statement.go +++ b/mysql/with_statement.go @@ -9,7 +9,7 @@ type CommonTableExpression struct { } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement { +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { return jet.WITH(Dialect, cte...) } diff --git a/postgres/with_statement.go b/postgres/with_statement.go index caa7100..c1f7a7b 100644 --- a/postgres/with_statement.go +++ b/postgres/with_statement.go @@ -9,7 +9,7 @@ type CommonTableExpression struct { } // WITH function creates new WITH statement from list of common table expressions -func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement { +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { return jet.WITH(Dialect, cte...) } diff --git a/tests/mysql/with_test.go b/tests/mysql/with_test.go index 7e3a8dd..fa53fad 100644 --- a/tests/mysql/with_test.go +++ b/tests/mysql/with_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func TestWITH_SELECT(t *testing.T) { +func TestWITH_And_SELECT(t *testing.T) { salesRep := CTE("sales_rep") salesRepStaffID := Staff.StaffID.From(salesRep) salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep) @@ -56,6 +56,97 @@ SELECT customer_sales_rep.customer_name AS "customer_name", FROM customer_sales_rep; `, "''", "`", -1)) - _, err := stmt.Exec(db) + var dest []struct { + CustomerName string + SalesRepFullName string + } + err := stmt.Query(db, &dest) + + require.Equal(t, len(dest), 599) require.NoError(t, err) } + +//func TestWITH_And_INSERT(t *testing.T) { +// paymentsToInsert := CTE("payments_to_insert") +// +// stmt := WITH( +// paymentsToInsert.AS( +// SELECT(Payment.AllColumns). +// FROM(Payment). +// WHERE(Payment.Amount.LT(Float(0.5))), +// ), +// )( +// Payment.INSERT(Payment.AllColumns). +// QUERY( +// SELECT(paymentsToInsert.AllColumns()). +// FROM(paymentsToInsert), +// ).ON_DUPLICATE_KEY_UPDATE( +// Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))), +// ), +// ) +// +// //fmt.Println(stmt.DebugSql()) +// +// tx, err := db.Begin() +// require.NoError(t, err) +// defer tx.Rollback() +// +// testutils.AssertExec(t, stmt, tx, 24) +//} + +func TestWITH_And_UPDATE(t *testing.T) { + paymentsToUpdate := CTE("payments_to_update") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate) + + stmt := WITH( + paymentsToUpdate.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.UPDATE(). + SET(Payment.Amount.SET(Float(0.0))). + WHERE(Payment.PaymentID.IN( + SELECT(paymentsToDeleteID). + FROM(paymentsToUpdate), + ), + ), + ) + + //fmt.Println(stmt.DebugSql()) + + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx) +} + +func TestWITH_And_DELETE(t *testing.T) { + paymentsToDelete := CTE("payments_to_delete") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete) + + stmt := WITH( + paymentsToDelete.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.DELETE(). + WHERE(Payment.PaymentID.IN( + SELECT(paymentsToDeleteID). + FROM(paymentsToDelete), + ), + ), + ) + + //fmt.Println(stmt.DebugSql()) + + tx, err := db.Begin() + require.NoError(t, err) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx, 24) +}