diff --git a/docs/src/examples/index.md b/docs/src/examples/index.md index 088e1ea4..d05400fe 100644 --- a/docs/src/examples/index.md +++ b/docs/src/examples/index.md @@ -655,45 +655,45 @@ query. ("concept_1"."vocabulary_id" = 'SNOMED') AND ("concept_1"."concept_code" = '22298006') UNION ALL - SELECT "concept_3"."concept_id" - FROM "base_1" AS "concept_2" + SELECT "concept_2"."concept_id" + FROM "base_1" AS "base_2" JOIN ( SELECT "concept_relationship_1"."concept_id_1", "concept_relationship_1"."concept_id_2" FROM "concept_relationship" AS "concept_relationship_1" WHERE ("concept_relationship_1"."relationship_id" = 'Is a') - ) AS "concept_relationship_2" ON ("concept_2"."concept_id" = "concept_relationship_2"."concept_id_2") - JOIN "concept" AS "concept_3" ON ("concept_relationship_2"."concept_id_1" = "concept_3"."concept_id") + ) AS "concept_relationship_2" ON ("base_2"."concept_id" = "concept_relationship_2"."concept_id_2") + JOIN "concept" AS "concept_2" ON ("concept_relationship_2"."concept_id_1" = "concept_2"."concept_id") ), - "base_3" ("concept_id") AS ( - SELECT "concept_4"."concept_id" - FROM "concept" AS "concept_4" + "base_4" ("concept_id") AS ( + SELECT "concept_3"."concept_id" + FROM "concept" AS "concept_3" WHERE - ("concept_4"."vocabulary_id" = 'SNOMED') AND - ("concept_4"."concept_code" = '70422006') + ("concept_3"."vocabulary_id" = 'SNOMED') AND + ("concept_3"."concept_code" = '70422006') UNION ALL - SELECT "concept_6"."concept_id" - FROM "base_3" AS "concept_5" + SELECT "concept_4"."concept_id" + FROM "base_4" AS "base_5" JOIN ( SELECT "concept_relationship_3"."concept_id_1", "concept_relationship_3"."concept_id_2" FROM "concept_relationship" AS "concept_relationship_3" WHERE ("concept_relationship_3"."relationship_id" = 'Is a') - ) AS "concept_relationship_4" ON ("concept_5"."concept_id" = "concept_relationship_4"."concept_id_2") - JOIN "concept" AS "concept_6" ON ("concept_relationship_4"."concept_id_1" = "concept_6"."concept_id") + ) AS "concept_relationship_4" ON ("base_5"."concept_id" = "concept_relationship_4"."concept_id_2") + JOIN "concept" AS "concept_4" ON ("concept_relationship_4"."concept_id_1" = "concept_4"."concept_id") ) SELECT "condition_occurrence_1"."person_id", "condition_occurrence_1"."condition_start_date" FROM "condition_occurrence" AS "condition_occurrence_1" JOIN ( - SELECT "base_2"."concept_id" - FROM "base_1" AS "base_2" - LEFT JOIN "base_3" AS "base_4" ON ("base_2"."concept_id" = "base_4"."concept_id") - WHERE ("base_4"."concept_id" IS NULL) - ) AS "concept_7" ON ("condition_occurrence_1"."condition_concept_id" = "concept_7"."concept_id") + SELECT "base_3"."concept_id" + FROM "base_1" AS "base_3" + LEFT JOIN "base_4" AS "base_6" ON ("base_3"."concept_id" = "base_6"."concept_id") + WHERE ("base_6"."concept_id" IS NULL) + ) AS "base_7" ON ("condition_occurrence_1"."condition_concept_id" = "base_7"."concept_id") ORDER BY "condition_occurrence_1"."condition_occurrence_id" =# @@ -920,34 +920,34 @@ Now we have all the components to construct the final query: ("concept_1"."vocabulary_id" = 'SNOMED') AND ("concept_1"."concept_code" = '22298006') UNION ALL - SELECT "concept_3"."concept_id" - FROM "base_1" AS "concept_2" + SELECT "concept_2"."concept_id" + FROM "base_1" AS "base_2" JOIN ( SELECT "concept_relationship_1"."concept_id_1", "concept_relationship_1"."concept_id_2" FROM "concept_relationship" AS "concept_relationship_1" WHERE ("concept_relationship_1"."relationship_id" = 'Is a') - ) AS "concept_relationship_2" ON ("concept_2"."concept_id" = "concept_relationship_2"."concept_id_2") - JOIN "concept" AS "concept_3" ON ("concept_relationship_2"."concept_id_1" = "concept_3"."concept_id") + ) AS "concept_relationship_2" ON ("base_2"."concept_id" = "concept_relationship_2"."concept_id_2") + JOIN "concept" AS "concept_2" ON ("concept_relationship_2"."concept_id_1" = "concept_2"."concept_id") ), - "base_3" ("concept_id") AS ( - SELECT "concept_4"."concept_id" - FROM "concept" AS "concept_4" + "base_4" ("concept_id") AS ( + SELECT "concept_3"."concept_id" + FROM "concept" AS "concept_3" WHERE - ("concept_4"."vocabulary_id" = 'Visit') AND - ("concept_4"."concept_code" = 'IP') + ("concept_3"."vocabulary_id" = 'Visit') AND + ("concept_3"."concept_code" = 'IP') UNION ALL - SELECT "concept_6"."concept_id" - FROM "base_3" AS "concept_5" + SELECT "concept_4"."concept_id" + FROM "base_4" AS "base_5" JOIN ( SELECT "concept_relationship_3"."concept_id_1", "concept_relationship_3"."concept_id_2" FROM "concept_relationship" AS "concept_relationship_3" WHERE ("concept_relationship_3"."relationship_id" = 'Is a') - ) AS "concept_relationship_4" ON ("concept_5"."concept_id" = "concept_relationship_4"."concept_id_2") - JOIN "concept" AS "concept_6" ON ("concept_relationship_4"."concept_id_1" = "concept_6"."concept_id") + ) AS "concept_relationship_4" ON ("base_5"."concept_id" = "concept_relationship_4"."concept_id_2") + JOIN "concept" AS "concept_4" ON ("concept_relationship_4"."concept_id_1" = "concept_4"."concept_id") ) SELECT "condition_occurrence_3"."person_id", @@ -962,13 +962,13 @@ Now we have all the components to construct the final query: "condition_occurrence_1"."person_id", "condition_occurrence_1"."condition_start_date" FROM "condition_occurrence" AS "condition_occurrence_1" - JOIN "base_1" AS "base_2" ON ("condition_occurrence_1"."condition_concept_id" = "base_2"."concept_id") + JOIN "base_1" AS "base_3" ON ("condition_occurrence_1"."condition_concept_id" = "base_3"."concept_id") ORDER BY "condition_occurrence_1"."condition_occurrence_id" ) AS "condition_occurrence_2" WHERE (EXISTS ( SELECT NULL AS "_" FROM "visit_occurrence" AS "visit_occurrence_1" - JOIN "base_3" AS "base_4" ON ("visit_occurrence_1"."visit_concept_id" = "base_4"."concept_id") + JOIN "base_4" AS "base_6" ON ("visit_occurrence_1"."visit_concept_id" = "base_6"."concept_id") WHERE ("visit_occurrence_1"."person_id" = "condition_occurrence_2"."person_id") AND ("condition_occurrence_2"."condition_start_date" BETWEEN "visit_occurrence_1"."visit_start_date" AND "visit_occurrence_1"."visit_end_date") diff --git a/docs/src/guide/index.md b/docs/src/guide/index.md index a25d4e4e..6a918797 100644 --- a/docs/src/guide/index.md +++ b/docs/src/guide/index.md @@ -1283,13 +1283,13 @@ immediate subtypes as the function: DBInterface.execute(conn, q) |> DataFrame #=> - 2×7 DataFrame + 2×10 DataFrame Row │ concept_id concept_name domain_id vocabulary_id conc ⋯ │ Int64 String String String Stri ⋯ ─────┼────────────────────────────────────────────────────────────────────────── 1 │ 4329847 Myocardial infarction Condition SNOMED Clin ⋯ 2 │ 312327 Acute myocardial infarction Condition SNOMED Clin - 3 columns omitted + 6 columns omitted =# But how can we fetch not just immediate, but all of the subtypes of a concept? @@ -1317,7 +1317,7 @@ This is exactly the action of the [`Iterate`](@ref) node. render(conn, q) |> print #=> - WITH RECURSIVE "concept_2" ("concept_id", "concept_name", "domain_id", "vocabulary_id", "concept_class_id", "standard_concept", "concept_code") AS ( + WITH RECURSIVE "__1" ("concept_id", "concept_name", "domain_id", "vocabulary_id", "concept_class_id", "standard_concept", "concept_code", "valid_start_date", "valid_end_date", "invalid_reason") AS ( SELECT "concept_1"."concept_id", "concept_1"."concept_name", @@ -1325,42 +1325,54 @@ This is exactly the action of the [`Iterate`](@ref) node. "concept_1"."vocabulary_id", "concept_1"."concept_class_id", "concept_1"."standard_concept", - "concept_1"."concept_code" + "concept_1"."concept_code", + "concept_1"."valid_start_date", + "concept_1"."valid_end_date", + "concept_1"."invalid_reason" FROM "concept" AS "concept_1" WHERE ("concept_1"."concept_name" = 'Myocardial infarction') UNION ALL SELECT - "concept_3"."concept_id", - "concept_3"."concept_name", - "concept_3"."domain_id", - "concept_3"."vocabulary_id", - "concept_3"."concept_class_id", - "concept_3"."standard_concept", - "concept_3"."concept_code" - FROM "concept" AS "concept_3" + "concept_2"."concept_id", + "concept_2"."concept_name", + "concept_2"."domain_id", + "concept_2"."vocabulary_id", + "concept_2"."concept_class_id", + "concept_2"."standard_concept", + "concept_2"."concept_code", + "concept_relationship_2"."valid_start_date", + "concept_relationship_2"."valid_end_date", + "concept_relationship_2"."invalid_reason" + FROM "concept" AS "concept_2" JOIN ( SELECT + "concept_relationship_1"."valid_start_date", + "concept_relationship_1"."valid_end_date", + "concept_relationship_1"."invalid_reason", "concept_relationship_1"."concept_id_2", "concept_relationship_1"."concept_id_1" FROM "concept_relationship" AS "concept_relationship_1" WHERE ("concept_relationship_1"."relationship_id" = 'Is a') - ) AS "concept_relationship_2" ON ("concept_3"."concept_id" = "concept_relationship_2"."concept_id_1") - JOIN "concept_2" AS "concept_4" ON ("concept_relationship_2"."concept_id_2" = "concept_4"."concept_id") + ) AS "concept_relationship_2" ON ("concept_2"."concept_id" = "concept_relationship_2"."concept_id_1") + JOIN "__1" AS "__2" ON ("concept_relationship_2"."concept_id_2" = "__2"."concept_id") ) SELECT - "concept_5"."concept_id", - "concept_5"."concept_name", - "concept_5"."domain_id", - "concept_5"."vocabulary_id", - "concept_5"."concept_class_id", - "concept_5"."standard_concept", - "concept_5"."concept_code" - FROM "concept_2" AS "concept_5" + "concept_3"."concept_id", + "concept_3"."concept_name", + "concept_3"."domain_id", + "concept_3"."vocabulary_id", + "concept_3"."concept_class_id", + "concept_3"."standard_concept", + "concept_3"."concept_code", + "concept_3"."valid_start_date", + "concept_3"."valid_end_date", + "concept_3"."invalid_reason" + FROM "__1" AS "concept_3" =# DBInterface.execute(conn, q) |> DataFrame #=> - 6×7 DataFrame + 6×10 DataFrame Row │ concept_id concept_name domain_id vocabulary_id ⋯ │ Int64 String String String ⋯ ─────┼────────────────────────────────────────────────────────────────────────── @@ -1370,6 +1382,6 @@ This is exactly the action of the [`Iterate`](@ref) node. 4 │ 438170 Acute myocardial infarction of i… Condition SNOMED 5 │ 438438 Acute myocardial infarction of a… Condition SNOMED ⋯ 6 │ 444406 Acute subendocardial infarction Condition SNOMED - 3 columns omitted + 6 columns omitted =# diff --git a/docs/src/test/nodes.md b/docs/src/test/nodes.md index d70efccc..cdc776ba 100644 --- a/docs/src/test/nodes.md +++ b/docs/src/test/nodes.md @@ -87,6 +87,12 @@ Ill-formed queries are detected. end =# + q = From(person) |> Fun.current_date() + #=> + ERROR: FunSQL.RebaseError in: + Fun.current_date() + =# + ## `@funsql` @@ -397,11 +403,6 @@ Use backticks to represent a name that is not a valid identifier. end =# - q = q |> Where(Fun.">"(e, 2000)) - - e = Get(over = q, :person_id) - #-> (…) |> Get.person_id - q.person_id #-> (…) |> Get.person_id @@ -414,15 +415,6 @@ Use backticks to represent a name that is not a valid identifier. q["person_id"] #-> (…) |> Get.person_id - q = q |> Select(e) - - print(render(q)) - #=> - SELECT "person_1"."person_id" - FROM "person" AS "person_1" - WHERE ("person_1"."year_of_birth" > 2000) - =# - `Get` is used for dereferencing an alias created with `As`. q = From(person) |> @@ -452,23 +444,6 @@ This is particularly useful when you need to disambiguate the output of `Join`. JOIN "location" AS "location_1" ON ("person_1"."location_id" = "location_1"."location_id") =# -Alternatively, node-bound references could be used for this purpose. - - qₚ = From(person) - qₗ = From(location) - q = qₚ |> - Join(qₗ, on = qₚ.location_id .== qₗ.location_id) |> - Select(qₚ.person_id, qₗ.state) - - print(render(q)) - #=> - SELECT - "person_1"."person_id", - "location_1"."state" - FROM "person" AS "person_1" - JOIN "location" AS "location_1" ON ("person_1"."location_id" = "location_1"."location_id") - =# - When `Get` refers to an unknown attribute, an error is reported. q = Select(Get.person_id) @@ -493,8 +468,8 @@ When `Get` refers to an unknown attribute, an error is reported. end =# -An error is also reported when a `Get` reference cannot be resolved -unambiguously. +An attribute defined in a `Join` shadows any previously defined attributes +with the same name. q = person |> Join(person, true) |> @@ -502,14 +477,9 @@ unambiguously. print(render(q)) #=> - ERROR: FunSQL.ReferenceError: `person_id` is ambiguous in: - let person = SQLTable(:person, …), - q1 = From(person), - q2 = From(person), - q3 = q1 |> Join(q2, true), - q4 = q3 |> Select(Get.person_id) - q4 - end + SELECT "person_2"."person_id" + FROM "person" AS "person_1" + CROSS JOIN "person" AS "person_2" =# An incomplete hierarchical reference, as well as an unexpected hierarchical @@ -542,45 +512,17 @@ reference, will result in an error. end =# -A node-bound reference that is bound to an unrelated node will cause an error. - - q = (qₚ = From(person)) |> - Join(:location => From(location) |> Where(qₚ.year_of_birth .>= 1950), - on = Get.location_id .== Get.location.location_id) - - print(render(q)) - #=> - ERROR: FunSQL.ReferenceError: node-bound reference failed to resolve in: - let person = SQLTable(:person, …), - location = SQLTable(:location, …), - q1 = From(person), - q2 = From(location), - q3 = q2 |> Where(Fun.">="(q1.year_of_birth, 1950)), - q4 = q1 |> - Join(q3 |> As(:location), - Fun."="(Get.location_id, Get.location.location_id)) - q4 - end - =# - -A node-bound reference which cannot be resolved unambiguously will also cause -an error. +A reference bound to any node other than `Get` will cause an error. - q = (qₚ = From(person)) |> - Join(:another => qₚ, - on = Get.person_id .!= Get.another.person_id) |> - Select(qₚ.person_id) + q = (qₚ = From(person)) |> Select(qₚ.person_id) print(render(q)) #=> - ERROR: FunSQL.ReferenceError: node-bound reference is ambiguous in: + ERROR: FunSQL.IllFormedError in: let person = SQLTable(:person, …), q1 = From(person), - q2 = q1 |> - Join(q1 |> As(:another), - Fun."<>"(Get.person_id, Get.another.person_id)), - q3 = q2 |> Select(q1.person_id) - q3 + q2 = q1 |> Select(q1.person_id) + q2 end =# @@ -1632,12 +1574,12 @@ produced by the base query and the iterator query. WITH RECURSIVE "previous_1" ("m") AS ( SELECT 0 AS "m" UNION ALL - SELECT ("union_1"."m" + 1) AS "m" - FROM "previous_1" AS "union_1" - WHERE ("union_1"."m" < 10) + SELECT ("previous_2"."m" + 1) AS "m" + FROM "previous_1" AS "previous_2" + WHERE ("previous_2"."m" < 10) ) - SELECT "previous_2"."m" - FROM "previous_1" AS "previous_2" + SELECT "previous_3"."m" + FROM "previous_1" AS "previous_3" =# `Iterate` aligns the columns of its subqueries. @@ -1723,18 +1665,6 @@ The `=>` shorthand is supported by `@funsql`. FROM "person" AS "person_1" =# -`As` does not block node-bound references. - - q = (qₚ = From(person)) |> - As(:p) |> - Select(qₚ.person_id) - - print(render(q)) - #=> - SELECT "person_1"."person_id" - FROM "person" AS "person_1" - =# - ## `From` @@ -2604,48 +2534,45 @@ It is an error for an aggregate expression to be used without `Group`. end =# -It is also an error when an aggregate expression cannot determine its `Group` -unambiguously. +`Group` in a `Join` expression shadows any previous applications of `Group`. qₚ = From(person) - qᵥ = From(visit_occurrence) |> Group(Get.person_id) - qₘ = From(measurement) |> Group(Get.person_id) + qᵥ = From(visit_occurrence) |> Group(:visit_person_id => Get.person_id) + qₘ = From(measurement) |> Group(:measurement_person_id => Get.person_id) q = qₚ |> - Join(qᵥ, on = qₚ.person_id .== qᵥ.person_id, left = true) |> - Join(qₘ, on = qₚ.person_id .== qₘ.person_id, left = true) |> - Select(qₚ.person_id, :count => Fun.coalesce(Agg.count(), 0)) + Join(qᵥ, on = Get.person_id .== Get.visit_person_id, left = true) |> + Join(qₘ, on = Get.person_id .== Get.measurement_person_id, left = true) |> + Select(Get.person_id, :count => Fun.coalesce(Agg.count(), 0)) print(render(q)) #=> - ERROR: FunSQL.ReferenceError: aggregate expression is ambiguous in: - let person = SQLTable(:person, …), - visit_occurrence = SQLTable(:visit_occurrence, …), - measurement = SQLTable(:measurement, …), - q1 = From(person), - q2 = From(visit_occurrence), - q3 = Get.person_id, - q4 = q2 |> Group(q3), - q5 = q1 |> Join(q4, Fun."="(q1.person_id, q4.person_id), left = true), - q6 = From(measurement), - q7 = Get.person_id, - q8 = q6 |> Group(q7), - q9 = q5 |> Join(q8, Fun."="(q1.person_id, q8.person_id), left = true), - q10 = q9 |> - Select(q1.person_id, Fun.coalesce(Agg.count(), 0) |> As(:count)) - q10 - end + SELECT + "person_1"."person_id", + coalesce("measurement_2"."count", 0) AS "count" + FROM "person" AS "person_1" + LEFT JOIN ( + SELECT DISTINCT "visit_occurrence_1"."person_id" AS "visit_person_id" + FROM "visit_occurrence" AS "visit_occurrence_1" + ) AS "visit_occurrence_2" ON ("person_1"."person_id" = "visit_occurrence_2"."visit_person_id") + LEFT JOIN ( + SELECT + count(*) AS "count", + "measurement_1"."person_id" AS "measurement_person_id" + FROM "measurement" AS "measurement_1" + GROUP BY "measurement_1"."person_id" + ) AS "measurement_2" ON ("person_1"."person_id" = "measurement_2"."measurement_person_id") =# It is still possible to use an aggregate in the context of a Join when the corresponding `Group` could be determined unambiguously. qₚ = From(person) - qᵥ = From(visit_occurrence) |> Group(Get.person_id) + qᵥ = From(visit_occurrence) |> Group(:visit_person_id => Get.person_id) q = qₚ |> - Join(qᵥ, on = qₚ.person_id .== qᵥ.person_id, left = true) |> - Select(qₚ.person_id, :count => Fun.coalesce(Agg.count(), 0)) + Join(qᵥ, on = Get.person_id .== Get.visit_person_id, left = true) |> + Select(Get.person_id, :count => Fun.coalesce(Agg.count(), 0)) print(render(q)) #=> @@ -2656,10 +2583,10 @@ corresponding `Group` could be determined unambiguously. LEFT JOIN ( SELECT count(*) AS "count", - "visit_occurrence_1"."person_id" + "visit_occurrence_1"."person_id" AS "visit_person_id" FROM "visit_occurrence" AS "visit_occurrence_1" GROUP BY "visit_occurrence_1"."person_id" - ) AS "visit_occurrence_2" ON ("person_1"."person_id" = "visit_occurrence_2"."person_id") + ) AS "visit_occurrence_2" ON ("person_1"."person_id" = "visit_occurrence_2"."visit_person_id") =# @@ -3703,79 +3630,48 @@ Consider the following query. :max_visit_start_date => Get.visit_group |> Agg.max(Get.visit_start_date)) -At the first stage of the translation, `render()` augments the query object -with some additional nodes. A `Box` node is inserted in front of each -tabular node and hierarchical `Get` nodes are reversed. +At the first stage of the translation, `render()` resolves table references +and determines node types. #? VERSION >= v"1.7" # https://github.com/JuliaLang/julia/issues/26798 - withenv("JULIA_DEBUG" => "FunSQL.annotate") do - render(q) - end; - #=> - ┌ Debug: FunSQL.annotate - │ let person = SQLTable(:person, …), - │ location = SQLTable(:location, …), - │ visit_occurrence = SQLTable(:visit_occurrence, …), - │ q1 = FromTable(table = person), - │ q2 = q1 |> Box(), - ⋮ - │ q21 = q20 |> - │ Select(Get.person_id, - │ NameBound(over = Agg.max(Get.visit_start_date), - │ name = :visit_group) |> - │ As(:max_visit_start_date)), - │ q22 = q21 |> Box() - │ q22 - │ end - └ @ FunSQL … - =# - -Next, `render()` determines the type of each tabular node and attaches -it to the corresponding `Box` node. - - #? VERSION >= v"1.7" - withenv("JULIA_DEBUG" => "FunSQL.resolve!") do + withenv("JULIA_DEBUG" => "FunSQL.resolve") do render(q) end; #=> - ┌ Debug: FunSQL.resolve! + ┌ Debug: FunSQL.resolve │ let person = SQLTable(:person, …), │ location = SQLTable(:location, …), │ visit_occurrence = SQLTable(:visit_occurrence, …), │ q1 = FromTable(table = person), - │ q2 = q1 |> - │ Box(type = BoxType(:person, - │ :person_id => ScalarType(), - │ :gender_concept_id => ScalarType(), - │ :year_of_birth => ScalarType(), - │ :month_of_birth => ScalarType(), - │ :day_of_birth => ScalarType(), - │ :birth_datetime => ScalarType(), - │ :location_id => ScalarType())), + │ q2 = Resolved(RowType(:person_id => ScalarType(), + │ :gender_concept_id => ScalarType(), + │ :year_of_birth => ScalarType(), + │ :month_of_birth => ScalarType(), + │ :day_of_birth => ScalarType(), + │ :birth_datetime => ScalarType(), + │ :location_id => ScalarType()), + │ over = q1) |> + │ Where(Resolved(ScalarType(), + │ over = Fun."<="(Resolved(ScalarType(), + │ over = Get.year_of_birth), + │ Resolved(ScalarType(), over = 2000)))), ⋮ - │ q21 = q20 |> - │ Select(Get.person_id, - │ NameBound(over = Agg.max(Get.visit_start_date), - │ name = :visit_group) |> - │ As(:max_visit_start_date)), - │ q22 = q21 |> - │ Box(type = BoxType(:person, - │ :person_id => ScalarType(), - │ :max_visit_start_date => ScalarType())) - │ q22 + │ WithContext(over = Resolved(RowType(:person_id => ScalarType(), + │ :max_visit_start_date => ScalarType()), + │ over = q9)) │ end └ @ FunSQL … =# -Next, `render()` validates column references and aggregate functions -and determine the columns to be provided by each tabular query. +Next, `render()` determines, for each tabular node, the data that it must +produce. #? VERSION >= v"1.7" - withenv("JULIA_DEBUG" => "FunSQL.link!") do + withenv("JULIA_DEBUG" => "FunSQL.link") do render(q) end; #=> - ┌ Debug: FunSQL.link! + ┌ Debug: FunSQL.link │ let person = SQLTable(:person, …), │ location = SQLTable(:location, …), │ visit_occurrence = SQLTable(:visit_occurrence, …), @@ -3784,25 +3680,9 @@ and determine the columns to be provided by each tabular query. │ q3 = Get.person_id, │ q4 = Get.location_id, │ q5 = Get.year_of_birth, - │ q6 = q1 |> - │ Box(type = BoxType(:person, - │ :person_id => ScalarType(), - │ :gender_concept_id => ScalarType(), - │ :year_of_birth => ScalarType(), - │ :month_of_birth => ScalarType(), - │ :day_of_birth => ScalarType(), - │ :birth_datetime => ScalarType(), - │ :location_id => ScalarType()), - │ refs = [q2, q3, q4, q5], - │ imm_refs_begin_at = 4), + │ q6 = Linked([q2, q3, q4, q5], 3, over = q1), ⋮ - │ q34 = q33 |> Select(q2, q29 |> As(:max_visit_start_date)), - │ q35 = q34 |> - │ Box(type = BoxType(:person, - │ :person_id => ScalarType(), - │ :max_visit_start_date => ScalarType()), - │ refs = [Get.person_id, Get.max_visit_start_date]) - │ q35 + │ WithContext(over = q33) │ end └ @ FunSQL … =# @@ -3815,68 +3695,73 @@ On the next stage, the query object is converted to a SQL syntax tree. end; #=> ┌ Debug: FunSQL.translate - │ ID(:person) |> - │ AS(:person_1) |> - │ FROM() |> - │ WHERE(FUN("<=", ID(:person_1) |> ID(:year_of_birth), LIT(2000))) |> - │ SELECT(ID(:person_1) |> ID(:person_id), ID(:person_1) |> ID(:location_id)) |> - │ AS(:person_2) |> - │ FROM() |> - │ JOIN(ID(:location) |> - │ AS(:location_1) |> - │ FROM() |> - │ WHERE(FUN("=", ID(:location_1) |> ID(:state), LIT("IL"))) |> - │ SELECT(ID(:location_1) |> ID(:location_id)) |> - │ AS(:location_2), - │ FUN("=", - │ ID(:person_2) |> ID(:location_id), - │ ID(:location_2) |> ID(:location_id))) |> - │ JOIN(ID(:visit_occurrence) |> - │ AS(:visit_occurrence_1) |> - │ FROM() |> - │ GROUP(ID(:visit_occurrence_1) |> ID(:person_id)) |> - │ SELECT(AGG("max", ID(:visit_occurrence_1) |> ID(:visit_start_date)) |> - │ AS(:max), - │ ID(:visit_occurrence_1) |> ID(:person_id)) |> - │ AS(:visit_group_1), - │ FUN("=", - │ ID(:person_2) |> ID(:person_id), - │ ID(:visit_group_1) |> ID(:person_id)), - │ left = true) |> - │ SELECT(ID(:person_2) |> ID(:person_id), - │ ID(:visit_group_1) |> ID(:max) |> AS(:max_visit_start_date)) + │ WITH_CONTEXT( + │ over = ID(:person) |> + │ AS(:person_1) |> + │ FROM() |> + │ WHERE(FUN("<=", ID(:person_1) |> ID(:year_of_birth), LIT(2000))) |> + │ SELECT(ID(:person_1) |> ID(:person_id), + │ ID(:person_1) |> ID(:location_id)) |> + │ AS(:person_2) |> + │ FROM() |> + │ JOIN(ID(:location) |> + │ AS(:location_1) |> + │ FROM() |> + │ WHERE(FUN("=", ID(:location_1) |> ID(:state), LIT("IL"))) |> + │ SELECT(ID(:location_1) |> ID(:location_id)) |> + │ AS(:location_2), + │ FUN("=", + │ ID(:person_2) |> ID(:location_id), + │ ID(:location_2) |> ID(:location_id))) |> + │ JOIN(ID(:visit_occurrence) |> + │ AS(:visit_occurrence_1) |> + │ FROM() |> + │ GROUP(ID(:visit_occurrence_1) |> ID(:person_id)) |> + │ SELECT(AGG("max", + │ ID(:visit_occurrence_1) |> ID(:visit_start_date)) |> + │ AS(:max), + │ ID(:visit_occurrence_1) |> ID(:person_id)) |> + │ AS(:visit_group_1), + │ FUN("=", + │ ID(:person_2) |> ID(:person_id), + │ ID(:visit_group_1) |> ID(:person_id)), + │ left = true) |> + │ SELECT(ID(:person_2) |> ID(:person_id), + │ ID(:visit_group_1) |> ID(:max) |> AS(:max_visit_start_date))) └ @ FunSQL … =# Finally, the SQL tree is serialized into SQL. #? VERSION >= v"1.7" - withenv("JULIA_DEBUG" => "FunSQL.render") do + withenv("JULIA_DEBUG" => "FunSQL.serialize") do render(q) end; #=> - ┌ Debug: FunSQL.render - │ SELECT - │ "person_2"."person_id", - │ "visit_group_1"."max" AS "max_visit_start_date" - │ FROM ( - │ SELECT - │ "person_1"."person_id", - │ "person_1"."location_id" - │ FROM "person" AS "person_1" - │ WHERE ("person_1"."year_of_birth" <= 2000) - │ ) AS "person_2" - │ JOIN ( - │ SELECT "location_1"."location_id" - │ FROM "location" AS "location_1" - │ WHERE ("location_1"."state" = 'IL') - │ ) AS "location_2" ON ("person_2"."location_id" = "location_2"."location_id") - │ LEFT JOIN ( - │ SELECT - │ max("visit_occurrence_1"."visit_start_date") AS "max", - │ "visit_occurrence_1"."person_id" - │ FROM "visit_occurrence" AS "visit_occurrence_1" - │ GROUP BY "visit_occurrence_1"."person_id" - │ ) AS "visit_group_1" ON ("person_2"."person_id" = "visit_group_1"."person_id") + ┌ Debug: FunSQL.serialize + │ SQLString( + │ """ + │ SELECT + │ "person_2"."person_id", + │ "visit_group_1"."max" AS "max_visit_start_date" + │ FROM ( + │ SELECT + │ "person_1"."person_id", + │ "person_1"."location_id" + │ FROM "person" AS "person_1" + │ WHERE ("person_1"."year_of_birth" <= 2000) + │ ) AS "person_2" + │ JOIN ( + │ SELECT "location_1"."location_id" + │ FROM "location" AS "location_1" + │ WHERE ("location_1"."state" = 'IL') + │ ) AS "location_2" ON ("person_2"."location_id" = "location_2"."location_id") + │ LEFT JOIN ( + │ SELECT + │ max("visit_occurrence_1"."visit_start_date") AS "max", + │ "visit_occurrence_1"."person_id" + │ FROM "visit_occurrence" AS "visit_occurrence_1" + │ GROUP BY "visit_occurrence_1"."person_id" + │ ) AS "visit_group_1" ON ("person_2"."person_id" = "visit_group_1"."person_id")""") └ @ FunSQL … =# diff --git a/src/FunSQL.jl b/src/FunSQL.jl index 2b1f24f3..fd2e7e5f 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -98,12 +98,14 @@ include("dissect.jl") include("quote.jl") include("strings.jl") include("dialects.jl") +include("types.jl") include("catalogs.jl") include("clauses.jl") include("nodes.jl") include("connections.jl") -include("types.jl") -include("annotate.jl") +#include("annotate.jl") +include("resolve.jl") +include("link.jl") include("translate.jl") include("serialize.jl") include("render.jl") diff --git a/src/annotate.jl b/src/annotate.jl deleted file mode 100644 index 0b1097e4..00000000 --- a/src/annotate.jl +++ /dev/null @@ -1,1306 +0,0 @@ -# Rewriting the node graph to prepare it for translation. - - -# Auxiliary nodes. - -# A SQL subquery with an undetermined SELECT args. -mutable struct BoxNode <: TabularNode - over::Union{SQLNode, Nothing} - type::BoxType - handle::Int - refs::Vector{SQLNode} - imm_refs_begin_at::Union{Int, Nothing} - - BoxNode(; over = nothing, type = EMPTY_BOX, handle = 0, refs = SQLNode[], imm_refs_begin_at = nothing) = - new(over, type, handle, refs, imm_refs_begin_at) -end - -Box(args...; kws...) = - BoxNode(args...; kws...) |> SQLNode - -dissect(scr::Symbol, ::typeof(Box), pats::Vector{Any}) = - dissect(scr, BoxNode, pats) - -function PrettyPrinting.quoteof(n::BoxNode, ctx::QuoteContext) - ex = Expr(:call, nameof(Box)) - if !ctx.limit - if n.type !== EMPTY_BOX - push!(ex.args, Expr(:kw, :type, quoteof(n.type))) - end - if n.handle != 0 - push!(ex.args, Expr(:kw, :handle, n.handle)) - end - if !isempty(n.refs) - push!(ex.args, Expr(:kw, :refs, Expr(:vect, quoteof(n.refs, ctx)...))) - end - if n.imm_refs_begin_at !== nothing - push!(ex.args, Expr(:kw, :imm_refs_begin_at, n.imm_refs_begin_at)) - end - else - push!(ex.args, :…) - end - if n.over !== nothing - ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) - end - ex -end - -label(n::BoxNode) = - n.type.name - -rebase(n::BoxNode, n′) = - BoxNode(over = rebase(n.over, n′), - type = n.type, handle = n.handle, refs = n.refs) - -box_type(n::BoxNode) = - n.type - -box_type(n::SQLNode) = - box_type(n[]::BoxNode) - -function reset!(n::BoxNode) - empty!(n.refs) - n.imm_refs_begin_at = nothing -end - -function begin_imm_refs!(n::BoxNode) - n.imm_refs_begin_at = length(n.refs) + 1 -end - -# Get(over = Get(:a), name = :b) => NameBound(over = Get(:b), name = :a) -mutable struct NameBoundNode <: AbstractSQLNode - over::SQLNode - name::Symbol - - NameBoundNode(; over, name) = - new(over, name) -end - -NameBound(args...; kws...) = - NameBoundNode(args...; kws...) |> SQLNode - -dissect(scr::Symbol, ::typeof(NameBound), pats::Vector{Any}) = - dissect(scr, NameBoundNode, pats) - -PrettyPrinting.quoteof(n::NameBoundNode, ctx::QuoteContext) = - Expr(:call, nameof(NameBound), Expr(:kw, :over, quoteof(n.over, ctx)), Expr(:kw, :name, QuoteNode(n.name))) - -# Get(over = q, name = :b) => HandleBound(over = Get(:b), handle = get_handle(q)) -mutable struct HandleBoundNode <: AbstractSQLNode - over::SQLNode - handle::Int - - HandleBoundNode(; over, handle) = - new(over, handle) -end - -HandleBound(args...; kws...) = - HandleBoundNode(args...; kws...) |> SQLNode - -dissect(scr::Symbol, ::typeof(HandleBound), pats::Vector{Any}) = - dissect(scr, HandleBoundNode, pats) - -PrettyPrinting.quoteof(n::HandleBoundNode, ctx::QuoteContext) = - Expr(:call, nameof(NameBound), Expr(:kw, :over, quoteof(n.over, ctx)), Expr(:kw, :handle, n.handle)) - -# A generic From node is specialized to FromNothing, FromTable, -# FromReference, FromSelf, FromValues, or FromFunction. -mutable struct FromNothingNode <: TabularNode -end - -FromNothing(args...; kws...) = - FromNothingNode(args...; kws...) |> SQLNode - -PrettyPrinting.quoteof(::FromNothingNode, ::QuoteContext) = - Expr(:call, nameof(FromNothing)) - -mutable struct FromTableNode <: TabularNode - table::SQLTable - - FromTableNode(; table) = - new(table) -end - -FromTable(args...; kws...) = - FromTableNode(args...; kws...) |> SQLNode - -function PrettyPrinting.quoteof(n::FromTableNode, ctx::QuoteContext) - tex = get(ctx.vars, n.table, nothing) - if tex === nothing - tex = quoteof(n.table, limit = true) - end - Expr(:call, nameof(FromTable), Expr(:kw, :table, tex)) -end - -mutable struct FromReferenceNode <: TabularNode - over::SQLNode - name::Symbol - - FromReferenceNode(; over, name) = - new(over, name) -end - -FromReference(args...; kws...) = - FromReferenceNode(args...; kws...) |> SQLNode - -PrettyPrinting.quoteof(n::FromReferenceNode, ctx::QuoteContext) = - Expr(:call, - nameof(FromReference), - Expr(:kw, :over, quoteof(n.over, ctx)), - Expr(:kw, :name, QuoteNode(n.name))) - -mutable struct FromSelfNode <: TabularNode - over::SQLNode - - FromSelfNode(; over) = - new(over) -end - -FromSelf(args...; kws...) = - FromSelfNode(args...; kws...) |> SQLNode - -PrettyPrinting.quoteof(n::FromSelfNode, ctx::QuoteContext) = - Expr(:call, nameof(FromSelf), Expr(:kw, :over, quoteof(n.over, ctx))) - -mutable struct FromValuesNode <: TabularNode - columns::NamedTuple - - FromValuesNode(; columns) = - new(columns) -end - -FromValues(args...; kws...) = - FromValuesNode(args...; kws...) |> SQLNode - -PrettyPrinting.quoteof(n::FromValuesNode, ctx::QuoteContext) = - Expr(:call, nameof(FromValues), Expr(:kw, :columns, quoteof(n.columns, ctx))) - -mutable struct FromFunctionNode <: TabularNode - over::SQLNode - columns::Vector{Symbol} - - FromFunctionNode(; over, columns) = - new(over, columns) -end - -FromFunction(args...; kws...) = - FromFunctionNode(args...; kws...) |> SQLNode - -PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) = - Expr(:call, - nameof(FromFunction), - Expr(:kw, :over, quoteof(n.over, ctx)), - Expr(:kw, :columns, Expr(:vect, [QuoteNode(col) for col in n.columns]...))) - -# Annotated Bind node. -mutable struct IntBindNode <: AbstractSQLNode - over::Union{SQLNode, Nothing} - args::Vector{SQLNode} - label_map::OrderedDict{Symbol, Int} - owned::Bool # Did we find the outer query for this node? - - function IntBindNode(; over = nothing, args, label_map = nothing, owned = false) - if label_map !== nothing - new(over, args, label_map, owned) - else - n = new(over, args, OrderedDict{Symbol, Int}(), owned) - populate_label_map!(n) - n - end - end -end - -IntBind(args...; kws...) = - IntBindNode(args...; kws...) |> SQLNode - -function PrettyPrinting.quoteof(n::IntBindNode, ctx::QuoteContext) - ex = Expr(:call, nameof(IntBind)) - push!(ex.args, Expr(:kw, :args, Expr(:vect, quoteof(n.args, ctx)...))) - push!(ex.args, Expr(:kw, :owned, n.owned)) - if n.over !== nothing - ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) - end - ex -end - -rebase(n::IntBindNode, n′) = - IntBindNode(over = rebase(n.over, n′), - args = n.args, label_map = n.label_map, owned = n.owned) - -# A recursive UNION ALL node. -mutable struct KnotNode <: TabularNode - over::Union{SQLNode, Nothing} - name::Symbol - box::BoxNode - iterator::SQLNode - iterator_boxes::Vector{BoxNode} - - KnotNode(; over = nothing, iterator, name = label(iterator), iterator_boxes = SQLNode[], box) = - new(over, name, box, iterator, iterator_boxes) -end - -KnotNode(iterator; over = nothing, box) = - KnotNode(over = over, iterator = iterator, box = box) - -Knot(args...; kws...) = - KnotNode(args...; kws...) |> SQLNode - -function PrettyPrinting.quoteof(n::KnotNode, ctx::QuoteContext) - ex = Expr(:call, nameof(Knot)) - if !ctx.limit - push!(ex.args, quoteof(n.iterator, ctx)) - else - push!(ex.args, :…) - end - push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) - if !ctx.limit - box_ex = Expr(:ref, quoteof(SQLNode(n.box), ctx)) - push!(ex.args, Expr(:kw, :box, box_ex)) - push!(ex.args, Expr(:kw, :iterator, quoteof(n.iterator, ctx))) - iterator_boxes_ex = - Expr(:vect, Any[Expr(:ref, quoteof(SQLNode(iterator_box), ctx)) - for iterator_box in n.iterator_boxes]...) - push!(ex.args, Expr(:kw, :iterator_boxes, iterator_boxes_ex)) - else - push!(ex.args, :…) - end - if n.over !== nothing - ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) - end - ex -end - -label(n::KnotNode) = - n.name - -rebase(n::KnotNode, n′) = - KnotNode(over = rebase(n.over, n′), - name = n.name, box = n.box, iterator = n.iterator, iterator_boxes = n.iterator_boxes) - -# Iterate node is split into Knot and IntIterate. -mutable struct IntIterateNode <: TabularNode - over::Union{SQLNode, Nothing} - name::Symbol - - IntIterateNode(; over = nothing, name) = - new(over, name) -end - -IntIterate(args...; kws...) = - IntIterateNode(args...; kws...) |> SQLNode - -function PrettyPrinting.quoteof(n::IntIterateNode, ctx::QuoteContext) - ex = Expr(:call, nameof(IntIterate)) - push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) - if n.over !== nothing - ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) - end - ex -end - -label(n::IntIterateNode) = - n.name - -rebase(n::IntIterateNode, n′) = - IntIterateNode(over = rebase(n.over, n′), name = n.name) - -# Annotated Join node. -mutable struct IntJoinNode <: TabularNode - over::Union{SQLNode, Nothing} - joinee::SQLNode - on::SQLNode - left::Bool - right::Bool - skip::Bool - type::BoxType # Type of the product of `over` and `joinee`. - lateral::Vector{SQLNode} # References from `joinee` to `over` for JOIN LATERAL. - - IntJoinNode(; over, joinee, on, left, right, skip, type = EMPTY_BOX, lateral = SQLNode[]) = - new(over, joinee, on, left, right, skip, type, lateral) -end - -IntJoinNode(joinee, on; over = nothing, left = false, right = false, skip = skip, type = EMPTY_BOX, lateral = SQLNode[]) = - IntJoinNode(over = over, joinee = joinee, on = on, left = left, right = right, skip = skip, type = type, lateral = lateral) - -IntJoin(args...; kws...) = - IntJoinNode(args...; kws...) |> SQLNode - -function PrettyPrinting.quoteof(n::IntJoinNode, ctx::QuoteContext) - ex = Expr(:call, nameof(IntJoin)) - if !ctx.limit - push!(ex.args, quoteof(n.joinee, ctx)) - push!(ex.args, quoteof(n.on, ctx)) - if n.left - push!(ex.args, Expr(:kw, :left, n.left)) - end - if n.right - push!(ex.args, Expr(:kw, :right, n.right)) - end - if n.skip - push!(ex.args, Expr(:kw, :skip, n.skip)) - end - if n.type !== EMPTY_BOX - push!(ex.args, Expr(:kw, :type, n.type)) - end - if !isempty(n.lateral) - push!(ex.args, Expr(:kw, :lateral, Expr(:vect, quoteof(n.lateral, ctx)...))) - end - else - push!(ex.args, :…) - end - if n.over !== nothing - ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) - end - ex -end - -rebase(n::IntJoinNode, n′) = - IntJoinNode(over = rebase(n.over, n′), - joinee = n.joinee, on = n.on, left = n.left, right = n.right, skip = n.skip, type = n.type, lateral = n.lateral) - -# Calculates the keys of a Group node. -mutable struct IntAutoDefineNode <: TabularNode - over::Union{SQLNode, Nothing} - - IntAutoDefineNode(; over = nothing) = - new(over) -end - -IntAutoDefine(args...; kws...) = - IntAutoDefineNode(args...; kws...) |> SQLNode - -function PrettyPrinting.quoteof(n::IntAutoDefineNode, ctx::QuoteContext) - ex = Expr(:call, nameof(IntAutoDefine)) - if n.over !== nothing - ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) - end - ex -end - -rebase(n::IntAutoDefineNode, n′) = - IntAutoDefineNode(over = rebase(n.over, n′)) - -label(n::Union{NameBoundNode, HandleBoundNode, IntAutoDefineNode, IntBindNode, IntJoinNode}) = - label(n.over) - - -# Annotation context. - -# Maps a node in the annotated graph to a path in the original graph (for error reporting). -struct PathMap - paths::Vector{Tuple{SQLNode, Int}} - origins::IdDict{Any, Int} - - PathMap() = - new(Tuple{SQLNode, Int}[], IdDict{Any, Int}()) -end - -function get_path(map::PathMap, idx::Int) - path = SQLNode[] - while idx != 0 - n, idx = map.paths[idx] - push!(path, n) - end - path -end - -get_path(map::PathMap, n) = - get_path(map, get(map.origins, n, 0)) - -struct AnnotateContext - catalog::SQLCatalog - path_map::PathMap - current_path::Vector{Int} - handles::Dict{SQLNode, Int} - boxes::Vector{BoxNode} - with_nodes::Dict{Symbol, SQLNode} - knot_node::Union{KnotNode, Nothing} - over_knot::Bool - - AnnotateContext(catalog) = - new(catalog, - PathMap(), - Int[0], - Dict{SQLNode, Int}(), - BoxNode[], - Dict{Symbol, SQLNode}(), - nothing, - false) - - AnnotateContext(ctx::AnnotateContext; - with_nodes = missing, - knot_node = missing, - over_knot = missing) = - new(ctx.catalog, - ctx.path_map, - ctx.current_path, - ctx.handles, - ctx.boxes, - coalesce(with_nodes, ctx.with_nodes), - coalesce(knot_node, ctx.knot_node), - coalesce(over_knot, ctx.over_knot)) -end - -function grow_path!(ctx::AnnotateContext, n::SQLNode) - push!(ctx.path_map.paths, (n, ctx.current_path[end])) - push!(ctx.current_path, length(ctx.path_map.paths)) -end - -function shrink_path!(ctx::AnnotateContext) - pop!(ctx.current_path) -end - -function mark_origin!(ctx::AnnotateContext, n::SQLNode) - ctx.path_map.origins[n] = ctx.current_path[end] -end - -mark_origin!(ctx::AnnotateContext, n::AbstractSQLNode) = - mark_origin!(ctx, convert(SQLNode, n)) - -get_path(ctx::AnnotateContext) = - get_path(ctx.path_map, ctx.current_path[end]) - -get_path(ctx::AnnotateContext, n::SQLNode) = - get_path(ctx.path_map, n) - -function make_handle!(ctx::AnnotateContext, n::SQLNode) - get!(ctx.handles, n) do - length(ctx.handles) + 1 - end -end - -function get_handle(ctx::AnnotateContext, n::SQLNode) - handle = 0 - idx = get(ctx.path_map.origins, n, 0) - if idx > 0 - n = ctx.path_map.paths[idx][1] - handle = get(ctx.handles, n, 0) - end - handle -end - -get_handle(ctx::AnnotateContext, ::Nothing) = - 0 - - -# Rewriting of the node graph. - -function annotate(n::SQLNode, ctx) - grow_path!(ctx, n) - n′ = convert(SQLNode, annotate(n[], ctx)) - mark_origin!(ctx, n′) - box = BoxNode(over = n′) - push!(ctx.boxes, box) - n′ = convert(SQLNode, box) - mark_origin!(ctx, n′) - shrink_path!(ctx) - n′ -end - -function annotate_scalar(n::SQLNode, ctx) - grow_path!(ctx, n) - n′ = convert(SQLNode, annotate_scalar(n[], ctx)) - mark_origin!(ctx, n′) - shrink_path!(ctx) - n′ -end - -annotate(ns::Vector{SQLNode}, ctx) = - SQLNode[annotate(n, ctx) for n in ns] - -annotate_scalar(ns::Vector{SQLNode}, ctx) = - SQLNode[annotate_scalar(n, ctx) for n in ns] - -function annotate(::Nothing, ctx) - if ctx.over_knot - knot_node = ctx.knot_node - @assert knot_node !== nothing - over = convert(SQLNode, FromSelf(over = knot_node.box)) - mark_origin!(ctx, over) - else - over = nothing - end - box = BoxNode(over = over) - push!(ctx.boxes, box) - n′ = convert(SQLNode, box) - mark_origin!(ctx, n′) - n′ -end - -annotate_scalar(::Nothing, ctx) = - nothing - -annotate(n::AbstractSQLNode, ctx) = - throw(IllFormedError(path = get_path(ctx))) - -function annotate_scalar(n::TabularNode, ctx) - ctx′ = AnnotateContext(ctx, over_knot = false) - n′ = convert(SQLNode, annotate(n, ctx′)) - mark_origin!(ctx, n′) - box = BoxNode(over = n′) - push!(ctx.boxes, box) - n′ = convert(SQLNode, box) - n′ -end - -function rebind(node, base, ctx) - while @dissect(node, over |> Get(name = name)) - mark_origin!(ctx, base) - base = NameBound(over = base, name = name) - node = over - end - if node !== nothing - handle = make_handle!(ctx, node) - mark_origin!(ctx, base) - base = HandleBound(over = base, handle = handle) - end - base -end - -function annotate_scalar(n::AggregateNode, ctx) - args′ = annotate_scalar(n.args, ctx) - filter′ = annotate_scalar(n.filter, ctx) - n′ = Agg(name = n.name, args = args′, filter = filter′) - rebind(n.over, n′, ctx) -end - -function annotate(n::AppendNode, ctx) - over = n.over - args = n.args - if over === nothing && !ctx.over_knot - if !isempty(args) - over = args[1] - args = args[2:end] - else - over = Where(false) - end - end - over′ = annotate(over, ctx) - ctx′ = AnnotateContext(ctx, over_knot = false) - args′ = annotate(args, ctx′) - Append(over = over′, args = args′) -end - -function annotate(n::AsNode, ctx) - over′ = annotate(n.over, ctx) - As(over = over′, name = n.name) -end - -function annotate_scalar(n::AsNode, ctx) - over′ = annotate_scalar(n.over, ctx) - As(over = over′, name = n.name) -end - -function annotate(n::BindNode, ctx) - over′ = annotate(n.over, ctx) - args′ = annotate_scalar(n.args, ctx) - IntBind(over = over′, args = args′, label_map = n.label_map) -end - -function annotate_scalar(n::BindNode, ctx) - ctx′ = AnnotateContext(ctx, over_knot = false) - annotate(n, ctx′) -end - -function annotate(n::DefineNode, ctx) - over′ = annotate(n.over, ctx) - args′ = annotate_scalar(n.args, ctx) - Define(over = over′, args = args′, label_map = n.label_map) -end - -function annotate(n::FromNode, ctx) - source = n.source - if source isa SQLTable - FromTable(table = source) - elseif source isa Symbol - over = get(ctx.with_nodes, source, nothing) - if over !== nothing - FromReference(over = over, name = source) - else - table = get(ctx.catalog, source, nothing) - if table !== nothing - FromTable(table = table) - else - throw(ReferenceError(REFERENCE_ERROR_TYPE.UNDEFINED_TABLE_REFERENCE, - name = source, - path = get_path(ctx))) - end - end - elseif source isa SelfSource - knot_node = ctx.knot_node - if knot_node !== nothing - FromSelf(over = knot_node.box) - else - throw(ReferenceError(REFERENCE_ERROR_TYPE.INVALID_SELF_REFERENCE, - path = get_path(ctx))) - end - elseif source isa ValuesSource - FromValues(columns = source.columns) - elseif source isa FunctionSource - FromFunction(over = annotate_scalar(source.node, ctx), - columns = source.columns) - else - FromNothing() - end -end - -function annotate_scalar(n::FunctionNode, ctx) - args′ = annotate_scalar(n.args, ctx) - Fun(name = n.name, args = args′) -end - -function annotate_scalar(n::GetNode, ctx) - if n.over === nothing - return n - end - rebind(n.over, Get(name = n.name), ctx) -end - -function annotate(n::GroupNode, ctx) - over′ = annotate(n.over, ctx) - if !isempty(n.by) - def = IntAutoDefine(over = over′) - mark_origin!(ctx, def) - box = BoxNode(over = def) - push!(ctx.boxes, box) - over′ = convert(SQLNode, box) - mark_origin!(ctx, over′) - end - by′ = annotate_scalar(n.by, ctx) - Group(over = over′, by = by′, name = n.name, label_map = n.label_map) -end - -function annotate(n::HighlightNode, ctx) - over′ = annotate(n.over, ctx) - Highlight(over = over′, color = n.color) -end - -function annotate_scalar(n::HighlightNode, ctx) - over′ = annotate_scalar(n.over, ctx) - Highlight(over = over′, color = n.color) -end - -function annotate(n::IterateNode, ctx) - over′ = annotate(n.over, ctx) - knot_box = BoxNode() - knot = KnotNode(over = over′, iterator = n.iterator, box = knot_box) - mark_origin!(ctx, knot) - knot_box.over = knot - push!(ctx.boxes, knot_box) - over′ = convert(SQLNode, knot_box) - mark_origin!(ctx, over′) - ctx′ = AnnotateContext(ctx, knot_node = knot, over_knot = true) - range_start = length(ctx.boxes) + 1 - iterator′ = annotate(n.iterator, ctx′) - range_stop = length(ctx.boxes) - knot.iterator = iterator′ - knot.iterator_boxes = ctx.boxes[range_start:range_stop] - IntIterateNode(over = over′, name = label(n.over)) -end - -function annotate(n::JoinNode, ctx) - over′ = annotate(n.over, ctx) - ctx′ = AnnotateContext(ctx, over_knot = false) - joinee′ = annotate(n.joinee, ctx′) - on′ = annotate_scalar(n.on, ctx) - IntJoin(over = over′, joinee = joinee′, on = on′, left = n.left, right = n.right, skip = n.optional) -end - -function annotate(n::LimitNode, ctx) - over′ = annotate(n.over, ctx) - Limit(over = over′, offset = n.offset, limit = n.limit) -end - -annotate_scalar(n::LiteralNode, ctx) = - n - -function annotate(n::OrderNode, ctx) - over′ = annotate(n.over, ctx) - by′ = annotate_scalar(n.by, ctx) - Order(over = over′, by = by′) -end - -annotate(n::OverNode, ctx) = - annotate(WithNode(over = n.arg, args = n.over !== nothing ? SQLNode[n.over] : SQLNode[]), ctx) - -function annotate(n::PartitionNode, ctx) - over′ = annotate(n.over, ctx) - by′ = annotate_scalar(n.by, ctx) - order_by′ = annotate_scalar(n.order_by, ctx) - Partition(over = over′, by = by′, order_by = order_by′, frame = n.frame, name = n.name) -end - -function annotate(n::SelectNode, ctx) - over′ = annotate(n.over, ctx) - args′ = annotate_scalar(n.args, ctx) - Select(over = over′, args = args′, label_map = n.label_map) -end - -function annotate_scalar(n::SortNode, ctx) - over′ = annotate_scalar(n.over, ctx) - Sort(over = over′, value = n.value, nulls = n.nulls) -end - -annotate_scalar(n::VariableNode, ctx) = - n - -function annotate(n::WhereNode, ctx) - over′ = annotate(n.over, ctx) - condition′ = annotate_scalar(n.condition, ctx) - Where(over = over′, condition = condition′) -end - -function annotate(n::WithNode, ctx) - ctx′ = AnnotateContext(ctx, over_knot = false) - args′ = annotate(n.args, ctx′) - with_nodes′ = copy(ctx.with_nodes) - for (name, i) in n.label_map - with_nodes′[name] = args′[i] - end - ctx′ = AnnotateContext(ctx, with_nodes = with_nodes′) - over′ = annotate(n.over, ctx′) - With(over = over′, args = args′, materialized = n.materialized, label_map = n.label_map) -end - -function annotate(n::WithExternalNode, ctx) - ctx′ = AnnotateContext(ctx, over_knot = false) - args′ = annotate(n.args, ctx′) - with_nodes′ = copy(ctx.with_nodes) - for (name, i) in n.label_map - with_nodes′[name] = args′[i] - end - ctx′ = AnnotateContext(ctx, with_nodes = with_nodes′) - over′ = annotate(n.over, ctx′) - WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = n.label_map) -end - - -# Type resolution. - -resolve!(ctx::AnnotateContext) = - resolve!(ctx.boxes, ctx) - -function resolve!(boxes::AbstractVector{BoxNode}, ctx) - for box in boxes - over = box.over - if over !== nothing - h = get_handle(ctx, over) - t = resolve(over[], ctx) - t = add_handle(t, h) - box.handle = h - box.type = t - end - end -end - -function resolve(n::AppendNode, ctx) - t = box_type(n.over) - for arg in n.args - t = intersect(t, box_type(arg)) - end - t -end - -function resolve(n::AsNode, ctx) - t = box_type(n.over) - fields = FieldTypeMap(n.name => t.row) - row = RowType(fields) - BoxType(n.name, row, t.handle_map) -end - -function resolve(n::DefineNode, ctx) - t = box_type(n.over) - fields = FieldTypeMap() - for (f, ft) in t.row.fields - if f in keys(n.label_map) - ft = ScalarType() - end - fields[f] = ft - end - for f in keys(n.label_map) - if !haskey(fields, f) - fields[f] = ScalarType() - end - end - row = RowType(fields, t.row.group) - BoxType(t.name, row, t.handle_map) -end - -resolve(n::FromNothingNode, ctx) = - EMPTY_BOX - -function resolve(n::FromFunctionNode, ctx) - fields = FieldTypeMap() - for f in n.columns - fields[f] = ScalarType() - end - row = RowType(fields) - BoxType(label(n.over), row) -end - -function resolve(n::FromReferenceNode, ctx) - t = box_type(n.over) - ft = get(t.row.fields, n.name, nothing) - if !(ft isa RowType) - throw(ReferenceError(REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE, - name = n.name, - path = get_path(ctx, n.over))) - end - BoxType(n.name, ft) -end - -function resolve(n::FromTableNode, ctx) - fields = FieldTypeMap() - for f in n.table.columns - fields[f] = ScalarType() - end - row = RowType(fields) - BoxType(n.table.name, row) -end - -function resolve(n::FromValuesNode, ctx) - columns = fieldnames(typeof(n.columns)) - fields = FieldTypeMap() - for f in columns - fields[f] = ScalarType() - end - row = RowType(fields) - BoxType(:values, row) -end - -function resolve(n::GroupNode, ctx) - t = box_type(n.over) - fields = FieldTypeMap() - for name in keys(n.label_map) - fields[name] = ScalarType() - end - if n.name === nothing - row = RowType(fields, t.row) - else - fields[n.name] = RowType(FieldTypeMap(), t.row) - row = RowType(fields) - end - BoxType(t.name, row) -end - -resolve(n::Union{FromSelfNode, HighlightNode, IntAutoDefineNode, IntBindNode, KnotNode, LimitNode, OrderNode, WhereNode, WithNode, WithExternalNode}, ctx) = - box_type(n.over) - -resolve_knot!(n::SQLNode, ctx) = - resolve_knot!(n[], ctx) - -function resolve_knot!(n::BoxNode, ctx) - knot = n.over[]::KnotNode - iterator_t = box_type(knot.iterator) - while !issubset(n.type.row, iterator_t.row) - n.type = intersect(n.type, iterator_t) - resolve!(knot.iterator_boxes, ctx) - iterator_t = box_type(knot.iterator) - end - n.type = add_handle(n.type, n.handle) -end - -function resolve(n::IntIterateNode, ctx) - resolve_knot!(n.over, ctx) - t = box_type(n.over) - BoxType(n.name, t.row) -end - -function resolve(n::IntJoinNode, ctx) - lt = box_type(n.over) - rt = box_type(n.joinee) - t = union(lt, rt) - n.type = t - t -end - -function resolve(n::PartitionNode, ctx) - t = box_type(n.over) - if n.name === nothing - row = RowType(t.row.fields, t.row) - else - fields = FieldTypeMap() - for (f, ft) in t.row.fields - if f !== n.name - fields[f] = ft - end - end - fields[n.name] = RowType(FieldTypeMap(), t.row) - row = RowType(fields, t.row.group) - end - BoxType(t.name, row, t.handle_map) -end - -function resolve(n::SelectNode, ctx) - t = box_type(n.over) - fields = FieldTypeMap() - for name in keys(n.label_map) - fields[name] = ScalarType() - end - row = RowType(fields) - BoxType(t.name, row) -end - - -# Collecting references. - -gather!(refs::Vector{SQLNode}, n::SQLNode) = - gather!(refs, n[]) - -function gather!(refs::Vector{SQLNode}, ns::Vector{SQLNode}) - for n in ns - gather!(refs, n) - end -end - -gather!(refs::Vector{SQLNode}, ::Union{AbstractSQLNode, Nothing}) = - nothing - -gather!(refs::Vector{SQLNode}, n::Union{AsNode, BoxNode, FromFunctionNode, HighlightNode, SortNode}) = - gather!(refs, n.over) - -function gather!(refs::Vector{SQLNode}, n::IntBindNode) - gather!(refs, n.over) - refs′ = SQLNode[] - gather!(refs′, n.args) - append!(refs, refs′) - # Make sure complex definitions and aggregates are wrapped in a nested - # subquery. - append!(refs, refs′) - n.owned = true -end - -gather!(refs::Vector{SQLNode}, n::FunctionNode) = - gather!(refs, n.args) - -function gather!(refs::Vector{SQLNode}, n::Union{AggregateNode, GetNode, HandleBoundNode, NameBoundNode}) - push!(refs, n) -end - - -# Validating references. - -function validate(t::BoxType, ref::SQLNode, ctx) - if @dissect(ref, over |> HandleBound(handle = handle)) - if handle in keys(t.handle_map) - ht = t.handle_map[handle] - if ht isa AmbiguousType - throw(ReferenceError(REFERENCE_ERROR_TYPE.AMBIGUOUS_HANDLE, - path = get_path(ctx, ref))) - end - validate(ht, over, ctx) - else - throw(ReferenceError(REFERENCE_ERROR_TYPE.UNDEFINED_HANDLE, - path = get_path(ctx, ref))) - end - else - validate(t.row, ref, ctx) - end -end - -function validate(t::RowType, ref::SQLNode, ctx) - while @dissect(ref, over |> NameBound(name = name)) - ft = get(t.fields, name, EmptyType()) - if !(ft isa RowType) - type = - ft isa EmptyType ? REFERENCE_ERROR_TYPE.UNDEFINED_NAME : - ft isa ScalarType ? REFERENCE_ERROR_TYPE.UNEXPECTED_SCALAR_TYPE : - ft isa AmbiguousType ? REFERENCE_ERROR_TYPE.AMBIGUOUS_NAME : error() - throw(ReferenceError(type, name = name, path = get_path(ctx, ref))) - end - t = ft - ref = over - end - if @dissect(ref, nothing |> Get(name = name)) - ft = get(t.fields, name, EmptyType()) - if !(ft isa ScalarType) - type = - ft isa EmptyType ? REFERENCE_ERROR_TYPE.UNDEFINED_NAME : - ft isa RowType ? REFERENCE_ERROR_TYPE.UNEXPECTED_ROW_TYPE : - ft isa AmbiguousType ? REFERENCE_ERROR_TYPE.AMBIGUOUS_NAME : error() - throw(ReferenceError(type, name = name, path = get_path(ctx, ref))) - end - elseif @dissect(ref, nothing |> Agg(name = name)) - if !(t.group isa RowType) - type = - t.group isa EmptyType ? REFERENCE_ERROR_TYPE.UNEXPECTED_AGGREGATE : - t.group isa AmbiguousType ? REFERENCE_ERROR_TYPE.AMBIGUOUS_AGGREGATE : error() - throw(ReferenceError(type, path = get_path(ctx, ref))) - end - else - error() - end -end - -function gather_and_validate!(refs::Vector{SQLNode}, n, t::BoxType, ctx) - start = length(refs) + 1 - gather!(refs, n) - for k in start:length(refs) - validate(t, refs[k], ctx) - end -end - -function route(lt::BoxType, rt::BoxType, ref::SQLNode) - if @dissect(ref, over |> HandleBound(handle = handle)) - if get(lt.handle_map, handle, EmptyType()) isa EmptyType - return 1 - else - return -1 - end - end - return route(lt.row, rt.row, ref) -end - -function route(lt::RowType, rt::RowType, ref::SQLNode) - while @dissect(ref, over |> NameBound(name = name)) - lt′ = get(lt.fields, name, EmptyType()) - if lt′ isa EmptyType - return 1 - end - rt′ = get(rt.fields, name, EmptyType()) - if rt′ isa EmptyType - return -1 - end - @assert lt′ isa RowType && rt′ isa RowType - lt = lt′ - rt = rt′ - ref = over - end - if @dissect(ref, Get(name = name)) - if name in keys(lt.fields) - return -1 - else - return 1 - end - elseif @dissect(ref, over |> Agg(name = name)) - if lt.group isa RowType - return -1 - else - return 1 - end - else - error() - end -end - - -# Linking references through box nodes. - -function link!(ctx::AnnotateContext) - root_box = ctx.boxes[end] - for (f, ft) in root_box.type.row.fields - if ft isa ScalarType - push!(root_box.refs, Get(f)) - end - end - link!(reverse(ctx.boxes), ctx) -end - -function link!(boxes::AbstractVector{BoxNode}, ctx) - for box in boxes - box.over !== nothing || continue - refs′ = SQLNode[] - for ref in box.refs - if @dissect(ref, over |> HandleBound(handle = handle)) && handle == box.handle - push!(refs′, over) - else - push!(refs′, ref) - end - end - link!(box.over[], refs′, ctx) - end -end - -function link!(n::AppendNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - append!(box.refs, refs) - for arg in n.args - box = arg[]::BoxNode - append!(box.refs, refs) - end -end - -function link!(n::AsNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - for ref in refs - if @dissect(ref, over |> NameBound(name = name)) - @assert name == n.name - push!(box.refs, over) - elseif @dissect(ref, HandleBound()) - push!(box.refs, ref) - else - error() - end - end -end - -function link!(n::DefineNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - seen = Set{Symbol}() - imm_refs = SQLNode[] - for ref in refs - if @dissect(ref, nothing |> Get(name = name)) && name in keys(n.label_map) - !(name in seen) || continue - push!(seen, name) - col = n.args[n.label_map[name]] - gather_and_validate!(imm_refs, col, box.type, ctx) - else - push!(box.refs, ref) - end - end - begin_imm_refs!(box) - append!(box.refs, imm_refs) -end - -link!(::Union{FromFunctionNode, FromNothingNode, FromTableNode, FromValuesNode}, ::Vector{SQLNode}, ctx) = - nothing - -function link!(n::FromReferenceNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - for ref in refs - push!(box.refs, NameBound(over = ref, name = n.name)) - end -end - -function link!(n::Union{FromSelfNode, HighlightNode, IntIterateNode, LimitNode, WithNode, WithExternalNode}, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - append!(box.refs, refs) -end - -function link!(n::GroupNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - begin_imm_refs!(box) - append!(box.refs, n.by) - has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> NameBound()), refs) - has_aggregates || return - for ref in refs - if (@dissect(ref, nothing |> Agg(args = args, filter = filter) |> NameBound(name = name)) && name === n.name) || - (@dissect(ref, nothing |> Agg(args = args, filter = filter)) && n.name === nothing) - gather_and_validate!(box.refs, args, box.type, ctx) - if filter !== nothing - gather_and_validate!(box.refs, filter, box.type, ctx) - end - elseif @dissect(ref, nothing |> Get(name = name)) && name in keys(n.label_map) - push!(box.refs, n.by[n.label_map[name]]) - end - end -end - -function link!(n::IntAutoDefineNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - begin_imm_refs!(box) - gather_and_validate!(box.refs, refs, box.type, ctx) -end - -function link!(n::IntBindNode, refs::Vector{SQLNode}, ctx) - if !n.owned - gather_and_validate!(SQLNode[], n.args, EMPTY_BOX, ctx) - end - box = n.over[]::BoxNode - append!(box.refs, refs) -end - -function link!(n::IntJoinNode, refs::Vector{SQLNode}, ctx) - lbox = n.over[]::BoxNode - rbox = n.joinee[]::BoxNode - for ref in refs - turn = route(lbox.type, rbox.type, ref) - if turn < 0 - push!(lbox.refs, ref) - else - push!(rbox.refs, ref) - end - end - if !isempty(rbox.refs) - n.skip = false - end - if n.skip - return - end - begin_imm_refs!(lbox) - begin_imm_refs!(rbox) - gather_and_validate!(n.lateral, n.joinee, lbox.type, ctx) - append!(lbox.refs, n.lateral) - refs′ = SQLNode[] - gather_and_validate!(refs′, n.on, n.type, ctx) - for ref in refs′ - turn = route(lbox.type, rbox.type, ref) - if turn < 0 - push!(lbox.refs, ref) - else - push!(rbox.refs, ref) - end - end -end - -function link!(n::KnotNode, ::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - iterator_box = n.iterator[]::BoxNode - refs = SQLNode[] - seen = Set{SQLNode}() - while true - repeat = false - for ref in n.box.refs - if !(ref in seen) - push!(refs, ref) - push!(seen, ref) - repeat = true - end - end - reset!(n.box) - append!(n.box.refs, refs) - repeat || break - for ibox in n.iterator_boxes - reset!(ibox) - end - append!(iterator_box.refs, refs) - link!(reverse(n.iterator_boxes), ctx) - end - for ref in refs - push!(box.refs, ref) - end -end - -function link!(n::OrderNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - append!(box.refs, refs) - begin_imm_refs!(box) - gather_and_validate!(box.refs, n.by, box.type, ctx) -end - -function link!(n::PartitionNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - imm_refs = SQLNode[] - has_aggregates = false - for ref in refs - if (@dissect(ref, nothing |> Agg(args = args, filter = filter) |> NameBound(name = name)) && name === n.name) || - (@dissect(ref, nothing |> Agg(args = args, filter = filter)) && n.name === nothing) - gather_and_validate!(imm_refs, args, box.type, ctx) - if filter !== nothing - gather_and_validate!(imm_refs, filter, box.type, ctx) - end - has_aggregates = true - else - push!(box.refs, ref) - end - end - if has_aggregates - gather_and_validate!(imm_refs, n.by, box.type, ctx) - gather_and_validate!(imm_refs, n.order_by, box.type, ctx) - begin_imm_refs!(box) - append!(box.refs, imm_refs) - end -end - -function link!(n::SelectNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - begin_imm_refs!(box) - gather_and_validate!(box.refs, n.args, box.type, ctx) -end - -function link!(n::WhereNode, refs::Vector{SQLNode}, ctx) - box = n.over[]::BoxNode - append!(box.refs, refs) - begin_imm_refs!(box) - gather_and_validate!(box.refs, n.condition, box.type, ctx) -end diff --git a/src/clauses.jl b/src/clauses.jl index 9dd74eaa..53561e4f 100644 --- a/src/clauses.jl +++ b/src/clauses.jl @@ -139,6 +139,7 @@ include("clauses/function.jl") include("clauses/group.jl") include("clauses/having.jl") include("clauses/identifier.jl") +include("clauses/internal.jl") include("clauses/join.jl") include("clauses/limit.jl") include("clauses/literal.jl") diff --git a/src/clauses/internal.jl b/src/clauses/internal.jl new file mode 100644 index 00000000..063f5211 --- /dev/null +++ b/src/clauses/internal.jl @@ -0,0 +1,25 @@ +# Auxiliary clauses. + +# Context holder for the serialize pass. + +mutable struct WithContextClause <: AbstractSQLClause + over::SQLClause + dialect::SQLDialect + + WithContextClause(; over, dialect) = + new(over, dialect) +end + +WITH_CONTEXT(args...; kws...) = + WithContextClause(args...; kws...) |> SQLClause + +dissect(scr::Symbol, ::typeof(WITH_CONTEXT), pats::Vector{Any}) = + dissect(scr, WithContextClause, pats) + +function PrettyPrinting.quoteof(c::WithContextClause, ctx::QuoteContext) + ex = Expr(:call, nameof(WITH_CONTEXT), Expr(:kw, :over, quoteof(c.over, ctx))) + if c.dialect !== default_dialect + push!(ex.args, Expr(:kw, :dialect, quoteof(c.dialect))) + end + ex +end diff --git a/src/link.jl b/src/link.jl new file mode 100644 index 00000000..570c80f6 --- /dev/null +++ b/src/link.jl @@ -0,0 +1,560 @@ +# Find select lists. + +struct LinkContext + dialect::SQLDialect + defs::Vector{SQLNode} + refs::Vector{SQLNode} + cte_refs::Base.ImmutableDict{Tuple{Symbol, Int}, Vector{SQLNode}} + knot_refs::Union{Vector{SQLNode}, Nothing} + + LinkContext(dialect) = + new(dialect, + SQLNode[], + SQLNode[], + Base.ImmutableDict{Tuple{Symbol, Int}, Vector{SQLNode}}(), + nothing) + + LinkContext(ctx::LinkContext; refs = ctx.refs, cte_refs = ctx.cte_refs, knot_refs = ctx.knot_refs) = + new(ctx.dialect, + ctx.defs, + refs, + cte_refs, + knot_refs) +end + +function link(n::SQLNode) + @dissect(n, WithContext(over = over, dialect = dialect)) || throw(ILLFormedError()) + ctx = LinkContext(dialect) + t = row_type(over) + refs = SQLNode[] + for (f, ft) in t.fields + if ft isa ScalarType + push!(refs, Get(f)) + end + end + over′ = Linked(refs, over = link(dismantle(over, ctx), ctx, refs)) + WithContext(over = over′, dialect = dialect, defs = ctx.defs) +end + +function dismantle(n::SQLNode, ctx) + convert(SQLNode, dismantle(n[], ctx)) +end + +function dismantle(ns::Vector{SQLNode}, ctx) + SQLNode[dismantle(n, ctx) for n in ns] +end + +function dismantle_scalar(n::SQLNode, ctx) + convert(SQLNode, dismantle_scalar(n[], ctx)) +end + +function dismantle_scalar(ns::Vector{SQLNode}, ctx) + SQLNode[dismantle_scalar(n, ctx) for n in ns] +end + +function dismantle_scalar(n::TabularNode, ctx) + n′ = dismantle(convert(SQLNode, n), ctx) + push!(ctx.defs, n′) + ref = lastindex(ctx.defs) + Isolated(ref) +end + +function dismantle_scalar(n::AggregateNode, ctx) + args′ = dismantle_scalar(n.args, ctx) + filter′ = n.filter !== nothing ? dismantle_scalar(n.filter, ctx) : nothing + Agg(name = n.name, args = args′, filter = filter′) +end + +function dismantle(n::AppendNode, ctx) + over′ = dismantle(n.over, ctx) + args′ = dismantle(n.args, ctx) + Append(over = over′, args = args′) +end + +function dismantle(n::AsNode, ctx) + over′ = dismantle(n.over, ctx) + As(over = over′, name = n.name) +end + +function dismantle_scalar(n::AsNode, ctx) + over′ = dismantle_scalar(n.over, ctx) + As(over = over′, name = n.name) +end + +function dismantle(n::BindNode, ctx) + over′ = dismantle(n.over, ctx) + args′ = dismantle_scalar(n.args, ctx) + BindNode(over = over′, args = args′, label_map = n.label_map) +end + +function dismantle_scalar(n::BindNode, ctx) + over′ = dismantle_scalar(n.over, ctx) + args′ = dismantle_scalar(n.args, ctx) + BindNode(over = over′, args = args′, label_map = n.label_map) +end + +dismantle_scalar(n::Union{BoundVariableNode, GetNode, LiteralNode, VariableNode}, ctx) = + convert(SQLNode, n) + +function dismantle(n::DefineNode, ctx) + over′ = dismantle(n.over, ctx) + args′ = dismantle_scalar(n.args, ctx) + Define(over = over′, args = args′, label_map = n.label_map) +end + +function dismantle(n::FromFunctionNode, ctx) + over′ = dismantle_scalar(n.over, ctx) + FromFunction(over = over′, columns = n.columns) +end + +dismantle(n::Union{FromIterateNode, FromNothingNode, FromTableExpressionNode, FromTableNode, FromValuesNode}, ctx) = + convert(SQLNode, n) + +function dismantle_scalar(n::FunctionNode, ctx) + args′ = dismantle_scalar(n.args, ctx) + Fun(name = n.name, args = args′) +end + +function dismantle(n::GroupNode, ctx) + over′ = dismantle(n.over, ctx) + by′ = dismantle_scalar(n.by, ctx) + Group(over = over′, by = by′, name = n.name, label_map = n.label_map) +end + +function dismantle(n::IterateNode, ctx) + over′ = dismantle(n.over, ctx) + iterator′ = dismantle(n.iterator, ctx) + Iterate(over = over′, iterator = iterator′) +end + +function dismantle(n::JoinNode, ctx) + rt = row_type(n.joinee) + router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType)) + over′ = dismantle(n.over, ctx) + joinee′ = dismantle(n.joinee, ctx) + on′ = dismantle_scalar(n.on, ctx) + RoutedJoin(over = over′, joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional) +end + +function dismantle(n::LimitNode, ctx) + over′ = dismantle(n.over, ctx) + Limit(over = over′, offset = n.offset, limit = n.limit) +end + +function dismantle_scalar(n::NestedNode, ctx) + over′ = dismantle_scalar(n.over, ctx) + NestedNode(over = over′, name = n.name) +end + +function dismantle(n::OrderNode, ctx) + over′ = dismantle(n.over, ctx) + by′ = dismantle_scalar(n.by, ctx) + Order(over = over′, by = by′) +end + +function dismantle(n::PaddingNode, ctx) + over′ = dismantle(n.over, ctx) + Padding(over = over′) +end + +function dismantle(n::PartitionNode, ctx) + over′ = dismantle(n.over, ctx) + by′ = dismantle_scalar(n.by, ctx) + order_by′ = dismantle_scalar(n.order_by, ctx) + Partition(over = over′, by = by′, order_by = order_by′, frame = n.frame, name = n.name) +end + +dismantle(n::ResolvedNode, ctx) = + dismantle(n.over, ctx) + +dismantle_scalar(n::ResolvedNode, ctx) = + dismantle_scalar(n.over, ctx) + +function dismantle(n::SelectNode, ctx) + over′ = dismantle(n.over, ctx) + args′ = dismantle_scalar(n.args, ctx) + Select(over = over′, args = args′, label_map = n.label_map) +end + +function dismantle_scalar(n::SortNode, ctx) + over′ = dismantle_scalar(n.over, ctx) + Sort(over = over′, value = n.value, nulls = n.nulls) +end + +function dismantle(n::WhereNode, ctx) + over′ = dismantle(n.over, ctx) + condition′ = dismantle_scalar(n.condition, ctx) + Where(over = over′, condition = condition′) +end + +function dismantle(n::WithNode, ctx) + over′ = dismantle(n.over, ctx) + args′ = dismantle(n.args, ctx) + With(over = over′, args = args′, materialized = n.materialized, label_map = n.label_map) +end + +function dismantle(n::WithExternalNode, ctx) + over′ = dismantle(n.over, ctx) + args′ = dismantle(n.args, ctx) + WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = n.label_map) +end + +function link(n::SQLNode, ctx) + convert(SQLNode, link(n[], ctx)) +end + +function link(ns::Vector{SQLNode}, ctx) + SQLNode[link(n, ctx) for n in ns] +end + +link(n, ctx, refs) = + link(n, LinkContext(ctx, refs = refs)) + +function link(n::AppendNode, ctx) + over′ = Linked(ctx.refs, over = link(n.over, ctx)) + args′ = SQLNode[Linked(ctx.refs, over = link(arg, ctx)) for arg in n.args] + Append(over = over′, args = args′) +end + +function link(n::AsNode, ctx) + refs = SQLNode[] + for ref in ctx.refs + if @dissect(ref, over |> Nested(name = name)) + @assert name == n.name + push!(refs, over) + else + error() + end + end + over′ = link(n.over, ctx, refs) + As(over = over′, name = n.name) +end + +function link(n::BindNode, ctx) + over′ = link(n.over, ctx) + Bind(over = over′, args = n.args, label_map = n.label_map) +end + +function link(n::DefineNode, ctx) + refs = SQLNode[] + seen = Set{Symbol}() + for ref in ctx.refs + if @dissect(ref, nothing |> Get(name = name)) && name in keys(n.label_map) + push!(seen, name) + else + push!(refs, ref) + end + end + if isempty(seen) + return link(n.over, ctx) + end + n_ext_refs = length(refs) + args′ = SQLNode[] + label_map′ = OrderedDict{Symbol, Int}() + for (f, i) in n.label_map + f in seen || continue + arg′ = n.args[i] + gather!(arg′, ctx, refs) + push!(args′, arg′) + label_map′[f] = lastindex(args′) + end + over′ = Linked(refs, n_ext_refs, over = link(n.over, ctx, refs)) + Define( + over = over′, + args = args′, + label_map = label_map′) +end + +link(n::Union{FromFunctionNode, FromNothingNode, FromTableNode, FromValuesNode}, ctx) = + convert(SQLNode, n) + +function link(n::FromIterateNode, ctx) + append!(ctx.knot_refs, ctx.refs) + n +end + +function link(n::FromTableExpressionNode, ctx) + refs = ctx.cte_refs[(n.name, n.depth)] + for ref in ctx.refs + push!(refs, Nested(over = ref, name = n.name)) + end + n +end + +function link(n::GroupNode, ctx) + has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> Nested()), ctx.refs) + if !has_aggregates && isempty(n.by) + return link(FromNothing(), ctx) + end + # Some group keys are added both to SELECT and to GROUP BY. + # To avoid duplicate SQL, they must be evaluated in a nested subquery. + refs = SQLNode[] + append!(refs, n.by) + # Ignore `SELECT DISTINCT` case. + if has_aggregates + ctx′ = LinkContext(ctx, refs = refs) + for ref in ctx.refs + if (@dissect(ref, nothing |> Agg(args = args, filter = filter) |> Nested(name = name)) && name === n.name) || + (@dissect(ref, nothing |> Agg(args = args, filter = filter)) && n.name === nothing) + gather!(args, ctx′) + if filter !== nothing + gather!(filter, ctx′) + end + elseif @dissect(ref, nothing |> Get(name = name)) && name in keys(n.label_map) + # Force evaluation in a nested subquery. + push!(refs, n.by[n.label_map[name]]) + end + end + end + over = n.over + if !isempty(n.by) + over = Padding(over = over) + end + over′ = Linked(refs, 0, over = link(over, ctx, refs)) + Group(over = over′, by = n.by, name = n.name, label_map = n.label_map) +end + +function link(n::IterateNode, ctx) + iterator′ = n.iterator + defs = copy(ctx.defs) + cte_refs = [(v, length(v)) for (k, v) in ctx.cte_refs] + refs = SQLNode[] + knot_refs = SQLNode[] + repeat = true + while repeat + refs = copy(ctx.refs) + append!(refs, knot_refs) + knot_refs = SQLNode[] + for (v, l) in cte_refs + resize!(v, l) + end + iterator′ = link(n.iterator, LinkContext(ctx, refs = refs, knot_refs = knot_refs)) + repeat = false + seen = Set(refs) + for ref in knot_refs + if !in(ref, seen) + repeat = true + ctx.defs .= defs + break + end + end + end + iterator′ = Linked(refs, over = iterator′) + over′ = Linked(refs, over = link(n.over, ctx, refs)) + n′ = Linked(refs, over = Iterate(over = over′, iterator = iterator′)) + Padding(over = n′) +end + +function route(r::JoinRouter, ref::SQLNode) + if @dissect(ref, over |> Nested(name = name)) && name in r.label_set + return 1 + end + if @dissect(ref, Get(name = name)) && name in r.label_set + return 1 + end + if @dissect(ref, over |> Agg()) && r.group + return 1 + end + return -1 +end + +function link(n::RoutedJoinNode, ctx) + lrefs = SQLNode[] + rrefs = SQLNode[] + for ref in ctx.refs + turn = route(n.router, ref) + push!(turn < 0 ? lrefs : rrefs, ref) + end + if n.optional && isempty(rrefs) + return link(n.over, ctx) + end + ln_ext_refs = length(lrefs) + rn_ext_refs = length(rrefs) + refs′ = SQLNode[] + lateral_refs = SQLNode[] + gather!(n.joinee, ctx, lateral_refs) + append!(lrefs, lateral_refs) + lateral = !isempty(lateral_refs) + gather!(n.on, ctx, refs′) + for ref in refs′ + turn = route(n.router, ref) + push!(turn < 0 ? lrefs : rrefs, ref) + end + over′ = Linked(lrefs, ln_ext_refs, over = link(n.over, ctx, lrefs)) + joinee′ = Linked(rrefs, rn_ext_refs, over = link(n.joinee, ctx, rrefs)) + RoutedJoinNode( + over = over′, + joinee = joinee′, + on = n.on, + router = n.router, + left = n.left, + right = n.right, + lateral = lateral) +end + +function link(n::LimitNode, ctx) + over′ = Linked(ctx.refs, over = link(n.over, ctx)) + Limit(over = over′, offset = n.offset, limit = n.limit) +end + +function link(n::OrderNode, ctx) + refs = copy(ctx.refs) + n_ext_refs = length(refs) + gather!(n.by, ctx, refs) + over′ = Linked(refs, n_ext_refs, over = link(n.over, ctx, refs)) + Order(over = over′, by = n.by) +end + +function link(n::PaddingNode, ctx) + refs = SQLNode[] + gather!(ctx.refs, ctx, refs) + over′ = Linked(refs, 0, over = link(n.over, ctx, refs)) + Padding(over = over′) +end + +function link(n::PartitionNode, ctx) + refs = SQLNode[] + imm_refs = SQLNode[] + ctx′ = LinkContext(ctx, refs = imm_refs) + has_aggregates = false + for ref in ctx.refs + if (@dissect(ref, nothing |> Agg(args = args, filter = filter) |> Nested(name = name)) && name === n.name) || + (@dissect(ref, nothing |> Agg(args = args, filter = filter)) && n.name === nothing) + gather!(args, ctx′) + if filter !== nothing + gather!(filter, ctx′) + end + has_aggregates = true + else + push!(refs, ref) + end + end + if !has_aggregates + return link(n.over, ctx) + end + gather!(n.by, ctx′) + gather!(n.order_by, ctx′) + n_ext_refs = length(refs) + append!(refs, imm_refs) + over′ = Linked(refs, n_ext_refs, over = link(n.over, ctx, refs)) + Partition(over = over′, by = n.by, order_by = n.order_by, frame = n.frame, name = n.name) +end + +function link(n::SelectNode, ctx) + refs = SQLNode[] + gather!(n.args, ctx, refs) + over′ = Linked(refs, 0, over = link(n.over, ctx, refs)) + Select(over = over′, args = n.args, label_map = n.label_map) +end + +function link(n::WhereNode, ctx) + refs = copy(ctx.refs) + n_ext_refs = length(refs) + gather!(n.condition, ctx, refs) + over′ = Linked(refs, n_ext_refs, over = link(n.over, ctx, refs)) + Where(n.condition, over = over′) +end + +function _cte_depth(dict, name) + for (n, d) in keys(dict) + if n === name + return d + end + end + 0 +end + +function link(n::WithNode, ctx) + cte_refs′ = ctx.cte_refs + refs_map = Vector{SQLNode}[] + for name in keys(n.label_map) + depth = _cte_depth(ctx.cte_refs, name) + 1 + refs = SQLNode[] + cte_refs′ = Base.ImmutableDict(cte_refs′, (name, depth) => refs) + push!(refs_map, refs) + end + ctx′ = LinkContext(ctx, cte_refs = cte_refs′) + over′ = Linked(ctx′.refs, over = link(n.over, ctx′)) + args′ = SQLNode[] + label_map′ = OrderedDict{Symbol, Int}() + for (f, i) in n.label_map + arg = n.args[i] + refs = refs_map[i] + arg′ = Linked(refs, over = link(arg, ctx, refs)) + push!(args′, arg′) + label_map′[f] = lastindex(args′) + end + With(over = over′, args = args′, materialized = n.materialized, label_map = label_map′) +end + +function link(n::WithExternalNode, ctx) + cte_refs′ = ctx.cte_refs + refs_map = Vector{SQLNode}[] + for name in keys(n.label_map) + depth = _cte_depth(ctx.cte_refs, name) + 1 + refs = SQLNode[] + cte_refs′ = Base.ImmutableDict(cte_refs′, (name, depth) => refs) + push!(refs_map, refs) + end + ctx′ = LinkContext(ctx, cte_refs = cte_refs′) + over′ = Linked(ctx′.refs, over = link(n.over, ctx′)) + args′ = SQLNode[] + label_map′ = OrderedDict{Symbol, Int}() + for (f, i) in n.label_map + arg = n.args[i] + refs = refs_map[i] + arg′ = Linked(refs, over = link(arg, ctx, refs)) + push!(args′, arg′) + label_map′[f] = lastindex(args′) + end + WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = label_map′) +end + +function gather!(n::SQLNode, ctx) + gather!(n[], ctx) +end + +function gather!(ns::Vector{SQLNode}, ctx) + for n in ns + gather!(n, ctx) + end +end + +gather!(n::AbstractSQLNode, ctx) = + nothing + +gather!(n, ctx, refs) = + gather!(n, LinkContext(ctx, refs = refs)) + +function gather!(n::Union{AggregateNode, GetNode, NestedNode}, ctx) + push!(ctx.refs, n) + nothing +end + +function gather!(n::Union{AsNode, FromFunctionNode, ResolvedNode, SortNode}, ctx) + gather!(n.over, ctx) +end + +function gather!(n::BindNode, ctx) + gather!(n.over, ctx) + refs′ = SQLNode[] + gather!(n.args, ctx, refs′) + append!(ctx.refs, refs′) + # Force aggregates and other complex definitions to be wrapped + # in a nested subquery. + append!(ctx.refs, refs′) + nothing +end + +function gather!(n::FunctionNode, ctx) + gather!(n.args, ctx) +end + +function gather!(n::IsolatedNode, ctx) + def = ctx.defs[n.idx] + !@dissect(def, Linked()) || return + refs = SQLNode[] + def′ = Linked(refs, over = link(def, ctx, refs)) + ctx.defs[n.idx] = def′ + nothing +end diff --git a/src/nodes.jl b/src/nodes.jl index 118a9dfb..58d562ad 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -55,12 +55,33 @@ Base.convert(::Type{SQLNode}, obj) = label(n::SQLNode) = label(n[])::Symbol -label(::Union{AbstractSQLNode, Nothing}) = +label(::Nothing) = :_ +@generated function label(n::AbstractSQLNode) + if :over in fieldnames(n) + return :(label(n.over)) + else + return :(:_) + end +end + rebase(n::SQLNode, n′) = convert(SQLNode, rebase(n[], n′)) +@generated function rebase(n::AbstractSQLNode, n′) + fs = fieldnames(n) + if !in(:over, fs) + return quote + throw(RebaseError(path = [n])) + end + end + return quote + $n($(Any[Expr(:kw, f, f === :over ? :(rebase(n.$(f), n′)) : :(n.$(f))) + for f in fs]...)) + end +end + Chain(n′, n) = rebase(convert(SQLNode, n), n′) @@ -321,17 +342,28 @@ function Base.showerror(io::IO, err::IllFormedError) showpath(io, err.path) end +""" +A node that cannot be rebased. +""" +struct RebaseError <: FunSQLError + path::Vector{SQLNode} + + RebaseError(; path = SQLNode[])= + new(path) +end + +function Base.showerror(io::IO, err::RebaseError) + print(io, "FunSQL.RebaseError") + showpath(io, err.path) +end + module REFERENCE_ERROR_TYPE @enum ReferenceErrorType::UInt8 begin - UNDEFINED_HANDLE - AMBIGUOUS_HANDLE UNDEFINED_NAME - AMBIGUOUS_NAME UNEXPECTED_ROW_TYPE UNEXPECTED_SCALAR_TYPE UNEXPECTED_AGGREGATE - AMBIGUOUS_AGGREGATE UNDEFINED_TABLE_REFERENCE INVALID_TABLE_REFERENCE INVALID_SELF_REFERENCE @@ -355,22 +387,14 @@ end function Base.showerror(io::IO, err::ReferenceError) print(io, "FunSQL.ReferenceError: ") - if err.type == REFERENCE_ERROR_TYPE.UNDEFINED_HANDLE - print(io, "node-bound reference failed to resolve") - elseif err.type == REFERENCE_ERROR_TYPE.AMBIGUOUS_HANDLE - print(io, "node-bound reference is ambiguous") - elseif err.type == REFERENCE_ERROR_TYPE.UNDEFINED_NAME + if err.type == REFERENCE_ERROR_TYPE.UNDEFINED_NAME print(io, "cannot find `$(err.name)`") - elseif err.type == REFERENCE_ERROR_TYPE.AMBIGUOUS_NAME - print(io, "`$(err.name)` is ambiguous") elseif err.type == REFERENCE_ERROR_TYPE.UNEXPECTED_ROW_TYPE print(io, "incomplete reference `$(err.name)`") elseif err.type == REFERENCE_ERROR_TYPE.UNEXPECTED_SCALAR_TYPE print(io, "unexpected reference after `$(err.name)`") elseif err.type == REFERENCE_ERROR_TYPE.UNEXPECTED_AGGREGATE print(io, "aggregate expression requires Group or Partition") - elseif err.type == REFERENCE_ERROR_TYPE.AMBIGUOUS_AGGREGATE - print(io, "aggregate expression is ambiguous") elseif err.type == REFERENCE_ERROR_TYPE.UNDEFINED_TABLE_REFERENCE print(io, "cannot find `$(err.name)`") elseif err.type == REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE @@ -391,9 +415,9 @@ end function highlight(path::Vector{SQLNode}, color = Base.error_color()) @assert !isempty(path) - n = Highlight(over = path[1], color = color) - for k = 2:lastindex(path) - n = substitute(path[k], path[k-1], n) + n = Highlight(over = path[end], color = color) + for k = lastindex(path):-1:2 + n = substitute(path[k - 1], path[k], n) end n end @@ -417,7 +441,7 @@ function populate_label_map!(n, args = n.args, label_map = n.label_map, group_na for (i, arg) in enumerate(args) name = label(arg) if name === group_name || name in keys(label_map) - err = DuplicateLabelError(name, path = [arg, n]) + err = DuplicateLabelError(name, path = [n, arg]) throw(err) end label_map[name] = i @@ -436,8 +460,8 @@ struct TransliterateContext TransliterateContext(mod::Module, src::LineNumberNode, decl::Bool = false) = new(mod, src, decl) - TransliterateContext(ctx::TransliterateContext; src = missing, decl = missing) = - new(ctx.mod, coalesce(src, ctx.src), coalesce(decl, ctx.decl)) + TransliterateContext(ctx::TransliterateContext; src = ctx.src, decl = ctx.decl) = + new(ctx.mod, src, decl) end """ @@ -641,6 +665,7 @@ include("nodes/function.jl") include("nodes/get.jl") include("nodes/group.jl") include("nodes/highlight.jl") +include("nodes/internal.jl") include("nodes/iterate.jl") include("nodes/join.jl") include("nodes/limit.jl") diff --git a/src/nodes/aggregate.jl b/src/nodes/aggregate.jl index 2586f4d7..87cd4824 100644 --- a/src/nodes/aggregate.jl +++ b/src/nodes/aggregate.jl @@ -145,12 +145,6 @@ end label(n::AggregateNode) = Meta.isidentifier(n.name) ? n.name : :_ -rebase(n::AggregateNode, n′) = - AggregateNode(over = rebase(n.over, n′), - name = n.name, - args = n.args, - filter = n.filter) - # Notation for making aggregate nodes. diff --git a/src/nodes/append.jl b/src/nodes/append.jl index 392c7792..f837169c 100644 --- a/src/nodes/append.jl +++ b/src/nodes/append.jl @@ -85,6 +85,3 @@ function label(n::AppendNode) end lbl end - -rebase(n::AppendNode, n′) = - AppendNode(over = rebase(n.over, n′), args = n.args) diff --git a/src/nodes/as.jl b/src/nodes/as.jl index 3176cd60..ca52308e 100644 --- a/src/nodes/as.jl +++ b/src/nodes/as.jl @@ -78,6 +78,3 @@ end label(n::AsNode) = n.name - -rebase(n::AsNode, n′) = - AsNode(over = rebase(n.over, n′), name = n.name) diff --git a/src/nodes/bind.jl b/src/nodes/bind.jl index 86280f49..a9f49721 100644 --- a/src/nodes/bind.jl +++ b/src/nodes/bind.jl @@ -107,9 +107,3 @@ function PrettyPrinting.quoteof(n::BindNode, ctx::QuoteContext) end ex end - -label(n::BindNode) = - label(n.over) - -rebase(n::BindNode, n′) = - BindNode(over = rebase(n.over, n′), args = n.args) diff --git a/src/nodes/define.jl b/src/nodes/define.jl index 991e5cfb..0b4d2a47 100644 --- a/src/nodes/define.jl +++ b/src/nodes/define.jl @@ -83,9 +83,3 @@ function PrettyPrinting.quoteof(n::DefineNode, ctx::QuoteContext) end ex end - -label(n::DefineNode) = - label(n.over) - -rebase(n::DefineNode, n′) = - DefineNode(over = rebase(n.over, n′), args = n.args, label_map = n.label_map) diff --git a/src/nodes/from.jl b/src/nodes/from.jl index eb3ba2db..9a5b1362 100644 --- a/src/nodes/from.jl +++ b/src/nodes/from.jl @@ -3,7 +3,7 @@ abstract type AbstractSource end -struct SelfSource <: AbstractSource +struct IterateSource <: AbstractSource end struct ValuesSource <: AbstractSource @@ -15,17 +15,17 @@ struct FunctionSource <: AbstractSource columns::Vector{Symbol} end -_from_source(source::Union{SQLTable, SelfSource, FunctionSource, ValuesSource, Nothing}) = +_from_source(source::Union{SQLTable, IterateSource, FunctionSource, ValuesSource, Nothing}) = source _from_source(source::AbstractString) = Symbol(source) _from_source(source::Symbol) = - source === :^ ? SelfSource() : source + source === :^ ? IterateSource() : source _from_source(::typeof(^)) = - SelfSource() + IterateSource() function _from_source(node::AbstractSQLNode; columns::AbstractVector{<:Union{Symbol, AbstractString}}) @@ -51,7 +51,7 @@ function _from_source(source) end mutable struct FromNode <: TabularNode - source::Union{SQLTable, Symbol, SelfSource, ValuesSource, FunctionSource, Nothing} + source::Union{SQLTable, Symbol, IterateSource, ValuesSource, FunctionSource, Nothing} FromNode(; source, kws...) = new(_from_source(source; kws...)) @@ -230,7 +230,7 @@ function PrettyPrinting.quoteof(n::FromNode, ctx::QuoteContext) Expr(:call, nameof(From), tex) elseif source isa Symbol Expr(:call, nameof(From), QuoteNode(source)) - elseif source isa SelfSource + elseif source isa IterateSource Expr(:call, nameof(From), :^) elseif source isa ValuesSource Expr(:call, nameof(From), quoteof(source.columns, ctx)) diff --git a/src/nodes/get.jl b/src/nodes/get.jl index 250656e7..d8dcb45e 100644 --- a/src/nodes/get.jl +++ b/src/nodes/get.jl @@ -25,8 +25,6 @@ A reference to a column of the input dataset. When a column reference is ambiguous (e.g., with [`Join`](@ref)), use [`As`](@ref) to disambiguate the columns, and a chained `Get` node (`Get.a.b.….z`) to refer to a column wrapped with `… |> As(:b) |> As(:a)`. -Alternatively, `Get` could be explicitly bound to the tabular node that -produces the given column. # Examples @@ -55,29 +53,6 @@ julia> q = From(:person) |> on = Get.location_id .== Get.location.location_id) |> Select(Get.person_id, Get.location.state); -julia> print(render(q, tables = [person, location])) -SELECT - "person_1"."person_id", - "location_1"."state" -FROM "person" AS "person_1" -JOIN "location" AS "location_1" ON ("person_1"."location_id" = "location_1"."location_id") -``` - -*Show patients with their state of residence.* - -```jldoctest -julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]); - -julia> location = SQLTable(:location, columns = [:location_id, :state]); - -julia> qₚ = From(:person); - -julia> qₗ = From(:location); - -julia> q = qₚ |> - Join(qₗ, on = qₚ.location_id .== qₗ.location_id) |> - Select(qₚ.person_id, qₗ.state); - julia> print(render(q, tables = [person, location])) SELECT "person_1"."person_id", @@ -138,6 +113,3 @@ end label(n::GetNode) = n.name - -rebase(n::GetNode, n′) = - GetNode(over = rebase(n.over, n′), name = n.name) diff --git a/src/nodes/group.jl b/src/nodes/group.jl index 6f4b8394..cac773f8 100644 --- a/src/nodes/group.jl +++ b/src/nodes/group.jl @@ -123,9 +123,3 @@ function PrettyPrinting.quoteof(n::GroupNode, ctx::QuoteContext) end ex end - -label(n::GroupNode) = - label(n.over) - -rebase(n::GroupNode, n′) = - GroupNode(over = rebase(n.over, n′), by = n.by, name = n.name, label_map = n.label_map) diff --git a/src/nodes/highlight.jl b/src/nodes/highlight.jl index 57bdf21a..bc22224c 100644 --- a/src/nodes/highlight.jl +++ b/src/nodes/highlight.jl @@ -84,9 +84,3 @@ function PrettyPrinting.quoteof(n::HighlightNode, ctx::QuoteContext) pop!(ctx.colors) EscWrapper(ex, n.color, copy(ctx.colors)) end - -label(n::HighlightNode) = - label(n.over) - -rebase(n::HighlightNode, n′) = - HighlightNode(over = rebase(n.over, n′), color = n.color) diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl new file mode 100644 index 00000000..00b61b84 --- /dev/null +++ b/src/nodes/internal.jl @@ -0,0 +1,303 @@ +# Auxiliary nodes. + +# Preserve context between rendering passes. +mutable struct WithContextNode <: AbstractSQLNode + over::SQLNode + dialect::SQLDialect + tables::Dict{Symbol, SQLTable} + defs::Vector{SQLNode} + + WithContextNode(; over, dialect = default_dialect, tables = Dict{Symbol, SQLTable}(), defs = SQLNode[]) = + new(over, dialect, tables, defs) +end + +WithContext(args...; kws...) = + WithContextNode(args...; kws...) |> SQLNode + +dissect(scr::Symbol, ::typeof(WithContext), pats::Vector{Any}) = + dissect(scr, WithContextNode, pats) + +function PrettyPrinting.quoteof(n::WithContextNode, ctx::QuoteContext) + ex = Expr(:call, nameof(WithContext), Expr(:kw, :over, quoteof(n.over, ctx))) + if n.dialect != default_dialect + push!(ex.args, Expr(:kw, :dialect, quoteof(n.dialect))) + end + if !isempty(n.tables) + push!(ex.args, Expr(:kw, :tables, quoteof(n.tables))) + end + if !isempty(n.defs) + push!(ex.args, Expr(:kw, :defs, Expr(:vect, Any[quoteof(def, ctx) for def in n.defs]...))) + end + ex +end + +# Annotations added by "resolve" pass. +mutable struct ResolvedNode <: AbstractSQLNode + over::SQLNode + type::AbstractSQLType + + ResolvedNode(; over, type) = + new(over, type) +end + +ResolvedNode(type; over) = + ResolvedNode(over = over, type = type) + +Resolved(args...; kws...) = + ResolvedNode(args...; kws...) |> SQLNode + +dissect(scr::Symbol, ::typeof(Resolved), pats::Vector{Any}) = + dissect(scr, ResolvedNode, pats) + +function PrettyPrinting.quoteof(n::ResolvedNode, ctx::QuoteContext) + ex = Expr(:call, nameof(Resolved), quoteof(n.type)) + push!(ex.args, Expr(:kw, :over, quoteof(n.over, ctx))) + ex +end + +# Annotations added by "link" pass. +mutable struct LinkedNode <: TabularNode + over::SQLNode + refs::Vector{SQLNode} + n_ext_refs::Int + + LinkedNode(; over, refs = SQLNode[], n_ext_refs = length(refs)) = + new(over, refs, n_ext_refs) +end + +LinkedNode(refs, n_ext_refs = length(refs); over) = + LinkedNode(over = over, refs = refs, n_ext_refs = n_ext_refs) + +Linked(args...; kws...) = + LinkedNode(args...; kws...) |> SQLNode + +dissect(scr::Symbol, ::typeof(Linked), pats::Vector{Any}) = + dissect(scr, LinkedNode, pats) + +function PrettyPrinting.quoteof(n::LinkedNode, ctx::QuoteContext) + ex = Expr(:call, nameof(Linked)) + if !isempty(n.refs) + push!(ex.args, Expr(:vect, Any[quoteof(ref, ctx) for ref in n.refs]...)) + end + if n.n_ext_refs != length(n.refs) + push!(ex.args, n.n_ext_refs) + end + push!(ex.args, Expr(:kw, :over, quoteof(n.over, ctx))) + ex +end + +# Get(over = Get(:a), name = :b) => Nested(over = Get(:b), name = :a) +mutable struct NestedNode <: AbstractSQLNode + over::SQLNode + name::Symbol + + NestedNode(; over, name) = + new(over, name) +end + +Nested(args...; kws...) = + NestedNode(args...; kws...) |> SQLNode + +dissect(scr::Symbol, ::typeof(Nested), pats::Vector{Any}) = + dissect(scr, NestedNode, pats) + +PrettyPrinting.quoteof(n::NestedNode, ctx::QuoteContext) = + Expr(:call, nameof(Nested), Expr(:kw, :over, quoteof(n.over, ctx)), Expr(:kw, :name, QuoteNode(n.name))) + +# Var() that found the corresponding Bind() +mutable struct BoundVariableNode <: AbstractSQLNode + name::Symbol + depth::Int + + BoundVariableNode(; name, depth) = + new(name, depth) +end + +BoundVariableNode(name, depth) = + BoundVariableNode(name = name, depth = depth) + +BoundVariable(args...; kws...) = + BoundVariableNode(args...; kws...) |> SQLNode + +PrettyPrinting.quoteof(n::BoundVariableNode, ctx::QuoteContext) = + Expr(:call, nameof(BoundVariable), QuoteNode(n.name), n.depth) + +# A generic From node is specialized to FromNothing, FromTable, +# FromTableExpression, FromIterate, FromValues, or FromFunction. +mutable struct FromNothingNode <: TabularNode +end + +FromNothing(args...; kws...) = + FromNothingNode(args...; kws...) |> SQLNode + +PrettyPrinting.quoteof(::FromNothingNode, ::QuoteContext) = + Expr(:call, nameof(FromNothing)) + +mutable struct FromTableNode <: TabularNode + table::SQLTable + + FromTableNode(; table) = + new(table) +end + +FromTable(args...; kws...) = + FromTableNode(args...; kws...) |> SQLNode + +function PrettyPrinting.quoteof(n::FromTableNode, ctx::QuoteContext) + tex = get(ctx.vars, n.table, nothing) + if tex === nothing + tex = quoteof(n.table, limit = true) + end + Expr(:call, nameof(FromTable), Expr(:kw, :table, tex)) +end + +mutable struct FromTableExpressionNode <: TabularNode + name::Symbol + depth::Int + + FromTableExpressionNode(; name, depth) = + new(name, depth) +end + +FromTableExpressionNode(name, depth) = + FromTableExpressionNode(name = name, depth = depth) + +FromTableExpression(args...; kws...) = + FromTableExpressionNode(args...; kws...) |> SQLNode + +PrettyPrinting.quoteof(n::FromTableExpressionNode, ctx::QuoteContext) = + Expr(:call, nameof(FromTableExpression), QuoteNode(n.name), n.depth) + +mutable struct FromIterateNode <: TabularNode +end + +FromIterate(args...; kws...) = + FromIterateNode(args...; kws...) |> SQLNode + +PrettyPrinting.quoteof(n::FromIterateNode, ctx::QuoteContext) = + Expr(:call, nameof(FromIterate)) + +mutable struct FromValuesNode <: TabularNode + columns::NamedTuple + + FromValuesNode(; columns) = + new(columns) +end + +FromValues(args...; kws...) = + FromValuesNode(args...; kws...) |> SQLNode + +PrettyPrinting.quoteof(n::FromValuesNode, ctx::QuoteContext) = + Expr(:call, nameof(FromValues), Expr(:kw, :columns, quoteof(n.columns, ctx))) + +mutable struct FromFunctionNode <: TabularNode + over::SQLNode + columns::Vector{Symbol} + + FromFunctionNode(; over, columns) = + new(over, columns) +end + +FromFunction(args...; kws...) = + FromFunctionNode(args...; kws...) |> SQLNode + +PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) = + Expr(:call, + nameof(FromFunction), + Expr(:kw, :over, quoteof(n.over, ctx)), + Expr(:kw, :columns, Expr(:vect, [QuoteNode(col) for col in n.columns]...))) + +# Annotated Join node. +struct JoinRouter + label_set::Set{Symbol} + group::Bool +end + +PrettyPrinting.quoteof(r::JoinRouter) = + Expr(:call, nameof(JoinRouter), quoteof(r.label_set), quoteof(r.group)) + +mutable struct RoutedJoinNode <: TabularNode + over::Union{SQLNode, Nothing} + joinee::SQLNode + on::SQLNode + router::JoinRouter + left::Bool + right::Bool + lateral::Bool + optional::Bool + + RoutedJoinNode(; over, joinee, on, router, left, right, lateral = false, optional = false) = + new(over, joinee, on, router, left, right, lateral, optional) +end + +RoutedJoinNode(joinee, on; over = nothing, router, left = false, right = false, lateral = false, optional = false) = + RoutedJoinNode(over = over, joinee = joinee, on = on, router, left = left, right = right, lateral = lateral, optional = optional) + +RoutedJoin(args...; kws...) = + RoutedJoinNode(args...; kws...) |> SQLNode + +function PrettyPrinting.quoteof(n::RoutedJoinNode, ctx::QuoteContext) + ex = Expr(:call, nameof(RoutedJoin)) + if !ctx.limit + push!(ex.args, quoteof(n.joinee, ctx)) + push!(ex.args, quoteof(n.on, ctx)) + push!(ex.args, Expr(:kw, :router, quoteof(n.router))) + if n.left + push!(ex.args, Expr(:kw, :left, n.left)) + end + if n.right + push!(ex.args, Expr(:kw, :right, n.right)) + end + if n.lateral + push!(ex.args, Expr(:kw, :lateral, n.lateral)) + end + if n.optional + push!(ex.args, Expr(:kw, :optional, n.optional)) + end + else + push!(ex.args, :…) + end + if n.over !== nothing + ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) + end + ex +end + +# Calculates the keys of a Group node. Also used by Iterate. +mutable struct PaddingNode <: TabularNode + over::Union{SQLNode, Nothing} + + PaddingNode(; over = nothing) = + new(over) +end + +Padding(args...; kws...) = + PaddingNode(args...; kws...) |> SQLNode + +function PrettyPrinting.quoteof(n::PaddingNode, ctx::QuoteContext) + ex = Expr(:call, nameof(Padding)) + if n.over !== nothing + ex = Expr(:call, :|>, quoteof(n.over, ctx), ex) + end + ex +end + +# Isolated subquery. +mutable struct IsolatedNode <: AbstractSQLNode + idx::Int + + IsolatedNode(; idx) = + new(idx) +end + +IsolatedNode(idx) = + IsolatedNode(idx = idx) + +Isolated(args...; kws...) = + IsolatedNode(args...; kws...) |> SQLNode + +PrettyPrinting.quoteof(n::IsolatedNode, ctx::QuoteContext) = + Expr(:call, nameof(Isolated), n.idx) + +dissect(scr::Symbol, ::typeof(Isolated), pats::Vector{Any}) = + dissect(scr, IsolatedNode, pats) diff --git a/src/nodes/iterate.jl b/src/nodes/iterate.jl index 852dce9b..aec63a76 100644 --- a/src/nodes/iterate.jl +++ b/src/nodes/iterate.jl @@ -111,9 +111,3 @@ function PrettyPrinting.quoteof(n::IterateNode, ctx::QuoteContext) end ex end - -label(n::IterateNode) = - label(n.over) - -rebase(n::IterateNode, n′) = - IterateNode(over = rebase(n.over, n′), iterator = n.iterator) diff --git a/src/nodes/join.jl b/src/nodes/join.jl index 4c886e26..b5e56536 100644 --- a/src/nodes/join.jl +++ b/src/nodes/join.jl @@ -115,11 +115,3 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) end ex end - -label(n::JoinNode) = - label(n.over) - -rebase(n::JoinNode, n′) = - JoinNode(over = rebase(n.over, n′), - joinee = n.joinee, on = n.on, left = n.left, right = n.right, optional = n.optional) - diff --git a/src/nodes/limit.jl b/src/nodes/limit.jl index af09fca9..143a728c 100644 --- a/src/nodes/limit.jl +++ b/src/nodes/limit.jl @@ -77,9 +77,3 @@ function PrettyPrinting.quoteof(n::LimitNode, ctx::QuoteContext) end ex end - -label(n::LimitNode) = - label(n.over) - -rebase(n::LimitNode, n′) = - LimitNode(over = rebase(n.over, n′), offset = n.offset, limit = n.limit) diff --git a/src/nodes/literal.jl b/src/nodes/literal.jl index 4139ddaf..12ae01b8 100644 --- a/src/nodes/literal.jl +++ b/src/nodes/literal.jl @@ -51,4 +51,3 @@ Base.convert(::Type{AbstractSQLNode}, ref::Base.RefValue) = PrettyPrinting.quoteof(n::LiteralNode, ctx::QuoteContext) = Expr(:call, nameof(Lit), n.val) - diff --git a/src/nodes/order.jl b/src/nodes/order.jl index b5143e2c..37599872 100644 --- a/src/nodes/order.jl +++ b/src/nodes/order.jl @@ -64,10 +64,3 @@ function PrettyPrinting.quoteof(n::OrderNode, ctx::QuoteContext) end ex end - -label(n::OrderNode) = - label(n.over) - -rebase(n::OrderNode, n′) = - OrderNode(over = rebase(n.over, n′), by = n.by) - diff --git a/src/nodes/over.jl b/src/nodes/over.jl index b6297c98..09e66b67 100644 --- a/src/nodes/over.jl +++ b/src/nodes/over.jl @@ -72,6 +72,3 @@ end label(n::OverNode) = label(n.arg) - -rebase(n::OverNode, n′) = - OverNode(over = rebase(n.over, n′), arg = n.arg, materialized = n.materialized) diff --git a/src/nodes/partition.jl b/src/nodes/partition.jl index 157a9538..c0b68b2f 100644 --- a/src/nodes/partition.jl +++ b/src/nodes/partition.jl @@ -118,9 +118,3 @@ function PrettyPrinting.quoteof(n::PartitionNode, ctx::QuoteContext) end ex end - -label(n::PartitionNode) = - label(n.over) - -rebase(n::PartitionNode, n′) = - PartitionNode(over = rebase(n.over, n′), by = n.by, order_by = n.order_by, frame = n.frame, name = n.name) diff --git a/src/nodes/select.jl b/src/nodes/select.jl index 00c8a1a4..dc6f9c16 100644 --- a/src/nodes/select.jl +++ b/src/nodes/select.jl @@ -70,9 +70,3 @@ function PrettyPrinting.quoteof(n::SelectNode, ctx::QuoteContext) end ex end - -label(n::SelectNode) = - label(n.over) - -rebase(n::SelectNode, n′) = - SelectNode(over = rebase(n.over, n′), args = n.args, label_map = n.label_map) diff --git a/src/nodes/sort.jl b/src/nodes/sort.jl index e67f4e9b..4e88a067 100644 --- a/src/nodes/sort.jl +++ b/src/nodes/sort.jl @@ -87,9 +87,3 @@ function PrettyPrinting.quoteof(n::SortNode, ctx::QuoteContext) end ex end - -label(n::SortNode) = - label(n.over) - -rebase(n::SortNode, n′) = - SortNode(over = rebase(n.over, n′), value = n.value, nulls = n.nulls) diff --git a/src/nodes/where.jl b/src/nodes/where.jl index f3db2008..bc1c71bb 100644 --- a/src/nodes/where.jl +++ b/src/nodes/where.jl @@ -55,9 +55,3 @@ function PrettyPrinting.quoteof(n::WhereNode, ctx::QuoteContext) end ex end - -label(n::WhereNode) = - label(n.over) - -rebase(n::WhereNode, n′) = - WhereNode(over = rebase(n.over, n′), condition = n.condition) diff --git a/src/nodes/with.jl b/src/nodes/with.jl index d930fd13..08c24f4c 100644 --- a/src/nodes/with.jl +++ b/src/nodes/with.jl @@ -90,9 +90,3 @@ function PrettyPrinting.quoteof(n::WithNode, ctx::QuoteContext) end ex end - -label(n::WithNode) = - label(n.over) - -rebase(n::WithNode, n′) = - WithNode(over = rebase(n.over, n′), args = n.args, materialized = n.materialized) diff --git a/src/nodes/with_external.jl b/src/nodes/with_external.jl index fd0da5db..9dca4570 100644 --- a/src/nodes/with_external.jl +++ b/src/nodes/with_external.jl @@ -98,10 +98,3 @@ function PrettyPrinting.quoteof(n::WithExternalNode, ctx::QuoteContext) end ex end - -label(n::WithExternalNode) = - label(n.over) - -rebase(n::WithExternalNode, n′) = - WithExternalNode(over = rebase(n.over, n′), args = n.args, qualifiers = n.qualifiers, handler = n.handler) - diff --git a/src/render.jl b/src/render.jl index 55d6fbe0..cfdba5fc 100644 --- a/src/render.jl +++ b/src/render.jl @@ -56,18 +56,15 @@ function render(catalog::SQLCatalog, n::SQLNode) return sql end end - actx = AnnotateContext(catalog) - n′ = annotate(n, actx) - @debug "FunSQL.annotate\n" * sprint(pprint, n′) _group = Symbol("FunSQL.annotate") - resolve!(actx) - @debug "FunSQL.resolve!\n" * sprint(pprint, n′) _group = Symbol("FunSQL.resolve!") - link!(actx) - @debug "FunSQL.link!\n" * sprint(pprint, n′) _group = Symbol("FunSQL.link!") - tctx = TranslateContext(actx) - c = translate_toplevel(n′, tctx) + n = WithContext(over = n, dialect = catalog.dialect, tables = catalog.tables) + n = resolve(n) + @debug "FunSQL.resolve\n" * sprint(pprint, n) _group = Symbol("FunSQL.resolve") + n = link(n) + @debug "FunSQL.link\n" * sprint(pprint, n) _group = Symbol("FunSQL.link") + c = translate(n) @debug "FunSQL.translate\n" * sprint(pprint, c) _group = Symbol("FunSQL.translate") - sql = render(catalog.dialect, c) - @debug "FunSQL.render\n" * sql _group = Symbol("FunSQL.render") + sql = serialize(c) + @debug "FunSQL.serialize\n" * sprint(pprint, sql) _group = Symbol("FunSQL.serialize") if cache !== nothing cache[n] = sql end @@ -87,9 +84,8 @@ render(dialect::SQLDialect, c::AbstractSQLClause) = Serialize the syntax tree of a SQL query. """ function render(dialect::SQLDialect, c::SQLClause) - ctx = SerializeContext(dialect) - serialize!(c, ctx) - raw = String(take!(ctx.io)) - SQLString(raw, vars = ctx.vars) + c = WITH_CONTEXT(over = c, dialect = dialect) + sql = serialize(c) + sql end diff --git a/src/resolve.jl b/src/resolve.jl new file mode 100644 index 00000000..2b1fd945 --- /dev/null +++ b/src/resolve.jl @@ -0,0 +1,532 @@ +# Resolving node types. + +struct ResolveContext + dialect::SQLDialect + tables::Dict{Symbol, SQLTable} + path::Vector{SQLNode} + row_type::RowType + cte_types::Base.ImmutableDict{Symbol, Tuple{Int, RowType}} + var_types::Base.ImmutableDict{Symbol, Tuple{Int, ScalarType}} + knot_type::Union{RowType, Nothing} + implicit_knot::Bool + + ResolveContext(dialect, tables) = + new(dialect, + tables, + SQLNode[], + EMPTY_ROW, + Base.ImmutableDict{Symbol, Tuple{Int, RowType}}(), + Base.ImmutableDict{Symbol, Tuple{Int, ScalarType}}(), + nothing, + false) + + ResolveContext( + ctx::ResolveContext; + row_type = ctx.row_type, + cte_types = ctx.cte_types, + var_types = ctx.var_types, + knot_type = ctx.knot_type, + implicit_knot = ctx.implicit_knot) = + new(ctx.dialect, + ctx.tables, + ctx.path, + row_type, + cte_types, + var_types, + knot_type, + implicit_knot) +end + +get_path(ctx::ResolveContext) = + copy(ctx.path) + +function row_type(n::SQLNode) + @dissect(n, Resolved(type = type::RowType)) || throw(IllFormedError()) + type +end + +function type(n::SQLNode) + @dissect(n, Resolved(type = t)) || throw(IllFormedError()) + t +end + +function resolve(n::SQLNode) + @dissect(n, WithContext(over = n′, dialect = dialect, tables = tables)) || throw(IllFormedError()) + ctx = ResolveContext(dialect, tables) + WithContext(over = resolve(n′, ctx), dialect = dialect) +end + +function resolve(n::SQLNode, ctx) + push!(ctx.path, n) + try + convert(SQLNode, resolve(n[], ctx)) + finally + pop!(ctx.path) + end +end + +resolve(ns::Vector{SQLNode}, ctx) = + SQLNode[resolve(n, ctx) for n in ns] + +function resolve(::Nothing, ctx) + t = ctx.knot_type + if t !== nothing && ctx.implicit_knot + n = FromIterate() + else + n = FromNothing() + t = EMPTY_ROW + end + Resolved(t, over = n) +end + +resolve(n, ctx, t) = + resolve(n, ResolveContext(ctx, row_type = t)) + +resolve(n::AbstractSQLNode, ctx) = + throw(IllFormedError(path = get_path(ctx))) + +function resolve_scalar(n::SQLNode, ctx) + push!(ctx.path, n) + n′ = convert(SQLNode, resolve_scalar(n[], ctx)) + pop!(ctx.path) + n′ +end + +function resolve_scalar(ns::Vector{SQLNode}, ctx) + SQLNode[resolve_scalar(n, ctx) for n in ns] +end + +resolve_scalar(n, ctx, t) = + resolve_scalar(n, ResolveContext(ctx, row_type = t)) + +function resolve_scalar(n::TabularNode, ctx) + n′ = resolve(n, ResolveContext(ctx, implicit_knot = false)) + Resolved(ScalarType(), over = n′) +end + +function unnest(node, base, ctx) + while @dissect(node, over |> Get(name = name)) + base = Nested(over = base, name = name) + node = over + end + if node !== nothing + throw(IllFormedError(path = get_path(ctx))) + end + base +end + +function resolve_scalar(n::AggregateNode, ctx) + if n.over !== nothing + n′ = unnest(n.over, Agg(name = n.name, args = n.args, filter = n.filter), ctx) + return resolve_scalar(n′, ctx) + end + t = ctx.row_type.group + if !(t isa RowType) + error_type = REFERENCE_ERROR_TYPE.UNEXPECTED_AGGREGATE + throw(ReferenceError(error_type, path = get_path(ctx))) + end + ctx′ = ResolveContext(ctx, row_type = t) + args′ = resolve_scalar(n.args, ctx′) + filter′ = nothing + if n.filter !== nothing + filter′ = resolve_scalar(n.filter, ctx′) + end + n′ = Agg(name = n.name, args = args′, filter = filter′) + Resolved(ScalarType(), over = n′) +end + +function resolve(n::AppendNode, ctx) + over = n.over + args = n.args + if over === nothing && !ctx.implicit_knot + if !isempty(args) + over = args[1] + args = args[2:end] + else + over = Where(false) + end + end + over′ = resolve(over, ctx) + args′ = resolve(args, ResolveContext(ctx, implicit_knot = false)) + n′ = Append(over = over′, args = args′) + t = row_type(over′) + for arg in args′ + t = intersect(t, row_type(arg)) + end + Resolved(t, over = n′) +end + +function resolve(n::AsNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + n′ = As(name = n.name, over = over′) + Resolved(RowType(FieldTypeMap(n.name => t)), over = n′) +end + +function resolve_scalar(n::AsNode, ctx) + over′ = resolve_scalar(n.over, ctx) + n′ = As(name = n.name, over = over′) + Resolved(type(over′), over = n′) +end + +function resolve(n::BindNode, ctx) + args′ = resolve_scalar(n.args, ctx) + var_types′ = ctx.var_types + for (name, i) in n.label_map + v = get(ctx.var_types, name, nothing) + depth = 1 + (v !== nothing ? v[1] : 0) + t = type(args′[i]) + if !(t isa ScalarType) + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.UNEXPECTED_ROW_TYPE, + name = name, + path = get_path(ctx))) + + end + var_types′ = Base.ImmutableDict(var_types′, name => (depth, t)) + end + over′ = resolve(n.over, ResolveContext(ctx, var_types = var_types′)) + n′ = Bind(over = over′, args = args′, label_map = n.label_map) + Resolved(row_type(over′), over = n′) +end + +function resolve_scalar(n::BindNode, ctx) + args′ = resolve_scalar(n.args, ctx) + var_types′ = ctx.var_types + for (name, i) in n.label_map + v = get(ctx.var_types, name, nothing) + depth = 1 + (v !== nothing ? v[1] : 0) + t = type(args′[i]) + if !(t isa ScalarType) + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.UNEXPECTED_ROW_TYPE, + name = name, + path = get_path(ctx))) + + end + var_types′ = Base.ImmutableDict(var_types′, name => (depth, t)) + end + over′ = resolve_scalar(n.over, ResolveContext(ctx, var_types = var_types′)) + n′ = Bind(over = over′, args = args′, label_map = n.label_map) + Resolved(type(over′), over = n′) +end + +function resolve_scalar(n::NestedNode, ctx) + t = get(ctx.row_type.fields, n.name, EmptyType()) + if !(t isa RowType) + error_type = + t isa EmptyType ? + REFERENCE_ERROR_TYPE.UNDEFINED_NAME : + REFERENCE_ERROR_TYPE.UNEXPECTED_SCALAR_TYPE + throw(ReferenceError(error_type, name = n.name, path = get_path(ctx))) + end + over′ = resolve_scalar(n.over, ctx, t) + n′ = NestedNode(over = over′, name = n.name) + Resolved(type(over′), over = n′) +end + +function resolve(n::DefineNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + args′ = resolve_scalar(n.args, ctx, t) + fields = FieldTypeMap() + for (f, ft) in t.fields + i = get(n.label_map, f, nothing) + if i !== nothing + ft = type(args′[i]) + end + fields[f] = ft + end + for (f, i) in n.label_map + if !haskey(fields, f) + fields[f] = type(args′[i]) + end + end + n′ = Define(over = over′, args = args′, label_map = n.label_map) + Resolved(RowType(fields, t.group), over = n′) +end + +function RowType(table::SQLTable) + fields = FieldTypeMap() + for f in table.columns + fields[f] = ScalarType() + end + RowType(fields) +end + +function resolve(n::FromNode, ctx) + source = n.source + if source isa SQLTable + n′ = FromTable(table = source) + t = RowType(source) + elseif source isa Symbol + v = get(ctx.cte_types, source, nothing) + if v !== nothing + (depth, t) = v + n′ = FromTableExpression(source, depth) + else + table = get(ctx.tables, source, nothing) + if table === nothing + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.UNDEFINED_TABLE_REFERENCE, + name = source, + path = get_path(ctx))) + end + n′ = FromTable(table = table) + t = RowType(table) + end + elseif source isa IterateSource + t = ctx.knot_type + if t === nothing + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.INVALID_SELF_REFERENCE, + path = get_path(ctx))) + end + n′ = FromIterate() + elseif source isa ValuesSource + n′ = FromValues(columns = source.columns) + fields = FieldTypeMap() + for f in keys(source.columns) + fields[f] = ScalarType() + end + t = RowType(fields) + elseif source isa FunctionSource + n′ = FromFunction(over = resolve_scalar(source.node, ctx), columns = source.columns) + fields = FieldTypeMap() + for f in source.columns + fields[f] = ScalarType() + end + t = RowType(fields) + elseif source === nothing + n′ = FromNothing() + t = RowType() + else + error() + end + Resolved(t, over = n′) +end + +function resolve_scalar(n::FunctionNode, ctx) + args′ = resolve_scalar(n.args, ctx) + n′ = Fun(name = n.name, args = args′) + Resolved(ScalarType(), over = n′) +end + +function resolve_scalar(n::GetNode, ctx) + if n.over !== nothing + n′ = unnest(n.over, Get(name = n.name), ctx) + return resolve_scalar(n′, ctx) + end + t = get(ctx.row_type.fields, n.name, EmptyType()) + if !(t isa ScalarType) + error_type = + t isa EmptyType ? + REFERENCE_ERROR_TYPE.UNDEFINED_NAME : + REFERENCE_ERROR_TYPE.UNEXPECTED_ROW_TYPE + throw(ReferenceError(error_type, name = n.name, path = get_path(ctx))) + end + Resolved(t, over = n) +end + +function resolve(n::GroupNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + by′ = resolve_scalar(n.by, ctx, t) + fields = FieldTypeMap() + for (name, i) in n.label_map + fields[name] = type(by′[i]) + end + group = t + if n.name !== nothing + fields[n.name] = RowType(FieldTypeMap(), group) + group = EmptyType() + end + n′ = Group(over = over′, by = by′, label_map = n.label_map) + Resolved(RowType(fields, group), over = n′) +end + +resolve(n::HighlightNode, ctx) = + resolve(n.over, ctx) + +resolve_scalar(n::HighlightNode, ctx) = + resolve_scalar(n.over, ctx) + +function resolve(n::IterateNode, ctx) + over′ = resolve(n.over, ResolveContext(ctx, knot_type = nothing, implicit_knot = false)) + t = row_type(over′) + iterator′ = resolve(n.iterator, ResolveContext(ctx, knot_type = t, implicit_knot = true)) + iterator_t = row_type(iterator′) + while !issubset(t, iterator_t) + t = intersect(t, iterator_t) + iterator′ = resolve(n.iterator, ResolveContext(ctx, knot_type = t, implicit_knot = true)) + iterator_t = row_type(iterator′) + end + n′ = IterateNode(over = over′, iterator = iterator′) + Resolved(t, over = n′) +end + +function resolve(n::JoinNode, ctx) + over′ = resolve(n.over, ctx) + lt = row_type(over′) + joinee′ = resolve(n.joinee, ResolveContext(ctx, row_type = lt, implicit_knot = false)) + rt = row_type(joinee′) + fields = FieldTypeMap() + for (f, ft) in lt.fields + fields[f] = get(rt.fields, f, ft) + end + for (f, ft) in rt.fields + if !haskey(fields, f) + fields[f] = ft + end + end + group = rt.group isa EmptyType ? lt.group : rt.group + t = RowType(fields, group) + on′ = resolve_scalar(n.on, ctx, t) + n′ = Join(over = over′, joinee = joinee′, on = on′, left = n.left, right = n.right, optional = n.optional) + Resolved(t, over = n′) +end + +function resolve(n::LimitNode, ctx) + over′ = resolve(n.over, ctx) + if n.offset === nothing && n.limit === nothing + return over′ + end + t = row_type(over′) + n′ = Limit(over = over′, offset = n.offset, limit = n.limit) + Resolved(t, over = n′) +end + +function resolve_scalar(n::LiteralNode, ctx) + Resolved(ScalarType(), over = n) +end + +function resolve(n::OrderNode, ctx) + over′ = resolve(n.over, ctx) + if isempty(n.by) + return over′ + end + t = row_type(over′) + by′ = resolve_scalar(n.by, ctx, t) + n′ = Order(over = over′, by = by′) + Resolved(t, over = n′) +end + +resolve(n::OverNode, ctx) = + resolve(With(over = n.arg, args = n.over !== nothing ? SQLNode[n.over] : SQLNode[]), ctx) + +function resolve(n::PartitionNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + ctx′ = ResolveContext(ctx, row_type = t) + by′ = resolve_scalar(n.by, ctx′) + order_by′ = resolve_scalar(n.order_by, ctx′) + fields = t.fields + group = t.group + if n.name === nothing + group = t + else + fields = FieldTypeMap() + for (f, ft) in t.fields + if f !== n.name + fields[f] = ft + end + end + fields[n.name] = RowType(FieldTypeMap(), t) + end + n′ = Partition(over = over′, by = by′, order_by = order_by′, frame = n.frame, name = n.name) + Resolved(RowType(fields, group), over = n′) +end + +resolve(n::ResolvedNode, ctx) = + n + +function resolve(n::SelectNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + args′ = resolve_scalar(n.args, ctx, t) + fields = FieldTypeMap() + for (name, i) in n.label_map + fields[name] = type(args′[i]) + end + n′ = Select(over = over′, args = args′, label_map = n.label_map) + Resolved(RowType(fields), over = n′) +end + +function resolve_scalar(n::SortNode, ctx) + over′ = resolve_scalar(n.over, ctx) + n′ = Sort(over = over′, value = n.value, nulls = n.nulls) + Resolved(type(over′), over = n′) +end + +function resolve_scalar(n::VariableNode, ctx) + v = get(ctx.var_types, n.name, nothing) + if v !== nothing + depth, t = v + n′ = BoundVariable(n.name, depth) + Resolved(t, over = n′) + else + Resolved(ScalarType(), over = n) + end +end + +function resolve(n::WhereNode, ctx) + over′ = resolve(n.over, ctx) + t = row_type(over′) + condition′ = resolve_scalar(n.condition, ctx, t) + n′ = Where(over = over′, condition = condition′) + Resolved(t, over = n′) +end + +function resolve(n::WithNode, ctx) + ctx′ = ResolveContext(ctx, knot_type = nothing, implicit_knot = false) + args′ = resolve(n.args, ctx′) + cte_types′ = ctx.cte_types + for (name, i) in n.label_map + v = get(ctx.cte_types, name, nothing) + depth = 1 + (v !== nothing ? v[1] : 0) + t = row_type(args′[i]) + cte_t = get(t.fields, name, EmptyType()) + if !(cte_t isa RowType) + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE, + name = name, + path = get_path(ctx))) + + end + cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t)) + end + ctx′ = ResolveContext(ctx, cte_types = cte_types′) + over′ = resolve(n.over, ctx′) + n′ = With(over = over′, args = args′, materialized = n.materialized, label_map = n.label_map) + Resolved(row_type(over′), over = n′) +end + +function resolve(n::WithExternalNode, ctx) + ctx′ = ResolveContext(ctx, knot_type = nothing, implicit_knot = false) + args′ = resolve(n.args, ctx′) + cte_types′ = ctx.cte_types + for (name, i) in n.label_map + v = get(ctx.cte_types, name, nothing) + depth = 1 + (v !== nothing ? v[1] : 0) + t = row_type(args′[i]) + cte_t = get(t.fields, name, EmptyType()) + if !(cte_t isa RowType) + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE, + name = name, + path = get_path(ctx))) + + end + cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t)) + end + ctx′ = ResolveContext(ctx, cte_types = cte_types′) + over′ = resolve(n.over, ctx′) + n′ = WithExternal(over = over′, args = args′, qualifiers = n.qualifiers, handler = n.handler, label_map = n.label_map) + Resolved(row_type(over′), over = n′) +end diff --git a/src/serialize.jl b/src/serialize.jl index d909a991..6c60e3b8 100644 --- a/src/serialize.jl +++ b/src/serialize.jl @@ -11,6 +11,14 @@ mutable struct SerializeContext <: IO new(dialect, IOBuffer(), 0, false, Symbol[]) end +function serialize(c::SQLClause) + @dissect(c, WITH_CONTEXT(over = c′, dialect = dialect)) || throw(IllFormedError()) + ctx = SerializeContext(dialect) + serialize!(c′, ctx) + raw = String(take!(ctx.io)) + SQLString(raw, vars = ctx.vars) +end + Base.write(ctx::SerializeContext, octet::UInt8) = write(ctx.io, octet) diff --git a/src/translate.jl b/src/translate.jl index 90403816..923b7305 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -3,12 +3,13 @@ # Partially constructed query. struct Assemblage + name::Symbol # Base name for the alias. clause::Union{SQLClause, Nothing} # A SQL subquery (possibly without SELECT clause). cols::OrderedDict{Symbol, SQLClause} # SELECT arguments, if necessary. repl::Dict{SQLNode, Symbol} # Maps a reference node to a column alias. - Assemblage(clause; cols = OrderedDict{Symbol, SQLClause}(), repl = Dict{SQLNode, Symbol}()) = - new(clause, cols, repl) + Assemblage(name, clause; cols = OrderedDict{Symbol, SQLClause}(), repl = Dict{SQLNode, Symbol}()) = + new(name, clause, cols, repl) end # Pack SELECT arguments. @@ -111,6 +112,16 @@ function make_repl_cols(trns::Vector{Pair{SQLNode, SQLClause}})::Tuple{Dict{SQLN (repl, cols) end +function aligned_columns(refs, repl, args) + length(refs) == length(args) || return false + for (ref, arg) in zip(refs, args) + if !(@dissect(arg, ID(name = name) || AS(name = name)) && name === repl[ref]) + return false + end + end + return true +end + struct CTEAssemblage a::Assemblage name::Symbol @@ -127,35 +138,44 @@ end struct TranslateContext dialect::SQLDialect - path_map::PathMap + defs::Vector{SQLNode} aliases::Dict{Symbol, Int} - cte_map::OrderedDict{SQLNode, CTEAssemblage} recursive::Ref{Bool} - vars::Dict{Symbol, SQLClause} + ctes::Vector{CTEAssemblage} + cte_map::Base.ImmutableDict{Tuple{Symbol, Int}, Int} + knot::Int + refs::Vector{SQLNode} + vars::Base.ImmutableDict{Tuple{Symbol, Int}, SQLClause} subs::Dict{SQLNode, SQLClause} - TranslateContext(ctx::AnnotateContext) = - new(ctx.catalog.dialect, - ctx.path_map, + TranslateContext(; dialect, defs) = + new(dialect, + defs, Dict{Symbol, Int}(), - OrderedDict{SQLNode, CTEAssemblage}(), Ref(false), - Dict{Symbol, SQLClause}(), - Dict{SQLNode, SQLClause}()) - - function TranslateContext(ctx::TranslateContext; vars = nothing, subs = nothing) + CTEAssemblage[], + Base.ImmutableDict{Tuple{Symbol, Int}, Int}(), + 0, + SQLNode[], + Base.ImmutableDict{Tuple{Symbol, Int}, SQLClause}(), + Dict{Int, SQLClause}()) + + function TranslateContext(ctx::TranslateContext; cte_map = ctx.cte_map, knot = ctx.knot, refs = ctx.refs, vars = ctx.vars, subs = ctx.subs) new(ctx.dialect, - ctx.path_map, + ctx.defs, ctx.aliases, - ctx.cte_map, ctx.recursive, - something(vars, ctx.vars), - something(subs, ctx.subs)) + ctx.ctes, + cte_map, + knot, + refs, + vars, + subs) end end -allocate_alias(ctx::TranslateContext, n::SQLNode) = - allocate_alias(ctx, (n[]::BoxNode).type.name) +allocate_alias(ctx::TranslateContext, a::Assemblage) = + allocate_alias(ctx, a.name) function allocate_alias(ctx::TranslateContext, alias::Symbol) n = get(ctx.aliases, alias, 0) + 1 @@ -163,10 +183,12 @@ function allocate_alias(ctx::TranslateContext, alias::Symbol) Symbol(alias, '_', n) end -function translate_toplevel(n::SQLNode, ctx) - c = translate(n, ctx) +function translate(n::SQLNode) + @dissect(n, WithContext(over = n′, dialect = dialect, defs = defs)) || throw(IllFormedError()) + ctx = TranslateContext(dialect = dialect, defs = defs) + c = translate(n′, ctx) with_args = SQLClause[] - for cte_a in values(ctx.cte_map) + for cte_a in ctx.ctes !cte_a.external || continue cols = Symbol[name for name in keys(cte_a.a.cols)] if isempty(cols) @@ -183,15 +205,7 @@ function translate_toplevel(n::SQLNode, ctx) if !isempty(with_args) c = WITH(over = c, args = with_args, recursive = ctx.recursive[]) end - c -end - - -# Translating scalar nodes. - -function translate(n, ctx::TranslateContext, subs::Dict{SQLNode, SQLClause}) - ctx′ = TranslateContext(ctx, subs = subs) - translate(n, ctx′) + WITH_CONTEXT(over = c, dialect = ctx.dialect) end function translate(n::SQLNode, ctx) @@ -202,32 +216,44 @@ function translate(n::SQLNode, ctx) c end -translate(ns::Vector{SQLNode}, ctx) = +function translate(ns::Vector{SQLNode}, ctx) SQLClause[translate(n, ctx) for n in ns] +end translate(::Nothing, ctx) = nothing +function translate(n, ctx::TranslateContext, subs::Dict{SQLNode, SQLClause}) + ctx′ = TranslateContext(ctx, subs = subs) + translate(n, ctx′) +end + function translate(n::AggregateNode, ctx) args = translate(n.args, ctx) filter = translate(n.filter, ctx) AGG(n.name, args = args, filter = filter) end -translate(n::Union{AsNode, HighlightNode}, ctx) = +function translate(n::AsNode, ctx) translate(n.over, ctx) +end -function translate(n::IntBindNode, ctx) - vars′ = copy(ctx.vars) +function translate(n::BindNode, ctx) + vars′ = ctx.vars for (name, i) in n.label_map - vars′[name] = translate(n.args[i], ctx) + depth = _cte_depth(ctx.vars, name) + 1 + vars′ = Base.ImmutableDict(vars′, (name, depth) => translate(n.args[i], ctx)) end ctx′ = TranslateContext(ctx, vars = vars′) translate(n.over, ctx′) end -function translate(n::BoxNode, ctx) - base = assemble(n, ctx) +function translate(n::BoundVariableNode, ctx) + ctx.vars[(n.name, n.depth)] +end + +function translate(n::LinkedNode, ctx) + base = assemble(n.over, TranslateContext(ctx, refs = n.refs)) complete(base) end @@ -270,95 +296,39 @@ function translate(n::FunctionNode, ctx) return LIT(!val) end end - FUN(n.name, args = args) + FUN(name = n.name, args = args) end -translate(n::LiteralNode, ctx) = +function translate(n::IsolatedNode, ctx) + base = assemble(ctx.defs[n.idx], ctx) + complete(base) +end + +function translate(n::LiteralNode, ctx) LIT(n.val) +end + +function translate(n::ResolvedNode, ctx) + translate(n.over, ctx) +end translate(n::SortNode, ctx) = SORT(over = translate(n.over, ctx), value = n.value, nulls = n.nulls) function translate(n::VariableNode, ctx) - c = get(ctx.vars, n.name, nothing) - if c === nothing - c = VAR(n.name) - end - c + VAR(n.name) end - -# Translating subquery nodes. - -assemble(n::SQLNode, ctx) = +function assemble(n::SQLNode, ctx) assemble(n[], ctx) - -function assemble(n::BoxNode, ctx) - refs′ = SQLNode[] - for ref in n.refs - if @dissect(ref, over |> HandleBoundNode(handle = handle)) && handle == n.handle - push!(refs′, over) - else - push!(refs′, ref) - end - end - base = assemble(n.over, refs′, ctx) - repl′ = Dict{SQLNode, Symbol}() - for ref in n.refs - if @dissect(ref, over |> HandleBoundNode(handle = handle)) && handle == n.handle - repl′[ref] = base.repl[over] - else - repl′[ref] = base.repl[ref] - end - end - a = Assemblage(base.clause, cols = base.cols, repl = repl′) - inhibit_duplicates(a, n, ctx) -end - -function inhibit_duplicates(a::Assemblage, n::BoxNode, ctx) - imm_refs_begin_at = something(n.imm_refs_begin_at, length(n.refs) + 1) - imm_refs_begin_at <= length(n.refs) || return a - dups = Set{Symbol}() - for k = 1:lastindex(n.refs) - col = a.repl[n.refs[k]] - if col in dups - if k >= imm_refs_begin_at - alias = allocate_alias(ctx, n.type.name) - c = FROM(AS(over = complete(a), name = alias)) - subs = make_subs(a, alias) - trns = Pair{SQLNode, SQLClause}[] - for ref in n.refs - push!(trns, ref => subs[ref]) - end - repl, cols = make_repl_cols(trns) - return Assemblage(c, cols = cols, repl = repl) - end - elseif !@dissect(a.cols[col], (nothing |> ID() || nothing |> ID() |> ID() || VAR() || LIT())) - push!(dups, col) - end - end - a end -assemble(n::SQLNode, refs, ctx) = - assemble(n[], refs, ctx) - -function assemble(::Nothing, refs, ctx) - @assert isempty(refs) - Assemblage(nothing) +function assemble(::Nothing, ctx) + @assert isempty(ctx.refs) + Assemblage(:_, nothing) end -function aligned_columns(refs, repl, args) - length(refs) == length(args) || return false - for (ref, arg) in zip(refs, args) - if !(@dissect(arg, ID(name = name) || AS(name = name)) && name === repl[ref]) - return false - end - end - return true -end - -function assemble(n::AppendNode, refs, ctx) +function assemble(n::AppendNode, ctx) base = assemble(n.over, ctx) branches = [n.over => base] for arg in n.args @@ -366,7 +336,7 @@ function assemble(n::AppendNode, refs, ctx) end dups = Dict{SQLNode, SQLNode}() seen = Dict{Symbol, SQLNode}() - for ref in refs + for ref in ctx.refs name = base.repl[ref] if name in keys(seen) other_ref = seen[name] @@ -379,7 +349,7 @@ function assemble(n::AppendNode, refs, ctx) end end urefs = SQLNode[] - for ref in refs + for ref in ctx.refs if !(ref in keys(dups)) push!(urefs, ref) dups[ref] = ref @@ -389,8 +359,12 @@ function assemble(n::AppendNode, refs, ctx) for (ref, uref) in dups repl[ref] = repl[uref] end + a_name = base.name cs = SQLClause[] for (arg, a) in branches + if a.name !== a_name + a_name = :union + end if @dissect(a.clause, SELECT(args = args)) && aligned_columns(urefs, repl, args) push!(cs, a.clause) continue @@ -398,7 +372,7 @@ function assemble(n::AppendNode, refs, ctx) alias = nothing tail = a.clause else - alias = allocate_alias(ctx, arg) + alias = allocate_alias(ctx, a) tail = FROM(AS(over = complete(a), name = alias)) end subs = make_subs(a, alias) @@ -411,57 +385,71 @@ function assemble(n::AppendNode, refs, ctx) push!(cs, c) end c = UNION(over = cs[1], all = true, args = cs[2:end]) - Assemblage(c, repl = repl, cols = dummy_cols) + Assemblage(a_name, c, repl = repl, cols = dummy_cols) end -function assemble(n::AsNode, refs, ctx) - base = assemble(n.over, ctx) +function assemble(n::AsNode, ctx) + refs′ = SQLNode[] + for ref in ctx.refs + if @dissect(ref, over |> Nested()) + push!(refs′, over) + else + push!(refs′, ref) + end + end + base = assemble(n.over, TranslateContext(ctx, refs = refs′)) repl′ = Dict{SQLNode, Symbol}() - for ref in refs - if @dissect(ref, over |> NameBound()) + for ref in ctx.refs + if @dissect(ref, over |> Nested()) repl′[ref] = base.repl[over] else repl′[ref] = base.repl[ref] end end - Assemblage(base.clause, cols = base.cols, repl = repl′) + Assemblage(n.name, base.clause, cols = base.cols, repl = repl′) end -function assemble(n::DefineNode, refs, ctx) - base = assemble(n.over, ctx) - if !any(ref -> @dissect(ref, Get(name = name)) && name in keys(n.label_map), refs) - return base +function assemble(n::BindNode, ctx) + vars′ = ctx.vars + for (name, i) in n.label_map + depth = _cte_depth(ctx.vars, name) + 1 + vars′ = Base.ImmutableDict(vars′, (name, depth) => translate(n.args[i], ctx)) end + ctx′ = TranslateContext(ctx, vars = vars′) + assemble(n.over, ctx′) +end + +function assemble(n::DefineNode, ctx) + base = assemble(n.over, ctx) if !@dissect(base.clause, SELECT() || UNION()) base_alias = nothing c = base.clause else - base_alias = allocate_alias(ctx, n.over) + base_alias = allocate_alias(ctx, base) c = FROM(AS(over = complete(base), name = base_alias)) end subs = make_subs(base, base_alias) + tr_cache = Dict{Symbol, SQLClause}() + for (f, i) in n.label_map + tr_cache[f] = translate(n.args[i], ctx, subs) + end repl = Dict{SQLNode, Symbol}() trns = Pair{SQLNode, SQLClause}[] - tr_cache = Dict{Symbol, SQLClause}() - for ref in refs - if @dissect(ref, nothing |> Get(name = name)) && name in keys(n.label_map) - col = get!(tr_cache, name) do - def = n.args[n.label_map[name]] - translate(def, ctx, subs) - end - push!(trns, ref => col) + for ref in ctx.refs + if @dissect(ref, nothing |> Get(name = name)) && name in keys(tr_cache) + push!(trns, ref => tr_cache[name]) else push!(trns, ref => subs[ref]) end end repl, cols = make_repl_cols(trns) - Assemblage(c, cols = cols, repl = repl) + Assemblage(base.name, c, cols = cols, repl = repl) end -function assemble(n::FromFunctionNode, refs, ctx) +function assemble(n::FromFunctionNode, ctx) seen = Set{Symbol}() column_set = Set(n.columns) - for ref in refs + for ref in ctx.refs @dissect(ref, nothing |> Get(name = name)) && name in column_set || error() if !(name in seen) push!(seen, name) @@ -476,57 +464,58 @@ function assemble(n::FromFunctionNode, refs, ctx) cols[col] = ID(over = alias, name = col) end repl = Dict{SQLNode, Symbol}() - for ref in refs + for ref in ctx.refs if @dissect(ref, nothing |> Get(name = name)) repl[ref] = name end end - Assemblage(c, cols = cols, repl = repl) + Assemblage(label(n.over), c, cols = cols, repl = repl) end -assemble(::FromNothingNode, refs, ctx) = - assemble(nothing, refs, ctx) +function assemble(n::FromIterateNode, ctx) + cte_a = ctx.ctes[ctx.knot] + name = cte_a.a.name + alias = allocate_alias(ctx, name) + tbl = ID(cte_a.qualifiers, cte_a.name) + c = FROM(AS(over = tbl, name = alias)) + subs = make_subs(cte_a.a, alias) + trns = Pair{SQLNode, SQLClause}[] + for ref in ctx.refs + push!(trns, ref => subs[ref]) + end + repl, cols = make_repl_cols(trns) + return Assemblage(name, c, cols = cols, repl = repl) +end + +assemble(::FromNothingNode, ctx) = + assemble(nothing, ctx) function unwrap_repl(a::Assemblage) repl′ = Dict{SQLNode, Symbol}() for (ref, name) in a.repl - @dissect(ref, over |> NameBound()) || error() + @dissect(ref, over |> Nested()) || error() repl′[over] = name end - Assemblage(a.clause, cols = a.cols, repl = repl′) + Assemblage(a.name, a.clause, cols = a.cols, repl = repl′) end -function assemble(n::FromReferenceNode, refs, ctx) - cte_a = ctx.cte_map[n.over] +function assemble(n::FromTableExpressionNode, ctx) + cte_a = ctx.ctes[ctx.cte_map[(n.name, n.depth)]] alias = allocate_alias(ctx, n.name) tbl = ID(cte_a.qualifiers, cte_a.name) c = FROM(AS(over = tbl, name = alias)) subs = make_subs(unwrap_repl(cte_a.a), alias) trns = Pair{SQLNode, SQLClause}[] - for ref in refs + for ref in ctx.refs push!(trns, ref => subs[ref]) end repl, cols = make_repl_cols(trns) - return Assemblage(c, cols = cols, repl = repl) + return Assemblage(n.name, c, cols = cols, repl = repl) end -function assemble(n::FromSelfNode, refs, ctx) - cte_a = ctx.cte_map[n.over] - alias = allocate_alias(ctx, label(n.over)) - tbl = ID(cte_a.qualifiers, cte_a.name) - c = FROM(AS(over = tbl, name = alias)) - subs = make_subs(cte_a.a, alias) - trns = Pair{SQLNode, SQLClause}[] - for ref in refs - push!(trns, ref => subs[ref]) - end - repl, cols = make_repl_cols(trns) - return Assemblage(c, cols = cols, repl = repl) -end - -function assemble(n::FromTableNode, refs, ctx) +function assemble(n::FromTableNode, ctx) seen = Set{Symbol}() - for ref in refs + for ref in ctx.refs @dissect(ref, nothing |> Get(name = name)) && name in n.table.column_set || error() if !(name in seen) push!(seen, name) @@ -541,19 +530,19 @@ function assemble(n::FromTableNode, refs, ctx) cols[col] = ID(over = alias, name = col) end repl = Dict{SQLNode, Symbol}() - for ref in refs + for ref in ctx.refs if @dissect(ref, nothing |> Get(name = name)) repl[ref] = name end end - Assemblage(c, cols = cols, repl = repl) + Assemblage(n.table.name, c, cols = cols, repl = repl) end -function assemble(n::FromValuesNode, refs, ctx) +function assemble(n::FromValuesNode, ctx) columns = Symbol[fieldnames(typeof(n.columns))...] column_set = Set{Symbol}(columns) seen = Set{Symbol}() - for ref in refs + for ref in ctx.refs @dissect(ref, nothing |> Get(name = name)) && name in column_set || error() if !(name in seen) push!(seen, name) @@ -596,36 +585,36 @@ function assemble(n::FromValuesNode, refs, ctx) end end repl = Dict{SQLNode, Symbol}() - for ref in refs + for ref in ctx.refs if @dissect(ref, nothing |> Get(name = name)) repl[ref] = name end end - Assemblage(c, cols = cols, repl = repl) + Assemblage(:values, c, cols = cols, repl = repl) end -function assemble(n::GroupNode, refs, ctx) - has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> NameBound()), refs) +function assemble(n::GroupNode, ctx) + has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> Nested()), ctx.refs) if isempty(n.by) && !has_aggregates - return assemble(nothing, refs, ctx) + return assemble(nothing, ctx) end base = assemble(n.over, ctx) if @dissect(base.clause, tail := nothing || FROM() || JOIN() || WHERE()) base_alias = nothing else - base_alias = allocate_alias(ctx, n.over) + base_alias = allocate_alias(ctx, base) tail = FROM(AS(over = complete(base), name = base_alias)) end subs = make_subs(base, base_alias) by = SQLClause[subs[key] for key in n.by] trns = Pair{SQLNode, SQLClause}[] - for ref in refs + for ref in ctx.refs if @dissect(ref, nothing |> Get(name = name)) @assert name in keys(n.label_map) push!(trns, ref => by[n.label_map[name]]) elseif @dissect(ref, nothing |> Agg()) push!(trns, ref => translate(ref, ctx, subs)) - elseif @dissect(ref, (over := nothing |> Agg()) |> NameBound()) + elseif @dissect(ref, (over := nothing |> Agg()) |> Nested()) push!(trns, ref => translate(over, ctx, subs)) end end @@ -643,109 +632,17 @@ function assemble(n::GroupNode, refs, ctx) c = SELECT(over = tail, distinct = true, args = args) cols = OrderedDict{Symbol, SQLClause}([name => ID(name) for name in keys(cols)]) end - return Assemblage(c, cols = cols, repl = repl) + return Assemblage(base.name, c, cols = cols, repl = repl) end -assemble(n::HighlightNode, refs, ctx) = - assemble(n.over, ctx) - -function assemble(n::IntAutoDefineNode, refs, ctx) - base = assemble(n.over, ctx) - if isempty(refs) - return base - end - if !@dissect(base.clause, SELECT() || UNION()) - base_alias = nothing - c = base.clause - else - base_alias = allocate_alias(ctx, n.over) - c = FROM(AS(over = complete(base), name = base_alias)) - end - subs = make_subs(base, base_alias) - repl = Dict{SQLNode, Symbol}() - trns = Pair{SQLNode, SQLClause}[] - for ref in refs - push!(trns, ref => translate(ref, ctx, subs)) - end - repl, cols = make_repl_cols(trns) - Assemblage(c, cols = cols, repl = repl) -end - -function assemble(n::IntBindNode, refs, ctx) - vars′ = copy(ctx.vars) - for (name, i) in n.label_map - vars′[name] = translate(n.args[i], ctx) - end - ctx′ = TranslateContext(ctx, vars = vars′) - assemble(n.over, ctx′) -end - -function assemble(n::IntIterateNode, refs, ctx) - ctx′ = TranslateContext(ctx, vars = Dict{Symbol, SQLClause}()) - base = assemble(n.over, ctx′) - @assert @dissect(base.clause, FROM()) - subs = make_subs(base, nothing) - trns = Pair{SQLNode, SQLClause}[] - for ref in refs - push!(trns, ref => subs[ref]) - end - repl, cols = make_repl_cols(trns) - Assemblage(base.clause, cols = cols, repl = repl) -end - -function assemble(n::IntJoinNode, refs, ctx) - left = assemble(n.over, ctx) - if n.skip - return left - end - if @dissect(left.clause, tail := FROM() || JOIN()) - left_alias = nothing - else - left_alias = allocate_alias(ctx, n.over) - tail = FROM(AS(over = complete(left), name = left_alias)) - end - lateral = !isempty(n.lateral) - subs = make_subs(left, left_alias) - if lateral - right = assemble(n.joinee, TranslateContext(ctx, subs = subs)) - else - right = assemble(n.joinee, ctx) - end - if @dissect(right.clause, (joinee := (nothing || nothing |> ID()) |> ID() |> AS(name = right_alias, columns = nothing)) |> FROM()) || - @dissect(right.clause, (joinee := nothing |> ID(name = right_alias)) |> FROM()) || - @dissect(right.clause, (joinee := FUN() |> AS(name = right_alias)) |> FROM()) - for (ref, name) in right.repl - subs[ref] = right.cols[name] - end - if ctx.dialect.has_implicit_lateral - lateral = false - end - else - right_alias = allocate_alias(ctx, n.joinee) - joinee = AS(over = complete(right), name = right_alias) - right_cache = Dict{Symbol, SQLClause}() - for (ref, name) in right.repl - subs[ref] = get(right_cache, name) do - ID(over = right_alias, name = name) - end - end - end - on = translate(n.on, ctx, subs) - c = JOIN(over = tail, joinee = joinee, on = on, left = n.left, right = n.right, lateral = lateral) - trns = Pair{SQLNode, SQLClause}[] - for ref in refs - push!(trns, ref => subs[ref]) - end - repl, cols = make_repl_cols(trns) - Assemblage(c, cols = cols, repl = repl) -end - -function assemble(n::KnotNode, refs, ctx) +function assemble(n::IterateNode, ctx) + ctx′ = TranslateContext(ctx, vars = Base.ImmutableDict{Tuple{Symbol, Int}, SQLClause}()) left = assemble(n.over, ctx) repl = Dict{SQLNode, Symbol}() dups = Dict{SQLNode, SQLNode}() seen = Dict{Symbol, SQLNode}() - for ref in refs + for ref in ctx.refs + !in(ref, keys(repl)) || continue name = left.repl[ref] repl[ref] = name if name in keys(seen) @@ -754,13 +651,17 @@ function assemble(n::KnotNode, refs, ctx) seen[name] = ref end end - temp_union = Assemblage(left.clause, cols = left.cols, repl = repl) - union_alias = allocate_alias(ctx, n.name) - ctx.cte_map[SQLNode(n.box)] = CTEAssemblage(temp_union, name = union_alias) + temp_union = Assemblage(label(n.iterator), left.clause, cols = left.cols, repl = repl) + union_alias = allocate_alias(ctx, temp_union) + cte = CTEAssemblage(temp_union, name = union_alias) + push!(ctx.ctes, cte) + knot = lastindex(ctx.ctes) + ctx = TranslateContext(ctx, knot = knot) right = assemble(n.iterator, ctx) urefs = SQLNode[] - for ref in refs + for ref in ctx.refs !(ref in keys(dups)) || continue + dups[ref] = ref push!(urefs, ref) end cs = SQLClause[] @@ -772,7 +673,7 @@ function assemble(n::KnotNode, refs, ctx) alias = nothing tail = a.clause else - alias = allocate_alias(ctx, arg) + alias = allocate_alias(ctx, a) tail = FROM(AS(over = complete(a), name = alias)) end subs = make_subs(a, alias) @@ -790,80 +691,111 @@ function assemble(n::KnotNode, refs, ctx) name = left.repl[ref] cols[name] = ID(name) end - union = Assemblage(union_clause, cols = cols, repl = repl) - ctx.cte_map[SQLNode(n.box)] = CTEAssemblage(union, name = union_alias) + union = Assemblage(right.name, union_clause, cols = cols, repl = repl) + ctx.ctes[knot] = CTEAssemblage(union, name = union_alias) ctx.recursive[] = true - alias = allocate_alias(ctx, n.name) + alias = allocate_alias(ctx, union) c = FROM(AS(over = ID(union_alias), name = alias)) subs = make_subs(union, alias) trns = Pair{SQLNode, SQLClause}[] - for ref in refs + for ref in ctx.refs push!(trns, ref => subs[ref]) end repl, cols = make_repl_cols(trns) - return Assemblage(c, cols = cols, repl = repl) + return Assemblage(union.name, c, cols = cols, repl = repl) end -function assemble(n::LimitNode, refs, ctx) +function assemble(n::LimitNode, ctx) base = assemble(n.over, ctx) - if n.offset === nothing && n.limit === nothing - return base - end if @dissect(base.clause, tail := nothing || FROM() || JOIN() || WHERE() || GROUP() || HAVING() || ORDER()) base_alias = nothing else - base_alias = allocate_alias(ctx, n.over) + base_alias = allocate_alias(ctx, base) tail = FROM(AS(over = complete(base), name = base_alias)) end c = LIMIT(over = tail, offset = n.offset, limit = n.limit) subs = make_subs(base, base_alias) trns = Pair{SQLNode, SQLClause}[] - for ref in refs + for ref in ctx.refs push!(trns, ref => subs[ref]) end repl, cols = make_repl_cols(trns) - Assemblage(c, cols = cols, repl = repl) + Assemblage(base.name, c, cols = cols, repl = repl) end -function assemble(n::OrderNode, refs, ctx) - base = assemble(n.over, ctx) - if isempty(n.by) - return base +function assemble(n::LinkedNode, ctx) + a = assemble(n.over, TranslateContext(ctx, refs = n.refs)) + n.n_ext_refs < length(n.refs) || return a + dups = Set{Symbol}() + for (k, ref) in enumerate(n.refs) + col = a.repl[ref] + if col in dups + if k > n.n_ext_refs + alias = allocate_alias(ctx, a) + c = FROM(AS(over = complete(a), name = alias)) + subs = make_subs(a, alias) + trns = Pair{SQLNode, SQLClause}[] + for ref in n.refs + push!(trns, ref => subs[ref]) + end + repl, cols = make_repl_cols(trns) + return Assemblage(a.name, c, cols = cols, repl = repl) + end + elseif !@dissect(a.cols[col], (nothing |> ID() || nothing |> ID() |> ID() || VAR() || LIT())) + push!(dups, col) + end end + a +end + +function assemble(n::OrderNode, ctx) + base = assemble(n.over, ctx) + @assert !isempty(n.by) if @dissect(base.clause, tail := nothing || FROM() || JOIN() || WHERE() || GROUP() || HAVING()) base_alias = nothing else - base_alias = allocate_alias(ctx, n.over) + base_alias = allocate_alias(ctx, base) tail = FROM(AS(over = complete(base), name = base_alias)) end subs = make_subs(base, base_alias) by = translate(n.by, ctx, subs) c = ORDER(over = tail, by = by) trns = Pair{SQLNode, SQLClause}[] - for ref in refs + for ref in ctx.refs push!(trns, ref => subs[ref]) end repl, cols = make_repl_cols(trns) - Assemblage(c, cols = cols, repl = repl) + Assemblage(base.name, c, cols = cols, repl = repl) end -function assemble(n::PartitionNode, refs, ctx) +function assemble(n::PaddingNode, ctx) base = assemble(n.over, ctx) - has_aggregates = false - for ref in refs - if (@dissect(ref, nothing |> Agg() |> NameBound(name = name)) && name === n.name) || - (@dissect(ref, nothing |> Agg()) && n.name === nothing) - has_aggregates = true - break - end - end - if !has_aggregates + if isempty(ctx.refs) return base end + if !@dissect(base.clause, SELECT() || UNION()) + base_alias = nothing + c = base.clause + else + base_alias = allocate_alias(ctx, base) + c = FROM(AS(over = complete(base), name = base_alias)) + end + subs = make_subs(base, base_alias) + repl = Dict{SQLNode, Symbol}() + trns = Pair{SQLNode, SQLClause}[] + for ref in ctx.refs + push!(trns, ref => translate(ref, ctx, subs)) + end + repl, cols = make_repl_cols(trns) + Assemblage(base.name, c, cols = cols, repl = repl) +end + +function assemble(n::PartitionNode, ctx) + base = assemble(n.over, ctx) if @dissect(base.clause, tail := nothing || FROM() || JOIN() || WHERE() || GROUP() || HAVING()) base_alias = nothing else - base_alias = allocate_alias(ctx, n.over) + base_alias = allocate_alias(ctx, base) tail = FROM(AS(over = complete(base), name = base_alias)) end c = WINDOW(over = tail, args = []) @@ -873,26 +805,74 @@ function assemble(n::PartitionNode, refs, ctx) order_by = translate(n.order_by, ctx′) partition = PARTITION(by = by, order_by = order_by, frame = n.frame) trns = Pair{SQLNode, SQLClause}[] - for ref in refs + has_aggregates = false + for ref in ctx.refs if @dissect(ref, nothing |> Agg()) && n.name === nothing push!(trns, ref => partition |> translate(ref, ctx′)) - elseif @dissect(ref, (over := nothing |> Agg()) |> NameBound(name = name)) && name === n.name + has_aggregates = true + elseif @dissect(ref, (over := nothing |> Agg()) |> Nested(name = name)) && name === n.name push!(trns, ref => partition |> translate(over, ctx′)) + has_aggregates = true else push!(trns, ref => subs[ref]) end end + @assert has_aggregates repl, cols = make_repl_cols(trns) - Assemblage(c, cols = cols, repl = repl) + Assemblage(base.name, c, cols = cols, repl = repl) end -function assemble(n::SelectNode, refs, ctx) +function assemble(n::RoutedJoinNode, ctx) + left = assemble(n.over, ctx) + if @dissect(left.clause, tail := FROM() || JOIN()) + left_alias = nothing + else + left_alias = allocate_alias(ctx, left) + tail = FROM(AS(over = complete(left), name = left_alias)) + end + lateral = n.lateral + subs = make_subs(left, left_alias) + if lateral + right = assemble(n.joinee, TranslateContext(ctx, subs = subs)) + else + right = assemble(n.joinee, ctx) + end + if @dissect(right.clause, (joinee := (nothing || nothing |> ID()) |> ID() |> AS(name = right_alias, columns = nothing)) |> FROM()) || + @dissect(right.clause, (joinee := nothing |> ID(name = right_alias)) |> FROM()) || + @dissect(right.clause, (joinee := FUN() |> AS(name = right_alias)) |> FROM()) + for (ref, name) in right.repl + subs[ref] = right.cols[name] + end + if ctx.dialect.has_implicit_lateral + lateral = false + end + else + right_alias = allocate_alias(ctx, right) + joinee = AS(over = complete(right), name = right_alias) + right_cache = Dict{Symbol, SQLClause}() + for (ref, name) in right.repl + subs[ref] = get(right_cache, name) do + ID(over = right_alias, name = name) + end + end + end + on = translate(n.on, ctx, subs) + c = JOIN(over = tail, joinee = joinee, on = on, left = n.left, right = n.right, lateral = lateral) + trns = Pair{SQLNode, SQLClause}[] + for ref in ctx.refs + push!(trns, ref => subs[ref]) + end + repl, cols = make_repl_cols(trns) + Assemblage(left.name, c, cols = cols, repl = repl) +end + +function assemble(n::SelectNode, ctx) base = assemble(n.over, ctx) if !@dissect(base.clause, SELECT() || UNION()) base_alias = nothing tail = base.clause else - base_alias = allocate_alias(ctx, n.over) + base_alias = allocate_alias(ctx, base) tail = FROM(AS(over = complete(base), name = base_alias)) end subs = make_subs(base, base_alias) @@ -904,11 +884,11 @@ function assemble(n::SelectNode, refs, ctx) c = SELECT(over = tail, args = complete(cols)) cols = OrderedDict{Symbol, SQLClause}([name => ID(name) for name in keys(cols)]) repl = Dict{SQLNode, Symbol}() - for ref in refs + for ref in ctx.refs @dissect(ref, nothing |> Get(name = name)) || error() repl[ref] = name end - Assemblage(c, cols = cols, repl = repl) + Assemblage(base.name, c, cols = cols, repl = repl) end function merge_conditions(c1, c2) @@ -925,7 +905,7 @@ function merge_conditions(c1, c2) end end -function assemble(n::WhereNode, refs, ctx) +function assemble(n::WhereNode, ctx) base = assemble(n.over, ctx) if @dissect(base.clause, nothing || FROM() || JOIN() || WHERE() || HAVING()) || @dissect(base.clause, GROUP(by = by)) && !isempty(by) @@ -946,7 +926,7 @@ function assemble(n::WhereNode, refs, ctx) c = WHERE(over = base.clause, condition = condition) end else - base_alias = allocate_alias(ctx, n.over) + base_alias = allocate_alias(ctx, base) tail = FROM(AS(over = complete(base), name = base_alias)) subs = make_subs(base, base_alias) condition = translate(n.condition, ctx, subs) @@ -956,28 +936,34 @@ function assemble(n::WhereNode, refs, ctx) c = WHERE(over = tail, condition = condition) end trns = Pair{SQLNode, SQLClause}[] - for ref in refs + for ref in ctx.refs push!(trns, ref => subs[ref]) end repl, cols = make_repl_cols(trns) - return Assemblage(c, cols = cols, repl = repl) + return Assemblage(base.name, c, cols = cols, repl = repl) end -function assemble(n::WithNode, refs, ctx) - ctx′ = TranslateContext(ctx, vars = Dict{Symbol, SQLClause}()) - for arg in n.args - a = assemble(arg, ctx) - alias = allocate_alias(ctx, arg) - ctx.cte_map[arg] = CTEAssemblage(a, name = alias, materialized = n.materialized) +function assemble(n::WithNode, ctx) + cte_map′ = ctx.cte_map + # FIXME: variable pushed into a CTE + ctx′ = TranslateContext(ctx, vars = Base.ImmutableDict{Tuple{Symbol, Int}, SQLClause}()) + for (name, i) in n.label_map + a = assemble(n.args[i], ctx) + alias = allocate_alias(ctx, a) + cte = CTEAssemblage(a, name = alias, materialized = n.materialized) + push!(ctx.ctes, cte) + depth = _cte_depth(ctx.cte_map, name) + 1 + cte_map′ = Base.ImmutableDict(cte_map′, (name, depth) => lastindex(ctx.ctes)) end - assemble(n.over, ctx) + assemble(n.over, TranslateContext(ctx, cte_map = cte_map′)) end -function assemble(n::WithExternalNode, refs, ctx) - ctx′ = TranslateContext(ctx, vars = Dict{Symbol, SQLClause}()) - for arg in n.args - a = assemble(arg, ctx) - table_name = (arg[]::BoxNode).type.name +function assemble(n::WithExternalNode, ctx) + cte_map′ = ctx.cte_map + ctx′ = TranslateContext(ctx, vars = Base.ImmutableDict{Tuple{Symbol, Int}, SQLClause}()) + for (name, i) in n.label_map + a = assemble(n.args[i], ctx) + table_name = a.name table_columns = Symbol[column_name for column_name in keys(a.cols)] if isempty(table_columns) push!(table_columns, :_) @@ -986,7 +972,10 @@ function assemble(n::WithExternalNode, refs, ctx) if n.handler !== nothing n.handler(table => complete(a)) end - ctx.cte_map[arg] = CTEAssemblage(a, name = table.name, qualifiers = table.qualifiers, external = true) + cte = CTEAssemblage(a, name = table.name, qualifiers = table.qualifiers, external = true) + push!(ctx.ctes, cte) + depth = _cte_depth(ctx.cte_map, name) + 1 + cte_map′ = Base.ImmutableDict(cte_map′, (name, depth) => lastindex(ctx.ctes)) end - assemble(n.over, ctx) + assemble(n.over, TranslateContext(ctx, cte_map = cte_map′)) end diff --git a/src/types.jl b/src/types.jl index 7ac0641e..856821ed 100644 --- a/src/types.jl +++ b/src/types.jl @@ -18,29 +18,22 @@ end PrettyPrinting.quoteof(::ScalarType) = Expr(:call, nameof(ScalarType)) -struct AmbiguousType <: AbstractSQLType -end - -PrettyPrinting.quoteof(::AmbiguousType) = - Expr(:call, nameof(AmbiguousType)) - struct RowType <: AbstractSQLType - fields::OrderedDict{Symbol, Union{ScalarType, AmbiguousType, RowType}} - group::Union{EmptyType, AmbiguousType, RowType} + fields::OrderedDict{Symbol, Union{ScalarType, RowType}} + group::Union{EmptyType, RowType} RowType(fields, group = EmptyType()) = new(fields, group) end -const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, AmbiguousType, RowType}} -const GroupType = Union{EmptyType, AmbiguousType, RowType} -const HandleTypeMap = Dict{Int, Union{AmbiguousType, RowType}} +const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, RowType}} +const GroupType = Union{EmptyType, RowType} RowType() = RowType(FieldTypeMap()) RowType(fields::Pair{Symbol, <:AbstractSQLType}...; group = EmptyType()) = - RowType(FieldTypeMap(fields...), group) + RowType(FieldTypeMap(fields), group) function PrettyPrinting.quoteof(t::RowType) ex = Expr(:call, nameof(RowType)) @@ -53,52 +46,7 @@ function PrettyPrinting.quoteof(t::RowType) ex end -struct BoxType <: AbstractSQLType - name::Symbol - row::RowType - handle_map::HandleTypeMap -end - -BoxType(name::Symbol, row::RowType) = - BoxType(name, row, HandleTypeMap()) - -function BoxType(name::Symbol, fields::Pair{<:Union{Symbol, Int}, <:AbstractSQLType}...; group = EmptyType()) - field_map = FieldTypeMap() - handle_map = HandleTypeMap() - for (key, val) in fields - if key isa Symbol - field_map[key] = val - else - handle_map[key] = val - end - end - BoxType(name, RowType(field_map, group), handle_map) -end - -function PrettyPrinting.quoteof(t::BoxType) - ex = Expr(:call, nameof(BoxType), QuoteNode(t.name)) - for (f, ft) in t.row.fields - push!(ex.args, Expr(:call, :(=>), QuoteNode(f), quoteof(ft))) - end - if !(t.row.group isa EmptyType) - push!(ex.args, Expr(:kw, :group, quoteof(t.row.group))) - end - for (h, ht) in sort!(collect(t.handle_map)) - push!(ex.args, Expr(:call, :(=>), h, quoteof(ht))) - end - ex -end - -const EMPTY_BOX = BoxType(:_, RowType(), HandleTypeMap()) - -function add_handle(t::BoxType, handle::Int) - if handle != 0 - handle_map = copy(t.handle_map) - handle_map[handle] = t.row - t = BoxType(t.name, t.row, handle_map) - end - t -end +const EMPTY_ROW = RowType() # Type of `Append` (UNION ALL). @@ -109,9 +57,6 @@ Base.intersect(::AbstractSQLType, ::AbstractSQLType) = Base.intersect(::ScalarType, ::ScalarType) = ScalarType() -Base.intersect(::AmbiguousType, ::AmbiguousType) = - AmbiguousType() - function Base.intersect(t1::RowType, t2::RowType) if t1 === t2 return t1 @@ -129,27 +74,16 @@ function Base.intersect(t1::RowType, t2::RowType) RowType(fields, group) end -function Base.intersect(t1::BoxType, t2::BoxType) - if t1 === t2 - return t1 - end - handle_map = HandleTypeMap() - for h in keys(t1.handle_map) - if h in keys(t2.handle_map) - t = intersect(t1.handle_map[h], t2.handle_map[h]) - if !(t isa EmptyType) - handle_map[h] = t - end - end - end - name = t1.name == t2.name ? t2.name : :union - BoxType(name, intersect(t1.row, t2.row), handle_map) -end + +# Type order. Base.issubset(::AbstractSQLType, ::AbstractSQLType) = false -Base.issubset(::T, ::T) where {T <: AbstractSQLType} = +Base.issubset(::EmptyType, ::AbstractSQLType) = + true + +Base.issubset(::ScalarType, ::ScalarType) = true function Base.issubset(t1::RowType, t2::RowType) @@ -161,82 +95,8 @@ function Base.issubset(t1::RowType, t2::RowType) return false end end - return true -end - -function Base.issubset(t1::BoxType, t2::BoxType) - if t1 === t2 - return true - end - t1.name == t2.name || return false - issubset(t1.row, t2.row) || return false - for h in keys(t1.handle_map) - if !(h in keys(t2.handle_map) && issubset(t1.handle_map[h], t2.handle_map[h])) - return false - end + if !issubset(t1.group, t2.group) + return false end return true end - -# Type of `Join`. - -Base.union(::AbstractSQLType, ::AbstractSQLType) = - AmbiguousType() - -Base.union(::EmptyType, ::EmptyType) = - EmptyType() - -Base.union(::EmptyType, t::AbstractSQLType) = - t - -Base.union(t::AbstractSQLType, ::EmptyType) = - t - -Base.union(::ScalarType, ::ScalarType) = - ScalarType() - -function Base.union(t1::RowType, t2::RowType) - fields = FieldTypeMap() - for (f, t) in t1.fields - if f in keys(t2.fields) - t′ = t2.fields[f] - if t isa RowType && t′ isa RowType - t = union(t, t′) - else - t = AmbiguousType() - end - end - fields[f] = t - end - for (f, t) in t2.fields - if !(f in keys(t1.fields)) - fields[f] = t - end - end - if t1.group isa EmptyType - group = t2.group - elseif t2.group isa EmptyType - group = t1.group - else - group = AmbiguousType() - end - RowType(fields, group) -end - -function Base.union(t1::BoxType, t2::BoxType) - handle_map = HandleTypeMap() - for l in keys(t1.handle_map) - if haskey(t2.handle_map, l) - handle_map[l] = AmbiguousType() - else - handle_map[l] = t1.handle_map[l] - end - end - for l in keys(t2.handle_map) - if !haskey(t1.handle_map, l) - handle_map[l] = t2.handle_map[l] - end - end - BoxType(t1.name, union(t1.row, t2.row), handle_map) -end -