Skip to content

Commit

Permalink
Fix issue with no specific linestyles was passed to ParametersPlotter
Browse files Browse the repository at this point in the history
  • Loading branch information
aymgal committed Mar 15, 2024
1 parent 0b43085 commit 2f7d846
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions coolest/api/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ class ParametersPlotter(object):
List of bool to toggle errorbars on point-estimate values
colors : list, optional
List of pyplot color names to associate to each coolest model.
linestyles : list, optional
List of pyplot linesyles to associate to each coolest model.
add_multivariate_margin_samples : bool, optional
If True, will append to the list of compared models
a new chain that is resampled from the multi-variate normal distribution,
Expand Down Expand Up @@ -462,6 +464,8 @@ def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None,
if colors is None:
colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models))
self.colors = colors
if linestyles is None:
linestyles = ['-']*self.num_models
self.linestyles = linestyles
self.ref_linestyles = ['--', ':', '-.', '-']
self.ref_markers = ['s', '^', 'o', '*']
Expand Down Expand Up @@ -597,7 +601,7 @@ def get_margin_mcsamples_getdist(self):
def plot_triangle_getdist(self, filled_contours=True, angles_range=None,
linewidth_hist=2, linewidth_cont=2, linewidth_margin=4,
marker_linewidth=2, marker_size=15,
axes_labelsize=12, legend_fontsize=14,
axes_labelsize=None, legend_fontsize=None,
**subplot_kwargs):
"""Corner array of subplots using getdist.triangle_plot method.
Expand Down Expand Up @@ -634,8 +638,10 @@ def plot_triangle_getdist(self, filled_contours=True, angles_range=None,

# Make the plot
g = plots.get_subplot_plotter(**subplot_kwargs)
g.settings.legend_fontsize = legend_fontsize
g.settings.axes_labelsize = axes_labelsize
if legend_fontsize is not None:
g.settings.legend_fontsize = legend_fontsize
if axes_labelsize is not None:
g.settings.axes_labelsize = axes_labelsize
g.triangle_plot(
self._mcsamples,
params=self.parameter_id_list,
Expand Down Expand Up @@ -685,8 +691,9 @@ def plot_triangle_getdist(self, filled_contours=True, angles_range=None,
return g

def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
legend_ncol=None, filled_contours=True, linewidth=1,
marker_size=15, **subplot_kwargs):
legend_ncol=None, legend_fontsize=None,
filled_contours=True, linewidth=1,
marker_size=15, axes_labelsize=None, **subplot_kwargs):
"""Array of (2D contours) subplots using getdist.rectangle_plot method.
Parameters
Expand All @@ -711,14 +718,18 @@ def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
if legend_ncol is None:
legend_ncol = 3
# Make the plot
g = plots.get_subplot_plotter(**subplot_kwargs)
g = plots.get_subplot_plotter(**subplot_kwargs)
if legend_fontsize is not None:
g.settings.legend_fontsize = legend_fontsize
if axes_labelsize is not None:
g.settings.axes_labelsize = axes_labelsize
g.rectangle_plot(x_param_ids, y_param_ids, roots=self._mcsamples,
legend_labels=legend_labels,
filled=filled_contours,
colors=colors,
legend_ncol=legend_ncol,
line_args=line_args,
contour_colors=self.colors)
filled=filled_contours,
colors=colors,
legend_ncol=legend_ncol,
legend_labels=legend_labels,
line_args=line_args,
contour_colors=self.colors)
for k in range(len(self.ref_values)):
g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
for j, key_x in enumerate(x_param_ids):
Expand All @@ -729,7 +740,9 @@ def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.ref_markers[k])
return g

def plot_1d_getdist(self, num_columns=None, legend_ncol=None, linewidth=1, **subplot_kwargs):
def plot_1d_getdist(self, num_columns=None, legend_ncol=None,
legend_fontsize=None, axes_labelsize=None,
linewidth=1, **subplot_kwargs):
"""Array of 1D histogram subplots using getdist.plots_1d method.
Parameters
Expand Down Expand Up @@ -757,7 +770,11 @@ def plot_1d_getdist(self, num_columns=None, legend_ncol=None, linewidth=1, **sub
if legend_ncol is None:
legend_ncol = 3
# Make the plot
g = plots.get_subplot_plotter(**subplot_kwargs)
g = plots.get_subplot_plotter(**subplot_kwargs)
if legend_fontsize is not None:
g.settings.legend_fontsize = legend_fontsize
if axes_labelsize is not None:
g.settings.axes_labelsize = axes_labelsize
g.plots_1d(self._mcsamples,
params=self.parameter_id_list,
legend_labels=legend_labels,
Expand Down

0 comments on commit 2f7d846

Please sign in to comment.