From a28a36fc538d22a0b20d048d8102fc88ac2a2d2c Mon Sep 17 00:00:00 2001 From: Aymeric Galan Date: Tue, 4 Jun 2024 00:33:34 +0200 Subject: [PATCH] Add plotting support for stored likelihood masks --- coolest/api/plotting.py | 20 ++++++++++++-- coolest/template/classes/likelihood.py | 21 ++++++++++----- coolest/template/classes/likelihood_list.py | 29 ++++++++++++++++++++- coolest/template/classes/observation.py | 4 +-- coolest/template/json.py | 9 +++++-- 5 files changed, 70 insertions(+), 13 deletions(-) diff --git a/coolest/api/plotting.py b/coolest/api/plotting.py index 1216034..d52619c 100644 --- a/coolest/api/plotting.py +++ b/coolest/api/plotting.py @@ -178,10 +178,11 @@ def plot_model_residuals(self, ax, mask=None, cmap = self.cmap_res if norm is None: norm = Normalize(-6, 6) + ll_mask = self._get_likelihood_mask(mask) lens_model = ComposableLensModel(self.coolest, self._directory, kwargs_selection_source=kwargs_source, kwargs_selection_lens_mass=kwargs_lens_mass) - image, coordinates = lens_model.model_residuals(mask=mask, **model_image_kwargs) + image, coordinates = lens_model.model_residuals(mask=ll_mask, **model_image_kwargs) extent = coordinates.plt_extent ax, im = plut.plot_regular_grid(ax, image, extent=extent, cmap=cmap, @@ -193,7 +194,8 @@ def plot_model_residuals(self, ax, mask=None, if add_scalebar: plut.scale_bar(ax, scalebar_size, color='black', loc='lower right') if add_chi2_label is True: - num_constraints = np.size(image) if mask is None else np.sum(mask) + num_constraints = np.size(image) if ll_mask is None else np.sum(ll_mask) + print("num_constraints", num_constraints) red_chi2 = np.sum(image**2) / num_constraints ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1, fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes, @@ -336,6 +338,20 @@ def plot_magnification_diff( plut.scale_bar(ax, scalebar_size, color='black', loc='lower right') return image + def _get_likelihood_mask(self, user_mask): + # TODO: + if self.coolest.likelihoods is None: + return None + try: + img_ll_idx = self.coolest.likelihoods.index('ImagingDataLikelihood') + except ValueError: + return None + img_ll = self.coolest.likelihoods[img_ll_idx] + mask = img_ll.get_mask_pixels(directory=self._directory) + if mask is None: # then we use the user-provided mask + mask = user_mask + return mask + class MultiModelPlotter(object): """Wrapper around a set of ModelPlotter instances to produce panels that diff --git a/coolest/template/classes/likelihood.py b/coolest/template/classes/likelihood.py index 129ef83..ce74d0a 100644 --- a/coolest/template/classes/likelihood.py +++ b/coolest/template/classes/likelihood.py @@ -1,5 +1,7 @@ __author__ = 'aymgal' +import numpy as np + from coolest.template.classes.base import APIBaseObject from coolest.template.classes.grid import PixelatedRegularGrid @@ -57,15 +59,22 @@ def __init__(self, mask: PixelatedRegularGrid = None, ) -> None: self.mask = mask + super().__init__() + + def get_mask_pixels(self, directory=None): + pixels = self.mask.get_pixels(directory=directory) + # checks that it contains only 0 and 1 + assert np.all(np.isin(pixels, [0, 1])), "Imaging likelihood is not binary." + return pixels def check_consistency_with_observation(self, observation): """Checks that the data image is consistent with instrument properties""" - width = abs(self.pixels.field_of_view_x[1] - self.pixels.field_of_view_x[0]) - height = abs(self.pixels.field_of_view_y[1] - self.pixels.field_of_view_y[0]) - num_pix_ra = int(width / observation.pixels.pixel_size) + width = abs(self.mask.field_of_view_x[1] - self.mask.field_of_view_x[0]) + height = abs(self.mask.field_of_view_y[1] - self.mask.field_of_view_y[0]) + num_pix_ra = round(width / observation.pixels.pixel_size) error_message_ra = f"Field-of-view along RA is inconsistent (data: {num_pix_ra}, likelihood mask: {self.mask.num_pix_x})." - assert self.pixels.num_pix_x == num_pix_ra, error_message_ra - num_pix_dec = int(height / observation.pixels.pixel_size) + assert self.mask.num_pix_x == num_pix_ra, error_message_ra + num_pix_dec = round(height / observation.pixels.pixel_size) error_message_dec = f"Field-of-view along Dec is inconsistent (data: {num_pix_dec}, likelihood mask: {self.mask.num_pix_y})." - assert self.pixels.num_pix_y == num_pix_dec, error_message_dec + assert self.mask.num_pix_y == num_pix_dec, error_message_dec # TODO: check pixel size value? diff --git a/coolest/template/classes/likelihood_list.py b/coolest/template/classes/likelihood_list.py index cb33c7f..04933e0 100644 --- a/coolest/template/classes/likelihood_list.py +++ b/coolest/template/classes/likelihood_list.py @@ -1,8 +1,9 @@ __author__ = 'aymgal' +from copy import deepcopy from typing import Tuple -from coolest.template.classes.likelihood import Likelihood +from coolest.template.classes.likelihood import Likelihood, SUPPORTED_CHOICES from coolest.template.classes.base import APIBaseObject class DataLikelihoodList(list, APIBaseObject): @@ -13,5 +14,31 @@ class DataLikelihoodList(list, APIBaseObject): """ def __init__(self, *likelihoods: Tuple[Likelihood]) -> None: + # check that there is no duplicate likelihood type + unique_choices = deepcopy(SUPPORTED_CHOICES) + for ll in likelihoods: + if ll.type in unique_choices: + unique_choices.remove(ll.type) + else: + raise ValueError(f"There cannot be more than one likelihood type '{ll.type}'.") list.__init__(self, likelihoods) APIBaseObject.__init__(self) + + def index(self, likelihood_type): + """Override the parent class to look for a likelihood from its string type. + + Parameters + ---------- + likelihood_type : string + Type of likelihood + """ + found = False + for i, ll in enumerate(self): + if ll.type == likelihood_type: + return i + raise ValueError(f"Likelihood '{likelihood_type}' is not in the likelihood list.") + + def check_consistency_with_observation(self, observation): + for ll in self: + ll.check_consistency_with_observation(observation) + \ No newline at end of file diff --git a/coolest/template/classes/observation.py b/coolest/template/classes/observation.py index 81ec64b..cd47bd5 100644 --- a/coolest/template/classes/observation.py +++ b/coolest/template/classes/observation.py @@ -53,10 +53,10 @@ def check_consistency_with_instrument(self, instrument): """Checks that the data image is consistent with instrument properties""" width = abs(self.pixels.field_of_view_x[1] - self.pixels.field_of_view_x[0]) height = abs(self.pixels.field_of_view_y[1] - self.pixels.field_of_view_y[0]) - num_pix_ra = int(width / instrument.pixel_size) + num_pix_ra = round(width / instrument.pixel_size) error_message_ra = f"Field-of-view along RA is inconsistent (data: {self.pixels.num_pix_x}, instrument: {num_pix_ra})." assert self.pixels.num_pix_x == num_pix_ra, error_message_ra - num_pix_dec = int(height / instrument.pixel_size) + num_pix_dec = round(height / instrument.pixel_size) error_message_dec = f"Field-of-view along Dec is inconsistent (data: {self.pixels.num_pix_y}, instrument: {num_pix_dec})." assert self.pixels.num_pix_y == num_pix_dec, error_message_dec # TODO: check pixel size value? diff --git a/coolest/template/json.py b/coolest/template/json.py index 72ac036..558f028 100644 --- a/coolest/template/json.py +++ b/coolest/template/json.py @@ -109,6 +109,8 @@ def load(self, skip_jsonpickle=False, verbose=True): jsonpickle_path = self.path + self._api_suffix + '.json' if os.path.exists(jsonpickle_path) and not skip_jsonpickle: instance = self.load_jsonpickle(jsonpickle_path) + # TODO: the following line is only for backward compatibility and will soon be removed + instance.likelihoods = None else: if verbose: print(f"Template file '{jsonpickle_path}' not found, now trying to read '{json_path}'.") @@ -235,8 +237,11 @@ class constructors called during instantiation of the COOLEST object. if obs_pix_size not in (0, None) and not isclose_bool: raise ValueError(f"Pixel size of observation ({obs_pix_size}) is inconsistent with " f"the instrument pixel size ({instru_pix_size})") - - # TODO: add extra checks + # INSTANCE METHODS + coolest.observation.check_consistency_with_instrument(coolest.instrument) + if coolest.likelihoods is not None: + coolest.likelihoods.check_consistency_with_observation(coolest.observation) + # TODO: further standardize these checks (using class methods?) def _setup_instrument(self, instru_in): psf_settings = instru_in.pop('psf')