Skip to content

Commit

Permalink
Fix unsafe branch collapsing of an outer JOIN (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
xitology authored Jul 3, 2024
1 parent 3abed0b commit a6a31bd
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 2 deletions.
110 changes: 110 additions & 0 deletions docs/src/test/nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -3286,6 +3286,116 @@ Nested subqueries that are combined with `Join` may fail to collapse.
) AS "location_2" ON ("person_2"."location_id" = "location_2"."location_id")
=#

An outer `Join` does not collapse its branches when doing so may change the
values of unmatched rows.

join_q(; left = false, right = false) =
From(:cohort1) |> Define(:n_cohort => 1) |> As(:cohort1) |>
Join(From(:cohort2) |> Define(:n_cohort => 1) |> As(:cohort2),
on = Get.cohort1.person_id .== Get.cohort2.person_id,
left = left,
right = right) |>
Select(:person_id => Fun.coalesce(Get.cohort1.person_id, Get.cohort2.person_id),
:n_cohort => Fun.coalesce(Get.cohort1.n_cohort, 0) .+ Fun.coalesce(Get.cohort2.n_cohort, 0)) |>
With(:cohort1 => From(person) |> Where(Get.year_of_birth .> 1970),
:cohort2 => From(person) |> Where(Get.year_of_birth .< 1990))

print(render(join_q()))
#=>
WITH "cohort1_1" ("person_id") AS (
SELECT "person_1"."person_id"
FROM "person" AS "person_1"
WHERE ("person_1"."year_of_birth" > 1970)
),
"cohort2_1" ("person_id") AS (
SELECT "person_2"."person_id"
FROM "person" AS "person_2"
WHERE ("person_2"."year_of_birth" < 1990)
)
SELECT
coalesce("cohort1_2"."person_id", "cohort2_2"."person_id") AS "person_id",
(coalesce(1, 0) + coalesce(1, 0)) AS "n_cohort"
FROM "cohort1_1" AS "cohort1_2"
JOIN "cohort2_1" AS "cohort2_2" ON ("cohort1_2"."person_id" = "cohort2_2"."person_id")
=#

print(render(join_q(left = true)))
#=>
WITH "cohort1_1" ("person_id") AS (
SELECT "person_1"."person_id"
FROM "person" AS "person_1"
WHERE ("person_1"."year_of_birth" > 1970)
),
"cohort2_1" ("person_id") AS (
SELECT "person_2"."person_id"
FROM "person" AS "person_2"
WHERE ("person_2"."year_of_birth" < 1990)
)
SELECT
coalesce("cohort1_2"."person_id", "cohort2_3"."person_id") AS "person_id",
(coalesce(1, 0) + coalesce("cohort2_3"."n_cohort", 0)) AS "n_cohort"
FROM "cohort1_1" AS "cohort1_2"
LEFT JOIN (
SELECT
"cohort2_2"."person_id",
1 AS "n_cohort"
FROM "cohort2_1" AS "cohort2_2"
) AS "cohort2_3" ON ("cohort1_2"."person_id" = "cohort2_3"."person_id")
=#

print(render(join_q(right = true)))
#=>
WITH "cohort1_1" ("person_id") AS (
SELECT "person_1"."person_id"
FROM "person" AS "person_1"
WHERE ("person_1"."year_of_birth" > 1970)
),
"cohort2_1" ("person_id") AS (
SELECT "person_2"."person_id"
FROM "person" AS "person_2"
WHERE ("person_2"."year_of_birth" < 1990)
)
SELECT
coalesce("cohort1_3"."person_id", "cohort2_2"."person_id") AS "person_id",
(coalesce("cohort1_3"."n_cohort", 0) + coalesce(1, 0)) AS "n_cohort"
FROM (
SELECT
"cohort1_2"."person_id",
1 AS "n_cohort"
FROM "cohort1_1" AS "cohort1_2"
) AS "cohort1_3"
RIGHT JOIN "cohort2_1" AS "cohort2_2" ON ("cohort1_3"."person_id" = "cohort2_2"."person_id")
=#

print(render(join_q(left = true, right = true)))
#=>
WITH "cohort1_1" ("person_id") AS (
SELECT "person_1"."person_id"
FROM "person" AS "person_1"
WHERE ("person_1"."year_of_birth" > 1970)
),
"cohort2_1" ("person_id") AS (
SELECT "person_2"."person_id"
FROM "person" AS "person_2"
WHERE ("person_2"."year_of_birth" < 1990)
)
SELECT
coalesce("cohort1_3"."person_id", "cohort2_3"."person_id") AS "person_id",
(coalesce("cohort1_3"."n_cohort", 0) + coalesce("cohort2_3"."n_cohort", 0)) AS "n_cohort"
FROM (
SELECT
"cohort1_2"."person_id",
1 AS "n_cohort"
FROM "cohort1_1" AS "cohort1_2"
) AS "cohort1_3"
FULL JOIN (
SELECT
"cohort2_2"."person_id",
1 AS "n_cohort"
FROM "cohort2_1" AS "cohort2_2"
) AS "cohort2_3" ON ("cohort1_3"."person_id" = "cohort2_3"."person_id")
=#

`Join` can be applied to correlated subqueries.

ql(person_id) =
Expand Down
7 changes: 5 additions & 2 deletions src/translate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,12 @@ function assemble(n::PartitionNode, ctx)
Assemblage(base.name, c, cols = cols, repl = repl)
end

_outer_safe(a::Assemblage) =
all(@dissect(col, (nothing |> ID() |> ID())) for col in values(a.cols))

function assemble(n::RoutedJoinNode, ctx)
left = assemble(n.over, ctx)
if @dissect(left.clause, tail := FROM() || JOIN())
if @dissect(left.clause, tail := FROM() || JOIN()) && (!n.right || _outer_safe(left))
left_alias = nothing
else
left_alias = allocate_alias(ctx, left)
Expand All @@ -842,7 +845,7 @@ function assemble(n::RoutedJoinNode, ctx)
else
right = assemble(n.joinee, ctx)
end
if @dissect(right.clause, (joinee := (ID() || AS())) |> FROM())
if @dissect(right.clause, (joinee := (ID() || AS())) |> FROM()) && (!n.left || _outer_safe(right))
for (ref, name) in right.repl
subs[ref] = right.cols[name]
end
Expand Down

0 comments on commit a6a31bd

Please sign in to comment.