diff --git a/docs/src/test/nodes.md b/docs/src/test/nodes.md index 278eb910..91c2b410 100644 --- a/docs/src/test/nodes.md +++ b/docs/src/test/nodes.md @@ -3286,6 +3286,116 @@ Nested subqueries that are combined with `Join` may fail to collapse. ) AS "location_2" ON ("person_2"."location_id" = "location_2"."location_id") =# +An outer `Join` does not collapse its branches when doing so may change the +values of unmatched rows. + + join_q(; left = false, right = false) = + From(:cohort1) |> Define(:n_cohort => 1) |> As(:cohort1) |> + Join(From(:cohort2) |> Define(:n_cohort => 1) |> As(:cohort2), + on = Get.cohort1.person_id .== Get.cohort2.person_id, + left = left, + right = right) |> + Select(:person_id => Fun.coalesce(Get.cohort1.person_id, Get.cohort2.person_id), + :n_cohort => Fun.coalesce(Get.cohort1.n_cohort, 0) .+ Fun.coalesce(Get.cohort2.n_cohort, 0)) |> + With(:cohort1 => From(person) |> Where(Get.year_of_birth .> 1970), + :cohort2 => From(person) |> Where(Get.year_of_birth .< 1990)) + + print(render(join_q())) + #=> + WITH "cohort1_1" ("person_id") AS ( + SELECT "person_1"."person_id" + FROM "person" AS "person_1" + WHERE ("person_1"."year_of_birth" > 1970) + ), + "cohort2_1" ("person_id") AS ( + SELECT "person_2"."person_id" + FROM "person" AS "person_2" + WHERE ("person_2"."year_of_birth" < 1990) + ) + SELECT + coalesce("cohort1_2"."person_id", "cohort2_2"."person_id") AS "person_id", + (coalesce(1, 0) + coalesce(1, 0)) AS "n_cohort" + FROM "cohort1_1" AS "cohort1_2" + JOIN "cohort2_1" AS "cohort2_2" ON ("cohort1_2"."person_id" = "cohort2_2"."person_id") + =# + + print(render(join_q(left = true))) + #=> + WITH "cohort1_1" ("person_id") AS ( + SELECT "person_1"."person_id" + FROM "person" AS "person_1" + WHERE ("person_1"."year_of_birth" > 1970) + ), + "cohort2_1" ("person_id") AS ( + SELECT "person_2"."person_id" + FROM "person" AS "person_2" + WHERE ("person_2"."year_of_birth" < 1990) + ) + SELECT + coalesce("cohort1_2"."person_id", "cohort2_3"."person_id") AS "person_id", + (coalesce(1, 0) + coalesce("cohort2_3"."n_cohort", 0)) AS "n_cohort" + FROM "cohort1_1" AS "cohort1_2" + LEFT JOIN ( + SELECT + "cohort2_2"."person_id", + 1 AS "n_cohort" + FROM "cohort2_1" AS "cohort2_2" + ) AS "cohort2_3" ON ("cohort1_2"."person_id" = "cohort2_3"."person_id") + =# + + print(render(join_q(right = true))) + #=> + WITH "cohort1_1" ("person_id") AS ( + SELECT "person_1"."person_id" + FROM "person" AS "person_1" + WHERE ("person_1"."year_of_birth" > 1970) + ), + "cohort2_1" ("person_id") AS ( + SELECT "person_2"."person_id" + FROM "person" AS "person_2" + WHERE ("person_2"."year_of_birth" < 1990) + ) + SELECT + coalesce("cohort1_3"."person_id", "cohort2_2"."person_id") AS "person_id", + (coalesce("cohort1_3"."n_cohort", 0) + coalesce(1, 0)) AS "n_cohort" + FROM ( + SELECT + "cohort1_2"."person_id", + 1 AS "n_cohort" + FROM "cohort1_1" AS "cohort1_2" + ) AS "cohort1_3" + RIGHT JOIN "cohort2_1" AS "cohort2_2" ON ("cohort1_3"."person_id" = "cohort2_2"."person_id") + =# + + print(render(join_q(left = true, right = true))) + #=> + WITH "cohort1_1" ("person_id") AS ( + SELECT "person_1"."person_id" + FROM "person" AS "person_1" + WHERE ("person_1"."year_of_birth" > 1970) + ), + "cohort2_1" ("person_id") AS ( + SELECT "person_2"."person_id" + FROM "person" AS "person_2" + WHERE ("person_2"."year_of_birth" < 1990) + ) + SELECT + coalesce("cohort1_3"."person_id", "cohort2_3"."person_id") AS "person_id", + (coalesce("cohort1_3"."n_cohort", 0) + coalesce("cohort2_3"."n_cohort", 0)) AS "n_cohort" + FROM ( + SELECT + "cohort1_2"."person_id", + 1 AS "n_cohort" + FROM "cohort1_1" AS "cohort1_2" + ) AS "cohort1_3" + FULL JOIN ( + SELECT + "cohort2_2"."person_id", + 1 AS "n_cohort" + FROM "cohort2_1" AS "cohort2_2" + ) AS "cohort2_3" ON ("cohort1_3"."person_id" = "cohort2_3"."person_id") + =# + `Join` can be applied to correlated subqueries. ql(person_id) = diff --git a/src/translate.jl b/src/translate.jl index cfc63187..c53c8b74 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -827,9 +827,12 @@ function assemble(n::PartitionNode, ctx) Assemblage(base.name, c, cols = cols, repl = repl) end +_outer_safe(a::Assemblage) = + all(@dissect(col, (nothing |> ID() |> ID())) for col in values(a.cols)) + function assemble(n::RoutedJoinNode, ctx) left = assemble(n.over, ctx) - if @dissect(left.clause, tail := FROM() || JOIN()) + if @dissect(left.clause, tail := FROM() || JOIN()) && (!n.right || _outer_safe(left)) left_alias = nothing else left_alias = allocate_alias(ctx, left) @@ -842,7 +845,7 @@ function assemble(n::RoutedJoinNode, ctx) else right = assemble(n.joinee, ctx) end - if @dissect(right.clause, (joinee := (ID() || AS())) |> FROM()) + if @dissect(right.clause, (joinee := (ID() || AS())) |> FROM()) && (!n.left || _outer_safe(right)) for (ref, name) in right.repl subs[ref] = right.cols[name] end