Skip to content

Commit

Permalink
Merge pull request #12 from SpeysideHEP/incomplete_patch
Browse files Browse the repository at this point in the history
Incomplete signal patch treatment
  • Loading branch information
jackaraz committed Jul 12, 2024
2 parents 49499d9 + 125d62e commit 33dae30
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 33 deletions.
6 changes: 3 additions & 3 deletions .zenodo.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"description": "pyhf plug-in for spey package",
"license": "MIT",
"title": "SpeysideHEP/spey-pyhf: v0.1.4",
"version": "v0.1.4",
"title": "SpeysideHEP/spey-pyhf: v0.1.5",
"version": "v0.1.5",
"upload_type": "software",
"creators": [
{
Expand All @@ -29,7 +29,7 @@
},
{
"scheme": "url",
"identifier": "https://github.com/SpeysideHEP/spey-pyhf/tree/v0.1.4",
"identifier": "https://github.com/SpeysideHEP/spey-pyhf/tree/v0.1.5",
"relation": "isSupplementTo"
},
{
Expand Down
3 changes: 3 additions & 0 deletions docs/releases/changelog-v0.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
* Model loading has been improved for prefit and postfit scenarios.
([#10](https://github.com/SpeysideHEP/spey-pyhf/pull/10))

* Improve undefined channel handling in the patchset
([#12](https://github.com/SpeysideHEP/spey-pyhf/pull/12))

## Bug fixes

* Bugfix in `simplify` module, where signal injector was not initiated properly.
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ we can inject signal to any channel we like
````{margin}
```{admonition} Attention!
:class: attention
Notice that the rest of the channels will be removed. If some of the channels are needed during the inference, simply remove the ones with `"op": "remove"` tag from the patch set. Patch set can be generated via `interpreter.make_patch()` function.
Notice that the rest of the channels will be added without any signal yields. If some of these channels need to be removed from the patch set, they can be added to the remove list via the ``remove_channel()`` function. **Note:** This behaviour has been updated in ``v0.1.5``. In the older versions, the channels that were not declared were removed.
```
````

```{code-cell} ipython3
interpreter.inject_signal('SRHMEM_mct2', [5.0, 12.0, 4.0])
```

Notice that I only added 3 inputs since the `"SRHMEM_mct2"` region has only 3 bins. One can inject signals to as many channels as one wants, but for simplicity, we will use only one channel. Now we are ready to export this signal patch and compute the exclusion limit
Notice that we only added 3 inputs since the `"SRHMEM_mct2"` region has only 3 bins. One can inject signals to as many channels as one wants, but for simplicity, we will use only one channel. Now we are ready to export this signal patch and compute the exclusion limit

```{code-cell} ipython3
pdf_wrapper = spey.get_backend("pyhf")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
with open("src/spey_pyhf/_version.py", mode="r", encoding="UTF-8") as f:
version = f.readlines()[-1].split()[-1].strip("\"'")

requirements = ["pyhf==0.7.6", "spey>=0.1.5"]
requirements = ["pyhf==0.7.6", "spey>=0.1.9"]

docs = [
"sphinx==6.2.1",
Expand Down
2 changes: 1 addition & 1 deletion src/spey_pyhf/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version of the spey - pyhf plugin"""

__version__ = "0.1.4"
__version__ = "0.1.5"
46 changes: 33 additions & 13 deletions src/spey_pyhf/data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional, List, Tuple, Dict, Text, Union, Iterator

from dataclasses import dataclass
import copy
import json
import os
from abc import ABC, abstractmethod
import json, copy, os
import numpy as np
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Text, Tuple, Union

from spey.base import ModelConfig
import numpy as np
from spey import ExpectationType
from spey.base import ModelConfig
from spey.system.exceptions import InvalidInput

from . import manager, WorkspaceInterpreter
from . import WorkspaceInterpreter, manager


class Base(ABC):
Expand Down Expand Up @@ -289,13 +291,31 @@ def __call__(
)

if expected == ExpectationType.apriori:
data = sum(
(
self.expected_background_yields[ch]
try:
data = sum(
(
self.expected_background_yields[ch]
for ch in self._model.config.channels
),
[],
)
except KeyError as err:
# provide a useful error message to guide the user to the solution
missing_channels = [
ch
for ch in self._model.config.channels
),
[],
)
if ch not in self.expected_background_yields
]
raise InvalidInput(
"Unable to construct expected data. "
+ (len(missing_channels) > 0)
* (
"\nThis is likely due to missing channels in the signal patch. "
+ "The missing channels are: "
+ ", ".join(missing_channels)
+ "\nPlease provide appropriate action for the missing channels to continue."
)
) from err
if include_aux:
data += self._model.config.auxdata
else:
Expand Down
92 changes: 80 additions & 12 deletions src/spey_pyhf/helper_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Helper function for creating and interpreting pyhf inputs"""
from typing import Dict, Iterator, List, Text, Union, Optional
import logging
from typing import Dict, Iterator, List, Optional, Text, Tuple, Union

__all__ = ["WorkspaceInterpreter"]

Expand All @@ -8,15 +9,18 @@ def __dir__():
return __all__


def remove_from_json(idx: int) -> Dict:
log = logging.getLogger("Spey")


def remove_from_json(idx: int) -> Dict[Text, Text]:
"""
Remove channel from the json file
Args:
idx (``int``): index of the channel
Returns:
``Dict``:
``Dict[Text, Text]``:
JSON patch
"""
return {"op": "remove", "path": f"/channels/{idx}"}
Expand Down Expand Up @@ -59,13 +63,19 @@ class WorkspaceInterpreter:
background_only_model (``Dict``): descrioption for the background only statistical model
"""

__slots__ = ["background_only_model", "_signal_dict", "_signal_modifiers"]
__slots__ = [
"background_only_model",
"_signal_dict",
"_signal_modifiers",
"_to_remove",
]

def __init__(self, background_only_model: Dict):
self.background_only_model = background_only_model
"""Background only statistical model description"""
self._signal_dict = {}
self._signal_modifiers = {}
self._to_remove = []

def __getitem__(self, item):
return self.background_only_model[item]
Expand All @@ -89,15 +99,29 @@ def bin_map(self) -> Dict[Text, int]:
def expected_background_yields(self) -> Dict[Text, List[float]]:
"""Retreive expected background yields with respect to signal injection"""
yields = {}
undefined_channels = []
for channel in self["channels"]:
if channel["name"] in self._signal_dict:
if channel["name"] not in self.remove_list:
yields[channel["name"]] = []
for smp in channel["samples"]:
if len(yields[channel["name"]]) == 0:
yields[channel["name"]] = [0.0] * len(smp["data"])
yields[channel["name"]] = [
ch + dt for ch, dt in zip(yields[channel["name"]], smp["data"])
]
if channel["name"] not in self._signal_dict:
undefined_channels.append(channel["name"])
if len(undefined_channels) > 0:
log.warning(
"Some of the channels are not defined in the patch set, "
"these channels will be kept in the statistical model. "
)
log.warning(
"If these channels are meant to be removed, please indicate them in the patch set."
)
log.warning(
"Please check the following channel(s): " + ", ".join(undefined_channels)
)
return yields

def guess_channel_type(self, channel_name: Text) -> Text:
Expand Down Expand Up @@ -197,8 +221,10 @@ def make_patch(self) -> List[Dict]:
ich, self._signal_dict[channel], self._signal_modifiers[channel]
)
)
else:
elif channel in self._to_remove:
to_remove.append(remove_from_json(ich))
else:
log.warning(f"Undefined channel in the patch set: {channel}")

to_remove.sort(key=lambda p: p["path"].split("/")[-1], reverse=True)

Expand All @@ -207,12 +233,46 @@ def make_patch(self) -> List[Dict]:
def reset_signal(self) -> None:
"""Clear the signal map"""
self._signal_dict = {}
self._to_remove = []

def add_patch(self, signal_patch: List[Dict]) -> None:
"""Inject signal patch"""
self._signal_dict = self.patch_to_map(signal_patch=signal_patch)
self._signal_dict, self._to_remove = self.patch_to_map(
signal_patch=signal_patch, return_remove_list=True
)

def patch_to_map(self, signal_patch: List[Dict]) -> Dict[Text, Dict]:
def remove_channel(self, channel_name: Text) -> None:
"""
Remove channel from the likelihood
.. versionadded:: 0.1.5
Args:
channel_name (``Text``): name of the channel to be removed
"""
if channel_name in self.channels:
if channel_name not in self._to_remove:
self._to_remove.append(channel_name)
else:
log.error(
f"Channel {channel_name} does not exist in the background only model. "
+ "The available channels are "
+ ", ".join(list(self.channels))
)

@property
def remove_list(self) -> List[Text]:
"""
Channels to be removed from the model
.. versionadded:: 0.1.5
"""
return self._to_remove

def patch_to_map(
self, signal_patch: List[Dict], return_remove_list: bool = False
) -> Union[Tuple[Dict[Text, Dict], List[Text]], Dict[Text, Dict]]:
"""
Convert JSONPatch into signal map
Expand All @@ -223,20 +283,28 @@ def patch_to_map(self, signal_patch: List[Dict]) -> Dict[Text, Dict]:
Args:
signal_patch (``List[Dict]``): JSONPatch for the signal
return_remove_list (``bool``, default ``False``): Inclure channels to be removed in the output
.. versionadded:: 0.1.5
Returns:
``Dict[Text, Dict]``:
signal map including the data and modifiers
``Tuple[Dict[Text, Dict], List[Text]]`` or ``Dict[Text, Dict]``:
signal map including the data and modifiers and the list of channels to be removed.
"""
signal_map = {}
to_remove = []
for item in signal_patch:
path = int(item["path"].split("/")[2])
channel_name = self["channels"][path]["name"]
if item["op"] == "add":
path = int(item["path"].split("/")[2])
channel_name = self["channels"][path]["name"]
signal_map[channel_name] = {
"data": item["value"]["data"],
"modifiers": item["value"].get(
"modifiers", _default_modifiers(poi_name=self.poi_name[0][1])
),
}
elif item["op"] == "remove":
to_remove.append(channel_name)
if return_remove_list:
return signal_map, to_remove
return signal_map
2 changes: 1 addition & 1 deletion src/spey_pyhf/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class PyhfInterface(BackendBase):
"""Version of the backend"""
author: Text = "SpeysideHEP"
"""Author of the backend"""
spey_requires: Text = ">=0.1.5,<0.2.0"
spey_requires: Text = ">=0.1.9,<0.2.0"
"""Spey version required for the backend"""
doi: List[Text] = ["10.5281/zenodo.1169739", "10.21105/joss.02823"]
"""Citable DOI for the backend"""
Expand Down

0 comments on commit 33dae30

Please sign in to comment.