Skip to content

Commit 0b43085

Browse files
committed
Add option to set linestyles and expose subplot getdist settings in ParametersPlotter
1 parent 27f042d commit 0b43085

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

coolest/api/plotting.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ class ParametersPlotter(object):
444444

445445
def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None, coolest_names=None,
446446
ref_coolest_objects=None, ref_coolest_directories=None, ref_coolest_names=None,
447-
posterior_bool_list=None, colors=None,
447+
posterior_bool_list=None, colors=None, linestyles=None,
448448
add_multivariate_margin_samples=False, num_samples_per_model_margin=5_000):
449449
self.parameter_id_list = parameter_id_list
450450
self.coolest_objects = coolest_objects
@@ -462,13 +462,14 @@ def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None,
462462
if colors is None:
463463
colors = plt.cm.turbo(np.linspace(0.1, 0.9, self.num_models))
464464
self.colors = colors
465-
self.linestyles = ['--', ':', '-.', '-']
466-
self.markers = ['s', '^', 'o', '*']
465+
self.linestyles = linestyles
466+
self.ref_linestyles = ['--', ':', '-.', '-']
467+
self.ref_markers = ['s', '^', 'o', '*']
467468

468469
self._add_margin_samples = add_multivariate_margin_samples
469470
self._ns_per_model_margin = num_samples_per_model_margin
470471
self._color_margin = 'black'
471-
self._label_margin = "Combined samples"
472+
self._label_margin = "Combined"
472473

473474
# self.posterior_bool_list = posterior_bool_list
474475
# self.param_lens, self.param_source = util.split_lens_source_params(
@@ -579,7 +580,7 @@ def init_getdist(self, shift_sample_list=None, settings_mcsamples=None,
579580

580581
self._mcsamples = mcsamples
581582
self.ref_values = point_estimates
582-
self.ref_markers = [dict(zip(self.parameter_id_list, values)) for values in self.ref_values]
583+
self.ref_values_markers = [dict(zip(self.parameter_id_list, values)) for values in self.ref_values]
583584

584585
def get_mcsamples_getdist(self, with_margin=False):
585586
if not self._add_margin_samples or with_margin:
@@ -593,10 +594,11 @@ def get_margin_mcsamples_getdist(self):
593594
else:
594595
return self._mcsamples[-1]
595596

596-
def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_range=None,
597+
def plot_triangle_getdist(self, filled_contours=True, angles_range=None,
597598
linewidth_hist=2, linewidth_cont=2, linewidth_margin=4,
598599
marker_linewidth=2, marker_size=15,
599-
axes_labelsize=12, legend_fontsize=14):
600+
axes_labelsize=12, legend_fontsize=14,
601+
**subplot_kwargs):
600602
"""Corner array of subplots using getdist.triangle_plot method.
601603
602604
Parameters
@@ -631,7 +633,7 @@ def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_ran
631633
# alphas[-1] = 0.7
632634

633635
# Make the plot
634-
g = plots.get_subplot_plotter(subplot_size=subplot_size)
636+
g = plots.get_subplot_plotter(**subplot_kwargs)
635637
g.settings.legend_fontsize = legend_fontsize
636638
g.settings.axes_labelsize = axes_labelsize
637639
g.triangle_plot(
@@ -649,15 +651,15 @@ def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_ran
649651

650652
# Add marker lines and points
651653
for k in range(0, len(self.ref_values)):
652-
g.add_param_markers(self.ref_markers[k], color='black', ls=self.linestyles[k],
654+
g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k],
653655
lw=marker_linewidth)
654656
for i in range(0,self.num_params):
655657
val_x = self.ref_values[k][i]
656658
for j in range(i+1,self.num_params):
657659
val_y = self.ref_values[k][j]
658660
if val_x is not None and val_y is not None:
659661
g.subplots[j,i].scatter(val_x, val_y, s=marker_size, facecolors='black',
660-
color='black', marker=self.markers[k])
662+
color='black', marker=self.ref_markers[k])
661663

662664

663665
# Set default ranges for angles
@@ -684,7 +686,7 @@ def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_ran
684686

685687
def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
686688
legend_ncol=None, filled_contours=True, linewidth=1,
687-
marker_size=15):
689+
marker_size=15, **subplot_kwargs):
688690
"""Array of (2D contours) subplots using getdist.rectangle_plot method.
689691
690692
Parameters
@@ -709,25 +711,25 @@ def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
709711
if legend_ncol is None:
710712
legend_ncol = 3
711713
# Make the plot
712-
g = plots.get_subplot_plotter(subplot_size=subplot_size)
714+
g = plots.get_subplot_plotter(**subplot_kwargs)
713715
g.rectangle_plot(x_param_ids, y_param_ids, roots=self._mcsamples,
714716
legend_labels=legend_labels,
715717
filled=filled_contours,
716718
colors=colors,
717719
legend_ncol=legend_ncol,
718720
line_args=line_args,
719721
contour_colors=self.colors)
720-
for k in range(len(self.ref_markers)):
721-
g.add_param_markers(self.ref_markers[k], color='black', ls=self.linestyles[k], lw=linewidth)
722+
for k in range(len(self.ref_values)):
723+
g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
722724
for j, key_x in enumerate(x_param_ids):
723-
val_x = self.ref_markers[k][key_x]
725+
val_x = self.ref_values_markers[k][key_x]
724726
for i, key_y in enumerate(y_param_ids):
725-
val_y = self.ref_markers[k][key_y]
727+
val_y = self.ref_values_markers[k][key_y]
726728
if val_x is not None and val_y is not None:
727-
g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.markers[k])
729+
g.subplots[i, j].scatter(val_x,val_y,s=marker_size,facecolors='black',color='black',marker=self.ref_markers[k])
728730
return g
729731

