From 9f5c59ddcd5e17a6202f29946f0d66dbf83ca775 Mon Sep 17 00:00:00 2001 From: Nur Date: Tue, 3 Dec 2024 12:30:58 +0600 Subject: [PATCH 1/2] Avoiding extra allocation for shared error. --- h3/src/error.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/h3/src/error.rs b/h3/src/error.rs index 053ab611..f4b3d53a 100644 --- a/h3/src/error.rs +++ b/h3/src/error.rs @@ -42,7 +42,7 @@ impl PartialEq for Code { #[derive(Clone)] pub(crate) struct ErrorImpl { pub(crate) kind: Kind, - cause: Option>, + cause: Option>, } /// Some errors affect the whole connection, others only one Request or Stream. @@ -74,7 +74,7 @@ pub enum Kind { }, // Error from QUIC layer #[non_exhaustive] - Transport(Arc), + Transport(Arc), // Connection has been closed with `Code::NO_ERROR` Closed, // Currently in a graceful shutdown procedure @@ -205,7 +205,7 @@ impl Code { } pub(crate) fn with_transport>>(self, err: E) -> Error { - Error::new(Kind::Transport(Arc::new(err.into()))) + Error::new(Kind::Transport(Arc::from(err.into()))) } } @@ -254,7 +254,7 @@ impl Error { } pub(crate) fn with_cause>(mut self, cause: E) -> Self { - self.inner.cause = Some(Arc::new(cause.into())); + self.inner.cause = Some(Arc::from(cause.into())); self } @@ -360,7 +360,7 @@ impl fmt::Display for Error { impl std::error::Error for Error { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.inner.cause.as_ref().map(|e| &***e as _) + self.inner.cause.as_ref().map(|e| e as _) } } @@ -462,7 +462,7 @@ where reason: None, level: ErrorLevel::ConnectionError, }), - None => Error::new(Kind::Transport(Arc::new(quic_error))), + None => Error::new(Kind::Transport(Arc::from(quic_error))), } } } From 598397e855a363c4107518439b6cf14ee78ff239 Mon Sep 17 00:00:00 2001 From: Nur Date: Wed, 11 Dec 2024 16:42:47 +0600 Subject: [PATCH 2/2] Some cleanup --- h3-webtransport/src/server.rs | 2 +- h3/src/client/connection.rs | 2 +- h3/src/connection.rs | 21 +++------------ h3/src/error.rs | 8 +++--- h3/src/server/connection.rs | 50 ++++++++++++++--------------------- 5 files changed, 30 insertions(+), 53 deletions(-) diff --git a/h3-webtransport/src/server.rs b/h3-webtransport/src/server.rs index 05970b61..5ae8886b 100644 --- a/h3-webtransport/src/server.rs +++ b/h3-webtransport/src/server.rs @@ -164,7 +164,7 @@ where // Accept the incoming stream let stream = poll_fn(|cx| { let mut conn = self.server_conn.lock().unwrap(); - conn.poll_accept_request(cx) + conn.poll_accept_request_stream(cx) }) .await; diff --git a/h3/src/client/connection.rs b/h3/src/client/connection.rs index 49e9ed2f..66b5cc6d 100644 --- a/h3/src/client/connection.rs +++ b/h3/src/client/connection.rs @@ -446,7 +446,7 @@ where //# receipt of a server-initiated bidirectional stream as a connection //# error of type H3_STREAM_CREATION_ERROR unless such an extension has //# been negotiated. - if self.inner.poll_accept_request(cx).is_ready() { + if self.inner.poll_accept_bi(cx).is_ready() { return Poll::Ready(Err(self.inner.close( Code::H3_STREAM_CREATION_ERROR, "client received a bidirectional stream", diff --git a/h3/src/connection.rs b/h3/src/connection.rs index a2fcf6ce..0160898d 100644 --- a/h3/src/connection.rs +++ b/h3/src/connection.rs @@ -250,19 +250,6 @@ where future::poll_fn(|cx| conn.poll_open_send(cx)).await, ); - let control_send = - control_send.map_err(|e| Code::H3_STREAM_CREATION_ERROR.with_transport(e))?; - - let qpack_encoder = match qpack_encoder { - Ok(stream) => Some(stream), - Err(_) => None, - }; - - let qpack_decoder = match qpack_decoder { - Ok(stream) => Some(stream), - Err(_) => None, - }; - //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1 //= type=implication //# The @@ -271,7 +258,7 @@ where let mut conn_inner = Self { shared, conn, - control_send, + control_send: control_send.map_err(Error::transport_err)?, control_recv: None, decoder_recv: None, encoder_recv: None, @@ -280,8 +267,8 @@ where send_grease_frame: config.send_grease, config, accepted_streams: Default::default(), - decoder_send: qpack_decoder, - encoder_send: qpack_encoder, + decoder_send: qpack_decoder.ok(), + encoder_send: qpack_encoder.ok(), // send grease stream if configured send_grease_stream_flag: config.send_grease, // start at first step @@ -323,7 +310,7 @@ where #[allow(missing_docs)] #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))] - pub fn poll_accept_request( + pub fn poll_accept_bi( &mut self, cx: &mut Context<'_>, ) -> Poll, Error>> { diff --git a/h3/src/error.rs b/h3/src/error.rs index f4b3d53a..94f193f1 100644 --- a/h3/src/error.rs +++ b/h3/src/error.rs @@ -203,10 +203,6 @@ impl Code { pub(crate) fn with_cause>(self, cause: E) -> Error { Error::from(self).with_cause(cause) } - - pub(crate) fn with_transport>>(self, err: E) -> Error { - Error::new(Kind::Transport(Arc::from(err.into()))) - } } impl From for u64 { @@ -224,6 +220,10 @@ impl Error { } } + pub(crate) fn transport_err>>(err: E) -> Error { + Error::new(Kind::Transport(Arc::from(err.into()))) + } + /// Returns the error code from the error if available pub fn try_get_code(&self) -> Option { match self.inner.kind { diff --git a/h3/src/server/connection.rs b/h3/src/server/connection.rs index c4e5687d..9d9721a8 100644 --- a/h3/src/server/connection.rs +++ b/h3/src/server/connection.rs @@ -4,18 +4,15 @@ use std::{ collections::HashSet, + future::poll_fn, marker::PhantomData, option::Option, result::Result, sync::Arc, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use bytes::Buf; -use futures_util::{ - future::{self}, - ready, -}; use http::Request; use quic::RecvStream; use quic::StreamId; @@ -117,7 +114,7 @@ where &mut self, ) -> Result, RequestStream)>, Error> { // Accept the incoming stream - let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await { + let mut stream = match poll_fn(|cx| self.poll_accept_request_stream(cx)).await { Ok(Some(s)) => FrameStream::new(BufRecvStream::new(s)), Ok(None) => { // We always send a last GoAway frame to the client, so it knows which was the last @@ -132,18 +129,13 @@ where code, reason, level: ErrorLevel::ConnectionError, - } => { - return Err(self.inner.close( - code, - reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), - )) - } + } => return Err(self.inner.close(code, reason.unwrap_or_default())), _ => return Err(err), }; } }; - let frame = future::poll_fn(|cx| stream.poll_next(cx)).await; + let frame = poll_fn(|cx| stream.poll_next(cx)).await; let req = self.accept_with_frame(stream, frame)?; if let Some(req) = req { Ok(Some(req.resolve().await?)) @@ -305,33 +297,31 @@ where /// This could be either a *Request* or a *WebTransportBiStream*, the first frame's type /// decides. #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))] - pub fn poll_accept_request( + pub fn poll_accept_request_stream( &mut self, cx: &mut Context<'_>, ) -> Poll, Error>> { let _ = self.poll_control(cx)?; let _ = self.poll_requests_completion(cx); loop { - match self.inner.poll_accept_request(cx) { - Poll::Ready(Err(x)) => break Poll::Ready(Err(x)), - Poll::Ready(Ok(None)) => { - if self.poll_requests_completion(cx).is_ready() { - break Poll::Ready(Ok(None)); + let conn = self.inner.poll_accept_bi(cx)?; + return match conn { + Poll::Ready(None) | Poll::Pending => { + let done = if conn.is_pending() { + self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready() + } else { + self.poll_requests_completion(cx).is_ready() + }; + + if done { + Poll::Ready(Ok(None)) } else { // Wait for all the requests to be finished, request_end_recv will wake // us on each request completion. - break Poll::Pending; - } - } - Poll::Pending => { - if self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready() { - // The connection is now idle. - break Poll::Ready(Ok(None)); - } else { - return Poll::Pending; + Poll::Pending } } - Poll::Ready(Ok(Some(mut s))) => { + Poll::Ready(Some(mut s)) => { // When the connection is in a graceful shutdown procedure, reject all // incoming requests not belonging to the grace interval. It's possible that // some acceptable request streams arrive after rejected requests. @@ -347,7 +337,7 @@ where } self.last_accepted_stream = Some(s.send_id()); self.ongoing_streams.insert(s.send_id()); - break Poll::Ready(Ok(Some(s))); + Poll::Ready(Ok(Some(s))) } }; }