From 42f57b43cb08db10e1be67df95a39f666bbc4a1a Mon Sep 17 00:00:00 2001 From: Jiahao XU Date: Tue, 22 Aug 2023 03:01:35 +1000 Subject: [PATCH] Fix panic in building h3 client when udp is forbidden (#1945) Fixed #1942 Signed-off-by: Jiahao XU --- src/async_impl/client.rs | 144 +++++++++++++++++----------- src/async_impl/h3_client/connect.rs | 6 +- 2 files changed, 90 insertions(+), 60 deletions(-) diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 23e491734..a6d508032 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -281,6 +281,55 @@ impl ClientBuilder { } let http = HttpConnector::new_with_resolver(DynResolver::new(resolver.clone())); + #[cfg(all(feature = "http3", feature = "__rustls"))] + let build_h3_connector = + |resolver, + tls, + quic_max_idle_timeout: Option, + quic_stream_receive_window, + quic_receive_window, + quic_send_window, + local_address, + http_version_pref: &HttpVersionPref| { + let mut transport_config = TransportConfig::default(); + + if let Some(max_idle_timeout) = quic_max_idle_timeout { + transport_config.max_idle_timeout(Some( + max_idle_timeout.try_into().map_err(error::builder)?, + )); + } + + if let Some(stream_receive_window) = quic_stream_receive_window { + transport_config.stream_receive_window(stream_receive_window); + } + + if let Some(receive_window) = quic_receive_window { + transport_config.receive_window(receive_window); + } + + if let Some(send_window) = quic_send_window { + transport_config.send_window(send_window); + } + + let res = H3Connector::new( + DynResolver::new(resolver), + tls, + local_address, + transport_config, + ); + + match res { + Ok(connector) => Ok(Some(connector)), + Err(err) => { + if let HttpVersionPref::Http3 = http_version_pref { + Err(error::builder(err)) + } else { + Ok(None) + } + } + } + }; + #[cfg(feature = "__tls")] match config.tls { #[cfg(feature = "default-tls")] @@ -367,32 +416,16 @@ impl ClientBuilder { TlsBackend::BuiltRustls(conn) => { #[cfg(feature = "http3")] { - let mut transport_config = TransportConfig::default(); - - if let Some(max_idle_timeout) = config.quic_max_idle_timeout { - transport_config.max_idle_timeout(Some( - max_idle_timeout.try_into().map_err(error::builder)?, - )); - } - - if let Some(stream_receive_window) = config.quic_stream_receive_window { - transport_config.stream_receive_window(stream_receive_window); - } - - if let Some(receive_window) = config.quic_receive_window { - transport_config.receive_window(receive_window); - } - - if let Some(send_window) = config.quic_send_window { - transport_config.send_window(send_window); - } - - h3_connector = Some(H3Connector::new( - DynResolver::new(resolver), + h3_connector = build_h3_connector( + resolver, conn.clone(), + config.quic_max_idle_timeout, + config.quic_stream_receive_window, + config.quic_receive_window, + config.quic_send_window, config.local_address, - transport_config, - )); + &config.http_version_pref, + )?; } Connector::new_rustls_tls( @@ -527,32 +560,16 @@ impl ClientBuilder { { tls.enable_early_data = config.tls_enable_early_data; - let mut transport_config = TransportConfig::default(); - - if let Some(max_idle_timeout) = config.quic_max_idle_timeout { - transport_config.max_idle_timeout(Some( - max_idle_timeout.try_into().map_err(error::builder)?, - )); - } - - if let Some(stream_receive_window) = config.quic_stream_receive_window { - transport_config.stream_receive_window(stream_receive_window); - } - - if let Some(receive_window) = config.quic_receive_window { - transport_config.receive_window(receive_window); - } - - if let Some(send_window) = config.quic_send_window { - transport_config.send_window(send_window); - } - - h3_connector = Some(H3Connector::new( - DynResolver::new(resolver), + h3_connector = build_h3_connector( + resolver, tls.clone(), + config.quic_max_idle_timeout, + config.quic_stream_receive_window, + config.quic_receive_window, + config.quic_send_window, config.local_address, - transport_config, - )); + &config.http_version_pref, + )?; } Connector::new_rustls_tls( @@ -639,11 +656,15 @@ impl ClientBuilder { accepts: config.accepts, #[cfg(feature = "cookies")] cookie_store: config.cookie_store, + // Use match instead of map since config is partially moved + // and it cannot be used in closure #[cfg(feature = "http3")] - h3_client: H3Client::new( - h3_connector.expect("missing HTTP/3 connector"), - config.pool_idle_timeout, - ), + h3_client: match h3_connector { + Some(h3_connector) => { + Some(H3Client::new(h3_connector, config.pool_idle_timeout)) + } + None => None, + }, hyper: builder.build(connector), headers: config.headers, redirect_policy: config.redirect_policy, @@ -1759,10 +1780,10 @@ impl Client { let in_flight = match version { #[cfg(feature = "http3")] - http::Version::HTTP_3 => { + http::Version::HTTP_3 if self.inner.h3_client.is_some() => { let mut req = builder.body(body).expect("valid request parts"); *req.headers_mut() = headers.clone(); - ResponseFuture::H3(self.inner.h3_client.request(req)) + ResponseFuture::H3(self.inner.h3_client.as_ref().unwrap().request(req)) } _ => { let mut req = builder @@ -1986,7 +2007,7 @@ struct ClientRef { headers: HeaderMap, hyper: HyperClient, #[cfg(feature = "http3")] - h3_client: H3Client, + h3_client: Option, redirect_policy: redirect::Policy, referer: bool, request_timeout: Option, @@ -2117,7 +2138,13 @@ impl PendingRequest { .body(body) .expect("valid request parts"); *req.headers_mut() = self.headers.clone(); - ResponseFuture::H3(self.client.h3_client.request(req)) + ResponseFuture::H3( + self.client + .h3_client + .as_ref() + .expect("H3 client must exists, otherwise we can't have a h3 request here") + .request(req), + ) } _ => { let mut req = hyper::Request::builder() @@ -2339,7 +2366,10 @@ impl Future for PendingRequest { .expect("valid request parts"); *req.headers_mut() = headers.clone(); std::mem::swap(self.as_mut().headers(), &mut headers); - ResponseFuture::H3(self.client.h3_client.request(req)) + ResponseFuture::H3(self.client.h3_client + .as_ref() + .expect("H3 client must exists, otherwise we can't have a h3 request here") + .request(req)) } _ => { let mut req = hyper::Request::builder() diff --git a/src/async_impl/h3_client/connect.rs b/src/async_impl/h3_client/connect.rs index daedb6e44..968704713 100644 --- a/src/async_impl/h3_client/connect.rs +++ b/src/async_impl/h3_client/connect.rs @@ -28,7 +28,7 @@ impl H3Connector { tls: rustls::ClientConfig, local_addr: Option, transport_config: TransportConfig, - ) -> H3Connector { + ) -> Result { let mut config = ClientConfig::new(Arc::new(tls)); // FIXME: Replace this when there is a setter. config.transport_config(Arc::new(transport_config)); @@ -38,10 +38,10 @@ impl H3Connector { None => "[::]:0".parse::().unwrap(), }; - let mut endpoint = Endpoint::client(socket_addr).expect("unable to create QUIC endpoint"); + let mut endpoint = Endpoint::client(socket_addr)?; endpoint.set_default_client_config(config); - Self { resolver, endpoint } + Ok(Self { resolver, endpoint }) } pub async fn connect(&mut self, dest: Uri) -> Result {