Skip to content

Commit

Permalink
feat: make SendStream allow writing arbitrary data like io::Write
Browse files Browse the repository at this point in the history
This changes the `quic::SendStream` trait to closer mimic the
`AsyncWrite` trait, but using a `impl Buf` rather than `&[u8]` which
removes the need to allocate and copy the byte slice to store it if
needed.

[s2n-quic::SendStream](https://github.com/aws/s2n-quic/blob/bf20c6dd148153802929a2514b444dcf5dd37fd1/quic/s2n-quic-h3/src/s2n_quic.rs#L364) uses this to enqueue the bytes for sending, which would require allocating if `&[u8]` was used.

Issue hyperium#78 discusses this API change which would remove the need for intermediate buffering.

See: hyperium#78 (comment)
  • Loading branch information
ten3roberts committed Apr 12, 2023
1 parent cc7dea2 commit 7f2bc2d
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 45 deletions.
4 changes: 4 additions & 0 deletions examples/webtransport_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
async fn handle_connection<C>(mut conn: Connection<C, Bytes>) -> Result<()>
where
C: 'static + Send + quic::Connection<Bytes>,
<C::SendStream as h3::quic::SendStream<Bytes>>::Error:
'static + std::error::Error + Send + Sync,
{
// 3. TODO: Conditionally, if the client indicated that this is a webtransport session, we should accept it here, else use regular h3.
// if this is a webtransport session, then h3 needs to stop handing the datagrams, bidirectional streams, and unidirectional streams and give them
Expand Down Expand Up @@ -242,6 +244,7 @@ async fn handle_session_and_echo_all_inbound_messages<C, B>(
) -> anyhow::Result<()>
where
C: 'static + Send + h3::quic::Connection<B>,
<C::SendStream as h3::quic::SendStream<B>>::Error: 'static + std::error::Error + Send + Sync,
B: Buf,
{
loop {
Expand Down Expand Up @@ -278,6 +281,7 @@ where

tracing::info!("Got message: {message:?}");

send.write_all(&mut format!("I got your message: {message:?}").as_bytes()).await.context("Failed to respond")?;
// send.send_data(message.freeze()).context("Failed to send response");
// future::poll_fn(|cx| send.poll_ready(cx)).await?;
}
Expand Down
103 changes: 74 additions & 29 deletions h3-quinn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,18 @@ where
self.send.poll_ready(cx)
}

fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
self.send.send_data(data)
}

fn poll_send<D: Buf>(
&mut self,
cx: &mut task::Context<'_>,
buf: &mut D,
) -> Poll<Result<usize, Self::Error>> {
self.send.poll_send(cx, buf)
}

fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.send.poll_finish(cx)
}
Expand All @@ -328,10 +340,6 @@ where
self.send.reset(reset_code)
}

fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
self.send.send_data(data)
}

