Skip to content

Commit 2f3d080

Browse files
committed
allow to pass live reco addr via env var
1 parent 93d2c2f commit 2f3d080

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,4 @@ HF_HUB_ETAG_TIMEOUT=500
104104
| `ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS` | Interval in seconds between heartbeats | `2` |
105105
| `ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS` | Time in seconds after which a node is considered dead if no heartbeat is received | `10` |
106106
| `ZERO_BAND_LIVE_RECO_PORT` | Port number for the live recovery server | `8000` |
107+
| `ZERO_BAND_LIVE_RECO_ADDR` | IP Address for the live recovery server | `localhost` |

src/zeroband/comms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
LIVE_RECO_PORT = int(os.environ.get("ZERO_BAND_LIVE_RECO_PORT", "8000"))
2424

25+
LIVE_RECO_ADDR = os.environ.get("ZERO_BAND_LIVE_RECO_ADDR", "localhost")
26+
2527

2628
class ElasticDeviceMesh:
2729
"""A class to manage the process groups for elastic training without restarts.
@@ -395,7 +397,7 @@ def init_live_endpoint(self, store: dist.Store):
395397
return
396398
self.store = dist.PrefixStore("live_reco_adress", store)
397399
port = LIVE_RECO_PORT + self.world_info.global_rank
398-
self.store.set(f"adress_{self.world_info.global_unique_id}", f"localhost:{port}")
400+
self.store.set(f"adress_{self.world_info.global_unique_id}", f"{LIVE_RECO_ADDR}:{port}")
399401

400402
def get_adress(self, rank: int) -> str:
401403
"""Get the live recovery adress for a given rank."""

0 commit comments

Comments
 (0)