29
29
from pisa .core .param import ParamSet , DerivedParam
30
30
from pisa .core .stage import Stage
31
31
from pisa .core .container import Container , ContainerSet
32
- from pisa .core .binning import MultiDimBinning , VarBinning
32
+ from pisa .core .binning import MultiDimBinning , OneDimBinning , VarBinning
33
33
from pisa .utils .config_parser import PISAConfigParser , parse_pipeline_config
34
34
from pisa .utils .fileio import mkdir
35
35
from pisa .utils .format import format_times
@@ -370,6 +370,84 @@ def get_outputs(self, **get_outputs_kwargs):
370
370
outputs = self ._get_outputs (** get_outputs_kwargs )
371
371
return outputs
372
372
373
+ def _get_outputs_multdimbinning (self , output_binning , output_key ):
374
+ """Logic that produces a single `MapSet` when the pipeline's
375
+ output binning is a regular `MultiDimBinning`.
376
+
377
+ Returns
378
+ -------
379
+ outputs : MapSet
380
+
381
+ """
382
+ self .data .representation = output_binning
383
+ if isinstance (output_key , tuple ):
384
+ assert len (output_key ) == 2
385
+ outputs = self .data .get_mapset (output_key [0 ], error = output_key [1 ])
386
+ else :
387
+ outputs = self .data .get_mapset (output_key )
388
+ return outputs
389
+
390
+ def _get_outputs_varbinning (self , output_binning , output_key ):
391
+ """Logic that produces multiple `MapSet`s when the pipeline's
392
+ output binning is a `VarBinning`.
393
+
394
+ Returns
395
+ -------
396
+ outputs : list of MapSet
397
+
398
+ """
399
+ assert self .data .representation == "events"
400
+ outputs = []
401
+
402
+ selections = output_binning .selections
403
+ for i in range (output_binning .nselections ):
404
+ # there will be a new ContainerSet created for each selection
405
+ containers = []
406
+ for c in self .data .containers :
407
+ cc = Container (name = c .name )
408
+ # Find the events that belong to the given selection, depending on
409
+ # type of selection
410
+ if isinstance (selections , list ):
411
+ keep = c .get_keep_mask (selections [i ])
412
+ else :
413
+ assert isinstance (selections , OneDimBinning )
414
+ cut_var = c [selections .name ]
415
+ # cut on bin edges
416
+ keep = (cut_var >= selections .edge_magnitudes [i ]) & (cut_var < selections .edge_magnitudes [i + 1 ])
417
+ for var_name in output_binning .binnings [i ].names :
418
+ # Store the selected var_name entries (corresponding to the
419
+ # dimensions in which the data for this selection will be
420
+ # binned) in the fresh Container
421
+ cc [var_name ] = c [var_name ][keep ]
422
+ # store the quantities that will populate each bin
423
+ if isinstance (output_key , tuple ):
424
+ assert len (output_key ) == 2
425
+ cc [output_key [0 ]] = c [output_key [0 ]][keep ]
426
+ cc .tranlation_modes [output_key [0 ]] = 'sum'
427
+ cc [output_key [1 ]] = np .square (c [output_key [0 ]][keep ])
428
+ cc .tranlation_modes [output_key [1 ]] = 'sum'
429
+ else :
430
+ cc [output_key ] = c [output_key ][keep ]
431
+ cc .tranlation_modes [output_key ] = 'sum'
432
+
433
+ containers .append (cc )
434
+
435
+ dat = ContainerSet (
436
+ name = self .data .name ,
437
+ containers = containers ,
438
+ representation = output_binning .binnings [i ],
439
+ )
440
+
441
+ if isinstance (output_key , tuple ):
442
+ for c in dat .containers :
443
+ # uncertainties
444
+ c [output_key [1 ]] = np .sqrt (c [output_key [1 ]])
445
+ outputs .append (dat .get_mapset (output_key [0 ], error = output_key [1 ]))
446
+ else :
447
+ outputs .append (dat .get_mapset (output_key ))
448
+ return outputs
449
+
450
+
373
451
def _get_outputs (self , output_binning = None , output_key = None ):
374
452
"""Get MapSet output"""
375
453
@@ -391,55 +469,10 @@ def _get_outputs(self, output_binning=None, output_key=None):
391
469
assert (isinstance (output_binning , (MultiDimBinning , VarBinning )))
392
470
393
471
if isinstance (output_binning , MultiDimBinning ):
394
- self .data .representation = output_binning
395
-
396
- if isinstance (output_key , tuple ):
397
- assert len (output_key ) == 2
398
- outputs = self .data .get_mapset (output_key [0 ], error = output_key [1 ])
399
- else :
400
- outputs = self .data .get_mapset (output_key )
401
-
472
+ outputs = self ._get_outputs_multdimbinning (output_binning , output_key )
402
473
else :
403
474
assert isinstance (output_binning , VarBinning )
404
- assert self .data .representation == "events"
405
- outputs = []
406
-
407
- selections = output_binning .selections
408
- for i in range (len (output_binning .binnings )):
409
- containers = []
410
- for c in self .data .containers :
411
- cc = Container (name = c .name )
412
- if isinstance (selections , list ):
413
- keep = c .get_keep_mask (selections [i ])
414
- else :
415
- cut_var = c [selections .name ]
416
- keep = (cut_var >= selections .edge_magnitudes [i ]) & (cut_var < selections .edge_magnitudes [i + 1 ])
417
- for var_name in output_binning .binnings [i ].names :
418
- cc [var_name ] = c [var_name ][keep ]
419
-
420
- if isinstance (output_key , tuple ):
421
- assert len (output_key ) == 2
422
- cc [output_key [0 ]] = c [output_key [0 ]][keep ]
423
- cc .tranlation_modes [output_key [0 ]] = 'sum'
424
- cc [output_key [1 ]] = np .square (c [output_key [0 ]][keep ])
425
- cc .tranlation_modes [output_key [1 ]] = 'sum'
426
- else :
427
- cc [output_key ] = c [output_key ][keep ]
428
- cc .tranlation_modes [output_key ] = 'sum'
429
-
430
- containers .append (cc )
431
-
432
- dat = ContainerSet (name = self .data .name ,
433
- containers = containers ,
434
- representation = output_binning .binnings [i ],
435
- )
436
-
437
- if isinstance (output_key , tuple ):
438
- for c in dat .containers :
439
- c [output_key [1 ]] = np .sqrt (c [output_key [1 ]])
440
- outputs .append (dat .get_mapset (output_key [0 ], error = output_key [1 ]))
441
- else :
442
- outputs .append (dat .get_mapset (output_key ))
475
+ outputs = self ._get_outputs_varbinning (output_binning , output_key )
443
476
444
477
return outputs
445
478
@@ -636,11 +669,13 @@ def __hash__(self):
636
669
637
670
def assert_varbinning_compat (self ):
638
671
"""Asserts that pipeline setup is compatible with `VarBinning`:
639
- all stages need to apply to events.
672
+ all stages need to apply to events (this precludes use with
673
+ any histogramming service, which requires a binning as apply_mode).
640
674
641
675
Raises
642
676
------
643
- ValueError : if at least one stage has apply_mode!='events'
677
+ ValueError
678
+ if at least one stage has apply_mode!='events'
644
679
645
680
"""
646
681
incompat = []
@@ -667,7 +702,8 @@ def assert_exclusive_varbinning(self, output_binning=None):
667
702
668
703
Raises
669
704
------
670
- ValueError : if a `VarBinning` is tested and at least two selections
705
+ ValueError
706
+ if a `VarBinning` is tested and at least two selections
671
707
(if applicable) are not mutually exclusive
672
708
673
709
"""
@@ -1004,9 +1040,9 @@ def main(return_outputs=False):
1004
1040
pass
1005
1041
if isinstance (stop_idx , str ):
1006
1042
stop_idx = pipeline .index (stop_idx )
1007
- outputs = pipeline .get_outputs (
1043
+ outputs = pipeline .get_outputs ( # pylint: disable=redefined-outer-name
1008
1044
idx = stop_idx
1009
- ) # pylint: disable=redefined-outer-name
1045
+ )
1010
1046
if stop_idx is not None :
1011
1047
stop_idx += 1
1012
1048
indices = slice (0 , stop_idx )
@@ -1102,4 +1138,4 @@ def main(return_outputs=False):
1102
1138
1103
1139
1104
1140
if __name__ == "__main__" :
1105
- pipeline , outputs = main (return_outputs = True ) # pylint: disable=invalid-name
1141
+ pipeline , outp = main (return_outputs = True )
0 commit comments