diff --git a/crates/host/src/wasmbus/ctl.rs b/crates/host/src/wasmbus/ctl.rs index ff45db4387..7bff653744 100644 --- a/crates/host/src/wasmbus/ctl.rs +++ b/crates/host/src/wasmbus/ctl.rs @@ -492,7 +492,7 @@ impl ControlInterfaceServer for Host { let provider_ref = request.provider_ref(); let annotations = request.annotations(); - if let Err(err) = self + if let Err(err) = Arc::clone(&self) .handle_start_provider_task( config, provider_id, @@ -538,9 +538,16 @@ impl ControlInterfaceServer for Host { return Ok(CtlResponse::error("provider with that ID is not running")); }; let Provider { - ref annotations, .. + ref annotations, + mut tasks, + shutdown, + .. } = entry.remove(); + // Set the shutdown flag to true to stop health checks and config updates. Also + // prevents restarting the provider but does not stop the provider process. + shutdown.store(true, Ordering::Relaxed); + // Send a request to the provider, requesting a graceful shutdown let req = serde_json::to_vec(&json!({ "host_id": host_id })) .context("failed to encode provider stop request")?; @@ -566,7 +573,13 @@ impl ControlInterfaceServer for Host { provider_id, "provider did not gracefully shut down in time, shutting down forcefully" ); + // NOTE: The provider child process is spawned with [tokio::process::Command::kill_on_drop], + // so dropping the task will send a SIGKILL to the provider process. } + + // Stop the provider and health check / config changes tasks + tasks.abort_all(); + info!(provider_id, "provider stopped"); self.publish_event( "provider_stopped", diff --git a/crates/host/src/wasmbus/mod.rs b/crates/host/src/wasmbus/mod.rs index 052fc63da5..5998e99027 100644 --- a/crates/host/src/wasmbus/mod.rs +++ b/crates/host/src/wasmbus/mod.rs @@ -1054,39 +1054,41 @@ impl Host { #[instrument(level = "debug", skip_all)] async fn inventory(&self) -> HostInventory { trace!("generating host inventory"); - let components = self.components.read().await; - let components: Vec<_> = stream::iter(components.iter()) - .filter_map(|(id, component)| async move { - let mut description = ComponentDescription::builder() - .id(id.into()) - .image_ref(component.image_reference.to_string()) - .annotations(component.annotations.clone().into_iter().collect()) - .max_instances(component.max_instances.get().try_into().unwrap_or(u32::MAX)) - .revision( - component - .claims() - .and_then(|claims| claims.metadata.as_ref()) - .and_then(|jwt::Component { rev, .. }| *rev) - .unwrap_or_default(), - ); - // Add name if present - if let Some(name) = component - .claims() - .and_then(|claims| claims.metadata.as_ref()) - .and_then(|metadata| metadata.name.as_ref()) - .cloned() - { - description = description.name(name); - }; + let components: Vec<_> = { + let components = self.components.read().await; + stream::iter(components.iter()) + .filter_map(|(id, component)| async move { + let mut description = ComponentDescription::builder() + .id(id.into()) + .image_ref(component.image_reference.to_string()) + .annotations(component.annotations.clone().into_iter().collect()) + .max_instances(component.max_instances.get().try_into().unwrap_or(u32::MAX)) + .revision( + component + .claims() + .and_then(|claims| claims.metadata.as_ref()) + .and_then(|jwt::Component { rev, .. }| *rev) + .unwrap_or_default(), + ); + // Add name if present + if let Some(name) = component + .claims() + .and_then(|claims| claims.metadata.as_ref()) + .and_then(|metadata| metadata.name.as_ref()) + .cloned() + { + description = description.name(name); + }; - Some( - description - .build() - .expect("failed to build component description: {e}"), - ) - }) - .collect() - .await; + Some( + description + .build() + .expect("failed to build component description: {e}"), + ) + }) + .collect() + .await + }; let providers: Vec<_> = self .providers diff --git a/crates/host/src/wasmbus/providers/mod.rs b/crates/host/src/wasmbus/providers/mod.rs index 73ce8f292c..28222ba819 100644 --- a/crates/host/src/wasmbus/providers/mod.rs +++ b/crates/host/src/wasmbus/providers/mod.rs @@ -19,9 +19,9 @@ use cloudevents::EventBuilderV10; use futures::{stream, Future, StreamExt}; use nkeys::XKey; use tokio::io::AsyncWriteExt; +use tokio::process; use tokio::sync::RwLock; use tokio::task::JoinSet; -use tokio::{process, select}; use tracing::{error, instrument, trace, warn}; use uuid::Uuid; use wascap::jwt::{CapabilityProvider, Token}; @@ -242,7 +242,6 @@ impl Host { Arc::clone(&self.host_config.lattice), self.host_key.public_key(), provider_id.to_string(), - shutdown.clone(), )); Ok(tasks) @@ -282,7 +281,6 @@ impl Host { Arc::clone(&config_bundle), Arc::clone(&lattice), provider_id.clone(), - shutdown.clone(), )); loop { let mut child = child.write().await; @@ -344,7 +342,6 @@ impl Host { new_config_bundle, Arc::clone(&lattice), provider_id.clone(), - shutdown.clone(), )); // Restart the provider by attempting to re-execute the binary with the same @@ -428,7 +425,6 @@ fn check_health( lattice: Arc, host_id: String, provider_id: String, - shutdown: Arc, ) -> impl Future { let health_subject = async_nats::Subject::from(format!("wasmbus.rpc.{lattice}.{provider_id}.health")); @@ -440,90 +436,87 @@ fn check_health( health_check.reset_after(Duration::from_secs(5)); async move { loop { - select! { - _ = health_check.tick() => { - trace!(?provider_id, "performing provider health check"); - let request = async_nats::Request::new() - .payload(Bytes::new()) - .headers(injector_to_headers(&TraceContextInjector::default_with_span())); - if let Ok(async_nats::Message { payload, ..}) = rpc_nats.send_request( - health_subject.clone(), - request, - ).await { - match (serde_json::from_slice::(&payload), previous_healthy) { - (Ok(HealthCheckResponse { healthy: true, ..}), false) => { - trace!(?provider_id, "provider health check succeeded"); - previous_healthy = true; - if let Err(e) = event::publish( - &event_builder, - &ctl_nats, - &lattice, - "health_check_passed", - event::provider_health_check( - &host_id, - &provider_id, - ) - ).await { - warn!( - ?e, - ?provider_id, - "failed to publish provider health check succeeded event", - ); - } - }, - (Ok(HealthCheckResponse { healthy: false, ..}), true) => { - trace!(?provider_id, "provider health check failed"); - previous_healthy = false; - if let Err(e) = event::publish( - &event_builder, - &ctl_nats, - &lattice, - "health_check_failed", - event::provider_health_check( - &host_id, - &provider_id, - ) - ).await { - warn!( - ?e, - ?provider_id, - "failed to publish provider health check failed event", - ); - } - } - // If the provider health status didn't change, we simply publish a health check status event - (Ok(_), _) => { - if let Err(e) = event::publish( - &event_builder, - &ctl_nats, - &lattice, - "health_check_status", - event::provider_health_check( - &host_id, - &provider_id, - ) - ).await { - warn!( - ?e, - ?provider_id, - "failed to publish provider health check status event", - ); - } - }, - _ => warn!( - ?provider_id, - "failed to deserialize provider health check response" - ), - } + let _ = health_check.tick().await; + trace!(?provider_id, "performing provider health check"); + let request = + async_nats::Request::new() + .payload(Bytes::new()) + .headers(injector_to_headers( + &TraceContextInjector::default_with_span(), + )); + if let Ok(async_nats::Message { payload, .. }) = + rpc_nats.send_request(health_subject.clone(), request).await + { + match ( + serde_json::from_slice::(&payload), + previous_healthy, + ) { + (Ok(HealthCheckResponse { healthy: true, .. }), false) => { + trace!(?provider_id, "provider health check succeeded"); + previous_healthy = true; + if let Err(e) = event::publish( + &event_builder, + &ctl_nats, + &lattice, + "health_check_passed", + event::provider_health_check(&host_id, &provider_id), + ) + .await + { + warn!( + ?e, + ?provider_id, + "failed to publish provider health check succeeded event", + ); } - else { - warn!(?provider_id, "failed to request provider health, retrying in 30 seconds"); + } + (Ok(HealthCheckResponse { healthy: false, .. }), true) => { + trace!(?provider_id, "provider health check failed"); + previous_healthy = false; + if let Err(e) = event::publish( + &event_builder, + &ctl_nats, + &lattice, + "health_check_failed", + event::provider_health_check(&host_id, &provider_id), + ) + .await + { + warn!( + ?e, + ?provider_id, + "failed to publish provider health check failed event", + ); } + } + // If the provider health status didn't change, we simply publish a health check status event + (Ok(_), _) => { + if let Err(e) = event::publish( + &event_builder, + &ctl_nats, + &lattice, + "health_check_status", + event::provider_health_check(&host_id, &provider_id), + ) + .await + { + warn!( + ?e, + ?provider_id, + "failed to publish provider health check status event", + ); + } + } + _ => warn!( + ?provider_id, + "failed to deserialize provider health check response" + ), } - true = async { shutdown.load(Ordering::Relaxed) } => { - trace!(?provider_id, "received shutdown signal, stopping health check task"); - break; - } + } else { + warn!( + ?provider_id, + "failed to request provider health, retrying in 30 seconds" + ); } } } @@ -538,38 +531,28 @@ fn watch_config( config: Arc>, lattice: Arc, provider_id: String, - shutdown: Arc, ) -> impl Future { let subject = provider_config_update_subject(&lattice, &provider_id); trace!(?provider_id, "starting config update listener"); async move { loop { let mut config = config.write().await; - select! { - maybe_update = config.changed() => { - let Ok(update) = maybe_update else { - // TODO: shouldn't this be continue? - break; - }; - trace!(?provider_id, "provider config bundle changed"); - let bytes = match serde_json::to_vec(&*update) { - Ok(bytes) => bytes, - Err(err) => { - error!(%err, ?provider_id, ?lattice, "failed to serialize configuration update "); - continue; - } - }; - trace!(?provider_id, subject, "publishing config bundle bytes"); - if let Err(err) = rpc_nats.publish(subject.clone(), Bytes::from(bytes)).await { - error!(%err, ?provider_id, ?lattice, "failed to publish configuration update bytes to component"); + if let Ok(update) = config.changed().await { + trace!(?provider_id, "provider config bundle changed"); + let bytes = match serde_json::to_vec(&*update) { + Ok(bytes) => bytes, + Err(err) => { + error!(%err, ?provider_id, ?lattice, "failed to serialize configuration update "); + continue; } + }; + trace!(?provider_id, subject, "publishing config bundle bytes"); + if let Err(err) = rpc_nats.publish(subject.clone(), Bytes::from(bytes)).await { + error!(%err, ?provider_id, ?lattice, "failed to publish configuration update bytes to component"); } - true = async { shutdown.load(Ordering::Relaxed) } => { - trace!(?provider_id, "received shutdown signal, stopping config update listener"); - // TODO: shouldn't this be return? - break; - } - } + } else { + break; + }; } } }