diff --git a/examples/single_threaded.rs b/examples/single_threaded.rs index 40e8d942b2..f297814c92 100644 --- a/examples/single_threaded.rs +++ b/examples/single_threaded.rs @@ -138,7 +138,7 @@ async fn http1_server() -> Result<(), Box> { loop { let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); + let io = IOTypeNotSend::new(TokioIo::new(stream)); let cnt = counter.clone(); @@ -166,7 +166,7 @@ async fn http1_client(url: hyper::Uri) -> Result<(), Box> let addr = format!("{}:{}", host, port); let stream = TcpStream::connect(addr).await?; - let io = TokioIo::new(stream); + let io = IOTypeNotSend::new(TokioIo::new(stream)); let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; diff --git a/examples/upgrades.rs b/examples/upgrades.rs index f9754e5d49..6a3afcd482 100644 --- a/examples/upgrades.rs +++ b/examples/upgrades.rs @@ -107,7 +107,8 @@ async fn client_upgrade_request(addr: SocketAddr) -> Result<()> { let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; tokio::task::spawn(async move { - if let Err(err) = conn.await { + // Don't forget to enable upgrades on the connection. + if let Err(err) = conn.with_upgrades().await { println!("Connection failed: {:?}", err); } }); diff --git a/src/client/conn/http1.rs b/src/client/conn/http1.rs index 60db767a42..600752b33f 100644 --- a/src/client/conn/http1.rs +++ b/src/client/conn/http1.rs @@ -14,7 +14,6 @@ use httparse::ParserConfig; use super::super::dispatch; use crate::body::{Body, Incoming as IncomingBody}; use crate::proto; -use crate::upgrade::Upgraded; type Dispatcher = proto::dispatch::Dispatcher, B, T, proto::h1::ClientTransaction>; @@ -51,15 +50,15 @@ pub struct Parts { #[must_use = "futures do nothing unless polled"] pub struct Connection where - T: Read + Write + Send + 'static, + T: Read + Write + 'static, B: Body + 'static, { - inner: Option>, + inner: Dispatcher, } impl Connection where - T: Read + Write + Send + Unpin + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Error: Into>, { @@ -67,7 +66,7 @@ where /// /// Only works for HTTP/1 connections. HTTP/2 connections will panic. pub fn into_parts(self) -> Parts { - let (io, read_buf, _) = self.inner.expect("already upgraded").into_inner(); + let (io, read_buf, _) = self.inner.into_inner(); Parts { io, read_buf, @@ -87,10 +86,7 @@ where /// and [`try_ready!`](https://docs.rs/futures/0.1.25/futures/macro.try_ready.html) /// to work with this function; or use the `without_shutdown` wrapper. pub fn poll_without_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner - .as_mut() - .expect("already upgraded") - .poll_without_shutdown(cx) + self.inner.poll_without_shutdown(cx) } } @@ -119,7 +115,7 @@ pub struct Builder { /// See [`client::conn`](crate::client::conn) for more. pub async fn handshake(io: T) -> crate::Result<(SendRequest, Connection)> where - T: Read + Write + Unpin + Send + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, @@ -240,9 +236,23 @@ impl fmt::Debug for SendRequest { // ===== impl Connection +impl Connection +where + T: Read + Write + Unpin + Send + 'static, + B: Body + 'static, + B::Error: Into>, +{ + /// Enable this connection to support higher-level HTTP upgrades. + /// + /// See [the `upgrade` module](crate::upgrade) for more. + pub fn with_upgrades(self) -> upgrades::UpgradeableConnection { + upgrades::UpgradeableConnection { inner: Some(self) } + } +} + impl fmt::Debug for Connection where - T: Read + Write + fmt::Debug + Send + 'static, + T: Read + Write + fmt::Debug + 'static, B: Body + 'static, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -252,7 +262,7 @@ where impl Future for Connection where - T: Read + Write + Unpin + Send + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, @@ -260,18 +270,15 @@ where type Output = crate::Result<()>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match ready!(Pin::new(self.inner.as_mut().unwrap()).poll(cx))? { + match ready!(Pin::new(&mut self.inner).poll(cx))? { proto::Dispatched::Shutdown => Poll::Ready(Ok(())), - proto::Dispatched::Upgrade(pending) => match self.inner.take() { - Some(h1) => { - let (io, buf, _) = h1.into_inner(); - pending.fulfill(Upgraded::new(io, buf)); - Poll::Ready(Ok(())) - } - _ => { - drop(pending); - unreachable!("Upgraded twice"); - } + proto::Dispatched::Upgrade(pending) => { + // With no `Send` bound on `I`, we can't try to do + // upgrades here. In case a user was trying to use + // `upgrade` with this API, send a special + // error letting them know about that. + pending.manual(); + Poll::Ready(Ok(())) }, } } @@ -474,7 +481,7 @@ impl Builder { io: T, ) -> impl Future, Connection)>> where - T: Read + Write + Unpin + Send + 'static, + T: Read + Write + Unpin + 'static, B: Body + 'static, B::Data: Send, B::Error: Into>, @@ -520,8 +527,55 @@ impl Builder { Ok(( SendRequest { dispatch: tx }, - Connection { inner: Some(proto) }, + Connection { inner: proto }, )) } } } + +mod upgrades { + use crate::upgrade::Upgraded; + + use super::*; + + // A future binding a connection with a Service with Upgrade support. + // + // This type is unnameable outside the crate. + #[must_use = "futures do nothing unless polled"] + #[allow(missing_debug_implementations)] + pub struct UpgradeableConnection + where + T: Read + Write + Unpin + Send + 'static, + B: Body + 'static, + B::Error: Into>, + { + pub(super) inner: Option>, + } + + + impl Future for UpgradeableConnection + where + I: Read + Write + Unpin + Send + 'static, + B: Body + 'static, + B::Data: Send, + B::Error: Into>, + { + type Output = crate::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match ready!(Pin::new(&mut self.inner.as_mut().unwrap().inner).poll(cx)) { + Ok(proto::Dispatched::Shutdown) => Poll::Ready(Ok(())), + Ok(proto::Dispatched::Upgrade(pending)) => { + let Parts { + io, + read_buf, + _inner, + } = self.inner.take().unwrap().into_parts(); + pending.fulfill(Upgraded::new(io, read_buf)); + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), + } + } + } +}