1
1
"""Interface to convert pyhf likelihoods to simplified likelihood framework"""
2
2
import copy
3
+ import logging
3
4
import warnings
4
- from typing import Callable , List , Optional , Text , Union , Literal
5
+ from contextlib import contextmanager
6
+ from typing import Callable , List , Literal , Optional , Text , Union
5
7
6
8
import numpy as np
7
9
import spey
@@ -18,6 +20,11 @@ def __dir__():
18
20
return []
19
21
20
22
23
+ # pylint: disable=W1203, R0903
24
+
25
+ log = logging .getLogger ("Spey" )
26
+
27
+
21
28
class ConversionError (Exception ):
22
29
"""Conversion error class"""
23
30
@@ -41,6 +48,22 @@ def func(vector: np.ndarray) -> float:
41
48
return func
42
49
43
50
51
+ @contextmanager
52
+ def _disable_logging (highest_level : int = logging .CRITICAL ):
53
+ """
54
+ Temporary disable logging implementation, this should move into Spey
55
+
56
+ Args:
57
+ highest_level (``int``, default ``logging.CRITICAL``): highest level to be set in logging
58
+ """
59
+ previous_level = logging .root .manager .disable
60
+ logging .disable (highest_level )
61
+ try :
62
+ yield
63
+ finally :
64
+ logging .disable (previous_level )
65
+
66
+
44
67
class Simplify (spey .ConverterBase ):
45
68
r"""
46
69
An interface to convert pyhf full statistical model prescription into simplified likelihood
@@ -175,9 +198,10 @@ def __call__(
175
198
}[fittype ]
176
199
177
200
interpreter = WorkspaceInterpreter (bkgonly_model )
201
+ bin_map = interpreter .bin_map
178
202
179
203
# configure signal patch map with respect to channel names
180
- signal_patch_map = interpreter .patch_to_map (signal_patch )
204
+ signal_patch_map , signal_modifiers_map = interpreter .patch_to_map (signal_patch )
181
205
182
206
# Prepare a JSON patch to separate control and validation regions
183
207
# These regions are generally marked as CR and VR
@@ -190,25 +214,26 @@ def __call__(
190
214
)
191
215
192
216
for channel in interpreter .get_channels (control_region_indices ):
193
- interpreter .inject_signal (
194
- channel ,
195
- [0.0 ] * len (signal_patch_map [channel ]["data" ]),
196
- signal_patch_map [channel ]["modifiers" ]
197
- if include_modifiers_in_control_model
198
- else None ,
199
- )
217
+ if channel in signal_patch_map and channel in signal_modifiers_map :
218
+ interpreter .inject_signal (
219
+ channel ,
220
+ [0.0 ] * bin_map [channel ],
221
+ signal_modifiers_map [channel ]
222
+ if include_modifiers_in_control_model
223
+ else None ,
224
+ )
200
225
201
226
pdf_wrapper = spey .get_backend ("pyhf" )
202
- control_model = pdf_wrapper (
203
- background_only_model = bkgonly_model , signal_patch = interpreter .make_patch ()
204
- )
227
+ with _disable_logging ():
228
+ control_model = pdf_wrapper (
229
+ background_only_model = bkgonly_model , signal_patch = interpreter .make_patch ()
230
+ )
205
231
206
232
# Extract the nuisance parameters that maximises the likelihood at mu=0
207
233
fit_opts = control_model .prepare_for_fit (expected = expected )
208
234
_ , fit_param = fit (
209
235
** fit_opts ,
210
236
initial_parameters = None ,
211
- bounds = None ,
212
237
fixed_poi_value = 0.0 ,
213
238
)
214
239
@@ -234,13 +259,33 @@ def __call__(
234
259
)
235
260
236
261
# Retreive pyhf models and compare parameter maps
237
- stat_model_pyhf = statistical_model .backend .model ()[1 ]
262
+ if include_modifiers_in_control_model :
263
+ stat_model_pyhf = statistical_model .backend .model ()[1 ]
264
+ else :
265
+ # Remove the nuisance parameters from the signal patch
266
+ # Note that even if the signal yields are zero, nuisance parameters
267
+ # do contribute to the statistical model and some models may be highly
268
+ # sensitive to the shape and size of the nuisance parameters.
269
+ with _disable_logging ():
270
+ tmp_interpreter = copy .deepcopy (interpreter )
271
+ for channel , data in signal_patch_map .items ():
272
+ tmp_interpreter .inject_signal (channel = channel , data = data )
273
+ tmp_model = spey .get_backend ("pyhf" )(
274
+ background_only_model = bkgonly_model ,
275
+ signal_patch = tmp_interpreter .make_patch (),
276
+ )
277
+ stat_model_pyhf = tmp_model .backend .model ()[1 ]
278
+ del tmp_model , tmp_interpreter
238
279
control_model_pyhf = control_model .backend .model ()[1 ]
239
280
is_nuisance_map_different = (
240
281
stat_model_pyhf .config .par_map != control_model_pyhf .config .par_map
241
282
)
242
283
fit_opts = statistical_model .prepare_for_fit (expected = expected )
243
284
suggested_fixed = fit_opts ["model_configuration" ].suggested_fixed
285
+ log .debug (
286
+ "Number of parameters to be fitted during the scan: "
287
+ f"{ fit_opts ['model_configuration' ].npar - len (fit_param )} "
288
+ )
244
289
245
290
samples = []
246
291
warnings_list = []
@@ -290,7 +335,9 @@ def __call__(
290
335
_ , new_params = fit (
291
336
** current_fit_opts ,
292
337
initial_parameters = init_params .tolist (),
293
- bounds = None ,
338
+ bounds = current_fit_opts [
339
+ "model_configuration"
340
+ ].suggested_bounds ,
294
341
)
295
342
warnings_list += w
296
343
@@ -304,13 +351,16 @@ def __call__(
304
351
# Some of the samples can lead to problems while sampling from a poisson distribution.
305
352
# e.g. poisson requires positive lambda values to sample from. If sample leads to a negative
306
353
# lambda value continue sampling to avoid that point.
354
+ log .debug ("Problem with the sample generation" )
355
+ log .debug (
356
+ f"Nuisance parameters: { current_nui_params if new_params is None else new_params } "
357
+ )
307
358
continue
308
359
309
360
if len (warnings_list ) > 0 :
310
- warnings .warn (
311
- message = f"{ len (warnings_list )} warning(s) generated during sampling."
312
- " This might be due to edge cases in nuisance parameter sampling." ,
313
- category = RuntimeWarning ,
361
+ log .warning (
362
+ f"{ len (warnings_list )} warning(s) generated during sampling."
363
+ " This might be due to edge cases in nuisance parameter sampling."
314
364
)
315
365
316
366
samples = np .vstack (samples )
@@ -323,9 +373,19 @@ def __call__(
323
373
324
374
# NOTE: model spec might be modified within the pyhf workspace, thus
325
375
# yields needs to be reordered properly before constructing the simplified likelihood
326
- signal_yields = []
376
+ signal_yields , missing_channels = [], []
327
377
for channel_name in stat_model_pyhf .config .channels :
328
- signal_yields += signal_patch_map [channel_name ]["data" ]
378
+ try :
379
+ signal_yields += signal_patch_map [channel_name ]
380
+ except KeyError :
381
+ missing_channels .append (channel_name )
382
+ signal_yields += [0.0 ] * bin_map [channel_name ]
383
+ if len (missing_channels ) > 0 :
384
+ log .warning (
385
+ "Following channels are not in the signal patch,"
386
+ f" will be set to zero: { ', ' .join (missing_channels )} "
387
+ )
388
+
329
389
# NOTE background yields are first moments in simplified framework not the yield values
330
390
# in the full statistical model!
331
391
background_yields = np .mean (samples , axis = 0 )
0 commit comments