From a0f9d413653682a5ec4ac891390b133abdcc3c41 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Fri, 1 Dec 2023 12:19:55 +0800 Subject: [PATCH] Save the shm to storage at SGINT signal. (#864) * Save the shm to storage when SGINT. * Register the signal handlers. --- .../python/elastic_agent/torch/ckpt_saver.py | 36 +++++++++++++------ .../python/elastic_agent/torch/training.py | 24 +++++++------ dlrover/python/tests/test_ckpt_saver.py | 10 ++++-- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/dlrover/python/elastic_agent/torch/ckpt_saver.py b/dlrover/python/elastic_agent/torch/ckpt_saver.py index d0b673a0d..983e28866 100644 --- a/dlrover/python/elastic_agent/torch/ckpt_saver.py +++ b/dlrover/python/elastic_agent/torch/ckpt_saver.py @@ -229,6 +229,31 @@ def _save(): def get_ckpt_saver(cls): return cls._saver_instance + @classmethod + def register_signal_handler(cls): + sigint_handler = signal.getsignal(signal.SIGINT) + sigterm_handler = signal.getsignal(signal.SIGTERM) + + def _clean_shm_handler(signum, frame): + """Clean the shared memory from ^C and "killall python" etc.""" + if cls._saver_instance: + cls._saver_instance.close() + if callable(sigint_handler): + sigint_handler(signum, frame) + + def _save_shm_before_exiting(signum, frame): + """Save the state dict from the shared memory into the storage + before the process exits. + """ + if cls._saver_instance: + cls._saver_instance.save_shm_to_storage() + cls._saver_instance.close() + if callable(sigterm_handler): + sigterm_handler(signum, frame) + + signal.signal(signal.SIGINT, _clean_shm_handler) + signal.signal(signal.SIGTERM, _save_shm_before_exiting) + @abstractmethod def close(self): pass @@ -717,14 +742,3 @@ def save_to_storage(self, state_dict, path, step): """ if self._rank == 0: super().save_to_storage(state_dict, path, step) - - -def _clean_shm_handler(signum, frame): - """Clean the shared memory from ^C and "killall python" etc.""" - saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver() - if saver: - saver.close() - - -signal.signal(signal.SIGINT, _clean_shm_handler) -signal.signal(signal.SIGTERM, _clean_shm_handler) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index 7a500b4b1..cf5d8cb3d 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -487,6 +487,9 @@ def _initialize_workers(self, worker_group): if self._config.network_check: run_network_check(self._config, self._entrypoint) super()._initialize_workers(worker_group) + # We need to register handler after starting workers because + # the PContext start_worker will overwrite the handler. + CheckpointSaver.register_signal_handler() except RendezvousOutSyncError: logger.info( "Exit elastic-training rendezvous when there are " @@ -541,6 +544,7 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: return run_result elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}: self._report_failure_to_master(run_result.failures) + self._save_ckpt_to_storage() if self._remaining_failovers > 0: logger.info( f"[{role}] Worker group {state.name}. " @@ -556,10 +560,20 @@ def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: elif state == WorkerState.HEALTHY: # membership changes do not count as retries if self._membership_changed(role, rdzv_handler): + self._save_ckpt_to_storage() self._restart_workers(self._worker_group) else: raise Exception(f"[{role}] Worker group in {state.name} state") + def _save_ckpt_to_storage(self): + """ + The agent can save the checkpointing state dict in the shared + memory into the storage before restarting training processes. + """ + saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver() + if saver: + saver.save_shm_to_storage() + def _stop_workers_to_restart(self): """ The agent query from the dlrover job master to check whether to restart @@ -588,18 +602,8 @@ def _report_failure_to_master(self, failures: Dict[int, ProcessFailure]): def _restart_workers(self, worker_group: WorkerGroup): self._restart_count += 1 self._remaining_restarts -= 1 - self._save_ckpt_to_storage() super()._restart_workers(worker_group) - def _save_ckpt_to_storage(self): - """ - The agent can save the checkpointing state dict in the shared - memory into the storage before restarting training processes. - """ - saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver() - if saver: - saver.save_shm_to_storage() - def _membership_changed(self, role, rdzv_handler: RendezvousHandler): # Timeout may happen when to query TCPStore. if self._config.network_check: diff --git a/dlrover/python/tests/test_ckpt_saver.py b/dlrover/python/tests/test_ckpt_saver.py index c8ae3e834..c73edea82 100644 --- a/dlrover/python/tests/test_ckpt_saver.py +++ b/dlrover/python/tests/test_ckpt_saver.py @@ -12,6 +12,7 @@ # limitations under the License. import os +import signal import tempfile import time import unittest @@ -28,7 +29,6 @@ NoShardingCheckpointEngine, NoShardingSaver, ShardingCheckpointEngine, - _clean_shm_handler, _convert_torch_dtype_to_numpy, _traverse_state_dict, ) @@ -114,11 +114,15 @@ def test_save_to_storage(self): self.assertFalse(saving_engine._meta_dict[_WIRTING_SHM]) saver: CheckpointSaver = CheckpointSaver.get_ckpt_saver() saver._tensor_shm = SharedMemory(name=saver._shm_name) - saver.save_shm_to_storage() + CheckpointSaver.register_signal_handler() + handler = signal.getsignal(signal.SIGTERM) + handler(None, None) + with self.assertRaises(KeyboardInterrupt): + handler = signal.getsignal(signal.SIGINT) + handler(None, None) ckpt_files = os.listdir(tmpdir) self.assertEqual(len(ckpt_files), 1) sq.close() - _clean_shm_handler(None, None) def test_sharding_checkpoint_engine(self): os.environ["LOCAL_RANK"] = "1"