Skip to content

Commit

Permalink
The agent in the main process async saves the state dict to storage. (#…
Browse files Browse the repository at this point in the history
…857)

* Customization of the SharedMemory without registering in ResourceTracker.

* Polish the docstring.

* Polish the docstring.

* Fix test cases.

* The agent in the main process async save to storage.

* Format codes.

* Refactor the socket communication.

* Fix test cases.

* Fix test cases.

* Add test cases.

* Fix the example.

* Add test cases.

* Format codes.
  • Loading branch information
workingloong authored Nov 30, 2023
1 parent 07a8adf commit bb28957
Show file tree
Hide file tree
Showing 9 changed files with 945 additions and 598 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,42 @@ def _create_socket_client(path):
"""
client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
client.connect(path)
connected = False
for _ in range(3):
try:
client.connect(path)
connected = True
break
except FileNotFoundError:
time.sleep(1)
if not connected:
client.connect(path)
return client


def _socket_send(socket: socket.socket, message):
"""
In the protocol, the first 4 bytes is the size of message.
"""
head = len(message).to_bytes(4, "big")
send_data = head + message
socket.send(send_data)


def _socket_recv(socket: socket.socket):
"""
In the protocol, the first 4 bytes is the size of message.
"""
recv_data = socket.recv(1024)
head = recv_data[0:4]
message = recv_data[4:]
message_len = int.from_bytes(head, "big")
while len(message) < message_len:
recv_data = socket.recv(1024)
message += recv_data
return message


@dataclass
class SocketRequest(object):
"""
Expand Down Expand Up @@ -115,13 +147,19 @@ class LocalSocketComm(metaclass=ABCMeta):

def __init__(self, name="", create=False):
self._name = name
self._file = self._create_socket_path()
self._socket_file = self._create_socket_path()
self._create = create
self._server = None
self._init_socket()

def __del__(self):
os.unlink(self._file)
self.close()

def close(self):
try:
os.unlink(self._socket_file)
except FileNotFoundError:
pass

