From bb289574b95811335153c22164239b45fb8d1ecf Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Thu, 30 Nov 2023 15:02:44 +0800 Subject: [PATCH] The agent in the main process async saves the state dict to storage. (#857) * Customization of the SharedMemory without registering in ResourceTracker. * Polish the docstring. * Polish the docstring. * Fix test cases. * The agent in the main process async save to storage. * Format codes. * Refactor the socket communication. * Fix test cases. * Fix test cases. * Add test cases. * Fix the example. * Add test cases. * Format codes. --- .../{shared_obj.py => multi_process.py} | 184 ++-- .../python/elastic_agent/torch/ckpt_saver.py | 206 ++++ .../python/elastic_agent/torch/training.py | 3 + dlrover/python/tests/test_ckpt_saver.py | 81 ++ ...st_shared_obj.py => test_multi_process.py} | 14 +- .../trainer/tests/torch/checkpoint_test.py | 131 +-- dlrover/trainer/torch/elastic/checkpoint.py | 922 ++++++++++-------- examples/pytorch/mnist/cnn_train.py | 1 - examples/pytorch/nanogpt/train.py | 1 - 9 files changed, 945 insertions(+), 598 deletions(-) rename dlrover/python/common/{shared_obj.py => multi_process.py} (77%) create mode 100644 dlrover/python/elastic_agent/torch/ckpt_saver.py create mode 100644 dlrover/python/tests/test_ckpt_saver.py rename dlrover/python/tests/{test_shared_obj.py => test_multi_process.py} (87%) diff --git a/dlrover/python/common/shared_obj.py b/dlrover/python/common/multi_process.py similarity index 77% rename from dlrover/python/common/shared_obj.py rename to dlrover/python/common/multi_process.py index 6985e4ab7..f816c6aef 100644 --- a/dlrover/python/common/shared_obj.py +++ b/dlrover/python/common/multi_process.py @@ -59,10 +59,42 @@ def _create_socket_client(path): """ client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - client.connect(path) + connected = False + for _ in range(3): + try: + client.connect(path) + connected = True + break + except FileNotFoundError: + time.sleep(1) + if not connected: + client.connect(path) return client +def _socket_send(socket: socket.socket, message): + """ + In the protocol, the first 4 bytes is the size of message. + """ + head = len(message).to_bytes(4, "big") + send_data = head + message + socket.send(send_data) + + +def _socket_recv(socket: socket.socket): + """ + In the protocol, the first 4 bytes is the size of message. + """ + recv_data = socket.recv(1024) + head = recv_data[0:4] + message = recv_data[4:] + message_len = int.from_bytes(head, "big") + while len(message) < message_len: + recv_data = socket.recv(1024) + message += recv_data + return message + + @dataclass class SocketRequest(object): """ @@ -115,13 +147,19 @@ class LocalSocketComm(metaclass=ABCMeta): def __init__(self, name="", create=False): self._name = name - self._file = self._create_socket_path() + self._socket_file = self._create_socket_path() self._create = create self._server = None self._init_socket() def __del__(self): - os.unlink(self._file) + self.close() + + def close(self): + try: + os.unlink(self._socket_file) + except FileNotFoundError: + pass def _create_socket_path(self): """Create a file path for the local socket.""" @@ -131,7 +169,7 @@ def _create_socket_path(self): def _init_socket(self): """Initialze a socket server.""" if self._create: - self._server = _create_socket_server(self._file) + self._server = _create_socket_server(self._socket_file) t = threading.Thread( target=self._sync, daemon=True, @@ -145,10 +183,10 @@ def _sync(self): def _request(self, request: SocketRequest): """Create a socket client to requet the shared object.""" - client = _create_socket_client(self._file) - send_data = pickle.dumps(request) - client.send(send_data) - recv_data = client.recv(256) + client = _create_socket_client(self._socket_file) + message = pickle.dumps(request) + _socket_send(client, message) + recv_data = _socket_recv(client) client.close() response: LockAcquireResponse = pickle.loads(recv_data) return response @@ -178,7 +216,7 @@ def _sync(self): while True: connection, _ = self._server.accept() try: - recv_data = connection.recv(256) + recv_data = _socket_recv(connection) msg: SocketRequest = pickle.loads(recv_data) response = LockAcquireResponse() if msg.method == "acquire": @@ -190,7 +228,7 @@ def _sync(self): response = LockAcquireResponse() response.status = ERROR_CODE send_data = pickle.dumps(response) - connection.send(send_data) + _socket_send(connection, send_data) def acquire(self, blocking=True): """ @@ -286,7 +324,7 @@ def _sync(self): while True: connection, _ = self._server.accept() try: - recv_data = connection.recv(256) + recv_data = _socket_recv(connection) msg: SocketRequest = pickle.loads(recv_data) response = SocketResponse() if msg.method == "put": @@ -304,8 +342,8 @@ def _sync(self): except Exception: response = SocketResponse() response.status = ERROR_CODE - send_data = pickle.dumps(response) - connection.send(send_data) + message = pickle.dumps(response) + _socket_send(connection, message) def put(self, obj, block=True, timeout=None): """Put an object into the queue.""" @@ -357,65 +395,56 @@ def empty(self): return False -# The process uses FIFO pipe not the local socket to transfer -# the tensor meta dict. Because, the local socket needs buffers -# at both the sending end and receiving end. The FIFO only need -# one buffer. The size of tensor meta dict may be large. Local socket -# may need double memory buffer size to transfer the dict. -class SharedDict(object): +@dataclass +class DictMessage(SocketResponse): + """ + The response to get the dict of shared dict using local socket. + + Attributes: + meta_dict (dict): the return value to get an obj from a shared queue. """ - A shared dict is used in two processes. One process updates the dict - and another uses the dict. + + meta_dict: object = None + + +class SharedDict(LocalSocketComm): + """ + A shared dict between local processes. Args: name (str): the shared dictionary name, one process can update the - dict with the same name of another process by fifo pipe. - create (bool): If ture, the instance reads the dict from the fifo. - Otherwist, the instance writes the dict into the fifo. + dict with the same name of another process by local socket. + create (bool): If ture, the instance will receive the dict from the + sending process to update its dict. """ def __init__(self, name="", create=False): - self._name = name - self._create = create - fname = self.__class__.__name__.lower() + "_" + self._name + ".fifo" - self._file = os.path.join(TMP_DIR, fname) - self._fd = None - - if not os.path.exists(self._file): - os.mkfifo(self._file, 0o666) - if self._create: - self._dict = {} - self._shared_queue = SharedQueue( - name=f"shard_dict_{name}", create=self._create - ) - threading.Thread( - target=self._sync, daemon=True, name=f"{name}-receiver" - ).start() - else: - self._dict = None - self._shared_queue = SharedQueue( - name=f"shard_dict_{name}", create=self._create - ) + super().__init__(name, create) - def __del__(self): - os.unlink(self._file) + self._dict = {} + self._shared_queue = SharedQueue( + name=f"shard_dict_{name}", create=self._create + ) def _sync(self): - if self._create: - self._fd = os.open(self._file, os.O_RDONLY) while True: - recv_bytes = os.read(self._fd, 4) - msg_size = int.from_bytes(recv_bytes, "big") - total_bytes = b"" - while True: - buffer_size = 1024 * 1024 - recv_bytes = os.read(self._fd, buffer_size) - total_bytes += recv_bytes - if len(total_bytes) == msg_size: - break - d = pickle.loads(total_bytes) - self._dict.update(d) - self._shared_queue.get() + connection, _ = self._server.accept() + try: + recv_data = _socket_recv(connection) + msg: SocketRequest = pickle.loads(recv_data) + response = DictMessage() + if msg.method == "update": + self.update(**msg.args) + self._shared_queue.get(1) + elif msg.method == "get": + response = DictMessage() + response.meta_dict = self.get(**msg.args) + response.status = SUCCESS_CODE + except Exception: + response = SocketResponse() + response.status = ERROR_CODE + message = pickle.dumps(response) + _socket_send(connection, message) def update(self, new_dict): """ @@ -424,18 +453,13 @@ def update(self, new_dict): Args: new_dict (dict): a new dict to update. """ - if self._create: - self._dict.update(new_dict) - else: - if not self._fd: - self._fd = os.open(self._file, os.O_WRONLY) - bs = pickle.dumps(new_dict) - bs_size = len(bs) + self._dict.update(new_dict) + if not self._server: + args = {"new_dict": new_dict} + request = SocketRequest(method="update", args=args) try: self._shared_queue.put(1) - # Firstly send the size of the message. - os.write(self._fd, bs_size.to_bytes(4, "big")) - os.write(self._fd, bs) + self._request(request) except Exception: logger.info("The recv processs has breakdown.") @@ -446,9 +470,16 @@ def get(self): If the writing instance sends the dict into the FIFO, the get method should wait for the sync thread to update the dict. """ - while not self._shared_queue.empty(): - time.sleep(0.1) - return self._dict + if self._server: + while not self._shared_queue.empty(): + time.sleep(0.1) + return self._dict + else: + request = SocketRequest(method="get", args={}) + response: DictMessage = self._request(request) + if response.status == SUCCESS_CODE: + return response.meta_dict + return {} class SharedMemory(shared_memory.SharedMemory): @@ -519,4 +550,7 @@ def unlink(self): called once (and only once) across all processes which have access to the shared memory block.""" if self._name: - _posixshmem.shm_unlink(self._name) + try: + _posixshmem.shm_unlink(self._name) + except FileNotFoundError: + pass diff --git a/dlrover/python/elastic_agent/torch/ckpt_saver.py b/dlrover/python/elastic_agent/torch/ckpt_saver.py new file mode 100644 index 000000000..6e2c2b13f --- /dev/null +++ b/dlrover/python/elastic_agent/torch/ckpt_saver.py @@ -0,0 +1,206 @@ +# Copyright 2023 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import sys +import threading +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Callable, List, Mapping, Tuple + +import numpy as np +import torch + +from dlrover.python.common.log import default_logger as logger +from dlrover.python.common.multi_process import ( + SharedDict, + SharedLock, + SharedMemory, + SharedQueue, +) + +CKPT_DIR_PREFIX = "checkpoint-" + +SAVE_STEP_QNAME_PREFIX = "checkpoint_lock_rank_" +CKPT_META_NAME_PREFIX = "checkpoint_meta_local_rank_" +TENSOR_SHM_NAME_PREFIX = "checkpoint_shm_local_rank_" +SHM_LOCK_NAME_PREFIX = "shm_local_rank_" + + +def _init_dir(dir): + if os.path.exists(dir): + shutil.rmtree(dir) + os.makedirs(dir) + + +def convert_torch_dtype_to_numpy(torch_dtype): + dtype_map = { + torch.float32: np.float32, + torch.float: np.float32, + torch.float64: np.float64, + torch.double: np.double, + torch.float16: np.float16, + torch.half: np.half, + torch.uint8: np.uint8, + torch.int8: np.int8, + torch.int16: np.int16, + torch.short: np.short, + torch.int32: np.int32, + torch.int: np.int32, + torch.long: np.int64, + torch.bool: np.dtype("bool"), + } + return dtype_map[torch_dtype] + + +def traverse_state_dict(value: object, visitor: Callable[[object], None]): + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + """ + if isinstance(value, Mapping): + temp_dict = {} + for k, v in value.items(): + temp_dict[k] = traverse_state_dict(v, visitor) + return temp_dict + elif isinstance(value, List): + temp_list = [] + for _, v in enumerate(value): + temp_list.append(traverse_state_dict(v, visitor)) + return temp_list + else: + return visitor(value) + + +def read_state_dict_from_shm(checkpoint_meta, tensor_shm): + state_dict = traverse_state_dict( + checkpoint_meta, + lambda x: read_tensor_from_buf(x, tensor_shm), + ) + return state_dict + + +def read_tensor_from_buf(value, shm_tensor_buffer): + """ + Read a tensor from the buffer of shared memory. + """ + if isinstance(value, TensorMeta): + data_array = np.frombuffer( + buffer=shm_tensor_buffer.buf, + dtype=value.dtype, + offset=value.offset, + count=value.numel, + ) + value = torch.reshape(torch.tensor(data_array), value.shape) + return value + else: + return value + + +@dataclass +class TensorMeta(object): + shape: Tuple[int] = None # type: ignore + dtype: torch.dtype = None # type: ignore + element_size: int = 0 + numel: int = 0 + offset: int = 0 + + +class SaverFactory(object): + """ + Save the checkpointing state dict from the shared memory + into the storage. + """ + + pass + + +class CheckpointSaver(metaclass=ABCMeta): + @abstractmethod + def _save_shm_to_storage(self): + pass + + @classmethod + def start_async_saving_ckpt(cls): + """ + Start a thread to asynchronously save the checkpoint state dict + from the shared memory into the storage. Firstly, it waits that + the training process notify the saver class to create a saver. + """ + sq = SharedQueue(name="factory", create=True) + + def _save(sq: SharedQueue): + class_name = sq.get() + class_def = getattr(sys.modules[__name__], class_name) + saver: CheckpointSaver = class_def() + saver._save_shm_to_storage() + + threading.Thread( + target=_save, args=(sq,), name="checkpoint-saver", daemon=True + ).start() + + +class NoShardingSaver(CheckpointSaver): + """ + The saver only saves the state dict without sharding + from the shared memory created by local rank 0 to the storage. + """ + + def __init__(self) -> None: + self._checkpoint_dir = "" + self._tensor_shm = None + # Only local rank 0 save the state dict to memory in DDP. + qname = SAVE_STEP_QNAME_PREFIX + str(0) + self._to_save_queue = SharedQueue(name=qname, create=True) + meta_name = CKPT_META_NAME_PREFIX + str(0) + self._shared_ckpt_meta = SharedDict(name=meta_name, create=True) + lock_name = SHM_LOCK_NAME_PREFIX + str(0) + self._shm_lock = SharedLock(name=lock_name, create=True) + self._shm_name = TENSOR_SHM_NAME_PREFIX + str(0) + + def __del__(self): + self.close() + + def close(self): + if self._tensor_shm: + self._tensor_shm.close() + self._tensor_shm.unlink() + self._to_save_queue.close() + self._shared_ckpt_meta.close() + self._shm_lock.close() + + def _save_shm_to_storage(self): + """ + The loop to persist the state dict from the memory + buffer into the storage. + """ + logger.info("Start saving the checkpointing state dict to storage.") + while True: + path = self._to_save_queue.get() + if not self._tensor_shm: + self._tensor_shm = SharedMemory(name=self._shm_name) + self._shm_lock.acquire() + logger.info( + "Save checkpoint from the shared memory " + f"into the storage {path}." + ) + meta_dict = self._shared_ckpt_meta.get() + state_dict = read_state_dict_from_shm(meta_dict, self._tensor_shm) + self._persist_to_storage(state_dict, path) + self._shm_lock.release() + + def _persist_to_storage(self, state_dict, path): + """Persist the checkpoint from CPU memory buffer into the storage.""" + checkpoint_dir = os.path.dirname(path) + _init_dir(checkpoint_dir) + torch.save(state_dict, path) diff --git a/dlrover/python/elastic_agent/torch/training.py b/dlrover/python/elastic_agent/torch/training.py index a341c119b..8723b8a21 100644 --- a/dlrover/python/elastic_agent/torch/training.py +++ b/dlrover/python/elastic_agent/torch/training.py @@ -67,6 +67,7 @@ ) from dlrover.python.elastic_agent.master_client import MasterClient from dlrover.python.elastic_agent.monitor.training import TorchTrainingMonitor +from dlrover.python.elastic_agent.torch.ckpt_saver import CheckpointSaver from dlrover.python.elastic_agent.torch.master_kv_store import MasterKVStore __all__ = ["launch_agent"] @@ -681,6 +682,8 @@ def launch_agent( log_dir=config.log_dir, ) + CheckpointSaver.start_async_saving_ckpt() + shutdown_rdzv = True try: metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) diff --git a/dlrover/python/tests/test_ckpt_saver.py b/dlrover/python/tests/test_ckpt_saver.py new file mode 100644 index 000000000..76e7e5870 --- /dev/null +++ b/dlrover/python/tests/test_ckpt_saver.py @@ -0,0 +1,81 @@ +# Copyright 2023 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from dlrover.python.common.multi_process import SharedMemory +from dlrover.python.elastic_agent.torch.ckpt_saver import ( + NoShardingSaver, + convert_torch_dtype_to_numpy, + traverse_state_dict, +) + + +def set_torch_dist_env(port): + os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = "0" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + + +class SimpleNet(nn.Module): + def __init__(self): + super(SimpleNet, self).__init__() + self.fc1 = nn.Linear(64, 32) + self.fc2 = nn.Linear(32, 10) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.dropout(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +class CheckpointSaverTest(unittest.TestCase): + def test_close_saver(self): + saver = NoShardingSaver() + saver._tensor_shm = SharedMemory(name="test", create=True, size=1024) + saver.close() + saver.close() + + def test_traverse_state_dict(self): + def visitor(value): + return value + + model = SimpleNet() + step = 100 + state_dict = dict( + model=model.state_dict(), + step=step, + ) + new_dict = traverse_state_dict(state_dict, visitor) + self.assertEqual(new_dict, state_dict) + + def test_convert_torch_dtype_to_numpy(self): + np_dtype = convert_torch_dtype_to_numpy(torch.float32) + self.assertEqual(np_dtype, np.float32) + + np_dtype = convert_torch_dtype_to_numpy(torch.float) + self.assertEqual(np_dtype, np.float32) + + np_dtype = convert_torch_dtype_to_numpy(torch.int32) + self.assertEqual(np_dtype, np.int32) diff --git a/dlrover/python/tests/test_shared_obj.py b/dlrover/python/tests/test_multi_process.py similarity index 87% rename from dlrover/python/tests/test_shared_obj.py rename to dlrover/python/tests/test_multi_process.py index cf7f426c1..67b7481b8 100644 --- a/dlrover/python/tests/test_shared_obj.py +++ b/dlrover/python/tests/test_multi_process.py @@ -13,7 +13,7 @@ import unittest -from dlrover.python.common.shared_obj import ( +from dlrover.python.common.multi_process import ( SharedDict, SharedLock, SharedMemory, @@ -54,13 +54,15 @@ def test_shared_queue(self): def test_shared_dict(self): name = "test" - read_dict = SharedDict(name=name, create=True) - write_dict = SharedDict(name=name, create=False) + server_dict = SharedDict(name=name, create=True) + client_dict = SharedDict(name=name, create=False) new_dict = {"a": 1, "b": 2} - write_dict.update(new_dict=new_dict) + client_dict.update(new_dict=new_dict) new_dict["a"] = 4 - write_dict.update(new_dict=new_dict) - d = read_dict.get() + client_dict.update(new_dict=new_dict) + d = server_dict.get() + self.assertDictEqual(d, new_dict) + d = client_dict.get() self.assertDictEqual(d, new_dict) diff --git a/dlrover/trainer/tests/torch/checkpoint_test.py b/dlrover/trainer/tests/torch/checkpoint_test.py index 13d5d79e1..3f7b02eaf 100644 --- a/dlrover/trainer/tests/torch/checkpoint_test.py +++ b/dlrover/trainer/tests/torch/checkpoint_test.py @@ -26,10 +26,12 @@ from torch.utils.data import DataLoader, Dataset from dlrover.python.common import grpc -from dlrover.trainer.torch.elastic import checkpoint +from dlrover.python.elastic_agent.torch.ckpt_saver import CheckpointSaver from dlrover.trainer.torch.elastic.checkpoint import ( - AsyncCheckpointEngine, CheckpointManger, + NoShardingCheckpointEngine, + ShardingCheckpointEngine, + _create_shared_memory, _get_latest_checkpoint, ) from dlrover.trainer.torch.elastic.sampler import ElasticDistributedSampler @@ -91,37 +93,24 @@ def create_torch_modules(): return model, optimizer, dataloader -class LocalCheckpointManagerTest(unittest.TestCase): - def test_local_save_load(self): - model, optimizer, dataloader = create_torch_modules() - with tempfile.TemporaryDirectory() as tmpdirname: - ckpt_manager = CheckpointManger.init_checkpoint_manager( - model, - optimizer, - dataloader, - tmpdirname, - max_to_keep=2, - ) - ckpt_manager.save(epoch=0, step=10) - ckpt_manager.wait_saving_latest_ckpt() - ckpt_manager.save(epoch=0, step=20) - ckpt_manager.wait_saving_latest_ckpt() - ckpt_manager.save(epoch=0, step=30) - ckpt_manager.wait_saving_latest_ckpt() - ckpt_dirs = os.listdir(tmpdirname) - self.assertEqual(len(ckpt_dirs), 2) +def _wait_async_saving_finished(dir_name, step): + ckpt_path = os.path.join(dir_name, f"checkpoint-{step}/checkpoint.pt") + while True: + if os.path.exists(ckpt_path): + return + time.sleep(0.2) - ckpt_dir = _get_latest_checkpoint(tmpdirname) - expected_dir = os.path.join(tmpdirname, "checkpoint-30") - self.assertEqual(ckpt_dir, expected_dir) - ckpt_manager.load() - self.assertEqual(dataloader.sampler.total_size, 60002) - ckpt_manager._save_engine.close() +class CheckpointManagerTest(unittest.TestCase): + def setUp(self): + CheckpointSaver.start_async_saving_ckpt() + def test_create_shared_memory(self): + shm = _create_shared_memory("test", False) + self.assertIsNone(shm) -class DDPCheckpointManagerTest(unittest.TestCase): def test_ddp_save_load(self): + os.environ["LOCAL_RANK"] = "0" port = grpc.find_free_port() set_torch_dist_env(port) dist.init_process_group(backend="gloo") @@ -135,12 +124,9 @@ def test_ddp_save_load(self): tmpdirname, max_to_keep=2, ) - ckpt_manager.save(epoch=0, step=10) - ckpt_manager.wait_saving_latest_ckpt() - ckpt_manager.save(epoch=0, step=20) - ckpt_manager.wait_saving_latest_ckpt() - ckpt_manager.save(epoch=0, step=30) - ckpt_manager.wait_saving_latest_ckpt() + for step in [10, 20, 30]: + ckpt_manager.save(epoch=0, step=step) + _wait_async_saving_finished(tmpdirname, step) ckpt_dirs = os.listdir(tmpdirname) self.assertEqual(len(ckpt_dirs), 2) @@ -150,26 +136,11 @@ def test_ddp_save_load(self): ckpt_manager.load() self.assertEqual(dataloader.sampler.total_size, 60002) - ckpt_manager._save_engine.close() + ckpt_manager._ckpt_engine.close() dist.destroy_process_group() - -class AsyncCheckpointEngineTest(unittest.TestCase): - def test_traverse_state_dict(self): - def visitor(value): - return value - - model = SimpleNet() - step = 100 - state_dict = dict( - model=model.state_dict(), - step=step, - ) - new_dict = checkpoint.traverse_state_dict(state_dict, visitor) - self.assertEqual(new_dict, state_dict) - def test_create_tensor_meta(self): - engine = AsyncCheckpointEngine("test", 1, 10) + engine = NoShardingCheckpointEngine("test-ckpt") value = torch.rand((10, 10), dtype=torch.float32) meta = engine._create_tensor_meta(value) self.assertEqual(meta.numel, 100) @@ -177,52 +148,9 @@ def test_create_tensor_meta(self): self.assertEqual(meta.offset, 0) self.assertEqual(meta.shape, (10, 10)) self.assertEqual(meta.dtype, np.float32) + engine.close() - def test_convert_torch_dtype_to_numpy(self): - np_dtype = checkpoint._convert_torch_dtype_to_numpy(torch.float32) - self.assertEqual(np_dtype, np.float32) - - np_dtype = checkpoint._convert_torch_dtype_to_numpy(torch.float) - self.assertEqual(np_dtype, np.float32) - - np_dtype = checkpoint._convert_torch_dtype_to_numpy(torch.int32) - self.assertEqual(np_dtype, np.int32) - - def test_local_save(self): - model = SimpleNet() - step = 100 - state_dict = dict( - model=model.state_dict(), - step=step, - ) - with self.assertRaises(ValueError): - AsyncCheckpointEngine("test", 0) - with self.assertRaises(ValueError): - AsyncCheckpointEngine("test", 1, 0) - with tempfile.TemporaryDirectory() as tmpdirname: - engine = AsyncCheckpointEngine(tmpdirname, 1, 10) - path = os.path.join(tmpdirname, "checkpoint-10") - engine._persist_to_storage(state_dict, path) - - with tempfile.TemporaryDirectory() as tmpdirname: - engine = AsyncCheckpointEngine(tmpdirname, 1, 10) - engine.save(step, state_dict) - time.sleep(0.2) - restore_state_dict = engine._read_state_dict_from_buf( - engine._shm_tensor_buffer - ) - self.assertEqual(restore_state_dict["step"], 100) - - for key, value in state_dict["model"].items(): - buffer_value = restore_state_dict["model"][key] - self.assertTrue(torch.equal(value, buffer_value)) - self.assertTrue(engine._to_save_step_queue.empty()) - ckpt_dir = _get_latest_checkpoint(tmpdirname) - expected_dir = os.path.join(tmpdirname, "checkpoint-100") - self.assertEqual(ckpt_dir, expected_dir) - engine.close() - - def test_load(self): + def test_load_no_sharding(self): model = SimpleNet() step = 100 state_dict = dict( @@ -231,7 +159,7 @@ def test_load(self): ) with tempfile.TemporaryDirectory() as tmpdirname: - engine = AsyncCheckpointEngine(tmpdirname, 1, 10) + engine = NoShardingCheckpointEngine(tmpdirname) path = os.path.join(tmpdirname, "checkpoint-10/checkpoint.pt") os.makedirs(os.path.dirname(path)) torch.save(state_dict, path) @@ -243,3 +171,12 @@ def test_load(self): for key, value in state_dict["model"].items(): loaded_value = loaded_state_dict["model"][key] self.assertTrue(torch.equal(value, loaded_value)) + engine.close() + + def test_sharding_checkpoint_engine(self): + os.environ["LOCAL_RANK"] = "1" + with tempfile.TemporaryDirectory() as tmpdirname: + engine = ShardingCheckpointEngine(tmpdirname) + self.assertEqual( + engine._shared_ckpt_meta._name, "checkpoint_meta_local_rank_1" + ) diff --git a/dlrover/trainer/torch/elastic/checkpoint.py b/dlrover/trainer/torch/elastic/checkpoint.py index ccb67a4a1..b803c5c0a 100644 --- a/dlrover/trainer/torch/elastic/checkpoint.py +++ b/dlrover/trainer/torch/elastic/checkpoint.py @@ -11,18 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ctypes -import multiprocessing import os -import random import shutil -import string import time from abc import ABCMeta, abstractmethod -from dataclasses import dataclass from datetime import timedelta -from multiprocessing import shared_memory -from typing import Callable, List, Mapping, Tuple +from typing import List, Mapping import numpy as np import torch @@ -34,17 +28,27 @@ from torch.nn.parallel import DistributedDataParallel as DDP from dlrover.python.common.log import default_logger as logger +from dlrover.python.common.multi_process import ( + SharedDict, + SharedLock, + SharedMemory, + SharedQueue, +) +from dlrover.python.elastic_agent.torch.ckpt_saver import ( + CKPT_META_NAME_PREFIX, + SAVE_STEP_QNAME_PREFIX, + SHM_LOCK_NAME_PREFIX, + TENSOR_SHM_NAME_PREFIX, + TensorMeta, + convert_torch_dtype_to_numpy, + read_state_dict_from_shm, + traverse_state_dict, +) from dlrover.trainer.torch.elastic.sampler import ElasticDistributedSampler CKPT_DIR_PREFIX = "checkpoint-" -def get_random_string(length): - letters = string.ascii_lowercase - result_str = "".join(random.choice(letters) for i in range(length)) - return result_str - - def timer(func): def wrapper(*args, **kwargs): start = time.time() @@ -56,12 +60,6 @@ def wrapper(*args, **kwargs): return wrapper -def _init_dir(dir): - if os.path.exists(dir): - shutil.rmtree(dir) - os.makedirs(dir) - - def _sync(): if dist.is_initialized(): dist.barrier() @@ -84,6 +82,26 @@ def _get_latest_checkpoint(checkpoint_dir): return path +def _create_shared_memory(name, create, size=0): + """ + Create a shared memory. + """ + if not create: + try: + return SharedMemory(name=name) + except FileNotFoundError: + return None + try: + shm = SharedMemory( + name=name, + create=create, + size=size, + ) + except FileExistsError: + shm = SharedMemory(name=name) + return shm + + def _keep_topk_checkpoint(checkpoint_dir, max_to_keep): """Keep top k checkpoints and remove other checkpoints. @@ -107,51 +125,396 @@ def _keep_topk_checkpoint(checkpoint_dir, max_to_keep): shutil.rmtree(dir_name) -def _convert_torch_dtype_to_numpy(torch_dtype): - dtype_map = { - torch.float32: np.float32, - torch.float: np.float32, - torch.float64: np.float64, - torch.double: np.double, - torch.float16: np.float16, - torch.half: np.half, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.short: np.short, - torch.int32: np.int32, - torch.int: np.int32, - torch.long: np.int64, - torch.bool: np.dtype("bool"), - } - return dtype_map[torch_dtype] - - -def traverse_state_dict(value: object, visitor: Callable[[object], None]): +class CheckpointEngine(metaclass=ABCMeta): """ - Invoke ``visitor`` for each value recursively in ``state_dict``. + The checkpoint engine synchronously writes the state dict into + the shared memory and notify the agent in main process to + asynchronously save the state dict from the shared memory into + the storage. Writing to memory is significantly quicker + than writing to storage. The engine only blocks the training + with a little time. Users can frequently call `save_to_memory` in + the training loop and call `save_to_storage`. + + If the training process fail, the agent in main process can continuely + saves the the state dict from the shared memory into the storage. + + Attributes: + checkpoint_dir (str): the directory to save the temp checkpoint + if the training process fails. + + Examples:: + >>> engine = NoShardingCheckpointEngine( + >>> checkpoint_dir="/tmp/checkpoint/" + >>> ) + >>> for step, data in enumerate(dataloader): + >>> ... + >>> state_dict = model.state_dict() + >>> if step % 5 == 0: + >>> engine.save_to_memory(state_dict, step) + >>> elif step % 100 == 0: + >>> path = f"/tmp/checkpoint/ckpt-{step}.pt" + >>> engine.save_to_storage(state_dict, path, step) + >>> sate_dict = engine.load() """ - if isinstance(value, Mapping): - temp_dict = {} - for k, v in value.items(): - temp_dict[k] = traverse_state_dict(v, visitor) - return temp_dict - elif isinstance(value, List): - temp_list = [] - for _, v in enumerate(value): - temp_list.append(traverse_state_dict(v, visitor)) - return temp_list - else: - return visitor(value) + def __init__(self, checkpoint_dir): + self.checkpoint_dir = checkpoint_dir + if dist.is_initialized(): + self._rank = dist.get_rank() + self._local_rank = int(os.environ["LOCAL_RANK"]) + self._saver_group = dist.new_group( + backend="gloo", timeout=timedelta(seconds=30) + ) + else: + self._rank = 0 + self._local_rank = int(os.getenv("LOCAL_RANK", 0)) + self._saver_group = None + + self._buffer_size = 0 + self._cached_step = 0 + self._meta_dict = dict() + self._shm_name = "" + self._tensor_shm: SharedMemory = None + self._shared_ckpt_meta: SharedDict = None + self._shm_buffer_lock: SharedLock = None + self._to_save_queue: SharedQueue = None + self._notify_agent_to_create_saver() + self._init_shared_objs() + + def __del__(self): + self.close() + + def close(self): + if self._shared_ckpt_meta: + self._shared_ckpt_meta.close() + if self._shm_buffer_lock: + self._shm_buffer_lock.close() + if self._to_save_queue: + self._to_save_queue.close() + if self._tensor_shm: + self._tensor_shm.close() + + @abstractmethod + def _init_shared_objs(self): + """ + Initialize the shared queue, lock and memory to communiate + with the agent in the main process. + """ + pass + + @abstractmethod + def _notify_agent_to_create_saver(self): + """ + Notify the agent in the main process to create a checkpointing + saver to save the state dict from the shared memory into the storage. + """ + pass + + def _create_tensor_meta(self, value: torch.Tensor): + """ + Create a tensor meta of a tensor and compute the total + size of the state dict. + """ + if not torch.is_tensor(value): + return value + dtype = convert_torch_dtype_to_numpy(value.dtype) + meta = TensorMeta( + shape=tuple(value.shape), # type: ignore + dtype=dtype, + element_size=value.element_size(), + numel=value.numel(), + offset=self._buffer_size, + ) + self._buffer_size += value.numel() * value.element_size() + return meta + + def _make_state_dict_buffer(self, state_dict): + """ + Make the shared memory to store the state dict. + """ + self._meta_dict = traverse_state_dict( + state_dict, self._create_tensor_meta + ) + + # Update the meta dict in the main process. + self._shared_ckpt_meta.update(self._meta_dict) + self._tensor_shm = _create_shared_memory( + name=self._shm_name, + create=True, + size=self._buffer_size, + ) + + def _copy_state_dict_to_shm(self, state_dict): + """ + Copy the state dict from CPU memory buffer into the shared memory. + """ + + def _tarverse_copy(value, meta): + if isinstance(value, Mapping): + for k, v in value.items(): + if isinstance(v, (Mapping, List)): + m = meta[k] + _tarverse_copy(v, m) + elif torch.is_tensor(v): + m = meta[k] + self._write_shared_memory(v, m) + else: + meta[k] = v + elif isinstance(value, List): + for i, v in enumerate(value): + if isinstance(v, (Mapping, List)): + m = meta[i] + _tarverse_copy(v, m) + elif torch.is_tensor(v): + m = meta[i] + self._write_shared_memory(v, m) + else: + meta[i] = v + + _tarverse_copy(state_dict, self._meta_dict) + # Update the meta dict in the main process. + self._shared_ckpt_meta.update(self._meta_dict) + + def _write_shared_memory(self, value, meta: TensorMeta): + """ + Write a CPU tensor into the shared memory. + """ + data_array = value.cpu().numpy() + write_array = np.ndarray( + data_array.shape, + dtype=data_array.dtype, + buffer=self._tensor_shm.buf, + offset=meta.offset, + ) + if data_array.shape == (): + write_array.fill(data_array) + else: + write_array[:] = data_array[:] + + @timer + def save_to_memory(self, state_dict, step): + """ + Synchonously Saves the state dict into the shared memory with the main + process. If the agent in the main process is saving the shared memory + into the storage, the method will skip to write the shared memory. + + Args: + state_dict (dict): the state dict of model and optimizer to save. + step (int): the iteration step. + """ + state_dict["step"] = step + if self._tensor_shm is None: + self._make_state_dict_buffer(state_dict) + acquired = self._shm_buffer_lock.acquire(blocking=False) + all_rank_ready = self._check_all_rank_ready(acquired) + if not all_rank_ready: + logger.info( + f"Rank {self._rank} skips the save the checkpoint " + f"in CPU memory since it is saving the latest " + "checkpoint from the CPU memory into the storage." + ) + if acquired: + self._shm_buffer_lock.release() + return + self._copy_state_dict_to_shm(state_dict) + + if acquired: + self._shm_buffer_lock.release() + self._cached_step = step + + def _check_all_rank_ready(self, ready): + """ + Check wether all ranks are ready. + """ + if not self._saver_group: + return ready + value = 0 if ready else 1 + t = torch.tensor([value], dtype=torch.int64) + dist.all_reduce(t, group=self._saver_group) + return t == 0 + + @timer + def save_to_storage(self, state_dict, path, step): + """ + Asynchonously saves the state dict into the storage. It synchonously + saves the state dict into the shared memory and put the path + into a shared queue. The agent in the main process waits for the queue + for save the state dict in the shared memory into the storage. + + Args: + state_dict (dict): the state dict of model and optimizer to save. + path (str): optional, the file path to save the checkpoint. If the + path is not defined, the engine will save the state dict into + the shared memory not the storage. + step (int): the iteration step. + """ + if step > self._cached_step: + self.save_to_memory(state_dict, step) + if path: + self._to_save_queue.put(path) + + def load(self, resume_path=""): + """ + The method firstly try to load the state dict from the shared memory. + If there is no state dict in the shared memory, the method will + load the state dict from the storage. + + Returns: + A dict. + """ + state_dict = self._load_from_shared_memory() + if state_dict: + return state_dict + state_dict = self._load_from_storage(resume_path) + return state_dict + + def _load_from_shared_memory(self): + """ + Load the state dict from the shared memory. -@dataclass -class TensorMeta(object): - shape: Tuple[int] = None # type: ignore - dtype: torch.dtype = None # type: ignore - element_size: int = 0 - numel: int = 0 - offset: int = 0 + Returns: + A dict. + """ + if self._tensor_shm is None: + self._tensor_shm = _create_shared_memory( + self._shm_name, + create=False, + ) + if not self._tensor_shm: + return None + meta_dict = self._shared_ckpt_meta.get() + state_dict = read_state_dict_from_shm(meta_dict, self._tensor_shm) + return state_dict + + def _load_from_storage(self, resume_path=""): + """ + Load the state dict from the CPU memory if the state dict is complete + in CPU memory. Otherwise, the function will load the state dict from + the storage. + + Args: + resume_path (str, optional): , If the resume_path is an empty + string, the function will load the latest checkpoint file in + the checkpoint directory. + + Returns: + A dict: + a dictionary containing a whole state of the modules in the + checkpointing file. + """ + if resume_path: + state_dict = torch.load(resume_path) + else: + state_dict = self._load_from_historic_checkpoint() + return state_dict + + def _load_from_historic_checkpoint(self): + """Locd checkpoint from the lastest complete checkpoint.""" + while True: + latest_ckpt_dir = _get_latest_checkpoint(self.checkpoint_dir) + if not latest_ckpt_dir: + return {} + + resume_path = os.path.join(latest_ckpt_dir, "checkpoint.pt") + if not os.path.exists(resume_path): + shutil.rmtree(latest_ckpt_dir) + continue + try: + state_dict = torch.load(resume_path) + logger.info(f"Load checkpoint from {resume_path}") + return state_dict + except Exception: + logger.warning( + f"Fail to load checkpoint from {resume_path}." + " Roll back to the last checkpoint file." + ) + shutil.rmtree(latest_ckpt_dir) + + +class ShardingCheckpointEngine(CheckpointEngine): + """ + The engine to save the sharding model and optimizer state dict + into the memory and storage. We can use it to save the model and optimizer + using FSDP, Zero-3 or Megatron-LM. + """ + + def __init__(self, checkpoint_dir): + super().__init__(checkpoint_dir) + + def _notify_agent_to_create_saver(self): + # TODO: implement the saver in the agent to support saving + # sharding state dict. + pass + + def _init_shared_objs(self): + meta_name = CKPT_META_NAME_PREFIX + str(self._local_rank) + self._shared_ckpt_meta = SharedDict(name=meta_name, create=False) + lock_name = SHM_LOCK_NAME_PREFIX + str(self._local_rank) + self._shm_buffer_lock = SharedLock(name=lock_name, create=False) + qname = SAVE_STEP_QNAME_PREFIX + str(self._local_rank) + self._to_save_queue = SharedQueue(name=qname, create=False) + self._shm_name = TENSOR_SHM_NAME_PREFIX + str(self._local_rank) + + +class NoShardingCheckpointEngine(CheckpointEngine): + """ + The engine saves the model and optimizer state dict without sharding + in a local or DDP job. + """ + + def __init__(self, checkpoint_dir): + super().__init__(checkpoint_dir) + + def _notify_agent_to_create_saver(self): + queue = SharedQueue(name="factory") + queue.put("NoShardingSaver") + queue.close() + + def _init_shared_objs(self): + """ + Initialize the shared object with the main process. + Without model sharding, all ranks share the same shared memory + created by the local rank 0 on a node. + """ + meta_name = CKPT_META_NAME_PREFIX + str(0) + self._shared_ckpt_meta = SharedDict(name=meta_name, create=False) + lock_name = SHM_LOCK_NAME_PREFIX + str(0) + self._shm_buffer_lock = SharedLock(name=lock_name, create=False) + qname = SAVE_STEP_QNAME_PREFIX + str(0) + self._to_save_queue = SharedQueue(name=qname, create=False) + self._shm_name = TENSOR_SHM_NAME_PREFIX + str(0) + + @timer + def save_to_memory(self, state_dict, step): + """ + Synchonously Saves the state dict into the shared memory with the main + process. If the agent in the main process is saving the shared memory + into the storage, the method will skip to write the shared memory. + + Args: + state_dict (dict): the state dict of model and optimizer to save. + step (int): the iteration step. + """ + if self._local_rank == 0: + super().save_to_memory(state_dict, step) + + @timer + def save_to_storage(self, state_dict, path, step): + """ + Asynchonously saves the state dict into the storage. It synchonously + saves the state dict into the shared memory and put the path + into a shared queue. The agent in the main process waits for the queue + for save the state dict in the shared memory into the storage. + + Args: + state_dict (dict): the state dict of model and optimizer to save. + step (int): the iteration step. + path (str): optional, the file path to save the checkpoint. If the + path is not defined, the engine will save the state dict into + the shared memory not the storage. + """ + if self._rank == 0: + super().save_to_storage(state_dict, path, step) class CheckpointManger(metaclass=ABCMeta): @@ -179,7 +542,6 @@ class CheckpointManger(metaclass=ABCMeta): >>> save_storage_interval=5, >>> ) >>> ckpt_manager.save(0, 10) - >>> ckpt_manager.wait_saving_latest_ckpt() >>> ckpt_manger.load() """ @@ -196,20 +558,37 @@ def __init__( self.optimizer = optimizer self.dataloader = dataloader self.checkpoint_dir = checkpoint_dir + self.save_storage_interval = save_storage_interval + self.max_to_keep = max_to_keep if dist.is_initialized(): self._rank = dist.get_rank() + self._local_rank = int(os.environ["LOCAL_RANK"]) else: self._rank = 0 - self._save_engine = AsyncCheckpointEngine( - checkpoint_dir, - save_storage_interval=save_storage_interval, - max_to_keep=max_to_keep, - ) + self._local_rank = int(os.getenv("LOCAL_RANK", 0)) def _log_rank0(self, log): if self._rank == 0: logger.info(log) + def _engine_save(self, engine: CheckpointEngine, state_dict, step): + """ + The each rank has the complete state dict without sharding. Only + the locak rank 0 on each node saves the state dict into the shared + memory and only the rank 0 saves the state dict into the storage. + """ + engine.save_to_memory(state_dict, step) + if step % self.save_storage_interval == 0: + if self._rank == 0: + _keep_topk_checkpoint( + self.checkpoint_dir, self.max_to_keep - 1 + ) + ckpt_dir = os.path.join( + self.checkpoint_dir, f"{CKPT_DIR_PREFIX}{step}" + ) + ckpt_path = os.path.join(ckpt_dir, "checkpoint.pt") + engine.save_to_storage(state_dict, ckpt_path, step) + @abstractmethod def save(self, epoch, step): """ @@ -233,12 +612,6 @@ def load(self, resuming_path=None): """ pass - def wait_saving_latest_ckpt(self): - """ - Wait until the saving process finishes saving the latest checkpoint. - """ - self._save_engine.wait() - @classmethod def init_checkpoint_manager( cls, @@ -298,11 +671,16 @@ def __init__( save_storage_interval=1, max_to_keep=1, ): - super().__init__(model, optimizer, dataloader, checkpoint_dir) - self._save_engine = AsyncCheckpointEngine( + super().__init__( + model, + optimizer, + dataloader, + checkpoint_dir, + save_storage_interval, + max_to_keep, + ) + self._ckpt_engine = NoShardingCheckpointEngine( checkpoint_dir, - save_storage_interval=save_storage_interval, - max_to_keep=max_to_keep, ) def save(self, epoch, step): @@ -319,11 +697,19 @@ def save(self, epoch, step): ssd = self.dataloader.sampler.state_dict( step, self.dataloader.batch_size ) - checkpoint = {"model": msd, "optimizer": osd, "sampler": ssd} - self._save_engine.save(step, checkpoint) + checkpoint = { + "model": msd, + "optimizer": osd, + "sampler": ssd, + "epoch": epoch, + } + self._engine_save(self._ckpt_engine, checkpoint, step) def load(self, resuming_path=None): - checkpoint = self._save_engine.load(resuming_path) + """ + Load teh state dict from checkpointing data to the model and optimizer. + """ + checkpoint = self._ckpt_engine.load(resuming_path) if not checkpoint: return sampler = self.dataloader.sampler @@ -335,7 +721,7 @@ def load(self, resuming_path=None): self.optimizer.load_state_dict(optim_state_dict) -class DDPCheckpointManger(CheckpointManger): +class DDPCheckpointManger(LocalCheckpointManger): """ DDPCheckpontManager saves and loads checkpoint states of a DDP model. """ @@ -349,41 +735,20 @@ def __init__( save_storage_interval=1, max_to_keep=1, ): - super().__init__(model, optimizer, dataloader, checkpoint_dir) - self._save_engine = AsyncCheckpointEngine( + super().__init__( + model, + optimizer, + dataloader, checkpoint_dir, - save_storage_interval=save_storage_interval, - max_to_keep=max_to_keep, + save_storage_interval, + max_to_keep, ) - def save(self, epoch, step): + def load(self, resuming_path=None): """ - Save the checkpoint of model, optimizer, dataloader into the directory - `{self.directory}/checkpoint-{step}/checkpoint.pt`. + Load teh state dict from checkpointing data to the model and optimizer. """ - self._log_rank0(f"Save checkpoint of step={step} of epoch={epoch}.") - step = step + epoch * len(self.dataloader) - msd = self.model.state_dict() - osd = self.optimizer.state_dict() - ssd = {} - if isinstance(self.dataloader.sampler, ElasticDistributedSampler): - ssd = self.dataloader.sampler.state_dict( - step, self.dataloader.batch_size - ) - checkpoint = {"model": msd, "optimizer": osd, "sampler": ssd} - self._save_engine.save(step, checkpoint) - - def load(self, resuming_path=None): - checkpoint = self._save_engine.load(resuming_path) - if not checkpoint: - return - sampler = self.dataloader.sampler - if isinstance(sampler, ElasticDistributedSampler): - sampler.load_state_dict(checkpoint.get("sampler", {})) - model_state_dict = checkpoint.get("model", {}) - optim_state_dict = checkpoint.get("optimizer", {}) - self.model.load_state_dict(model_state_dict) - self.optimizer.load_state_dict(optim_state_dict) + super().load(resuming_path=resuming_path) _sync() @@ -392,6 +757,25 @@ class FSDPCheckpointManger(CheckpointManger): DDPCheckpontManager saves and loads checkpoint states of a DDP model. """ + def __init__( + self, + model, + optimizer, + dataloader, + checkpoint_dir, + save_storage_interval=1, + max_to_keep=1, + ): + super().__init__( + model, + optimizer, + dataloader, + checkpoint_dir, + save_storage_interval, + max_to_keep, + ) + self._ckpt_engine = NoShardingCheckpointEngine(checkpoint_dir) + def save(self, epoch, step): """ Save the checkpoint of model, optimizer, dataloader into the directory @@ -417,10 +801,18 @@ def save(self, epoch, step): ssd = self.dataloader.sampler.state_dict( step, self.dataloader.batch_size ) - checkpoint = {"model": msd, "optimizer": osd, "sampler": ssd} - self._save_engine.save(step, checkpoint) + checkpoint = { + "model": msd, + "optimizer": osd, + "sampler": ssd, + "epoch": epoch, + } + self._engine_save(self._ckpt_engine, checkpoint, step) def load(self, resuming_path=None): + """ + Load teh state dict from checkpointing data to the model and optimizer. + """ checkpoint = self._save_engine.load(resuming_path) if not checkpoint: return @@ -448,309 +840,3 @@ def load(self, resuming_path=None): self.model.load_state_dict(model_state_dict) self.optimizer.load_state_dict(optim_state_dict) _sync() - - -class AsyncCheckpointEngine(object): - """ - The `save` of the engine only writes the state dict into the shared memory. - A subprocess will asychronously save the state dict into the storage. - Writing to memory is significantly quicker than writing to storage. - The engine.save only block the training with a little time. - - Attributes: - checkpoint_dir: str, the directory to save the checkpoint. - save_storage_interval: int, the interval of iteration steps to save - the model and optimizer states from CPU memory to the storage. - max_to_keep: int, the number of checkpoint files to keep. - - Examples:: - >>> engine = AsyncCheckpointEngine( - >>> checkpoint_dir="/tmp/checkpoint/" - >>> save_storage_interval=5, - >>> max_to_keep=1, - >>> ) - >>> state_dict = model.state_dict() - >>> engine.save(step=100, state_dict=state_dict) - >>> engine.wait() - >>> sate_dict = engine.load() - """ - - def __init__( - self, - checkpoint_dir, - save_storage_interval=1, - max_to_keep=1, - ): - self.checkpoint_dir = checkpoint_dir - self.max_to_keep = max_to_keep - self.save_storage_interval = save_storage_interval - self._manager = multiprocessing.Manager() - self._tensor_meta_buffer = self._manager.dict() - self._shm_tensor_buffer = None - self._shm_buffer_lock = multiprocessing.Lock() - self._buffer_size = 0 - self._latest_step = 0 - self._latest_finish_step = multiprocessing.Value(ctypes.c_int, 0) - self._to_save_step_queue = multiprocessing.Queue(maxsize=1) - if dist.is_initialized(): - self._rank = dist.get_rank() - self._saver_group = dist.new_group( - backend="gloo", timeout=timedelta(seconds=30) - ) - else: - self._rank = 0 - self._saver_group = None - random_name = get_random_string(8) - self._shm_name = f"tensor_buffer_{random_name}_{self._rank}" - self._persist_proc = multiprocessing.Process( - name=f"persist-process-rank-{self._rank}", - target=self._persist_memory_buffer_to_storage, - daemon=True, - ) - self._check_arguments() - self._persist_proc.start() - - def __del__(self): - self.close() - - def close(self): - self._manager.shutdown() - if self._shm_tensor_buffer: - self._shm_tensor_buffer.close() - if self._persist_proc.is_alive(): - self._persist_proc.kill() - - def _check_arguments(self): - if self.max_to_keep == 0: - raise ValueError("max_to_keep cannot be 0.") - if self.save_storage_interval == 0: - raise ValueError("save_storage_interval cannot be 0.") - - def _create_tensor_meta(self, value): - """ - Create a tensor meta of a tensor and compute the total - size of the state dict. - """ - if not torch.is_tensor(value): - return value - dtype = _convert_torch_dtype_to_numpy(value.dtype) - meta = TensorMeta( - shape=tuple(value.shape), - dtype=dtype, - element_size=value.element_size(), - numel=value.numel(), - offset=self._buffer_size, - ) - self._buffer_size += value.numel() * value.element_size() - return meta - - def _make_state_dict_buffer(self, state_dict): - """ - Make the shared memory to store the state dict. - """ - meta_dict = traverse_state_dict(state_dict, self._create_tensor_meta) - self._tensor_meta_buffer.update(meta_dict) - self._shm_tensor_buffer = shared_memory.SharedMemory( - create=True, - size=self._buffer_size, - name=self._shm_name, - ) - - def _copy_state_dict_to_shm(self, state_dict): - """ - Copy the state dict from CPU memory buffer into the shared memory. - """ - - def _tarverse_copy(value, meta): - if isinstance(value, Mapping): - for k, v in value.items(): - if isinstance(v, (Mapping, List)): - m = meta[k] - _tarverse_copy(v, m) - elif torch.is_tensor(v): - m = meta[k] - self._write_shared_memory(v, m) - else: - meta[k] = v - elif isinstance(value, List): - for i, v in enumerate(value): - if isinstance(v, (Mapping, List)): - m = meta[i] - _tarverse_copy(v, m) - elif torch.is_tensor(v): - m = meta[i] - self._write_shared_memory(v, m) - else: - meta[i] = v - - _tarverse_copy(state_dict, self._tensor_meta_buffer) - - def _write_shared_memory(self, value, meta: TensorMeta): - """ - Write a CPU tensor into the shared memory. - """ - data_array = value.cpu().numpy() - write_array = np.ndarray( - data_array.shape, - dtype=data_array.dtype, - buffer=self._shm_tensor_buffer.buf, - offset=meta.offset, - ) - if data_array.shape == (): - write_array.fill(data_array) - else: - write_array[:] = data_array[:] - - def _persist_memory_buffer_to_storage(self): - """ - The loop to persist the state dict from the memory - buffer into the storage. - """ - logger.info("Start the process to persist the state dict.") - shm_tensor_buffer = None - while True: - step = self._to_save_step_queue.get() - if not shm_tensor_buffer: - shm_tensor_buffer = shared_memory.SharedMemory( - name=self._shm_name, - ) - with self._shm_buffer_lock: - checkpoint_dir = os.path.join( - self.checkpoint_dir, f"{CKPT_DIR_PREFIX}{step}" - ) - logger.info( - f"Save step-{step} checkpoint from memory " - f"into the storage {checkpoint_dir}." - ) - state_dict = self._read_state_dict_from_buf(shm_tensor_buffer) - self._persist_to_storage(state_dict, checkpoint_dir) - self._latest_finish_step.value = step - - def _read_state_dict_from_buf(self, shm_tensor_buffer): - meta_dict = {} - meta_dict.update(self._tensor_meta_buffer) - state_dict = traverse_state_dict( - meta_dict, - lambda x: self._read_tensor_from_buf(x, shm_tensor_buffer), - ) - return state_dict - - def _read_tensor_from_buf(self, value, shm_tensor_buffer): - """ - Read a tensor from the buffer of shared memory. - """ - if isinstance(value, TensorMeta): - data_array = np.frombuffer( - buffer=shm_tensor_buffer.buf, - dtype=value.dtype, - offset=value.offset, - count=value.numel, - ) - value = torch.reshape(torch.tensor(data_array), value.shape) - return value - else: - return value - - def _persist_to_storage(self, state_dict, checkpoint_dir): - """Persist the checkpoint from CPU memory buffer into the storage.""" - if self._rank == 0: - _init_dir(checkpoint_dir) - checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt") - torch.save(state_dict, checkpoint_path) - _keep_topk_checkpoint(self.checkpoint_dir, self.max_to_keep) - - @timer - def save(self, step, state_dict): - """ - Save the state dict into the CPU memory. If the step is the multiple - of the save_storage_interval, the engine will persist the state dict - from the CPU memory into the storage. - - Args: - step: the iteration step in the training loop. - state_dict: a dictionary. - """ - state_dict["step"] = step - if self._shm_tensor_buffer is None: - self._make_state_dict_buffer(state_dict) - acquired = self._shm_buffer_lock.acquire(block=False) - all_rank_ready = self._check_all_rank_ready(acquired) - if not all_rank_ready: - logger.info( - f"Rank {self._rank} skips the save the checkpoint with " - f"step {step} in CPU memory since it is saving the latest " - "checkpoint from the CPU memory into the storage." - ) - if acquired: - self._shm_buffer_lock.release() - return - self._copy_state_dict_to_shm(state_dict) - if step % self.save_storage_interval == 0: - self._to_save_step_queue.put(step) - self._latest_step = step - if acquired: - self._shm_buffer_lock.release() - - def _check_all_rank_ready(self, ready): - """ - Check wether all ranks are ready. - """ - if not self._saver_group: - return ready - value = 0 if ready else 1 - t = torch.tensor([value], dtype=torch.int64) - dist.all_reduce(t, group=self._saver_group) - return t == 0 - - def load(self, resume_path=""): - """ - Load the state dict from the CPU memory if the state dict is complete - in CPU memory. Otherwise, the function will load the state dict from - the storage. - - Args: - resume_path (str, optional): , If the resume_path is an empty - string, the function will load the latest checkpoint file in - the checkpoint directory. - - Returns: - dict: - a dictionary containing a whole state of the modules in the - checkpointing file. - """ - if resume_path: - state_dict = torch.load(resume_path) - else: - state_dict = self._load_from_historic_checkpoint() - return state_dict - - def _load_from_historic_checkpoint(self): - """Locd checkpoint from the lastest complete checkpoint.""" - while True: - latest_ckpt_dir = _get_latest_checkpoint(self.checkpoint_dir) - if not latest_ckpt_dir: - return {} - - resume_path = os.path.join(latest_ckpt_dir, "checkpoint.pt") - if not os.path.exists(resume_path): - shutil.rmtree(latest_ckpt_dir) - continue - try: - state_dict = torch.load(resume_path) - logger.info(f"Load checkpoint from {resume_path}") - return state_dict - except Exception: - logger.warning( - f"Fail to load checkpoint from {resume_path}." - " Roll back to the last checkpoint file." - ) - shutil.rmtree(latest_ckpt_dir) - - def wait(self): - """ - Wait until the saving process finishes saving the latest checkpoint. - """ - while self._latest_step > 0: - if self._latest_finish_step.value == self._latest_step: - break - time.sleep(0.1) diff --git a/examples/pytorch/mnist/cnn_train.py b/examples/pytorch/mnist/cnn_train.py index 267dd4581..8ce9a2683 100644 --- a/examples/pytorch/mnist/cnn_train.py +++ b/examples/pytorch/mnist/cnn_train.py @@ -185,7 +185,6 @@ def train(args): ) log_rank0("Test model after epoch {}".format(epoch)) test(model, device, test_loader) - ckpt_manager.wait_saving_latest_ckpt() if args.save_model: rank = int(os.environ.get("RANK", "0")) save_model(model, args.num_epochs, rank, args.use_fsdp) diff --git a/examples/pytorch/nanogpt/train.py b/examples/pytorch/nanogpt/train.py index 4dbf6119b..33f9c4ba8 100644 --- a/examples/pytorch/nanogpt/train.py +++ b/examples/pytorch/nanogpt/train.py @@ -385,7 +385,6 @@ def train(): if args.save_model: rank = int(os.getenv("RANK", "0")) save_model(model, epoch, rank, args.use_fsdp) - ckpt_manager.wait_saving_latest_ckpt() def save_model(model, epoch, rank, use_fsdp=False):