Skip to content

Commit

Permalink
Merge pull request #109 from rapoliveira/bugfix-improve-example16
Browse files Browse the repository at this point in the history
Bugfix improve example16
  • Loading branch information
rpoleski authored Nov 17, 2023
2 parents 828190b + 6ac184c commit 106f682
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 43 deletions.
15 changes: 8 additions & 7 deletions examples/example_16/ob03235_2_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ plots:
# makes legend in 2 columns
loc: lower center
second Y scale:
# This adds second Y axis on the right-hand side. Only first line below is required.
magnifications: [2, 3, 4, 5, 6, 7, 8, 9]
# If you don't know what will be the range of magnifications on your plot, then make
# a test run with some very small and very large values and a warning will tell you
# what is exact range on the plot.
labels: [a, b, c, d, e, f, g, h]
# If you remove the line above, then magnification values will be printed.
# This adds second Y axis to the right side. Only magnifications key is required.
magnifications: optimal
# magnifications: [2, 3, 4, 5, 6, 7, 8, 9]
# If you want to provide magnification values but don't know what will be the range
# of magnifications on your plot, then make a test with very small and large numbers
# and a warning will tell you the exact range.
# labels: [a, b, c, d, e, f, g, h]
# The list of labels above can not be given if magnifications = "optimal"
label: What is shown on this axis?
color: magenta
trajectory:
Expand Down
147 changes: 111 additions & 36 deletions examples/example_16/ulens_model_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import numpy as np
from scipy.interpolate import interp1d
from matplotlib import pyplot as plt
from matplotlib import gridspec, rc, rcParams, rcParamsDefault
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib import gridspec, rcParams, rcParamsDefault
# from matplotlib.backends.backend_pdf import PdfPages

import_failed = set()
try:
Expand All @@ -38,7 +38,7 @@
except Exception:
raise ImportError('\nYou have to install MulensModel first!\n')

__version__ = '0.33.0'
__version__ = '0.34.0'


class UlensModelFit(object):
Expand Down Expand Up @@ -334,6 +334,8 @@ def _check_MM_version(self):
"""
Check if MulensModel is new enough
"""
# code_version = "{:} and {:}".format(mm.__version__, __version__)
# print('\nMulensModel and script versions:', code_version, end='\n\n')
if int(mm.__version__.split('.')[0]) < 2:
raise RuntimeError(
"ulens_model_fit.py requires MulensModel in version "
Expand Down Expand Up @@ -717,17 +719,27 @@ def _check_plots_parameters_best_model_Y_scale(self):
'Unknown settings for "second Y scale" in '
'"best model": {:}'.format(unknown))
if not isinstance(settings['magnifications'], list):
raise TypeError(
'"best model" -> "second Y scale" -> "magnifications" has to '
'be a list, not ' + str(type(settings['magnifications'])))
for value in settings['magnifications']:
if not isinstance(value, (int, float)):
if settings['magnifications'] != 'optimal':
raise TypeError(
'Wrong value in magnifications: ' + str(value))
'"best model" -> "second Y scale" -> "magnifications" has '
'to be a list or "optimal", not ' +
str(type(settings['magnifications'])))
else:
for value in settings['magnifications']:
if not isinstance(value, (int, float)):
raise TypeError(
'Wrong value in magnifications: ' + str(value))
if 'labels' not in settings:
settings['labels'] = [
str(x) for x in settings['magnifications']]
if settings['magnifications'] != 'optimal':
settings['labels'] = [
str(x) for x in settings['magnifications']]
else:
settings['labels'] = []
else:
if settings['magnifications'] == 'optimal':
raise ValueError(
'In "best model" -> "second Y scale", labels can not be '
'provided if "magnifications" is defined as "optimal"')
if not isinstance(settings['labels'], list):
raise TypeError(
'"best model" -> "second Y scale" -> "labels" has to be '
Expand Down Expand Up @@ -1113,11 +1125,11 @@ def _parse_fitting_parameters_EMCEE(self):
'got: ' + name)
if path.exists(name):
if path.isfile(name):
msg = "Exisiting file " + name + " will be overwritten"
msg = "Existing file " + name + " will be overwritten"
warnings.warn(msg)
else:
raise ValueError("The path provided for posterior (" +
name + ") exsists and is a directory")
name + ") exists and is a directory")
self._posterior_file_name = name[:-4]
self._posterior_file_fluxes = None

