diff --git a/src/condense.jl b/src/condense.jl index 73939b86..4dd3afd2 100644 --- a/src/condense.jl +++ b/src/condense.jl @@ -175,6 +175,28 @@ function get_basis_colnames(rhs::AbstractTerm) return colnames(rhs.basisfunction) end +""" + get_basisnames(model::UnfoldModel) + +Return the basisnames for all predictor terms as a vector. + +The returned vector contains the name of the event type/basis, repeated by their actual coefficient number (after StatsModels.apply_schema / timeexpansion). +If a model has more than one event type (e.g. stimulus and fixation), the vectors are concatenated. +""" +@traitfn function get_basis_names(m::T) where {T <: UnfoldModel; !ContinuousTimeTrait{T}} + + # Extract the event names from the design + design_keys = first.((Unfold.design(m))) + + # Create a list of the basis names corresponding to each model term + basisnames = String[] + for (ix, event) in enumerate(design_keys) + push!(basisnames, repeat([event], size(modelmatrix(m)[ix], 2))...) + end + return basisnames +end + + @traitfn get_basis_names(m::T) where {T <: UnfoldModel; ContinuousTimeTrait{T}} = get_basis_names.(formulas(m)) function get_basis_names(m::FormulaTerm)