diff --git a/Project.toml b/Project.toml index 661aaf6e..d5ff8381 100644 --- a/Project.toml +++ b/Project.toml @@ -26,12 +26,13 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [compat] ChainRulesCore = "1" -ChangesOfVariables = "0.1.3" +ChangesOfVariables = "0.1" Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" @@ -52,6 +53,7 @@ Reexport = "1" SpecialFunctions = "2" Static = "0.8, 1" Statistics = "1" +StatsBase = "0.34" Test = "1" Tricks = "0.1" julia = "1.6" diff --git a/src/rand.jl b/src/rand.jl index f92cb16a..db5a31cd 100644 --- a/src/rand.jl +++ b/src/rand.jl @@ -8,6 +8,24 @@ Base.rand(rng::AbstractRNG, d::AbstractMeasure) = rand(rng, Float64, d) @inline Random.rand!(d::AbstractMeasure, args...) = rand!(GLOBAL_RNG, d, args...) +@inline function Base.rand( + rng::AbstractRNG, + ::Type{T}, + d::ProductMeasure{A}, +) where {T,A<:AbstractArray} + mar = marginals(d) + + # Distributions doens't (yet) have the three-argument form + elT = typeof(rand(rng, T, first(mar))) + + sz = size(mar) + x = Array{elT,length(sz)}(undef, sz) + @inbounds @simd for j in eachindex(mar) + x[j] = rand(rng, T, mar[j]) + end + x +end + # TODO: Make this work # function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractMeasure) where {T} # x = testvalue(d) diff --git a/src/utils.jl b/src/utils.jl index d0f85481..2d294b37 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -164,3 +164,12 @@ using InverseFunctions: FunctionWithInverse unwrap(f) = f unwrap(f::FunctionWithInverse) = f.f + +import Statistics +import StatsBase + +StatsBase.entropy(m::AbstractMeasure, b::Real) = entropy(proxy(m), b) +Statistics.mean(m::AbstractMeasure) = mean(proxy(m)) +Statistics.std(m::AbstractMeasure) = std(proxy(m)) +Statistics.var(m::AbstractMeasure) = var(proxy(m)) +Statistics.quantile(m::AbstractMeasure, q) = quantile(proxy(m), q) \ No newline at end of file