Skip to content

Commit

Permalink
STASH smart ctors, canonical measure nesting
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Jul 16, 2023
1 parent 3a77a05 commit 8dad17e
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions src/combinators/smart-constructors.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@

# Canonical measure type nesting, outer to inner:
#
# WeightedMeasure, Dirac, PowerMeasure, ProductMeasure


###############################################################################
# Half

Expand All @@ -19,25 +24,27 @@ Constructs a power of a measure `μ`.
function powermeasure end
export powermeasure

powermeasure(m::AbstractMeasure, ::Tuple{}) = asmeasure(m)
@inline powermeasure(μ, exponent) = _generic_powermeasure_impl(asmeasure(μ), _pm_axes(exponent))

@inline _generic_powermeasure_stage1::AbstractMeasure, ::Tuple{}) = μ

@inline function powermeasure(x::T, sz::Tuple{Vararg{Any,N}}) where {T,N}
PowerMeasure(asmeasure(x), _pm_axes(sz))
@inline function _generic_powermeasure_stage1::AbstractMeasure, exponent::Tuple)
_generic_powermeasure_stage2(μ, exponent)
end

function powermeasure(
μ::WeightedMeasure,
dims::Tuple{<:AbstractArray,Vararg{AbstractArray}},
)
k = mapreduce(length, *, dims) * μ.logweight
return weightedmeasure(k, μ.base^dims)
@inline _generic_powermeasure_stage2::AbstractMeasure, exponent::Tuple) = PowerMeasure(μ, exponent)

@inline function _generic_powermeasure_stage2::Dirac, exponent::Tuple)
Dirac(maybestatic_fill.value, exponent))
end

function powermeasure::WeightedMeasure, dims::NonEmptyTuple)
k = prod(dims) * μ.logweight
return weightedmeasure(k, μ.base^dims)
@inline function _generic_powermeasure_stage2::WeightedMeasure, exponent::Tuple)
ν = μ.base^exponent
k = maybestatic_length(ν) * μ.logweight
return weightedmeasure(k, ν)
end


###############################################################################
# ProductMeasure

Expand All @@ -58,25 +65,30 @@ productmeasure((pushfwd(Mul(scale), StdExponential()) for scale in 0.1:0.2:2))
function productmeasure end
export productmeasure

productmeasure(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar))
@inline productmeasure(mar) = _generic_procuctmeasure_impl(mar)

@inline _generic_procuctmeasure_impl(mar::Fill) = powermeasure(_fill_value(mar), _fill_axes(mar))

productmeasure(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar)
productmeasure(mar::Tuple) = ProductMeasure(map(asmeasure, mar))
@inline _generic_procuctmeasure_impl(mar::Tuple{Vararg{AbstractMeasure}}) = ProductMeasure(mar)
_generic_procuctmeasure_impl(mar::Tuple{Vararg{Dirac}}) = Dirac(map(m -> m.value), mar)
_generic_procuctmeasure_impl(mar::Tuple) = productmeasure(map(asmeasure, mar))

productmeasure(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar)
productmeasure(mar::NamedTuple) = ProductMeasure(map(asmeasure, mar))
@inline _generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{AbstractMeasure}}}) where names = ProductMeasure(mar)
_generic_procuctmeasure_impl(mar::NamedTuple{names,<:Tuple{Vararg{Dirac}}}) where names = Dirac(map(m -> m.value), mar)
_generic_procuctmeasure_impl(mar::NamedTuple) = productmeasure(map(asmeasure, mar))

productmeasure(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar)
productmeasure(mar::AbstractArray) = ProductMeasure(asmeasure.(mar))
@inline _generic_procuctmeasure_impl(mar::AbstractArray{<:AbstractProductMeasure}) = ProductMeasure(mar)
_generic_procuctmeasure_impl(mar::AbstractArray{<:Dirac}) = Dirac((m -> m.value).(mar))
_generic_procuctmeasure_impl(mar::AbstractArray) = ProductMeasure(asmeasure.(mar))

function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M}
@inline function _generic_procuctmeasure_impl(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M}
return powermeasure(mar.f.value, axes(mar.data))
end

productmeasure(mar::Base.Generator) = ProductMeasure(mar)
@inline _generic_procuctmeasure_impl(mar::Base.Generator) = ProductMeasure(mar)

# TODO: Make this static when its length is static
@inline function productmeasure(
@inline function _generic_procuctmeasure_impl(
mar::AbstractArray{<:WeightedMeasure{StaticFloat64{W},M}},
) where {W,M}
return weightedmeasure(W * length(mar), productmeasure(map(basemeasure, mar)))
Expand Down

0 comments on commit 8dad17e

Please sign in to comment.