Skip to content

Commit

Permalink
Merge pull request #286 from KIT-CMS/PR_tut_main
Browse files Browse the repository at this point in the history
Merge some changes of CROWN_tutorial branch with main
  • Loading branch information
nshadskiy authored Jan 21, 2025
2 parents 3120dde + 5bf6c02 commit 637ad13
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 55 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# folders
build/
build/*
build*/
!build/.gitkeep
log/
logs/
generation_logs/
# ignore all analysis in the config folder apart from the unittest and the template
analysis_configurations/*
!analysis_configurations/unittest
Expand Down
1 change: 1 addition & 0 deletions build/.gitkeep
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build
38 changes: 18 additions & 20 deletions code_generation/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from code_generation.rules import ProducerRule, RemoveProducer
from code_generation.systematics import SystematicShift, SystematicShiftByQuantity
from code_generation.helpers import is_empty

log = logging.getLogger(__name__)
# type aliases
Expand Down Expand Up @@ -257,12 +258,12 @@ def unpack_producergroups(
"""

if isinstance(producers, list):
# we always want to know the toplevel producergroup, so if the parent is None, we set it to the first producer.
# we always want to know the toplevel producergroup, so if the parent evaluates to false, we set it to the first producer.
# If a prent is given, we set it to the parent, since this means we are in a producergroup. This is important if we
# have nested producergroups, this way every producer is assigned to the outermost producergroup, which is important for the
# potential removal of a single producer.
for producer in producers:
if parent is None:
if is_empty(parent):
parent_producer = producer
else:
parent_producer = parent
Expand All @@ -276,7 +277,7 @@ def unpack_producergroups(
if isinstance(producers, ProducerGroup):
log.debug("{} Unpacking ".format(" " * depth))
for sub_producer in producers.producers[scope]:
if parent is None:
if is_empty(parent):
parent_producer = producers
else:
parent_producer = parent
Expand All @@ -287,7 +288,7 @@ def unpack_producergroups(
depth=depth + 1,
)
else:
if parent is None:
if is_empty(parent):
log.debug("{} {}".format(" " * depth, producers))
self.unpacked_producers[scope][producers] = producers
else:
Expand Down Expand Up @@ -333,19 +334,19 @@ def add_shift(
Returns:
None
"""
if exclude_samples is not None and samples is not None:
if not is_empty(exclude_samples) and not is_empty(samples):
raise ConfigurationError(
f"You cannot use samples and exclude_samples at the same time -> Shift {shift}, samples {samples}, exclude_samples {exclude_samples}"
)
if samples is not None:
if not is_empty(samples):
if isinstance(samples, str):
samples = [samples]
for sample in samples:
if sample not in self.available_sample_types:
raise ConfigurationError(
f"Sampletype {sample} is not available -> Shift {shift}, available_sample_types {self.available_sample_types}, sample_types {samples}"
)
if exclude_samples is not None:
if not is_empty(exclude_samples):
if isinstance(exclude_samples, str):
exclude_samples = [exclude_samples]
for excluded_sample in exclude_samples:
Expand All @@ -360,7 +361,7 @@ def add_shift(
raise TypeError("shift must be of type SystematicShift")
if isinstance(samples, str):
samples = [samples]
if samples is None or self.sample in samples:
if is_empty(samples) or self.sample in samples:
scopes_to_shift = [
scope for scope in shift.get_scopes() if scope in self.scopes
]
Expand Down Expand Up @@ -513,9 +514,9 @@ def _remove_empty_scopes(self) -> None:
# we have to use a seperate list, because we cannot modify the list while iterating over it without breaking stuff
scopes_to_test = [scope for scope in self.scopes]
for scope in scopes_to_test:
if (len(self.producers[scope]) == 0) or (
scope not in self.selected_scopes and scope is not self.global_scope
):
if (
len(self.producers[scope]) == 0 or scope not in self.selected_scopes
) and scope is not self.global_scope:
log.warning("Removing unrequested / empty scope {}".format(scope))
self.scopes.remove(scope)
del self.producers[scope]
Expand Down Expand Up @@ -631,12 +632,7 @@ def _remove_empty_configkeys(self, config) -> None:
if isinstance(value, dict):
self._remove_empty_configkeys(value)

elif (
config[key] is None
or config[key] == ""
or config[key] == []
or config[key] == {}
):
elif is_empty(config[key]):
log.info(
"Removing {} since it is an empty configuration parameter".format(
key
Expand Down Expand Up @@ -767,9 +763,11 @@ def report(self) -> None:
total_quantities = [
sum(
[
len(self.config_parameters[scope][output.vec_config])
if isinstance(output, QuantityGroup)
else 1
(
len(self.config_parameters[scope][output.vec_config])
if isinstance(output, QuantityGroup)
else 1
)
for output in self.outputs[scope]
]
)
Expand Down
3 changes: 2 additions & 1 deletion code_generation/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations # needed for type annotations in > python 3.7
from typing import List, Set, Union
from code_generation.quantity import Quantity
from code_generation.helpers import is_empty


class ConfigurationError(Exception):
Expand Down Expand Up @@ -108,7 +109,7 @@ class InvalidShiftError(ConfigurationError):
"""

def __init__(self, shift: str, sample: str, scope: Union[str, None] = None):
if scope is None:
if is_empty(scope):
self.message = "Shift {} is not setup properly or not available for sampletype {}".format(
shift, sample
)
Expand Down
24 changes: 24 additions & 0 deletions code_generation/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations # needed for type annotations in > python 3.7

# File with helper functions for the CROWN code generation


def is_empty(value):
"""
Check if a value is empty.
Args:
value: The value that should be checked.
Returns:
bool: Whether the input value is considered 'empty'
"""
# List of all values that should be considered empty despite not having a length.
empty_values = [None]

try:
length = len(value)
except TypeError:
length = -1
bool_val = value in empty_values or length == 0
return bool_val
5 changes: 3 additions & 2 deletions code_generation/modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
SampleConfigurationError,
EraConfigurationError,
)
from code_generation.helpers import is_empty

ConfigurationParameters = Union[str, int, float, bool]

Expand Down Expand Up @@ -71,7 +72,7 @@ def apply(self, sample: str) -> ModifierResolved:
"""
if sample in self.samples:
return self.modifier_dict[sample]
elif self.default is not None:
elif not is_empty(self.default):
return self.default
else:
raise SampleConfigurationError(sample, self.samples)
Expand Down Expand Up @@ -106,7 +107,7 @@ def apply(self, era: str) -> ModifierResolved:
"""
if era in self.eras:
return self.modifier_dict[era]
elif self.default is not None:
elif not is_empty(self.default):
return self.default
else:
raise EraConfigurationError(era, self.eras)
7 changes: 4 additions & 3 deletions code_generation/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations # needed for type annotations in > python 3.7
from code_generation.quantity import NanoAODQuantity, Quantity
from code_generation.producer import Filter, BaseFilter, Producer, ProducerGroup
from code_generation.helpers import is_empty
from typing import Set, Tuple, Union, List
import logging

Expand Down Expand Up @@ -79,7 +80,7 @@ def get_global_outputs(self) -> List[Quantity]:
"""
outputs: List[Quantity] = []
for producer in self.global_producers:
if producer.get_outputs("global") is not None:
if not is_empty(producer.get_outputs("global")):
outputs.extend(
[
quantity
Expand Down Expand Up @@ -155,7 +156,7 @@ def Optimize(self) -> None:
log.error("Please check, if all needed producers are activated")
raise Exception
wrongProducer, wrong_inputs = self.check_ordering()
if wrongProducer is not None:
if not is_empty(wrongProducer):
producers_to_relocate = self.find_inputs(wrongProducer, wrong_inputs)
# if len(producers_to_relocate) == 0:
# self.optimized = True
Expand Down Expand Up @@ -197,7 +198,7 @@ def check_ordering(
outputs = self.global_outputs
for producer_to_check in self.ordering:
temp_outputs = producer_to_check.get_outputs(self.scope)
if temp_outputs is not None:
if not is_empty(temp_outputs):
outputs.extend(
[
quantity
Expand Down
Loading

0 comments on commit 637ad13

Please sign in to comment.