Skip to content

Commit

Permalink
refactor: cleanup init
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 1, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 09cbd7f commit 05527d1
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ def __init__(self):
# Initialize global process group
self.global_pg = FakeProcessGroup(self.world_info.rank, 1)
if self.world_info.global_world_size > 1:
self.global_pg = self._init_global_pg()
self._init_global_pg()

# Initialize local process group
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
@@ -138,7 +138,7 @@ def _get_assigned_global_rank_and_size(self) -> Tuple[int, int]:
"""Get the assigned global rank from the leader."""
return

def _init_global_pg(self) -> dist.ProcessGroup:
def _init_global_pg(self) -> None:
# Each rank gets its own global store with global rank 0 as the master
self._global_leader = self.world_info.global_rank == 0
self.global_store = dist.TCPStore(
@@ -148,11 +148,12 @@ def _init_global_pg(self) -> dist.ProcessGroup:
is_master=self._global_leader,
)

# Initialize store
# Initialize store values
self._init_global_store_and_status()

# Initialize prefix store
if self.global_status == "init": # First time init path
self.mesh_count = 0
self.mesh_count = 0 # TODO: privatize?
prefix_store = dist.PrefixStore("mesh_0", self.global_store)
elif self.global_status == "running": # Join path
# Ask to join and then wait for the status to be "reinit"
@@ -169,17 +170,20 @@ def _init_global_pg(self) -> dist.ProcessGroup:
else:
# TODO: Could be in "reinit" status. We probably just recurse until running in this case
raise RuntimeError(f"Unknown status {self.global_status}")
pg = dist.ProcessGroupGloo(

# Create process group
self.global_pg = dist.ProcessGroupGloo(
prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT
)

# Update global store values
if self._global_leader:
self.global_store.set("status", "running")
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))
self.global_status = "running"
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))

# Setting instance variables
self.leaving = False # TODO: do we need this?
return pg

def _resolve_world(self):
"""Set the new world size and ranks for all nodes."""

0 comments on commit 05527d1

Please sign in to comment.