730-
def plot_1d_getdist(self, subplot_size=1, num_columns=None, legend_ncol=None, linewidth=1):
732+
def plot_1d_getdist(self, num_columns=None, legend_ncol=None, linewidth=1, **subplot_kwargs):
731733
"""Array of 1D histogram subplots using getdist.plots_1d method.
732734
733735
Parameters
@@ -755,7 +757,7 @@ def plot_1d_getdist(self, subplot_size=1, num_columns=None, legend_ncol=None, li
755757
if legend_ncol is None:
756758
legend_ncol = 3
757759
# Make the plot
758-
g = plots.get_subplot_plotter(subplot_size=subplot_size)
760+
g = plots.get_subplot_plotter(**subplot_kwargs)
759761
g.plots_1d(self._mcsamples,
760762
params=self.parameter_id_list,
761763
legend_labels=legend_labels,
@@ -765,14 +767,14 @@ def plot_1d_getdist(self, subplot_size=1, num_columns=None, legend_ncol=None, li
765767
nx=num_columns, legend_ncol=legend_ncol,
766768
)
767769
for k in range(len(self.ref_values)):
768-
g.add_param_markers(self.ref_markers[k], color='black', ls=self.linestyles[k], lw=linewidth)
770+
g.add_param_markers(self.ref_values_markers[k], color='black', ls=self.ref_linestyles[k], lw=linewidth)
769771
# for k in range(0, len(self.ref_values)):
770772
# # Add vertical and horizontal lines
771773
# for i in range(0, self.num_params):
772774
# val = self.ref_values[k][i]
773775
# ax = g.subplots.flatten()[i]
774776
# if val is not None:
775-
# ax.axvline(val, color='black', ls=self.linestyles[k], alpha=1.0, lw=1)
777+
# ax.axvline(val, color='black', ls=self.ref_linestyles[k], alpha=1.0, lw=1)
776778
return g
777779

778780
def plot_source(self, idx_file=0):
@@ -860,7 +862,7 @@ def plotting_routine(self, param_dict, idx_file=0):
860862
def _prepare_getdist_plot(self, lw, lw_cont=None, lw_margin=None):
861863
if lw_margin is None:
862864
lw_margin = lw + 2
863-
line_args = [{'ls':'-', 'lw': lw, 'color': c} for c in self.colors]
865+
line_args = [{'ls': ls, 'lw': lw, 'color': c} for ls, c in zip(self.linestyles, self.colors)]
864866
lw_conts = [lw_cont]*self.num_models
865867
ls_conts = ['-']*self.num_models
866868
legend_labels = copy.deepcopy(self.coolest_names)

0 commit comments

Comments
 (0)