Skip to content

Commit 7837426

Browse files
committed
superficial: comments & docstrings, thousands of lines of ipynb output cleared, moved core pipeline output calculations into separate functions for clarity
1 parent 94a0529 commit 7837426

File tree

2 files changed

+116
-6471
lines changed

2 files changed

+116
-6471
lines changed

pisa/core/pipeline.py

+90-54
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pisa.core.param import ParamSet, DerivedParam
3030
from pisa.core.stage import Stage
3131
from pisa.core.container import Container, ContainerSet
32-
from pisa.core.binning import MultiDimBinning, VarBinning
32+
from pisa.core.binning import MultiDimBinning, OneDimBinning, VarBinning
3333
from pisa.utils.config_parser import PISAConfigParser, parse_pipeline_config
3434
from pisa.utils.fileio import mkdir
3535
from pisa.utils.format import format_times
@@ -370,6 +370,84 @@ def get_outputs(self, **get_outputs_kwargs):
370370
outputs = self._get_outputs(**get_outputs_kwargs)
371371
return outputs
372372

373+
def _get_outputs_multdimbinning(self, output_binning, output_key):
374+
"""Logic that produces a single `MapSet` when the pipeline's
375+
output binning is a regular `MultiDimBinning`.
376+
377+
Returns
378+
-------
379+
outputs : MapSet
380+
381+
"""
382+
self.data.representation = output_binning
383+
if isinstance(output_key, tuple):
384+
assert len(output_key) == 2
385+
outputs = self.data.get_mapset(output_key[0], error=output_key[1])
386+
else:
387+
outputs = self.data.get_mapset(output_key)
388+
return outputs
389+
390+
def _get_outputs_varbinning(self, output_binning, output_key):
391+
"""Logic that produces multiple `MapSet`s when the pipeline's
392+
output binning is a `VarBinning`.
393+
394+
Returns
395+
-------
396+
outputs : list of MapSet
397+
398+
"""
399+
assert self.data.representation == "events"
400+
outputs = []
401+
402+
selections = output_binning.selections
403+
for i in range(output_binning.nselections):
404+
# there will be a new ContainerSet created for each selection
405+
containers = []
406+
for c in self.data.containers:
407+
cc = Container(name=c.name)
408+
# Find the events that belong to the given selection, depending on
409+
# type of selection
410+
if isinstance(selections, list):
411+
keep = c.get_keep_mask(selections[i])
412+
else:
413+
assert isinstance(selections, OneDimBinning)
414+
cut_var = c[selections.name]
415+
# cut on bin edges
416+
keep = (cut_var >= selections.edge_magnitudes[i]) & (cut_var < selections.edge_magnitudes[i+1])
417+
for var_name in output_binning.binnings[i].names:
418+
# Store the selected var_name entries (corresponding to the
419+
# dimensions in which the data for this selection will be
420+
# binned) in the fresh Container
421+
cc[var_name] = c[var_name][keep]
422+
# store the quantities that will populate each bin
423+
if isinstance(output_key, tuple):
424+
assert len(output_key) == 2
425+
cc[output_key[0]] = c[output_key[0]][keep]
426+
cc.tranlation_modes[output_key[0]] = 'sum'
427+
cc[output_key[1]] = np.square(c[output_key[0]][keep])
428+
cc.tranlation_modes[output_key[1]] = 'sum'
429+
else:
430+
cc[output_key] = c[output_key][keep]
431+
cc.tranlation_modes[output_key] = 'sum'
432+
433+
containers.append(cc)
434+
435+
dat = ContainerSet(
436+
name=self.data.name,
437+
containers=containers,
438+
representation=output_binning.binnings[i],
439+
)
440+
441+
if isinstance(output_key, tuple):
442+
for c in dat.containers:
443+
# uncertainties
444+
c[output_key[1]] = np.sqrt(c[output_key[1]])
445+
outputs.append(dat.get_mapset(output_key[0], error=output_key[1]))
446+
else:
447+
outputs.append(dat.get_mapset(output_key))
448+
return outputs
449+
450+
373451
def _get_outputs(self, output_binning=None, output_key=None):
374452
"""Get MapSet output"""
375453

