Skip to content

Commit

Permalink
Update torch_distributed.py with type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Jan 25, 2024
1 parent 6d28c91 commit 4e891d5
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
if TYPE_CHECKING:
import torch
import torch.distributed as torch_dist
from torch import Tensor
else:
torch = import_optional_module("torch", error="warn")
Tensor = import_optional_module("torch", attribute="Tensor", error="warn")
torch_dist = import_optional_module("torch.distributed", error="warn")


Expand Down Expand Up @@ -47,13 +49,13 @@ def world_size(self) -> int:
"""Return the world size of the current process group."""
return torch_dist.get_world_size()

def _simple_all_gather(self, data: torch.Tensor) -> List[torch.Tensor]:
def _simple_all_gather(self, data: Tensor) -> List[Tensor]:
"""Gather tensors of the same shape from all processes."""
gathered_data = [torch.zeros_like(data) for _ in range(self.world_size)]
torch_dist.all_gather(gathered_data, data) # type: ignore[no-untyped-call]
return gathered_data

def all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: # type: ignore[override]
def all_gather(self, data: Tensor) -> List[Tensor]: # type: ignore[override]
"""Gather Arrays from current proccess and return as a list.
Parameters
Expand Down

0 comments on commit 4e891d5

Please sign in to comment.