Add WITH RECURSIVE statement support
This commit is contained in:
parent
001d64f1dc
commit
038a32b032
17 changed files with 695 additions and 91 deletions
|
|
@ -149,7 +149,26 @@ func TestWITH_And_DELETE(t *testing.T) {
|
|||
),
|
||||
)
|
||||
|
||||
//fmt.Println(stmt.DebugSql())
|
||||
// fmt.Println(stmt.DebugSql())
|
||||
|
||||
testutils.AssertDebugStatementSql(t, stmt, strings.ReplaceAll(`
|
||||
WITH payments_to_delete AS (
|
||||
SELECT payment.payment_id AS "payment.payment_id",
|
||||
payment.customer_id AS "payment.customer_id",
|
||||
payment.staff_id AS "payment.staff_id",
|
||||
payment.rental_id AS "payment.rental_id",
|
||||
payment.amount AS "payment.amount",
|
||||
payment.payment_date AS "payment.payment_date",
|
||||
payment.last_update AS "payment.last_update"
|
||||
FROM dvds.payment
|
||||
WHERE payment.amount < 0.5
|
||||
)
|
||||
DELETE FROM dvds.payment
|
||||
WHERE payment.payment_id IN (
|
||||
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
|
||||
FROM payments_to_delete
|
||||
);
|
||||
`, "''", "`"))
|
||||
|
||||
tx, err := db.Begin()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -157,3 +176,119 @@ func TestWITH_And_DELETE(t *testing.T) {
|
|||
|
||||
testutils.AssertExec(t, stmt, tx, 24)
|
||||
}
|
||||
|
||||
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
|
||||
// CTE columns are listed as part of CTE definition
|
||||
n1 := IntegerColumn("n1")
|
||||
fibN1 := IntegerColumn("fibN1")
|
||||
nextFibN1 := IntegerColumn("nextFibN1")
|
||||
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
|
||||
|
||||
// CTE columns are columns from non-recursive select
|
||||
fibonacci2 := CTE("fibonacci2")
|
||||
n2 := IntegerColumn("n2").From(fibonacci2)
|
||||
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
|
||||
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
|
||||
|
||||
stmt := WITH_RECURSIVE(
|
||||
fibonacci1.AS(
|
||||
SELECT(
|
||||
Int32(1), Int32(0), Int32(1),
|
||||
).UNION_ALL(
|
||||
SELECT(
|
||||
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
|
||||
).FROM(
|
||||
fibonacci1,
|
||||
).WHERE(
|
||||
n1.LT(Int(20)),
|
||||
),
|
||||
),
|
||||
),
|
||||
fibonacci2.AS(
|
||||
SELECT(
|
||||
Int32(1).AS(n2.Name()),
|
||||
Int32(0).AS(fibN2.Name()),
|
||||
Int32(1).AS(nextFibN2.Name()),
|
||||
).UNION_ALL(
|
||||
SELECT(
|
||||
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
|
||||
).FROM(
|
||||
fibonacci2,
|
||||
).WHERE(
|
||||
n2.LT(Int(20)),
|
||||
),
|
||||
),
|
||||
),
|
||||
)(
|
||||
SELECT(
|
||||
fibonacci1.AllColumns(),
|
||||
fibonacci2.AllColumns(),
|
||||
).FROM(
|
||||
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
|
||||
).WHERE(
|
||||
n1.EQ(Int(20)),
|
||||
),
|
||||
)
|
||||
|
||||
// fmt.Println(stmt.Sql())
|
||||
|
||||
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
|
||||
WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS (
|
||||
(
|
||||
SELECT ?,
|
||||
?,
|
||||
?
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT fibonacci1.n1 + ?,
|
||||
fibonacci1.''nextFibN1'' AS "nextFibN1",
|
||||
fibonacci1.''fibN1'' + fibonacci1.''nextFibN1''
|
||||
FROM fibonacci1
|
||||
WHERE fibonacci1.n1 < ?
|
||||
)
|
||||
),fibonacci2 AS (
|
||||
(
|
||||
SELECT ? AS "n2",
|
||||
? AS "fibN2",
|
||||
? AS "nextFibN2"
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT fibonacci2.n2 + ?,
|
||||
fibonacci2.''nextFibN2'' AS "nextFibN2",
|
||||
fibonacci2.''fibN2'' + fibonacci2.''nextFibN2''
|
||||
FROM fibonacci2
|
||||
WHERE fibonacci2.n2 < ?
|
||||
)
|
||||
)
|
||||
SELECT fibonacci1.n1 AS "n1",
|
||||
fibonacci1.''fibN1'' AS "fibN1",
|
||||
fibonacci1.''nextFibN1'' AS "nextFibN1",
|
||||
fibonacci2.n2 AS "n2",
|
||||
fibonacci2.''fibN2'' AS "fibN2",
|
||||
fibonacci2.''nextFibN2'' AS "nextFibN2"
|
||||
FROM fibonacci1
|
||||
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
|
||||
WHERE fibonacci1.n1 = ?;
|
||||
`, "''", "`"))
|
||||
|
||||
var dest struct {
|
||||
N1 int
|
||||
FibN1 int
|
||||
NextFibN1 int
|
||||
|
||||
N2 int
|
||||
FibN2 int
|
||||
NextFibN2 int
|
||||
}
|
||||
|
||||
err := stmt.Query(db, &dest)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dest.N1, 20)
|
||||
require.Equal(t, dest.FibN1, 4181)
|
||||
require.Equal(t, dest.NextFibN1, 6765)
|
||||
require.Equal(t, dest.N2, 20)
|
||||
require.Equal(t, dest.FibN2, 4181)
|
||||
require.Equal(t, dest.NextFibN2, 6765)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -814,10 +814,10 @@ ORDER BY f1.film_id ASC;
|
|||
type F1 model.Film
|
||||
type F2 model.Film
|
||||
|
||||
theSameLengthFilms := []struct {
|
||||
var theSameLengthFilms []struct {
|
||||
F1 F1
|
||||
F2 F2
|
||||
}{}
|
||||
}
|
||||
|
||||
err := query.Query(db, &theSameLengthFilms)
|
||||
|
||||
|
|
@ -858,7 +858,7 @@ LIMIT 1000;
|
|||
Title2 string
|
||||
Length int16
|
||||
}
|
||||
films := []thesameLengthFilms{}
|
||||
var films []thesameLengthFilms
|
||||
|
||||
err := query.Query(db, &films)
|
||||
|
||||
|
|
|
|||
|
|
@ -144,7 +144,7 @@ func TestWithStatementDeleteAndInsert(t *testing.T) {
|
|||
require.Equal(t, len(updateDiscontinuedPrice.AllColumns()[0].(ProjectionList)), 10)
|
||||
require.Equal(t, len(logDiscontinuedProducts.AllColumns()), 10)
|
||||
|
||||
//fmt.Println(stmt.Sql())
|
||||
// fmt.Println(stmt.Sql())
|
||||
|
||||
testutils.AssertStatementSql(t, stmt, `
|
||||
WITH remove_discontinued_orders AS (
|
||||
|
|
@ -218,7 +218,122 @@ FROM log_discontinued;
|
|||
|
||||
err = stmt.Query(tx, &resp)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
|
||||
// CTE columns are listed as part of CTE definition
|
||||
n1 := IntegerColumn("n1")
|
||||
fibN1 := IntegerColumn("fibN1")
|
||||
nextFibN1 := IntegerColumn("nextFibN1")
|
||||
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
|
||||
|
||||
// CTE columns are columns from non-recursive select
|
||||
fibonacci2 := CTE("fibonacci2")
|
||||
n2 := IntegerColumn("n2").From(fibonacci2)
|
||||
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
|
||||
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
|
||||
|
||||
stmt := WITH_RECURSIVE(
|
||||
fibonacci1.AS(
|
||||
SELECT(
|
||||
Int32(1), Int32(0), Int32(1),
|
||||
).UNION_ALL(
|
||||
SELECT(
|
||||
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
|
||||
).FROM(
|
||||
fibonacci1,
|
||||
).WHERE(
|
||||
n1.LT(Int(20)),
|
||||
),
|
||||
),
|
||||
),
|
||||
fibonacci2.AS(
|
||||
SELECT(
|
||||
Int32(1).AS(n2.Name()),
|
||||
Int32(0).AS(fibN2.Name()),
|
||||
Int32(1).AS(nextFibN2.Name()),
|
||||
).UNION_ALL(
|
||||
SELECT(
|
||||
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
|
||||
).FROM(
|
||||
fibonacci2,
|
||||
).WHERE(
|
||||
n2.LT(Int(20)),
|
||||
),
|
||||
),
|
||||
),
|
||||
)(
|
||||
SELECT(
|
||||
fibonacci1.AllColumns(),
|
||||
fibonacci2.AllColumns(),
|
||||
).FROM(
|
||||
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
|
||||
).WHERE(
|
||||
n1.EQ(Int(20)),
|
||||
),
|
||||
)
|
||||
|
||||
//fmt.Println(stmt.Sql())
|
||||
|
||||
testutils.AssertStatementSql(t, stmt, `
|
||||
WITH RECURSIVE fibonacci1 (n1, "fibN1", "nextFibN1") AS (
|
||||
(
|
||||
SELECT $1::integer,
|
||||
$2::integer,
|
||||
$3::integer
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT fibonacci1.n1 + $4,
|
||||
fibonacci1."nextFibN1" AS "nextFibN1",
|
||||
fibonacci1."fibN1" + fibonacci1."nextFibN1"
|
||||
FROM fibonacci1
|
||||
WHERE fibonacci1.n1 < $5
|
||||
)
|
||||
),fibonacci2 AS (
|
||||
(
|
||||
SELECT $6::integer AS "n2",
|
||||
$7::integer AS "fibN2",
|
||||
$8::integer AS "nextFibN2"
|
||||
)
|
||||
UNION ALL
|
||||
(
|
||||
SELECT fibonacci2.n2 + $9,
|
||||
fibonacci2."nextFibN2" AS "nextFibN2",
|
||||
fibonacci2."fibN2" + fibonacci2."nextFibN2"
|
||||
FROM fibonacci2
|
||||
WHERE fibonacci2.n2 < $10
|
||||
)
|
||||
)
|
||||
SELECT fibonacci1.n1 AS "n1",
|
||||
fibonacci1."fibN1" AS "fibN1",
|
||||
fibonacci1."nextFibN1" AS "nextFibN1",
|
||||
fibonacci2.n2 AS "n2",
|
||||
fibonacci2."fibN2" AS "fibN2",
|
||||
fibonacci2."nextFibN2" AS "nextFibN2"
|
||||
FROM fibonacci1
|
||||
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
|
||||
WHERE fibonacci1.n1 = $11;
|
||||
`)
|
||||
|
||||
var dest struct {
|
||||
N1 int
|
||||
FibN1 int
|
||||
NextFibN1 int
|
||||
|
||||
N2 int
|
||||
FibN2 int
|
||||
NextFibN2 int
|
||||
}
|
||||
|
||||
err := stmt.Query(db, &dest)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dest.N1, 20)
|
||||
require.Equal(t, dest.FibN1, 4181)
|
||||
require.Equal(t, dest.NextFibN1, 6765)
|
||||
require.Equal(t, dest.N2, 20)
|
||||
require.Equal(t, dest.FibN2, 4181)
|
||||
require.Equal(t, dest.NextFibN2, 6765)
|
||||
}
|
||||
|
||||
// default column aliases from sub-queries are bubbled up to the main query,
|
||||
|
|
@ -298,13 +413,7 @@ FROM cte2;
|
|||
require.Equal(t, dest[0].CustomColumn2, "custom_column_2")
|
||||
}
|
||||
|
||||
type EmployeeWrap struct {
|
||||
model.Employees
|
||||
|
||||
Subordinates []*EmployeeWrap
|
||||
}
|
||||
|
||||
func TestWithRecursive(t *testing.T) {
|
||||
func TestRecursiveWithStatement(t *testing.T) {
|
||||
|
||||
subordinates := CTE("subordinates")
|
||||
|
||||
|
|
@ -333,6 +442,14 @@ func TestWithRecursive(t *testing.T) {
|
|||
),
|
||||
)
|
||||
|
||||
//fmt.Println(stmt.DebugSql())
|
||||
|
||||
type EmployeeWrap struct {
|
||||
model.Employees
|
||||
|
||||
Subordinates []*EmployeeWrap
|
||||
}
|
||||
|
||||
type employeeID = int16
|
||||
employeeMap := make(map[employeeID]*EmployeeWrap)
|
||||
|
||||
|
|
@ -352,7 +469,7 @@ func TestWithRecursive(t *testing.T) {
|
|||
|
||||
employeeMap[employeeModel.EmployeeID] = newEmployeeWrap
|
||||
|
||||
if employeeModel.ReportsTo == nil { // top manager(always first row in the result)
|
||||
if result == nil { // top manager(always first row in the result)
|
||||
result = newEmployeeWrap
|
||||
continue
|
||||
}
|
||||
|
|
@ -559,3 +676,73 @@ func TestWithRecursive(t *testing.T) {
|
|||
}
|
||||
`)
|
||||
}
|
||||
|
||||
var suppliersWithFax = CTE("suppliers_fax").AS(
|
||||
SELECT(
|
||||
Suppliers.SupplierID,
|
||||
Suppliers.ContactName,
|
||||
Suppliers.Country,
|
||||
).FROM(
|
||||
Suppliers,
|
||||
).WHERE(Suppliers.Fax.IS_NOT_NULL()),
|
||||
)
|
||||
|
||||
func SuppliersNotFromUSorAUS(suppliersCTE CommonTableExpression) CommonTableExpression {
|
||||
return CTE("not_from_us_or_aus").AS(
|
||||
SELECT(
|
||||
suppliersCTE.AllColumns(),
|
||||
).FROM(
|
||||
suppliersCTE,
|
||||
).WHERE(
|
||||
Suppliers.Country.From(suppliersCTE).NOT_IN(String("US"), String("Australia")),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func TestCTEReuse(t *testing.T) {
|
||||
suppliersFilteredByCountry := SuppliersNotFromUSorAUS(suppliersWithFax)
|
||||
supplierContactName := Suppliers.ContactName.From(suppliersFilteredByCountry)
|
||||
|
||||
stmt := WITH(
|
||||
suppliersWithFax,
|
||||
suppliersFilteredByCountry,
|
||||
)(
|
||||
SELECT(
|
||||
suppliersFilteredByCountry.AllColumns(),
|
||||
).FROM(
|
||||
suppliersFilteredByCountry,
|
||||
).WHERE(
|
||||
supplierContactName.NOT_EQ(String("John")),
|
||||
),
|
||||
)
|
||||
|
||||
// fmt.Println(stmt.DebugSql())
|
||||
|
||||
testutils.AssertDebugStatementSql(t, stmt, `
|
||||
WITH suppliers_fax AS (
|
||||
SELECT suppliers.supplier_id AS "suppliers.supplier_id",
|
||||
suppliers.contact_name AS "suppliers.contact_name",
|
||||
suppliers.country AS "suppliers.country"
|
||||
FROM northwind.suppliers
|
||||
WHERE suppliers.fax IS NOT NULL
|
||||
),not_from_us_or_aus AS (
|
||||
SELECT suppliers_fax."suppliers.supplier_id" AS "suppliers.supplier_id",
|
||||
suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name",
|
||||
suppliers_fax."suppliers.country" AS "suppliers.country"
|
||||
FROM suppliers_fax
|
||||
WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia')
|
||||
)
|
||||
SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id",
|
||||
not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name",
|
||||
not_from_us_or_aus."suppliers.country" AS "suppliers.country"
|
||||
FROM not_from_us_or_aus
|
||||
WHERE not_from_us_or_aus."suppliers.contact_name" != 'John';
|
||||
`)
|
||||
|
||||
var dest []model.Suppliers
|
||||
|
||||
err := stmt.Query(db, &dest)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dest, 11)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,3 +232,117 @@ FROM payment;
|
|||
err := stmt.Query(db, &dest)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestRecursiveWithStatement_Fibonacci(t *testing.T) {
|
||||
// CTE columns are listed as part of CTE definition
|
||||
n1 := IntegerColumn("n1")
|
||||
fibN1 := IntegerColumn("fibN1")
|
||||
nextFibN1 := IntegerColumn("nextFibN1")
|
||||
fibonacci1 := CTE("fibonacci1", n1, fibN1, nextFibN1)
|
||||
|
||||
// CTE columns are columns from non-recursive select
|
||||
fibonacci2 := CTE("fibonacci2")
|
||||
n2 := IntegerColumn("n2").From(fibonacci2)
|
||||
fibN2 := IntegerColumn("fibN2").From(fibonacci2)
|
||||
nextFibN2 := IntegerColumn("nextFibN2").From(fibonacci2)
|
||||
|
||||
stmt := WITH_RECURSIVE(
|
||||
fibonacci1.AS(
|
||||
SELECT(
|
||||
Int32(1), Int32(0), Int32(1),
|
||||
).UNION_ALL(
|
||||
SELECT(
|
||||
n1.ADD(Int(1)), nextFibN1, fibN1.ADD(nextFibN1),
|
||||
).FROM(
|
||||
fibonacci1,
|
||||
).WHERE(
|
||||
n1.LT(Int(20)),
|
||||
),
|
||||
),
|
||||
),
|
||||
fibonacci2.AS(
|
||||
SELECT(
|
||||
Int32(1).AS(n2.Name()),
|
||||
Int32(0).AS(fibN2.Name()),
|
||||
Int32(1).AS(nextFibN2.Name()),
|
||||
).UNION_ALL(
|
||||
SELECT(
|
||||
n2.ADD(Int(1)), nextFibN2, fibN2.ADD(nextFibN2),
|
||||
).FROM(
|
||||
fibonacci2,
|
||||
).WHERE(
|
||||
n2.LT(Int(20)),
|
||||
),
|
||||
),
|
||||
),
|
||||
)(
|
||||
SELECT(
|
||||
fibonacci1.AllColumns(),
|
||||
fibonacci2.AllColumns(),
|
||||
).FROM(
|
||||
fibonacci1.INNER_JOIN(fibonacci2, n1.EQ(n2)),
|
||||
).WHERE(
|
||||
n1.EQ(Int(20)),
|
||||
),
|
||||
)
|
||||
|
||||
//fmt.Println(stmt.Sql())
|
||||
|
||||
testutils.AssertStatementSql(t, stmt, strings.ReplaceAll(`
|
||||
WITH RECURSIVE fibonacci1 (n1, ''fibN1'', ''nextFibN1'') AS (
|
||||
|
||||
SELECT ?,
|
||||
?,
|
||||
?
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT fibonacci1.n1 + ?,
|
||||
fibonacci1.''nextFibN1'' AS "nextFibN1",
|
||||
fibonacci1.''fibN1'' + fibonacci1.''nextFibN1''
|
||||
FROM fibonacci1
|
||||
WHERE fibonacci1.n1 < ?
|
||||
),fibonacci2 AS (
|
||||
|
||||
SELECT ? AS "n2",
|
||||
? AS "fibN2",
|
||||
? AS "nextFibN2"
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT fibonacci2.n2 + ?,
|
||||
fibonacci2.''nextFibN2'' AS "nextFibN2",
|
||||
fibonacci2.''fibN2'' + fibonacci2.''nextFibN2''
|
||||
FROM fibonacci2
|
||||
WHERE fibonacci2.n2 < ?
|
||||
)
|
||||
SELECT fibonacci1.n1 AS "n1",
|
||||
fibonacci1.''fibN1'' AS "fibN1",
|
||||
fibonacci1.''nextFibN1'' AS "nextFibN1",
|
||||
fibonacci2.n2 AS "n2",
|
||||
fibonacci2.''fibN2'' AS "fibN2",
|
||||
fibonacci2.''nextFibN2'' AS "nextFibN2"
|
||||
FROM fibonacci1
|
||||
INNER JOIN fibonacci2 ON (fibonacci1.n1 = fibonacci2.n2)
|
||||
WHERE fibonacci1.n1 = ?;
|
||||
`, "''", "`"))
|
||||
|
||||
var dest struct {
|
||||
N1 int
|
||||
FibN1 int
|
||||
NextFibN1 int
|
||||
|
||||
N2 int
|
||||
FibN2 int
|
||||
NextFibN2 int
|
||||
}
|
||||
|
||||
err := stmt.Query(db, &dest)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dest.N1, 20)
|
||||
require.Equal(t, dest.FibN1, 4181)
|
||||
require.Equal(t, dest.NextFibN1, 6765)
|
||||
require.Equal(t, dest.N2, 20)
|
||||
require.Equal(t, dest.FibN2, 4181)
|
||||
require.Equal(t, dest.NextFibN2, 6765)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue