diff --git a/transactron/utils/dependencies.py b/transactron/utils/dependencies.py index 683ff58..73efd57 100644 --- a/transactron/utils/dependencies.py +++ b/transactron/utils/dependencies.py @@ -1,7 +1,7 @@ from collections import defaultdict from abc import abstractmethod, ABC -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Optional, TypeVar __all__ = ["DependencyManager", "DependencyKey", "DependencyContext", "SimpleKey", "ListKey"] @@ -122,15 +122,28 @@ def get_dependency(self, key: DependencyKey[Any, U]) -> U: The way dependencies are interpreted is dependent on the key type. """ - if not key.empty_valid and key not in self.dependencies: + ret = self.get_optional_dependency(key) + + if ret is None: raise KeyError(f"Dependency {key} not provided") - if key in self.cache: - return self.cache[key] + return ret + + def get_optional_dependency(self, key: DependencyKey[Any, U]) -> Optional[U]: + """Gets the dependency for a key, if it exists. + If the dependency is gettable, the return value is the same as in + `get_dependency`. Otherwise, `None` is returned. + """ if key.lock_on_get: self.locked_dependencies.add(key) + if not key.empty_valid and key not in self.dependencies: + return None + + if key in self.cache: + return self.cache[key] + val = key.combine(self.dependencies[key]) if key.cache: