Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions ignite/handlers/time_profilers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools
from collections import OrderedDict
from typing import Any, Callable, cast, Dict, List, Mapping, Sequence, Tuple, Union
from typing import Any, Callable, cast, Mapping, Sequence
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also replace Mapping with dict in this file


import torch

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self) -> None:

self.dataflow_times = torch.zeros(1)
self.processing_times = torch.zeros(1)
self.event_handlers_times: Dict[EventEnum, torch.Tensor] = {}
self.event_handlers_times: dict[EventEnum, torch.Tensor] = {}

self._events = [
Events.EPOCH_STARTED,
Expand Down Expand Up @@ -222,10 +222,10 @@ def attach(self, engine: Engine) -> None:
engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {}))

@staticmethod
def _compute_basic_stats(data: torch.Tensor) -> Dict[str, Union[str, float, Tuple[float, float]]]:
def _compute_basic_stats(data: torch.Tensor) -> dict[str, str | float | tuple[float, float]]:
# compute on non-zero data:
data = data[data > 0]
out: List[Tuple[str, Union[str, float, Tuple[float, float]]]] = [
out: list[tuple[str, str | float | tuple[float, float]]] = [
("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered")
]
if len(data) > 1:
Expand All @@ -239,7 +239,7 @@ def _compute_basic_stats(data: torch.Tensor) -> Dict[str, Union[str, float, Tupl
)
return OrderedDict(out)

def get_results(self) -> Dict[str, Dict[str, Any]]:
def get_results(self) -> dict[str, dict[str, Any]]:
"""
Method to fetch the aggregated profiler results after the engine is run

Expand All @@ -248,7 +248,7 @@ def get_results(self) -> Dict[str, Dict[str, Any]]:
results = profiler.get_results()

"""
total_eh_time: Union[int, torch.Tensor] = sum(
total_eh_time: int | torch.Tensor = sum(
[(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore]
)
# pyrefly: ignore [no-matching-overload]
Expand Down Expand Up @@ -354,7 +354,7 @@ def write_results(self, output_path: str) -> None:
results_df.to_csv(output_path, index=False)

@staticmethod
def print_results(results: Dict) -> str:
def print_results(results: dict) -> str:
"""
Method to print the aggregated results from the profiler

Expand Down Expand Up @@ -402,7 +402,7 @@ def print_results(results: Dict) -> str:

"""

def to_str(v: Union[str, tuple, int, float]) -> str:
def to_str(v: str | tuple | int | float) -> str:
if isinstance(v, str):
return v
elif isinstance(v, tuple):
Expand Down Expand Up @@ -498,16 +498,16 @@ def __init__(self) -> None:
self._processing_timer = Timer()
self._event_handlers_timer = Timer()

self.dataflow_times: List[float] = []
self.processing_times: List[float] = []
self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {}
self.dataflow_times: list[float] = []
self.processing_times: list[float] = []
self.event_handlers_times: dict[str | EventEnum, dict[str, list[float]]] = {}

@staticmethod
def _get_callable_name(handler: Callable) -> str:
# get name of the callable handler
return getattr(handler, "__qualname__", handler.__class__.__name__)

def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable:
def _create_wrapped_handler(self, handler: Callable, event: str | EventEnum) -> Callable:
@functools.wraps(handler)
def _timeit_handler(*args: Any, **kwargs: Any) -> None:
self._event_handlers_timer.reset()
Expand All @@ -532,7 +532,7 @@ def _timeit_dataflow(self) -> None:
t = self._dataflow_timer.value()
self.dataflow_times.append(t)

def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None:
def _reset(self, event_handlers_names: Mapping[str | EventEnum, list[str]]) -> None:
# reset the variables used for profiling
self.dataflow_times = []
self.processing_times = []
Expand Down Expand Up @@ -593,7 +593,7 @@ def attach(self, engine: Engine) -> None:
if not engine.has_event_handler(self._as_first_started):
engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {}))

def get_results(self) -> List[List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]]:
def get_results(self) -> list[list[str | float | tuple[str | float, str | float]]]:
"""
Method to fetch the aggregated profiler results after the engine is run

Expand All @@ -612,16 +612,16 @@ def get_results(self) -> List[List[Union[str, float, Tuple[Union[str, float], Un
total_eh_time = round(float(total_eh_time), 5)

def compute_basic_stats(
times: Union[Sequence, torch.Tensor]
) -> List[Union[str, float, Tuple[Union[str, float], Union[str, float]]]]:
times: Sequence | torch.Tensor,
) -> list[str | float | tuple[str | float, str | float]]:
data = torch.as_tensor(times, dtype=torch.float32)
# compute on non-zero data:
data = data[data > 0]
total: Union[str, float] = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered"
min_index: Tuple[Union[str, float], Union[str, float]] = ("None", "None")
max_index: Tuple[Union[str, float], Union[str, float]] = ("None", "None")
mean: Union[str, float] = "None"
std: Union[str, float] = "None"
total: str | float = round(torch.sum(data).item(), 5) if len(data) > 0 else "not triggered"
min_index: tuple[str | float, str | float] = ("None", "None")
max_index: tuple[str | float, str | float] = ("None", "None")
mean: str | float = "None"
std: str | float = "None"
if len(data) > 0:
min_index = (round(torch.min(data).item(), 5), torch.argmin(data).item())
max_index = (round(torch.max(data).item(), 5), torch.argmax(data).item())
Expand Down Expand Up @@ -696,7 +696,7 @@ def write_results(self, output_path: str) -> None:
results_df.to_csv(output_path, index=False)

@staticmethod
def print_results(results: List[List[Union[str, float]]]) -> None:
def print_results(results: list[list[str | float]]) -> None:
"""
Method to print the aggregated results from the profiler

Expand Down