Skip to content

Commit

Permalink
Fixing memoryview error (#2929)
Browse files Browse the repository at this point in the history
* Fixed dup seq 0 bug

* Formatting errors
  • Loading branch information
nvidianz authored Sep 9, 2024
1 parent c278aff commit bff1d69
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 31 deletions.
14 changes: 8 additions & 6 deletions nvflare/fuel/f3/streaming/blob_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def __init__(self, future: StreamFuture, stream: Stream):
else:
self.buffer = FastBuffer()

def __str__(self):
return f"Blob[SID:{self.future.get_stream_id()} Size:{self.size}]"


class BlobHandler:
def __init__(self, blob_cb: Callable):
Expand Down Expand Up @@ -113,23 +116,22 @@ def _read_stream(blob_task: BlobTask):
if blob_task.pre_allocated:
remaining = len(blob_task.buffer) - buf_size
if length > remaining:
log.error(f"Buffer overrun: {remaining=} {length=} {buf_size=}")
log.error(f"{blob_task} Buffer overrun: {remaining=} {length=} {buf_size=}")
if remaining > 0:
blob_task.buffer[buf_size : buf_size + remaining] = buf[0:remaining]
break
else:
blob_task.buffer[buf_size : buf_size + length] = buf
else:
blob_task.buffer.append(buf)
except Exception as ex:
log.error(f"memory view error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}")
log.error(f"{blob_task} memoryview error: {ex} Debug info: {length=} {buf_size=} {type(buf)=}")
raise ex

buf_size += length

if blob_task.size and blob_task.size != buf_size:
log.warning(
f"Stream {blob_task.future.get_stream_id()} size doesn't match: " f"{blob_task.size} <> {buf_size}"
)
log.warning(f"Stream {blob_task} Size doesn't match: " f"{blob_task.size} <> {buf_size}")

if blob_task.pre_allocated:
result = blob_task.buffer
Expand All @@ -138,7 +140,7 @@ def _read_stream(blob_task: BlobTask):

blob_task.future.set_result(result)
except Exception as ex:
log.error(f"Stream {blob_task.future.get_stream_id()} read error: {ex}")
log.error(f"Stream {blob_task} Read error: {ex}")
log.error(secure_format_traceback())
blob_task.future.set_exception(ex)

Expand Down
55 changes: 30 additions & 25 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, sid: int, origin: str):
self.last_chunk_received = False

def __str__(self):
return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic}]"
return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]"


class RxStream(Stream):
Expand All @@ -98,9 +98,7 @@ def read(self, chunk_size: int) -> bytes:

# Block if buffers are empty
if count > 0:
log.warning(f"Read block is unblocked multiple times: {count}")

self.task.waiter.clear()
log.warning(f"{self.task} Read block is unblocked multiple times: {count}")

if not self.task.waiter.wait(self.timeout):
error = StreamError(f"{self.task} read timed out after {self.timeout} seconds")
Expand All @@ -117,6 +115,7 @@ def _read_chunk(self, chunk_size: int) -> Tuple[int, Optional[BytesAlike]]:
if self.task.eos:
return RESULT_EOS, None
else:
self.task.waiter.clear()
return RESULT_WAIT, None

last_chunk, buf = self.task.buffers.popleft()
Expand Down Expand Up @@ -239,33 +238,39 @@ def _data_handler(self, message: Message):
self.stop_task(task, StreamError(f"Received error from {origin}: {error}"), notify=False)
return

if seq == 0:
# Handle new stream
task.channel = message.get_header(StreamHeaderKey.CHANNEL)
task.topic = message.get_header(StreamHeaderKey.TOPIC)
task.headers = message.headers
with task.task_lock:
if seq == 0:
# Handle new stream
task.channel = message.get_header(StreamHeaderKey.CHANNEL)
task.topic = message.get_header(StreamHeaderKey.TOPIC)
task.headers = message.headers

# GRPC may re-send the same request, causing seq 0 delivered more than once
if task.stream_future:
log.warning(f"{task} Received duplicate chunk 0, ignored")
return

task.stream_future = StreamFuture(sid, message.headers)
task.size = message.get_header(StreamHeaderKey.SIZE, 0)
task.stream_future.set_size(task.size)
task.stream_future = StreamFuture(sid, message.headers)
task.size = message.get_header(StreamHeaderKey.SIZE, 0)
task.stream_future.set_size(task.size)

# Invoke callback
callback = self.registry.find(task.channel, task.topic)
if not callback:
self.stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}"))
return
# Invoke callback
callback = self.registry.find(task.channel, task.topic)
if not callback:
self.stop_task(task, StreamError(f"No callback is registered for {task.channel}/{task.topic}"))
return

self.received_stream_counter_pool.increment(
category=stream_stats_category(task.channel, task.topic, "stream"), counter_name=COUNTER_NAME_RECEIVED
)
self.received_stream_counter_pool.increment(
category=stream_stats_category(task.channel, task.topic, "stream"),
counter_name=COUNTER_NAME_RECEIVED,
)

self.received_stream_size_pool.record_value(
category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB
)
self.received_stream_size_pool.record_value(
category=stream_stats_category(task.channel, task.topic, "stream"), value=task.size / ONE_MB
)

stream_thread_pool.submit(self._callback_wrapper, task, callback)
stream_thread_pool.submit(self._callback_wrapper, task, callback)

with task.task_lock:
data_type = message.get_header(StreamHeaderKey.DATA_TYPE)
last_chunk = data_type == StreamDataType.FINAL
if last_chunk:
Expand Down

0 comments on commit bff1d69

Please sign in to comment.