Skip to content

Commit

Permalink
Add plotting support for stored likelihood masks
Browse files Browse the repository at this point in the history
  • Loading branch information
aymgal committed Jun 3, 2024
1 parent 9a869f9 commit a28a36f
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 13 deletions.
20 changes: 18 additions & 2 deletions coolest/api/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,11 @@ def plot_model_residuals(self, ax, mask=None,
cmap = self.cmap_res
if norm is None:
norm = Normalize(-6, 6)
ll_mask = self._get_likelihood_mask(mask)
lens_model = ComposableLensModel(self.coolest, self._directory,
kwargs_selection_source=kwargs_source,
kwargs_selection_lens_mass=kwargs_lens_mass)
image, coordinates = lens_model.model_residuals(mask=mask, **model_image_kwargs)
image, coordinates = lens_model.model_residuals(mask=ll_mask, **model_image_kwargs)
extent = coordinates.plt_extent
ax, im = plut.plot_regular_grid(ax, image, extent=extent,
cmap=cmap,
Expand All @@ -193,7 +194,8 @@ def plot_model_residuals(self, ax, mask=None,
if add_scalebar:
plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
if add_chi2_label is True:
num_constraints = np.size(image) if mask is None else np.sum(mask)
num_constraints = np.size(image) if ll_mask is None else np.sum(ll_mask)
print("num_constraints", num_constraints)
red_chi2 = np.sum(image**2) / num_constraints
ax.text(0.05, 0.05, r'$\chi^2_\nu$='+f'{red_chi2:.2f}', color='black', alpha=1,
fontsize=chi2_fontsize, va='bottom', ha='left', transform=ax.transAxes,
Expand Down Expand Up @@ -336,6 +338,20 @@ def plot_magnification_diff(
plut.scale_bar(ax, scalebar_size, color='black', loc='lower right')
return image

def _get_likelihood_mask(self, user_mask):
# TODO:
if self.coolest.likelihoods is None:
return None
try:
img_ll_idx = self.coolest.likelihoods.index('ImagingDataLikelihood')
except ValueError:
return None
img_ll = self.coolest.likelihoods[img_ll_idx]
mask = img_ll.get_mask_pixels(directory=self._directory)
if mask is None: # then we use the user-provided mask
mask = user_mask
return mask


class MultiModelPlotter(object):
"""Wrapper around a set of ModelPlotter instances to produce panels that
Expand Down
21 changes: 15 additions & 6 deletions coolest/template/classes/likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__author__ = 'aymgal'

import numpy as np

from coolest.template.classes.base import APIBaseObject
from coolest.template.classes.grid import PixelatedRegularGrid

Expand Down Expand Up @@ -57,15 +59,22 @@ def __init__(self,
mask: PixelatedRegularGrid = None,
) -> None:
self.mask = mask
super().__init__()

def get_mask_pixels(self, directory=None):
pixels = self.mask.get_pixels(directory=directory)
# checks that it contains only 0 and 1
assert np.all(np.isin(pixels, [0, 1])), "Imaging likelihood is not binary."
return pixels

def check_consistency_with_observation(self, observation):
"""Checks that the data image is consistent with instrument properties"""
width = abs(self.pixels.field_of_view_x[1] - self.pixels.field_of_view_x[0])
height = abs(self.pixels.field_of_view_y[1] - self.pixels.field_of_view_y[0])
num_pix_ra = int(width / observation.pixels.pixel_size)
width = abs(self.mask.field_of_view_x[1] - self.mask.field_of_view_x[0])
height = abs(self.mask.field_of_view_y[1] - self.mask.field_of_view_y[0])
num_pix_ra = round(width / observation.pixels.pixel_size)
error_message_ra = f"Field-of-view along RA is inconsistent (data: {num_pix_ra}, likelihood mask: {self.mask.num_pix_x})."
assert self.pixels.num_pix_x == num_pix_ra, error_message_ra
num_pix_dec = int(height / observation.pixels.pixel_size)
assert self.mask.num_pix_x == num_pix_ra, error_message_ra
num_pix_dec = round(height / observation.pixels.pixel_size)
error_message_dec = f"Field-of-view along Dec is inconsistent (data: {num_pix_dec}, likelihood mask: {self.mask.num_pix_y})."
assert self.pixels.num_pix_y == num_pix_dec, error_message_dec
assert self.mask.num_pix_y == num_pix_dec, error_message_dec
# TODO: check pixel size value?
29 changes: 28 additions & 1 deletion coolest/template/classes/likelihood_list.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__author__ = 'aymgal'

from copy import deepcopy
from typing import Tuple

from coolest.template.classes.likelihood import Likelihood
from coolest.template.classes.likelihood import Likelihood, SUPPORTED_CHOICES
from coolest.template.classes.base import APIBaseObject

class DataLikelihoodList(list, APIBaseObject):
Expand All @@ -13,5 +14,31 @@ class DataLikelihoodList(list, APIBaseObject):
"""

def __init__(self, *likelihoods: Tuple[Likelihood]) -> None:
# check that there is no duplicate likelihood type
unique_choices = deepcopy(SUPPORTED_CHOICES)
for ll in likelihoods:
if ll.type in unique_choices:
unique_choices.remove(ll.type)
else:
raise ValueError(f"There cannot be more than one likelihood type '{ll.type}'.")
list.__init__(self, likelihoods)
APIBaseObject.__init__(self)

def index(self, likelihood_type):
"""Override the parent class to look for a likelihood from its string type.
Parameters
----------
likelihood_type : string
Type of likelihood
"""
found = False
for i, ll in enumerate(self):
if ll.type == likelihood_type:
return i
raise ValueError(f"Likelihood '{likelihood_type}' is not in the likelihood list.")

def check_consistency_with_observation(self, observation):
for ll in self:
ll.check_consistency_with_observation(observation)

4 changes: 2 additions & 2 deletions coolest/template/classes/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def check_consistency_with_instrument(self, instrument):
"""Checks that the data image is consistent with instrument properties"""
width = abs(self.pixels.field_of_view_x[1] - self.pixels.field_of_view_x[0])
height = abs(self.pixels.field_of_view_y[1] - self.pixels.field_of_view_y[0])
num_pix_ra = int(width / instrument.pixel_size)
num_pix_ra = round(width / instrument.pixel_size)
error_message_ra = f"Field-of-view along RA is inconsistent (data: {self.pixels.num_pix_x}, instrument: {num_pix_ra})."
assert self.pixels.num_pix_x == num_pix_ra, error_message_ra
num_pix_dec = int(height / instrument.pixel_size)
num_pix_dec = round(height / instrument.pixel_size)
error_message_dec = f"Field-of-view along Dec is inconsistent (data: {self.pixels.num_pix_y}, instrument: {num_pix_dec})."
assert self.pixels.num_pix_y == num_pix_dec, error_message_dec
# TODO: check pixel size value?
Expand Down
9 changes: 7 additions & 2 deletions coolest/template/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def load(self, skip_jsonpickle=False, verbose=True):
jsonpickle_path = self.path + self._api_suffix + '.json'
if os.path.exists(jsonpickle_path) and not skip_jsonpickle:
instance = self.load_jsonpickle(jsonpickle_path)
# TODO: the following line is only for backward compatibility and will soon be removed
instance.likelihoods = None
else:
if verbose:
print(f"Template file '{jsonpickle_path}' not found, now trying to read '{json_path}'.")
Expand Down Expand Up @@ -235,8 +237,11 @@ class constructors called during instantiation of the COOLEST object.
if obs_pix_size not in (0, None) and not isclose_bool:
raise ValueError(f"Pixel size of observation ({obs_pix_size}) is inconsistent with "
f"the instrument pixel size ({instru_pix_size})")

# TODO: add extra checks
# INSTANCE METHODS
coolest.observation.check_consistency_with_instrument(coolest.instrument)
if coolest.likelihoods is not None:
coolest.likelihoods.check_consistency_with_observation(coolest.observation)
# TODO: further standardize these checks (using class methods?)

def _setup_instrument(self, instru_in):
psf_settings = instru_in.pop('psf')
Expand Down

0 comments on commit a28a36f

Please sign in to comment.