Skip to content

Commit 3b94a44

Browse files
committed
add env var port live reco
add env var port live reco
1 parent 0dcca60 commit 3b94a44

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,4 @@ HF_HUB_ETAG_TIMEOUT=500
103103
| `ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS` | Number of seconds between polls to the store when waiting for values | `0.1` |
104104
| `ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS` | Interval in seconds between heartbeats | `2` |
105105
| `ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS` | Time in seconds after which a node is considered dead if no heartbeat is received | `10` |
106+
| `ZERO_BAND_LIVE_RECO_PORT` | Port number for the live recovery server | `8000` |

src/zeroband/checkpoint.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535
SHM_PATH = "/dev/shm/zeroband"
3636

37+
ZERO_BAND_LIVE_RECO_PORT = int(os.environ.get("ZERO_BAND_LIVE_RECO_PORT", "8000"))
38+
3739

3840
@dataclass
3941
class TrainingProgress(Stateful):
@@ -135,7 +137,9 @@ def __init__(
135137
self.async_save_process: list[multiprocessing.Process] = []
136138

137139
if live_ckpt_server:
138-
self.live_server = CkptLiveServer(port=8000 + self.world_info.global_rank, ckpt_path=SHM_PATH)
140+
self.live_server = CkptLiveServer(
141+
port=ZERO_BAND_LIVE_RECO_PORT + self.world_info.global_rank, ckpt_path=SHM_PATH
142+
)
139143

140144
def _init_state(self):
141145
# states can only be stateful object, hence we need to wrap Model and Optimizer
@@ -288,11 +292,11 @@ def download_and_load_ckpt_from_peers(self):
288292
if self.world_info.local_rank == 0:
289293
# only local rank download the ckpt
290294
wget(
291-
source=f"http://localhost:{8000+dest_rank}/latest/diloco_{dest_rank}",
295+
source=f"http://localhost:{ZERO_BAND_LIVE_RECO_PORT+dest_rank}/latest/diloco_{dest_rank}",
292296
destination=path,
293297
)
294298
wget(
295-
source=f"http://localhost:{8000+dest_rank}/latest/diloco_{dest_rank}/.metadata",
299+
source=f"http://localhost:{ZERO_BAND_LIVE_RECO_PORT+dest_rank}/latest/diloco_{dest_rank}/.metadata",
296300
destination=path,
297301
)
298302
dist.barrier()

0 commit comments

Comments
 (0)