Skip to content

Commit

Permalink
allow to pass live reco addr via env var
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 6, 2024
1 parent 93d2c2f commit 2f3d080
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,4 @@ HF_HUB_ETAG_TIMEOUT=500
| `ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS` | Interval in seconds between heartbeats | `2` |
| `ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS` | Time in seconds after which a node is considered dead if no heartbeat is received | `10` |
| `ZERO_BAND_LIVE_RECO_PORT` | Port number for the live recovery server | `8000` |
| `ZERO_BAND_LIVE_RECO_ADDR` | IP Address for the live recovery server | `localhost` |
4 changes: 3 additions & 1 deletion src/zeroband/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

LIVE_RECO_PORT = int(os.environ.get("ZERO_BAND_LIVE_RECO_PORT", "8000"))

LIVE_RECO_ADDR = os.environ.get("ZERO_BAND_LIVE_RECO_ADDR", "localhost")


class ElasticDeviceMesh:
"""A class to manage the process groups for elastic training without restarts.
Expand Down Expand Up @@ -395,7 +397,7 @@ def init_live_endpoint(self, store: dist.Store):
return
self.store = dist.PrefixStore("live_reco_adress", store)
port = LIVE_RECO_PORT + self.world_info.global_rank
self.store.set(f"adress_{self.world_info.global_unique_id}", f"localhost:{port}")
self.store.set(f"adress_{self.world_info.global_unique_id}", f"{LIVE_RECO_ADDR}:{port}")

def get_adress(self, rank: int) -> str:
"""Get the live recovery adress for a given rank."""
Expand Down

0 comments on commit 2f3d080

Please sign in to comment.