Skip to content

Commit

Permalink
fix: channel specific response selection
Browse files Browse the repository at this point in the history
No longer delete rows from the response matrix when dropping bad
bins, instead we just select the channels we want at the time of
fitting.

Formatted.
  • Loading branch information
fjebaker committed Oct 1, 2023
1 parent 0331da0 commit 088e835
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 40 deletions.
4 changes: 3 additions & 1 deletion src/datasets/binning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ function augmented_energy_channels(channels, other_channels, bins_low, bins_high
for (i, c) in enumerate(channels)
index = findnext(==(c), other_channels, i)
if isnothing(index)
error("Failed to find channel in response corresponding to channel $c in spectrum.")
error(
"Failed to find channel in response corresponding to channel $c in spectrum.",
)
end
if index > lastindex(bins_low)
break
Expand Down
8 changes: 4 additions & 4 deletions src/datasets/ogip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ end
function read_filename(header, entry, parent, exts...)
data_directory = Base.dirname(parent)
parent_name = basename(parent)
if haskey(header, entry)
if haskey(header, entry)
path::String = strip(header[entry])
name = find_file(data_directory, path, parent_name, exts)
if !ismissing(name)
Expand All @@ -313,8 +313,8 @@ end
function find_file(dir, name, parent, extensions)
if length(name) == 0
return missing
elseif match(r"%match%", name) !== nothing
base = splitext(parent)[1]
elseif match(r"%match%", name) !== nothing
base = splitext(parent)[1]
for ext in extensions
testfile = joinpath(dir, base * ext)
if isfile(testfile)
Expand All @@ -340,4 +340,4 @@ function read_fits_header(path; hdu = 2)
end
end

export read_fits_header
export read_fits_header
2 changes: 1 addition & 1 deletion src/datasets/ogipdataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ function _printinfo(io, data::OGIPDataset{T}) where {T}
_printinfo(io, data.data)
end

export OGIPDataset
export OGIPDataset
18 changes: 15 additions & 3 deletions src/datasets/response.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,22 @@ function _printinfo(io, resp::AncillaryResponse{T}) where {T}
print(io, descr)
end

fold_ancillary(response::ResponseMatrix, ancillary::AncillaryResponse) =
ancillary.effective_area' .* response.matrix
function fold_ancillary(
channels::AbstractVector{<:Int},
response::ResponseMatrix,
ancillary::AncillaryResponse,
)
@views ancillary.effective_area' .* response.matrix[channels, :]
end

function fold_ancillary(
channels::AbstractVector{<:Int},
response::ResponseMatrix,
::Missing,
)
@views response.matrix[channels, :]
end

fold_ancillary(response::ResponseMatrix, ::Missing) = response.matrix

function Base.show(
io::IO,
Expand Down
70 changes: 41 additions & 29 deletions src/datasets/spectraldata.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ function Base.show(io::IO, ::MIME"text/plain", @nospecialize(paths::SpectralData
print(io, descr)
end

function SpectralDataPaths(; spectrum = missing, background = missing, response = missing, ancillary = missing)
SpectralDataPaths(
spectrum,
background,
response,
ancillary,
)
function SpectralDataPaths(;
spectrum = missing,
background = missing,
response = missing,
ancillary = missing,
)
SpectralDataPaths(spectrum, background, response, ancillary)
end

function SpectralDataPaths(spec_path)
Expand Down Expand Up @@ -64,7 +64,16 @@ function SpectralData(
domain = _make_domain_vector(spectrum, response)
energy_low, energy_high = _make_energy_vector(spectrum, response)
data_mask = BitVector(fill(true, size(spectrum.data)))
SpectralData(spectrum, response, background, ancillary, energy_low, energy_high, domain, data_mask)
SpectralData(
spectrum,
response,
background,
ancillary,
energy_low,
energy_high,
domain,
data_mask,
)
end

supports_contiguosly_binned(::Type{<:SpectralData}) = true
Expand Down Expand Up @@ -107,7 +116,10 @@ function objective_transformer(
layout::ContiguouslyBinned,
dataset::SpectralData{T},
) where {T}
R = fold_ancillary(dataset.response, dataset.ancillary)[dataset.data_mask, :]
R = fold_ancillary(dataset.spectrum.channels, dataset.response, dataset.ancillary)[
dataset.data_mask,
:,
]
ΔE = bin_widths(dataset)
E = response_energy(dataset.response)
cache = DiffCache(construct_objective_cache(layout, T, length(E), 1))
Expand All @@ -124,7 +136,8 @@ function objective_transformer(
_transformer!!
end

bin_widths(dataset::SpectralData) = (dataset.energy_high .- dataset.energy_low)[dataset.data_mask]
bin_widths(dataset::SpectralData) =
(dataset.energy_high.-dataset.energy_low)[dataset.data_mask]
has_background(dataset::SpectralData) = !ismissing(dataset.background)
has_ancillary(dataset::SpectralData) = !ismissing(dataset.ancillary)

Expand All @@ -140,7 +153,6 @@ end

function drop_channels!(dataset::SpectralData, indices)
drop_channels!(dataset.spectrum, indices)
drop_channels!(dataset.response, indices)
if has_background(dataset)
drop_channels!(dataset.background, indices)
end
Expand All @@ -151,7 +163,7 @@ function drop_channels!(dataset::SpectralData, indices)
end

spectrum_energy(dataset::SpectralData) =
((dataset.energy_low .+ dataset.energy_high) ./ 2)[dataset.data_mask]
((dataset.energy_low.+dataset.energy_high)./2)[dataset.data_mask]

function regroup!(dataset::SpectralData, grouping; safety_copy = false)
grp::typeof(grouping) = if safety_copy
Expand Down Expand Up @@ -298,47 +310,47 @@ macro _forward_SpectralData_api(args)
SpectralFitting.make_model_domain(
layout::SpectralFitting.AbstractDataLayout,
t::$(T),
) = SpectralFitting.make_model_domain(layout, getproperty(t, $(field)))
) = SpectralFitting.make_model_domain(layout, getfield(t, $(field)))
SpectralFitting.make_domain_variance(
layout::SpectralFitting.AbstractDataLayout,
t::$(T),
) = SpectralFitting.make_domain_variance(layout, getproperty(t, $(field)))
) = SpectralFitting.make_domain_variance(layout, getfield(t, $(field)))
SpectralFitting.make_objective(
layout::SpectralFitting.AbstractDataLayout,
t::$(T),
) = SpectralFitting.make_objective(layout, getproperty(t, $(field)))
) = SpectralFitting.make_objective(layout, getfield(t, $(field)))
SpectralFitting.make_objective_variance(
layout::SpectralFitting.AbstractDataLayout,
t::$(T),
) = SpectralFitting.make_objective_variance(layout, getproperty(t, $(field)))
) = SpectralFitting.make_objective_variance(layout, getfield(t, $(field)))
SpectralFitting.objective_transformer(
layout::SpectralFitting.AbstractDataLayout,
t::$(T),
) = SpectralFitting.objective_transformer(layout, getproperty(t, $(field)))
) = SpectralFitting.objective_transformer(layout, getfield(t, $(field)))
SpectralFitting.regroup!(t::$(T), args...) =
SpectralFitting.regroup!(getproperty(t, $(field)), args...)
SpectralFitting.regroup!(getfield(t, $(field)), args...)
SpectralFitting.restrict_domain!(t::$(T), args...) =
SpectralFitting.restrict_domain!(getproperty(t, $(field)), args...)
SpectralFitting.restrict_domain!(getfield(t, $(field)), args...)
SpectralFitting.mask_energies!(t::$(T), args...) =
SpectralFitting.mask_energies!(getproperty(t, $(field)), args...)
SpectralFitting.mask_energies!(getfield(t, $(field)), args...)
SpectralFitting.drop_channels!(t::$(T), args...) =
SpectralFitting.drop_channels!(getproperty(t, $(field)), args...)
SpectralFitting.drop_channels!(getfield(t, $(field)), args...)
SpectralFitting.drop_bad_channels!(t::$(T)) =
SpectralFitting.drop_bad_channels!(getproperty(t, $(field)))
SpectralFitting.drop_bad_channels!(getfield(t, $(field)))
SpectralFitting.drop_negative_channels!(t::$(T)) =
SpectralFitting.drop_negative_channels!(getproperty(t, $(field)))
SpectralFitting.drop_negative_channels!(getfield(t, $(field)))
SpectralFitting.normalize!(t::$(T)) =
SpectralFitting.normalize!(getproperty(t, $(field)))
SpectralFitting.normalize!(getfield(t, $(field)))
SpectralFitting.objective_units(t::$(T)) =
SpectralFitting.objective_units(getproperty(t, $(field)))
SpectralFitting.objective_units(getfield(t, $(field)))
SpectralFitting.spectrum_energy(t::$(T)) =
SpectralFitting.spectrum_energy(getproperty(t, $(field)))
SpectralFitting.spectrum_energy(getfield(t, $(field)))
SpectralFitting.bin_widths(t::$(T)) =
SpectralFitting.bin_widths(getproperty(t, $(field)))
SpectralFitting.bin_widths(getfield(t, $(field)))
SpectralFitting.subtract_background!(t::$(T), args...) =
SpectralFitting.subtract_background!(getproperty(t, $(field)), args...)
SpectralFitting.subtract_background!(getfield(t, $(field)), args...)
SpectralFitting.set_domain!(t::$(T), args...) =
SpectralFitting.set_domain!(getproperty(t, $(field)), args...)
SpectralFitting.set_domain!(getfield(t, $(field)), args...)
end |> esc
end

Expand Down
6 changes: 4 additions & 2 deletions src/plotting-recipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ plotting_domain(dataset::InjectiveData) = dataset.domain
y = _f_objective(result.config)(result.config.domain, result.u)
x = plotting_domain(dataset)
if length(y) != length(x)
error("Domain mismatch. Are you sure you're plotting the result with the right dataset?")
error(
"Domain mismatch. Are you sure you're plotting the result with the right dataset?",
)
end
x, y
end
Expand Down Expand Up @@ -84,7 +86,7 @@ end

data = r.args[1]
x = plotting_domain(data)
result= r.args[2] isa FittingResult ? r.args[2][1] : r.args[2]
result = r.args[2] isa FittingResult ? r.args[2][1] : r.args[2]
y = invoke_result(result, result.u)

y_ratio = @. result.objective / y
Expand Down

0 comments on commit 088e835

Please sign in to comment.