Skip to content

Commit

Permalink
Merge pull request #63 from fjebaker/fergus/parameter-api
Browse files Browse the repository at this point in the history
Feat!: new parameter API
  • Loading branch information
fjebaker authored Oct 1, 2023
2 parents 41ddea3 + 35d1d7f commit daec873
Show file tree
Hide file tree
Showing 43 changed files with 614 additions and 720 deletions.
4 changes: 4 additions & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ using DocStringExtensions
abstract type AbstractMission end
struct NoMission <: AbstractMission end

abstract type AbstractStatistic end

# unitful units
include("units.jl")
SpectralUnits.@reexport using .SpectralUnits

include("print-utilities.jl")

include("fitparam.jl")
include("param-cache.jl")
include("abstract-models.jl")

include("ccall-wrapper.jl")
Expand All @@ -66,6 +69,7 @@ include("model-data-io.jl")
include("fitting/result.jl")
include("fitting/cache.jl")
include("fitting/problem.jl")
include("fitting/binding.jl")
include("fitting/multi-cache.jl")
include("fitting/methods.jl")
include("fitting/statistics.jl")
Expand Down
107 changes: 59 additions & 48 deletions src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ export AbstractSpectralModel,
invokemodel!,
objective_cache_count,
modelparameters,
freeparameters,
frozenparameters,
updateparameters!
updateparameters!,
make_parameter_cache

"""
abstract type AbstractSpectralModel{T,K}
Expand Down Expand Up @@ -57,7 +56,7 @@ abstract type AbstractSpectralModelKind end
Additive models are effectively the sources of photons, and are the principle building blocks
of composite models. Every additive model has a normalisation parameter which re-scales the
flux by a constant factor `K`.
output by a constant factor `K`.
!!! note
Defining custom additive models requires special care. See [Defining new models](@ref).
Expand All @@ -68,15 +67,15 @@ struct Additive <: AbstractSpectralModelKind end
Multiplicative()
Multiplicative models act on [`Additive`](@ref) models, by element-wise
multiplying the flux in each energy bin of the additive model by a different factor.
multiplying the output in each energy bin of the additive model by a different factor.
"""
struct Multiplicative <: AbstractSpectralModelKind end
"""
Convolutional <: AbstractSpectralModelKind
Convolutional()
Convolutional models act on the flux generated by [`Additive`](@ref) models, similar to
[`Multiplicative`](@ref) models, however may convolve kernels through the flux also.
Convolutional models act on the output generated by [`Additive`](@ref) models, similar to
[`Multiplicative`](@ref) models, however may convolve kernels through the output also.
"""
struct Convolutional <: AbstractSpectralModelKind end

Expand Down Expand Up @@ -125,10 +124,10 @@ ConstructionBase.constructorof(::Type{M}) where {M<:AbstractSpectralModel} = M
# never to be called directly
# favour `invokemodel!` instead
"""
SpectralFitting.invoke!(flux, energy, M::Type{<:AbstractSpectralModel}, params...)
SpectralFitting.invoke!(output, energy, M::Type{<:AbstractSpectralModel}, params...)
Used to define the behaviour of models. Should calculate flux of the model and write in-place
into `flux`.
Used to define the behaviour of models. Should calculate output of the model and write in-place
into `output`.
!!! warning
This function should not be called directly. Use [`invokemodel`](@ref) instead.
Expand All @@ -144,15 +143,15 @@ end
```
would have the arguments passed to `invoke!` as
```julia
function SpectralFitting.invoke!(flux, energy, ::Type{<:MyModel}, p1, p2, p3, ...)
function SpectralFitting.invoke!(output, energy, ::Type{<:MyModel}, p1, p2, p3, ...)
# ...
end
```
The only exception to this are [`Additive`](@ref) models, where the normalisation parameter
`K` is not passed to `invoke!`.
"""
invoke!(flux, energy, M::AbstractSpectralModel) = error("Not defined for $(M).")
invoke!(output, energy, M::AbstractSpectralModel) = error("Not defined for $(M).")

