Skip to content

Commit

Permalink
ruff lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Sep 27, 2024
1 parent e8332c4 commit 4938bb4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
49 changes: 29 additions & 20 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@


TCPSTORE_TIMEOUT = timedelta(seconds=10)
MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit
MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit
MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit
MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit


def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str:
while True:
Expand All @@ -26,6 +27,7 @@ def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str:
raise e
time.sleep(0.1)


def _queue_join(store: dist.Store, unique_id: str):
for i in range(MAX_JOINERS):
joiner_id = store.get(f"joiner_{i}").decode("utf-8")
Expand All @@ -36,6 +38,7 @@ def _queue_join(store: dist.Store, unique_id: str):
else:
raise RuntimeError("Too many joiners")


def _queue_leave(store: dist.Store, unique_id: str):
for i in range(MAX_LEAVERS):
leaver_id = store.get(f"leaver_{i}").decode("utf-8")
Expand All @@ -46,6 +49,7 @@ def _queue_leave(store: dist.Store, unique_id: str):
else:
raise RuntimeError("Too many leavers")


def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]:
joiners = []
leavers = []
Expand All @@ -62,17 +66,19 @@ def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]:
print(f"Joiners: {joiners}, Leavers: {leavers}")
return joiners, leavers


def _clear_joiners_and_leavers(store: dist.Store):
store.set("joiner_0", "null")
store.set("leaver_0", "null")


class ElasticDeviceMesh:
"""A class to manage the process groups for elastic training without restarts.
The way it works is rank 0 coordinates the joining and leaving of nodes.
Rank 0 manages the status to coordinate the creation and recreation of the process groups.
When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks.
Store keys used:
- status: "init", "running", "reinit"
- world_size: The current world size
Expand All @@ -96,26 +102,28 @@ def __init__(self):
self._init_unique_id()
if self.world_info.rank == 0:
self.global_pg = self._init_global_pg()
#from torch.distributed.distributed_c10d import _world
#global_rank = int(os.environ["GLOBAL_RANK"])
#_world.pg_group_ranks[self.global_pg] = {i: global_rank for i in range(self.world_info.world_size)}
#_world.pg_map[self.global_pg] = "gloo", self.global_store
# from torch.distributed.distributed_c10d import _world
# global_rank = int(os.environ["GLOBAL_RANK"])
# _world.pg_group_ranks[self.global_pg] = {i: global_rank for i in range(self.world_info.world_size)}
# _world.pg_map[self.global_pg] = "gloo", self.global_store

# Initialize local process group
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
self._device_mesh = init_device_mesh(
"cuda", (self.world_info.nnodes, self.world_info.local_world_size), mesh_dim_names=("internode", "intranode")
"cuda",
(self.world_info.nnodes, self.world_info.local_world_size),
mesh_dim_names=("internode", "intranode"),
)
self.local_pg = self._device_mesh.get_group("intranode")

if self.world_info.rank == 0:
self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")
else:
self._logger.debug(f"local pg world : {self.local_pg.size()}")

def __del__(self):
dist.destroy_process_group()

def _init_global_pg(self) -> dist.Store:
global_addr = os.environ["GLOBAL_ADDR"]
global_port = int(os.environ["GLOBAL_PORT"])
Expand All @@ -138,7 +146,7 @@ def _init_global_pg(self) -> dist.Store:
status = "init"
else:
status = _wait_for_status(store)

if status == "init":
# First time initialization
self.mesh_count = 0
Expand Down Expand Up @@ -178,18 +186,17 @@ def _init_unique_id(self):
return
if self.local_rank == 0:
self.unique_id = str(uuid.uuid4())
with open('/tmp/torch_unique_id', 'w') as f:
with open("/tmp/torch_unique_id", "w") as f:
f.write(self.unique_id)
else:
while True:
try:
with open('/tmp/torch_unique_id', 'r') as f:
with open("/tmp/torch_unique_id", "r") as f:
self.unique_id = f.read()
break
except FileNotFoundError:
time.sleep(0.1)


def _resolve_world(self):
"""Set the new world size and ranks for all nodes."""
# Find joiners and leavers
Expand All @@ -204,12 +211,12 @@ def _resolve_world(self):
for i, rank in enumerate(live_ranks):
self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size))
new_world_size = len(live_ranks) * self.local_world_size

# Give joiners new ranks
for joiner_id in joiners:
self.global_store.set(f"rank_{joiner_id}", str(new_world_size))
new_world_size += self.local_world_size

# Update world_size
self.global_store.set("world_size", str(new_world_size))
self.global_store.set("mesh_count", str(self.mesh_count + 1))
Expand All @@ -224,7 +231,7 @@ def maybe_reinit_device_mesh(self):
status = self.global_store.get("status").decode("utf-8")
if status == "running":
return

print("Reinitializing device mesh")
dist.destroy_process_group()
print("Destroyed process group")
Expand All @@ -240,8 +247,10 @@ def maybe_reinit_device_mesh(self):
self.world_size = int(self.global_store.get("world_size").decode("utf-8"))
self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8"))
self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store)
dist.init_process_group(backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size)

dist.init_process_group(
backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size
)

if self.rank == 0:
_clear_joiners_and_leavers(self.global_store)
self.global_store.set("status", "running")
Expand Down
10 changes: 6 additions & 4 deletions src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from torch.distributed.fsdp import ShardingStrategy
import torch.distributed as dist

from torch.testing._internal.distributed.fake_pg import FakeProcessGroup

class DilocoConfig(BaseConfig):
outer_lr: float = 0.7
inner_steps: int


class Diloco:
"""
This class implements the diloco algorithm from https://arxiv.org/abs/2311.08105 and https://arxiv.org/abs/2407.07852.
Expand Down Expand Up @@ -106,13 +106,15 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]:

