diff --git a/README.md b/README.md index 12265211..dc3d2045 100644 --- a/README.md +++ b/README.md @@ -104,3 +104,4 @@ HF_HUB_ETAG_TIMEOUT=500 | `ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS` | Interval in seconds between heartbeats | `2` | | `ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS` | Time in seconds after which a node is considered dead if no heartbeat is received | `10` | | `ZERO_BAND_LIVE_RECO_PORT` | Port number for the live recovery server | `8000` | +| `ZERO_BAND_LIVE_RECO_ADDR` | IP Address for the live recovery server | `localhost` | diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index 2f7434a5..2b7f62c4 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -22,6 +22,8 @@ LIVE_RECO_PORT = int(os.environ.get("ZERO_BAND_LIVE_RECO_PORT", "8000")) +LIVE_RECO_ADDR = os.environ.get("ZERO_BAND_LIVE_RECO_ADDR", "localhost") + class ElasticDeviceMesh: """A class to manage the process groups for elastic training without restarts. @@ -395,7 +397,7 @@ def init_live_endpoint(self, store: dist.Store): return self.store = dist.PrefixStore("live_reco_adress", store) port = LIVE_RECO_PORT + self.world_info.global_rank - self.store.set(f"adress_{self.world_info.global_unique_id}", f"localhost:{port}") + self.store.set(f"adress_{self.world_info.global_unique_id}", f"{LIVE_RECO_ADDR}:{port}") def get_adress(self, rank: int) -> str: """Get the live recovery adress for a given rank."""