|
15 | 15 | MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit
|
16 | 16 | MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit
|
17 | 17 | HEARTBEAT_INTERVAL = 2 # Interval in seconds between heartbeats
|
18 |
| -HEARTBEAT_TIMEOUT = 6 # Time in seconds after which a node is considered dead if no heartbeat is received |
| 18 | +HEARTBEAT_TIMEOUT = 10 # Time in seconds after which a node is considered dead if no heartbeat is received |
19 | 19 |
|
20 | 20 |
|
21 | 21 | class ElasticDeviceMesh:
|
@@ -212,21 +212,29 @@ def _start_heartbeat(self):
|
212 | 212 |
|
213 | 213 | def _stop_heartbeat(self):
|
214 | 214 | """Stop the heartbeat process."""
|
| 215 | + self._send_deathrattle() |
215 | 216 | if hasattr(self, "_heartbeat_stop_event"):
|
216 | 217 | self._heartbeat_stop_event.set()
|
217 | 218 | self._heartbeat_process.join()
|
218 | 219 |
|
219 | 220 | def _heartbeat_loop(self, stop_event):
|
220 | 221 | """Continuously send heartbeats until stopped."""
|
221 |
| - while not stop_event.is_set(): |
222 |
| - self._send_heartbeat() |
223 |
| - time.sleep(HEARTBEAT_INTERVAL) |
| 222 | + try: |
| 223 | + while not stop_event.is_set(): |
| 224 | + self._send_heartbeat() |
| 225 | + time.sleep(HEARTBEAT_INTERVAL) |
| 226 | + finally: |
| 227 | + self._send_deathrattle() |
224 | 228 |
|
225 | 229 | def _send_heartbeat(self):
|
226 | 230 | """Send a heartbeat to the global store."""
|
227 | 231 | current_time = time.time()
|
228 | 232 | self.global_store.set(f"heartbeat_{self.world_info.global_rank}", str(current_time))
|
229 | 233 |
|
| 234 | + def _send_deathrattle(self): |
| 235 | + """Send a deathrattle to the global store.""" |
| 236 | + self.global_store.set(f"heartbeat_{self.world_info.global_rank}", "-100") |
| 237 | + |
230 | 238 | def _check_heartbeats(self) -> List[str]:
|
231 | 239 | """Check heartbeats and return a list of nodes that have missed their heartbeats."""
|
232 | 240 | dead_nodes = []
|
|
0 commit comments