Skip to content

Commit

Permalink
Customization of the SharedMemory without ResourceTracker. (#856)
Browse files Browse the repository at this point in the history
* Customization of the SharedMemory without registering in ResourceTracker.

* Polish the docstring.

* Polish the docstring.

* Fix test cases.
  • Loading branch information
workingloong authored Nov 29, 2023
1 parent 788eeb4 commit 07a8adf
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
75 changes: 75 additions & 0 deletions dlrover/python/common/shared_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import mmap
import os
import pickle
import queue
Expand All @@ -19,8 +20,11 @@
import time
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Dict

import _posixshmem

from .log import default_logger as logger

TMP_DIR = "/tmp"
Expand Down Expand Up @@ -445,3 +449,74 @@ def get(self):
while not self._shared_queue.empty():
time.sleep(0.1)
return self._dict


class SharedMemory(shared_memory.SharedMemory):
"""
Customization of the SharedMemory is necessary, as the
'resource_tracker.ResourceTracker' in Python will unlink and remove the
file if one process fails. Our objective is to ensure that the training
process does not unlink the shared memory upon failure,
hereby allowing a new training process to commence utilizing
the existing shared memory to load checkpoint.
Note:: We must explicitly unlink the SharedMemory to avoid memory leak.
"""

# Defaults; enables close() and unlink() to run without errors.
_name = None
_fd = -1
_mmap = None
_buf = None
_flags = os.O_RDWR
_mode = 0o600
_prepend_leading_slash = True

def __init__(self, name=None, create=False, size=0):
if not size >= 0:
raise ValueError("'size' must be a positive integer")
if create:
self._flags = os.O_CREAT | os.O_EXCL | os.O_RDWR
if size == 0:
raise ValueError(
"'size' must be a positive number different from zero"
)
if name is None and not self._flags & os.O_EXCL:
raise ValueError("'name' can only be None if create=True")

if name is None:
while True:
name = shared_memory._make_filename()
try:
self._fd = _posixshmem.shm_open(
name, self._flags, mode=self._mode
)
except FileExistsError:
continue
self._name = name
break
else:
name = "/" + name if self._prepend_leading_slash else name
self._fd = _posixshmem.shm_open(name, self._flags, mode=self._mode)
self._name = name
try:
if create and size:
os.ftruncate(self._fd, size)
stats = os.fstat(self._fd)
size = stats.st_size
self._mmap = mmap.mmap(self._fd, size)
except OSError:
self.unlink()
raise

self._size = size
self._buf = memoryview(self._mmap)

def unlink(self):
"""Requests that the underlying shared memory block be destroyed.
In order to ensure proper cleanup of resources, unlink should be
called once (and only once) across all processes which have access
to the shared memory block."""
if self._name:
_posixshmem.shm_unlink(self._name)
16 changes: 16 additions & 0 deletions dlrover/python/tests/test_shared_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dlrover.python.common.shared_obj import (
SharedDict,
SharedLock,
SharedMemory,
SharedQueue,
)

Expand Down Expand Up @@ -61,3 +62,18 @@ def test_shared_dict(self):
write_dict.update(new_dict=new_dict)
d = read_dict.get()
self.assertDictEqual(d, new_dict)


class SharedMemoryTest(unittest.TestCase):
def test_unlink(self):
fanme = "test-shm"
with self.assertRaises(ValueError):
shm = SharedMemory(name=fanme, create=True, size=-1)
with self.assertRaises(ValueError):
shm = SharedMemory(name=fanme, create=True, size=0)
shm = SharedMemory(name=fanme, create=True, size=1024)
shm.buf[0:4] = b"abcd"
shm.close()
shm.unlink()
with self.assertRaises(FileNotFoundError):
shm = SharedMemory(name=fanme, create=False)

0 comments on commit 07a8adf

Please sign in to comment.