Expand Down Expand Up @@ -1284,7 +1296,7 @@ def _check_output_files_MultiNest(self):
existing.append(file_name)

if len(existing) > 0:
message = "\n\n Exisiting files will be overwritten "
message = "\n\n Existing files will be overwritten "
message += "(unless you kill this process)!!!\n"
warnings.warn(message + str(existing) + "\n")

Expand Down Expand Up @@ -2309,6 +2321,11 @@ def _parse_results_EMCEE(self):
This version works with EMCEE version 2.X and 3.0.
"""
if self._yaml_results:
lst = [mm.__version__, __version__]
code_version = "MulensModel and script versions: {:}".format(lst)
print(code_version, **self._yaml_kwargs)

accept_rate = np.mean(self._sampler.acceptance_fraction)
out = "Mean acceptance fraction: {0:.3f}".format(accept_rate)
print(out)
Expand All @@ -2323,6 +2340,11 @@ def _parse_results_EMCEE(self):
print(out, **self._yaml_kwargs)
self._extract_posterior_samples_EMCEE()

if self._yaml_results and isinstance(self._fixed_parameters, dict):
print("Fixed parameters:", **self._yaml_kwargs)
for (key, value) in self._fixed_parameters.items():
print(" {:} : {:}".format(key, value), **self._yaml_kwargs)

print("Fitted parameters:")
self._print_results(self._samples_flat)
if self._yaml_results:
Expand All @@ -2343,6 +2365,10 @@ def _parse_results_EMCEE(self):
if self._yaml_results:
self._print_yaml_best_model()

if self._shift_t_0 and self._yaml_results:
print("Plots shift_t_0 : {:}".format(self._shift_t_0_val),
**self._yaml_kwargs)

def _extract_posterior_samples_EMCEE(self):
"""
set self._samples_flat and self._samples for EMCEE
Expand Down Expand Up @@ -2486,20 +2512,20 @@ def _shift_t_0_in_samples(self):
if name in self._fit_parameters:
index = self._fit_parameters.index(name)
values = self._samples_flat[:, index]
mean = np.mean(values)
self._shift_t_0_val = int(np.mean(values))
try:
self._samples_flat[:, index] -= int(mean)
self._samples_flat[:, index] -= self._shift_t_0_val
if 'trace' in self._plots:
self._samples[:, :, index] -= int(mean)
self._samples[:, :, index] -= self._shift_t_0_val
except TypeError:
fmt = ("Warning: extremely wide range of posterior {:}: "
"from {:} to {:}")
warnings.warn(
fmt.format(name, np.min(values), np.max(values)))
self._samples_flat[:, index] = values - int(mean)
self._samples_flat[:, index] = values - self._shift_t_0_val
if 'trace' in self._plots:
self._samples[:, :, index] = (
self._samples[:, :, index] - int(mean))
self._samples[:, :, index] - self._shift_t_0_val)

