diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index 0a2b1909..3256308b 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -1,61 +1,67 @@ export HierarchicalMeasure +""" + struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure -# TODO: Document and use FlattenMode -abstract type FlattenMode end -struct NoFlatten <: FlattenMode end -struct AutoFlatten <: FlattenMode end +Represents a hierarchical measure. - -struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure +User code should not instantiate `HierarchicalMeasure` directly, use +[`hierarchical_measure`](@ref) instead. +""" +struct HierarchicalMeasure{F,M<:AbstractMeasure,G} <: AbstractMeasure f::F m::M - flatten_mode::FM + flatten::G end -# TODO: Document -const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten} -export HierarchicalProductMeasure - -HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten()) +""" + hierarchical_measure(f, m::AbstractMeasure, flatten) -# TODO: Document -const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten} -export FlatHierarchicalMeasure +Construct a hierarchical measure from a function `f`, measure `m` and +""" +@inline function hierarchical_measure(f, m::AbstractMeasure, flatten) + F, M, G = Core.Typeof(f), Core.Typeof(m), Core.Typeof(flatten) + HierarchicalProductMeasure{F,M,G}(f, m, flatten) +end -FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten()) -HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m) +#!!!!!! +const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(=>)} +const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,::typeof(vcat)} -function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2}) - @assert x isa Tuple{2} - return x[1], x[2] +function _split_variate(::typeof(=>), ::AbstractMeasure, x::Pair) + return x.first, x.second end +function _split_variate(flatten::F, μ_primary::AbstractMeasure, x) where F + test_primary = testvalue(μ_primary) + return _split_variate_byvalue(flatten, test_primary, x) +end -function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x) - a_test = testvalue(μ) - return _autosplit_variate_after_testvalue(a_test, x) +function _split_variate(::Type{F}, μ::AbstractMeasure, x) where F + test_primary = testvalue(μ) + return _split_variate_byvalue(F, test_primary, x) end -function _autosplit_variate_after_testvalue(::Any, x) + +function _split_variate_byvalue(::Any, x) @assert x isa Tuple{2} return x[1], x[2] end -function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector) - n, m = length(eachindex(a_test)), length(eachindex(x)) +function _split_variate_byvalue(test_primary::AbstractVector, x::AbstractVector) + n, m = length(eachindex(test_primary)), length(eachindex(x)) # TODO: Use getindex or view? return x[begin:n], x[begin+n:m] end -function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M} +function _split_variate_byvalue(::Tuple{N}, x::Tuple{M}) where {N,M} return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) end -@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} +@generated function _split_variate_byvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} # TODO: implement @assert false end @@ -147,7 +153,7 @@ end function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) dof_μ = getdof(μ) - x_μ, x_rest = _split_variate_after(flatten_mode, μ, x) + x_μ, x_rest = _split_variate(flatten_mode, μ, x) y = transport_to(ν_inner^dof_μ, μ, x_μ) return y, x_rest end