@@ -391,55 +469,10 @@ def _get_outputs(self, output_binning=None, output_key=None):
391469
assert(isinstance(output_binning, (MultiDimBinning, VarBinning)))
392470

393471
if isinstance(output_binning, MultiDimBinning):
394-
self.data.representation = output_binning
395-
396-
if isinstance(output_key, tuple):
397-
assert len(output_key) == 2
398-
outputs = self.data.get_mapset(output_key[0], error=output_key[1])
399-
else:
400-
outputs = self.data.get_mapset(output_key)
401-
472+
outputs = self._get_outputs_multdimbinning(output_binning, output_key)
402473
else:
403474
assert isinstance(output_binning, VarBinning)
404-
assert self.data.representation == "events"
405-
outputs = []
406-
407-
selections = output_binning.selections
408-
for i in range(len(output_binning.binnings)):
409-
containers = []
410-
for c in self.data.containers:
411-
cc = Container(name=c.name)
412-
if isinstance(selections, list):
413-
keep = c.get_keep_mask(selections[i])
414-
else:
415-
cut_var = c[selections.name]
416-
keep = (cut_var >= selections.edge_magnitudes[i]) & (cut_var < selections.edge_magnitudes[i+1])
417-
for var_name in output_binning.binnings[i].names:
418-
cc[var_name] = c[var_name][keep]
419-
420-
if isinstance(output_key, tuple):
421-
assert len(output_key) == 2
422-
cc[output_key[0]] = c[output_key[0]][keep]
423-
cc.tranlation_modes[output_key[0]] = 'sum'
424-
cc[output_key[1]] = np.square(c[output_key[0]][keep])
425-
cc.tranlation_modes[output_key[1]] = 'sum'
426-
else:
427-
cc[output_key] = c[output_key][keep]
428-
cc.tranlation_modes[output_key] = 'sum'
429-
430-
containers.append(cc)
431-
432-
dat = ContainerSet(name=self.data.name,
433-
containers=containers,
434-
representation=output_binning.binnings[i],
435-
)
436-
437-
if isinstance(output_key, tuple):
438-
for c in dat.containers:
439-
c[output_key[1]] = np.sqrt(c[output_key[1]])
440-
outputs.append(dat.get_mapset(output_key[0], error=output_key[1]))
441-
else:
442-
outputs.append(dat.get_mapset(output_key))
475+
outputs = self._get_outputs_varbinning(output_binning, output_key)
443476

444477
return outputs
445478

@@ -636,11 +669,13 @@ def __hash__(self):
636669

637670
def assert_varbinning_compat(self):
638671
"""Asserts that pipeline setup is compatible with `VarBinning`:
639-
all stages need to apply to events.
672+
all stages need to apply to events (this precludes use with
673+
any histogramming service, which requires a binning as apply_mode).
640674
641675
Raises
642676
------
643-
ValueError : if at least one stage has apply_mode!='events'
677+
ValueError
678+
if at least one stage has apply_mode!='events'
644679
645680
"""
646681
incompat = []
@@ -667,7 +702,8 @@ def assert_exclusive_varbinning(self, output_binning=None):
667702
668703
Raises
669704
------
670-
ValueError : if a `VarBinning` is tested and at least two selections
705+
ValueError
706+
if a `VarBinning` is tested and at least two selections
671707
(if applicable) are not mutually exclusive
672708
673709
"""
@@ -1004,9 +1040,9 @@ def main(return_outputs=False):
10041040
pass
10051041
if isinstance(stop_idx, str):
10061042
stop_idx = pipeline.index(stop_idx)
1007-
outputs = pipeline.get_outputs(
1043+
outputs = pipeline.get_outputs( # pylint: disable=redefined-outer-name
10081044
idx=stop_idx
1009-
) # pylint: disable=redefined-outer-name
1045+
)
10101046
if stop_idx is not None:
10111047
stop_idx += 1
10121048
indices = slice(0, stop_idx)
@@ -1102,4 +1138,4 @@ def main(return_outputs=False):
11021138

11031139

11041140
if __name__ == "__main__":
1105-
pipeline, outputs = main(return_outputs=True) # pylint: disable=invalid-name
1141+
pipeline, outp = main(return_outputs=True)

0 commit comments

Comments
 (0)