diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 6e2378cd68..3f461870cc 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -22,6 +22,7 @@ all-features = true [features] default = ["http2", "tokio-macros"] http2 = ["hyper/http2", "hyper-util/http2"] +http3 = ["s2n-quic", "s2n-quic-h3", "tls"] secrets = ["cookie/private", "cookie/key-expansion"] json = ["serde_json"] msgpack = ["rmp-serde"] @@ -76,8 +77,7 @@ futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" [dependencies.hyper-util] -git = "https://github.com/SergioBenitez/hyper-util.git" -branch = "fix-readversion" +version = "0.1.3" default-features = false features = ["http1", "server", "tokio"] @@ -99,6 +99,16 @@ version = "0.6.0-dev" path = "../http" features = ["serde"] +[dependencies.s2n-quic] +version = "1.32" +default-features = false +features = ["provider-address-token-default", "provider-tls-rustls"] +optional = true + +[dependencies.s2n-quic-h3] +git = "https://github.com/SergioBenitez/s2n-quic-h3.git" +optional = true + [target.'cfg(unix)'.dependencies] libc = "0.2.149" diff --git a/core/lib/src/data/data_stream.rs b/core/lib/src/data/data_stream.rs index 77d033284a..7026b5484a 100644 --- a/core/lib/src/data/data_stream.rs +++ b/core/lib/src/data/data_stream.rs @@ -3,7 +3,7 @@ use std::task::{Context, Poll}; use std::path::Path; use std::io::{self, Cursor}; -use futures::ready; +use futures::{ready, FutureExt}; use futures::stream::Stream; use tokio::fs::File; use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, ReadBuf, Take}; @@ -65,10 +65,15 @@ pub type BaseReader<'r> = Take>, RawReader<'r>>>; /// Direct reader to the underlying data stream. Not limited in any manner. pub type RawReader<'r> = StreamReader, Bytes>; +#[cfg(feature = "http3")] +use s2n_quic_h3::{self as quic, h3}; + /// Raw underlying data stream. pub enum RawStream<'r> { Empty, - Body(&'r mut HyperBody), + Body(HyperBody), + #[cfg(feature = "http3")] + H3Body(h3::server::RequestStream), Multipart(multer::Field<'r>), } @@ -343,7 +348,17 @@ impl Stream for RawStream<'_> { .poll_frame(cx) .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new())) .map_err(io::Error::other) - } + }, + #[cfg(feature = "http3")] + RawStream::H3Body(stream) => { + use bytes::Buf; + + match ready!(stream.poll_recv_data(cx)) { + Ok(Some(mut buf)) => Poll::Ready(Some(Ok(buf.copy_to_bytes(buf.remaining())))), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(io::Error::other(e)))), + } + }, RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other), RawStream::Empty => Poll::Ready(None), } @@ -356,6 +371,8 @@ impl Stream for RawStream<'_> { let (lower, upper) = (hint.lower(), hint.upper()); (lower as usize, upper.map(|x| x as usize)) }, + #[cfg(feature = "http3")] + RawStream::H3Body(_) => (0, Some(0)), RawStream::Multipart(mp) => mp.size_hint(), RawStream::Empty => (0, Some(0)), } @@ -367,17 +384,26 @@ impl std::fmt::Display for RawStream<'_> { match self { RawStream::Empty => f.write_str("empty stream"), RawStream::Body(_) => f.write_str("request body"), + #[cfg(feature = "http3")] + RawStream::H3Body(_) => f.write_str("http3 quic stream"), RawStream::Multipart(_) => f.write_str("multipart form field"), } } } -impl<'r> From<&'r mut HyperBody> for RawStream<'r> { - fn from(value: &'r mut HyperBody) -> Self { +impl<'r> From for RawStream<'r> { + fn from(value: HyperBody) -> Self { Self::Body(value) } } +#[cfg(feature = "http3")] +impl<'r> From> for RawStream<'r> { + fn from(value: h3::server::RequestStream) -> Self { + Self::H3Body(value) + } +} + impl<'r> From> for RawStream<'r> { fn from(value: multer::Field<'r>) -> Self { Self::Multipart(value) diff --git a/core/lib/src/data/mod.rs b/core/lib/src/data/mod.rs index e3eebdd23c..f7c879dc0d 100644 --- a/core/lib/src/data/mod.rs +++ b/core/lib/src/data/mod.rs @@ -18,3 +18,5 @@ pub use self::capped::{N, Capped}; pub use self::io_stream::{IoHandler, IoStream}; pub use ubyte::{ByteUnit, ToByteUnit}; pub use self::transform::{Transform, TransformBuf}; + +pub(crate) use self::data_stream::RawStream; diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs index 7b62522c55..8a14f3bdee 100644 --- a/core/lib/src/erased.rs +++ b/core/lib/src/erased.rs @@ -6,7 +6,6 @@ use std::task::{Poll, Context}; use futures::future::BoxFuture; use http::request::Parts; -use hyper::body::Incoming; use tokio::io::{AsyncRead, ReadBuf}; use crate::data::{Data, IoHandler}; @@ -16,9 +15,9 @@ use crate::{Request, Response, Rocket, Orbit}; // TODO: Write safety proofs. macro_rules! static_assert_covariance { - ($T:tt) => ( + ($($T:tt)*) => ( const _: () = { - fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x } + fn _assert_covariance<'x: 'y, 'y>(x: &'y $($T)*<'x>) -> &'y $($T)*<'y> { x } }; ) } @@ -40,7 +39,6 @@ pub struct ErasedResponse { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'static>, _request: Arc, - _incoming: Box, } impl Drop for ErasedResponse { @@ -81,8 +79,7 @@ impl ErasedRequest { pub async fn into_response( self, - incoming: Incoming, - data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>, + data: Data<'static>, preprocess: impl for<'r, 'x> FnOnce( &'r Rocket, &'r mut Request<'x>, @@ -95,13 +92,11 @@ impl ErasedRequest { Data<'r> ) -> BoxFuture<'r, Response<'r>>, ) -> ErasedResponse { - let mut incoming = Box::new(incoming); - let mut data: Data<'_> = { - let incoming: &mut Incoming = &mut *incoming; - let incoming: &'static mut Incoming = unsafe { transmute(incoming) }; - data_builder(incoming) - }; + // FIXME: UNSAFE. This is incorrect. The following fail: + // static_assert_covariance!(Data); + // static_assert_covariance!(crate::data::RawStream); + let mut data: Data<'_> = data; let mut parent = Arc::new(self); let token: T = { let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap(); @@ -122,7 +117,6 @@ impl ErasedRequest { ErasedResponse { _request: parent, - _incoming: incoming, response: response, } } diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 21753b1f1b..b2a6340260 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -7,7 +7,7 @@ use std::error::Error as StdError; use yansi::Paint; use figment::Profile; -use crate::{Rocket, Orbit}; +use crate::{Ignite, Orbit, Rocket}; /// An error that occurs during launch. /// @@ -89,6 +89,11 @@ pub enum ErrorKind { SentinelAborts(Vec), /// The configuration profile is not debug but no secret key is configured. InsecureSecretKey(Profile), + /// Liftoff failed. Contains the Rocket instance that failed to shutdown. + Liftoff( + Result, Arc>>, + Box + ), /// Shutdown failed. Contains the Rocket instance that failed to shutdown. Shutdown(Arc>), } @@ -121,6 +126,12 @@ impl Error { Error { handled: AtomicBool::new(false), kind } } + pub(crate) fn io_other(e: E) -> Error + where E: Into> + { + Error::from(io::Error::other(e)) + } + #[inline(always)] fn was_handled(&self) -> bool { self.handled.load(Ordering::Acquire) @@ -225,6 +236,11 @@ impl Error { "aborting due to sentinel-triggered abort(s)" } + ErrorKind::Liftoff(_, error) => { + error!("Rocket liftoff faield due to panicking liftoff fairing(s)."); + error_!("{error}"); + "aborting due to failed liftoff" + } ErrorKind::Shutdown(_) => { error!("Rocket failed to shutdown gracefully."); "aborting due to failed shutdown" @@ -246,6 +262,7 @@ impl fmt::Display for ErrorKind { ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), ErrorKind::Config(_) => "failed to extract configuration".fmt(f), ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f), + ErrorKind::Liftoff(_, _) => "liftoff failed".fmt(f), ErrorKind::Shutdown(_) => "shutdown failed".fmt(f), } } @@ -300,33 +317,37 @@ pub(crate) fn log_server_error(error: &Box) { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let error = &self.0; if let Some(e) = error.downcast_ref::() { - write!(f, "request processing failed: {e}")?; + write!(f, "request failed: {e}")?; } else if let Some(e) = error.downcast_ref::() { - write!(f, "connection I/O error: ")?; + write!(f, "connection error: ")?; match e.kind() { io::ErrorKind::NotConnected => write!(f, "remote disconnected")?, io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?, io::ErrorKind::ConnectionReset - | io::ErrorKind::ConnectionAborted - | io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?, + | io::ErrorKind::ConnectionAborted => write!(f, "terminated by remote")?, _ => write!(f, "{e}")?, } } else { write!(f, "http server error: {error}")?; } - if let Some(e) = error.source() { - write!(f, " ({})", ServerError(e))?; - } - Ok(()) } } + let mut error: &(dyn StdError + 'static) = &**error; if error.downcast_ref::().is_some() { - warn!("{}", ServerError(&**error)) + warn!("{}", ServerError(error)); + while let Some(source) = error.source() { + error = source; + warn_!("{}", ServerError(error)); + } } else { - error!("{}", ServerError(&**error)) + error!("{}", ServerError(error)); + while let Some(source) = error.source() { + error = source; + error_!("{}", ServerError(error)); + } } } diff --git a/core/lib/src/listener/cancellable.rs b/core/lib/src/listener/cancellable.rs index fbabfb2c6d..5998f43d8e 100644 --- a/core/lib/src/listener/cancellable.rs +++ b/core/lib/src/listener/cancellable.rs @@ -54,6 +54,7 @@ pin_project! { } } +#[derive(Debug)] enum State { /// I/O has not been cancelled. Proceed as normal. Active, @@ -97,11 +98,11 @@ pub trait CancellableExt: Sized { impl CancellableExt for L { } fn time_out() -> io::Error { - io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out") + io::Error::new(io::ErrorKind::TimedOut, "shutdown grace period elapsed") } fn gone() -> io::Error { - io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated") + io::Error::new(io::ErrorKind::BrokenPipe, "I/O driver terminated") } impl CancellableListener> @@ -162,7 +163,7 @@ impl Listener for CancellableListener } impl CancellableIo { - fn inner(&self) -> Option<&I> { + pub fn inner(&self) -> Option<&I> { self.io.as_ref() } @@ -171,7 +172,7 @@ impl CancellableIo { mut self: Pin<&mut Self>, cx: &mut Context<'_>, do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, - ) -> Poll> { + ) -> Poll>> { let mut me = self.as_mut().project(); let io = match me.io.as_pin_mut() { Some(io) => io, @@ -184,14 +185,14 @@ impl CancellableIo { if me.trigger.as_mut().poll(cx).is_ready() { *me.state = State::Grace(Box::pin(sleep(*me.grace))); } else { - return do_io(io, cx); + return do_io(io, cx).map_ok(Some); } } State::Grace(timer) => { if timer.as_mut().poll(cx).is_ready() { *me.state = State::Mercy(Box::pin(sleep(*me.mercy))); } else { - return do_io(io, cx); + return do_io(io, cx).map_ok(Some); } } State::Mercy(timer) => { @@ -218,7 +219,9 @@ impl AsyncRead for CancellableIo { cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf)) + self.as_mut() + .poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf)) + .map_ok(|ok| ok.unwrap_or_default()) } } @@ -228,21 +231,27 @@ impl AsyncWrite for CancellableIo { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf)) + self.as_mut() + .poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf)) + .map_ok(|ok| ok.unwrap_or_default()) } fn poll_flush( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx)) + self.as_mut() + .poll_trigger_then(cx, |io, cx| io.poll_flush(cx)) + .map_ok(|ok| ok.unwrap_or_default()) } fn poll_shutdown( mut self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx)) + self.as_mut() + .poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx)) + .map_ok(|ok| ok.unwrap_or_default()) } fn poll_write_vectored( @@ -250,7 +259,9 @@ impl AsyncWrite for CancellableIo { cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs)) + self.as_mut() + .poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs)) + .map_ok(|ok| ok.unwrap_or_default()) } fn is_write_vectored(&self) -> bool { diff --git a/core/lib/src/listener/endpoint.rs b/core/lib/src/listener/endpoint.rs index 26640d1d1c..54e22e2c81 100644 --- a/core/lib/src/listener/endpoint.rs +++ b/core/lib/src/listener/endpoint.rs @@ -39,6 +39,7 @@ impl Endpoint { pub fn tcp(&self) -> Option { match self { Endpoint::Tcp(addr) => Some(*addr), + Endpoint::Tls(addr, _) => addr.tcp(), _ => None, } } @@ -46,6 +47,7 @@ impl Endpoint { pub fn unix(&self) -> Option<&Path> { match self { Endpoint::Unix(addr) => Some(addr), + Endpoint::Tls(addr, _) => addr.unix(), _ => None, } } diff --git a/core/lib/src/listener/mod.rs b/core/lib/src/listener/mod.rs index 244c36c604..b656c7fdad 100644 --- a/core/lib/src/listener/mod.rs +++ b/core/lib/src/listener/mod.rs @@ -13,6 +13,8 @@ pub mod unix; #[cfg_attr(nightly, doc(cfg(feature = "tls")))] pub mod tls; pub mod tcp; +#[cfg(feature = "http3")] +pub mod quic; pub use endpoint::*; pub use listener::*; diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs new file mode 100644 index 0000000000..4c57993351 --- /dev/null +++ b/core/lib/src/listener/quic.rs @@ -0,0 +1,150 @@ +use std::io; +use std::net::SocketAddr; + +use bytes::Bytes; +use s2n_quic as quic; +use s2n_quic_h3 as quic_h3; +use quic_h3::h3::server as h3; +use s2n_quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES}; +use s2n_quic::provider::tls::rustls::Server as H3TlsServer; + +use tokio::sync::Mutex; + +use crate::listener::{Bindable, Listener}; +use crate::tls::TlsConfig; + +use super::{Connection, Endpoint}; + +pub struct QuicBindable { + address: SocketAddr, + tls: TlsConfig, +} + +pub struct QuicListener { + listener: Mutex, + local_addr: SocketAddr, +} + +pub struct QuicConnection { + pub(crate) handle: quic::connection::Handle, + pub(crate) stream: QuicStream, +} + +pub struct QuicStream { + inner: h3::RequestStream, Bytes>, +} + +impl Bindable for QuicBindable { + type Listener = QuicListener; + + type Error = io::Error; + + async fn bind(self) -> Result { + // FIXME: Remove this as soon as `s2n_quic` is on rustls 0.22. + let cert_chain = crate::tls::util::load_cert_chain(&mut self.tls.certs_reader().unwrap()) + .unwrap() + .into_iter() + .map(|v| v.to_vec()) + .map(rustls::Certificate) + .collect::>(); + + let key = crate::tls::util::load_key(&mut self.tls.key_reader().unwrap()) + .unwrap() + .secret_der() + .to_vec(); + + let mut tls = rustls::server::ServerConfig::builder() + .with_cipher_suites(DEFAULT_CIPHERSUITES) + .with_safe_default_kx_groups() + .with_safe_default_protocol_versions() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))? + .with_client_cert_verifier(rustls::server::NoClientAuth::boxed()) + .with_single_cert(cert_chain, rustls::PrivateKey(key)) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?; + + tls.alpn_protocols = vec![b"h3".to_vec()]; + tls.ignore_client_order = self.tls.prefer_server_cipher_order; + tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024); + tls.ticketer = rustls::Ticketer::new() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?; + + let listener = quic::Server::builder() + .with_tls(H3TlsServer::new(tls)) + .unwrap_or_else(|e| match e { }) + .with_io(self.address)? + .start() + .map_err(io::Error::other)?; + + let local_addr = listener.local_addr()?; + + Ok(QuicListener { listener: Mutex::new(listener), local_addr }) + } +} + +#[derive(Copy, Clone)] +pub struct Void(pub T); + +impl Listener for QuicListener { + type Accept = quic::Connection; + + type Connection = Void; + + async fn accept(&self) -> io::Result { + self.listener + .lock().await + .accept().await + .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "server closed")) + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + let addr = accept.handle().local_addr()?; + Ok(Void(addr)) + } + + fn socket_addr(&self) -> io::Result { + Ok(self.local_addr.into()) + } +} + +impl + Send + Sync + Unpin> Connection for Void { + fn peer_address(&self) -> io::Result { + Ok(self.0.clone().into()) + } +} + +mod async_traits { + use std::task::{Context, Poll}; + use std::pin::Pin; + + use super::*; + + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + + impl AsyncRead for Void { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl AsyncWrite for Void { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &[u8], + ) -> Poll> { + Poll::Ready(Ok(0)) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } +} diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs index ce2b53ffaf..d2fd81f82e 100644 --- a/core/lib/src/listener/tls.rs +++ b/core/lib/src/listener/tls.rs @@ -27,7 +27,7 @@ pub struct TlsBindable { } impl TlsConfig { - pub(crate) fn acceptor(&self) -> Result { + pub(crate) fn server_config(&self) -> Result { let provider = rustls::crypto::CryptoProvider { cipher_suites: self.ciphers().map(|c| c.into()).collect(), ..rustls::crypto::ring::default_provider() @@ -64,7 +64,7 @@ impl TlsConfig { tls_config.alpn_protocols.insert(0, b"h2".to_vec()); } - Ok(TlsAcceptor::from(Arc::new(tls_config))) + Ok(tls_config) } } @@ -75,7 +75,7 @@ impl Bindable for TlsBindable { async fn bind(self) -> Result { Ok(TlsListener { - acceptor: self.tls.acceptor()?, + acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)), listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?, config: self.tls, }) diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 40570e7b45..563d1e2a60 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -1,12 +1,14 @@ use std::fmt; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::time::Duration; use yansi::Paint; use either::Either; use figment::{Figment, Provider}; use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield}; -use crate::listener::{Endpoint, Bindable, DefaultListener}; +use crate::listener::{Bindable, DefaultListener, Endpoint, Listener}; use crate::router::Router; use crate::util::TripWire; use crate::fairing::{Fairing, Fairings}; @@ -688,12 +690,78 @@ impl Rocket { } async fn _launch_on(self, bindable: B) -> Result, Error> { - let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?; - self.serve(listener).await + let listener = bindable.bind().await + .map_err(|e| ErrorKind::Bind(Box::new(e)))?; + + let rocket = Arc::new(self.into_orbit(listener.socket_addr()?)); + if let Err(e) = tokio::spawn(Rocket::liftoff(rocket.clone())).await { + let rocket = rocket.try_wait_shutdown().await; + return Err(ErrorKind::Liftoff(rocket, Box::new(e)).into()); + } + + // #[cfg(not(feature = "http3"))] { + // rocket.clone().serve(listener).await?; + // } + + // #[cfg(feature = "http3")] { + // // let (r1, r2) = tokio::join!( + // // tokio::task::spawn(rocket.clone().serve(listener)), + // // tokio::task::spawn(rocket.clone().serve3(listener)), + // // ); + // } + + rocket.clone().serve(listener).await?; + Ok(rocket.try_wait_shutdown().await.map_err(ErrorKind::Shutdown)?) } } impl Rocket { + /// Rocket wraps all connections in a `CancellableIo` struct, an internal + /// structure that gracefully closes I/O when it receives a signal. That + /// signal is the `shutdown` future. When the future resolves, + /// `CancellableIo` begins to terminate in grace, mercy, and finally force + /// close phases. Since all connections are wrapped in `CancellableIo`, this + /// eventually ends all I/O. + /// + /// At that point, unless a user spawned an infinite, stand-alone task that + /// isn't monitoring `Shutdown`, all tasks should resolve. This means that + /// all instances of the shared `Arc` are dropped and we can return + /// the owned instance of `Rocket`. + /// + /// Unfortunately, the Hyper `server` future resolves as soon as it has + /// finished processing requests without respect for ongoing responses. That + /// is, `server` resolves even when there are running tasks that are + /// generating a response. So, `server` resolving implies little to nothing + /// about the state of connections. As a result, we depend on the timing of + /// grace + mercy + some buffer to determine when all connections should be + /// closed, thus all tasks should be complete, thus all references to + /// `Arc` should be dropped and we can get back a unique reference. + async fn try_wait_shutdown(self: Arc) -> Result, Arc> { + info!("Shutting down. Waiting for shutdown fairings and pending I/O..."); + tokio::spawn({ + let rocket = self.clone(); + async move { rocket.fairings.handle_shutdown(&*rocket).await } + }); + + let config = &self.config.shutdown; + let wait = Duration::from_micros(250); + for period in [wait, config.grace(), wait, config.mercy(), wait * 4] { + if Arc::strong_count(&self) == 1 { break } + tokio::time::sleep(period).await; + } + + match Arc::try_unwrap(self) { + Ok(rocket) => { + info!("Graceful shutdown completed successfully."); + Ok(rocket.into_ignite()) + } + Err(rocket) => { + warn!("Shutdown failed: outstanding background I/O."); + Err(rocket) + } + } + } + pub(crate) fn into_ignite(self) -> Rocket { Rocket(Igniting { router: self.0.router, diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 3fbe2ae702..4d2bf25bd4 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -7,32 +7,31 @@ use hyper::service::service_fn; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use hyper_util::server::conn::auto::Builder; use futures::{Future, TryFutureExt, future::{select, Either::*}}; -use tokio::time::sleep; -use crate::{Request, Rocket, Orbit, Data, Ignite}; +use crate::{Data, Orbit, Request, Rocket}; use crate::request::ConnectionMeta; use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler}; use crate::listener::{Listener, CancellableExt, BouncedExt}; -use crate::error::{Error, ErrorKind}; -use crate::data::IoStream; +use crate::data::{IoStream, RawStream}; use crate::util::ReaderStream; use crate::http::Status; +type Result = std::result::Result; + impl Rocket { - async fn service( + async fn service>>( self: Arc, - mut req: hyper::Request, + parts: http::request::Parts, + stream: T, + upgrade: Option, connection: ConnectionMeta, ) -> Result>, http::Error> { - let upgrade = hyper::upgrade::on(&mut req); - let (parts, incoming) = req.into_parts(); let request = ErasedRequest::new(self, parts, |rocket, parts| { Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e) }); let mut response = request.into_response( - incoming, - |incoming| Data::from(incoming), + Data::from(stream), |rocket, request, data| Box::pin(rocket.preprocess(request, data)), |token, rocket, request, data| Box::pin(async move { if !request.errors.is_empty() { @@ -46,7 +45,7 @@ impl Rocket { ).await; let io_handler = response.to_io_handler(Rocket::extract_io_handler); - if let Some(handler) = io_handler { + if let (Some(handler), Some(upgrade)) = (io_handler, upgrade) { let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other); tokio::task::spawn(io_handler_task(upgrade, handler)); } @@ -83,8 +82,8 @@ async fn io_handler_task(stream: S, mut handler: ErasedIoHandler) } } -impl Rocket { - pub(crate) async fn serve(self, listener: L) -> Result +impl Rocket { + pub(crate) async fn serve(self: Arc, listener: L) -> Result<()> where L: Listener + 'static { let mut builder = Builder::new(TokioExecutor::new()); @@ -107,17 +106,19 @@ impl Rocket { } let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown); - let rocket = Arc::new(self.into_orbit(listener.socket_addr()?)); - let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await; - let (server, listener) = (Arc::new(builder), Arc::new(listener)); while let Some(accept) = listener.accept_next().await { - let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone()); + let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone()); tokio::spawn({ let result = async move { let conn = TokioIo::new(listener.connect(accept).await?); let meta = ConnectionMeta::from(conn.inner()); - let service = service_fn(|req| rocket.clone().service(req, meta.clone())); + let service = service_fn(|mut req| { + let upgrade = hyper::upgrade::on(&mut req); + let (parts, incoming) = req.into_parts(); + rocket.clone().service(parts, incoming, Some(upgrade), meta.clone()) + }); + let serve = pin!(server.serve_connection_with_upgrades(conn, service)); match select(serve, rocket.shutdown()).await { Left((result, _)) => result, @@ -132,49 +133,58 @@ impl Rocket { }); } - // Rocket wraps all connections in a `CancellableIo` struct, an internal - // structure that gracefully closes I/O when it receives a signal. That - // signal is the `shutdown` future. When the future resolves, - // `CancellableIo` begins to terminate in grace, mercy, and finally - // force close phases. Since all connections are wrapped in - // `CancellableIo`, this eventually ends all I/O. - // - // At that point, unless a user spawned an infinite, stand-alone task - // that isn't monitoring `Shutdown`, all tasks should resolve. This - // means that all instances of the shared `Arc` are dropped and - // we can return the owned instance of `Rocket`. - // - // Unfortunately, the Hyper `server` future resolves as soon as it has - // finished processing requests without respect for ongoing responses. - // That is, `server` resolves even when there are running tasks that are - // generating a response. So, `server` resolving implies little to - // nothing about the state of connections. As a result, we depend on the - // timing of grace + mercy + some buffer to determine when all - // connections should be closed, thus all tasks should be complete, thus - // all references to `Arc` should be dropped and we can get back - // a unique reference. - info!("Shutting down. Waiting for shutdown fairings and pending I/O..."); - tokio::spawn({ - let rocket = rocket.clone(); - async move { rocket.fairings.handle_shutdown(&*rocket).await } - }); + Ok(()) + } +} - let config = &rocket.config.shutdown; - let wait = Duration::from_micros(250); - for period in [wait, config.grace(), wait, config.mercy(), wait * 4] { - if Arc::strong_count(&rocket) == 1 { break } - sleep(period).await; - } +#[cfg(feature = "http3")] +impl Rocket { + pub(crate) async fn serve3(self: Arc, listener: L) -> Result<()> + where L: Listener + 'static + { + use crate::listener::quic::Void; + use tokio_stream::StreamExt; + use s2n_quic_h3 as quic_h3; - match Arc::try_unwrap(rocket) { - Ok(rocket) => { - info!("Graceful shutdown completed successfully."); - Ok(rocket.into_ignite()) - } - Err(rocket) => { - warn!("Shutdown failed: outstanding background I/O."); - Err(Error::new(ErrorKind::Shutdown(rocket))) - } + type H3Conn = quic_h3::h3::server::Connection; + + let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown); + let listener = Arc::new(listener); + while let Some(accept) = listener.accept_next().await { + let rocket = self.clone(); + tokio::spawn({ + let result = async move { + let void = Void(accept.handle().local_addr()?); + let quic_conn = quic_h3::Connection::new(accept); + let mut h3 = H3Conn::new(quic_conn).await.map_err(io::Error::other)?; + while let Some((req, stream)) = h3.accept().await.map_err(io::Error::other)? { + let rocket = rocket.clone(); + tokio::spawn(async move { + let (mut tx, rx) = stream.split(); + let (parts, _) = req.into_parts(); + let response = rocket + .service(parts, rx, None, ConnectionMeta::from(&void)).await + .map_err(io::Error::other)?; + + let (r, mut stream) = response.into_parts(); + let response = http::Response::from_parts(r, ()); + tx.send_response(response).await.map_err(io::Error::other)?; + + while let Some(Ok(bytes)) = stream.next().await { + tx.send_data(bytes).await.map_err(io::Error::other)?; + } + + tx.finish().await.map_err(io::Error::other) + }).await.map_err(io::Error::other)??; + } + + Ok(()) + }; + + result.inspect_err(crate::error::log_server_error) + }); } + + Ok(()) } } diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index 3131e16d5c..533b5793af 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -427,7 +427,7 @@ impl TlsConfig { } pub fn validate(&self) -> Result<(), crate::tls::Error> { - self.acceptor().map(|_| ()) + self.server_config().map(|_| ()) } } diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index 9c72493908..a9a2f5727d 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" publish = false [dependencies] -rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] } +rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3"] } yansi = "1.0.0-rc.1" diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 4ce4254c24..31c8d09cd6 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -22,5 +22,5 @@ fn rocket() -> _ { // Run `./private/gen_certs.sh` to generate a CA and key pairs. rocket::build() .mount("/", routes![hello, mutual]) - .attach(redirector::Redirector::on(3000)) + // .attach(redirector::Redirector::on(3000)) }