diff --git a/docs/src/test/clauses.md b/docs/src/test/clauses.md index 590ce1f1..e147c6dd 100644 --- a/docs/src/test/clauses.md +++ b/docs/src/test/clauses.md @@ -1131,6 +1131,52 @@ rendered. FROM "person" =# +`GROUP` can accept the grouping mode or a vector of grouping sets. + + c = FROM(:person) |> GROUP(:year_of_birth, sets = :ROLLUP) + #-> (…) |> GROUP(…, sets = :ROLLUP) + + print(render(c |> SELECT(:year_of_birth, AGG(:count)))) + #=> + SELECT + "year_of_birth", + count(*) + FROM "person" + GROUP BY ROLLUP("year_of_birth") + =# + + c = FROM(:person) |> GROUP(:year_of_birth, sets = :CUBE) + #-> (…) |> GROUP(…, sets = :CUBE) + + print(render(c |> SELECT(:year_of_birth, AGG(:count)))) + #=> + SELECT + "year_of_birth", + count(*) + FROM "person" + GROUP BY CUBE("year_of_birth") + =# + + c = FROM(:person) |> GROUP(:year_of_birth, sets = [[1], Int[]]) + #-> (…) |> GROUP(…, sets = [[1], Int64[]]) + + print(render(c |> SELECT(:year_of_birth, AGG(:count)))) + #=> + SELECT + "year_of_birth", + count(*) + FROM "person" + GROUP BY GROUPING SETS(("year_of_birth"), ()) + =# + +`GROUP` raises an error when the vector of grouping sets is out of bounds. + + FROM(:person) |> GROUP(:year_of_birth, sets = [[1, 2], [1], Int[]]) + #=> + ERROR: DomainError with [[1, 2], [1], Int64[]]: + sets are out of bounds + =# + ## `HAVING` Clause diff --git a/docs/src/test/nodes.md b/docs/src/test/nodes.md index 4639d3b9..c2ad9274 100644 --- a/docs/src/test/nodes.md +++ b/docs/src/test/nodes.md @@ -2502,6 +2502,95 @@ downstream. ) AS "person_2" =# +`Group` allows specifying the grouping sets, either with grouping mode +indicators `:cube` or `:rollup`, or by explicit enumeration. + + q = From(person) |> + Group(Get.year_of_birth, sets = :cube) + Define(Agg.count()) + + display(q) + #=> + let person = SQLTable(:person, …), + q1 = From(person), + q2 = q1 |> Group(Get.year_of_birth, sets = :CUBE) + q2 + end + =# + + print(render(q)) + #=> + SELECT "person_1"."year_of_birth" + FROM "person" AS "person_1" + GROUP BY CUBE("person_1"."year_of_birth") + =# + + q = From(person) |> + Group(Get.year_of_birth, sets = [[1], Int[]]) + Define(Agg.count()) + + display(q) + #=> + let person = SQLTable(:person, …), + q1 = From(person), + q2 = q1 |> Group(Get.year_of_birth, sets = [[1], []]) + q2 + end + =# + + print(render(q)) + #=> + SELECT "person_1"."year_of_birth" + FROM "person" AS "person_1" + GROUP BY GROUPING SETS(("person_1"."year_of_birth"), ()) + =# + +`Group` allows specifying grouping sets using names of the grouping keys. + + q = From(person) |> + Group(Get.year_of_birth, Get.gender_concept_id, + sets = ([:year_of_birth], ["gender_concept_id"])) + Define(Agg.count()) + + display(q) + #=> + let person = SQLTable(:person, …), + q1 = From(person), + q2 = q1 |> + Group(Get.year_of_birth, Get.gender_concept_id, sets = [[1], [2]]) + q2 + end + =# + +`Group` will report when a grouping set refers to an unknown key. + + From(person) |> + Group(Get.year_of_birth, sets = [[:gender_concept_id], []]) + #=> + ERROR: FunSQL.InvalidGroupingSetsError: `gender_concept_id` is not a valid key + =# + +`Group` complains about out-of-bound or incomplete grouping sets. + + From(person) |> + Group(Get.year_of_birth, sets = [[1, 2], [1], []]) + #=> + ERROR: FunSQL.InvalidGroupingSetsError: `2` is out of bounds in: + let q1 = Group(Get.year_of_birth, sets = [[1, 2], [1], []]) + q1 + end + =# + + From(person) |> + Group(Get.year_of_birth, Get.gender_concept_id, + sets = [[1], []]) + #=> + ERROR: FunSQL.InvalidGroupingSetsError: missing keys `[:year_of_birth]` in: + let q1 = Group(Get.year_of_birth, Get.gender_concept_id, sets = [[1], []]) + q1 + end + =# + `Group` allows specifying the name of a group field. q = From(person) |> diff --git a/docs/src/test/other.md b/docs/src/test/other.md index 6ca8b172..d81945fc 100644 --- a/docs/src/test/other.md +++ b/docs/src/test/other.md @@ -162,7 +162,7 @@ Any `Dict`-like object can serve as a query cache. #-> SQLCatalog(dialect = SQLDialect(), cache = Dict{Any, Any}()) display(customcache_catalog) - #-> SQLCatalog(dialect = SQLDialect(), cache = Dict{Any, Any}()) + #-> SQLCatalog(dialect = SQLDialect(), cache = (Dict{Any, Any})()) The catalog behaves as a read-only `Dict` object. diff --git a/src/clauses/group.jl b/src/clauses/group.jl index f1d0dd23..c32d905d 100644 --- a/src/clauses/group.jl +++ b/src/clauses/group.jl @@ -1,21 +1,47 @@ # GROUP BY clause. +module GROUPING_MODE + +@enum GroupingMode::UInt8 begin + ROLLUP + CUBE +end + +Base.convert(::Type{GroupingMode}, s::Symbol) = + s in (:rollup, :ROLLUP) ? + ROLLUP : + s in (:cube, :CUBE) ? + CUBE : + throw(DomainError(QuoteNode(s), "expected :rollup or :cube")) + +end + +import .GROUPING_MODE.GroupingMode + mutable struct GroupClause <: AbstractSQLClause over::Union{SQLClause, Nothing} by::Vector{SQLClause} + sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} - GroupClause(; + function GroupClause(; over = nothing, - by = SQLClause[]) = - new(over, by) + by = SQLClause[], + sets = nothing) + c = new(over, by, sets isa Symbol ? convert(GroupingMode, sets) : sets) + s = c.sets + if s isa Vector{Vector{Int}} && !checkbounds(Bool, c.by, s) + throw(DomainError(s, "sets are out of bounds")) + end + c + end end -GroupClause(by...; over = nothing) = - GroupClause(over = over, by = SQLClause[by...]) +GroupClause(by...; over = nothing, sets = nothing) = + GroupClause(over = over, by = SQLClause[by...], sets = sets) """ - GROUP(; over = nothing, by = []) - GROUP(by...; over = nothing) + GROUP(; over = nothing, by = [], sets = nothing) + GROUP(by...; over = nothing, sets = nothing) A `GROUP BY` clause. @@ -43,6 +69,10 @@ dissect(scr::Symbol, ::typeof(GROUP), pats::Vector{Any}) = function PrettyPrinting.quoteof(c::GroupClause, ctx::QuoteContext) ex = Expr(:call, nameof(GROUP)) append!(ex.args, quoteof(c.by, ctx)) + s = c.sets + if s !== nothing + push!(ex.args, Expr(:kw, :sets, s isa GroupingMode ? QuoteNode(Symbol(s)) : s)) + end if c.over !== nothing ex = Expr(:call, :|>, quoteof(c.over, ctx), ex) end @@ -50,5 +80,5 @@ function PrettyPrinting.quoteof(c::GroupClause, ctx::QuoteContext) end rebase(c::GroupClause, c′) = - GroupClause(over = rebase(c.over, c′), by = c.by) + GroupClause(over = rebase(c.over, c′), by = c.by, sets = c.sets) diff --git a/src/link.jl b/src/link.jl index 7b74ddba..06ce0522 100644 --- a/src/link.jl +++ b/src/link.jl @@ -118,7 +118,7 @@ end function dismantle(n::GroupNode, ctx) over′ = dismantle(n.over, ctx) by′ = dismantle_scalar(n.by, ctx) - Group(over = over′, by = by′, name = n.name, label_map = n.label_map) + Group(over = over′, by = by′, sets = n.sets, name = n.name, label_map = n.label_map) end function dismantle(n::IterateNode, ctx) @@ -290,6 +290,10 @@ function link(n::GroupNode, ctx) # To avoid duplicate SQL, they must be evaluated in a nested subquery. refs = SQLNode[] append!(refs, n.by) + if n.sets !== nothing + # Force evaluation in a nested subquery. + append!(refs, n.by) + end # Ignore `SELECT DISTINCT` case. if has_aggregates ctx′ = LinkContext(ctx, refs = refs) @@ -311,7 +315,7 @@ function link(n::GroupNode, ctx) over = Padding(over = over) end over′ = Linked(refs, 0, over = link(over, ctx, refs)) - Group(over = over′, by = n.by, name = n.name, label_map = n.label_map) + Group(over = over′, by = n.by, sets = n.sets, name = n.name, label_map = n.label_map) end function link(n::IterateNode, ctx) diff --git a/src/nodes.jl b/src/nodes.jl index 58d562ad..c896f795 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -348,7 +348,7 @@ A node that cannot be rebased. struct RebaseError <: FunSQLError path::Vector{SQLNode} - RebaseError(; path = SQLNode[])= + RebaseError(; path = SQLNode[]) = new(path) end @@ -357,6 +357,30 @@ function Base.showerror(io::IO, err::RebaseError) showpath(io, err.path) end +""" +Grouping sets are specified incorrectly. +""" +struct InvalidGroupingSetsError <: FunSQLError + value::Union{Int, Symbol, Vector{Symbol}} + path::Vector{SQLNode} + + InvalidGroupingSetsError(value; path = SQLNode[]) = + new(value, path) +end + +function Base.showerror(io::IO, err::InvalidGroupingSetsError) + print(io, "FunSQL.InvalidGroupingSetsError: ") + value = err.value + if value isa Int + print(io, "`$value` is out of bounds") + elseif value isa Symbol + print(io, "`$value` is not a valid key") + elseif value isa Vector{Symbol} + print(io, "missing keys `$value`") + end + showpath(io, err.path) +end + module REFERENCE_ERROR_TYPE @enum ReferenceErrorType::UInt8 begin diff --git a/src/nodes/group.jl b/src/nodes/group.jl index cac773f8..2c98aae1 100644 --- a/src/nodes/group.jl +++ b/src/nodes/group.jl @@ -1,39 +1,86 @@ # Grouping. +function populate_grouping_sets!(n, sets) + sets′ = Vector{Int}[] + for set in sets + set′ = Int[] + for el in set + push!(set′, _grouping_index(el, n)) + end + push!(sets′, set′) + end + n.sets = sets′ +end + +_grouping_index(el::Integer, n) = + convert(Int, el) + +function _grouping_index(el::Symbol, n) + k = get(n.label_map, el, nothing) + if k == nothing + throw(InvalidGroupingSetsError(el)) + end + k +end + +_grouping_index(el::AbstractString, n) = + _grouping_index(Symbol(el), n) + mutable struct GroupNode <: TabularNode over::Union{SQLNode, Nothing} by::Vector{SQLNode} + sets::Union{Vector{Vector{Int}}, GroupingMode, Nothing} name::Union{Symbol, Nothing} label_map::OrderedDict{Symbol, Int} function GroupNode(; over = nothing, by = SQLNode[], + sets = nothing, name::Union{Symbol, AbstractString, Nothing} = nothing, label_map = nothing) - if label_map !== nothing - new(over, by, name !== nothing ? Symbol(name) : nothing, label_map) - else - n = new(over, by, name !== nothing ? Symbol(name) : nothing, OrderedDict{Symbol, Int}()) + need_to_populate_sets = !(sets isa Union{Vector{Vector{Int}}, GroupingMode, Symbol, Nothing}) + n = new( + over, + by, + need_to_populate_sets ? nothing : sets isa Symbol ? convert(GroupingMode, sets) : sets, + name !== nothing ? Symbol(name) : nothing, + label_map !== nothing ? label_map : OrderedDict{Symbol, Int}()) + if label_map === nothing populate_label_map!(n, n.by, n.label_map, n.name) - n end + if need_to_populate_sets + populate_grouping_sets!(n, sets) + end + if n.sets isa Vector{Vector{Int}} + usage = falses(length(n.by)) + for set in n.sets + for k in set + checkbounds(Bool, n.by, k) || throw(InvalidGroupingSetsError(k, path = [n])) + usage[k] = true + end + end + all(usage) || throw(InvalidGroupingSetsError(collect(keys(n.label_map))[usage], path = [n])) + end + n end end -GroupNode(by...; over = nothing, name = nothing) = - GroupNode(over = over, by = SQLNode[by...], name = name) +GroupNode(by...; over = nothing, sets = nothing, name = nothing) = + GroupNode(over = over, by = SQLNode[by...], sets = sets, name = name) """ - Group(; over; by = [], name = nothing) - Group(by...; over, name = nothing) + Group(; over, by = [], sets = sets, name = nothing) + Group(by...; over, sets = sets, name = nothing) The `Group` node summarizes the input dataset. Specifically, `Group` outputs all unique values of the given grouping key. This key partitions the input rows into disjoint groups that are summarized -by aggregate functions [`Agg`](@ref) applied to the output of `Group`. An -optional parameter `name` specifies the field to hold the group. +by aggregate functions [`Agg`](@ref) applied to the output of `Group`. The +parameter `sets` specifies the grouping sets, either with grouping mode +indicators `:cube` or `:rollup`, or explicitly as `Vector{Vector{Symbol}}`. +An optional parameter `name` specifies the field to hold the group. The `Group` node is translated to a SQL query with a `GROUP BY` clause: ```sql @@ -92,6 +139,23 @@ FROM "person" AS "person_1" GROUP BY "person_1"."year_of_birth" ``` +*Number of patients per year of birth and the total number of patients.* + +```jldoctest +julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth]); + +julia> q = From(:person) |> + Group(Get.year_of_birth, sets = :cube) |> + Select(Get.year_of_birth, Agg.count()); + +julia> print(render(q, tables = [person])) +SELECT + "person_1"."year_of_birth", + count(*) AS "count" +FROM "person" AS "person_1" +GROUP BY CUBE("person_1"."year_of_birth") +``` + *Distinct states across all available locations.* ```jldoctest @@ -115,6 +179,10 @@ dissect(scr::Symbol, ::typeof(Group), pats::Vector{Any}) = function PrettyPrinting.quoteof(n::GroupNode, ctx::QuoteContext) ex = Expr(:call, nameof(Group), quoteof(n.by, ctx)...) + s = n.sets + if s !== nothing + push!(ex.args, Expr(:kw, :sets, s isa GroupingMode ? QuoteNode(Symbol(s)) : s)) + end if n.name !== nothing push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) end diff --git a/src/resolve.jl b/src/resolve.jl index 32dafdf7..81b76d92 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -324,7 +324,7 @@ function resolve(n::GroupNode, ctx) fields[n.name] = RowType(FieldTypeMap(), group) group = EmptyType() end - n′ = Group(over = over′, by = by′, label_map = n.label_map) + n′ = Group(over = over′, by = by′, sets = n.sets, label_map = n.label_map) Resolved(RowType(fields, group), over = n′) end diff --git a/src/serialize.jl b/src/serialize.jl index 6c60e3b8..4765ae3b 100644 --- a/src/serialize.jl +++ b/src/serialize.jl @@ -538,7 +538,36 @@ function serialize!(c::GroupClause, ctx) !isempty(c.by) || return newline(ctx) print(ctx, "GROUP BY") - serialize_lines!(c.by, ctx) + sets = c.sets + if sets !== nothing + if sets isa GroupingMode + if sets == GROUPING_MODE.ROLLUP + print(ctx, " ROLLUP(") + elseif sets == GROUPING_MODE.CUBE + print(ctx, " CUBE(") + else + throw(DomainError(sets)) + end + serialize!(c.by, ctx) + print(ctx, ')') + else + print(ctx, " GROUPING SETS(") + first = true + for set in sets + if !first + print(ctx, ", ") + else + first = false + end + print(ctx, '(') + serialize!(c.by[set], ctx) + print(ctx, ')') + end + print(ctx, ')') + end + else + serialize_lines!(c.by, ctx) + end end function serialize!(c::HavingClause, ctx) diff --git a/src/translate.jl b/src/translate.jl index 462e25bd..9f9a7ce9 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -618,15 +618,15 @@ function assemble(n::GroupNode, ctx) push!(trns, ref => translate(over, ctx, subs)) end end - if !has_aggregates + if !has_aggregates && n.sets === nothing for name in keys(n.label_map) push!(trns, Get(name = name) => by[n.label_map[name]]) end end repl, cols = make_repl_cols(trns) @assert !isempty(cols) - if has_aggregates - c = GROUP(over = tail, by = by) + if has_aggregates || n.sets !== nothing + c = GROUP(over = tail, by = by, sets = n.sets) else args = complete(cols) c = SELECT(over = tail, distinct = true, args = args)