Skip to content

Commit

Permalink
Save the shm to storage at SGINT signal. (#864)
Browse files Browse the repository at this point in the history
* Save the shm to storage when SGINT.

* Register the signal handlers.
  • Loading branch information
workingloong authored Dec 1, 2023
1 parent 84f6c86 commit a0f9d41
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 24 deletions.
36 changes: 25 additions & 11 deletions dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,31 @@ def _save():
def get_ckpt_saver(cls):
return cls._saver_instance

@classmethod
def register_signal_handler(cls):
sigint_handler = signal.getsignal(signal.SIGINT)
sigterm_handler = signal.getsignal(signal.SIGTERM)

def _clean_shm_handler(signum, frame):
"""Clean the shared memory from ^C and "killall python" etc."""
if cls._saver_instance:
cls._saver_instance.close()
if callable(sigint_handler):
sigint_handler(signum, frame)

def _save_shm_before_exiting(signum, frame):
"""Save the state dict from the shared memory into the storage
before the process exits.
"""
if cls._saver_instance:
cls._saver_instance.save_shm_to_storage()
cls._saver_instance.close()
if callable(sigterm_handler):
sigterm_handler(signum, frame)

signal.signal(signal.SIGINT, _clean_shm_handler)
signal.signal(signal.SIGTERM, _save_shm_before_exiting)

@abstractmethod
def close(self):
pass
Expand Down Expand Up @@ -717,14 +742,3 @@ def save_to_storage(self, state_dict, path, step):
"""
if self._rank == 0:
super().save_to_storage(state_dict, path, step)


def _clean_shm_handler(signum, frame):
"""Clean the shared memory from ^C and "killall python" etc."""
saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver()
if saver:
saver.close()


signal.signal(signal.SIGINT, _clean_shm_handler)
signal.signal(signal.SIGTERM, _clean_shm_handler)
24 changes: 14 additions & 10 deletions dlrover/python/elastic_agent/torch/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,9 @@ def _initialize_workers(self, worker_group):
if self._config.network_check:
run_network_check(self._config, self._entrypoint)
super()._initialize_workers(worker_group)
# We need to register handler after starting workers because
# the PContext start_worker will overwrite the handler.
CheckpointSaver.register_signal_handler()
except RendezvousOutSyncError:
logger.info(
"Exit elastic-training rendezvous when there are "
Expand Down Expand Up @@ -541,6 +544,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
self._report_failure_to_master(run_result.failures)
self._save_ckpt_to_storage()
if self._remaining_failovers > 0:
logger.info(
f"[{role}] Worker group {state.name}. "
Expand All @@ -556,10 +560,20 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
if self._membership_changed(role, rdzv_handler):
self._save_ckpt_to_storage()
self._restart_workers(self._worker_group)
else:
raise Exception(f"[{role}] Worker group in {state.name} state")

def _save_ckpt_to_storage(self):
"""
The agent can save the checkpointing state dict in the shared
memory into the storage before restarting training processes.
"""
saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver()
if saver:
saver.save_shm_to_storage()

def _stop_workers_to_restart(self):
"""
The agent query from the dlrover job master to check whether to restart
Expand Down Expand Up @@ -588,18 +602,8 @@ def _report_failure_to_master(self, failures: Dict[int, ProcessFailure]):
def _restart_workers(self, worker_group: WorkerGroup):
self._restart_count += 1
self._remaining_restarts -= 1
self._save_ckpt_to_storage()
super()._restart_workers(worker_group)

def _save_ckpt_to_storage(self):
"""
The agent can save the checkpointing state dict in the shared
memory into the storage before restarting training processes.
"""
saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver()
if saver:
saver.save_shm_to_storage()

def _membership_changed(self, role, rdzv_handler: RendezvousHandler):
# Timeout may happen when to query TCPStore.
if self._config.network_check:
Expand Down
10 changes: 7 additions & 3 deletions dlrover/python/tests/test_ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.

import os
import signal
import tempfile
import time
import unittest
Expand All @@ -28,7 +29,6 @@
NoShardingCheckpointEngine,
NoShardingSaver,
ShardingCheckpointEngine,
_clean_shm_handler,
_convert_torch_dtype_to_numpy,
_traverse_state_dict,
)
Expand Down Expand Up @@ -114,11 +114,15 @@ def test_save_to_storage(self):
self.assertFalse(saving_engine._meta_dict[_WIRTING_SHM])
saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver()
saver._tensor_shm = SharedMemory(name=saver._shm_name)
saver.save_shm_to_storage()
CheckpointSaver.register_signal_handler()
handler = signal.getsignal(signal.SIGTERM)
handler(None, None)
with self.assertRaises(KeyboardInterrupt):
handler = signal.getsignal(signal.SIGINT)
handler(None, None)
ckpt_files = os.listdir(tmpdir)
self.assertEqual(len(ckpt_files), 1)
sq.close()
_clean_shm_handler(None, None)

def test_sharding_checkpoint_engine(self):
os.environ["LOCAL_RANK"] = "1"
Expand Down

0 comments on commit a0f9d41

Please sign in to comment.