"""
invokemodel(energy, model)
Expand All @@ -169,7 +168,7 @@ any normalisation or post-processing tasks that a specific model kind may requir
Users should always call models using [`invokemodel`](@ref) or [`invokemodel!`](@ref) to ensure
normalisations and closures are accounted for.
`invokemodel` allocates the needed flux arrays based on the element type of `free_params` to allow
`invokemodel` allocates the needed output arrays based on the element type of `free_params` to allow
automatic differentation libraries to calculate parameter gradients.
In-place non-allocating variants are the [`invokemodel!`](@ref) functions.
Expand All @@ -186,29 +185,23 @@ invokemodel(energy, model, p0)
```
"""
function invokemodel(e, m::AbstractSpectralModel)
flux = construct_objective_cache(m, e) |> vec
invokemodel!(flux, e, m)
flux
end
function invokemodel(e, m::AbstractSpectralModel, free_params)
model = remake_with_free(m, free_params)
flux = construct_objective_cache(eltype(free_params), m, e) |> vec
invokemodel!(flux, e, model)
flux
output = construct_objective_cache(m, e) |> vec
invokemodel!(output, e, m)
output
end

"""
invokemodel!(flux, energy, model)
invokemodel!(flux, energy, model, free_params)
invokemodel!(flux, energy, model, free_params, frozen_params)
invokemodel!(output, energy, model)
invokemodel!(output, energy, model, free_params)
invokemodel!(output, energy, model, free_params, frozen_params)
In-place variant of [`invokemodel`](@ref), calculating the flux of an [`AbstractSpectralModel`](@ref)
In-place variant of [`invokemodel`](@ref), calculating the output of an [`AbstractSpectralModel`](@ref)
given by `model`, optionally overriding the free and/or frozen parameter values. These arguments
may be a vector or tuple with element type [`FitParam`](@ref) or `Number`.
The number of fluxes to allocate for a model may change if using any [`CompositeModel`](@ref)
as the `model`. It is generally recommended to use [`objective_cache_count`](@ref) to ensure the correct number
of flux arrays are allocated with [`construct_objective_cache`](@ref) when using composite models.
of output arrays are allocated with [`construct_objective_cache`](@ref) when using composite models.
Single spectral model components should use [`make_flux`](@ref) instead.
Expand All @@ -217,23 +210,26 @@ Single spectral model components should use [`make_flux`](@ref) instead.
```julia
model = XS_PowerLaw()
energy = collect(range(0.1, 20.0, 100))
flux = make_flux(model, energy)
invokemodel!(flux, energy, model)
output = make_flux(model, energy)
invokemodel!(output, energy, model)
p0 = [0.1, 2.0] # change K and a
invokemodel!(flux, energy, model, p0)
invokemodel!(output, energy, model, p0)
```
"""
@inline function invokemodel!(f, e, m::AbstractSpectralModel, free_params)
# update only the free parameters
model = remake_with_free(m, free_params)
invokemodel!(view(f, :, 1), e, model)
end

@inline function invokemodel!(f, e, m::AbstractSpectralModel{<:FitParam})
# need to extract the parameter values
model = remake_with_number_type(m)
invokemodel!(view(f, :, 1), e, model)
end
@inline function invokemodel!(f, e, m::AbstractSpectralModel, cache::ParameterCache)
invokemodel!(f, e, m, cache.parameters)
end
@inline function invokemodel!(f, e, m::AbstractSpectralModel, parameters::AbstractArray)
invokemodel!(view(f, :, 1), e, remake_with_parameters(m, parameters))
end