for param_name, param in model.named_parameters():
if param.requires_grad:
storage = torch.UntypedStorage.from_file(f"/dev/shm/zeroband/{unique_id}-{param_name}", True, param.data.untyped_storage().size())
storage = torch.UntypedStorage.from_file(
f"/dev/shm/zeroband/{unique_id}-{param_name}", True, param.data.untyped_storage().size()
)
offloaded_param = torch.tensor(storage, dtype=param.dtype, device="cpu")
offloaded_param.as_strided_(size=param.data.size(), stride=param.data.stride())
if self.world_info.rank == 0:
# TODO: Can we async or split the copy among gpus probs overkill?
offloaded_param.copy_(param.data)
offloaded_param.requires_grad = False # TODO: check if we need to set this to True
offloaded_param.requires_grad = False # TODO: check if we need to set this to True
offloaded_params.append(offloaded_param)

dist.barrier()
Expand All @@ -130,6 +132,6 @@ def step(self, model: nn.Module):

dist.barrier()
self.sync_inner_model(model)

def __del__(self):
shutil.rmtree("/dev/shm/zeroband", ignore_errors=True)
10 changes: 5 additions & 5 deletions src/zeroband/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

TENSOR_SIG_SAMPLE_SIZE = 100


def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str:
"""
Get the tensor signature
Expand All @@ -14,18 +15,17 @@ def get_tensor_signature(a: torch.Tensor | torch.nn.Parameter) -> str:
else:
step_size = a.numel() // TENSOR_SIG_SAMPLE_SIZE
b = a.as_strided(size=(TENSOR_SIG_SAMPLE_SIZE,), stride=(step_size,))
element_str = ''.join([f'{x:.3e}' for x in b])
element_str = "".join([f"{x:.3e}" for x in b])
element_hash = hashlib.md5(element_str.encode("utf-8")).hexdigest()
return f"{a.dtype}{a.shape}{a.stride()}<{element_hash}>"

def get_module_signature(module: torch.nn.Module, compress: bool=True) -> str:

def get_module_signature(module: torch.nn.Module, compress: bool = True) -> str:
"""
Get the module signature
"""
state_dict_sig = {name: get_tensor_signature(param) for name, param in module.named_parameters()}
if compress:
return hashlib.md5(str(state_dict_sig).encode("utf-8")).hexdigest()
else:
return '\n'.join(f"{name}: {sig}" for name, sig in state_dict_sig.items())


return "\n".join(f"{name}: {sig}" for name, sig in state_dict_sig.items())

0 comments on commit 4938bb4

Please sign in to comment.