diff --git a/ignite/handlers/time_profilers.py b/ignite/handlers/time_profilers.py index 3aa85e115ec5..d40eb5d2e6b4 100644 --- a/ignite/handlers/time_profilers.py +++ b/ignite/handlers/time_profilers.py @@ -1,6 +1,6 @@ import functools from collections import OrderedDict -from typing import Any, Callable, cast, Dict, List, Mapping, Sequence, Tuple, Union +from typing import Any, Callable, cast, Mapping, Sequence import torch @@ -50,7 +50,7 @@ def __init__(self) -> None: self.dataflow_times = torch.zeros(1) self.processing_times = torch.zeros(1) - self.event_handlers_times: Dict[EventEnum, torch.Tensor] = {} + self.event_handlers_times: dict[EventEnum, torch.Tensor] = {} self._events = [ Events.EPOCH_STARTED, @@ -222,10 +222,10 @@ def attach(self, engine: Engine) -> None: engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {})) @staticmethod - def _compute_basic_stats(data: torch.Tensor) -> Dict[str, Union[str, float, Tuple[float, float]]]: + def _compute_basic_stats(data: torch.Tensor) -> dict[str, str | float | tuple[float, float]]: # compute on non-zero data: data = data[data > 0] - out: List[Tuple[str, Union[str, float, Tuple[float, float]]]] = [ + out: list[tuple[str, str | float | tuple[float, float]]] = [ ("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered") ] if len(data) > 1: @@ -239,7 +239,7 @@ def _compute_basic_stats(data: torch.Tensor) -> Dict[str, Union[str, float, Tupl ) return OrderedDict(out) - def get_results(self) -> Dict[str, Dict[str, Any]]: + def get_results(self) -> dict[str, dict[str, Any]]: """ Method to fetch the aggregated profiler results after the engine is run @@ -248,7 +248,7 @@ def get_results(self) -> Dict[str, Dict[str, Any]]: results = profiler.get_results() """ - total_eh_time: Union[int, torch.Tensor] = sum( + total_eh_time: int | torch.Tensor = sum( [(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore] ) # pyrefly: ignore [no-matching-overload] @@ -354,7 +354,7 @@ def write_results(self, output_path: str) -> None: results_df.to_csv(output_path, index=False) @staticmethod - def print_results(results: Dict) -> str: + def print_results(results: dict) -> str: """ Method to print the aggregated results from the profiler @@ -402,7 +402,7 @@ def print_results(results: Dict) -> str: """ - def to_str(v: Union[str, tuple, int, float]) -> str: + def to_str(v: str | tuple | int | float) -> str: if isinstance(v, str): return v elif isinstance(v, tuple): @@ -498,16 +498,16 @@ def __init__(self) -> None: self._processing_timer = Timer() self._event_handlers_timer = Timer() - self.dataflow_times: List[float] = [] - self.processing_times: List[float] = [] - self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {} + self.dataflow_times: list[float] = [] + self.processing_times: list[float] = [] + self.event_handlers_times: dict[str | EventEnum, dict[str, list[float]]] = {} @staticmethod def _get_callable_name(handler: Callable) -> str: # get name of the callable handler return getattr(handler, "__qualname__", handler.__class__.__name__) - def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable: + def _create_wrapped_handler(self, handler: Callable, event: str | EventEnum) -> Callable: @functools.wraps(handler) def _timeit_handler(*args: Any, **kwargs: Any) -> None: self._event_handlers_timer.reset() @@ -532,7 +532,7 @@ def _timeit_dataflow(self) -> None: t = self._dataflow_timer.value() self.dataflow_times.append(t) - def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None: + def _reset(self, event_handlers_names: Mapping[str | EventEnum, list[str]]) -> None: # reset the variables used for profiling self.dataflow_times = [] self.processing_times = [] @@ -593,7 +593,7 @@ def attach(self, engine: Engine) -> None: if not engine.has_event_handler(self._as_first_started): engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {})) - def get_results(self) -> List[List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]]: + def get_results(self) -> list[list[str | float | tuple[str | float, str | float]]]: """ Method to fetch the aggregated profiler results after the engine is run @@ -612,16 +612,16 @@ def get_results(self) -> List[List[Union[str, float, Tuple[Union[str, float], Un total_eh_time = round(float(total_eh_time), 5) def compute_basic_stats( - times: Union[Sequence, torch.Tensor] - ) -> List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]: + times: Sequence | torch.Tensor, + ) -> list[str | float | tuple[str | float, str | float]]: data = torch.as_tensor(times, dtype=torch.float32) # compute on non-zero data: data = data[data > 0] - total: Union[str, float] = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered" - min_index: Tuple[Union[str, float], Union[str, float]] = ("None", "None") - max_index: Tuple[Union[str, float], Union[str, float]] = ("None", "None") - mean: Union[str, float] = "None" - std: Union[str, float] = "None" + total: str | float = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered" + min_index: tuple[str | float, str | float] = ("None", "None") + max_index: tuple[str | float, str | float] = ("None", "None") + mean: str | float = "None" + std: str | float = "None" if len(data) > 0: min_index = (round(torch.min(data).item(), 5), torch.argmin(data).item()) max_index = (round(torch.max(data).item(), 5), torch.argmax(data).item()) @@ -696,7 +696,7 @@ def write_results(self, output_path: str) -> None: results_df.to_csv(output_path, index=False) @staticmethod - def print_results(results: List[List[Union[str, float]]]) -> None: + def print_results(results: list[list[str | float]]) -> None: """ Method to print the aggregated results from the profiler