Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoiding extra allocation for shared error. #263

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion h3-webtransport/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ where
// Accept the incoming stream
let stream = poll_fn(|cx| {
let mut conn = self.server_conn.lock().unwrap();
conn.poll_accept_request(cx)
conn.poll_accept_request_stream(cx)
})
.await;

Expand Down
2 changes: 1 addition & 1 deletion h3/src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@
Ok(Frame::Settings(_)) => {
#[cfg(feature = "tracing")]
trace!("Got settings");
()

Check warning on line 391 in h3/src/client/connection.rs

View workflow job for this annotation

GitHub Actions / Lint

unneeded unit expression
}

Ok(Frame::Goaway(id)) => {
Expand Down Expand Up @@ -446,7 +446,7 @@
//# receipt of a server-initiated bidirectional stream as a connection
//# error of type H3_STREAM_CREATION_ERROR unless such an extension has
//# been negotiated.
if self.inner.poll_accept_request(cx).is_ready() {
if self.inner.poll_accept_bi(cx).is_ready() {
return Poll::Ready(Err(self.inner.close(
Code::H3_STREAM_CREATION_ERROR,
"client received a bidirectional stream",
Expand Down
21 changes: 4 additions & 17 deletions h3/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,6 @@ where
future::poll_fn(|cx| conn.poll_open_send(cx)).await,
);

let control_send =
control_send.map_err(|e| Code::H3_STREAM_CREATION_ERROR.with_transport(e))?;

let qpack_encoder = match qpack_encoder {
Ok(stream) => Some(stream),
Err(_) => None,
};

let qpack_decoder = match qpack_decoder {
Ok(stream) => Some(stream),
Err(_) => None,
};

//= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1
//= type=implication
//# The
Expand All @@ -271,7 +258,7 @@ where
let mut conn_inner = Self {
shared,
conn,
control_send,
control_send: control_send.map_err(Error::transport_err)?,
control_recv: None,
decoder_recv: None,
encoder_recv: None,
Expand All @@ -280,8 +267,8 @@ where
send_grease_frame: config.send_grease,
config,
accepted_streams: Default::default(),
decoder_send: qpack_decoder,
encoder_send: qpack_encoder,
decoder_send: qpack_decoder.ok(),
encoder_send: qpack_encoder.ok(),
// send grease stream if configured
send_grease_stream_flag: config.send_grease,
// start at first step
Expand Down Expand Up @@ -323,7 +310,7 @@ where

#[allow(missing_docs)]
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn poll_accept_request(
pub fn poll_accept_bi(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<C::BidiStream>, Error>> {
Expand Down
18 changes: 9 additions & 9 deletions h3/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl PartialEq<u64> for Code {
#[derive(Clone)]
pub(crate) struct ErrorImpl {
pub(crate) kind: Kind,
cause: Option<Arc<Cause>>,
cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
}

/// Some errors affect the whole connection, others only one Request or Stream.
Expand Down Expand Up @@ -74,7 +74,7 @@ pub enum Kind {
},
// Error from QUIC layer
#[non_exhaustive]
Transport(Arc<TransportError>),
Transport(Arc<dyn quic::Error>),
// Connection has been closed with `Code::NO_ERROR`
Closed,
// Currently in a graceful shutdown procedure
Expand Down Expand Up @@ -203,10 +203,6 @@ impl Code {
pub(crate) fn with_cause<E: Into<Cause>>(self, cause: E) -> Error {
Error::from(self).with_cause(cause)
}

pub(crate) fn with_transport<E: Into<Box<dyn quic::Error>>>(self, err: E) -> Error {
Error::new(Kind::Transport(Arc::new(err.into())))
}
}

impl From<Code> for u64 {
Expand All @@ -224,6 +220,10 @@ impl Error {
}
}

pub(crate) fn transport_err<E: Into<Box<dyn quic::Error>>>(err: E) -> Error {
Error::new(Kind::Transport(Arc::from(err.into())))
}

/// Returns the error code from the error if available
pub fn try_get_code(&self) -> Option<Code> {
match self.inner.kind {
Expand Down Expand Up @@ -254,7 +254,7 @@ impl Error {
}

pub(crate) fn with_cause<E: Into<Cause>>(mut self, cause: E) -> Self {
self.inner.cause = Some(Arc::new(cause.into()));
self.inner.cause = Some(Arc::from(cause.into()));
self
}

Expand Down Expand Up @@ -360,7 +360,7 @@ impl fmt::Display for Error {

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.inner.cause.as_ref().map(|e| &***e as _)
self.inner.cause.as_ref().map(|e| e as _)
}
}

Expand Down Expand Up @@ -462,7 +462,7 @@ where
reason: None,
level: ErrorLevel::ConnectionError,
}),
None => Error::new(Kind::Transport(Arc::new(quic_error))),
None => Error::new(Kind::Transport(Arc::from(quic_error))),
}
}
}
Expand Down
50 changes: 20 additions & 30 deletions h3/src/server/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@

use std::{
collections::HashSet,
future::poll_fn,
marker::PhantomData,
option::Option,
result::Result,
sync::Arc,
task::{Context, Poll},
task::{ready, Context, Poll},
};

use bytes::Buf;
use futures_util::{
future::{self},
ready,
};
use http::Request;
use quic::RecvStream;
use quic::StreamId;
Expand Down Expand Up @@ -117,7 +114,7 @@
&mut self,
) -> Result<Option<(Request<()>, RequestStream<C::BidiStream, B>)>, Error> {
// Accept the incoming stream
let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await {
let mut stream = match poll_fn(|cx| self.poll_accept_request_stream(cx)).await {
Ok(Some(s)) => FrameStream::new(BufRecvStream::new(s)),
Ok(None) => {
// We always send a last GoAway frame to the client, so it knows which was the last
Expand All @@ -132,18 +129,13 @@
code,
reason,
level: ErrorLevel::ConnectionError,
} => {
return Err(self.inner.close(
code,
reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))),
))
}
} => return Err(self.inner.close(code, reason.unwrap_or_default())),
_ => return Err(err),
};
}
};

let frame = future::poll_fn(|cx| stream.poll_next(cx)).await;
let frame = poll_fn(|cx| stream.poll_next(cx)).await;
let req = self.accept_with_frame(stream, frame)?;
if let Some(req) = req {
Ok(Some(req.resolve().await?))
Expand Down Expand Up @@ -305,33 +297,31 @@
/// This could be either a *Request* or a *WebTransportBiStream*, the first frame's type
/// decides.
#[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))]
pub fn poll_accept_request(
pub fn poll_accept_request_stream(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<C::BidiStream>, Error>> {
let _ = self.poll_control(cx)?;
let _ = self.poll_requests_completion(cx);
loop {
match self.inner.poll_accept_request(cx) {
Poll::Ready(Err(x)) => break Poll::Ready(Err(x)),
Poll::Ready(Ok(None)) => {
if self.poll_requests_completion(cx).is_ready() {
break Poll::Ready(Ok(None));
let conn = self.inner.poll_accept_bi(cx)?;
return match conn {
Poll::Ready(None) | Poll::Pending => {
let done = if conn.is_pending() {
self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready()
} else {
self.poll_requests_completion(cx).is_ready()
};

if done {
Poll::Ready(Ok(None))
} else {
// Wait for all the requests to be finished, request_end_recv will wake
// us on each request completion.
break Poll::Pending;
}
}
Poll::Pending => {
if self.recv_closing.is_some() && self.poll_requests_completion(cx).is_ready() {
// The connection is now idle.
break Poll::Ready(Ok(None));
} else {
return Poll::Pending;
Poll::Pending
}
}
Poll::Ready(Ok(Some(mut s))) => {
Poll::Ready(Some(mut s)) => {
// When the connection is in a graceful shutdown procedure, reject all
// incoming requests not belonging to the grace interval. It's possible that
// some acceptable request streams arrive after rejected requests.
Expand All @@ -347,7 +337,7 @@
}
self.last_accepted_stream = Some(s.send_id());
self.ongoing_streams.insert(s.send_id());
break Poll::Ready(Ok(Some(s)));
Poll::Ready(Ok(Some(s)))
}
};
}
Expand All @@ -370,7 +360,7 @@
Frame::Settings(_setting) => {
#[cfg(feature = "tracing")]
trace!("Got settings > {:?}", _setting);
()

Check warning on line 363 in h3/src/server/connection.rs

View workflow job for this annotation

GitHub Actions / Lint

unneeded unit expression
}
&Frame::Goaway(id) => self.inner.process_goaway(&mut self.recv_closing, id)?,
_frame @ Frame::MaxPushId(_) | _frame @ Frame::CancelPush(_) => {
Expand Down
Loading