From c68a919c2eed7aa40431acc89424d207303a826a Mon Sep 17 00:00:00 2001 From: John Howard Date: Fri, 5 Jul 2024 12:36:18 -0700 Subject: [PATCH] Abstract and fix draining (#1176) * Abstract and fix draining * Centralize draining logic in one helper function * Fix inbound draining (HBONE). Before, we did not shut down the listener upon draining. This meant new connections would go to the old ztunnel on a ztunnel restart. * Simplify inbound draining; do not re-create the force shutdown logic, and instead let the common abstraction do it (which does it slightly better) * socsk5: add propery draining with force shutdown. Remove double-spawn, which adds some complexity around the proxy_to_cancellable. This is primarily tested in https://github.com/istio/istio/pull/51710, which sends a large stream of requests and restarts ztunnel and the backend app (2 different tests). With this change, these tests pass. It would be good to get more isolated tests in this repo in the future as well * Refactor our into own package * Add tests for draining * unclean but forceful shutdown * fmt * Fix flakes * fix flake --- Cargo.lock | 10 -- Cargo.toml | 1 - fuzz/Cargo.lock | 10 -- src/admin.rs | 4 +- src/app.rs | 16 ++- src/dns/server.rs | 20 +-- src/drain.rs | 220 +++++++++++++++++++++++++++++++ src/hyper_util.rs | 16 ++- src/inpod/protocol.rs | 8 +- src/inpod/statemanager.rs | 33 +++-- src/inpod/test_helpers.rs | 8 +- src/inpod/workloadmanager.rs | 13 +- src/lib.rs | 1 + src/metrics/server.rs | 4 +- src/proxy.rs | 11 +- src/proxy/connection_manager.rs | 38 +++--- src/proxy/h2/server.rs | 20 +-- src/proxy/inbound.rs | 96 ++++++++------ src/proxy/inbound_passthrough.rs | 111 +++++++--------- src/proxy/outbound.rs | 157 ++++++++-------------- src/proxy/pool.rs | 14 +- src/proxy/socks5.rs | 126 +++++++++--------- src/proxyfactory.rs | 8 +- src/readiness/server.rs | 4 +- src/signal.rs | 2 +- src/test_helpers/dns.rs | 7 +- src/test_helpers/inpod.rs | 5 +- src/test_helpers/linux.rs | 22 +++- src/test_helpers/netns.rs | 16 +++ tests/direct.rs | 2 - tests/namespaced.rs | 132 +++++++++++++++++++ 31 files changed, 736 insertions(+), 399 deletions(-) create mode 100644 src/drain.rs diff --git a/Cargo.lock b/Cargo.lock index 997d3b3bd..561b03c34 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -723,15 +723,6 @@ dependencies = [ "syn 2.0.60", ] -[[package]] -name = "drain" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d105028bd2b5dfcb33318fd79a445001ead36004dd8dffef1bdd7e493d8bc1e" -dependencies = [ - "tokio", -] - [[package]] name = "dtoa" version = "1.0.9" @@ -3769,7 +3760,6 @@ dependencies = [ "criterion", "ctor", "diff", - "drain", "duration-str", "flurry", "futures", diff --git a/Cargo.toml b/Cargo.toml index e1036493a..77c106baf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,6 @@ base64 = "0.22" byteorder = "1.5" bytes = { version = "1.5", features = ["serde"] } chrono = "0.4" -drain = "0.1" duration-str = "0.7" futures = "0.3" futures-core = "0.3" diff --git a/fuzz/Cargo.lock b/fuzz/Cargo.lock index 7e2f9e278..ac200cf33 100644 --- a/fuzz/Cargo.lock +++ b/fuzz/Cargo.lock @@ -559,15 +559,6 @@ dependencies = [ "syn", ] -[[package]] -name = "drain" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f1a0abf3fcefad9b4dd0e414207a7408e12b68414a01e6bb19b897d5bd7632d" -dependencies = [ - "tokio", -] - [[package]] name = "dtoa" version = "1.0.9" @@ -3304,7 +3295,6 @@ dependencies = [ "byteorder", "bytes", "chrono", - "drain", "duration-str", "flurry", "futures", diff --git a/src/admin.rs b/src/admin.rs index 8ff7c148e..57f667de5 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -23,7 +23,6 @@ use crate::{signal, telemetry}; use base64::engine::general_purpose::STANDARD; use bytes::Bytes; -use drain::Watch; use http_body_util::Full; use hyper::body::Incoming; use hyper::{header::HeaderValue, header::CONTENT_TYPE, Request, Response}; @@ -36,6 +35,7 @@ use std::sync::Arc; use std::time::SystemTime; use std::{net::SocketAddr, time::Duration}; +use crate::drain::DrainWatcher; use tokio::time; use tracing::{error, info, warn}; use tracing_subscriber::filter; @@ -106,7 +106,7 @@ impl Service { config: Arc, proxy_state: DemandProxyState, shutdown_trigger: signal::ShutdownTrigger, - drain_rx: Watch, + drain_rx: DrainWatcher, cert_manager: Arc, ) -> anyhow::Result { Server::::bind( diff --git a/src/app.rs b/src/app.rs index ae4f17b59..46633bf8b 100644 --- a/src/app.rs +++ b/src/app.rs @@ -16,14 +16,14 @@ use std::future::Future; use crate::proxyfactory::ProxyFactory; +use crate::drain; +use anyhow::Context; +use prometheus_client::registry::Registry; use std::net::SocketAddr; use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{mpsc, Arc}; use std::thread; - -use anyhow::Context; -use prometheus_client::registry::Registry; use tokio::task::JoinSet; use tracing::{warn, Instrument}; @@ -45,7 +45,7 @@ pub async fn build_with_cert( // Any component which wants time to gracefully exit should take in a drain_rx clone, // await drain_rx.signaled(), then cleanup. // Note: there is still a hard timeout if the draining takes too long - let (drain_tx, drain_rx) = drain::channel(); + let (drain_tx, drain_rx) = drain::new(); // Register readiness tasks. let ready = readiness::Ready::new(); @@ -320,7 +320,7 @@ fn init_inpod_proxy_mgr( config: &config::Config, proxy_gen: ProxyFactory, ready: readiness::Ready, - drain_rx: drain::Watch, + drain_rx: drain::DrainWatcher, ) -> anyhow::Result + Send + Sync>>> { let metrics = Arc::new(crate::inpod::metrics::Metrics::new( registry.sub_registry_with_prefix("workload_manager"), @@ -349,7 +349,7 @@ pub struct Bound { pub udp_dns_proxy_address: Option, pub shutdown: signal::Shutdown, - drain_tx: drain::Signal, + drain_tx: drain::DrainTrigger, } impl Bound { @@ -359,7 +359,9 @@ impl Bound { // Start a drain; this will attempt to end all connections // or itself be interrupted by a stronger TERM signal, whichever comes first. - self.drain_tx.drain().await; + self.drain_tx + .start_drain_and_wait(drain::DrainMode::Graceful) + .await; Ok(()) } diff --git a/src/dns/server.rs b/src/dns/server.rs index c4bd8133a..fe2877880 100644 --- a/src/dns/server.rs +++ b/src/dns/server.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use drain::Watch; use hickory_proto::error::ProtoErrorKind; use hickory_proto::op::ResponseCode; use hickory_proto::rr::rdata::{A, AAAA, CNAME}; @@ -43,6 +42,7 @@ use crate::dns::metrics::{ }; use crate::dns::name_util::{has_domain, trim_domain}; use crate::dns::resolver::{Answer, Resolver}; +use crate::drain::{DrainMode, DrainWatcher}; use crate::metrics::{DeferRecorder, IncrementRecorder, Recorder}; use crate::proxy::Error; use crate::socket::to_canonical; @@ -65,7 +65,7 @@ pub struct Server { tcp_addr: SocketAddr, udp_addr: SocketAddr, server: ServerFuture, - drain: Watch, + drain: DrainWatcher, } impl Server { @@ -85,7 +85,7 @@ impl Server { state: DemandProxyState, forwarder: Arc, metrics: Arc, - drain: Watch, + drain: DrainWatcher, socket_factory: &(dyn SocketFactory + Send + Sync), allow_unknown_source: bool, ) -> Result { @@ -171,9 +171,11 @@ impl Server { } } } - _ = self.drain.signaled() => { + res = self.drain.wait_for_drain() => { info!("shutting down the DNS server"); - let _ = self.server.shutdown_gracefully().await; + if res.mode() == DrainMode::Graceful { + let _ = self.server.shutdown_gracefully().await; + } } } info!("dns server drained"); @@ -875,7 +877,6 @@ mod tests { use prometheus_client::registry::Registry; use super::*; - use crate::strng; use crate::test_helpers::dns::{ a, aaaa, cname, ip, ipv4, ipv6, n, new_message, new_tcp_client, new_udp_client, run_dns, send_request, server_request, @@ -887,6 +888,7 @@ mod tests { use crate::xds::istio::workload::Service as XdsService; use crate::xds::istio::workload::Workload as XdsWorkload; use crate::xds::istio::workload::{IpFamilies, NetworkAddress as XdsNetworkAddress}; + use crate::{drain, strng}; use crate::{metrics, test_helpers}; const NS1: &str = "ns1"; @@ -1308,7 +1310,7 @@ mod tests { let domain = "cluster.local".to_string(); let state = state(); let forwarder = forwarder(); - let (_signal, drain) = drain::channel(); + let (_signal, drain) = drain::new(); let factory = crate::proxy::DefaultSocketFactory; let proxy = Server::new( domain, @@ -1426,7 +1428,7 @@ mod tests { .await .unwrap(), ); - let (_signal, drain) = drain::channel(); + let (_signal, drain) = drain::new(); let factory = crate::proxy::DefaultSocketFactory; let server = Server::new( domain, @@ -1503,7 +1505,7 @@ mod tests { ips: HashMap::from([(n("large.com."), new_large_response())]), }); let domain = "cluster.local".to_string(); - let (_signal, drain) = drain::channel(); + let (_signal, drain) = drain::new(); let factory = crate::proxy::DefaultSocketFactory; let server = Server::new( domain, diff --git a/src/drain.rs b/src/drain.rs new file mode 100644 index 000000000..23bc49ba0 --- /dev/null +++ b/src/drain.rs @@ -0,0 +1,220 @@ +// Copyright Istio Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::Future; +use std::time::Duration; +use tokio::sync::watch; +use tracing::{debug, info, warn}; + +// #[derive(Debug)] +// pub struct DrainTrigger(internal::Signal); + +// impl DrainTrigger { +// /// start_drain_and_wait initiates a draining sequence. The future will not complete until the drain +// /// is complete (all outstanding DrainWatchers are dropped). +// pub async fn start_drain_and_wait(self) { +// self.0.drain().await +// } +// } + +// #[derive(Clone, Debug)] +pub use internal::DrainMode; +pub use internal::ReleaseShutdown as DrainBlocker; +pub use internal::Signal as DrainTrigger; +pub use internal::Watch as DrainWatcher; +// pub struct DrainWatcher(internal::Watch); +// +// impl DrainWatcher { +// /// wait_for_drain will return once a drain has been initiated. +// /// The drain will not complete until the returned DrainBlocker is dropped +// pub async fn wait_for_drain(self) -> DrainBlocker { +// DrainBlocker(self.0.signaled().await) +// } +// } + +// #[allow(dead_code)] +/// DrainBlocker provides a token that must be dropped to unblock the drain. +// pub struct DrainBlocker(internal::ReleaseShutdown); + +/// New constructs a new pair for draining +/// * DrainTrigger can be used to start a draining sequence and wait for it to complete. +/// * DrainWatcher should be held by anything that wants to participate in the draining. This can be cloned, +/// and a drain will not complete until all outstanding DrainWatchers are dropped. +pub fn new() -> (DrainTrigger, DrainWatcher) { + let (tx, rx) = internal::channel(); + (tx, rx) +} + +/// run_with_drain provides a wrapper to run a future with graceful shutdown/draining support. +/// A caller should construct a future with takes two arguments: +/// * drain: while holding onto this, the future is marked as active, which will block the server from shutting down. +/// Additionally, it can be watched (with drain.signaled()) to see when to start a graceful shutdown. +/// * force_shutdown: when this is triggered, the future must forcefully shutdown any ongoing work ASAP. +/// This means the graceful drain exceeded the hard deadline, and all work must terminate now. +/// This is only required for spawned() tasks; otherwise, the future is dropped entirely, canceling all work. +pub async fn run_with_drain( + component: String, + drain: DrainWatcher, + deadline: Duration, + make_future: F, +) where + F: FnOnce(DrainWatcher, watch::Receiver<()>) -> Fut, + Fut: Future, + O: Send + 'static, +{ + let (sub_drain_signal, sub_drain) = new(); + let (trigger_force_shutdown, force_shutdown) = watch::channel(()); + // Stop accepting once we drain. + // We will then allow connections up to `deadline` to terminate on their own. + // After that, they will be forcefully terminated. + let fut = make_future(sub_drain, force_shutdown); + tokio::select! { + _res = fut => {} + res = drain.wait_for_drain() => { + if res.mode() == DrainMode::Graceful { + debug!(component, "drain started, waiting {:?} for any connections to complete", deadline); + if tokio::time::timeout(deadline, sub_drain_signal.start_drain_and_wait(DrainMode::Graceful)).await.is_err() { + // Not all connections completed within time, we will force shut them down + warn!(component, "drain duration expired with pending connections, forcefully shutting down"); + } + } else { + debug!(component, "terminating"); + } + // Trigger force shutdown. In theory, this is only needed in the timeout case. However, + // it doesn't hurt to always trigger it. + let _ = trigger_force_shutdown.send(()); + + info!(component, "shutdown complete"); + drop(res); + } + }; +} + +mod internal { + use tokio::sync::{mpsc, watch}; + + /// Creates a drain channel. + /// + /// The `Signal` is used to start a drain, and the `Watch` will be notified + /// when a drain is signaled. + pub fn channel() -> (Signal, Watch) { + let (signal_tx, signal_rx) = watch::channel(None); + let (drained_tx, drained_rx) = mpsc::channel(1); + + let signal = Signal { + drained_rx, + signal_tx, + }; + let watch = Watch { + drained_tx, + signal_rx, + }; + (signal, watch) + } + + enum Never {} + + #[derive(Debug, Clone, Copy, PartialEq)] + pub enum DrainMode { + Immediate, + Graceful, + } + + /// Send a drain command to all watchers. + pub struct Signal { + drained_rx: mpsc::Receiver, + signal_tx: watch::Sender>, + } + + /// Watch for a drain command. + /// + /// All `Watch` instances must be dropped for a `Signal::signal` call to + /// complete. + #[derive(Clone)] + pub struct Watch { + drained_tx: mpsc::Sender, + signal_rx: watch::Receiver>, + } + + #[must_use = "ReleaseShutdown should be dropped explicitly to release the runtime"] + #[derive(Clone)] + #[allow(dead_code)] + pub struct ReleaseShutdown(mpsc::Sender, DrainMode); + + impl ReleaseShutdown { + pub fn mode(&self) -> DrainMode { + self.1 + } + } + + impl Signal { + /// Waits for all [`Watch`] instances to be dropped. + pub async fn closed(&mut self) { + self.signal_tx.closed().await; + } + + /// Asynchronously signals all watchers to begin draining gracefully and waits for all + /// handles to be dropped. + pub async fn start_drain_and_wait(mut self, mode: DrainMode) { + // Update the state of the signal watch so that all watchers are observe + // the change. + let _ = self.signal_tx.send(Some(mode)); + + // Wait for all watchers to release their drain handle. + match self.drained_rx.recv().await { + None => {} + Some(n) => match n {}, + } + } + } + + impl Watch { + /// Returns a `ReleaseShutdown` handle after the drain has been signaled. The + /// handle must be dropped when a shutdown action has been completed to + /// unblock graceful shutdown. + pub async fn wait_for_drain(mut self) -> ReleaseShutdown { + // This future completes once `Signal::signal` has been invoked so that + // the channel's state is updated. + let mode = self + .signal_rx + .wait_for(Option::is_some) + .await + .map(|mode| mode.expect("already asserted it is_some")) + // If we got an error, then the signal was dropped entirely. Presumably this means a graceful shutdown is not required. + .unwrap_or(DrainMode::Immediate); + + // Return a handle that holds the drain channel, so that the signal task + // is only notified when all handles have been dropped. + ReleaseShutdown(self.drained_tx, mode) + } + } + + impl std::fmt::Debug for Signal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Signal").finish_non_exhaustive() + } + } + + impl std::fmt::Debug for Watch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Watch").finish_non_exhaustive() + } + } + + impl std::fmt::Debug for ReleaseShutdown { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ReleaseShutdown").finish_non_exhaustive() + } + } +} diff --git a/src/hyper_util.rs b/src/hyper_util.rs index 768beb7f2..45cda2d75 100644 --- a/src/hyper_util.rs +++ b/src/hyper_util.rs @@ -22,9 +22,9 @@ use std::{ time::{Duration, Instant}, }; -use crate::config; +use crate::drain::DrainWatcher; +use crate::{config, proxy}; use bytes::Bytes; -use drain::Watch; use futures_util::TryFutureExt; use http_body_util::Full; use hyper::client; @@ -46,6 +46,9 @@ pub fn tls_server( tls_listener::builder(crate::tls::InboundAcceptor::new(cert_provider)) .listen(listener) + .take_while(|item| { + !matches!(item, Err(tls_listener::Error::ListenerError(e)) if proxy::util::is_runtime_shutdown(e)) + }) .filter_map(|conn| { // Avoid 'By default, if a client fails the TLS handshake, that is treated as an error, and the TlsListener will return an Err' match conn { @@ -181,7 +184,7 @@ pub fn plaintext_response(code: hyper::StatusCode, body: String) -> Response { name: String, binds: Vec, - drain_rx: Watch, + drain_rx: DrainWatcher, state: S, } @@ -189,7 +192,7 @@ impl Server { pub async fn bind( name: &str, addrs: config::Address, - drain_rx: Watch, + drain_rx: DrainWatcher, s: S, ) -> anyhow::Result { let mut binds = vec![]; @@ -240,7 +243,7 @@ impl Server { let f = f.clone(); tokio::spawn(async move { let stream = tokio_stream::wrappers::TcpListenerStream::new(bind); - let mut stream = stream.take_until(Box::pin(drain_stream.signaled())); + let mut stream = stream.take_until(Box::pin(drain_stream.wait_for_drain())); while let Some(Ok(socket)) = stream.next().await { socket.set_nodelay(true).unwrap(); let drain = drain_connections.clone(); @@ -267,7 +270,8 @@ impl Server { }), ); // Wait for drain to signal or connection serving to complete - match futures_util::future::select(Box::pin(drain.signaled()), serve).await + match futures_util::future::select(Box::pin(drain.wait_for_drain()), serve) + .await { // We got a shutdown request. Start gracful shutdown and wait for the pending requests to complete. futures_util::future::Either::Left((_shutdown, mut serve)) => { diff --git a/src/inpod/protocol.rs b/src/inpod/protocol.rs index 18fa7dcea..698a5889f 100644 --- a/src/inpod/protocol.rs +++ b/src/inpod/protocol.rs @@ -14,7 +14,7 @@ use super::istio::zds::{self, Ack, Version, WorkloadRequest, WorkloadResponse, ZdsHello}; use super::{WorkloadData, WorkloadMessage}; -use drain::Watch; +use crate::drain::DrainWatcher; use nix::sys::socket::{recvmsg, sendmsg, ControlMessageOwned, MsgFlags}; use prost::Message; use std::io::{IoSlice, IoSliceMut}; @@ -28,12 +28,12 @@ use zds::workload_request::Payload; #[allow(dead_code)] pub struct WorkloadStreamProcessor { stream: UnixStream, - drain: Watch, + drain: DrainWatcher, } #[allow(dead_code)] impl WorkloadStreamProcessor { - pub fn new(stream: UnixStream, drain: Watch) -> Self { + pub fn new(stream: UnixStream, drain: DrainWatcher) -> Self { WorkloadStreamProcessor { stream, drain } } @@ -91,7 +91,7 @@ impl WorkloadStreamProcessor { let res = loop { tokio::select! { biased; // check drain first, so we don't read from the socket if we are draining. - _ = self.drain.clone().signaled() => { + _ = self.drain.clone().wait_for_drain() => { info!("workload proxy manager: drain requested"); return Ok(None); } diff --git a/src/inpod/statemanager.rs b/src/inpod/statemanager.rs index 9cf42ffac..7bb18f5f7 100644 --- a/src/inpod/statemanager.rs +++ b/src/inpod/statemanager.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use drain::Signal; +use crate::drain; +use crate::drain::DrainTrigger; use std::sync::Arc; use tracing::{debug, info, Instrument}; @@ -28,7 +29,7 @@ use super::WorkloadUid; // Note: we can't drain on drop, as drain is async (it waits for the drain to finish). pub(super) struct WorkloadState { - drain: Signal, + drain: DrainTrigger, workload_netns_inode: libc::ino_t, } @@ -38,8 +39,13 @@ struct DrainingTasks { } impl DrainingTasks { - fn drain_workload(&mut self, workload_state: WorkloadState) { - let handle = tokio::spawn(workload_state.drain.drain()); + fn shutdown_workload(&mut self, workload_state: WorkloadState) { + // Workload is gone, so no need to gracefully clean it up + let handle = tokio::spawn( + workload_state + .drain + .start_drain_and_wait(drain::DrainMode::Immediate), + ); // before we push to draining, try to clear done entries, so the vector doesn't grow too much self.draining.retain(|x| !x.is_finished()); // add deleted pod to draining. we do this so we make sure to wait for it incase we @@ -151,7 +157,10 @@ impl WorkloadProxyManagerState { Ok(()) } WorkloadMessage::DelWorkload(workload_uid) => { - info!(uid = workload_uid.0, "pod delete request, draining proxy"); + info!( + uid = workload_uid.0, + "pod delete request, shutting down proxy" + ); if !self.snapshot_received { // TODO: consider if this is an error. if not, do this instead: // self.snapshot_names.remove(&workload_uid) @@ -183,17 +192,17 @@ impl WorkloadProxyManagerState { .workload_states .extract_if(|uid, _| !self.snapshot_names.contains(uid)) { - self.draining.drain_workload(workload_state); + self.draining.shutdown_workload(workload_state); } self.snapshot_names.clear(); self.update_proxy_count_metrics(); } pub async fn drain(self) { - let drain_futures = self - .workload_states - .into_iter() - .map(|(_, v)| v.drain.drain() /* do not .await here!!! */); + let drain_futures = + self.workload_states.into_iter().map(|(_, v)| { + v.drain.start_drain_and_wait(drain::DrainMode::Graceful) + } /* do not .await here!!! */); // join these first, as we need to drive these to completion futures::future::join_all(drain_futures).await; // these are join handles that are driven by tokio, we just need to wait for them, so join these @@ -258,7 +267,7 @@ impl WorkloadProxyManagerState { // We create a per workload drain here. If the main loop in WorkloadProxyManager::run drains, // we drain all these per-workload drains before exiting the loop let workload_netns_inode = netns.workload_inode(); - let (drain_tx, drain_rx) = drain::channel(); + let (drain_tx, drain_rx) = drain::new(); let proxies = self .proxy_gen @@ -347,7 +356,7 @@ impl WorkloadProxyManagerState { self.update_proxy_count_metrics(); - self.draining.drain_workload(workload_state); + self.draining.shutdown_workload(workload_state); } fn update_proxy_count_metrics(&self) { diff --git a/src/inpod/test_helpers.rs b/src/inpod/test_helpers.rs index d14f2082d..eddecd12c 100644 --- a/src/inpod/test_helpers.rs +++ b/src/inpod/test_helpers.rs @@ -30,6 +30,8 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::istio::zds::{WorkloadRequest, WorkloadResponse, ZdsHello}; +use crate::drain; +use crate::drain::{DrainTrigger, DrainWatcher}; use once_cell::sync::Lazy; use std::os::fd::{AsRawFd, OwnedFd}; use tracing::debug; @@ -42,8 +44,8 @@ pub struct Fixture { pub proxy_factory: ProxyFactory, pub ipc: InPodConfig, pub inpod_metrics: Arc, - pub drain_tx: drain::Signal, - pub drain_rx: drain::Watch, + pub drain_tx: DrainTrigger, + pub drain_rx: DrainWatcher, } // Ensure that the `tracing` stack is only initialised once using `once_cell` static UNSHARE: Lazy<()> = Lazy::new(|| { @@ -70,7 +72,7 @@ impl Default for Fixture { let cert_manager: Arc = crate::identity::mock::new_secret_manager(std::time::Duration::from_secs(10)); let metrics = Arc::new(crate::proxy::Metrics::new(&mut registry)); - let (drain_tx, drain_rx) = drain::channel(); + let (drain_tx, drain_rx) = drain::new(); let dstate = DemandProxyState::new( state.clone(), diff --git a/src/inpod/workloadmanager.rs b/src/inpod/workloadmanager.rs index 492f98b6f..831809fce 100644 --- a/src/inpod/workloadmanager.rs +++ b/src/inpod/workloadmanager.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use crate::drain::DrainWatcher; use crate::readiness; use backoff::{backoff::Backoff, ExponentialBackoff}; -use drain::Watch; use std::path::PathBuf; use std::time::Duration; use tokio::net::UnixStream; @@ -158,7 +158,7 @@ impl WorkloadProxyManager { Ok(mgr) } - pub async fn run(mut self, drain: Watch) -> Result<(), anyhow::Error> { + pub async fn run(mut self, drain: DrainWatcher) -> Result<(), anyhow::Error> { self.run_internal(drain).await?; // We broke the loop, this can only happen when drain was signaled @@ -174,7 +174,7 @@ impl WorkloadProxyManager { // - we have a ProtocolError (we have a serious version mismatch) // We should never _have_ a protocol error as the gRPC proto should be forwards+backwards compatible, // so this is mostly a safeguard - async fn run_internal(&mut self, drain: Watch) -> Result<(), anyhow::Error> { + async fn run_internal(&mut self, drain: DrainWatcher) -> Result<(), anyhow::Error> { // for now just drop block_ready, until we support knowing that our state is in sync. debug!("workload proxy manager is running"); // hold the release shutdown until we are done with `state.drain` below. @@ -183,7 +183,7 @@ impl WorkloadProxyManager { // Accept a connection let stream = tokio::select! { biased; // check the drain first - rs = drain.clone().signaled() => { + rs = drain.clone().wait_for_drain() => { info!("drain requested"); break rs; } @@ -390,6 +390,7 @@ pub(crate) mod tests { send_workload_added, send_workload_del, uid, }; + use crate::drain::DrainTrigger; use std::{collections::HashSet, sync::Arc}; fn assert_end_stream(res: Result<(), Error>) { @@ -412,8 +413,8 @@ pub(crate) mod tests { struct Fixture { state: WorkloadProxyManagerState, inpod_metrics: Arc, - drain_rx: drain::Watch, - _drain_tx: drain::Signal, + drain_rx: DrainWatcher, + _drain_tx: DrainTrigger, } macro_rules! fixture { diff --git a/src/lib.rs b/src/lib.rs index 7d21ed277..377edab51 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ pub mod cert_fetcher; pub mod config; pub mod copy; pub mod dns; +pub mod drain; pub mod hyper_util; pub mod identity; #[cfg(target_os = "linux")] diff --git a/src/metrics/server.rs b/src/metrics/server.rs index b14daee46..bd992360a 100644 --- a/src/metrics/server.rs +++ b/src/metrics/server.rs @@ -16,7 +16,6 @@ use bytes::Bytes; use std::sync::Mutex; use std::{net::SocketAddr, sync::Arc}; -use drain::Watch; use http_body_util::Full; use hyper::body::Incoming; use hyper::{Request, Response}; @@ -24,6 +23,7 @@ use prometheus_client::encoding::text::encode; use prometheus_client::registry::Registry; use crate::config::Config; +use crate::drain::DrainWatcher; use crate::hyper_util; pub struct Server { @@ -33,7 +33,7 @@ pub struct Server { impl Server { pub async fn new( config: Arc, - drain_rx: Watch, + drain_rx: DrainWatcher, registry: Registry, ) -> anyhow::Result { hyper_util::Server::>::bind( diff --git a/src/proxy.rs b/src/proxy.rs index b48f95b24..23be4821e 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -20,7 +20,6 @@ use std::sync::Arc; use std::time::Duration; use std::{fmt, io}; -use drain::Watch; use hickory_proto::error::ProtoError; use rand::Rng; @@ -35,6 +34,7 @@ pub use metrics::*; use crate::identity::{Identity, SecretManager}; use crate::dns::resolver::Resolver; +use crate::drain::DrainWatcher; use crate::proxy::connection_manager::{ConnectionManager, PolicyWatcher}; use crate::proxy::inbound_passthrough::InboundPassthrough; use crate::proxy::outbound::Outbound; @@ -55,7 +55,7 @@ pub mod metrics; mod outbound; pub mod pool; mod socks5; -mod util; +pub mod util; pub trait SocketFactory { fn new_tcp_v4(&self) -> std::io::Result; @@ -204,7 +204,7 @@ impl Proxy { state: DemandProxyState, cert_manager: Arc, metrics: Metrics, - drain: Watch, + drain: DrainWatcher, resolver: Option>, ) -> Result { let metrics = Arc::new(metrics); @@ -224,7 +224,10 @@ impl Proxy { } #[allow(unused_mut)] - pub(super) async fn from_inputs(mut pi: Arc, drain: Watch) -> Result { + pub(super) async fn from_inputs( + mut pi: Arc, + drain: DrainWatcher, + ) -> Result { // We setup all the listeners first so we can capture any errors that should block startup let inbound = Inbound::new(pi.clone(), drain.clone()).await?; diff --git a/src/proxy/connection_manager.rs b/src/proxy/connection_manager.rs index a88949192..a70357b6d 100644 --- a/src/proxy/connection_manager.rs +++ b/src/proxy/connection_manager.rs @@ -16,7 +16,6 @@ use crate::proxy::Error; use crate::state::DemandProxyState; use crate::state::ProxyRbacContext; -use drain; use serde::{Serialize, Serializer}; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; @@ -24,6 +23,8 @@ use std::fmt::Formatter; use std::future::Future; use std::net::SocketAddr; +use crate::drain; +use crate::drain::{DrainTrigger, DrainWatcher}; use std::sync::Arc; use std::sync::RwLock; use tracing::{debug, error, info, warn}; @@ -32,14 +33,14 @@ struct ConnectionDrain { // TODO: this should almost certainly be changed to a type which has counted references exposed. // tokio::sync::watch can be subscribed without taking a write lock and exposes references // and also a receiver_count method - tx: drain::Signal, - rx: drain::Watch, + tx: DrainTrigger, + rx: DrainWatcher, count: usize, } impl ConnectionDrain { fn new() -> Self { - let (tx, rx) = drain::channel(); + let (tx, rx) = drain::new(); ConnectionDrain { tx, rx, count: 1 } } @@ -47,8 +48,10 @@ impl ConnectionDrain { // always inline, this is for convenience so that we don't forget to drop the rx but there's really no reason it needs to grow the stack #[inline(always)] async fn drain(self) { - drop(self.rx); // very important, drain cannont complete if there are outstand rx - self.tx.drain().await; + drop(self.rx); // very important, drain cannot complete if there are outstand rx + self.tx + .start_drain_and_wait(drain::DrainMode::Immediate) + .await; } } @@ -76,7 +79,7 @@ impl Default for ConnectionManager { pub struct ConnectionGuard { cm: ConnectionManager, conn: InboundConnection, - watch: Option, + watch: Option, } impl ConnectionGuard { @@ -90,7 +93,7 @@ impl ConnectionGuard { self.cm.release(&self.conn); res } - _signaled = watch.signaled() => Err(Error::AuthorizationPolicyLateRejection) + _signaled = watch.wait_for_drain() => Err(Error::AuthorizationPolicyLateRejection) } } } @@ -194,7 +197,7 @@ impl ConnectionManager { // this must be done before a connection can be tracked // allows policy to be asserted against the connection // even no tasks have a receiver channel yet - fn register(&self, c: &InboundConnection) -> Option { + fn register(&self, c: &InboundConnection) -> Option { match self.drains.write().expect("mutex").entry(c.clone()) { Entry::Occupied(mut cd) => { cd.get_mut().count += 1; @@ -282,14 +285,14 @@ impl Serialize for ConnectionManager { pub struct PolicyWatcher { state: DemandProxyState, - stop: drain::Watch, + stop: DrainWatcher, connection_manager: ConnectionManager, } impl PolicyWatcher { pub fn new( state: DemandProxyState, - stop: drain::Watch, + stop: DrainWatcher, connection_manager: ConnectionManager, ) -> Self { PolicyWatcher { @@ -303,7 +306,7 @@ impl PolicyWatcher { let mut policies_changed = self.state.read().policies.subscribe(); loop { tokio::select! { - _ = self.stop.clone().signaled() => { + _ = self.stop.clone().wait_for_drain() => { break; } _ = policies_changed.changed() => { @@ -322,7 +325,8 @@ impl PolicyWatcher { #[cfg(test)] mod tests { - use drain::Watch; + use crate::drain; + use crate::drain::DrainWatcher; use hickory_resolver::config::{ResolverConfig, ResolverOpts}; use prometheus_client::registry::Registry; use std::net::{Ipv4Addr, SocketAddrV4}; @@ -559,7 +563,7 @@ mod tests { metrics, ); let connection_manager = ConnectionManager::default(); - let (tx, stop) = drain::channel(); + let (tx, stop) = drain::new(); let state_mutator = ProxyStateUpdateMutator::new_no_fetch(); // clones to move into spawned task @@ -620,12 +624,12 @@ mod tests { } // release lock // send the signal which stops policy watcher - tx.drain().await; + tx.start_drain_and_wait(drain::DrainMode::Immediate).await; } // small helper to assert that the Watches are working in a timely manner - async fn assert_close(c: Watch) { - let result = tokio::time::timeout(Duration::from_secs(1), c.signaled()).await; + async fn assert_close(c: DrainWatcher) { + let result = tokio::time::timeout(Duration::from_secs(1), c.wait_for_drain()).await; assert!(result.is_ok()) } } diff --git a/src/proxy/h2/server.rs b/src/proxy/h2/server.rs index 52c49ff56..9acd9fe62 100644 --- a/src/proxy/h2/server.rs +++ b/src/proxy/h2/server.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::config; +use crate::drain::DrainWatcher; use crate::proxy::Error; use bytes::Bytes; use futures_util::FutureExt; @@ -22,8 +23,7 @@ use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::net::TcpStream; -use tokio::sync::oneshot; -use tokio::time::timeout; +use tokio::sync::{oneshot, watch}; use tracing::{debug, warn}; pub struct H2Request { @@ -75,7 +75,8 @@ impl H2Request { pub async fn serve_connection( cfg: Arc, s: tokio_rustls::server::TlsStream, - drain: drain::Watch, + drain: DrainWatcher, + mut force_shutdown: watch::Receiver<()>, handler: F, ) -> Result<(), Error> where @@ -83,7 +84,6 @@ where Fut: Future + Send + 'static, { let mut builder = h2::server::Builder::new(); - let drain_deadline = cfg.self_termination_deadline; let mut conn = builder .initial_window_size(cfg.window_size) .initial_connection_window_size(cfg.connection_window_size) @@ -138,7 +138,7 @@ where conn.abrupt_shutdown(h2::Reason::NO_ERROR); break } - _shutdown = drain.signaled() => { + _shutdown = drain.wait_for_drain() => { debug!("starting graceful drain..."); conn.graceful_shutdown(); break; @@ -148,9 +148,13 @@ where // Signal to the ping_pong it should also stop. dropped.store(true, Ordering::Relaxed); let poll_closed = futures_util::future::poll_fn(move |cx| conn.poll_closed(cx)); - timeout(drain_deadline, poll_closed) - .await - .map_err(|_| Error::DrainTimeOut)??; + tokio::select! { + _ = force_shutdown.changed() => { + return Err(Error::DrainTimeOut) + } + _ = poll_closed => {} + } + // Mark we are done with the connection drop(drain); Ok(()) } diff --git a/src/proxy/inbound.rs b/src/proxy/inbound.rs index 32d9993e5..d51224227 100644 --- a/src/proxy/inbound.rs +++ b/src/proxy/inbound.rs @@ -16,12 +16,12 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; -use drain::Watch; use futures::stream::StreamExt; use http::{Method, Response, StatusCode}; use tokio::net::TcpStream; +use tokio::sync::watch; use tracing::{debug, info, instrument, trace_span, Instrument}; @@ -29,6 +29,7 @@ use super::{Error, ScopedSecretManager}; use crate::baggage::parse_baggage_header; use crate::identity::Identity; +use crate::drain::DrainWatcher; use crate::proxy::h2::server::H2Request; use crate::proxy::metrics::{ConnectionOpen, Reporter}; use crate::proxy::{metrics, ProxyInputs, TraceParent, BAGGAGE_HEADER, TRACEPARENT_HEADER}; @@ -39,6 +40,7 @@ use crate::state::workload::address::Address; use crate::state::workload::application_tunnel::Protocol as AppProtocol; use crate::{assertions, copy, proxy, socket, strng, tls}; +use crate::drain::run_with_drain; use crate::proxy::h2; use crate::state::workload::{self, NetworkAddress, Workload}; use crate::state::DemandProxyState; @@ -47,13 +49,13 @@ use crate::tls::TlsError; pub(super) struct Inbound { listener: socket::Listener, - drain: Watch, + drain: DrainWatcher, pi: Arc, enable_orig_src: bool, } impl Inbound { - pub(super) async fn new(pi: Arc, drain: Watch) -> Result { + pub(super) async fn new(pi: Arc, drain: DrainWatcher) -> Result { let listener = pi .socket_factory .tcp_bind(pi.cfg.inbound_addr) @@ -79,6 +81,7 @@ impl Inbound { } pub(super) async fn run(self) { + let pi = self.pi.clone(); let acceptor = InboundCertProvider { state: self.pi.state.clone(), cert_manager: self.pi.cert_manager.clone(), @@ -87,46 +90,53 @@ impl Inbound { // Safety: we set nodelay directly in tls_server, so it is safe to convert to a normal listener. // Although, that is *after* the TLS handshake; in theory we may get some benefits to setting it earlier. - let stream = crate::hyper_util::tls_server(acceptor, self.listener.inner()); - let mut stream = stream.take_until(Box::pin(self.drain.signaled())); - - let (sub_drain_signal, sub_drain) = drain::channel(); - - while let Some(tls) = stream.next().await { - let pi = self.pi.clone(); - let (raw_socket, ssl) = tls.get_ref(); - let src_identity: Option = tls::identity_from_connection(ssl); - let dst = crate::socket::orig_dst_addr_or_default(raw_socket); - let src = to_canonical(raw_socket.peer_addr().expect("peer_addr available")); - let drain = sub_drain.clone(); - let network = pi.cfg.network.clone(); - let serve_client = async move { - let conn = Connection { - src_identity, - src, - dst_network: strng::new(&network), // inbound request must be on our network - dst, - }; - debug!(%conn, "accepted connection"); - let cfg = pi.cfg.clone(); - let request_handler = move |req| { - Self::serve_connect(pi.clone(), conn.clone(), self.enable_orig_src, req) - }; - let serve = Box::pin(h2::server::serve_connection( - cfg, - tls, - drain, - request_handler, - )); - serve.await - }; - assertions::size_between_ref(1000, 1500, &serve_client); - tokio::task::spawn(serve_client.in_current_span()); - } - info!("draining connections"); - drop(sub_drain); // sub_drain_signal.drain() will never resolve while sub_drain is valid, will deadlock if not dropped - sub_drain_signal.drain().await; - info!("all inbound connections drained"); + let mut stream = crate::hyper_util::tls_server(acceptor, self.listener.inner()); + + let accept = |drain: DrainWatcher, force_shutdown: watch::Receiver<()>| { + async move { + while let Some(tls) = stream.next().await { + let pi = self.pi.clone(); + let (raw_socket, ssl) = tls.get_ref(); + let src_identity: Option = tls::identity_from_connection(ssl); + let dst = crate::socket::orig_dst_addr_or_default(raw_socket); + let src = to_canonical(raw_socket.peer_addr().expect("peer_addr available")); + let drain = drain.clone(); + let force_shutdown = force_shutdown.clone(); + let network = pi.cfg.network.clone(); + let serve_client = async move { + let conn = Connection { + src_identity, + src, + dst_network: strng::new(&network), // inbound request must be on our network + dst, + }; + debug!(%conn, "accepted connection"); + let cfg = pi.cfg.clone(); + let request_handler = move |req| { + Self::serve_connect(pi.clone(), conn.clone(), self.enable_orig_src, req) + }; + let serve = Box::pin(h2::server::serve_connection( + cfg, + tls, + drain, + force_shutdown, + request_handler, + )); + serve.await + }; + assertions::size_between_ref(1000, 1500, &serve_client); + tokio::task::spawn(serve_client.in_current_span()); + } + } + }; + + run_with_drain( + "inbound".to_string(), + self.drain, + pi.cfg.self_termination_deadline, + accept, + ) + .await } fn extract_traceparent(req: &H2Request) -> TraceParent { diff --git a/src/proxy/inbound_passthrough.rs b/src/proxy/inbound_passthrough.rs index 5d297c45d..f5f8144b3 100644 --- a/src/proxy/inbound_passthrough.rs +++ b/src/proxy/inbound_passthrough.rs @@ -16,15 +16,15 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; -use drain::Watch; use tokio::net::TcpStream; use tokio::sync::watch; -use tokio::time::timeout; -use tracing::{debug, error, info, trace, warn, Instrument}; +use tracing::{debug, error, info, trace, Instrument}; use crate::config::ProxyMode; +use crate::drain::run_with_drain; +use crate::drain::DrainWatcher; use crate::proxy::metrics::Reporter; use crate::proxy::Error; use crate::proxy::{metrics, util, ProxyInputs}; @@ -35,14 +35,14 @@ use crate::{proxy, socket}; pub(super) struct InboundPassthrough { listener: socket::Listener, pi: Arc, - drain: Watch, + drain: DrainWatcher, enable_orig_src: bool, } impl InboundPassthrough { pub(super) async fn new( pi: Arc, - drain: Watch, + drain: DrainWatcher, ) -> Result { let listener = pi .socket_factory @@ -66,72 +66,55 @@ impl InboundPassthrough { } pub(super) async fn run(self) { - let (sub_drain_signal, sub_drain) = drain::channel(); - let deadline = self.pi.cfg.self_termination_deadline; - let (trigger_force_shutdown, force_shutdown) = watch::channel(()); - let accept = async move { - loop { - // Asynchronously wait for an inbound socket. - let socket = self.listener.accept().await; - let start = Instant::now(); - let mut force_shutdown = force_shutdown.clone(); - let drain = sub_drain.clone(); - let pi = self.pi.clone(); - match socket { - Ok((stream, remote)) => { - let serve_client = async move { - debug!(dur=?start.elapsed(), "inbound passthrough connection started"); - // Since this task is spawned, make sure we are guaranteed to terminate - tokio::select! { - _ = force_shutdown.changed() => { - debug!("inbound passthrough connection forcefully terminated signaled"); + let pi = self.pi.clone(); + let accept = |drain: DrainWatcher, force_shutdown: watch::Receiver<()>| { + async move { + loop { + // Asynchronously wait for an inbound socket. + let socket = self.listener.accept().await; + let start = Instant::now(); + let mut force_shutdown = force_shutdown.clone(); + let drain = drain.clone(); + let pi = self.pi.clone(); + match socket { + Ok((stream, remote)) => { + let serve_client = async move { + debug!(component="inbound passthrough", "connection started"); + // Since this task is spawned, make sure we are guaranteed to terminate + tokio::select! { + _ = force_shutdown.changed() => { + debug!(component="inbound passthrough", "connection forcefully terminated"); + } + _ = Self::proxy_inbound_plaintext(pi, socket::to_canonical(remote), stream, self.enable_orig_src) => { + } } - _ = Self::proxy_inbound_plaintext( - pi, // pi cloned above; OK to move - socket::to_canonical(remote), - stream, - self.enable_orig_src, - ) => {} + // Mark we are done with the connection, so drain can complete + drop(drain); + debug!(component="inbound passthrough", dur=?start.elapsed(), "connection completed"); } - // Mark we are done with the connection, so drain can complete - drop(drain); - debug!(dur=?start.elapsed(), "inbound passthrough connection completed"); - } - .in_current_span(); + .in_current_span(); - assertions::size_between_ref(1500, 3000, &serve_client); - tokio::spawn(serve_client); - } - Err(e) => { - if util::is_runtime_shutdown(&e) { - return; + assertions::size_between_ref(1500, 3000, &serve_client); + tokio::spawn(serve_client); + } + Err(e) => { + if util::is_runtime_shutdown(&e) { + return; + } + error!("Failed TCP handshake {}", e); } - error!("Failed TCP handshake {}", e); } } - } - } - .in_current_span(); - - // Stop accepting once we drain. - // We will then allow connections up to `deadline` to terminate on their own. - // After that, they will be forcefully terminated. - tokio::select! { - res = accept => { res } - res = self.drain.signaled() => { - debug!("inbound passthrough drained, waiting {:?} for any outbound connections to close", deadline); - if let Err(e) = timeout(deadline, sub_drain_signal.drain()).await { - // Not all connections completed within time, we will force shut them down - warn!("drain duration expired with pending connections, forcefully shutting down: {e:?}"); - } - // Trigger force shutdown. In theory, this is only needed in the timeout case. However, - // it doesn't hurt to always trigger it. - let _ = trigger_force_shutdown.send(()); + }.in_current_span() + }; - info!("outbound drain complete"); - drop(res); - } - } + run_with_drain( + "inbound passthrough".to_string(), + self.drain, + pi.cfg.self_termination_deadline, + accept, + ) + .await } async fn proxy_inbound_plaintext( diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs index 2a2124e13..d265d7c73 100644 --- a/src/proxy/outbound.rs +++ b/src/proxy/outbound.rs @@ -17,15 +17,12 @@ use std::sync::Arc; use std::time::{Duration, Instant}; -use drain::Watch; - use hyper::header::FORWARDED; use tokio::net::TcpStream; use tokio::sync::watch; -use tokio::time::timeout; -use tracing::{debug, error, info, info_span, trace_span, warn, Instrument}; +use tracing::{debug, error, info, info_span, trace_span, Instrument}; use crate::config::ProxyMode; use crate::identity::Identity; @@ -34,6 +31,8 @@ use crate::proxy::metrics::Reporter; use crate::proxy::{metrics, pool, ConnectionOpen, ConnectionResult, DerivedWorkload}; use crate::proxy::{util, Error, ProxyInputs, TraceParent, BAGGAGE_HEADER, TRACEPARENT_HEADER}; +use crate::drain::run_with_drain; +use crate::drain::DrainWatcher; use crate::proxy::h2::H2Stream; use crate::state::service::ServiceDescription; use crate::state::workload::{address::Address, NetworkAddress, Protocol, Workload}; @@ -42,13 +41,13 @@ use crate::{assertions, copy, proxy, socket}; pub struct Outbound { pi: Arc, - drain: Watch, + drain: DrainWatcher, listener: socket::Listener, enable_orig_src: bool, } impl Outbound { - pub(super) async fn new(pi: Arc, drain: Watch) -> Result { + pub(super) async fn new(pi: Arc, drain: DrainWatcher) -> Result { let listener = pi .socket_factory .tcp_bind(pi.cfg.outbound_addr) @@ -76,86 +75,67 @@ impl Outbound { } pub(super) async fn run(self) { - // Since we are spawning autonomous tasks to handle outbound connections for a single workload, - // we can have situations where the workload is deleted, but a task is still "stuck" - // waiting for a server response stream on a HTTP/2 connection or whatnot. - // - // So use a drain to nuke tasks that might be stuck sending. - let (sub_drain_signal, sub_drain) = drain::channel(); let pool = proxy::pool::WorkloadHBONEPool::new( self.pi.cfg.clone(), self.enable_orig_src, self.pi.socket_factory.clone(), self.pi.cert_manager.clone(), ); - let deadline = self.pi.cfg.self_termination_deadline; - let (trigger_force_shutdown, force_shutdown) = watch::channel(()); - let accept = async move { - loop { - // Asynchronously wait for an inbound socket. - let socket = self.listener.accept().await; - let start_outbound_instant = Instant::now(); - let drain = sub_drain.clone(); - let mut force_shutdown = force_shutdown.clone(); - match socket { - Ok((stream, _remote)) => { - let mut oc = OutboundConnection { - pi: self.pi.clone(), - id: TraceParent::new(), - pool: pool.clone(), - enable_orig_src: self.enable_orig_src, - hbone_port: self.pi.cfg.inbound_addr.port(), - }; - stream.set_nodelay(true).unwrap(); - let span = info_span!("outbound", id=%oc.id); - let serve_outbound_connection = (async move { - debug!(dur=?start_outbound_instant.elapsed(), "outbound connection started"); - // Since this task is spawned, make sure we are guaranteed to terminate - tokio::select! { - _ = force_shutdown.changed() => { - debug!("outbound connection forcefully terminated signaled"); + let pi = self.pi.clone(); + let accept = |drain: DrainWatcher, force_shutdown: watch::Receiver<()>| { + async move { + loop { + // Asynchronously wait for an inbound socket. + let socket = self.listener.accept().await; + let start = Instant::now(); + let drain = drain.clone(); + let mut force_shutdown = force_shutdown.clone(); + match socket { + Ok((stream, _remote)) => { + let mut oc = OutboundConnection { + pi: self.pi.clone(), + id: TraceParent::new(), + pool: pool.clone(), + enable_orig_src: self.enable_orig_src, + hbone_port: self.pi.cfg.inbound_addr.port(), + }; + let span = info_span!("outbound", id=%oc.id); + let serve_outbound_connection = (async move { + debug!(component="outbound", "connection started"); + // Since this task is spawned, make sure we are guaranteed to terminate + tokio::select! { + _ = force_shutdown.changed() => { + debug!(component="outbound", "connection forcefully terminated"); + } + _ = oc.proxy(stream) => {} } - _ = oc.proxy(stream) => {} - } - // Mark we are done with the connection, so drain can complete - drop(drain); - debug!(dur=?start_outbound_instant.elapsed(), "outbound connection completed"); - }) - .instrument(span); + // Mark we are done with the connection, so drain can complete + drop(drain); + debug!(component="outbound", dur=?start.elapsed(), "connection completed"); + }).instrument(span); - assertions::size_between_ref(1000, 1750, &serve_outbound_connection); - tokio::spawn(serve_outbound_connection); - } - Err(e) => { - if util::is_runtime_shutdown(&e) { - return; + assertions::size_between_ref(1000, 1750, &serve_outbound_connection); + tokio::spawn(serve_outbound_connection); + } + Err(e) => { + if util::is_runtime_shutdown(&e) { + return; + } + error!("Failed TCP handshake {}", e); } - error!("Failed TCP handshake {}", e); } } } - } - .in_current_span(); - - // Stop accepting once we drain. - // We will then allow connections up to `deadline` to terminate on their own. - // After that, they will be forcefully terminated. - tokio::select! { - res = accept => { res } - res = self.drain.signaled() => { - debug!("outbound drained, waiting {:?} for any outbound connections to close", deadline); - if let Err(e) = timeout(deadline, sub_drain_signal.drain()).await { - // Not all connections completed within time, we will force shut them down - warn!("drain duration expired with pending connections, forcefully shutting down: {e:?}"); - } - // Trigger force shutdown. In theory, this is only needed in the timeout case. However, - // it doesn't hurt to always trigger it. - let _ = trigger_force_shutdown.send(()); + .in_current_span() + }; - info!("outbound drain complete"); - drop(res); - } - } + run_with_drain( + "outbound".to_string(), + self.drain, + pi.cfg.self_termination_deadline, + accept, + ) + .await } } @@ -175,36 +155,7 @@ impl OutboundConnection { self.proxy_to(source_stream, source_addr, dst_addr).await; } - // this is a cancellable outbound proxy. If `out_drain` is a Watch drain, will resolve - // when the drain is signaled, or the outbound stream is completed, no matter what. - // - // If `out_drain` is none, will only resolve when the outbound stream is terminated. - // - // If using `proxy_to` in `tokio::spawn` tasks, it is recommended to use a drain, to guarantee termination - // and prevent "zombie" outbound tasks. - pub async fn proxy_to_cancellable( - &mut self, - stream: TcpStream, - remote_addr: SocketAddr, - orig_dst_addr: SocketAddr, - out_drain: Option, - ) { - match out_drain { - Some(drain) => { - tokio::select! { - _ = drain.signaled() => { - info!("drain signaled"); - } - res = self.proxy_to(stream, remote_addr, orig_dst_addr) => res - } - } - None => { - self.proxy_to(stream, remote_addr, orig_dst_addr).await; - } - } - } - - async fn proxy_to( + pub async fn proxy_to( &mut self, source_stream: TcpStream, source_addr: SocketAddr, diff --git a/src/proxy/pool.rs b/src/proxy/pool.rs index ed5bde88a..977a9b1c2 100644 --- a/src/proxy/pool.rs +++ b/src/proxy/pool.rs @@ -564,9 +564,8 @@ mod test { use std::net::SocketAddr; use std::time::Instant; - use crate::{identity, proxy}; + use crate::{drain, identity, proxy}; - use drain::Watch; use futures_util::{future, StreamExt}; use hyper::body::Incoming; @@ -584,6 +583,7 @@ mod test { use crate::test_helpers::helpers::initialize_telemetry; + use crate::drain::DrainWatcher; use ztunnel::test_helpers::*; use super::*; @@ -765,7 +765,7 @@ mod test { let (pool, mut srv) = setup_test_with_idle(4, Duration::from_millis(100)).await; let key = key(&srv, 1); - let (client_stop_signal, client_stop) = drain::channel(); + let (client_stop_signal, client_stop) = drain::new(); // Spin up 1 connection spawn_persistent_client(pool.clone(), key.clone(), srv.addr, client_stop).await; spawn_clients_concurrently(pool.clone(), key.clone(), srv.addr, 2).await; @@ -776,7 +776,9 @@ mod test { assert_opens_drops!(srv, 2, 1); // Trigger the persistent client to stop, we should evict that connection as well - client_stop_signal.drain().await; + client_stop_signal + .start_drain_and_wait(drain::DrainMode::Immediate) + .await; assert_opens_drops!(srv, 2, 1); } @@ -847,7 +849,7 @@ mod test { mut pool: WorkloadHBONEPool, key: WorkloadKey, remote_addr: SocketAddr, - stop: Watch, + stop: DrainWatcher, ) { let req = || { http::Request::builder() @@ -866,7 +868,7 @@ mod test { start.elapsed().as_millis() ); tokio::spawn(async move { - let _ = stop.signaled().await; + let _ = stop.wait_for_drain().await; debug!("persistent client stop"); // Close our connection drop(c1); diff --git a/src/proxy/socks5.rs b/src/proxy/socks5.rs index de9c09bf7..5e0f210b0 100644 --- a/src/proxy/socks5.rs +++ b/src/proxy/socks5.rs @@ -14,7 +14,6 @@ use anyhow::Result; use byteorder::{BigEndian, ByteOrder}; -use drain::Watch; use hickory_proto::op::{Message, MessageType, Query}; use hickory_proto::rr::{Name, RecordType}; @@ -23,26 +22,30 @@ use hickory_server::authority::MessageRequest; use hickory_server::server::{Protocol, Request}; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; +use std::time::Instant; use crate::dns::resolver::Resolver; use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; -use tracing::{debug, error, info}; +use tokio::sync::watch; +use tracing::{debug, error, info, info_span, Instrument}; +use crate::drain::run_with_drain; +use crate::drain::DrainWatcher; use crate::proxy::outbound::OutboundConnection; use crate::proxy::{util, Error, ProxyInputs, TraceParent}; -use crate::socket; +use crate::{assertions, socket}; pub(super) struct Socks5 { pi: Arc, listener: socket::Listener, - drain: Watch, + drain: DrainWatcher, enable_orig_src: bool, } impl Socks5 { - pub(super) async fn new(pi: Arc, drain: Watch) -> Result { + pub(super) async fn new(pi: Arc, drain: DrainWatcher) -> Result { let listener = pi .socket_factory .tcp_bind(pi.cfg.socks5_addr.unwrap()) @@ -72,54 +75,66 @@ impl Socks5 { } pub async fn run(self) { - let inner_drain = self.drain.clone(); - let inpod = self.pi.cfg.inpod_enabled; - let accept = async move { - loop { - // Asynchronously wait for an inbound socket. - let socket = self.listener.accept().await; - let stream_drain = inner_drain.clone(); - // TODO creating a new HBONE pool for SOCKS5 here may not be ideal, - // but ProxyInfo is overloaded and only `outbound` should ever use the pool. - let pool = crate::proxy::pool::WorkloadHBONEPool::new( - self.pi.cfg.clone(), - self.enable_orig_src, - self.pi.socket_factory.clone(), - self.pi.cert_manager.clone(), - ); - match socket { - Ok((stream, remote)) => { - debug!("accepted outbound connection from {}", remote); - let oc = OutboundConnection { - pi: self.pi.clone(), - id: TraceParent::new(), - pool, - enable_orig_src: self.enable_orig_src, - hbone_port: self.pi.cfg.inbound_addr.port(), - }; - tokio::spawn(async move { - if let Err(err) = handle(oc, stream, stream_drain, inpod).await { - log::error!("handshake error: {}", err); + let pi = self.pi.clone(); + let pool = crate::proxy::pool::WorkloadHBONEPool::new( + self.pi.cfg.clone(), + self.enable_orig_src, + self.pi.socket_factory.clone(), + self.pi.cert_manager.clone(), + ); + let accept = |drain: DrainWatcher, force_shutdown: watch::Receiver<()>| { + async move { + loop { + // Asynchronously wait for an inbound socket. + let socket = self.listener.accept().await; + let start = Instant::now(); + let drain = drain.clone(); + let mut force_shutdown = force_shutdown.clone(); + match socket { + Ok((stream, _remote)) => { + let oc = OutboundConnection { + pi: self.pi.clone(), + id: TraceParent::new(), + pool: pool.clone(), + enable_orig_src: self.enable_orig_src, + hbone_port: self.pi.cfg.inbound_addr.port(), + }; + let span = info_span!("socks5", id=%oc.id); + let serve = (async move { + debug!(component="socks5", "connection started"); + // Since this task is spawned, make sure we are guaranteed to terminate + tokio::select! { + _ = force_shutdown.changed() => { + debug!(component="socks5", "connection forcefully terminated"); + } + _ = handle(oc, stream) => {} + } + // Mark we are done with the connection, so drain can complete + drop(drain); + debug!(component="socks5", dur=?start.elapsed(), "connection completed"); + }).instrument(span); + + assertions::size_between_ref(1000, 2000, &serve); + tokio::spawn(serve); + } + Err(e) => { + if util::is_runtime_shutdown(&e) { + return; } - }); - } - Err(e) => { - if util::is_runtime_shutdown(&e) { - return; + error!("Failed TCP handshake {}", e); } - error!("Failed TCP handshake {}", e); } } } }; - tokio::select! { - res = accept => { res } - _ = self.drain.signaled() => { - // out_drain_signal.drain().await; - info!("socks5 drained"); - } - } + run_with_drain( + "socks5".to_string(), + self.drain, + pi.cfg.self_termination_deadline, + accept, + ) + .await } } @@ -127,12 +142,7 @@ impl Socks5 { // sufficient to integrate with common clients: // - only unauthenticated requests // - only CONNECT, with IPv4 or IPv6 -async fn handle( - mut oc: OutboundConnection, - mut stream: TcpStream, - out_drain: Watch, - is_inpod: bool, -) -> Result<(), anyhow::Error> { +async fn handle(mut oc: OutboundConnection, mut stream: TcpStream) -> Result<(), anyhow::Error> { let remote_addr = socket::to_canonical(stream.peer_addr().expect("must receive peer addr")); // Version(5), Number of auth methods @@ -238,17 +248,7 @@ async fn handle( stream.write_all(&buf).await?; debug!("accepted connection from {remote_addr} to {host}"); - // For inpod, we want this `spawn` to guaranteed-terminate when we drain - the workload is gone. - // For non-inpod (shared instance for all workloads), let the spawned task run until the proxy process - // itself is killed, or the connection terminates normally. - tokio::spawn(async move { - let drain = match is_inpod { - true => Some(out_drain), - false => None, - }; - oc.proxy_to_cancellable(stream, remote_addr, host, drain) - .await; - }); + oc.proxy_to(stream, remote_addr, host).await; Ok(()) } diff --git a/src/proxyfactory.rs b/src/proxyfactory.rs index 4f3c1da66..803153e56 100644 --- a/src/proxyfactory.rs +++ b/src/proxyfactory.rs @@ -15,11 +15,11 @@ use crate::config; use crate::identity::SecretManager; use crate::state::{DemandProxyState, WorkloadInfo}; -use drain::Watch; use std::sync::Arc; use tracing::error; use crate::dns; +use crate::drain::DrainWatcher; use crate::proxy::connection_manager::ConnectionManager; use crate::proxy::{Error, Metrics}; @@ -34,7 +34,7 @@ pub struct ProxyFactory { cert_manager: Arc, proxy_metrics: Arc, dns_metrics: Option>, - drain: Watch, + drain: DrainWatcher, } impl ProxyFactory { @@ -44,7 +44,7 @@ impl ProxyFactory { cert_manager: Arc, proxy_metrics: Arc, dns_metrics: Option, - drain: Watch, + drain: DrainWatcher, ) -> std::io::Result { let dns_metrics = match dns_metrics { Some(metrics) => Some(Arc::new(metrics)), @@ -73,7 +73,7 @@ impl ProxyFactory { pub async fn new_proxies_from_factory( &self, - proxy_drain: Option, + proxy_drain: Option, proxy_workload_info: Option, socket_factory: Arc, ) -> Result { diff --git a/src/readiness/server.rs b/src/readiness/server.rs index c9f25d790..3dc4f4416 100644 --- a/src/readiness/server.rs +++ b/src/readiness/server.rs @@ -16,12 +16,12 @@ use std::net::SocketAddr; use std::sync::Arc; use bytes::Bytes; -use drain::Watch; use http_body_util::Full; use hyper::body::Incoming; use hyper::{Request, Response}; use itertools::Itertools; +use crate::drain::DrainWatcher; use crate::hyper_util; use crate::{config, readiness}; @@ -33,7 +33,7 @@ pub struct Server { impl Server { pub async fn new( config: Arc, - drain_rx: Watch, + drain_rx: DrainWatcher, ready: readiness::Ready, ) -> anyhow::Result { hyper_util::Server::::bind( diff --git a/src/signal.rs b/src/signal.rs index fcf4eb4cb..57fb66541 100644 --- a/src/signal.rs +++ b/src/signal.rs @@ -59,7 +59,7 @@ pub struct ShutdownTrigger { impl ShutdownTrigger { pub async fn shutdown_now(&self) { - self.shutdown_tx.send(()).await.unwrap(); + let _ = self.shutdown_tx.send(()).await; } } diff --git a/src/test_helpers/dns.rs b/src/test_helpers/dns.rs index 516684002..a1a285a7d 100644 --- a/src/test_helpers/dns.rs +++ b/src/test_helpers/dns.rs @@ -15,10 +15,11 @@ use crate::config::Address; use crate::dns::resolver::{Answer, Resolver}; use crate::dns::Metrics; +use crate::drain::DrainTrigger; use crate::proxy::Error; use crate::state::workload::Workload; use crate::test_helpers::new_proxy_state; -use crate::{dns, metrics}; +use crate::{dns, drain, metrics}; use futures_util::ready; use futures_util::stream::{Stream, StreamExt}; use hickory_client::client::{AsyncClient, ClientHandle}; @@ -213,7 +214,7 @@ pub struct TestDnsServer { tcp: SocketAddr, udp: SocketAddr, resolver: Arc, - _drain: drain::Signal, + _drain: DrainTrigger, } impl TestDnsServer { @@ -249,7 +250,7 @@ pub async fn run_dns(responses: HashMap>) -> anyhow::Result + Send + 'static>( info!("ack received, len {}", read_amount); // Now await for FDs while let Some((uid, fd)) = rx.recv().await { + let orig_uid = uid.clone(); let uid = crate::inpod::WorkloadUid::new(uid); if fd >= 0 { send_workload_added(&mut ztun_sock, uid, fd).await; @@ -83,8 +84,8 @@ pub fn start_ztunnel_server + Send + 'static>( }; // receive ack from ztunnel - let ack = read_msg(&mut ztun_sock).await; - info!("ack received, len {:?}", ack); + let _ = read_msg(&mut ztun_sock).await; + info!(uid=orig_uid, %fd, "ack received"); rx.ack().await.expect("ack failed"); } }); diff --git a/src/test_helpers/linux.rs b/src/test_helpers/linux.rs index c97453a4c..6057a0b1b 100644 --- a/src/test_helpers/linux.rs +++ b/src/test_helpers/linux.rs @@ -22,6 +22,9 @@ use crate::test_helpers::*; use crate::xds::{LocalConfig, LocalWorkload}; use crate::{config, identity, proxy, strng}; +use crate::signal::ShutdownTrigger; +use crate::test_helpers::inpod::start_ztunnel_server; +use crate::test_helpers::linux::TestMode::InPod; use itertools::Itertools; use nix::unistd::mkdtemp; use std::net::IpAddr; @@ -29,10 +32,6 @@ use std::os::fd::AsRawFd; use std::path::PathBuf; use std::thread; use std::time::Duration; - -use crate::signal::ShutdownTrigger; -use crate::test_helpers::inpod::start_ztunnel_server; -use crate::test_helpers::linux::TestMode::InPod; use tokio::sync::Mutex; use tracing::info; @@ -219,7 +218,20 @@ impl WorkloadManager { } pub async fn delete_workload(&mut self, name: &str) -> anyhow::Result<()> { - self.workloads.retain(|w| w.workload.name != name); + let mut workloads = vec![]; + std::mem::swap(&mut self.workloads, &mut workloads); + let (keep, drop) = workloads.into_iter().partition(|w| w.workload.name != name); + self.workloads = keep; + for d in drop { + if let Some(zt) = self.ztunnels.get_mut(&d.workload.node.to_string()).as_mut() { + zt.fd_sender + .as_mut() + .unwrap() + .send_and_wait((d.workload.uid.to_string(), -1)) // Test server handles -1 as del + .await + .unwrap(); + } + } self.refresh_config().await?; Ok(()) } diff --git a/src/test_helpers/netns.rs b/src/test_helpers/netns.rs index 4982feaf2..3dd403b5e 100644 --- a/src/test_helpers/netns.rs +++ b/src/test_helpers/netns.rs @@ -99,6 +99,22 @@ impl Namespace { format!("veth{}", self.id) } + // A small helper around run_ready that marks as "ready" immediately and waits for completion + pub fn run_and_wait(&self, f: F) -> anyhow::Result + where + F: FnOnce() -> Fut + Send + 'static, + Fut: Future>, + R: Send + 'static, + { + self.run_ready(|ready| async move { + ready.set_ready(); + f().await + }) + .unwrap() + .join() + .unwrap() + } + // A small helper around run_ready that marks as "ready" immediately. pub fn run(&self, f: F) -> anyhow::Result>> where diff --git a/tests/direct.rs b/tests/direct.rs index 458b0db04..17992e5d0 100644 --- a/tests/direct.rs +++ b/tests/direct.rs @@ -213,7 +213,6 @@ async fn run_requests_test( // Test a round trip outbound call (via socks5) let echo = tcp::TestServer::new(tcp::Mode::ReadWrite, 0).await; let echo_addr = echo.address(); - let dns_drain: Option = None; let mut cfg = config::Config { local_node: (!node.is_empty()).then(|| node.to_string()), ..test_config_with_port(echo_addr.port()) @@ -247,7 +246,6 @@ async fn run_requests_test( } }) .await; - drop(dns_drain); } #[tokio::test] diff --git a/tests/namespaced.rs b/tests/namespaced.rs index 68c18b18a..7e7c8c21f 100644 --- a/tests/namespaced.rs +++ b/tests/namespaced.rs @@ -22,6 +22,8 @@ mod namespaced { use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; + use std::sync::{Arc, Mutex}; + use std::thread::JoinHandle; use std::time::Duration; use ztunnel::rbac::{Authorization, RbacMatch, StringMatch}; @@ -471,6 +473,136 @@ mod namespaced { Ok(()) } + #[tokio::test] + async fn test_ztunnel_shutdown() -> anyhow::Result<()> { + let mut manager = setup_netns_test!(InPod); + let local = manager.deploy_ztunnel(DEFAULT_NODE).await?; + let server = manager + .workload_builder("server", DEFAULT_NODE) + .register() + .await?; + run_tcp_server(server)?; + + let client = manager + .workload_builder("client", DEFAULT_NODE) + .register() + .await?; + let (mut tx, rx) = mpsc_ack(1); + let srv = resolve_target(manager.resolver(), "server"); + + // Run a client which will send some traffic when signaled to do so + let cjh = run_long_running_tcp_client(&client, rx, srv).unwrap(); + + // First, send the initial request and wait for it + tx.send_and_wait(()).await?; + // Now start shutdown. Ztunnel should keep things working since we have pending open connections + local.shutdown.shutdown_now().await; + // Requests should still succeed... + tx.send_and_wait(()).await?; + // Close the connection + drop(tx); + + cjh.join().unwrap()?; + + assert_eventually( + Duration::from_secs(2), + || async { + client + .run_and_wait(move || async move { Ok(TcpStream::connect(srv).await?) }) + .is_err() + }, + true, + ) + .await; + // let res = client.run_and_wait(move || async move { Ok(TcpStream::connect(srv).await?) }); + // assert!(res.is_err(), "requests should fail after shutdown"); + Ok(()) + } + + #[tokio::test] + async fn test_server_shutdown() -> anyhow::Result<()> { + let mut manager = setup_netns_test!(InPod); + manager.deploy_ztunnel(DEFAULT_NODE).await?; + let server = manager + .workload_builder("server", DEFAULT_NODE) + .register() + .await?; + run_tcp_server(server)?; + + let client = manager + .workload_builder("client", DEFAULT_NODE) + .register() + .await?; + let (mut tx, rx) = mpsc_ack(1); + let srv = resolve_target(manager.resolver(), "server"); + + // Run a client which will send some traffic when signaled to do so + let cjh = run_long_running_tcp_client(&client, rx, srv).unwrap(); + + // First, send the initial request and wait for it + tx.send_and_wait(()).await?; + // Now shutdown the server. In real world, the server app would shutdown, then ztunnel would remove itself. + // In this test, we will leave the server app running, but shutdown ztunnel. + manager.delete_workload("server").await.unwrap(); + // Request should fail now + let tx = Arc::new(Mutex::new(tx)); + #[allow(clippy::await_holding_lock)] + assert_eventually( + Duration::from_secs(2), + || async { tx.lock().unwrap().send_and_wait(()).await.is_err() }, + true, + ) + .await; + // Close the connection + drop(tx); + + // Should fail as the last request fails + assert!(cjh.join().unwrap().is_err()); + + // Now try to connect and make sure it fails + client + .run_and_wait(move || async move { + let mut stream = TcpStream::connect(srv).await.unwrap(); + // We should be able to connect (since client is running), but not send a request + let send = timeout( + Duration::from_millis(50), + double_read_write_stream(&mut stream), + ) + .await; + assert!(send.is_err()); + Ok(()) + }) + .unwrap(); + Ok(()) + } + + fn run_long_running_tcp_client( + client: &Namespace, + mut rx: MpscAckReceiver<()>, + srv: SocketAddr, + ) -> anyhow::Result>> { + async fn double_read_write_stream(stream: &mut TcpStream) -> anyhow::Result { + const BODY: &[u8] = b"hello world"; + stream.write_all(BODY).await?; + let mut buf = [0; BODY.len() * 2]; + stream.read_exact(&mut buf).await?; + assert_eq!(b"hello worldhello world", &buf); + Ok(BODY.len() * 2) + } + client.run(move || async move { + let mut stream = timeout(Duration::from_secs(5), TcpStream::connect(srv)).await??; + while let Some(()) = rx.recv().await { + timeout( + Duration::from_secs(5), + double_read_write_stream(&mut stream), + ) + .await??; + rx.ack().await.unwrap(); + } + Ok(()) + }) + } + #[tokio::test] async fn test_policy() -> anyhow::Result<()> { let mut manager = setup_netns_test!(InPod);