From 7f2bc2db3bad4fb793eac4d4fbe448da2d91a192 Mon Sep 17 00:00:00 2001 From: Tei Roberts Date: Wed, 12 Apr 2023 14:10:48 +0200 Subject: [PATCH] feat: make SendStream allow writing arbitrary data like io::Write This changes the `quic::SendStream` trait to closer mimic the `AsyncWrite` trait, but using a `impl Buf` rather than `&[u8]` which removes the need to allocate and copy the byte slice to store it if needed. [s2n-quic::SendStream](https://github.com/aws/s2n-quic/blob/bf20c6dd148153802929a2514b444dcf5dd37fd1/quic/s2n-quic-h3/src/s2n_quic.rs#L364) uses this to enqueue the bytes for sending, which would require allocating if `&[u8]` was used. Issue #78 discusses this API change which would remove the need for intermediate buffering. See: https://github.com/hyperium/h3/issues/78#issuecomment-1032899617 --- examples/webtransport_server.rs | 4 ++ h3-quinn/src/lib.rs | 103 +++++++++++++++++++++++--------- h3/src/connection.rs | 4 +- h3/src/frame.rs | 18 ++++++ h3/src/proto/frame.rs | 16 +++-- h3/src/quic.rs | 13 ++++ h3/src/stream.rs | 7 ++- h3/src/webtransport/stream.rs | 40 +++++++++++-- 8 files changed, 160 insertions(+), 45 deletions(-) diff --git a/examples/webtransport_server.rs b/examples/webtransport_server.rs index 8a466d3c..e00ce0f7 100644 --- a/examples/webtransport_server.rs +++ b/examples/webtransport_server.rs @@ -188,6 +188,8 @@ async fn main() -> Result<(), Box> { async fn handle_connection(mut conn: Connection) -> Result<()> where C: 'static + Send + quic::Connection, + >::Error: + 'static + std::error::Error + Send + Sync, { // 3. TODO: Conditionally, if the client indicated that this is a webtransport session, we should accept it here, else use regular h3. // if this is a webtransport session, then h3 needs to stop handing the datagrams, bidirectional streams, and unidirectional streams and give them @@ -242,6 +244,7 @@ async fn handle_session_and_echo_all_inbound_messages( ) -> anyhow::Result<()> where C: 'static + Send + h3::quic::Connection, + >::Error: 'static + std::error::Error + Send + Sync, B: Buf, { loop { @@ -278,6 +281,7 @@ where tracing::info!("Got message: {message:?}"); + send.write_all(&mut format!("I got your message: {message:?}").as_bytes()).await.context("Failed to respond")?; // send.send_data(message.freeze()).context("Failed to send response"); // future::poll_fn(|cx| send.poll_ready(cx)).await?; } diff --git a/h3-quinn/src/lib.rs b/h3-quinn/src/lib.rs index 5c4d5e92..dce8bd5e 100644 --- a/h3-quinn/src/lib.rs +++ b/h3-quinn/src/lib.rs @@ -320,6 +320,18 @@ where self.send.poll_ready(cx) } + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.send.poll_send(cx, buf) + } + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { self.send.poll_finish(cx) } @@ -328,10 +340,6 @@ where self.send.reset(reset_code) } - fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { - self.send.send_data(data) - } - fn send_id(&self) -> StreamId { self.send.send_id() } @@ -447,33 +455,70 @@ where { type Error = SendStreamError; - fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - if let Some(ref mut data) = self.writing { - while data.has_remaining() { - match ready!(Pin::new(&mut self.stream).poll_write(cx, data.chunk())) { - Ok(cnt) => data.advance(cnt), - Err(err) => { - // We are forced to use AsyncWrite for now because we cannot store - // the result of a call to: - // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, S>. - // - // This is why we have to unpack the error from io::Error below. This should not - // panic as long as quinn's AsyncWrite impl doesn't change. - return Poll::Ready(Err(SendStreamError::Write( - err.into_inner() - .expect("write stream returned an empty error") - .downcast_ref::() - .expect( - "write stream returned an error which type is not WriteError", - ) - .clone(), - ))); - } - } + fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll> { + unimplemented!() + //if let Some(ref mut data) = self.writing { + // while data.has_remaining() { + // match ready!(Pin::new(&mut self.stream).poll_write(cx, data.chunk())) { + // Ok(cnt) => data.advance(cnt), + // Err(err) => { + // // We are forced to use AsyncWrite for now because we cannot store + // // the result of a call to: + // // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, S>. + // // + // // This is why we have to unpack the error from io::Error below. This should not + // // panic as long as quinn's AsyncWrite impl doesn't change. + // return Poll::Ready(Err(SendStreamError::Write( + // err.into_inner() + // .expect("write stream returned an empty error") + // .downcast_ref::() + // .expect( + // "write stream returned an error which type is not WriteError", + // ) + // .clone(), + // ))); + // } + // } + // } + //} + //self.writing = None; + //Poll::Ready(Ok(())) + } + + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + if self.writing.is_some() { + return Poll::Ready(Err(Self::Error::NotReady)); + } + + let s = Pin::new(&mut self.stream); + + let res = ready!(s.poll_write(cx, buf.chunk())); + match res { + Ok(written) => { + buf.advance(written); + Poll::Ready(Ok(written)) + } + Err(err) => { + // We are forced to use AsyncWrite for now because we cannot store + // the result of a call to: + // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result. + // + // This is why we have to unpack the error from io::Error instead of having it + // returned directly. This should not panic as long as quinn's AsyncWrite impl + // doesn't change. + let err = err + .into_inner() + .expect("write stream returned an empty error") + .downcast::() + .expect("write stream returned an error which type is not WriteError"); + + Poll::Ready(Err(SendStreamError::Write(*err))) } } - self.writing = None; - Poll::Ready(Ok(())) } fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { diff --git a/h3/src/connection.rs b/h3/src/connection.rs index 1bf432a2..f940c5e9 100644 --- a/h3/src/connection.rs +++ b/h3/src/connection.rs @@ -830,9 +830,7 @@ where .map_err(|e| self.maybe_conn_err(e))?; self.send_grease_frame = false; } - future::poll_fn(|cx| self.stream.poll_ready(cx)) - .await - .map_err(|e| self.maybe_conn_err(e))?; + future::poll_fn(|cx| self.stream.poll_finish(cx)) .await .map_err(|e| self.maybe_conn_err(e)) diff --git a/h3/src/frame.rs b/h3/src/frame.rs index 135fdd3d..f2aa7049 100644 --- a/h3/src/frame.rs +++ b/h3/src/frame.rs @@ -20,6 +20,7 @@ use crate::{ /// Decodes Frames from the underlying QUIC stream pub struct FrameStream { stream: S, + // Already read data from the stream bufs: BufList, decoder: FrameDecoder, remaining_data: usize, @@ -70,6 +71,10 @@ where self.remaining_data = len; Poll::Ready(Ok(Some(Frame::Data(PayloadLen(len))))) } + frame @ Some(Frame::WebTransportStream(_)) => { + self.remaining_data = usize::MAX; + Poll::Ready(Ok(frame)) + } Some(frame) => Poll::Ready(Ok(Some(frame))), None => match end { // Received a chunk but frame is incomplete, poll until we get `Pending`. @@ -89,6 +94,10 @@ where } } + /// Retrieves the next piece of data in an incoming data packet or webtransport stream + /// + /// + /// WebTransport bidirectional payload has no finite length and is processed until the end of the stream. pub fn poll_data( &mut self, cx: &mut Context<'_>, @@ -115,6 +124,7 @@ where } } + /// Stops the underlying stream with the provided error code pub(crate) fn stop_sending(&mut self, error_code: crate::error::Code) { self.stream.stop_sending(error_code.into()); } @@ -176,6 +186,14 @@ where fn send_id(&self) -> StreamId { self.stream.send_id() } + + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) + } } impl FrameStream diff --git a/h3/src/proto/frame.rs b/h3/src/proto/frame.rs index 063079f9..dc3cf805 100644 --- a/h3/src/proto/frame.rs +++ b/h3/src/proto/frame.rs @@ -51,9 +51,13 @@ pub enum Frame { PushPromise(PushPromise), Goaway(VarInt), MaxPushId(PushId), - /// A webtransport frame is a frame without a length. + /// Describes the header for a webtransport stream. /// - /// The data is streaming + /// The payload is sent streaming until the stream is closed + /// + /// Unwrap the framed streamer and read the inner stream until the end. + /// + /// Conversely, when sending, send this frame and unwrap the stream WebTransportStream(SessionId), Grease, } @@ -81,7 +85,7 @@ impl Frame { // Webtransport streams need special handling as they have no length. // // See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-4.2 - if ty == FrameType::WEBTRANSPORT_STREAM { + if ty == FrameType::WEBTRANSPORT_BI_STREAM { tracing::trace!("webtransport frame"); return Ok(Frame::WebTransportStream(SessionId::decode(buf)?)); } @@ -111,7 +115,7 @@ impl Frame { | FrameType::H2_PING | FrameType::H2_WINDOW_UPDATE | FrameType::H2_CONTINUATION => Err(FrameError::UnsupportedFrame(ty.0)), - FrameType::WEBTRANSPORT_STREAM | FrameType::DATA => unreachable!(), + FrameType::WEBTRANSPORT_BI_STREAM | FrameType::DATA => unreachable!(), _ => { buf.advance(len as usize); Err(FrameError::UnknownFrame(ty.0)) @@ -212,7 +216,7 @@ impl fmt::Debug for Frame { Frame::Goaway(id) => write!(f, "GoAway({})", id), Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id), Frame::Grease => write!(f, "Grease()"), - Frame::WebTransportStream(_) => todo!(), + Frame::WebTransportStream(session) => write!(f, "WebTransportStream({session:?})"), } } } @@ -286,7 +290,7 @@ frame_types! { H2_CONTINUATION = 0x9, MAX_PUSH_ID = 0xD, // Reserved frame types - WEBTRANSPORT_STREAM = 0x41, + WEBTRANSPORT_BI_STREAM = 0x41, } impl FrameType { diff --git a/h3/src/quic.rs b/h3/src/quic.rs index 1cba9486..1ebc1985 100644 --- a/h3/src/quic.rs +++ b/h3/src/quic.rs @@ -164,6 +164,19 @@ pub trait SendStream { /// Send more data on the stream. fn send_data>>(&mut self, data: T) -> Result<(), Self::Error>; + /// Attempts write data into the stream. + /// + /// Returns the number of bytes written. + /// + /// `buf` is advanced by the number of bytes written. + /// + /// This allows writing arbitrary data to the stream as well as complete encoded frames. + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll>; + /// Poll to finish the sending side of the stream. fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll>; diff --git a/h3/src/stream.rs b/h3/src/stream.rs index 9a59bbd6..917515a2 100644 --- a/h3/src/stream.rs +++ b/h3/src/stream.rs @@ -19,14 +19,17 @@ use crate::{ }; #[inline] +/// Transmits data by encoding in wire format. pub(crate) async fn write(stream: &mut S, data: D) -> Result<(), Error> where S: SendStream, D: Into>, B: Buf, { - stream.send_data(data)?; - future::poll_fn(|cx| stream.poll_ready(cx)).await?; + let mut write_buf = data.into(); + while write_buf.has_remaining() { + future::poll_fn(|cx| stream.poll_send(cx, &mut write_buf)).await?; + } Ok(()) } diff --git a/h3/src/webtransport/stream.rs b/h3/src/webtransport/stream.rs index 12913ae8..0d6a2895 100644 --- a/h3/src/webtransport/stream.rs +++ b/h3/src/webtransport/stream.rs @@ -1,12 +1,12 @@ use std::{marker::PhantomData, task::Poll}; use bytes::{Buf, Bytes}; -use futures_util::{ready, AsyncRead}; +use futures_util::{future, ready, AsyncRead}; use crate::{ buf::BufList, proto::varint::UnexpectedEnd, - quic::{self}, + quic::{self, SendStream as QSendStream}, }; use super::SessionId; @@ -73,19 +73,49 @@ impl SendStream { } } -impl quic::SendStream for SendStream +impl SendStream where S: quic::SendStream, B: Buf, +{ + /// Write bytes to the stream. + /// + /// Returns the number of bytes written + pub async fn write(&mut self, buf: &mut impl Buf) -> Result { + future::poll_fn(|cx| quic::SendStream::poll_send(self, cx, buf)).await + } + + /// Writes the entire buffer to the stream + pub async fn write_all(&mut self, mut buf: impl Buf) -> Result<(), S::Error> { + while buf.has_remaining() { + self.write(&mut buf).await?; + } + + Ok(()) + } +} + +impl QSendStream for SendStream +where + S: QSendStream, + B: Buf, { type Error = S::Error; fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { - self.stream.poll_ready(cx) + todo!() } fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { - self.stream.send_data(data) + todo!() + } + + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) } fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll> {