From d8a5167307af846fd3a595578572601cd7de95da Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Tue, 1 Oct 2024 15:01:11 +0800 Subject: [PATCH] reinit logic --- src/zeroband/comms.py | 64 +++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index a2d177ab..fd505a48 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -1,3 +1,4 @@ +import sys import os from torch.distributed.device_mesh import init_device_mesh from zeroband.utils.world_info import get_world_info @@ -53,10 +54,8 @@ def __init__(self): ) self.local_pg = self.mesh.get_group("intranode") - if self.world_info.rank == 0: - self._logger.info(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}") - else: - self._logger.info(f"local pg world : {self.local_pg.size()}") + # Logging + self._logger.info(f"global_pg size : {self.global_pg.size()}, local_pg size: {self.local_pg.size()}") def __del__(self): dist.destroy_process_group() @@ -186,7 +185,7 @@ def _init_global_pg(self) -> None: self.leaving = False # TODO: do we need this? def _resolve_world(self): - """Set the new world size and ranks for all nodes.""" + """Set the new world size and ranks for all nodes if there are joiners or leavers. Else, do nothing.""" # Find joiners and leavers joiners, leavers = self._get_joiners_and_leavers() # If no joiners or leavers, no resolution needed @@ -195,15 +194,15 @@ def _resolve_world(self): # Remap live ranks to smaller world_size caused by leavers leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers} - live_ranks = [i for i in range(0, self.world_size, self.local_world_size) if i not in leaving_ranks] + live_ranks = [i for i in range(self.world_info.global_world_size) if i not in leaving_ranks] for i, rank in enumerate(live_ranks): - self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size)) - new_world_size = len(live_ranks) * self.local_world_size + self.global_store.set(f"rank_map_{rank}", str(i)) + new_world_size = len(live_ranks) # Give joiners new ranks for joiner_id in joiners: self.global_store.set(f"rank_{joiner_id}", str(new_world_size)) - new_world_size += self.local_world_size + new_world_size += 1 # Update world_size self.global_store.set("world_size", str(new_world_size)) @@ -211,39 +210,46 @@ def _resolve_world(self): # Set status to "reinit" self.global_store.set("status", "reinit") - def maybe_reinit_device_mesh(self): - """Reinitialize the device mesh if there are joiners or leavers.""" - if self.rank == 0: + def maybe_reinit_global_pg(self): + """Reinitialize the global_pg if there are joiners or leavers.""" + if self._global_leader: self._resolve_world() dist.barrier() status = self.global_store.get("status").decode("utf-8") - if status == "running": + if status == "running": # No joiners or leavers return - print("Reinitializing device mesh") - dist.destroy_process_group() - print("Destroyed process group") + # Reinit Path + self._logger.info("Reinitializing global_pg") + if sys.getrefcount(self.global_pg) > 2: + self._logger.warning( + f"Global PG refcount was {sys.getrefcount(self.global_pg)} when 2 is expected during deletion. This may cause a memory leak." + ) + del self.global_pg + self._logger.info("Destroyed process group") if self.leaving: - print("Leaving") + self._logger.info("Leaving") return # Check if we got remapped - prev_uuid_rank = int(self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8")) - new_uuid_rank = int(self.global_store.get(f"rank_map_{prev_uuid_rank}").decode("utf-8")) - self.rank = new_uuid_rank + self.local_rank + old_global_rank = self.world_info.global_rank + self.world_info.global_rank = int( + self.global_store.get(f"rank_map_{self.world_info.global_rank}").decode("utf-8") + ) - self.world_size = int(self.global_store.get("world_size").decode("utf-8")) + self.world_info.global_world_size = int(self.global_store.get("world_size").decode("utf-8")) self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8")) - self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store) - dist.init_process_group( - backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size + prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store) + + # Create process group + self.global_pg = dist.ProcessGroupGloo( + prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT ) - if self.rank == 0: + if self._global_leader: self._clear_joiners_and_leavers() self.global_store.set("status", "running") + # Update rank if needed (otherwise, the next remap will do the lookup incorrectly) - if self.local_rank == 0 and new_uuid_rank != prev_uuid_rank: - self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(new_uuid_rank)) - # Reinitialize sub process groups - self.world_rank = self.rank // self.local_world_size + if old_global_rank != self.world_info.global_rank: + self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))