diff --git a/dlrover/python/common/shared_obj.py b/dlrover/python/common/shared_obj.py index 1bacb1d9d..6985e4ab7 100644 --- a/dlrover/python/common/shared_obj.py +++ b/dlrover/python/common/shared_obj.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import mmap import os import pickle import queue @@ -19,8 +20,11 @@ import time from abc import ABCMeta, abstractmethod from dataclasses import dataclass +from multiprocessing import shared_memory from typing import Dict +import _posixshmem + from .log import default_logger as logger TMP_DIR = "/tmp" @@ -445,3 +449,74 @@ def get(self): while not self._shared_queue.empty(): time.sleep(0.1) return self._dict + + +class SharedMemory(shared_memory.SharedMemory): + """ + Customization of the SharedMemory is necessary, as the + 'resource_tracker.ResourceTracker' in Python will unlink and remove the + file if one process fails. Our objective is to ensure that the training + process does not unlink the shared memory upon failure, + hereby allowing a new training process to commence utilizing + the existing shared memory to load checkpoint. + + Note:: We must explicitly unlink the SharedMemory to avoid memory leak. + """ + + # Defaults; enables close() and unlink() to run without errors. + _name = None + _fd = -1 + _mmap = None + _buf = None + _flags = os.O_RDWR + _mode = 0o600 + _prepend_leading_slash = True + + def __init__(self, name=None, create=False, size=0): + if not size >= 0: + raise ValueError("'size' must be a positive integer") + if create: + self._flags = os.O_CREAT | os.O_EXCL | os.O_RDWR + if size == 0: + raise ValueError( + "'size' must be a positive number different from zero" + ) + if name is None and not self._flags & os.O_EXCL: + raise ValueError("'name' can only be None if create=True") + + if name is None: + while True: + name = shared_memory._make_filename() + try: + self._fd = _posixshmem.shm_open( + name, self._flags, mode=self._mode + ) + except FileExistsError: + continue + self._name = name + break + else: + name = "/" + name if self._prepend_leading_slash else name + self._fd = _posixshmem.shm_open(name, self._flags, mode=self._mode) + self._name = name + try: + if create and size: + os.ftruncate(self._fd, size) + stats = os.fstat(self._fd) + size = stats.st_size + self._mmap = mmap.mmap(self._fd, size) + except OSError: + self.unlink() + raise + + self._size = size + self._buf = memoryview(self._mmap) + + def unlink(self): + """Requests that the underlying shared memory block be destroyed. + + In order to ensure proper cleanup of resources, unlink should be + called once (and only once) across all processes which have access + to the shared memory block.""" + if self._name: + _posixshmem.shm_unlink(self._name) diff --git a/dlrover/python/tests/test_shared_obj.py b/dlrover/python/tests/test_shared_obj.py index c6246897d..cf7f426c1 100644 --- a/dlrover/python/tests/test_shared_obj.py +++ b/dlrover/python/tests/test_shared_obj.py @@ -16,6 +16,7 @@ from dlrover.python.common.shared_obj import ( SharedDict, SharedLock, + SharedMemory, SharedQueue, ) @@ -61,3 +62,18 @@ def test_shared_dict(self): write_dict.update(new_dict=new_dict) d = read_dict.get() self.assertDictEqual(d, new_dict) + + +class SharedMemoryTest(unittest.TestCase): + def test_unlink(self): + fanme = "test-shm" + with self.assertRaises(ValueError): + shm = SharedMemory(name=fanme, create=True, size=-1) + with self.assertRaises(ValueError): + shm = SharedMemory(name=fanme, create=True, size=0) + shm = SharedMemory(name=fanme, create=True, size=1024) + shm.buf[0:4] = b"abcd" + shm.close() + shm.unlink() + with self.assertRaises(FileNotFoundError): + shm = SharedMemory(name=fanme, create=False)