Skip to content

Commit

Permalink
Allow multi object rate plots
Browse files Browse the repository at this point in the history
  • Loading branch information
artlbv committed Oct 14, 2024
1 parent 44e97ee commit 1b4e50a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
24 changes: 18 additions & 6 deletions menu_tools/rate_plots/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ def _style_plot(self, fig, ax0, ax1=None, legend_loc="upper right"):
ax0.set_yscale("log")
ax0.grid()
ax0.tick_params(direction="in")
xlabel = rf"{self._online_offline} $p_T$ [GeV]"

if ax1:
ax1.set_xlabel(rf"{self._online_offline} $p_T$ [GeV]")
ax1.set_xlabel(xlabel)
ax1.grid()
else:
ax0.set_xlabel(rf"{self._online_offline} $p_T$ [GeV]")
ax0.set_xlabel(xlabel)
fig.tight_layout()

def _plot_single_version_rate_curves(self):
Expand All @@ -85,14 +87,19 @@ def _plot_single_version_rate_curves(self):
xvals = list(rate_values.keys())
yvals = list(rate_values.values())
label = f"{obj_instances[version].plot_label}"
xlabel = rf"{self._online_offline} $p_T$ [GeV]"

obj_spec_split = obj_specifier.split(":")
if len(obj_spec_split) == 3:
label += f", {obj_spec_split[2]}"

plot_dict[obj_specifier] = {
"x_values": xvals,
"y_values": yvals,
"object": obj_instances[version].plot_label,
"label": label,
"version": version,
"xlabel": rf"{self._online_offline} $p_T$ [GeV]",
"xlabel": xlabel,
}

ax.plot(
Expand Down Expand Up @@ -238,7 +245,7 @@ def _load_cached_arrays(self):

return arr

def compute_rate(self, thresholds: np.ndarray) -> dict:
def compute_rate(self, thresholds: np.ndarray, nObj = 1) -> dict:
"""Computes rate at threholds after application of all object cuts.
threshold: pt threshold for which to compute rate
Expand All @@ -253,7 +260,9 @@ def compute_rate(self, thresholds: np.ndarray) -> dict:
pt_field = "offline_pt" if self.apply_offline_conversion else "pt"

if (max_pt_obj := self.arrays[obj_mask][pt_field]).ndim > 1:
max_pt_obj = ak.max(max_pt_obj, axis=1)
# max_pt_obj = ak.max(max_pt_obj, axis=1)
max_pt_obj = max_pt_obj[ak.argsort(max_pt_obj, axis=1, ascending=False)][:,nObj-1:]
max_pt_obj = ak.fill_none(ak.firsts(max_pt_obj), -1)

cumsum = np.cumsum(
np.histogram(max_pt_obj, bins=[-1] + list(thresholds) + [1e5])[0]
Expand Down Expand Up @@ -307,7 +316,7 @@ def _compute_rates(
apply_offline_conversion,
)

rate_data[version] = rate_computer.compute_rate(self.get_bins(plot_config))
rate_data[version] = rate_computer.compute_rate(self.get_bins(plot_config), nObj = plot_config.nObjects)

return rate_data

Expand All @@ -328,6 +337,9 @@ def run(self, apply_offline_conversion: bool = False) -> None:
plot_config = RatePlotConfig(cfg_plot, plot_name)
rate_plot_data = {}

if plot_config.nObjects > 1:
print(f"## Warning! Making rates for {plot_config.nObjects} objects!")

# Iterate over test objects in plot
for (
obj_specifier,
Expand Down
7 changes: 7 additions & 0 deletions menu_tools/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ def version(self) -> str:
except KeyError:
raise KeyError(f"No version configured for {self.plot_name}!")

@property
def nObjects(self) -> int:
if "nObjects" in self._cfg:
return int(self._cfg["nObjects"])
else:
return 1

@property
def bin_width(self) -> float:
return float(self._cfg["binning"]["step"])
Expand Down

0 comments on commit 1b4e50a

Please sign in to comment.