def _get_fluxes_to_print_EMCEE(self):
"""
Expand Down Expand Up @@ -2823,6 +2849,9 @@ def _save_figure(self, file_name, figure=None, dpi=None):
kwargs = dict()
if dpi is not None:
kwargs = {'dpi': dpi}
if path.isfile(file_name):
msg = "Existing file " + file_name + " will be overwritten"
warnings.warn(msg)
caller.savefig(file_name, **kwargs)
plt.close()

Expand Down Expand Up @@ -3071,18 +3100,74 @@ def _mark_second_Y_axis_in_best_plot(self):
magnifications = settings['magnifications']
color = settings.get("color", "red")
label = settings.get("label", "magnification")
labels = settings['labels']
labels = settings.get("labels")

ylim = plt.ylim()
ax2 = plt.gca().twinx()
(A_min, A_max, sb_fluxes) = self._second_Y_axis_get_fluxes(ylim)
out1, out2 = False, False
if magnifications == "optimal":
(magnifications, labels, out1) = self._second_Y_axis_optimal(
ax2, A_min, A_max)
flux = sb_fluxes[0] * magnifications + sb_fluxes[1]
out2 = self._second_Y_axis_warnings(flux, labels, magnifications,
A_min, A_max)
if out1 or out2:
ax2.get_yaxis().set_visible(False)
return

ticks = mm.Utils.get_mag_from_flux(flux)
ax2.set_ylabel(label).set_color(color)
ax2.spines['right'].set_color(color)
ax2.set_ylim(ylim[0], ylim[1])
ax2.tick_params(axis='y', colors=color)
plt.yticks(ticks, labels, color=color)

def _second_Y_axis_get_fluxes(self, ylim):
"""
Get fluxes and limiting magnification values for the second Y axis
"""
flux_min = mm.Utils.get_flux_from_mag(ylim[0])
flux_max = mm.Utils.get_flux_from_mag(ylim[1])

(source_flux, blend_flux) = self._event.get_ref_fluxes()

if self._model.n_sources == 1:
total_source_flux = source_flux
else:
total_source_flux = sum(source_flux)
flux = total_source_flux * magnifications + blend_flux
A_min = (flux_min - blend_flux) / total_source_flux
A_max = (flux_max - blend_flux) / total_source_flux

return (A_min, A_max, [total_source_flux, blend_flux])

def _second_Y_axis_optimal(self, ax2, A_min, A_max):
"""
Get optimal values of magnifications and labels
"""
ax2.set_ylim(A_min, A_max)
A_values = ax2.yaxis.get_ticklocs().round(7)
A_values = A_values[(A_values >= 1.) & (A_values < A_max)]
is_integer = [mag.is_integer() for mag in A_values]
if all(is_integer):
labels = [f"{int(x):d}" for x in A_values]
return (A_values.tolist(), labels, False)

fnum = np.array([str(x)[::-1].find(".") for x in A_values])
labels = np.array([f"%0.{max(fnum)}f" % x for x in A_values])
if max(fnum) >= 4 and len(fnum[fnum < 4]) < 3:
msg = ("The computed magnifications for the second Y scale cover"
" a range too small to be shown: {:}")
warnings.warn(msg.format(A_values))
return (A_values.tolist(), labels.tolist(), True)
if max(fnum) >= 4:
labels = np.array([f"{x:0.3f}" for x in A_values])

return (A_values[fnum < 4].tolist(), labels[fnum < 4].tolist(), False)

def _second_Y_axis_warnings(self, flux, labels, A_values, A_min, A_max):
"""
Issue warnings for negative flux or bad range of magnificaitons
"""
if np.any(flux < 0.):
mask = (flux > 0.)
flux = flux[mask]
Expand All @@ -3091,28 +3176,18 @@ def _mark_second_Y_axis_in_best_plot(self):
"because they correspond to negative flux which cannot "
"be translated to magnitudes.")
warnings.warn(msg.format(np.sum(np.logical_not(mask))))
A_min = (flux_min - blend_flux) / total_source_flux
A_max = (flux_max - blend_flux) / total_source_flux

if (np.min(magnifications) < A_min or np.max(magnifications) > A_max or
if (np.min(A_values) < A_min or np.max(A_values) > A_max or
np.any(flux < 0.)):
msg = ("Provided magnifications for the second (i.e., right-hand "
"side) Y-axis scale are from {:} to {:},\nbut the range "
"of plotted magnifications is from {:} to {:}, hence, "
"the second scale is not plotted")
args = [min(magnifications), max(magnifications),
A_min[0], A_max[0]]
args = [min(A_values), max(A_values), A_min[0], A_max[0]]
warnings.warn(msg.format(*args))
return
return True

ticks = mm.Utils.get_mag_from_flux(flux)

ax2 = plt.gca().twinx()
ax2.set_ylabel(label).set_color(color)
ax2.spines['right'].set_color(color)
ax2.set_ylim(ylim[0], ylim[1])
ax2.tick_params(axis='y', colors=color)
plt.yticks(ticks, labels, color=color)
return False

def _make_trajectory_plot(self):
"""
Expand Down

0 comments on commit 106f682

Please sign in to comment.