Skip to content

Commit

Permalink
Better apprach using class and cached properites
Browse files Browse the repository at this point in the history
  • Loading branch information
samaloney committed Jan 30, 2025
1 parent fb1d3ba commit d1eccdb
Showing 1 changed file with 90 additions and 52 deletions.
142 changes: 90 additions & 52 deletions sunkit_spex/models/physical/thermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,36 @@ def setup_default_abundances(filename=None):
return load_xray_abundances()


# Read line, continuum and abundance data into global variables.
CONTINUUM_GRID = None
LINE_GRID = None
DEFAULT_ABUNDANCES = None
DEFAULT_ABUNDANCE_TYPE = "sun_coronal_ext"
class ThermalParameters:
_continuum_grid = None
_line_grid = None
_default_abundances = None
_default_abundance_type = "sun_coronal_ext"

@property
def continuum_grid(self):
if not self._continuum_grid:
self._continuum_grid = setup_continuum_parameters()
return self._continuum_grid

@property
def line_grid(self):
if not self._line_grid:
self._line_grid = setup_line_parameters()
return self._line_grid

@property
def default_abundances(self):
if not self._default_abundances:
self._default_abundances = setup_default_abundances()
return self._default_abundances

@property
def default_abundance_type(self):
return self._default_abundance_type


THERMAL_PARAMETERS = ThermalParameters()


