Additional MySQL WITH statement tests.

This commit is contained in:
go-jet 2020-06-01 20:30:09 +02:00
parent e54e8fcabf
commit d19fdea86d
5 changed files with 103 additions and 8 deletions

View file

@ -1,7 +1,7 @@
package jet
// WITH function creates new with statement from list of common table expressions for specified dialect
func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement SerializerStatement) Statement {
func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statement Statement) Statement {
newWithImpl := &withImpl{
ctes: cte,
serializerStatementInterfaceImpl: serializerStatementInterfaceImpl{
@ -11,8 +11,12 @@ func WITH(dialect Dialect, cte ...CommonTableExpressionDefinition) func(statemen
}
newWithImpl.parent = newWithImpl
return func(primaryStatement SerializerStatement) Statement {
newWithImpl.primaryStatement = primaryStatement
return func(primaryStatement Statement) Statement {
serializerStatement, ok := primaryStatement.(SerializerStatement)
if !ok {
panic("jet: unsupported main WITH statement.")
}
newWithImpl.primaryStatement = serializerStatement
return newWithImpl
}
}

View file

@ -28,7 +28,7 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int
require.NoError(t, err)
if len(rowsAffected) > 0 {
require.Equal(t, rows, rowsAffected[0])
require.Equal(t, rowsAffected[0], rows)
}
}

View file

@ -9,7 +9,7 @@ type CommonTableExpression struct {
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement {
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}

View file

@ -9,7 +9,7 @@ type CommonTableExpression struct {
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.SerializerStatement) Statement {
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}

View file

@ -9,7 +9,7 @@ import (
"testing"
)
func TestWITH_SELECT(t *testing.T) {
func TestWITH_And_SELECT(t *testing.T) {
salesRep := CTE("sales_rep")
salesRepStaffID := Staff.StaffID.From(salesRep)
salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep)
@ -56,6 +56,97 @@ SELECT customer_sales_rep.customer_name AS "customer_name",
FROM customer_sales_rep;
`, "''", "`", -1))
_, err := stmt.Exec(db)
var dest []struct {
CustomerName string
SalesRepFullName string
}
err := stmt.Query(db, &dest)
require.Equal(t, len(dest), 599)
require.NoError(t, err)
}
//func TestWITH_And_INSERT(t *testing.T) {
// paymentsToInsert := CTE("payments_to_insert")
//
// stmt := WITH(
// paymentsToInsert.AS(
// SELECT(Payment.AllColumns).
// FROM(Payment).
// WHERE(Payment.Amount.LT(Float(0.5))),
// ),
// )(
// Payment.INSERT(Payment.AllColumns).
// QUERY(
// SELECT(paymentsToInsert.AllColumns()).
// FROM(paymentsToInsert),
// ).ON_DUPLICATE_KEY_UPDATE(
// Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))),
// ),
// )
//
// //fmt.Println(stmt.DebugSql())
//
// tx, err := db.Begin()
// require.NoError(t, err)
// defer tx.Rollback()
//
// testutils.AssertExec(t, stmt, tx, 24)
//}
func TestWITH_And_UPDATE(t *testing.T) {
paymentsToUpdate := CTE("payments_to_update")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate)
stmt := WITH(
paymentsToUpdate.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.UPDATE().
SET(Payment.Amount.SET(Float(0.0))).
WHERE(Payment.PaymentID.IN(
SELECT(paymentsToDeleteID).
FROM(paymentsToUpdate),
),
),
)
//fmt.Println(stmt.DebugSql())
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx)
}
func TestWITH_And_DELETE(t *testing.T) {
paymentsToDelete := CTE("payments_to_delete")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete)
stmt := WITH(
paymentsToDelete.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.DELETE().
WHERE(Payment.PaymentID.IN(
SELECT(paymentsToDeleteID).
FROM(paymentsToDelete),
),
),
)
//fmt.Println(stmt.DebugSql())
tx, err := db.Begin()
require.NoError(t, err)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx, 24)
}