Skip to content

Commit

Permalink
add get_basis_names for massunivariate models (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
behinger authored Jun 24, 2024
1 parent 875b04f commit ce2b36f
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/condense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,28 @@ function get_basis_colnames(rhs::AbstractTerm)
return colnames(rhs.basisfunction)
end

"""
get_basisnames(model::UnfoldModel)
Return the basisnames for all predictor terms as a vector.
The returned vector contains the name of the event type/basis, repeated by their actual coefficient number (after StatsModels.apply_schema / timeexpansion).
If a model has more than one event type (e.g. stimulus and fixation), the vectors are concatenated.
"""
@traitfn function get_basis_names(m::T) where {T <: UnfoldModel; !ContinuousTimeTrait{T}}

# Extract the event names from the design
design_keys = first.((Unfold.design(m)))

# Create a list of the basis names corresponding to each model term
basisnames = String[]
for (ix, event) in enumerate(design_keys)
push!(basisnames, repeat([event], size(modelmatrix(m)[ix], 2))...)
end
return basisnames
end


@traitfn get_basis_names(m::T) where {T <: UnfoldModel; ContinuousTimeTrait{T}} =
get_basis_names.(formulas(m))
function get_basis_names(m::FormulaTerm)
Expand Down

0 comments on commit ce2b36f

Please sign in to comment.