Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions web-transport-quiche/src/ez/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ pub(super) struct DriverState {

/// The negotiated ALPN protocol, set after the handshake completes.
alpn: Option<Vec<u8>>,

/// Wakers waiting for the handshake to complete.
handshake_wakers: Vec<Waker>,
}

impl DriverState {
Expand All @@ -58,6 +61,7 @@ impl DriverState {
bi: DriverOpen::new(next_bi),
uni: DriverOpen::new(next_uni),
alpn: None,
handshake_wakers: Vec::new(),
}
}

Expand All @@ -78,6 +82,34 @@ impl DriverState {
self.alpn.as_deref()
}

/// Poll for handshake completion.
/// Returns Ready once the handshake completes, or if the connection is closed.
pub fn poll_handshake(&mut self, waker: &Waker) -> Poll<Result<(), ConnectionError>> {
// Check if already established
if self.alpn.is_some() {
return Poll::Ready(Ok(()));
}

// Check if connection is closed
if let Poll::Ready(err) = self.local.poll(waker) {
return Poll::Ready(Err(err));
}
if let Poll::Ready(err) = self.remote.poll(waker) {
return Poll::Ready(Err(err));
}

// Wait for handshake
self.handshake_wakers.push(waker.clone());
Poll::Pending
}

/// Notify all wakers waiting for handshake completion.
/// Should be called when the handshake completes.
#[must_use = "wake the handshake wakers"]
pub fn complete_handshake(&mut self) -> Vec<Waker> {
std::mem::take(&mut self.handshake_wakers)
}

#[must_use = "wake the driver"]
pub fn send(&mut self, stream_id: StreamId) -> Option<Waker> {
if !self.send.insert(stream_id) {
Expand Down Expand Up @@ -179,12 +211,17 @@ impl Driver {
) -> Result<(), ConnectionError> {
// Capture the negotiated ALPN protocol.
let alpn = qconn.application_proto();
self.state.lock().alpn = if alpn.is_empty() {
None
} else {
Some(alpn.to_vec())
let wakers = {
let mut state = self.state.lock();
state.alpn = (!alpn.is_empty()).then(|| alpn.to_vec());
state.complete_handshake()
};

// Wake all tasks waiting for handshake completion.
for waker in wakers {
waker.wake();
}

// Run poll once to advance any pending operations.
match self.poll(Waker::noop(), qconn) {
Poll::Ready(Err(e)) => Err(e),
Expand Down
26 changes: 20 additions & 6 deletions web-transport-quiche/src/ez/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use crate::ez::tls::{DynamicCertHook, StaticCertHook};
use crate::ez::DriverState;

use super::{CertResolver, Connection, DefaultMetrics, Driver, Lock, Metrics, Settings};
use super::{
CertResolver, Connection, ConnectionError, DefaultMetrics, Driver, Lock, Metrics, Settings,
};

/// Used with [ServerBuilder] to require specific parameters.
#[derive(Default)]
Expand Down Expand Up @@ -201,6 +203,7 @@ impl<M: Metrics> ServerBuilder<M, ServerWithListener> {
/// The peer address is available before calling [Incoming::accept].
pub struct Incoming {
connection: Connection,
driver: Lock<DriverState>,
}

impl Incoming {
Expand All @@ -209,9 +212,17 @@ impl Incoming {
self.connection.peer_addr()
}

/// Accept the connection, starting the post-handshake driver.
pub fn accept(self) -> Connection {
self.connection
/// Accept the connection, waiting for the TLS handshake to complete.
///
/// Returns the connection once the handshake is complete, or an error if the connection
/// is closed before the handshake finishes.
pub async fn accept(self) -> Result<Connection, ConnectionError> {
use std::future::poll_fn;

// Wait for handshake to complete
poll_fn(|cx| self.driver.lock().poll_handshake(cx.waker())).await?;

Ok(self.connection)
}
}

Expand Down Expand Up @@ -263,8 +274,11 @@ impl<M: Metrics> Server<M> {
let session = Driver::new(state.clone(), accept_bi.0, accept_uni.0);

let inner = initial.start(session);
let connection = Connection::new(inner, state, accept_bi.1, accept_uni.1);
let incoming = Incoming { connection };
let connection = Connection::new(inner, state.clone(), accept_bi.1, accept_uni.1);
let incoming = Incoming {
connection,
driver: state,
};

if accept.send(incoming).await.is_err() {
return Ok(());
Expand Down
11 changes: 8 additions & 3 deletions web-transport-quiche/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ pub enum ServerError {
#[error("io error: {0}")]
Io(Arc<std::io::Error>),

#[error("connection error: {0}")]
Connection(#[from] ez::ConnectionError),

#[error("settings error: {0}")]
Settings(#[from] h3::SettingsError),

Expand Down Expand Up @@ -158,13 +161,15 @@ impl<M: ez::Metrics> Server<M> {
loop {
tokio::select! {
Some(incoming) = self.inner.accept() => {
let conn = incoming.accept();
self.accept.push(Box::pin(h3::Request::accept(conn)));
self.accept.push(Box::pin(async move {
let conn = incoming.accept().await?;
h3::Request::accept(conn).await
}));
}
Some(res) = self.accept.next() => {
match res {
Ok(session) => return Some(session),
Err(err) => tracing::warn!("ignoring failed HTTP/3 handshake: {}", err),
Err(err) => tracing::warn!("ignoring failed handshake: {}", err),
}
}
else => return None,
Expand Down