From 1b4e50ac1502fdcffd5d7e6b179eca45586620de Mon Sep 17 00:00:00 2001 From: Artur Lobanov Date: Mon, 14 Oct 2024 16:48:50 +0200 Subject: [PATCH] Allow multi object rate plots --- menu_tools/rate_plots/plotter.py | 24 ++++++++++++++++++------ menu_tools/utils/config.py | 7 +++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/menu_tools/rate_plots/plotter.py b/menu_tools/rate_plots/plotter.py index 1d23d8c8..0189f360 100644 --- a/menu_tools/rate_plots/plotter.py +++ b/menu_tools/rate_plots/plotter.py @@ -60,11 +60,13 @@ def _style_plot(self, fig, ax0, ax1=None, legend_loc="upper right"): ax0.set_yscale("log") ax0.grid() ax0.tick_params(direction="in") + xlabel = rf"{self._online_offline} $p_T$ [GeV]" + if ax1: - ax1.set_xlabel(rf"{self._online_offline} $p_T$ [GeV]") + ax1.set_xlabel(xlabel) ax1.grid() else: - ax0.set_xlabel(rf"{self._online_offline} $p_T$ [GeV]") + ax0.set_xlabel(xlabel) fig.tight_layout() def _plot_single_version_rate_curves(self): @@ -85,6 +87,11 @@ def _plot_single_version_rate_curves(self): xvals = list(rate_values.keys()) yvals = list(rate_values.values()) label = f"{obj_instances[version].plot_label}" + xlabel = rf"{self._online_offline} $p_T$ [GeV]" + + obj_spec_split = obj_specifier.split(":") + if len(obj_spec_split) == 3: + label += f", {obj_spec_split[2]}" plot_dict[obj_specifier] = { "x_values": xvals, @@ -92,7 +99,7 @@ def _plot_single_version_rate_curves(self): "object": obj_instances[version].plot_label, "label": label, "version": version, - "xlabel": rf"{self._online_offline} $p_T$ [GeV]", + "xlabel": xlabel, } ax.plot( @@ -238,7 +245,7 @@ def _load_cached_arrays(self): return arr - def compute_rate(self, thresholds: np.ndarray) -> dict: + def compute_rate(self, thresholds: np.ndarray, nObj = 1) -> dict: """Computes rate at threholds after application of all object cuts. threshold: pt threshold for which to compute rate @@ -253,7 +260,9 @@ def compute_rate(self, thresholds: np.ndarray) -> dict: pt_field = "offline_pt" if self.apply_offline_conversion else "pt" if (max_pt_obj := self.arrays[obj_mask][pt_field]).ndim > 1: - max_pt_obj = ak.max(max_pt_obj, axis=1) + # max_pt_obj = ak.max(max_pt_obj, axis=1) + max_pt_obj = max_pt_obj[ak.argsort(max_pt_obj, axis=1, ascending=False)][:,nObj-1:] + max_pt_obj = ak.fill_none(ak.firsts(max_pt_obj), -1) cumsum = np.cumsum( np.histogram(max_pt_obj, bins=[-1] + list(thresholds) + [1e5])[0] @@ -307,7 +316,7 @@ def _compute_rates( apply_offline_conversion, ) - rate_data[version] = rate_computer.compute_rate(self.get_bins(plot_config)) + rate_data[version] = rate_computer.compute_rate(self.get_bins(plot_config), nObj = plot_config.nObjects) return rate_data @@ -328,6 +337,9 @@ def run(self, apply_offline_conversion: bool = False) -> None: plot_config = RatePlotConfig(cfg_plot, plot_name) rate_plot_data = {} + if plot_config.nObjects > 1: + print(f"## Warning! Making rates for {plot_config.nObjects} objects!") + # Iterate over test objects in plot for ( obj_specifier, diff --git a/menu_tools/utils/config.py b/menu_tools/utils/config.py index 05419827..ff6e31aa 100644 --- a/menu_tools/utils/config.py +++ b/menu_tools/utils/config.py @@ -33,6 +33,13 @@ def version(self) -> str: except KeyError: raise KeyError(f"No version configured for {self.plot_name}!") + @property + def nObjects(self) -> int: + if "nObjects" in self._cfg: + return int(self._cfg["nObjects"]) + else: + return 1 + @property def bin_width(self) -> float: return float(self._cfg["binning"]["step"])