invokemodel!(
f::AbstractVector,
e::AbstractVector,
Expand Down Expand Up @@ -284,32 +280,21 @@ end

modelparameters(model::AbstractSpectralModel{T}) where {T} =
T[model_parameters_tuple(model)...]
freeparameters(model::AbstractSpectralModel{T}) where {T} =
T[free_parameters_tuple(model)...]
frozenparameters(model::AbstractSpectralModel{T}) where {T} =
T[frozen_parameters_tuple(model)...]

# todo: this function could be cleaned up with some generated hackery
function remake_with_number_type(model::AbstractSpectralModel{P}, T::Type) where {P}
M = typeof(model).name.wrapper
params = modelparameters(model)
params = model_parameters_tuple(model)
new_params = if P <: FitParam
convert.(T, get_value.(params))
else
convert.(T, param)
end
M{T,FreeParameters{free_parameter_symbols(model)}}(new_params...)
M{T}(new_params...)
end
remake_with_number_type(model::AbstractSpectralModel{FitParam{T}}) where {T} =
remake_with_number_type(model, T)

remake_with_free(
model::AbstractSpectralModel{T},
free_params::AbstractVector{T},
) where {T<:Number} = updatefree(model, free_params)
remake_with_free(model::AbstractSpectralModel, free_params) =
updatefree(remake_with_number_type(model, eltype(free_params)), free_params)

"""
updatemodel(model::AbstractSpectralModel; kwargs...)
updatemodel(model::AbstractSpectralModel, patch::NamedTuple)
Expand Down Expand Up @@ -357,3 +342,29 @@ function updateparameters!(model::AbstractSpectralModel{<:FitParam}; params...)
end
model
end

_allocate_free_parameters(model::AbstractSpectralModel) =
filter(isfree, modelparameters(model))

function make_parameter_cache(model::AbstractSpectralModel)
parameters = modelparameters(model)
ParameterCache(parameters)
end

function make_diff_parameter_cache(
model::AbstractSpectralModel;
param_diff_cache_size = nothing,
)
parameters = modelparameters(model)
free_mask = _make_free_mask(parameters)

vals = map(get_value, parameters)
N = isnothing(param_diff_cache_size) ? length(vals) : param_diff_cache_size
diffcache = DiffCache(vals, ForwardDiff.pickchunksize(N))

# embed current parameter values inside of the dual cache
# else all frozens will be zero
get_tmp(diffcache, ForwardDiff.Dual(1.0)) .= vals

ParameterCache(free_mask, diffcache)
end
4 changes: 2 additions & 2 deletions src/ccall-wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ If the callsite is not specified, the user must implement [`_unsafe_ffi_invoke!`
# Examples
```julia
@xspecmodel :C_powerlaw struct XS_PowerLaw{T,F} <: AbstractSpectralModel{T, Additive}
@xspecmodel :C_powerlaw struct XS_PowerLaw{T} <: AbstractSpectralModel{T, Additive}
"Normalisation."
K::T
"Photon index."
Expand All @@ -130,7 +130,7 @@ end
# constructor has default values
function XS_PowerLaw(; K = FitParam(1.0), a = FitParam(1.0))
XS_PowerLaw{typeof(K), SpectralFitting.FreeParameters{(:K, :a)}}(K, a)
XS_PowerLaw{typeof(K)}(K, a)
end
```
Expand Down
61 changes: 33 additions & 28 deletions src/composite-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,37 +101,22 @@ function closurekind(::Type{<:CompositeModel{M1,M2}}) where {M1,M2}
end

# invocation wrappers
function invokemodel(e, m::CompositeModel)
fluxes = construct_objective_cache(m, e)
invokemodel!(fluxes, e, m)
view(fluxes, :, 1)
end

function invokemodel!(f, e, model::CompositeModel)
@assert size(f, 2) == objective_cache_count(model) "Too few flux arrays allocated for this model."
generated_model_call!(f, e, model, model_parameters_tuple(model))
end
function invokemodel!(f, e, model::CompositeModel, free_params, frozen_params)
@assert size(f, 2) == objective_cache_count(model) "Too few flux arrays allocated for this model."
generated_model_call!(f, e, model, free_params, frozen_params)
end
function invokemodel!(f, e, model::CompositeModel, free_params)

function invokemodel!(f, e, model::CompositeModel, parameters::AbstractArray)
@assert size(f, 2) == objective_cache_count(model) "Too few flux arrays allocated for this model."
frozen_params = convert.(eltype(free_params), frozenparameters(model))
invokemodel!(f, e, model, free_params, frozen_params)
generated_model_call!(f, e, model, parameters)
end

function invokemodel(e, m::CompositeModel)
fluxes = construct_objective_cache(m, e)
invokemodel!(fluxes, e, m)
view(fluxes, :, 1)
end
function invokemodel(e, m::CompositeModel, free_params)
if eltype(free_params) <: Number
# for compatability with AD
fluxes = construct_objective_cache(eltype(free_params), m, e)
invokemodel!(fluxes, e, m, free_params)
else
p0 = get_value.(free_params)
fluxes = construct_objective_cache(eltype(p0), m, e)
invokemodel!(fluxes, e, m, p0)
end
view(fluxes, :, 1)
end

# algebra grammar
add_models(_, _, ::M1, ::M2) where {M1,M2} =
Expand Down Expand Up @@ -159,7 +144,7 @@ conv_models(m1::M1, m2::M2) where {M1,M2} =

function Base.show(io::IO, @nospecialize(model::CompositeModel))
expr, infos = _destructure_for_printing(model)
for (symbol, (m, _, _)) in zip(keys(infos), infos)
for (symbol, (m, _)) in zip(keys(infos), infos)
expr =
replace(expr, "$(symbol)" => "$(FunctionGeneration.model_base_name(typeof(m)))")
end
Expand Down Expand Up @@ -221,11 +206,11 @@ function _printinfo(io::IO, model::CompositeModel{M1,M2}) where {M1,M2}

println(io, "Model key and parameters:")
sym_buffer = 5
param_name_offset = sym_buffer + maximum(infos) do (_, syms, _)
param_name_offset = sym_buffer + maximum(infos) do (_, syms)
maximum(length(s) for s in syms)
end
buff = IOBuffer()
for (symbol, (m, param_symbols, states)) in zip(keys(infos), infos)
for (symbol, (m, param_symbols)) in zip(keys(infos), infos)
M = typeof(m)
basename = FunctionGeneration.model_base_name(M)
println(
Expand All @@ -239,7 +224,8 @@ function _printinfo(io::IO, model::CompositeModel{M1,M2}) where {M1,M2}
Crayons.Crayon(reset = true),
)

for (val, s::String, free::Bool) in zip(modelparameters(m), param_symbols, states)
for (val, s::String) in zip(modelparameters(m), param_symbols)
free = !isfrozen(val)
_print_param(buff, free, s, val, param_name_offset, q1, q2, q3, q4)
end
end
Expand All @@ -255,5 +241,24 @@ ConstructionBase.setproperties(::CompositeModel, ::NamedTuple) =
ConstructionBase.constructorof(::Type{<:CompositeModel}) =
throw("Cannot be used with `CompositeModel`.")

function Base.propertynames(model::CompositeModel, private::Bool = false)
all_parameter_symbols(model)
end

function Base.getproperty(model::CompositeModel, symb::Symbol)
lookup = all_parameters_to_named_tuple(model)
lookup[symb]
end

function Base.setproperty!(model::CompositeModel, symb::Symbol, value::FitParam)
set!(getproperty(model, symb), value)
end

function Base.setproperty!(model::CompositeModel, symb::Symbol, x)
error(
"Only `FitParam` may be directly set with another `FitParam`. Use `set_value!` and related API otherwise.",
)
end

# function ConstructionBase.setproperties(m::CompositeModel, patch::NamedTuple)
# end
21 changes: 0 additions & 21 deletions src/datasets/response.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,6 @@ function normalise_rows!(matrix)
end
end

function build_response_matrix!(
R,
f_chan::Matrix,
n_chan::Matrix,
matrix_rows::Vector,
first_channel,
)
for (i, (F, N)) in enumerate(zip(eachcol(f_chan), eachcol(n_chan)))
M = matrix_rows
index = 1
for (first, len) in zip(F, N)
if len == 0
break
end
first -= first_channel
@views R[first+1:first+len, i] .= M[index:index+len-1]
index += len
end
end
end

function Base.show(
io::IO,
::MIME{Symbol("text/plain")},
Expand Down
Loading

0 comments on commit daec873

Please sign in to comment.