From c278aff01f366b55418f52b6efa64b6d13e89b12 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang <100308595+nvidianz@users.noreply.github.com> Date: Fri, 6 Sep 2024 20:29:52 -0400 Subject: [PATCH] Fixing the memoryview issues (#2926) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added handling for buffer overun * Added task_lock to read() and ignore duplicate chunks * Simplifed the wait loop * Fixed a formatting error * Check EOS when appending data --------- Co-authored-by: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- nvflare/fuel/f3/streaming/blob_streamer.py | 13 +++-- nvflare/fuel/f3/streaming/byte_receiver.py | 55 +++++++++++++++------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/nvflare/fuel/f3/streaming/blob_streamer.py b/nvflare/fuel/f3/streaming/blob_streamer.py index e3ba9ab11c..282d8e8050 100644 --- a/nvflare/fuel/f3/streaming/blob_streamer.py +++ b/nvflare/fuel/f3/streaming/blob_streamer.py @@ -111,14 +111,17 @@ def _read_stream(blob_task: BlobTask): length = len(buf) try: if blob_task.pre_allocated: - blob_task.buffer[buf_size : buf_size + length] = buf + remaining = len(blob_task.buffer) - buf_size + if length > remaining: + log.error(f"Buffer overrun: {remaining=} {length=} {buf_size=}") + if remaining > 0: + blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining] + else: + blob_task.buffer[buf_size : buf_size + length] = buf else: blob_task.buffer.append(buf) except Exception as ex: - log.error( - f"memory view error: {ex} " - f"Debug info: {length=} {buf_size=} {len(blob_task.pre_allocated)=} {type(buf)=}" - ) + log.error(f"memory view error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}") raise ex buf_size += length diff --git a/nvflare/fuel/f3/streaming/byte_receiver.py b/nvflare/fuel/f3/streaming/byte_receiver.py index 08a815ca38..a9309bc030 100644 --- a/nvflare/fuel/f3/streaming/byte_receiver.py +++ b/nvflare/fuel/f3/streaming/byte_receiver.py @@ -14,7 +14,7 @@ import logging import threading from collections import deque -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Optional, Tuple from nvflare.fuel.f3.cellnet.core_cell import CoreCell from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey @@ -41,6 +41,9 @@ ACK_INTERVAL = 1024 * 1024 * 4 READ_TIMEOUT = 300 COUNTER_NAME_RECEIVED = "received" +RESULT_DATA = 0 +RESULT_WAIT = 1 +RESULT_EOS = 2 class RxTask: @@ -78,30 +81,44 @@ def __init__(self, byte_receiver: "ByteReceiver", task: RxTask): super().__init__(task.size, task.headers) self.byte_receiver = byte_receiver self.task = task + self.timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT) + self.ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL) def read(self, chunk_size: int) -> bytes: if self.closed: raise StreamError("Read from closed stream") - if (not self.task.buffers) and self.task.eos: - return EOS - - # Block if buffers are empty count = 0 - while not self.task.buffers: + while True: + result_code, result = self._read_chunk(chunk_size) + if result_code == RESULT_EOS: + return EOS + elif result_code == RESULT_DATA: + return result + + # Block if buffers are empty if count > 0: - log.debug(f"Read block is unblocked multiple times: {count}") + log.warning(f"Read block is unblocked multiple times: {count}") self.task.waiter.clear() - timeout = CommConfigurator().get_streaming_read_timeout(READ_TIMEOUT) - if not self.task.waiter.wait(timeout): - error = StreamError(f"{self.task} read timed out after {timeout} seconds") + + if not self.task.waiter.wait(self.timeout): + error = StreamError(f"{self.task} read timed out after {self.timeout} seconds") self.byte_receiver.stop_task(self.task, error) raise error count += 1 + def _read_chunk(self, chunk_size: int) -> Tuple[int, Optional[BytesAlike]]: + with self.task.task_lock: + + if not self.task.buffers: + if self.task.eos: + return RESULT_EOS, None + else: + return RESULT_WAIT, None + last_chunk, buf = self.task.buffers.popleft() if buf is None: buf = bytes(0) @@ -117,8 +134,7 @@ def read(self, chunk_size: int) -> bytes: self.task.offset += len(result) - ack_interval = CommConfigurator().get_streaming_ack_interval(ACK_INTERVAL) - if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > ack_interval): + if not self.task.last_chunk_received and (self.task.offset - self.task.offset_ack > self.ack_interval): # Send ACK message = Message() message.add_headers( @@ -133,7 +149,7 @@ def read(self, chunk_size: int) -> bytes: self.task.stream_future.set_progress(self.task.offset) - return result + return RESULT_DATA, result def close(self): if not self.task.stream_future.done(): @@ -148,6 +164,7 @@ def __init__(self, cell: CoreCell): self.registry = Registry() self.rx_task_map = {} self.map_lock = threading.Lock() + self.max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS) self.received_stream_counter_pool = StatsPoolManager.add_counter_pool( name="Received_Stream_Counters", @@ -254,6 +271,10 @@ def _data_handler(self, message: Message): if last_chunk: task.last_chunk_received = True + if seq < task.next_seq: + log.warning(f"{task} Duplicate chunk ignored {seq=}") + return + if seq == task.next_seq: self._append(task, (last_chunk, payload)) task.next_seq += 1 @@ -266,8 +287,7 @@ def _data_handler(self, message: Message): else: # Out-of-seq chunk reassembly - max_out_seq = CommConfigurator().get_streaming_max_out_seq_chunks(MAX_OUT_SEQ_CHUNKS) - if len(task.out_seq_buffers) >= max_out_seq: + if len(task.out_seq_buffers) >= self.max_out_seq: self.stop_task(task, StreamError(f"Too many out-of-sequence chunks: {len(task.out_seq_buffers)}")) return else: @@ -294,7 +314,10 @@ def _append(task: RxTask, buf: Tuple[bool, BytesAlike]): if not buf: return - task.buffers.append(buf) + if task.eos: + log.error(f"{task} Data after EOS is ignored") + else: + task.buffers.append(buf) # Wake up blocking read() if not task.waiter.is_set():