fn send_id(&self) -> StreamId {
self.send.send_id()
}
Expand Down Expand Up @@ -447,33 +455,70 @@ where
{
type Error = SendStreamError;

fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
if let Some(ref mut data) = self.writing {
while data.has_remaining() {
match ready!(Pin::new(&mut self.stream).poll_write(cx, data.chunk())) {
Ok(cnt) => data.advance(cnt),
Err(err) => {
// We are forced to use AsyncWrite for now because we cannot store
// the result of a call to:
// quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, S>.
//
// This is why we have to unpack the error from io::Error below. This should not
// panic as long as quinn's AsyncWrite impl doesn't change.
return Poll::Ready(Err(SendStreamError::Write(
err.into_inner()
.expect("write stream returned an empty error")
.downcast_ref::<WriteError>()
.expect(
"write stream returned an error which type is not WriteError",
)
.clone(),
)));
}
}
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
unimplemented!()
//if let Some(ref mut data) = self.writing {
// while data.has_remaining() {
// match ready!(Pin::new(&mut self.stream).poll_write(cx, data.chunk())) {
// Ok(cnt) => data.advance(cnt),
// Err(err) => {
// // We are forced to use AsyncWrite for now because we cannot store
// // the result of a call to:
// // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, S>.
// //
// // This is why we have to unpack the error from io::Error below. This should not
// // panic as long as quinn's AsyncWrite impl doesn't change.
// return Poll::Ready(Err(SendStreamError::Write(
// err.into_inner()
// .expect("write stream returned an empty error")
// .downcast_ref::<WriteError>()
// .expect(
// "write stream returned an error which type is not WriteError",
// )
// .clone(),
// )));
// }
// }
// }
//}
//self.writing = None;
//Poll::Ready(Ok(()))
}

fn poll_send<D: Buf>(
&mut self,
cx: &mut task::Context<'_>,
buf: &mut D,
) -> Poll<Result<usize, Self::Error>> {
if self.writing.is_some() {
return Poll::Ready(Err(Self::Error::NotReady));
}

let s = Pin::new(&mut self.stream);

let res = ready!(s.poll_write(cx, buf.chunk()));
match res {
Ok(written) => {
buf.advance(written);
Poll::Ready(Ok(written))
}
Err(err) => {
// We are forced to use AsyncWrite for now because we cannot store
// the result of a call to:
// quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result<usize, WriteError>.
//
// This is why we have to unpack the error from io::Error instead of having it
// returned directly. This should not panic as long as quinn's AsyncWrite impl
// doesn't change.
let err = err
.into_inner()
.expect("write stream returned an empty error")
.downcast::<WriteError>()
.expect("write stream returned an error which type is not WriteError");

Poll::Ready(Err(SendStreamError::Write(*err)))
}
}
self.writing = None;
Poll::Ready(Ok(()))
}

fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Expand Down
4 changes: 1 addition & 3 deletions h3/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,9 +830,7 @@ where
.map_err(|e| self.maybe_conn_err(e))?;
self.send_grease_frame = false;
}
future::poll_fn(|cx| self.stream.poll_ready(cx))
.await
.map_err(|e| self.maybe_conn_err(e))?;

future::poll_fn(|cx| self.stream.poll_finish(cx))
.await
.map_err(|e| self.maybe_conn_err(e))
Expand Down
18 changes: 18 additions & 0 deletions h3/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{
/// Decodes Frames from the underlying QUIC stream
pub struct FrameStream<S, B> {
stream: S,
// Already read data from the stream
bufs: BufList<Bytes>,
decoder: FrameDecoder,
remaining_data: usize,
Expand Down Expand Up @@ -70,6 +71,10 @@ where
self.remaining_data = len;
Poll::Ready(Ok(Some(Frame::Data(PayloadLen(len)))))
}
frame @ Some(Frame::WebTransportStream(_)) => {
self.remaining_data = usize::MAX;
Poll::Ready(Ok(frame))
}
Some(frame) => Poll::Ready(Ok(Some(frame))),
None => match end {
// Received a chunk but frame is incomplete, poll until we get `Pending`.
Expand All @@ -89,6 +94,10 @@ where
}
}

/// Retrieves the next piece of data in an incoming data packet or webtransport stream
///
///
/// WebTransport bidirectional payload has no finite length and is processed until the end of the stream.
pub fn poll_data(
&mut self,
cx: &mut Context<'_>,
Expand All @@ -115,6 +124,7 @@ where
}
}

/// Stops the underlying stream with the provided error code
pub(crate) fn stop_sending(&mut self, error_code: crate::error::Code) {
self.stream.stop_sending(error_code.into());
}
Expand Down Expand Up @@ -176,6 +186,14 @@ where
fn send_id(&self) -> StreamId {
self.stream.send_id()
}

fn poll_send<D: Buf>(
&mut self,
cx: &mut std::task::Context<'_>,
buf: &mut D,
) -> Poll<Result<usize, Self::Error>> {
self.stream.poll_send(cx, buf)
}
}