def _create_socket_path(self):
"""Create a file path for the local socket."""
Expand All @@ -131,7 +169,7 @@ def _create_socket_path(self):
def _init_socket(self):
"""Initialze a socket server."""
if self._create:
self._server = _create_socket_server(self._file)
self._server = _create_socket_server(self._socket_file)
t = threading.Thread(
target=self._sync,
daemon=True,
Expand All @@ -145,10 +183,10 @@ def _sync(self):

def _request(self, request: SocketRequest):
"""Create a socket client to requet the shared object."""
client = _create_socket_client(self._file)
send_data = pickle.dumps(request)
client.send(send_data)
recv_data = client.recv(256)
client = _create_socket_client(self._socket_file)
message = pickle.dumps(request)
_socket_send(client, message)
recv_data = _socket_recv(client)
client.close()
response: LockAcquireResponse = pickle.loads(recv_data)
return response
Expand Down Expand Up @@ -178,7 +216,7 @@ def _sync(self):
while True:
connection, _ = self._server.accept()
try:
recv_data = connection.recv(256)
recv_data = _socket_recv(connection)
msg: SocketRequest = pickle.loads(recv_data)
response = LockAcquireResponse()
if msg.method == "acquire":
Expand All @@ -190,7 +228,7 @@ def _sync(self):
response = LockAcquireResponse()
response.status = ERROR_CODE
send_data = pickle.dumps(response)
connection.send(send_data)
_socket_send(connection, send_data)

def acquire(self, blocking=True):
"""
Expand Down Expand Up @@ -286,7 +324,7 @@ def _sync(self):
while True:
connection, _ = self._server.accept()
try:
recv_data = connection.recv(256)
recv_data = _socket_recv(connection)
msg: SocketRequest = pickle.loads(recv_data)
response = SocketResponse()
if msg.method == "put":
Expand All @@ -304,8 +342,8 @@ def _sync(self):
except Exception:
response = SocketResponse()
response.status = ERROR_CODE
send_data = pickle.dumps(response)
connection.send(send_data)
message = pickle.dumps(response)
_socket_send(connection, message)

def put(self, obj, block=True, timeout=None):
"""Put an object into the queue."""
Expand Down Expand Up @@ -357,65 +395,56 @@ def empty(self):
return False


# The process uses FIFO pipe not the local socket to transfer
# the tensor meta dict. Because, the local socket needs buffers
# at both the sending end and receiving end. The FIFO only need
# one buffer. The size of tensor meta dict may be large. Local socket
# may need double memory buffer size to transfer the dict.
class SharedDict(object):
@dataclass
class DictMessage(SocketResponse):
"""
The response to get the dict of shared dict using local socket.
Attributes:
meta_dict (dict): the return value to get an obj from a shared queue.
"""
A shared dict is used in two processes. One process updates the dict
and another uses the dict.

meta_dict: object = None


class SharedDict(LocalSocketComm):
"""
A shared dict between local processes.
Args:
name (str): the shared dictionary name, one process can update the
dict with the same name of another process by fifo pipe.
create (bool): If ture, the instance reads the dict from the fifo.
Otherwist, the instance writes the dict into the fifo.
dict with the same name of another process by local socket.
create (bool): If ture, the instance will receive the dict from the
sending process to update its dict.
"""

def __init__(self, name="", create=False):
self._name = name
self._create = create
fname = self.__class__.__name__.lower() + "_" + self._name + ".fifo"
self._file = os.path.join(TMP_DIR, fname)
self._fd = None

if not os.path.exists(self._file):
os.mkfifo(self._file, 0o666)
if self._create:
self._dict = {}
self._shared_queue = SharedQueue(
name=f"shard_dict_{name}", create=self._create
)
threading.Thread(
target=self._sync, daemon=True, name=f"{name}-receiver"
).start()
else:
self._dict = None
self._shared_queue = SharedQueue(
name=f"shard_dict_{name}", create=self._create
)
super().__init__(name, create)

def __del__(self):
os.unlink(self._file)
self._dict = {}
self._shared_queue = SharedQueue(
name=f"shard_dict_{name}", create=self._create
)

def _sync(self):
if self._create:
self._fd = os.open(self._file, os.O_RDONLY)
while True:
recv_bytes = os.read(self._fd, 4)
msg_size = int.from_bytes(recv_bytes, "big")
total_bytes = b""
while True:
buffer_size = 1024 * 1024
recv_bytes = os.read(self._fd, buffer_size)
total_bytes += recv_bytes
if len(total_bytes) == msg_size:
break
d = pickle.loads(total_bytes)
self._dict.update(d)
self._shared_queue.get()
connection, _ = self._server.accept()
try:
recv_data = _socket_recv(connection)
msg: SocketRequest = pickle.loads(recv_data)
response = DictMessage()
if msg.method == "update":
self.update(**msg.args)
self._shared_queue.get(1)
elif msg.method == "get":
response = DictMessage()
response.meta_dict = self.get(**msg.args)
response.status = SUCCESS_CODE
except Exception:
response = SocketResponse()
response.status = ERROR_CODE
message = pickle.dumps(response)
_socket_send(connection, message)

def update(self, new_dict):
"""
Expand All @@ -424,18 +453,13 @@ def update(self, new_dict):
Args:
new_dict (dict): a new dict to update.
"""
if self._create:
self._dict.update(new_dict)
else:
if not self._fd:
self._fd = os.open(self._file, os.O_WRONLY)
bs = pickle.dumps(new_dict)
bs_size = len(bs)
self._dict.update(new_dict)
if not self._server:
args = {"new_dict": new_dict}
request = SocketRequest(method="update", args=args)
try:
self._shared_queue.put(1)
# Firstly send the size of the message.
os.write(self._fd, bs_size.to_bytes(4, "big"))
os.write(self._fd, bs)
self._request(request)
except Exception:
logger.info("The recv processs has breakdown.")

Expand All @@ -446,9 +470,16 @@ def get(self):
If the writing instance sends the dict into the FIFO, the get method
should wait for the sync thread to update the dict.
"""
while not self._shared_queue.empty():
time.sleep(0.1)
return self._dict
if self._server:
while not self._shared_queue.empty():
time.sleep(0.1)
return self._dict
else:
request = SocketRequest(method="get", args={})
response: DictMessage = self._request(request)
if response.status == SUCCESS_CODE:
return response.meta_dict
return {}


class SharedMemory(shared_memory.SharedMemory):
Expand Down Expand Up @@ -519,4 +550,7 @@ def unlink(self):
called once (and only once) across all processes which have access
to the shared memory block."""
if self._name:
_posixshmem.shm_unlink(self._name)
try:
_posixshmem.shm_unlink(self._name)
except FileNotFoundError:
pass
Loading

0 comments on commit bb28957

Please sign in to comment.