From e73490db40e31e590c25ec86999a3fff158ddeca Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Thu, 7 Mar 2024 02:16:21 -0600 Subject: [PATCH 1/4] Support for grouping sets --- docs/src/test/clauses.md | 46 ++++++++++++++++++++++++++++++++++ docs/src/test/nodes.md | 51 ++++++++++++++++++++++++++++++++++++++ src/clauses/group.jl | 46 ++++++++++++++++++++++++++++------ src/link.jl | 8 ++++-- src/nodes/group.jl | 53 +++++++++++++++++++++++++++++++--------- src/resolve.jl | 2 +- src/serialize.jl | 31 ++++++++++++++++++++++- src/translate.jl | 6 ++--- 8 files changed, 217 insertions(+), 26 deletions(-) diff --git a/docs/src/test/clauses.md b/docs/src/test/clauses.md index 590ce1f1..42fad55c 100644 --- a/docs/src/test/clauses.md +++ b/docs/src/test/clauses.md @@ -1131,6 +1131,52 @@ rendered. FROM "person" =# +`GROUP` can accept the grouping mode or a vector of grouping sets. + + c = FROM(:person) |> GROUP(:year_of_birth, grouping_sets = :ROLLUP) + #-> (…) |> GROUP(…, grouping_sets = :ROLLUP) + + print(render(c |> SELECT(:year_of_birth, AGG(:count)))) + #=> + SELECT + "year_of_birth", + count(*) + FROM "person" + GROUP BY ROLLUP("year_of_birth") + =# + + c = FROM(:person) |> GROUP(:year_of_birth, grouping_sets = :CUBE) + #-> (…) |> GROUP(…, grouping_sets = :CUBE) + + print(render(c |> SELECT(:year_of_birth, AGG(:count)))) + #=> + SELECT + "year_of_birth", + count(*) + FROM "person" + GROUP BY CUBE("year_of_birth") + =# + + c = FROM(:person) |> GROUP(:year_of_birth, grouping_sets = [[1], Int[]]) + #-> (…) |> GROUP(…, grouping_sets = [[1], Int64[]]) + + print(render(c |> SELECT(:year_of_birth, AGG(:count)))) + #=> + SELECT + "year_of_birth", + count(*) + FROM "person" + GROUP BY GROUPING SETS(("year_of_birth"), ()) + =# + +`GROUP` raises an error when the vector of grouping sets is out of bounds. + + FROM(:person) |> GROUP(:year_of_birth, grouping_sets = [[1, 2], [1], Int[]]) + #=> + ERROR: DomainError with [[1, 2], [1], Int64[]]: + grouping_sets is out of bounds + =# + ## `HAVING` Clause diff --git a/docs/src/test/nodes.md b/docs/src/test/nodes.md index 4639d3b9..3a539884 100644 --- a/docs/src/test/nodes.md +++ b/docs/src/test/nodes.md @@ -2502,6 +2502,57 @@ downstream. ) AS "person_2" =# +`Group` allows specifying the grouping sets. + + q = From(person) |> + Group(Get.year_of_birth, grouping_sets = :cube) + Define(Agg.count()) + + display(q) + #=> + let person = SQLTable(:person, …), + q1 = From(person), + q2 = q1 |> Group(Get.year_of_birth, grouping_sets = :CUBE) + q2 + end + =# + + print(render(q)) + #=> + SELECT "person_1"."year_of_birth" + FROM "person" AS "person_1" + GROUP BY CUBE("person_1"."year_of_birth") + =# + + q = From(person) |> + Group(Get.year_of_birth, grouping_sets = [[1], Int[]]) + Define(Agg.count()) + + display(q) + #=> + let person = SQLTable(:person, …), + q1 = From(person), + q2 = q1 |> Group(Get.year_of_birth, grouping_sets = [[1], []]) + q2 + end + =# + + print(render(q)) + #=> + SELECT "person_1"."year_of_birth" + FROM "person" AS "person_1" + GROUP BY GROUPING SETS(("person_1"."year_of_birth"), ()) + =# + +`Group` complains about out-of-bound grouping sets. + + From(person) |> + Group(Get.year_of_birth, grouping_sets = [[1, 2], [1], Int[]]) + #=> + ERROR: DomainError with [[1, 2], [1], Int64[]]: + grouping_sets is out of bounds + =# + `Group` allows specifying the name of a group field. q = From(person) |> diff --git a/src/clauses/group.jl b/src/clauses/group.jl index f1d0dd23..ef243909 100644 --- a/src/clauses/group.jl +++ b/src/clauses/group.jl @@ -1,21 +1,47 @@ # GROUP BY clause. +module GROUPING_MODE + +@enum GroupingMode::UInt8 begin + ROLLUP + CUBE +end + +Base.convert(::Type{GroupingMode}, s::Symbol) = + s in (:rollup, :ROLLUP) ? + ROLLUP : + s in (:cube, :CUBE) ? + CUBE : + throw(DomainError(QuoteNode(s), "expected :rollup or :cube")) + +end + +import .GROUPING_MODE.GroupingMode + mutable struct GroupClause <: AbstractSQLClause over::Union{SQLClause, Nothing} by::Vector{SQLClause} + grouping_sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} - GroupClause(; + function GroupClause(; over = nothing, - by = SQLClause[]) = - new(over, by) + by = SQLClause[], + grouping_sets = nothing) + c = new(over, by, grouping_sets isa Symbol ? convert(GroupingMode, grouping_sets) : grouping_sets) + gs = c.grouping_sets + if gs isa Vector{Vector{Int}} && !checkbounds(Bool, c.by, gs) + throw(DomainError(gs, "grouping_sets is out of bounds")) + end + c + end end -GroupClause(by...; over = nothing) = - GroupClause(over = over, by = SQLClause[by...]) +GroupClause(by...; over = nothing, grouping_sets = nothing) = + GroupClause(over = over, by = SQLClause[by...], grouping_sets = grouping_sets) """ - GROUP(; over = nothing, by = []) - GROUP(by...; over = nothing) + GROUP(; over = nothing, by = [], grouping_sets = nothing) + GROUP(by...; over = nothing, grouping_sets = nothing) A `GROUP BY` clause. @@ -43,6 +69,10 @@ dissect(scr::Symbol, ::typeof(GROUP), pats::Vector{Any}) = function PrettyPrinting.quoteof(c::GroupClause, ctx::QuoteContext) ex = Expr(:call, nameof(GROUP)) append!(ex.args, quoteof(c.by, ctx)) + gs = c.grouping_sets + if gs !== nothing + push!(ex.args, Expr(:kw, :grouping_sets, gs isa GroupingMode ? QuoteNode(Symbol(gs)) : gs)) + end if c.over !== nothing ex = Expr(:call, :|>, quoteof(c.over, ctx), ex) end @@ -50,5 +80,5 @@ function PrettyPrinting.quoteof(c::GroupClause, ctx::QuoteContext) end rebase(c::GroupClause, c′) = - GroupClause(over = rebase(c.over, c′), by = c.by) + GroupClause(over = rebase(c.over, c′), by = c.by, grouping_sets = c.grouping_sets) diff --git a/src/link.jl b/src/link.jl index 7b74ddba..9a28585f 100644 --- a/src/link.jl +++ b/src/link.jl @@ -118,7 +118,7 @@ end function dismantle(n::GroupNode, ctx) over′ = dismantle(n.over, ctx) by′ = dismantle_scalar(n.by, ctx) - Group(over = over′, by = by′, name = n.name, label_map = n.label_map) + Group(over = over′, by = by′, grouping_sets = n.grouping_sets, name = n.name, label_map = n.label_map) end function dismantle(n::IterateNode, ctx) @@ -290,6 +290,10 @@ function link(n::GroupNode, ctx) # To avoid duplicate SQL, they must be evaluated in a nested subquery. refs = SQLNode[] append!(refs, n.by) + if n.grouping_sets !== nothing + # Force evaluation in a nested subquery. + append!(refs, n.by) + end # Ignore `SELECT DISTINCT` case. if has_aggregates ctx′ = LinkContext(ctx, refs = refs) @@ -311,7 +315,7 @@ function link(n::GroupNode, ctx) over = Padding(over = over) end over′ = Linked(refs, 0, over = link(over, ctx, refs)) - Group(over = over′, by = n.by, name = n.name, label_map = n.label_map) + Group(over = over′, by = n.by, grouping_sets = n.grouping_sets, name = n.name, label_map = n.label_map) end function link(n::IterateNode, ctx) diff --git a/src/nodes/group.jl b/src/nodes/group.jl index cac773f8..618f9df9 100644 --- a/src/nodes/group.jl +++ b/src/nodes/group.jl @@ -3,37 +3,47 @@ mutable struct GroupNode <: TabularNode over::Union{SQLNode, Nothing} by::Vector{SQLNode} + grouping_sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} name::Union{Symbol, Nothing} label_map::OrderedDict{Symbol, Int} function GroupNode(; over = nothing, by = SQLNode[], + grouping_sets = nothing, name::Union{Symbol, AbstractString, Nothing} = nothing, label_map = nothing) - if label_map !== nothing - new(over, by, name !== nothing ? Symbol(name) : nothing, label_map) - else - n = new(over, by, name !== nothing ? Symbol(name) : nothing, OrderedDict{Symbol, Int}()) + n = new( + over, + by, + grouping_sets isa Symbol ? convert(GroupingMode, grouping_sets) : grouping_sets, + name !== nothing ? Symbol(name) : nothing, + label_map !== nothing ? label_map : OrderedDict{Symbol, Int}()) + gs = n.grouping_sets + if gs isa Vector{Vector{Int}} && !checkbounds(Bool, n.by, gs) + throw(DomainError(gs, "grouping_sets is out of bounds")) + end + if label_map === nothing populate_label_map!(n, n.by, n.label_map, n.name) - n end + n end end -GroupNode(by...; over = nothing, name = nothing) = - GroupNode(over = over, by = SQLNode[by...], name = name) +GroupNode(by...; over = nothing, grouping_sets = nothing, name = nothing) = + GroupNode(over = over, by = SQLNode[by...], grouping_sets = grouping_sets, name = name) """ - Group(; over; by = [], name = nothing) - Group(by...; over, name = nothing) + Group(; over, by = [], grouping_sets = grouping_sets, name = nothing) + Group(by...; over, grouping_sets = grouping_sets, name = nothing) The `Group` node summarizes the input dataset. Specifically, `Group` outputs all unique values of the given grouping key. This key partitions the input rows into disjoint groups that are summarized -by aggregate functions [`Agg`](@ref) applied to the output of `Group`. An -optional parameter `name` specifies the field to hold the group. +by aggregate functions [`Agg`](@ref) applied to the output of `Group`. The +parameter `grouping_sets` customizes the grouping sets. An optional parameter +`name` specifies the field to hold the group. The `Group` node is translated to a SQL query with a `GROUP BY` clause: ```sql @@ -92,6 +102,23 @@ FROM "person" AS "person_1" GROUP BY "person_1"."year_of_birth" ``` +*Number of patients per year of birth and the total number of patients.* + +```jldoctest +julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth]); + +julia> q = From(:person) |> + Group(Get.year_of_birth, grouping_sets = :cube) |> + Select(Get.year_of_birth, Agg.count()); + +julia> print(render(q, tables = [person])) +SELECT + "person_1"."year_of_birth", + count(*) AS "count" +FROM "person" AS "person_1" +GROUP BY CUBE("person_1"."year_of_birth") +``` + *Distinct states across all available locations.* ```jldoctest @@ -115,6 +142,10 @@ dissect(scr::Symbol, ::typeof(Group), pats::Vector{Any}) = function PrettyPrinting.quoteof(n::GroupNode, ctx::QuoteContext) ex = Expr(:call, nameof(Group), quoteof(n.by, ctx)...) + gs = n.grouping_sets + if gs !== nothing + push!(ex.args, Expr(:kw, :grouping_sets, gs isa GroupingMode ? QuoteNode(Symbol(gs)) : gs)) + end if n.name !== nothing push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) end diff --git a/src/resolve.jl b/src/resolve.jl index 32dafdf7..04259c44 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -324,7 +324,7 @@ function resolve(n::GroupNode, ctx) fields[n.name] = RowType(FieldTypeMap(), group) group = EmptyType() end - n′ = Group(over = over′, by = by′, label_map = n.label_map) + n′ = Group(over = over′, by = by′, grouping_sets = n.grouping_sets, label_map = n.label_map) Resolved(RowType(fields, group), over = n′) end diff --git a/src/serialize.jl b/src/serialize.jl index 6c60e3b8..1469c9f2 100644 --- a/src/serialize.jl +++ b/src/serialize.jl @@ -538,7 +538,36 @@ function serialize!(c::GroupClause, ctx) !isempty(c.by) || return newline(ctx) print(ctx, "GROUP BY") - serialize_lines!(c.by, ctx) + grouping_sets = c.grouping_sets + if grouping_sets !== nothing + if grouping_sets isa GroupingMode + if grouping_sets == GROUPING_MODE.ROLLUP + print(ctx, " ROLLUP(") + elseif grouping_sets == GROUPING_MODE.CUBE + print(ctx, " CUBE(") + else + throw(DomainError(grouping_sets)) + end + serialize!(c.by, ctx) + print(ctx, ')') + else + print(ctx, " GROUPING SETS(") + first = true + for set in grouping_sets + if !first + print(ctx, ", ") + else + first = false + end + print(ctx, '(') + serialize!(c.by[set], ctx) + print(ctx, ')') + end + print(ctx, ')') + end + else + serialize_lines!(c.by, ctx) + end end function serialize!(c::HavingClause, ctx) diff --git a/src/translate.jl b/src/translate.jl index 462e25bd..11ddae39 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -618,15 +618,15 @@ function assemble(n::GroupNode, ctx) push!(trns, ref => translate(over, ctx, subs)) end end - if !has_aggregates + if !has_aggregates && n.grouping_sets === nothing for name in keys(n.label_map) push!(trns, Get(name = name) => by[n.label_map[name]]) end end repl, cols = make_repl_cols(trns) @assert !isempty(cols) - if has_aggregates - c = GROUP(over = tail, by = by) + if has_aggregates || n.grouping_sets !== nothing + c = GROUP(over = tail, by = by, grouping_sets = n.grouping_sets) else args = complete(cols) c = SELECT(over = tail, distinct = true, args = args) From 109c22d68db1a76ef0f01e2d65640ec86a52e1a6 Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Fri, 8 Mar 2024 06:56:10 -0600 Subject: [PATCH 2/4] rename parameter grouping_sets -> sets --- docs/src/test/clauses.md | 16 ++++++++-------- docs/src/test/nodes.md | 12 ++++++------ src/clauses/group.jl | 28 ++++++++++++++-------------- src/link.jl | 6 +++--- src/nodes/group.jl | 30 +++++++++++++++--------------- src/resolve.jl | 2 +- src/serialize.jl | 14 +++++++------- src/translate.jl | 6 +++--- 8 files changed, 57 insertions(+), 57 deletions(-) diff --git a/docs/src/test/clauses.md b/docs/src/test/clauses.md index 42fad55c..e147c6dd 100644 --- a/docs/src/test/clauses.md +++ b/docs/src/test/clauses.md @@ -1133,8 +1133,8 @@ rendered. `GROUP` can accept the grouping mode or a vector of grouping sets. - c = FROM(:person) |> GROUP(:year_of_birth, grouping_sets = :ROLLUP) - #-> (…) |> GROUP(…, grouping_sets = :ROLLUP) + c = FROM(:person) |> GROUP(:year_of_birth, sets = :ROLLUP) + #-> (…) |> GROUP(…, sets = :ROLLUP) print(render(c |> SELECT(:year_of_birth, AGG(:count)))) #=> @@ -1145,8 +1145,8 @@ rendered. GROUP BY ROLLUP("year_of_birth") =# - c = FROM(:person) |> GROUP(:year_of_birth, grouping_sets = :CUBE) - #-> (…) |> GROUP(…, grouping_sets = :CUBE) + c = FROM(:person) |> GROUP(:year_of_birth, sets = :CUBE) + #-> (…) |> GROUP(…, sets = :CUBE) print(render(c |> SELECT(:year_of_birth, AGG(:count)))) #=> @@ -1157,8 +1157,8 @@ rendered. GROUP BY CUBE("year_of_birth") =# - c = FROM(:person) |> GROUP(:year_of_birth, grouping_sets = [[1], Int[]]) - #-> (…) |> GROUP(…, grouping_sets = [[1], Int64[]]) + c = FROM(:person) |> GROUP(:year_of_birth, sets = [[1], Int[]]) + #-> (…) |> GROUP(…, sets = [[1], Int64[]]) print(render(c |> SELECT(:year_of_birth, AGG(:count)))) #=> @@ -1171,10 +1171,10 @@ rendered. `GROUP` raises an error when the vector of grouping sets is out of bounds. - FROM(:person) |> GROUP(:year_of_birth, grouping_sets = [[1, 2], [1], Int[]]) + FROM(:person) |> GROUP(:year_of_birth, sets = [[1, 2], [1], Int[]]) #=> ERROR: DomainError with [[1, 2], [1], Int64[]]: - grouping_sets is out of bounds + sets are out of bounds =# diff --git a/docs/src/test/nodes.md b/docs/src/test/nodes.md index 3a539884..53f9c7f4 100644 --- a/docs/src/test/nodes.md +++ b/docs/src/test/nodes.md @@ -2505,14 +2505,14 @@ downstream. `Group` allows specifying the grouping sets. q = From(person) |> - Group(Get.year_of_birth, grouping_sets = :cube) + Group(Get.year_of_birth, sets = :cube) Define(Agg.count()) display(q) #=> let person = SQLTable(:person, …), q1 = From(person), - q2 = q1 |> Group(Get.year_of_birth, grouping_sets = :CUBE) + q2 = q1 |> Group(Get.year_of_birth, sets = :CUBE) q2 end =# @@ -2525,14 +2525,14 @@ downstream. =# q = From(person) |> - Group(Get.year_of_birth, grouping_sets = [[1], Int[]]) + Group(Get.year_of_birth, sets = [[1], Int[]]) Define(Agg.count()) display(q) #=> let person = SQLTable(:person, …), q1 = From(person), - q2 = q1 |> Group(Get.year_of_birth, grouping_sets = [[1], []]) + q2 = q1 |> Group(Get.year_of_birth, sets = [[1], []]) q2 end =# @@ -2547,10 +2547,10 @@ downstream. `Group` complains about out-of-bound grouping sets. From(person) |> - Group(Get.year_of_birth, grouping_sets = [[1, 2], [1], Int[]]) + Group(Get.year_of_birth, sets = [[1, 2], [1], Int[]]) #=> ERROR: DomainError with [[1, 2], [1], Int64[]]: - grouping_sets is out of bounds + sets are out of bounds =# `Group` allows specifying the name of a group field. diff --git a/src/clauses/group.jl b/src/clauses/group.jl index ef243909..c32d905d 100644 --- a/src/clauses/group.jl +++ b/src/clauses/group.jl @@ -21,27 +21,27 @@ import .GROUPING_MODE.GroupingMode mutable struct GroupClause <: AbstractSQLClause over::Union{SQLClause, Nothing} by::Vector{SQLClause} - grouping_sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} + sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} function GroupClause(; over = nothing, by = SQLClause[], - grouping_sets = nothing) - c = new(over, by, grouping_sets isa Symbol ? convert(GroupingMode, grouping_sets) : grouping_sets) - gs = c.grouping_sets - if gs isa Vector{Vector{Int}} && !checkbounds(Bool, c.by, gs) - throw(DomainError(gs, "grouping_sets is out of bounds")) + sets = nothing) + c = new(over, by, sets isa Symbol ? convert(GroupingMode, sets) : sets) + s = c.sets + if s isa Vector{Vector{Int}} && !checkbounds(Bool, c.by, s) + throw(DomainError(s, "sets are out of bounds")) end c end end -GroupClause(by...; over = nothing, grouping_sets = nothing) = - GroupClause(over = over, by = SQLClause[by...], grouping_sets = grouping_sets) +GroupClause(by...; over = nothing, sets = nothing) = + GroupClause(over = over, by = SQLClause[by...], sets = sets) """ - GROUP(; over = nothing, by = [], grouping_sets = nothing) - GROUP(by...; over = nothing, grouping_sets = nothing) + GROUP(; over = nothing, by = [], sets = nothing) + GROUP(by...; over = nothing, sets = nothing) A `GROUP BY` clause. @@ -69,9 +69,9 @@ dissect(scr::Symbol, ::typeof(GROUP), pats::Vector{Any}) = function PrettyPrinting.quoteof(c::GroupClause, ctx::QuoteContext) ex = Expr(:call, nameof(GROUP)) append!(ex.args, quoteof(c.by, ctx)) - gs = c.grouping_sets - if gs !== nothing - push!(ex.args, Expr(:kw, :grouping_sets, gs isa GroupingMode ? QuoteNode(Symbol(gs)) : gs)) + s = c.sets + if s !== nothing + push!(ex.args, Expr(:kw, :sets, s isa GroupingMode ? QuoteNode(Symbol(s)) : s)) end if c.over !== nothing ex = Expr(:call, :|>, quoteof(c.over, ctx), ex) @@ -80,5 +80,5 @@ function PrettyPrinting.quoteof(c::GroupClause, ctx::QuoteContext) end rebase(c::GroupClause, c′) = - GroupClause(over = rebase(c.over, c′), by = c.by, grouping_sets = c.grouping_sets) + GroupClause(over = rebase(c.over, c′), by = c.by, sets = c.sets) diff --git a/src/link.jl b/src/link.jl index 9a28585f..06ce0522 100644 --- a/src/link.jl +++ b/src/link.jl @@ -118,7 +118,7 @@ end function dismantle(n::GroupNode, ctx) over′ = dismantle(n.over, ctx) by′ = dismantle_scalar(n.by, ctx) - Group(over = over′, by = by′, grouping_sets = n.grouping_sets, name = n.name, label_map = n.label_map) + Group(over = over′, by = by′, sets = n.sets, name = n.name, label_map = n.label_map) end function dismantle(n::IterateNode, ctx) @@ -290,7 +290,7 @@ function link(n::GroupNode, ctx) # To avoid duplicate SQL, they must be evaluated in a nested subquery. refs = SQLNode[] append!(refs, n.by) - if n.grouping_sets !== nothing + if n.sets !== nothing # Force evaluation in a nested subquery. append!(refs, n.by) end @@ -315,7 +315,7 @@ function link(n::GroupNode, ctx) over = Padding(over = over) end over′ = Linked(refs, 0, over = link(over, ctx, refs)) - Group(over = over′, by = n.by, grouping_sets = n.grouping_sets, name = n.name, label_map = n.label_map) + Group(over = over′, by = n.by, sets = n.sets, name = n.name, label_map = n.label_map) end function link(n::IterateNode, ctx) diff --git a/src/nodes/group.jl b/src/nodes/group.jl index 618f9df9..2d330a74 100644 --- a/src/nodes/group.jl +++ b/src/nodes/group.jl @@ -3,25 +3,25 @@ mutable struct GroupNode <: TabularNode over::Union{SQLNode, Nothing} by::Vector{SQLNode} - grouping_sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} + sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} name::Union{Symbol, Nothing} label_map::OrderedDict{Symbol, Int} function GroupNode(; over = nothing, by = SQLNode[], - grouping_sets = nothing, + sets = nothing, name::Union{Symbol, AbstractString, Nothing} = nothing, label_map = nothing) n = new( over, by, - grouping_sets isa Symbol ? convert(GroupingMode, grouping_sets) : grouping_sets, + sets isa Symbol ? convert(GroupingMode, sets) : sets, name !== nothing ? Symbol(name) : nothing, label_map !== nothing ? label_map : OrderedDict{Symbol, Int}()) - gs = n.grouping_sets - if gs isa Vector{Vector{Int}} && !checkbounds(Bool, n.by, gs) - throw(DomainError(gs, "grouping_sets is out of bounds")) + s = n.sets + if s isa Vector{Vector{Int}} && !checkbounds(Bool, n.by, s) + throw(DomainError(s, "sets are out of bounds")) end if label_map === nothing populate_label_map!(n, n.by, n.label_map, n.name) @@ -30,19 +30,19 @@ mutable struct GroupNode <: TabularNode end end -GroupNode(by...; over = nothing, grouping_sets = nothing, name = nothing) = - GroupNode(over = over, by = SQLNode[by...], grouping_sets = grouping_sets, name = name) +GroupNode(by...; over = nothing, sets = nothing, name = nothing) = + GroupNode(over = over, by = SQLNode[by...], sets = sets, name = name) """ - Group(; over, by = [], grouping_sets = grouping_sets, name = nothing) - Group(by...; over, grouping_sets = grouping_sets, name = nothing) + Group(; over, by = [], sets = sets, name = nothing) + Group(by...; over, sets = sets, name = nothing) The `Group` node summarizes the input dataset. Specifically, `Group` outputs all unique values of the given grouping key. This key partitions the input rows into disjoint groups that are summarized by aggregate functions [`Agg`](@ref) applied to the output of `Group`. The -parameter `grouping_sets` customizes the grouping sets. An optional parameter +parameter `sets` customizes the grouping sets. An optional parameter `name` specifies the field to hold the group. The `Group` node is translated to a SQL query with a `GROUP BY` clause: @@ -108,7 +108,7 @@ GROUP BY "person_1"."year_of_birth" julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth]); julia> q = From(:person) |> - Group(Get.year_of_birth, grouping_sets = :cube) |> + Group(Get.year_of_birth, sets = :cube) |> Select(Get.year_of_birth, Agg.count()); julia> print(render(q, tables = [person])) @@ -142,9 +142,9 @@ dissect(scr::Symbol, ::typeof(Group), pats::Vector{Any}) = function PrettyPrinting.quoteof(n::GroupNode, ctx::QuoteContext) ex = Expr(:call, nameof(Group), quoteof(n.by, ctx)...) - gs = n.grouping_sets - if gs !== nothing - push!(ex.args, Expr(:kw, :grouping_sets, gs isa GroupingMode ? QuoteNode(Symbol(gs)) : gs)) + s = n.sets + if s !== nothing + push!(ex.args, Expr(:kw, :sets, s isa GroupingMode ? QuoteNode(Symbol(s)) : s)) end if n.name !== nothing push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) diff --git a/src/resolve.jl b/src/resolve.jl index 04259c44..81b76d92 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -324,7 +324,7 @@ function resolve(n::GroupNode, ctx) fields[n.name] = RowType(FieldTypeMap(), group) group = EmptyType() end - n′ = Group(over = over′, by = by′, grouping_sets = n.grouping_sets, label_map = n.label_map) + n′ = Group(over = over′, by = by′, sets = n.sets, label_map = n.label_map) Resolved(RowType(fields, group), over = n′) end diff --git a/src/serialize.jl b/src/serialize.jl index 1469c9f2..4765ae3b 100644 --- a/src/serialize.jl +++ b/src/serialize.jl @@ -538,22 +538,22 @@ function serialize!(c::GroupClause, ctx) !isempty(c.by) || return newline(ctx) print(ctx, "GROUP BY") - grouping_sets = c.grouping_sets - if grouping_sets !== nothing - if grouping_sets isa GroupingMode - if grouping_sets == GROUPING_MODE.ROLLUP + sets = c.sets + if sets !== nothing + if sets isa GroupingMode + if sets == GROUPING_MODE.ROLLUP print(ctx, " ROLLUP(") - elseif grouping_sets == GROUPING_MODE.CUBE + elseif sets == GROUPING_MODE.CUBE print(ctx, " CUBE(") else - throw(DomainError(grouping_sets)) + throw(DomainError(sets)) end serialize!(c.by, ctx) print(ctx, ')') else print(ctx, " GROUPING SETS(") first = true - for set in grouping_sets + for set in sets if !first print(ctx, ", ") else diff --git a/src/translate.jl b/src/translate.jl index 11ddae39..9f9a7ce9 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -618,15 +618,15 @@ function assemble(n::GroupNode, ctx) push!(trns, ref => translate(over, ctx, subs)) end end - if !has_aggregates && n.grouping_sets === nothing + if !has_aggregates && n.sets === nothing for name in keys(n.label_map) push!(trns, Get(name = name) => by[n.label_map[name]]) end end repl, cols = make_repl_cols(trns) @assert !isempty(cols) - if has_aggregates || n.grouping_sets !== nothing - c = GROUP(over = tail, by = by, grouping_sets = n.grouping_sets) + if has_aggregates || n.sets !== nothing + c = GROUP(over = tail, by = by, sets = n.sets) else args = complete(cols) c = SELECT(over = tail, distinct = true, args = args) From 2449bc46f1c71b81bada3e811d2744aa7a1fdb02 Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Fri, 8 Mar 2024 13:34:46 -0600 Subject: [PATCH 3/4] Allow key names for specifying grouping sets --- docs/src/test/nodes.md | 48 ++++++++++++++++++++++++++++++++++----- src/nodes.jl | 26 ++++++++++++++++++++- src/nodes/group.jl | 51 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 112 insertions(+), 13 deletions(-) diff --git a/docs/src/test/nodes.md b/docs/src/test/nodes.md index 53f9c7f4..c2ad9274 100644 --- a/docs/src/test/nodes.md +++ b/docs/src/test/nodes.md @@ -2502,7 +2502,8 @@ downstream. ) AS "person_2" =# -`Group` allows specifying the grouping sets. +`Group` allows specifying the grouping sets, either with grouping mode +indicators `:cube` or `:rollup`, or by explicit enumeration. q = From(person) |> Group(Get.year_of_birth, sets = :cube) @@ -2544,13 +2545,50 @@ downstream. GROUP BY GROUPING SETS(("person_1"."year_of_birth"), ()) =# -`Group` complains about out-of-bound grouping sets. +`Group` allows specifying grouping sets using names of the grouping keys. + + q = From(person) |> + Group(Get.year_of_birth, Get.gender_concept_id, + sets = ([:year_of_birth], ["gender_concept_id"])) + Define(Agg.count()) + + display(q) + #=> + let person = SQLTable(:person, …), + q1 = From(person), + q2 = q1 |> + Group(Get.year_of_birth, Get.gender_concept_id, sets = [[1], [2]]) + q2 + end + =# + +`Group` will report when a grouping set refers to an unknown key. From(person) |> - Group(Get.year_of_birth, sets = [[1, 2], [1], Int[]]) + Group(Get.year_of_birth, sets = [[:gender_concept_id], []]) #=> - ERROR: DomainError with [[1, 2], [1], Int64[]]: - sets are out of bounds + ERROR: FunSQL.InvalidGroupingSetsError: `gender_concept_id` is not a valid key + =# + +`Group` complains about out-of-bound or incomplete grouping sets. + + From(person) |> + Group(Get.year_of_birth, sets = [[1, 2], [1], []]) + #=> + ERROR: FunSQL.InvalidGroupingSetsError: `2` is out of bounds in: + let q1 = Group(Get.year_of_birth, sets = [[1, 2], [1], []]) + q1 + end + =# + + From(person) |> + Group(Get.year_of_birth, Get.gender_concept_id, + sets = [[1], []]) + #=> + ERROR: FunSQL.InvalidGroupingSetsError: missing keys `[:year_of_birth]` in: + let q1 = Group(Get.year_of_birth, Get.gender_concept_id, sets = [[1], []]) + q1 + end =# `Group` allows specifying the name of a group field. diff --git a/src/nodes.jl b/src/nodes.jl index 58d562ad..c896f795 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -348,7 +348,7 @@ A node that cannot be rebased. struct RebaseError <: FunSQLError path::Vector{SQLNode} - RebaseError(; path = SQLNode[])= + RebaseError(; path = SQLNode[]) = new(path) end @@ -357,6 +357,30 @@ function Base.showerror(io::IO, err::RebaseError) showpath(io, err.path) end +""" +Grouping sets are specified incorrectly. +""" +struct InvalidGroupingSetsError <: FunSQLError + value::Union{Int, Symbol, Vector{Symbol}} + path::Vector{SQLNode} + + InvalidGroupingSetsError(value; path = SQLNode[]) = + new(value, path) +end + +function Base.showerror(io::IO, err::InvalidGroupingSetsError) + print(io, "FunSQL.InvalidGroupingSetsError: ") + value = err.value + if value isa Int + print(io, "`$value` is out of bounds") + elseif value isa Symbol + print(io, "`$value` is not a valid key") + elseif value isa Vector{Symbol} + print(io, "missing keys `$value`") + end + showpath(io, err.path) +end + module REFERENCE_ERROR_TYPE @enum ReferenceErrorType::UInt8 begin diff --git a/src/nodes/group.jl b/src/nodes/group.jl index 2d330a74..2c98aae1 100644 --- a/src/nodes/group.jl +++ b/src/nodes/group.jl @@ -1,5 +1,31 @@ # Grouping. +function populate_grouping_sets!(n, sets) + sets′ = Vector{Int}[] + for set in sets + set′ = Int[] + for el in set + push!(set′, _grouping_index(el, n)) + end + push!(sets′, set′) + end + n.sets = sets′ +end + +_grouping_index(el::Integer, n) = + convert(Int, el) + +function _grouping_index(el::Symbol, n) + k = get(n.label_map, el, nothing) + if k == nothing + throw(InvalidGroupingSetsError(el)) + end + k +end + +_grouping_index(el::AbstractString, n) = + _grouping_index(Symbol(el), n) + mutable struct GroupNode <: TabularNode over::Union{SQLNode, Nothing} by::Vector{SQLNode} @@ -13,19 +39,29 @@ mutable struct GroupNode <: TabularNode sets = nothing, name::Union{Symbol, AbstractString, Nothing} = nothing, label_map = nothing) + need_to_populate_sets = !(sets isa Union{Vector{Vector{Int}}, GroupingMode, Symbol, Nothing}) n = new( over, by, - sets isa Symbol ? convert(GroupingMode, sets) : sets, + need_to_populate_sets ? nothing : sets isa Symbol ? convert(GroupingMode, sets) : sets, name !== nothing ? Symbol(name) : nothing, label_map !== nothing ? label_map : OrderedDict{Symbol, Int}()) - s = n.sets - if s isa Vector{Vector{Int}} && !checkbounds(Bool, n.by, s) - throw(DomainError(s, "sets are out of bounds")) - end if label_map === nothing populate_label_map!(n, n.by, n.label_map, n.name) end + if need_to_populate_sets + populate_grouping_sets!(n, sets) + end + if n.sets isa Vector{Vector{Int}} + usage = falses(length(n.by)) + for set in n.sets + for k in set + checkbounds(Bool, n.by, k) || throw(InvalidGroupingSetsError(k, path = [n])) + usage[k] = true + end + end + all(usage) || throw(InvalidGroupingSetsError(collect(keys(n.label_map))[usage], path = [n])) + end n end end @@ -42,8 +78,9 @@ The `Group` node summarizes the input dataset. Specifically, `Group` outputs all unique values of the given grouping key. This key partitions the input rows into disjoint groups that are summarized by aggregate functions [`Agg`](@ref) applied to the output of `Group`. The -parameter `sets` customizes the grouping sets. An optional parameter -`name` specifies the field to hold the group. +parameter `sets` specifies the grouping sets, either with grouping mode +indicators `:cube` or `:rollup`, or explicitly as `Vector{Vector{Symbol}}`. +An optional parameter `name` specifies the field to hold the group. The `Group` node is translated to a SQL query with a `GROUP BY` clause: ```sql From 92b61bba9fa99b168a5c179f10fd11c0235cca87 Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Fri, 8 Mar 2024 13:35:26 -0600 Subject: [PATCH 4/4] Fix output to reflect latest PrettyPrinting --- docs/src/test/other.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/test/other.md b/docs/src/test/other.md index 6ca8b172..d81945fc 100644 --- a/docs/src/test/other.md +++ b/docs/src/test/other.md @@ -162,7 +162,7 @@ Any `Dict`-like object can serve as a query cache. #-> SQLCatalog(dialect = SQLDialect(), cache = Dict{Any, Any}()) display(customcache_catalog) - #-> SQLCatalog(dialect = SQLDialect(), cache = Dict{Any, Any}()) + #-> SQLCatalog(dialect = SQLDialect(), cache = (Dict{Any, Any})()) The catalog behaves as a read-only `Dict` object.