impl<S, B> FrameStream<S, B>
Expand Down
16 changes: 10 additions & 6 deletions h3/src/proto/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ pub enum Frame<B> {
PushPromise(PushPromise),
Goaway(VarInt),
MaxPushId(PushId),
/// A webtransport frame is a frame without a length.
/// Describes the header for a webtransport stream.
///
/// The data is streaming
/// The payload is sent streaming until the stream is closed
///
/// Unwrap the framed streamer and read the inner stream until the end.
///
/// Conversely, when sending, send this frame and unwrap the stream
WebTransportStream(SessionId),
Grease,
}
Expand Down Expand Up @@ -81,7 +85,7 @@ impl Frame<PayloadLen> {
// Webtransport streams need special handling as they have no length.
//
// See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-4.2
if ty == FrameType::WEBTRANSPORT_STREAM {
if ty == FrameType::WEBTRANSPORT_BI_STREAM {
tracing::trace!("webtransport frame");
return Ok(Frame::WebTransportStream(SessionId::decode(buf)?));
}
Expand Down Expand Up @@ -111,7 +115,7 @@ impl Frame<PayloadLen> {
| FrameType::H2_PING
| FrameType::H2_WINDOW_UPDATE
| FrameType::H2_CONTINUATION => Err(FrameError::UnsupportedFrame(ty.0)),
FrameType::WEBTRANSPORT_STREAM | FrameType::DATA => unreachable!(),
FrameType::WEBTRANSPORT_BI_STREAM | FrameType::DATA => unreachable!(),
_ => {
buf.advance(len as usize);
Err(FrameError::UnknownFrame(ty.0))
Expand Down Expand Up @@ -212,7 +216,7 @@ impl fmt::Debug for Frame<PayloadLen> {
Frame::Goaway(id) => write!(f, "GoAway({})", id),
Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id),
Frame::Grease => write!(f, "Grease()"),
Frame::WebTransportStream(_) => todo!(),
Frame::WebTransportStream(session) => write!(f, "WebTransportStream({session:?})"),
}
}
}
Expand Down Expand Up @@ -286,7 +290,7 @@ frame_types! {
H2_CONTINUATION = 0x9,
MAX_PUSH_ID = 0xD,
// Reserved frame types
WEBTRANSPORT_STREAM = 0x41,
WEBTRANSPORT_BI_STREAM = 0x41,
}

impl FrameType {
Expand Down
13 changes: 13 additions & 0 deletions h3/src/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ pub trait SendStream<B: Buf> {
/// Send more data on the stream.
fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error>;

/// Attempts write data into the stream.
///
/// Returns the number of bytes written.
///
/// `buf` is advanced by the number of bytes written.
///
/// This allows writing arbitrary data to the stream as well as complete encoded frames.
fn poll_send<D: Buf>(
&mut self,
cx: &mut task::Context<'_>,
buf: &mut D,
) -> Poll<Result<usize, Self::Error>>;

/// Poll to finish the sending side of the stream.
fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>;

Expand Down
7 changes: 5 additions & 2 deletions h3/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ use crate::{
};

#[inline]
/// Transmits data by encoding in wire format.
pub(crate) async fn write<S, D, B>(stream: &mut S, data: D) -> Result<(), Error>
where
S: SendStream<B>,
D: Into<WriteBuf<B>>,
B: Buf,
{
stream.send_data(data)?;
future::poll_fn(|cx| stream.poll_ready(cx)).await?;
let mut write_buf = data.into();
while write_buf.has_remaining() {
future::poll_fn(|cx| stream.poll_send(cx, &mut write_buf)).await?;
}

Ok(())
}
Expand Down
40 changes: 35 additions & 5 deletions h3/src/webtransport/stream.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{marker::PhantomData, task::Poll};

use bytes::{Buf, Bytes};
use futures_util::{ready, AsyncRead};
use futures_util::{future, ready, AsyncRead};

use crate::{
buf::BufList,
proto::varint::UnexpectedEnd,
quic::{self},
quic::{self, SendStream as QSendStream},
};

use super::SessionId;
Expand Down Expand Up @@ -73,19 +73,49 @@ impl<S, B> SendStream<S, B> {
}
}

impl<S, B> quic::SendStream<B> for SendStream<S, B>
impl<S, B> SendStream<S, B>
where
S: quic::SendStream<B>,
B: Buf,
{
/// Write bytes to the stream.
///
/// Returns the number of bytes written
pub async fn write(&mut self, buf: &mut impl Buf) -> Result<usize, S::Error> {
future::poll_fn(|cx| quic::SendStream::poll_send(self, cx, buf)).await
}

/// Writes the entire buffer to the stream
pub async fn write_all(&mut self, mut buf: impl Buf) -> Result<(), S::Error> {
while buf.has_remaining() {
self.write(&mut buf).await?;
}

Ok(())
}
}

impl<S, B> QSendStream<B> for SendStream<S, B>
where
S: QSendStream<B>,
B: Buf,
{
type Error = S::Error;

fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.stream.poll_ready(cx)
todo!()
}

fn send_data<T: Into<quic::WriteBuf<B>>>(&mut self, data: T) -> Result<(), Self::Error> {
self.stream.send_data(data)
todo!()
}

fn poll_send<D: Buf>(
&mut self,
cx: &mut std::task::Context<'_>,
buf: &mut D,
) -> Poll<Result<usize, Self::Error>> {
self.stream.poll_send(cx, buf)
}

fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Expand Down

0 comments on commit 7f2bc2d

Please sign in to comment.