Skip to content

Commit e60f7ac

Browse files
authored
Merge pull request #13 from SpeysideHEP/simplify-update
Missing channel handling in full llhd simplification
2 parents 33dae30 + 61fd5a8 commit e60f7ac

File tree

6 files changed

+130
-40
lines changed

6 files changed

+130
-40
lines changed

.zenodo.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
{
22
"description": "pyhf plug-in for spey package",
33
"license": "MIT",
4-
"title": "SpeysideHEP/spey-pyhf: v0.1.5",
5-
"version": "v0.1.5",
4+
"title": "SpeysideHEP/spey-pyhf: v0.1.6",
5+
"version": "v0.1.6",
66
"upload_type": "software",
77
"creators": [
88
{
@@ -29,7 +29,7 @@
2929
},
3030
{
3131
"scheme": "url",
32-
"identifier": "https://github.com/SpeysideHEP/spey-pyhf/tree/v0.1.5",
32+
"identifier": "https://github.com/SpeysideHEP/spey-pyhf/tree/v0.1.6",
3333
"relation": "isSupplementTo"
3434
},
3535
{

docs/releases/changelog-v0.1.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
* Improve undefined channel handling in the patchset
2525
([#12](https://github.com/SpeysideHEP/spey-pyhf/pull/12))
2626

27+
* Improve undefined channel handling in the patchset for full likelihood simplification.
28+
([#13](https://github.com/SpeysideHEP/spey-pyhf/pull/13))
29+
30+
* Add modifier check to signal injection.
31+
([#13](https://github.com/SpeysideHEP/spey-pyhf/pull/13))
32+
2733
## Bug fixes
2834

2935
* Bugfix in `simplify` module, where signal injector was not initiated properly.

setup.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
url="https://github.com/SpeysideHEP/spey-pyhf",
3131
project_urls={
3232
"Bug Tracker": "https://github.com/SpeysideHEP/spey-pyhf/issues",
33+
"Documentation": "https://spey-pyhf.readthedocs.io",
34+
"Repository": "https://github.com/SpeysideHEP/spey-pyhf",
35+
"Homepage": "https://github.com/SpeysideHEP/spey-pyhf",
36+
"Download": f"https://github.com/SpeysideHEP/spey-pyhf/archive/refs/tags/v{version}.tar.gz",
3337
},
3438
download_url=f"https://github.com/SpeysideHEP/spey-pyhf/archive/refs/tags/v{version}.tar.gz",
3539
author="Jack Y. Araz",
@@ -50,8 +54,14 @@
5054
"Intended Audience :: Science/Research",
5155
"License :: OSI Approved :: MIT License",
5256
"Operating System :: OS Independent",
53-
"Programming Language :: Python :: 3",
5457
"Topic :: Scientific/Engineering :: Physics",
58+
"Programming Language :: Python",
59+
"Programming Language :: Python :: 3",
60+
"Programming Language :: Python :: 3.8",
61+
"Programming Language :: Python :: 3.9",
62+
"Programming Language :: Python :: 3.10",
63+
"Programming Language :: Python :: 3.11",
64+
"Programming Language :: Python :: 3.12",
5565
],
5666
extras_require={
5767
"dev": ["pytest>=7.1.2", "pytest-cov>=3.0.0", "twine>=3.7.1", "wheel>=0.37.1"],

src/spey_pyhf/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Version of the spey - pyhf plugin"""
22

3-
__version__ = "0.1.5"
3+
__version__ = "0.1.6"

src/spey_pyhf/helper_functions.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ def __dir__():
99
return __all__
1010

1111

12+
# pylint: disable=W1203, W1201, C0103
13+
1214
log = logging.getLogger("Spey")
1315

1416

@@ -168,6 +170,7 @@ def inject_signal(
168170
Args:
169171
channel (``Text``): channel name
170172
data (``List[float]``): signal yields
173+
modifiers (``List[Dict]``): uncertainties. If None, default modifiers will be added.
171174
172175
Raises:
173176
``ValueError``: If channel does not exist or number of yields does not match
@@ -184,9 +187,20 @@ def inject_signal(
184187
f"{self.bin_map[channel]} expected, {len(data)} received."
185188
)
186189

190+
default_modifiers = _default_modifiers(self.poi_name[0][1])
191+
if modifiers is not None:
192+
for mod in default_modifiers:
193+
if mod not in modifiers:
194+
log.warning(
195+
f"Modifier `{mod['name']}` with type `{mod['type']}` is missing"
196+
f" from the input. Adding `{mod['name']}`"
197+
)
198+
log.debug(f"Adding modifier: {mod}")
199+
modifiers.append(mod)
200+
187201
self._signal_dict[channel] = data
188202
self._signal_modifiers[channel] = (
189-
_default_modifiers(self.poi_name[0][1]) if modifiers is None else modifiers
203+
default_modifiers if modifiers is None else modifiers
190204
)
191205

192206
@property
@@ -237,7 +251,7 @@ def reset_signal(self) -> None:
237251

238252
def add_patch(self, signal_patch: List[Dict]) -> None:
239253
"""Inject signal patch"""
240-
self._signal_dict, self._to_remove = self.patch_to_map(
254+
self._signal_dict, self._signal_modifiers, self._to_remove = self.patch_to_map(
241255
signal_patch=signal_patch, return_remove_list=True
242256
)
243257

@@ -272,7 +286,10 @@ def remove_list(self) -> List[Text]:
272286

273287
def patch_to_map(
274288
self, signal_patch: List[Dict], return_remove_list: bool = False
275-
) -> Union[Tuple[Dict[Text, Dict], List[Text]], Dict[Text, Dict]]:
289+
) -> Union[
290+
Tuple[Dict[Text, Dict], Dict[Text, Dict], List[Text]],
291+
Tuple[Dict[Text, Dict], Dict[Text, Dict]],
292+
]:
276293
"""
277294
Convert JSONPatch into signal map
278295
@@ -288,23 +305,20 @@ def patch_to_map(
288305
.. versionadded:: 0.1.5
289306
290307
Returns:
291-
``Tuple[Dict[Text, Dict], List[Text]]`` or ``Dict[Text, Dict]``:
308+
``Tuple[Dict[Text, Dict], Dict[Text, Dict], List[Text]]`` or ``Tuple[Dict[Text, Dict], Dict[Text, Dict]]``:
292309
signal map including the data and modifiers and the list of channels to be removed.
293310
"""
294-
signal_map = {}
295-
to_remove = []
311+
signal_map, modifier_map, to_remove = {}, {}, []
296312
for item in signal_patch:
297313
path = int(item["path"].split("/")[2])
298314
channel_name = self["channels"][path]["name"]
299315
if item["op"] == "add":
300-
signal_map[channel_name] = {
301-
"data": item["value"]["data"],
302-
"modifiers": item["value"].get(
303-
"modifiers", _default_modifiers(poi_name=self.poi_name[0][1])
304-
),
305-
}
316+
signal_map[channel_name] = item["value"]["data"]
317+
modifier_map[channel_name] = item["value"].get(
318+
"modifiers", _default_modifiers(poi_name=self.poi_name[0][1])
319+
)
306320
elif item["op"] == "remove":
307321
to_remove.append(channel_name)
308322
if return_remove_list:
309-
return signal_map, to_remove
310-
return signal_map
323+
return signal_map, modifier_map, to_remove
324+
return signal_map, modifier_map

src/spey_pyhf/simplify.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Interface to convert pyhf likelihoods to simplified likelihood framework"""
22
import copy
3+
import logging
34
import warnings
4-
from typing import Callable, List, Optional, Text, Union, Literal
5+
from contextlib import contextmanager
6+
from typing import Callable, List, Literal, Optional, Text, Union
57

68
import numpy as np
79
import spey
@@ -18,6 +20,11 @@ def __dir__():
1820
return []
1921

2022

23+
# pylint: disable=W1203, R0903
24+
25+
log = logging.getLogger("Spey")
26+
27+
2128
class ConversionError(Exception):
2229
"""Conversion error class"""
2330

@@ -41,6 +48,22 @@ def func(vector: np.ndarray) -> float:
4148
return func
4249

4350

51+
@contextmanager
52+
def _disable_logging(highest_level: int = logging.CRITICAL):
53+
"""
54+
Temporary disable logging implementation, this should move into Spey
55+
56+
Args:
57+
highest_level (``int``, default ``logging.CRITICAL``): highest level to be set in logging
58+
"""
59+
previous_level = logging.root.manager.disable
60+
logging.disable(highest_level)
61+
try:
62+
yield
63+
finally:
64+
logging.disable(previous_level)
65+
66+
4467
class Simplify(spey.ConverterBase):
4568
r"""
4669
An interface to convert pyhf full statistical model prescription into simplified likelihood
@@ -175,9 +198,10 @@ def __call__(
175198
}[fittype]
176199

177200
interpreter = WorkspaceInterpreter(bkgonly_model)
201+
bin_map = interpreter.bin_map
178202

179203
# configure signal patch map with respect to channel names
180-
signal_patch_map = interpreter.patch_to_map(signal_patch)
204+
signal_patch_map, signal_modifiers_map = interpreter.patch_to_map(signal_patch)
181205

182206
# Prepare a JSON patch to separate control and validation regions
183207
# These regions are generally marked as CR and VR
@@ -190,25 +214,26 @@ def __call__(
190214
)
191215

192216
for channel in interpreter.get_channels(control_region_indices):
193-
interpreter.inject_signal(
194-
channel,
195-
[0.0] * len(signal_patch_map[channel]["data"]),
196-
signal_patch_map[channel]["modifiers"]
197-
if include_modifiers_in_control_model
198-
else None,
199-
)
217+
if channel in signal_patch_map and channel in signal_modifiers_map:
218+
interpreter.inject_signal(
219+
channel,
220+
[0.0] * bin_map[channel],
221+
signal_modifiers_map[channel]
222+
if include_modifiers_in_control_model
223+
else None,
224+
)
200225

201226
pdf_wrapper = spey.get_backend("pyhf")
202-
control_model = pdf_wrapper(
203-
background_only_model=bkgonly_model, signal_patch=interpreter.make_patch()
204-
)
227+
with _disable_logging():
228+
control_model = pdf_wrapper(
229+
background_only_model=bkgonly_model, signal_patch=interpreter.make_patch()
230+
)
205231

206232
# Extract the nuisance parameters that maximises the likelihood at mu=0
207233
fit_opts = control_model.prepare_for_fit(expected=expected)
208234
_, fit_param = fit(
209235
**fit_opts,
210236
initial_parameters=None,
211-
bounds=None,
212237
fixed_poi_value=0.0,
213238
)
214239

@@ -234,13 +259,33 @@ def __call__(
234259
)
235260

236261
# Retreive pyhf models and compare parameter maps
237-
stat_model_pyhf = statistical_model.backend.model()[1]
262+
if include_modifiers_in_control_model:
263+
stat_model_pyhf = statistical_model.backend.model()[1]
264+
else:
265+
# Remove the nuisance parameters from the signal patch
266+
# Note that even if the signal yields are zero, nuisance parameters
267+
# do contribute to the statistical model and some models may be highly
268+
# sensitive to the shape and size of the nuisance parameters.
269+
with _disable_logging():
270+
tmp_interpreter = copy.deepcopy(interpreter)
271+
for channel, data in signal_patch_map.items():
272+
tmp_interpreter.inject_signal(channel=channel, data=data)
273+
tmp_model = spey.get_backend("pyhf")(
274+
background_only_model=bkgonly_model,
275+
signal_patch=tmp_interpreter.make_patch(),
276+
)
277+
stat_model_pyhf = tmp_model.backend.model()[1]
278+
del tmp_model, tmp_interpreter
238279
control_model_pyhf = control_model.backend.model()[1]
239280
is_nuisance_map_different = (
240281
stat_model_pyhf.config.par_map != control_model_pyhf.config.par_map
241282
)
242283
fit_opts = statistical_model.prepare_for_fit(expected=expected)
243284
suggested_fixed = fit_opts["model_configuration"].suggested_fixed
285+
log.debug(
286+
"Number of parameters to be fitted during the scan: "
287+
f"{fit_opts['model_configuration'].npar - len(fit_param)}"
288+
)
244289

245290
samples = []
246291
warnings_list = []
@@ -290,7 +335,9 @@ def __call__(
290335
_, new_params = fit(
291336
**current_fit_opts,
292337
initial_parameters=init_params.tolist(),
293-
bounds=None,
338+
bounds=current_fit_opts[
339+
"model_configuration"
340+
].suggested_bounds,
294341
)
295342
warnings_list += w
296343

@@ -304,13 +351,16 @@ def __call__(
304351
# Some of the samples can lead to problems while sampling from a poisson distribution.
305352
# e.g. poisson requires positive lambda values to sample from. If sample leads to a negative
306353
# lambda value continue sampling to avoid that point.
354+
log.debug("Problem with the sample generation")
355+
log.debug(
356+
f"Nuisance parameters: {current_nui_params if new_params is None else new_params}"
357+
)
307358
continue
308359

309360
if len(warnings_list) > 0:
310-
warnings.warn(
311-
message=f"{len(warnings_list)} warning(s) generated during sampling."
312-
" This might be due to edge cases in nuisance parameter sampling.",
313-
category=RuntimeWarning,
361+
log.warning(
362+
f"{len(warnings_list)} warning(s) generated during sampling."
363+
" This might be due to edge cases in nuisance parameter sampling."
314364
)
315365

316366
samples = np.vstack(samples)
@@ -323,9 +373,19 @@ def __call__(
323373

324374
# NOTE: model spec might be modified within the pyhf workspace, thus
325375
# yields needs to be reordered properly before constructing the simplified likelihood
326-
signal_yields = []
376+
signal_yields, missing_channels = [], []
327377
for channel_name in stat_model_pyhf.config.channels:
328-
signal_yields += signal_patch_map[channel_name]["data"]
378+
try:
379+
signal_yields += signal_patch_map[channel_name]
380+
except KeyError:
381+
missing_channels.append(channel_name)
382+
signal_yields += [0.0] * bin_map[channel_name]
383+
if len(missing_channels) > 0:
384+
log.warning(
385+
"Following channels are not in the signal patch,"
386+
f" will be set to zero: {', '.join(missing_channels)}"
387+
)
388+
329389
# NOTE background yields are first moments in simplified framework not the yield values
330390
# in the full statistical model!
331391
background_yields = np.mean(samples, axis=0)

0 commit comments

Comments
 (0)