diff --git a/Project.toml b/Project.toml index a4bc2e9..871ffeb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EpithelialDynamics1D" uuid = "ace8a2d7-7779-48a6-a8a4-cf6831a7e55b" authors = ["Daniel VandenHeuvel "] -version = "1.6.0" +version = "1.7.0" [deps] CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" diff --git a/src/statistics.jl b/src/statistics.jl index dc043ad..f6fd3b8 100644 --- a/src/statistics.jl +++ b/src/statistics.jl @@ -79,14 +79,19 @@ function node_densities(cell_positions::AbstractVector{T}; smooth_boundary=true) end """ - get_knots(sol, num_knots = 500; indices = eachindex(sol), use_max=true) + get_knots(sol, num_knots = 500; indices = eachindex(sol), stat=maximum) Computes knots for each time, covering the extremum of the cell positions across all cell simulations. You can restrict the simultaions to consider using the `indices`. -If `use_max` is true, then the knots will be obtained by taking the extreme node positions -for each `time`, otherwise the average is used. +The knots are obtained by applying `stat`, a tuple of functions for the left and right sides, +to the vector of extrema at each time. For example, if `stat=maximum` then, at each time, +the knots range between the smallest position observed and the maximum position +observed across each simulation at that time. """ -function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), use_extrema=true) +function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), stat=(minimum, maximum)) + if stat isa Function + stat = (stat, stat) + end @static if VERSION < v"1.7" knots = Vector{LinRange{Float64}}(undef, length(first(sol))) else @@ -95,31 +100,13 @@ function get_knots(sol::EnsembleSolution, num_knots=500; indices=eachindex(sol), times = first(sol).t Base.Threads.@threads for i in eachindex(times) local a, b - if use_extrema - a = Inf - b = -Inf - else - a = 0.0 - b = 0.0 - ctr = 0 - end - for j in indices - _a = sol[j][i][begin] - _b = sol[j][i][end] - if use_extrema - a = min(a, _a) - b = max(b, _b) - else - a += _a - b += _b - ctr += 1 - end - end - if !use_extrema - a /= ctr - b /= ctr + a = zeros(length(indices)) + b = zeros(length(indices)) + for (ℓ, j) in enumerate(indices) + a[ℓ] = sol[j].u[i][begin] + b[ℓ] = sol[j].u[i][end] end - knots[i] = LinRange(a, b, num_knots) + knots[i] = LinRange(stat[1](a), stat[2](b), num_knots) end return knots end @@ -131,7 +118,14 @@ function get_knots(sol::ODESolution, num_knots=500) end """ - node_densities(sol::EnsembleSolution; num_knots=500, knots=get_knots(sol, num_knots), alpha=0.05, interp_fnc=(u, t) -> LinearInterpolation{true}(u, t), smooth_boundary=true) + node_densities(sol::EnsembleSolution; + indices=eachindex(sol), + num_knots=500, + stat=(minimum, maximum), + knots=get_knots(sol, num_knots; indices, stat), + alpha=0.05, + interp_fnc=(u, t) -> LinearInterpolation{true}(u, t), + smooth_boundary=true) Computes summary statistics for the node densities from an `EnsembleSolution` to a [`CellProblem`](@ref). @@ -141,8 +135,8 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to # Keyword Arguments - `indices = eachindex(sol)`: The indices of the cell simulations to consider. - `num_knots::Int = 500`: The number of knots to use for the spline interpolation. -- `use_extrema::Bool = true`: Whether to use the extrema of the cell positions for the knots, or the average. -- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, use_extrema)`: The knots to use for the spline interpolation. +- `stat = (minimum, maximum)`: How to summarise the knots for `get_knots`. +- `knots::Vector{Vector{Float64}} = get_knots(sol, num_knots; indices, stat)`: The knots to use for the spline interpolation. - `alpha::Float64 = 0.05`: The significance level for the confidence intervals. - `interp_fnc = (u, t) -> LinearInterpolation{true}(u, t)`: The function to use for constructing the interpolant. - `smooth_boundary::Bool = true`: Whether to use the smooth boundary node densities. @@ -158,8 +152,8 @@ Computes summary statistics for the node densities from an `EnsembleSolution` to function node_densities(sol::EnsembleSolution; indices=eachindex(sol), num_knots=500, - use_extrema=true, - knots=get_knots(sol, num_knots; indices, use_extrema), + stat=(minimum, maximum), + knots=get_knots(sol, num_knots; indices, stat), alpha=0.05, interp_fnc=(u, t) -> LinearInterpolation{true}(u, t), smooth_boundary=true) diff --git a/test/step_function.jl b/test/step_function.jl index c348dfa..1fdff7d 100644 --- a/test/step_function.jl +++ b/test/step_function.jl @@ -344,8 +344,8 @@ end # Using average leading edge _indices = rand(eachindex(sol), 40) - q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false, smooth_boundary=false) - @inferred node_densities(sol; indices=_indices, use_extrema=false) + q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, stat=mean, smooth_boundary=false) + @inferred node_densities(sol; indices=_indices, stat=mean) @test all(≈(LinRange(0, 30, 500)), knots) for (enum_k, k) in enumerate(_indices) for j in rand(1:length(sol[k]), 40) @@ -543,13 +543,13 @@ end # Test the statistics with a specific interpolation function _indices = rand(eachindex(sol), 20) - q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, interp_fnc=CubicSpline) + q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, interp_fnc=CubicSpline, stat = minimum) @inferred node_densities(sol; indices=_indices, interp_fnc=CubicSpline) for j in eachindex(knots) a = Inf b = -Inf m = minimum(sol[k][j][begin] for k in _indices) - M = maximum(sol[k][j][end] for k in _indices) + M = minimum(sol[k][j][end] for k in _indices) @test knots[j] == LinRange(m, M, 500) end for (enum_k, k) in enumerate(_indices) @@ -582,8 +582,8 @@ end _indices = rand(eachindex(sol), 20) _L = _L[:, _indices] _mL = mean.(eachrow(_L)) - q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, use_extrema=false) - @inferred node_densities(sol; indices=_indices, use_extrema=false) + q, r, means, lowers, uppers, knots = node_densities(sol; indices=_indices, stat=mean) + @inferred node_densities(sol; indices=_indices, stat=mean) for j in eachindex(knots) a = mean(sol[k][j][begin] for k in _indices) b = mean(sol[k][j][end] for k in _indices)