From 174908a081b07e0c13932abdbbb59c75ee7391ee Mon Sep 17 00:00:00 2001 From: Marko Date: Mon, 9 Feb 2026 13:18:09 +0100 Subject: [PATCH] Modernize_type_hints_to_Python_3.10+_in_engine/deterministic.py --- ignite/engine/deterministic.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ignite/engine/deterministic.py b/ignite/engine/deterministic.py index 8a38e035f193..dd620c32c9cd 100644 --- a/ignite/engine/deterministic.py +++ b/ignite/engine/deterministic.py @@ -2,7 +2,7 @@ import warnings from collections import OrderedDict from functools import wraps -from typing import Any, Callable, Generator, Iterator, List, Optional +from typing import Any, Callable, Generator, Iterator import torch from torch.utils.data import DataLoader @@ -56,11 +56,11 @@ class ReproducibleBatchSampler(BatchSampler): """ - def __init__(self, batch_sampler: BatchSampler, start_iteration: Optional[int] = None): + def __init__(self, batch_sampler: BatchSampler, start_iteration: int | None = None): if not isinstance(batch_sampler, BatchSampler): raise TypeError("Argument batch_sampler should be torch.utils.data.sampler.BatchSampler") - self.batch_indices: List = [] + self.batch_indices: list = [] self.batch_sampler = batch_sampler self.start_iteration = start_iteration self.sampler = self.batch_sampler.sampler @@ -84,8 +84,8 @@ def __len__(self) -> int: return len(self.batch_sampler) -def _get_rng_states() -> List[Any]: - output: List[Any] = [random.getstate(), torch.get_rng_state()] +def _get_rng_states() -> list[Any]: + output: list[Any] = [random.getstate(), torch.get_rng_state()] try: import numpy as np @@ -96,7 +96,7 @@ def _get_rng_states() -> List[Any]: return output -def _set_rng_states(rng_states: List[Any]) -> None: +def _set_rng_states(rng_states: list[Any]) -> None: random.setstate(rng_states[0]) if "cpu" not in rng_states[1].device.type: @@ -111,7 +111,7 @@ def _set_rng_states(rng_states: List[Any]) -> None: pass -def _repr_rng_state(rng_states: List[Any]) -> str: +def _repr_rng_state(rng_states: list[Any]) -> str: from hashlib import md5 out = " ".join([md5(str(list(s)).encode("utf-8")).hexdigest() for s in rng_states]) @@ -281,7 +281,7 @@ def _from_iteration(self, iteration: int) -> Iterator: return data_iter - def _setup_seed(self, _: Any = None, iter_counter: Optional[int] = None, iteration: Optional[int] = None) -> None: + def _setup_seed(self, _: Any = None, iter_counter: int | None = None, iteration: int | None = None) -> None: if iter_counter is None: le = self._dataloader_len if self._dataloader_len is not None else 1 elif not iter_counter > 0: