From c2bfd2013d65f994222d5440edc3f36b125be6a8 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Sat, 25 Nov 2023 20:06:35 +0800 Subject: [PATCH 1/6] Use socket to implement the shared lock between processes. --- dlrover/python/common/socket.py | 157 ++++++++++++++++++++++++++++ dlrover/python/tests/test_socket.py | 31 ++++++ 2 files changed, 188 insertions(+) create mode 100644 dlrover/python/common/socket.py create mode 100644 dlrover/python/tests/test_socket.py diff --git a/dlrover/python/common/socket.py b/dlrover/python/common/socket.py new file mode 100644 index 000000000..c70360f57 --- /dev/null +++ b/dlrover/python/common/socket.py @@ -0,0 +1,157 @@ +import os +import socket +import threading +import time +import pickle +from dataclasses import dataclass +from typing import Dict + + +SOCKER_TEMP_FILE_DIR = "/tmp/checkpoint/" + +SUCCESS_CODE = "OK" +ERROR_CODE = "ERROR" + + +def _create_socket_server(path): + """ + Create a socket server. + + Args: + path (str): a file path. + """ + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + path_dir = os.path.dirname(path) + os.makedirs(path_dir, exist_ok=True) + if os.path.exists(path): + os.unlink(path) + server.bind(path) + server.listen(0) + return server + + +def _create_socket_client(path): + """ + Create a socket client. + + Args: + path (str): a file path. + + """ + client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client.connect(path) + return client + + +@dataclass +class SocketRequest(object): + method: str = "" + args: Dict[str, object] = None # type: ignore + + +@dataclass +class SocketResponse(object): + status: str = "" + + +@dataclass +class LockResponse(SocketResponse): + acquired: bool = False + + +class SharedLock(object): + """ + On a single node, processes can share a lock with an identical name + via socket-based communication. + + Args: + name (str): the lock name, processes can share a lock with an + identical name on a single node. + create (bool): If ture, the lock creates a socket server and a lock. + Otherwise, the lock need to create a socket client to access + the lock. + """ + def __init__(self, name="", create=False): + self._name = name + self._lock = threading.Lock() + fname = "lock_" + self._name + ".sock" + self._path = os.path.join(SOCKER_TEMP_FILE_DIR, fname) + self._create = create + self._server = None + self._init_socket() + + def _init_socket(self): + if self._create: + self._server = _create_socket_server(self._path) + t = threading.Thread( + target=self._sync_lock_status, daemon=True, + ) + t.start() + + def _sync_lock_status(self): + while True: + connection, _ = self._server.accept() + try: + recv_data = connection.recv(256) + print(recv_data) + msg: SocketRequest = pickle.loads(recv_data) + response = LockResponse() + if msg.method == "acquire": + response.acquired = self.acquire(**msg.args) + elif msg.method == "release": + self.release() + response.status = SUCCESS_CODE + except Exception as e: + print(e) + response = LockResponse() + response.status = ERROR_CODE + send_data = pickle.dumps(response) + connection.send(send_data) + + def acquire(self, blocking=True): + """ + Acquire a lock shared by multiple process, blocking or non-blocking. + + Args: + blocking (bool): blocking or non-blocking. + """ + if self._create: + return self._lock.acquire(blocking=blocking) + else: + request = SocketRequest( + method="acquire", + args={"blocking": blocking}, + ) + response = self._request(request) + if response: + return response.acquired + return False + + def release(self): + """ + Release a lock shared by multiple processes. + """ + if self._create: + if self._lock.locked(): + self._lock.release() + else: + request = SocketRequest( + method="release", + args={}, + ) + self._request(request) + + def _request(self, request: SocketRequest): + for _ in range(3): + client = _create_socket_client(self._path) + send_data = pickle.dumps(request) + client.send(send_data) + recv_data = client.recv(256) + client.close() + response: LockResponse = pickle.loads(recv_data) + print(response) + if response.status == SUCCESS_CODE: + return response + else: + time.sleep(1) + continue diff --git a/dlrover/python/tests/test_socket.py b/dlrover/python/tests/test_socket.py new file mode 100644 index 000000000..b621bb274 --- /dev/null +++ b/dlrover/python/tests/test_socket.py @@ -0,0 +1,31 @@ +# Copyright 2023 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dlrover.python.common.socket import SharedLock + + +class SharedLockTest(unittest.TestCase): + def test_shared_lock(self): + name = "test" + server_lock = SharedLock(name, create=True) + client_lock = SharedLock(name, create=False) + acquired = server_lock.acquire() + self.assertTrue(acquired) + acquired = client_lock.acquire(blocking=False) + self.assertFalse(acquired) + server_lock.release() + acquired = client_lock.acquire(blocking=False) + self.assertTrue(acquired) + client_lock.release() From cd113073f41e710c73f98395060764b93d076ce2 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Mon, 27 Nov 2023 16:55:49 +0800 Subject: [PATCH 2/6] Rename the file and format it. --- .../python/common/{socket.py => process.py} | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) rename dlrover/python/common/{socket.py => process.py} (85%) diff --git a/dlrover/python/common/socket.py b/dlrover/python/common/process.py similarity index 85% rename from dlrover/python/common/socket.py rename to dlrover/python/common/process.py index c70360f57..16c58c991 100644 --- a/dlrover/python/common/socket.py +++ b/dlrover/python/common/process.py @@ -1,12 +1,24 @@ +# Copyright 2023 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os +import pickle import socket import threading import time -import pickle from dataclasses import dataclass from typing import Dict - SOCKER_TEMP_FILE_DIR = "/tmp/checkpoint/" SUCCESS_CODE = "OK" @@ -36,7 +48,7 @@ def _create_socket_client(path): Args: path (str): a file path. - + """ client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) client.connect(path) @@ -71,6 +83,7 @@ class SharedLock(object): Otherwise, the lock need to create a socket client to access the lock. """ + def __init__(self, name="", create=False): self._name = name self._lock = threading.Lock() @@ -84,7 +97,8 @@ def _init_socket(self): if self._create: self._server = _create_socket_server(self._path) t = threading.Thread( - target=self._sync_lock_status, daemon=True, + target=self._sync_lock_status, + daemon=True, ) t.start() @@ -93,7 +107,6 @@ def _sync_lock_status(self): connection, _ = self._server.accept() try: recv_data = connection.recv(256) - print(recv_data) msg: SocketRequest = pickle.loads(recv_data) response = LockResponse() if msg.method == "acquire": @@ -101,8 +114,7 @@ def _sync_lock_status(self): elif msg.method == "release": self.release() response.status = SUCCESS_CODE - except Exception as e: - print(e) + except Exception: response = LockResponse() response.status = ERROR_CODE send_data = pickle.dumps(response) @@ -149,7 +161,6 @@ def _request(self, request: SocketRequest): recv_data = client.recv(256) client.close() response: LockResponse = pickle.loads(recv_data) - print(response) if response.status == SUCCESS_CODE: return response else: From 0015c00230bcd6fd4a6fa50b19f9ba1379e0a717 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Mon, 27 Nov 2023 17:02:34 +0800 Subject: [PATCH 3/6] Fix the import path. --- dlrover/python/tests/test_socket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlrover/python/tests/test_socket.py b/dlrover/python/tests/test_socket.py index b621bb274..6d27cef6a 100644 --- a/dlrover/python/tests/test_socket.py +++ b/dlrover/python/tests/test_socket.py @@ -13,7 +13,7 @@ import unittest -from dlrover.python.common.socket import SharedLock +from dlrover.python.common.process import SharedLock class SharedLockTest(unittest.TestCase): From b02ff76eeceb07e08dcd31291afe815dbd8168c2 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 28 Nov 2023 16:58:39 +0800 Subject: [PATCH 4/6] Implement shared object across local processes. --- dlrover/python/common/process.py | 168 ------- dlrover/python/common/shared_obj.py | 447 ++++++++++++++++++ .../python/elastic_agent/torch/checkpoint.py | 50 ++ .../{test_socket.py => test_shared_obj.py} | 34 +- 4 files changed, 530 insertions(+), 169 deletions(-) delete mode 100644 dlrover/python/common/process.py create mode 100644 dlrover/python/common/shared_obj.py create mode 100644 dlrover/python/elastic_agent/torch/checkpoint.py rename dlrover/python/tests/{test_socket.py => test_shared_obj.py} (51%) diff --git a/dlrover/python/common/process.py b/dlrover/python/common/process.py deleted file mode 100644 index 16c58c991..000000000 --- a/dlrover/python/common/process.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2023 The DLRover Authors. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import pickle -import socket -import threading -import time -from dataclasses import dataclass -from typing import Dict - -SOCKER_TEMP_FILE_DIR = "/tmp/checkpoint/" - -SUCCESS_CODE = "OK" -ERROR_CODE = "ERROR" - - -def _create_socket_server(path): - """ - Create a socket server. - - Args: - path (str): a file path. - """ - server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - path_dir = os.path.dirname(path) - os.makedirs(path_dir, exist_ok=True) - if os.path.exists(path): - os.unlink(path) - server.bind(path) - server.listen(0) - return server - - -def _create_socket_client(path): - """ - Create a socket client. - - Args: - path (str): a file path. - - """ - client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - client.connect(path) - return client - - -@dataclass -class SocketRequest(object): - method: str = "" - args: Dict[str, object] = None # type: ignore - - -@dataclass -class SocketResponse(object): - status: str = "" - - -@dataclass -class LockResponse(SocketResponse): - acquired: bool = False - - -class SharedLock(object): - """ - On a single node, processes can share a lock with an identical name - via socket-based communication. - - Args: - name (str): the lock name, processes can share a lock with an - identical name on a single node. - create (bool): If ture, the lock creates a socket server and a lock. - Otherwise, the lock need to create a socket client to access - the lock. - """ - - def __init__(self, name="", create=False): - self._name = name - self._lock = threading.Lock() - fname = "lock_" + self._name + ".sock" - self._path = os.path.join(SOCKER_TEMP_FILE_DIR, fname) - self._create = create - self._server = None - self._init_socket() - - def _init_socket(self): - if self._create: - self._server = _create_socket_server(self._path) - t = threading.Thread( - target=self._sync_lock_status, - daemon=True, - ) - t.start() - - def _sync_lock_status(self): - while True: - connection, _ = self._server.accept() - try: - recv_data = connection.recv(256) - msg: SocketRequest = pickle.loads(recv_data) - response = LockResponse() - if msg.method == "acquire": - response.acquired = self.acquire(**msg.args) - elif msg.method == "release": - self.release() - response.status = SUCCESS_CODE - except Exception: - response = LockResponse() - response.status = ERROR_CODE - send_data = pickle.dumps(response) - connection.send(send_data) - - def acquire(self, blocking=True): - """ - Acquire a lock shared by multiple process, blocking or non-blocking. - - Args: - blocking (bool): blocking or non-blocking. - """ - if self._create: - return self._lock.acquire(blocking=blocking) - else: - request = SocketRequest( - method="acquire", - args={"blocking": blocking}, - ) - response = self._request(request) - if response: - return response.acquired - return False - - def release(self): - """ - Release a lock shared by multiple processes. - """ - if self._create: - if self._lock.locked(): - self._lock.release() - else: - request = SocketRequest( - method="release", - args={}, - ) - self._request(request) - - def _request(self, request: SocketRequest): - for _ in range(3): - client = _create_socket_client(self._path) - send_data = pickle.dumps(request) - client.send(send_data) - recv_data = client.recv(256) - client.close() - response: LockResponse = pickle.loads(recv_data) - if response.status == SUCCESS_CODE: - return response - else: - time.sleep(1) - continue diff --git a/dlrover/python/common/shared_obj.py b/dlrover/python/common/shared_obj.py new file mode 100644 index 000000000..a689d38e9 --- /dev/null +++ b/dlrover/python/common/shared_obj.py @@ -0,0 +1,447 @@ +# Copyright 2023 The DLRover Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import queue +import socket +import threading +import time +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from typing import Dict + +from .log import default_logger as logger + +TMP_DIR = "/tmp" + +SUCCESS_CODE = "OK" +ERROR_CODE = "ERROR" + + +def _create_socket_server(path): + """ + Create a socket server. + + Args: + path (str): a file path. + """ + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + path_dir = os.path.dirname(path) + os.makedirs(path_dir, exist_ok=True) + if os.path.exists(path): + os.unlink(path) + server.bind(path) + server.listen(0) + return server + + +def _create_socket_client(path): + """ + Create a socket client. + + Args: + path (str): a file path. + + """ + client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client.connect(path) + return client + + +@dataclass +class SocketRequest(object): + """ + A socket request. + + Attributes: + method (str): the method name to call. + args (dict): the arguments of the method. + """ + + method: str = "" + args: Dict[str, object] = None # type: ignore + + +@dataclass +class SocketResponse(object): + """ + A socket response. + + Attributes: + status (str): the return code which may be "OK" or "ERROR". + """ + + status: str = "" + + +@dataclass +class LockAcquireResponse(SocketResponse): + """ + A response to acquire a shared lock using local socket. + + Attributes: + acquired (bool): Ture if the lock is acquired. + """ + + acquired: bool = False + + +class LocalSocketComm(metaclass=ABCMeta): + """ + Local socket for processes to communicate. + + Args: + name (str): the instance name which must be unique if multiple + process share a common object using the local socket. + create (bool): If ture, the instance creates a socket server + Otherwise, the instance creates a socket client to access + the shared object. + """ + + def __init__(self, name="", create=False): + self._name = name + self._file = self._create_socket_path() + self._create = create + self._server = None + self._init_socket() + + def __del__(self): + os.unlink(self._file) + + def _create_socket_path(self): + """Create a file path for the local socket.""" + fname = self.__class__.__name__.lower() + "_" + self._name + ".sock" + return os.path.join(TMP_DIR, fname) + + def _init_socket(self): + """Initialze a socket server.""" + if self._create: + self._server = _create_socket_server(self._file) + t = threading.Thread( + target=self._sync, + daemon=True, + ) + t.start() + + @abstractmethod + def _sync(self): + """Synchronize the obj between processes.""" + pass + + def _request(self, request: SocketRequest): + """Create a socket client to requet the shared object.""" + client = _create_socket_client(self._file) + send_data = pickle.dumps(request) + client.send(send_data) + recv_data = client.recv(256) + client.close() + response: LockAcquireResponse = pickle.loads(recv_data) + return response + + +class SharedLock(LocalSocketComm): + """ + On a single node, processes can share a lock with an identical name + via socket-based communication. + + Args: + name (str): the lock name, processes can share a lock with an + identical name on a single node. + create (bool): If ture, the lock creates a socket server and a lock. + Otherwise, the lock need to create a socket client to access + the lock. + """ + + def __init__(self, name="", create=False): + super().__init__(name, create) + if self._create: + self._lock = threading.Lock() + else: + self._lock = None + + def _sync(self): + while True: + connection, _ = self._server.accept() + try: + recv_data = connection.recv(256) + msg: SocketRequest = pickle.loads(recv_data) + response = LockAcquireResponse() + if msg.method == "acquire": + response.acquired = self.acquire(**msg.args) + elif msg.method == "release": + self.release() + response.status = SUCCESS_CODE + except Exception: + response = LockAcquireResponse() + response.status = ERROR_CODE + send_data = pickle.dumps(response) + connection.send(send_data) + + def acquire(self, blocking=True): + """ + Acquire a lock shared by multiple process, blocking or non-blocking. + + Args: + blocking (bool): blocking or non-blocking. + """ + if self._server: + return self._lock.acquire(blocking=blocking) + else: + request = SocketRequest( + method="acquire", + args={"blocking": blocking}, + ) + response = self._request(request) + if response: + return response.acquired + return False + + def release(self): + """ + Release a lock shared by multiple processes. + """ + if self._server: + if self._lock.locked(): + self._lock.release() + else: + request = SocketRequest( + method="release", + args={}, + ) + self._request(request) + + +@dataclass +class QueueGetResponse(SocketResponse): + """ + The response to get an obj from a shared queue using local socket. + + Attributes: + obj (object): the return value to get an obj from a shared queue. + """ + + obj: object = None + + +@dataclass +class QueueSizeResponse(SocketResponse): + """ + The response to get the size of a shared queue using local socket. + + Attributes: + size (int): the size of a queue. + """ + + size: int = 0 + + +@dataclass +class QueueEmptyResponse(SocketResponse): + """ + The response to verify a shared queue is empty. + + Attributes: + empty (bool): True if the queue is empty. + """ + + empty: bool = False + + +class SharedQueue(LocalSocketComm): + """ + On a single node, processes can share a queue with an identical name + via local socket communication. + + Args: + name (str): the queue name, processes can share the queue with an + identical name on a single node. + create (bool): If ture, the instance creates a socket server and a + queue. Otherwise, the instance need to create a local socket + client to access the queue. + """ + + def __init__(self, name="", create=False, maxsize=1): + super().__init__(name, create) + if self._create: + self._queue = queue.Queue(maxsize) + else: + self._queue = None + + def _sync(self): + while True: + connection, _ = self._server.accept() + try: + recv_data = connection.recv(256) + msg: SocketRequest = pickle.loads(recv_data) + response = SocketResponse() + if msg.method == "put": + self.put(**msg.args) + elif msg.method == "get": + response = QueueGetResponse() + response.obj = self.get(**msg.args) + elif msg.method == "qsize": + response = QueueSizeResponse() + response.size = self.qsize() + elif msg.method == "empty": + response = QueueEmptyResponse() + response.empty = self.empty() + response.status = SUCCESS_CODE + except Exception: + response = SocketResponse() + response.status = ERROR_CODE + send_data = pickle.dumps(response) + connection.send(send_data) + + def put(self, obj, block=True, timeout=None): + """Put an object into the queue.""" + if self._server: + self._queue.put(obj, block=block, timeout=timeout) + else: + args = {} + args["obj"] = obj + args["block"] = block + args["timeout"] = timeout + request = SocketRequest(method="put", args=args) + self._request(request) + + def get(self, block=True, timeout=None): + """Get an object from the queue.""" + if self._server: + obj = self._queue.get(block=block, timeout=timeout) + return obj + else: + args = {} + args["block"] = block + args["timeout"] = timeout + request = SocketRequest(method="get", args=args) + response: QueueGetResponse = self._request(request) + if response.status == SUCCESS_CODE: + return response.obj + return None + + def qsize(self): + """Get the size of the queue.""" + if self._server: + return self._queue.qsize() + else: + request = SocketRequest(method="qsize", args={}) + response: QueueSizeResponse = self._request(request) + if response.status == SUCCESS_CODE: + return response.size + return -1 + + def empty(self): + """Verify the queue is empty.""" + if self._server: + return self._queue.empty() + else: + request = SocketRequest(method="empty", args={}) + response: QueueEmptyResponse = self._request(request) + if response.status == SUCCESS_CODE: + return response.empty + return False + + +# The process uses FIFO pipe not the local socket to transfer +# the tensor meta dict. Because, the local socket needs buffers +# at both the sending end and receiving end. The FIFO only need +# one buffer. The size of tensor meta dict may be large. Local socket +# may need double memory buffer size to transfer the dict. +class SharedDict(object): + """ + A shared dict is used in two processes. One process updates the dict + and another uses the dict. + + Args: + name (str): the shared dictionary name, one process can update the + dict with the same name of another process by fifo pipe. + create (bool): If ture, the instance reads the dict from the fifo. + Otherwist, the instance writes the dict into the fifo. + """ + + def __init__(self, name="", create=False): + self._name = name + self._create = create + fname = self.__class__.__name__.lower() + "_" + self._name + ".fifo" + self._file = os.path.join(TMP_DIR, fname) + self._fd = None + + if not os.path.exists(self._file): + os.mkfifo(self._file, 0o666) + if self._create: + self._dict = {} + self._shared_queue = SharedQueue( + name=f"shard_dict_{name}", create=self._create + ) + threading.Thread( + target=self._sync, daemon=True, name=f"{name}-receiver" + ).start() + else: + self._dict = None + self._shared_queue = SharedQueue( + name=f"shard_dict_{name}", create=self._create + ) + + def __del__(self): + os.unlink(self._file) + + def _sync(self): + if self._create: + self._fd = os.open(self._file, os.O_RDONLY) + while True: + recv_bytes = os.read(self._fd, 4) + msg_size = int.from_bytes(recv_bytes, "big") + total_bytes = b"" + while True: + buffer_size = 1024 * 1024 + recv_bytes = os.read(self._fd, buffer_size) + total_bytes += recv_bytes + if len(total_bytes) == msg_size: + break + d = pickle.loads(total_bytes) + self._dict.update(d) + self._shared_queue.get() + + def update(self, new_dict): + """ + Update the shared Dict with a new Dict. + + Args: + new_dict (dict): a new dict to update. + """ + if self._create: + self._dict.update(new_dict) + else: + if not self._fd: + self._fd = os.open(self._file, os.O_WRONLY) + bs = pickle.dumps(new_dict) + bs_size = len(bs) + try: + # Firstly send the size of the message. + os.write(self._fd, bs_size.to_bytes(4, "big")) + os.write(self._fd, bs) + self._shared_queue.put(1) + except Exception: + logger.info("The recv processs has breakdown.") + + def get(self): + """ + Returns a Python Dict from the shared Dict. + + If the writing instance sends the dict into the FIFO, the get method + should wait for the sync thread to update the dict. + """ + while not self._shared_queue.empty(): + time.sleep(0.1) + return self._dict diff --git a/dlrover/python/elastic_agent/torch/checkpoint.py b/dlrover/python/elastic_agent/torch/checkpoint.py new file mode 100644 index 000000000..05509d1b4 --- /dev/null +++ b/dlrover/python/elastic_agent/torch/checkpoint.py @@ -0,0 +1,50 @@ +import time +import socket +import os +import pickle +from typing import Dict +from dlrover.python.common.socket import CheckpointMeta + + +class CheckpointBuffer(object): + """ + + Args: + num_proc (int): Number of workers on the node. + """ + def __init__(self, num_proc): + self._num_proc = num_proc + self._rank_ckpt_metas: Dict[int, CheckpointMeta] = dict() + + def _create_socket_server(self): + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + if os.path.exists("/tmp/checkpoint.sock"): + os.unlink("/tmp/checkpoint.sock") + server.bind("/tmp/checkpoint.sock") + server.listen(0) + while True: + try: + connection, _ = server.accept() + recv = connection.recv(1024) + self._deserialize(recv) + connection.send(b"OK") + except Exception: + connection.close() + break + + def _deserialize(self, buffer): + obj = pickle.loads(buffer) + if isinstance(obj, CheckpointMeta): + self._rank_ckpt_metas[obj.rank] = obj + + def _wait_training_process_init_model(self): + while True: + if len(self._rank_ckpt_metas) == self._num_proc: + break + time.sleep(1) + + def _create_shared_memory_buffer(self): + pass + + def save_to_storage(self): + pass \ No newline at end of file diff --git a/dlrover/python/tests/test_socket.py b/dlrover/python/tests/test_shared_obj.py similarity index 51% rename from dlrover/python/tests/test_socket.py rename to dlrover/python/tests/test_shared_obj.py index 6d27cef6a..c6246897d 100644 --- a/dlrover/python/tests/test_socket.py +++ b/dlrover/python/tests/test_shared_obj.py @@ -13,7 +13,11 @@ import unittest -from dlrover.python.common.process import SharedLock +from dlrover.python.common.shared_obj import ( + SharedDict, + SharedLock, + SharedQueue, +) class SharedLockTest(unittest.TestCase): @@ -29,3 +33,31 @@ def test_shared_lock(self): acquired = client_lock.acquire(blocking=False) self.assertTrue(acquired) client_lock.release() + + def test_shared_queue(self): + name = "test" + server_queue = SharedQueue(name, create=True) + client_queue = SharedQueue(name, create=False) + server_queue.put(2) + qsize = server_queue.qsize() + self.assertEqual(qsize, 1) + value = server_queue.get() + self.assertEqual(value, 2) + client_queue.put(3) + qsize = client_queue.qsize() + self.assertEqual(qsize, 1) + qsize = client_queue.qsize() + self.assertEqual(qsize, 1) + value = client_queue.get() + self.assertEqual(value, 3) + + def test_shared_dict(self): + name = "test" + read_dict = SharedDict(name=name, create=True) + write_dict = SharedDict(name=name, create=False) + new_dict = {"a": 1, "b": 2} + write_dict.update(new_dict=new_dict) + new_dict["a"] = 4 + write_dict.update(new_dict=new_dict) + d = read_dict.get() + self.assertDictEqual(d, new_dict) From 09933d6a286550d1206a5eec78de0fac7880350b Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 28 Nov 2023 17:04:28 +0800 Subject: [PATCH 5/6] Remove test files. --- .../python/elastic_agent/torch/checkpoint.py | 50 ------------------- 1 file changed, 50 deletions(-) delete mode 100644 dlrover/python/elastic_agent/torch/checkpoint.py diff --git a/dlrover/python/elastic_agent/torch/checkpoint.py b/dlrover/python/elastic_agent/torch/checkpoint.py deleted file mode 100644 index 05509d1b4..000000000 --- a/dlrover/python/elastic_agent/torch/checkpoint.py +++ /dev/null @@ -1,50 +0,0 @@ -import time -import socket -import os -import pickle -from typing import Dict -from dlrover.python.common.socket import CheckpointMeta - - -class CheckpointBuffer(object): - """ - - Args: - num_proc (int): Number of workers on the node. - """ - def __init__(self, num_proc): - self._num_proc = num_proc - self._rank_ckpt_metas: Dict[int, CheckpointMeta] = dict() - - def _create_socket_server(self): - server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - if os.path.exists("/tmp/checkpoint.sock"): - os.unlink("/tmp/checkpoint.sock") - server.bind("/tmp/checkpoint.sock") - server.listen(0) - while True: - try: - connection, _ = server.accept() - recv = connection.recv(1024) - self._deserialize(recv) - connection.send(b"OK") - except Exception: - connection.close() - break - - def _deserialize(self, buffer): - obj = pickle.loads(buffer) - if isinstance(obj, CheckpointMeta): - self._rank_ckpt_metas[obj.rank] = obj - - def _wait_training_process_init_model(self): - while True: - if len(self._rank_ckpt_metas) == self._num_proc: - break - time.sleep(1) - - def _create_shared_memory_buffer(self): - pass - - def save_to_storage(self): - pass \ No newline at end of file From 82f33fbd3aafec85fe2cb9b75127afab555358d4 Mon Sep 17 00:00:00 2001 From: Qinlong Wang Date: Tue, 28 Nov 2023 20:48:25 +0800 Subject: [PATCH 6/6] Put a value in the queue before writing.. --- dlrover/python/common/shared_obj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dlrover/python/common/shared_obj.py b/dlrover/python/common/shared_obj.py index a689d38e9..1bacb1d9d 100644 --- a/dlrover/python/common/shared_obj.py +++ b/dlrover/python/common/shared_obj.py @@ -428,10 +428,10 @@ def update(self, new_dict): bs = pickle.dumps(new_dict) bs_size = len(bs) try: + self._shared_queue.put(1) # Firstly send the size of the message. os.write(self._fd, bs_size.to_bytes(4, "big")) os.write(self._fd, bs) - self._shared_queue.put(1) except Exception: logger.info("The recv processs has breakdown.")