@@ -444,7 +444,7 @@ class ParametersPlotter(object):
444
444
445
445
def __init__ (self , parameter_id_list , coolest_objects , coolest_directories = None , coolest_names = None ,
446
446
ref_coolest_objects = None , ref_coolest_directories = None , ref_coolest_names = None ,
447
- posterior_bool_list = None , colors = None ,
447
+ posterior_bool_list = None , colors = None , linestyles = None ,
448
448
add_multivariate_margin_samples = False , num_samples_per_model_margin = 5_000 ):
449
449
self .parameter_id_list = parameter_id_list
450
450
self .coolest_objects = coolest_objects
@@ -462,13 +462,14 @@ def __init__(self, parameter_id_list, coolest_objects, coolest_directories=None,
462
462
if colors is None :
463
463
colors = plt .cm .turbo (np .linspace (0.1 , 0.9 , self .num_models ))
464
464
self .colors = colors
465
- self .linestyles = ['--' , ':' , '-.' , '-' ]
466
- self .markers = ['s' , '^' , 'o' , '*' ]
465
+ self .linestyles = linestyles
466
+ self .ref_linestyles = ['--' , ':' , '-.' , '-' ]
467
+ self .ref_markers = ['s' , '^' , 'o' , '*' ]
467
468
468
469
self ._add_margin_samples = add_multivariate_margin_samples
469
470
self ._ns_per_model_margin = num_samples_per_model_margin
470
471
self ._color_margin = 'black'
471
- self ._label_margin = "Combined samples "
472
+ self ._label_margin = "Combined"
472
473
473
474
# self.posterior_bool_list = posterior_bool_list
474
475
# self.param_lens, self.param_source = util.split_lens_source_params(
@@ -579,7 +580,7 @@ def init_getdist(self, shift_sample_list=None, settings_mcsamples=None,
579
580
580
581
self ._mcsamples = mcsamples
581
582
self .ref_values = point_estimates
582
- self .ref_markers = [dict (zip (self .parameter_id_list , values )) for values in self .ref_values ]
583
+ self .ref_values_markers = [dict (zip (self .parameter_id_list , values )) for values in self .ref_values ]
583
584
584
585
def get_mcsamples_getdist (self , with_margin = False ):
585
586
if not self ._add_margin_samples or with_margin :
@@ -593,10 +594,11 @@ def get_margin_mcsamples_getdist(self):
593
594
else :
594
595
return self ._mcsamples [- 1 ]
595
596
596
- def plot_triangle_getdist (self , subplot_size = 1 , filled_contours = True , angles_range = None ,
597
+ def plot_triangle_getdist (self , filled_contours = True , angles_range = None ,
597
598
linewidth_hist = 2 , linewidth_cont = 2 , linewidth_margin = 4 ,
598
599
marker_linewidth = 2 , marker_size = 15 ,
599
- axes_labelsize = 12 , legend_fontsize = 14 ):
600
+ axes_labelsize = 12 , legend_fontsize = 14 ,
601
+ ** subplot_kwargs ):
600
602
"""Corner array of subplots using getdist.triangle_plot method.
601
603
602
604
Parameters
@@ -631,7 +633,7 @@ def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_ran
631
633
# alphas[-1] = 0.7
632
634
633
635
# Make the plot
634
- g = plots .get_subplot_plotter (subplot_size = subplot_size )
636
+ g = plots .get_subplot_plotter (** subplot_kwargs )
635
637
g .settings .legend_fontsize = legend_fontsize
636
638
g .settings .axes_labelsize = axes_labelsize
637
639
g .triangle_plot (
@@ -649,15 +651,15 @@ def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_ran
649
651
650
652
# Add marker lines and points
651
653
for k in range (0 , len (self .ref_values )):
652
- g .add_param_markers (self .ref_markers [k ], color = 'black' , ls = self .linestyles [k ],
654
+ g .add_param_markers (self .ref_values_markers [k ], color = 'black' , ls = self .ref_linestyles [k ],
653
655
lw = marker_linewidth )
654
656
for i in range (0 ,self .num_params ):
655
657
val_x = self .ref_values [k ][i ]
656
658
for j in range (i + 1 ,self .num_params ):
657
659
val_y = self .ref_values [k ][j ]
658
660
if val_x is not None and val_y is not None :
659
661
g .subplots [j ,i ].scatter (val_x , val_y , s = marker_size , facecolors = 'black' ,
660
- color = 'black' , marker = self .markers [k ])
662
+ color = 'black' , marker = self .ref_markers [k ])
661
663
662
664
663
665
# Set default ranges for angles
@@ -684,7 +686,7 @@ def plot_triangle_getdist(self, subplot_size=1, filled_contours=True, angles_ran
684
686
685
687
def plot_rectangle_getdist (self , x_param_ids , y_param_ids , subplot_size = 1 ,
686
688
legend_ncol = None , filled_contours = True , linewidth = 1 ,
687
- marker_size = 15 ):
689
+ marker_size = 15 , ** subplot_kwargs ):
688
690
"""Array of (2D contours) subplots using getdist.rectangle_plot method.
689
691
690
692
Parameters
@@ -709,25 +711,25 @@ def plot_rectangle_getdist(self, x_param_ids, y_param_ids, subplot_size=1,
709
711
if legend_ncol is None :
710
712
legend_ncol = 3
711
713
# Make the plot
712
- g = plots .get_subplot_plotter (subplot_size = subplot_size )
714
+ g = plots .get_subplot_plotter (** subplot_kwargs )
713
715
g .rectangle_plot (x_param_ids , y_param_ids , roots = self ._mcsamples ,
714
716
legend_labels = legend_labels ,
715
717
filled = filled_contours ,
716
718
colors = colors ,
717
719
legend_ncol = legend_ncol ,
718
720
line_args = line_args ,
719
721
contour_colors = self .colors )
720
- for k in range (len (self .ref_markers )):
721
- g .add_param_markers (self .ref_markers [k ], color = 'black' , ls = self .linestyles [k ], lw = linewidth )
722
+ for k in range (len (self .ref_values )):
723
+ g .add_param_markers (self .ref_values_markers [k ], color = 'black' , ls = self .ref_linestyles [k ], lw = linewidth )
722
724
for j , key_x in enumerate (x_param_ids ):
723
- val_x = self .ref_markers [k ][key_x ]
725
+ val_x = self .ref_values_markers [k ][key_x ]
724
726
for i , key_y in enumerate (y_param_ids ):
725
- val_y = self .ref_markers [k ][key_y ]
727
+ val_y = self .ref_values_markers [k ][key_y ]
726
728
if val_x is not None and val_y is not None :
727
- g .subplots [i , j ].scatter (val_x ,val_y ,s = marker_size ,facecolors = 'black' ,color = 'black' ,marker = self .markers [k ])
729
+ g .subplots [i , j ].scatter (val_x ,val_y ,s = marker_size ,facecolors = 'black' ,color = 'black' ,marker = self .ref_markers [k ])
728
730
return g
729
731
730
- def plot_1d_getdist (self , subplot_size = 1 , num_columns = None , legend_ncol = None , linewidth = 1 ):
732
+ def plot_1d_getdist (self , num_columns = None , legend_ncol = None , linewidth = 1 , ** subplot_kwargs ):
731
733
"""Array of 1D histogram subplots using getdist.plots_1d method.
732
734
733
735
Parameters
@@ -755,7 +757,7 @@ def plot_1d_getdist(self, subplot_size=1, num_columns=None, legend_ncol=None, li
755
757
if legend_ncol is None :
756
758
legend_ncol = 3
757
759
# Make the plot
758
- g = plots .get_subplot_plotter (subplot_size = subplot_size )
760
+ g = plots .get_subplot_plotter (** subplot_kwargs )
759
761
g .plots_1d (self ._mcsamples ,
760
762
params = self .parameter_id_list ,
761
763
legend_labels = legend_labels ,
@@ -765,14 +767,14 @@ def plot_1d_getdist(self, subplot_size=1, num_columns=None, legend_ncol=None, li
765
767
nx = num_columns , legend_ncol = legend_ncol ,
766
768
)
767
769
for k in range (len (self .ref_values )):
768
- g .add_param_markers (self .ref_markers [k ], color = 'black' , ls = self .linestyles [k ], lw = linewidth )
770
+ g .add_param_markers (self .ref_values_markers [k ], color = 'black' , ls = self .ref_linestyles [k ], lw = linewidth )
769
771
# for k in range(0, len(self.ref_values)):
770
772
# # Add vertical and horizontal lines
771
773
# for i in range(0, self.num_params):
772
774
# val = self.ref_values[k][i]
773
775
# ax = g.subplots.flatten()[i]
774
776
# if val is not None:
775
- # ax.axvline(val, color='black', ls=self.linestyles [k], alpha=1.0, lw=1)
777
+ # ax.axvline(val, color='black', ls=self.ref_linestyles [k], alpha=1.0, lw=1)
776
778
return g
777
779
778
780
def plot_source (self , idx_file = 0 ):
@@ -860,7 +862,7 @@ def plotting_routine(self, param_dict, idx_file=0):
860
862
def _prepare_getdist_plot (self , lw , lw_cont = None , lw_margin = None ):
861
863
if lw_margin is None :
862
864
lw_margin = lw + 2
863
- line_args = [{'ls' :'-' , 'lw' : lw , 'color' : c } for c in self .colors ]
865
+ line_args = [{'ls' : ls , 'lw' : lw , 'color' : c } for ls , c in zip ( self .linestyles , self . colors ) ]
864
866
lw_conts = [lw_cont ]* self .num_models
865
867
ls_conts = ['-' ]* self .num_models
866
868
legend_labels = copy .deepcopy (self .coolest_names )
0 commit comments