diff --git a/web-transport-quiche/src/ez/driver.rs b/web-transport-quiche/src/ez/driver.rs index 446ef60..594658a 100644 --- a/web-transport-quiche/src/ez/driver.rs +++ b/web-transport-quiche/src/ez/driver.rs @@ -36,6 +36,9 @@ pub(super) struct DriverState { /// The negotiated ALPN protocol, set after the handshake completes. alpn: Option>, + + /// Wakers waiting for the handshake to complete. + handshake_wakers: Vec, } impl DriverState { @@ -58,6 +61,7 @@ impl DriverState { bi: DriverOpen::new(next_bi), uni: DriverOpen::new(next_uni), alpn: None, + handshake_wakers: Vec::new(), } } @@ -78,6 +82,34 @@ impl DriverState { self.alpn.as_deref() } + /// Poll for handshake completion. + /// Returns Ready once the handshake completes, or if the connection is closed. + pub fn poll_handshake(&mut self, waker: &Waker) -> Poll> { + // Check if already established + if self.alpn.is_some() { + return Poll::Ready(Ok(())); + } + + // Check if connection is closed + if let Poll::Ready(err) = self.local.poll(waker) { + return Poll::Ready(Err(err)); + } + if let Poll::Ready(err) = self.remote.poll(waker) { + return Poll::Ready(Err(err)); + } + + // Wait for handshake + self.handshake_wakers.push(waker.clone()); + Poll::Pending + } + + /// Notify all wakers waiting for handshake completion. + /// Should be called when the handshake completes. + #[must_use = "wake the handshake wakers"] + pub fn complete_handshake(&mut self) -> Vec { + std::mem::take(&mut self.handshake_wakers) + } + #[must_use = "wake the driver"] pub fn send(&mut self, stream_id: StreamId) -> Option { if !self.send.insert(stream_id) { @@ -179,12 +211,17 @@ impl Driver { ) -> Result<(), ConnectionError> { // Capture the negotiated ALPN protocol. let alpn = qconn.application_proto(); - self.state.lock().alpn = if alpn.is_empty() { - None - } else { - Some(alpn.to_vec()) + let wakers = { + let mut state = self.state.lock(); + state.alpn = (!alpn.is_empty()).then(|| alpn.to_vec()); + state.complete_handshake() }; + // Wake all tasks waiting for handshake completion. + for waker in wakers { + waker.wake(); + } + // Run poll once to advance any pending operations. match self.poll(Waker::noop(), qconn) { Poll::Ready(Err(e)) => Err(e), diff --git a/web-transport-quiche/src/ez/server.rs b/web-transport-quiche/src/ez/server.rs index 8ae57f6..dc5c19e 100644 --- a/web-transport-quiche/src/ez/server.rs +++ b/web-transport-quiche/src/ez/server.rs @@ -12,7 +12,9 @@ use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use crate::ez::tls::{DynamicCertHook, StaticCertHook}; use crate::ez::DriverState; -use super::{CertResolver, Connection, DefaultMetrics, Driver, Lock, Metrics, Settings}; +use super::{ + CertResolver, Connection, ConnectionError, DefaultMetrics, Driver, Lock, Metrics, Settings, +}; /// Used with [ServerBuilder] to require specific parameters. #[derive(Default)] @@ -201,6 +203,7 @@ impl ServerBuilder { /// The peer address is available before calling [Incoming::accept]. pub struct Incoming { connection: Connection, + driver: Lock, } impl Incoming { @@ -209,9 +212,17 @@ impl Incoming { self.connection.peer_addr() } - /// Accept the connection, starting the post-handshake driver. - pub fn accept(self) -> Connection { - self.connection + /// Accept the connection, waiting for the TLS handshake to complete. + /// + /// Returns the connection once the handshake is complete, or an error if the connection + /// is closed before the handshake finishes. + pub async fn accept(self) -> Result { + use std::future::poll_fn; + + // Wait for handshake to complete + poll_fn(|cx| self.driver.lock().poll_handshake(cx.waker())).await?; + + Ok(self.connection) } } @@ -263,8 +274,11 @@ impl Server { let session = Driver::new(state.clone(), accept_bi.0, accept_uni.0); let inner = initial.start(session); - let connection = Connection::new(inner, state, accept_bi.1, accept_uni.1); - let incoming = Incoming { connection }; + let connection = Connection::new(inner, state.clone(), accept_bi.1, accept_uni.1); + let incoming = Incoming { + connection, + driver: state, + }; if accept.send(incoming).await.is_err() { return Ok(()); diff --git a/web-transport-quiche/src/server.rs b/web-transport-quiche/src/server.rs index f5e2b14..66f7744 100644 --- a/web-transport-quiche/src/server.rs +++ b/web-transport-quiche/src/server.rs @@ -12,6 +12,9 @@ pub enum ServerError { #[error("io error: {0}")] Io(Arc), + #[error("connection error: {0}")] + Connection(#[from] ez::ConnectionError), + #[error("settings error: {0}")] Settings(#[from] h3::SettingsError), @@ -158,13 +161,15 @@ impl Server { loop { tokio::select! { Some(incoming) = self.inner.accept() => { - let conn = incoming.accept(); - self.accept.push(Box::pin(h3::Request::accept(conn))); + self.accept.push(Box::pin(async move { + let conn = incoming.accept().await?; + h3::Request::accept(conn).await + })); } Some(res) = self.accept.next() => { match res { Ok(session) => return Some(session), - Err(err) => tracing::warn!("ignoring failed HTTP/3 handshake: {}", err), + Err(err) => tracing::warn!("ignoring failed handshake: {}", err), } } else => return None,