From 36fb7d60e78ced71ad41a264295e28ce17498044 Mon Sep 17 00:00:00 2001
From: "Jack Y. Araz" <jackaraz@gmail.com>
Date: Fri, 10 Jan 2025 17:08:32 -0500
Subject: [PATCH] update

---
 madanalysis/misc/statistical_models.py | 50 ++++++++++++++++++++++++++
 1 file changed, 50 insertions(+)

diff --git a/madanalysis/misc/statistical_models.py b/madanalysis/misc/statistical_models.py
index b2ba08fa..f2601bf1 100644
--- a/madanalysis/misc/statistical_models.py
+++ b/madanalysis/misc/statistical_models.py
@@ -1,6 +1,13 @@
 import spey
+import logging
 from .histfactory_reader import HF_Background, HF_Signal
 
+APRIORI = spey.ExpectationType.apriori
+APOSTERIORI = spey.ExpectationType.aposteriori
+OBSERVED = spey.ExpectationType.observed
+
+logger = logging.getLogger("MA5")
+
 
 def initialise_statistical_models(
     regiondata: dict,
@@ -84,3 +91,46 @@ def initialise_statistical_models(
         "simplified_likelihoods": simplified_likelihoods,
         "full_likelihoods": full_likelihoods,
     }
+
+
+def compute_poi_upper_limits(
+    regiondata: dict,
+    stat_models: dict,
+    xsection: float,
+    is_extrapolated: bool,
+    record_to: str = None,
+) -> dict:  # pylint: disable=too-many-arguments
+    """
+    Compute upper limit on cross section.
+
+    Args:
+        regiondata (``dict``): data for each region
+        regions (``list[str]``): list of regions
+        xsection (``float``): cross section
+        lumi (``float``): luminosity
+        is_extrapolated (``bool``): extrapolated luminosity
+        record_to (``str``): record to a specific section in regiondata
+
+    Returns:
+        ``dict``:
+        regiondata
+    """
+    logger.debug("Compute signal CL...")
+    if record_to is not None and record_to not in regiondata.keys():
+        regiondata[record_to] = {}
+    tags = (
+        [[APRIORI], ["exp"]]
+        if is_extrapolated
+        else [[APOSTERIORI, OBSERVED], ["exp", "obs"]]
+    )
+
+    for tag, label in zip(*tags):
+        for reg, stat_model in stat_models.items():
+            s95 = stat_model.poi_upper_limit(expected=tag) * xsection
+            if record_to is None:
+                logger.debug(f"region {reg} s95{label} = {s95:.5f} pb")
+                regiondata[reg]["s95" + label] = "%-20.7f" % s95
+            else:
+                logger.debug(f"{record_to}:: region {reg} s95{label} = {s95:.5f} pb")
+                regiondata[record_to][reg]["s95" + label] = "%-20.7f" % s95
+    return regiondata