Skip to content

Commit

Permalink
Custom stat
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielVandH committed Jul 18, 2023
1 parent 081e2ac commit e84dc6e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EpithelialDynamics1D"
uuid = "ace8a2d7-7779-48a6-a8a4-cf6831a7e55b"
authors = ["Daniel VandenHeuvel <danj.vandenheuvel@gmail.com>"]
version = "1.6.0"
version = "1.7.0"

[deps]
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Expand Down
60 changes: 27 additions & 33 deletions src/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,19 @@ function node_densities(cell_positions::AbstractVector{T}; smooth_boundary=true)
end

"""
get_knots(sol, num_knots = 500; indices = eachindex(sol), use_max=true)
get_knots(sol, num_knots = 500; indices = eachindex(sol), stat=maximum)
Computes knots for each time, covering the extremum of the cell positions across all
cell simulations. You can restrict the simultaions to consider using the `indices`.
If `use_max` is true, then the knots will be obtained by taking the extreme node positions
for each `time`, otherwise the average is used.
The knots are obtained by applying `stat`, a tuple of functions for the left and right sides,
to the vector of extrema at each time. For example, if `stat=maximum` then, at each time,
the knots range between the smallest position observed and the maximum position
observed across each simulation at that time.
"""
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), use_extrema=true)
function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), stat=(minimum, maximum))
if stat isa Function
stat = (stat, stat)
end
@static if VERSION < v"1.7"
knots = Vector{LinRange{Float64}}(undef, length(first(sol)))
else
Expand All @@ -95,31 +100,13 @@ function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol),
times = first(sol).t
Base.Threads.@threads for i in eachindex(times)
local a, b
if use_extrema
a = Inf
b = -Inf
else
a = 0.0
b = 0.0
ctr = 0
end
for j in indices
_a = sol[j][i][begin]
_b = sol[j][i][end]
if use_extrema
a = min(a, _a)
b = max(b, _b)
else
a += _a
b += _b
ctr += 1
end
end
if !use_extrema
a /= ctr
b /= ctr
a = zeros(length(indices))
b = zeros(length(indices))
for (ℓ, j) in enumerate(indices)
a[ℓ] = sol[j].u[i][begin]
b[ℓ] = sol[j].u[i][end]
end
knots[i] = LinRange(a, b, num_knots)
knots[i] = LinRange(stat[1](a), stat[2](b), num_knots)
end
return knots
end
Expand All @@ -131,7 +118,14 @@ function get_knots(sol::ODESolution, num_knots=500)
end

"""
node_densities(sol::EnsembleSolution; num_knots=500, knots=get_knots(sol, num_knots), alpha=0.05, interp_fnc=(u, t) -> LinearInterpolation{true}(u, t), smooth_boundary=true)
node_densities(sol::EnsembleSolution;
indices=eachindex(sol),
num_knots=500,
stat=(minimum, maximum),
knots=get_knots(sol, num_knots; indices, stat),
alpha=0.05,
interp_fnc=(u, t) -> LinearInterpolation{true}(u, t),
smooth_boundary=true)
Computes summary statistics for the node densities from an `EnsembleSolution` to a [`CellProblem`](@ref).
Expand All @@ -141,8 +135,8 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
# Keyword Arguments
- `indices = eachindex(sol)`: The indices of the cell simulations to consider.
- `num_knots::Int = 500`: The number of knots to use for the spline interpolation.
- `use_extrema::Bool = true`: Whether to use the extrema of the cell positions for the knots, or the average.
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, use_extrema)`: The knots to use for the spline interpolation.
- `stat = (minimum, maximum)`: How to summarise the knots for `get_knots`.
- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, stat)`: The knots to use for the spline interpolation.
- `alpha::Float64 = 0.05`: The significance level for the confidence intervals.
- `interp_fnc = (u, t) -> LinearInterpolation{true}(u, t)`: The function to use for constructing the interpolant.
- `smooth_boundary::Bool = true`: Whether to use the smooth boundary node densities.
Expand All @@ -158,8 +152,8 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to
function node_densities(sol::EnsembleSolution;
indices=eachindex(sol),
num_knots=500,
use_extrema=true,
knots=get_knots(sol, num_knots; indices, use_extrema),
stat=(minimum, maximum),
knots=get_knots(sol, num_knots; indices, stat),
alpha=0.05,
interp_fnc=(u, t) -> LinearInterpolation{true}(u, t),
smooth_boundary=true)
Expand Down
12 changes: 6 additions & 6 deletions test/step_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ end

# Using average leading edge
_indices = rand(eachindex(sol), 40)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false, smooth_boundary=false)
@inferred node_densities(sol; indices=_indices, use_extrema=false)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, stat=mean, smooth_boundary=false)
@inferred node_densities(sol; indices=_indices, stat=mean)
@test all((LinRange(0, 30, 500)), knots)
for (enum_k, k) in enumerate(_indices)
for j in rand(1:length(sol[k]), 40)
Expand Down Expand Up @@ -543,13 +543,13 @@ end

# Test the statistics with a specific interpolation function
_indices = rand(eachindex(sol), 20)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, interp_fnc=CubicSpline)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, interp_fnc=CubicSpline, stat = minimum)
@inferred node_densities(sol; indices=_indices, interp_fnc=CubicSpline)
for j in eachindex(knots)
a = Inf
b = -Inf
m = minimum(sol[k][j][begin] for k in _indices)
M = maximum(sol[k][j][end] for k in _indices)
M = minimum(sol[k][j][end] for k in _indices)
@test knots[j] == LinRange(m, M, 500)
end
for (enum_k, k) in enumerate(_indices)
Expand Down Expand Up @@ -582,8 +582,8 @@ end
_indices = rand(eachindex(sol), 20)
_L = _L[:, _indices]
_mL = mean.(eachrow(_L))
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false)
@inferred node_densities(sol; indices=_indices, use_extrema=false)
q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, stat=mean)
@inferred node_densities(sol; indices=_indices, stat=mean)
for j in eachindex(knots)
a = mean(sol[k][j][begin] for k in _indices)
b = mean(sol[k][j][end] for k in _indices)
Expand Down

0 comments on commit e84dc6e

Please sign in to comment.