@u.quantity_input(
Expand All @@ -204,7 +229,7 @@ def thermal_emission(
energy_edges,
temperature,
emission_measure,
abundance_type=DEFAULT_ABUNDANCE_TYPE,
abundance_type=THERMAL_PARAMETERS.default_abundance_type,
relative_abundances=None,
observer_distance=(1 * u.AU).to(u.cm),
):
Expand All @@ -218,19 +243,26 @@ def thermal_emission(
{doc_string_params}"""
# Convert inputs to known units and confirm they are within range.
energy_edges_keV, temperature_K = _sanitize_inputs(energy_edges, temperature)
global CONTINUUM_GRID, LINE_GRID
if CONTINUUM_GRID is None:
CONTINUUM_GRID = setup_continuum_parameters()
if LINE_GRID is None:
LINE_GRID = setup_line_parameters()
energy_range = (
min(CONTINUUM_GRID["energy range keV"][0], LINE_GRID["energy range keV"][0]),
max(CONTINUUM_GRID["energy range keV"][1], LINE_GRID["energy range keV"][1]),
min(
THERMAL_PARAMETERS.continuum_grid["energy range keV"][0],
THERMAL_PARAMETERS.line_grid["energy range keV"][0],
),
max(
THERMAL_PARAMETERS.continuum_grid["energy range keV"][1],
THERMAL_PARAMETERS.line_grid["energy range keV"][1],
),
)
_error_if_input_outside_valid_range(energy_edges_keV, energy_range, "energy", "keV")
temp_range = (
min(CONTINUUM_GRID["temperature range K"][0], LINE_GRID["temperature range K"][0]),
max(CONTINUUM_GRID["temperature range K"][1], LINE_GRID["temperature range K"][1]),
min(
THERMAL_PARAMETERS.continuum_grid["temperature range K"][0],
THERMAL_PARAMETERS.line_grid["temperature range K"][0],
),
max(
THERMAL_PARAMETERS.continuum_grid["temperature range K"][1],
THERMAL_PARAMETERS.line_grid["temperature range K"][1],
),
)
_error_if_input_outside_valid_range(temperature_K, temp_range, "temperature", "K")
# Calculate abundances
Expand All @@ -251,7 +283,7 @@ def continuum_emission(
energy_edges,
temperature,
emission_measure,
abundance_type=DEFAULT_ABUNDANCE_TYPE,
abundance_type=THERMAL_PARAMETERS.default_abundance_type,
relative_abundances=None,
observer_distance=(1 * u.AU).to(u.cm),
):
Expand All @@ -264,12 +296,13 @@ def continuum_emission(
{doc_string_params}"""
# Convert inputs to known units and confirm they are within range.
global CONTINUUM_GRID
if CONTINUUM_GRID is None:
CONTINUUM_GRID = setup_continuum_parameters()
energy_edges_keV, temperature_K = _sanitize_inputs(energy_edges, temperature)
_error_if_input_outside_valid_range(energy_edges_keV, CONTINUUM_GRID["energy range keV"], "energy", "keV")
_error_if_input_outside_valid_range(temperature_K, CONTINUUM_GRID["temperature range K"], "temperature", "K")
_error_if_input_outside_valid_range(
energy_edges_keV, THERMAL_PARAMETERS.continuum_grid["energy range keV"], "energy", "keV"
)
_error_if_input_outside_valid_range(
temperature_K, THERMAL_PARAMETERS.continuum_grid["temperature range K"], "temperature", "K"
)
# Calculate abundances
abundances = _calculate_abundances(abundance_type, relative_abundances)
# Calculate flux.
Expand All @@ -287,7 +320,7 @@ def line_emission(
energy_edges,
temperature,
emission_measure,
abundance_type=DEFAULT_ABUNDANCE_TYPE,
abundance_type=THERMAL_PARAMETERS.default_abundance_type,
relative_abundances=None,
observer_distance=(1 * u.AU).to(u.cm),
):
Expand All @@ -297,11 +330,12 @@ def line_emission(
{docstring_params}"""
# Convert inputs to known units and confirm they are within range.
energy_edges_keV, temperature_K = _sanitize_inputs(energy_edges, temperature)
global LINE_GRID
if LINE_GRID is None:
LINE_GRID = setup_line_parameters()
_warn_if_input_outside_valid_range(energy_edges_keV, LINE_GRID["energy range keV"], "energy", "keV")
_error_if_input_outside_valid_range(temperature_K, LINE_GRID["temperature range K"], "temperature", "K")
_warn_if_input_outside_valid_range(
energy_edges_keV, THERMAL_PARAMETERS.line_grid["energy range keV"], "energy", "keV"
)
_error_if_input_outside_valid_range(
temperature_K, THERMAL_PARAMETERS.line_grid["temperature range K"], "temperature", "K"
)
# Calculate abundances
abundances = _calculate_abundances(abundance_type, relative_abundances)

Expand Down Expand Up @@ -341,7 +375,7 @@ def _continuum_emission(energy_edges_keV, temperature_K, abundances):

# Mask Unwanted Abundances
abundance_mask = np.zeros(len(abundances))
abundance_mask[CONTINUUM_GRID["abundance index"]] = 1.0
abundance_mask[THERMAL_PARAMETERS.continuum_grid["abundance index"]] = 1.0
abundances *= abundance_mask

# Calculate Continuum Intensity Summed Over All Elements
Expand All @@ -352,23 +386,24 @@ def _continuum_emission(energy_edges_keV, temperature_K, abundances):
# temperatures here. If only a few temperatures are input, do this step only
# when looping over input temperatures. This minimizes computation.
n_tband = 3
n_t_grid = len(CONTINUUM_GRID["log10T"])
n_t_grid = len(THERMAL_PARAMETERS.continuum_grid["log10T"])
n_temperature_K = len(temperature_K)
n_thresh = n_temperature_K * n_tband
if n_thresh >= n_t_grid:
intensity_per_em_at_source_allT = np.zeros(CONTINUUM_GRID["intensity"].shape[1:])
intensity_per_em_at_source_allT = np.zeros(THERMAL_PARAMETERS.continuum_grid["intensity"].shape[1:])
for i in range(0, n_t_grid):
intensity_per_em_at_source_allT[i] = np.matmul(
abundances[CONTINUUM_GRID["sorted abundance index"]], CONTINUUM_GRID["intensity"][:, i]
abundances[THERMAL_PARAMETERS.continuum_grid["sorted abundance index"]],
THERMAL_PARAMETERS.continuum_grid["intensity"][:, i],
)
# 2. Add dummy axes to energy and temperature grid arrays for later vectorized operations.
repeat_E_grid = CONTINUUM_GRID["E_keV"][np.newaxis, :]
repeat_T_grid = CONTINUUM_GRID["T_keV"][:, np.newaxis]
dE_grid_keV = CONTINUUM_GRID["energy bin widths keV"][np.newaxis, :]
repeat_E_grid = THERMAL_PARAMETERS.continuum_grid["E_keV"][np.newaxis, :]
repeat_T_grid = THERMAL_PARAMETERS.continuum_grid["T_keV"][:, np.newaxis]
dE_grid_keV = THERMAL_PARAMETERS.continuum_grid["energy bin widths keV"][np.newaxis, :]
# 3. Identify the indices of the temperature bins containing each input temperature and
# the bins above and below them. For each input temperature, these three bins will
# act as a temperature band over which we'll interpolate the continuum emission.
selt = np.digitize(log10T_in, CONTINUUM_GRID["log10T"]) - 1
selt = np.digitize(log10T_in, THERMAL_PARAMETERS.continuum_grid["log10T"]) - 1
tband_idx = selt[:, np.newaxis] + np.arange(n_tband)[np.newaxis, :]

# Finally, loop over input temperatures and calculate continuum emission for each.
Expand All @@ -377,11 +412,12 @@ def _continuum_emission(energy_edges_keV, temperature_K, abundances):
# If not already done above, calculate continuum intensity summed over
# all elements as a function of energy/wavelength over the temperature band.
if n_thresh < n_t_grid:
element_intensities_per_em_at_source = CONTINUUM_GRID["intensity"][:, tband_idx[j]]
element_intensities_per_em_at_source = THERMAL_PARAMETERS.continuum_grid["intensity"][:, tband_idx[j]]
intensity_per_em_at_source = np.zeros(element_intensities_per_em_at_source.shape[1:])
for i in range(0, n_tband):
intensity_per_em_at_source[i] = np.matmul(
abundances[CONTINUUM_GRID["sorted abundance index"]], element_intensities_per_em_at_source[:, i]
abundances[THERMAL_PARAMETERS.continuum_grid["sorted abundance index"]],
element_intensities_per_em_at_source[:, i],
)
else:
intensity_per_em_at_source = intensity_per_em_at_source_allT[tband_idx[j]]
Expand All @@ -396,13 +432,17 @@ def _continuum_emission(energy_edges_keV, temperature_K, abundances):
# Interpolate the normalized temperature component of the intensity grid the the
# input temperature.
flux[j] = _interpolate_continuum_intensities(
gaunt, CONTINUUM_GRID["log10T"][tband_idx[j]], CONTINUUM_GRID["E_keV"], energy_gmean_keV, logt
gaunt,
THERMAL_PARAMETERS.continuum_grid["log10T"][tband_idx[j]],
THERMAL_PARAMETERS.continuum_grid["E_keV"],
energy_gmean_keV,
logt,
)
# Rescale the interpolated intensity.
flux = flux * np.exp(-(energy_gmean_keV[np.newaxis, :] / T_in_keV[:, np.newaxis]))

# Put intensity into correct units.
return flux * CONTINUUM_GRID["intensity unit"]
return flux * THERMAL_PARAMETERS.continuum_grid["intensity unit"]


def _line_emission(energy_edges_keV, temperature_K, abundances):
Expand All @@ -428,27 +468,28 @@ def _line_emission(energy_edges_keV, temperature_K, abundances):

# Find indices of lines within user input energy range.
energy_roi_indices = np.logical_and(
LINE_GRID["line peaks keV"] >= energy_edges_keV.min(), LINE_GRID["line peaks keV"] <= energy_edges_keV.max()
THERMAL_PARAMETERS.line_grid["line peaks keV"] >= energy_edges_keV.min(),
THERMAL_PARAMETERS.line_grid["line peaks keV"] <= energy_edges_keV.max(),
)
n_energy_roi_indices = energy_roi_indices.sum()
# If there are emission lines within the energy range of interest, compile spectrum.
if n_energy_roi_indices > 0:
# Mask Unwanted Abundances
abundance_mask = np.zeros(len(abundances))
abundance_mask[LINE_GRID["abundance index"]] = 1.0
abundance_mask[THERMAL_PARAMETERS.line_grid["abundance index"]] = 1.0
abundances *= abundance_mask
# Extract only the lines within the energy range of interest.
line_abundances = abundances[LINE_GRID["line atomic numbers"][energy_roi_indices] - 2]
line_abundances = abundances[THERMAL_PARAMETERS.line_grid["line atomic numbers"][energy_roi_indices] - 2]
# Above magic number of of -2 is comprised of:
# a -1 to account for the fact that index is atomic number -1, and
# another -1 because abundance index is offset from abundance index by 1.

##### Calculate Line Intensities within the Input Energy Range #####
# Calculate abundance-normalized intensity of each line in energy range of
# interest as a function of energy and temperature.
line_intensity_grid = LINE_GRID["intensity"][energy_roi_indices]
line_intensity_grid = THERMAL_PARAMETERS.line_grid["intensity"][energy_roi_indices]
line_intensities = _calculate_abundance_normalized_line_intensities(
np.log10(temperature_K), line_intensity_grid, LINE_GRID["log10T"]
np.log10(temperature_K), line_intensity_grid, THERMAL_PARAMETERS.line_grid["log10T"]
)
# Scale line intensities by abundances to get true line intensities.
line_intensities *= line_abundances
Expand All @@ -459,7 +500,7 @@ def _line_emission(energy_edges_keV, temperature_K, abundances):
# when averaged over neighboring bins.
# This has the effect of appearing to double the number of lines as regards
# the dimensionality of the line_intensities array.
line_peaks_keV = LINE_GRID["line peaks keV"][energy_roi_indices]
line_peaks_keV = THERMAL_PARAMETERS.line_grid["line peaks keV"][energy_roi_indices]
split_line_intensities, line_spectrum_bins = _weight_emission_bins_to_line_centroid(
line_peaks_keV, energy_edges_keV, line_intensities
)
Expand All @@ -477,7 +518,7 @@ def _line_emission(energy_edges_keV, temperature_K, abundances):
# Scale flux by observer distance, emission measure and spectral bin width
# and put into correct units.
energy_bin_widths = (energy_edges_keV[1:] - energy_edges_keV[:-1]) * u.keV
return flux * LINE_GRID["intensity unit"] / energy_bin_widths
return flux * THERMAL_PARAMETERS.line_grid["intensity unit"] / energy_bin_widths


def _interpolate_continuum_intensities(data_grid, log10T_grid, energy_grid_keV, energy_keV, log10T):
Expand Down Expand Up @@ -764,18 +805,15 @@ def _warn_if_input_outside_valid_range(input_values, grid_range, param_name, par


def _calculate_abundances(abundance_type, relative_abundances):
global DEFAULT_ABUNDANCES
if DEFAULT_ABUNDANCES is None:
DEFAULT_ABUNDANCES = setup_default_abundances()
abundances = DEFAULT_ABUNDANCES[abundance_type].data
abundances = THERMAL_PARAMETERS.default_abundances[abundance_type].data
if relative_abundances:
# Convert input relative abundances to array where
# first axis is atomic number, i.e == index + 1
# Second axis is relative abundance value.
rel_abund_array = np.array(relative_abundances).T
# Confirm relative abundances are for valid elements and positive.
min_abundance_z = DEFAULT_ABUNDANCES["atomic number"].min()
max_abundance_z = DEFAULT_ABUNDANCES["atomic number"].max()
min_abundance_z = THERMAL_PARAMETERS.default_abundances["atomic number"].min()
max_abundance_z = THERMAL_PARAMETERS.default_abundances["atomic number"].max()
if rel_abund_array[0].min() < min_abundance_z or rel_abund_array[0].max() > max_abundance_z:
raise ValueError(
"Relative abundances can only be set for elements with "
Expand Down

0 comments on commit d1eccdb

Please sign in to comment.