From 4e891d5a5f292dea152462dbbeeab909a322af64 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Thu, 25 Jan 2024 12:56:00 -0500 Subject: [PATCH] Update `torch_distributed.py` with type hints --- .../experimental/distributed_backends/torch_distributed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py index 0636e2d31..25b0ea1cd 100644 --- a/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py @@ -10,8 +10,10 @@ if TYPE_CHECKING: import torch import torch.distributed as torch_dist + from torch import Tensor else: torch = import_optional_module("torch", error="warn") + Tensor = import_optional_module("torch", attribute="Tensor", error="warn") torch_dist = import_optional_module("torch.distributed", error="warn") @@ -47,13 +49,13 @@ def world_size(self) -> int: """Return the world size of the current process group.""" return torch_dist.get_world_size() - def _simple_all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: + def _simple_all_gather(self, data: Tensor) -> List[Tensor]: """Gather tensors of the same shape from all processes.""" gathered_data = [torch.zeros_like(data) for _ in range(self.world_size)] torch_dist.all_gather(gathered_data, data) # type: ignore[no-untyped-call] return gathered_data - def all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: # type: ignore[override] + def all_gather(self, data: Tensor) -> List[Tensor]: # type: ignore[override] """Gather Arrays from current proccess and return as a list. Parameters